什么是Triton/Pytorch/CUDA
- 显然,接触最多的pytorch是最上层的框架,在pytorch中,最小的计算单位是 Tensor, 我们无需关心矩阵乘法在GPU内部的调度,内存管理等细节,屏蔽了底层所有的硬件细节
- 对于cuda来说,视角是单个线程。手动管理Grid, BLock控制每一个Thread行为,例如线程读取什么内存,计算什么元素,什么时候同步等;
1 2 3 4
| int idx = blockIDx.x + blockDim.x + threadIdx.x if (idx < N){ C[idx] = A[idx] + B[idx] }
|
此外还要手动管理共享内存,寄存器使用等榨干GPU的性能能
- 而Triton就显然是折中的方案,它使用分块变成范式,不需要管理但单个线程,而是以Tile为粒度描述计算
1 2 3 4 5 6 7 8 9 10 11 12
| @triton.jit def add_kernel(A_ptr,B_ptr,C_ptr,N,BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) offsets = pid * BLOCK_SIZE + tl.arange(0,BLOCK_SIZE) mask = offsets < N
a = tl.load(A_ptr + offsets) b = tl.load(B_ptr + offsets)
c = a + b tl.store(C_ptr + offsets, c)
|
这看起来像是标量计算,但是注意tl.arange实际上生成了一个向量化的偏移量数组,tl.load 一次性加载一整块数据,triton的编译器则会负责将这块的计算映射到底层线程上
| 对比维度 |
CUDA |
Triton |
Pytorch |
| 编程语言 |
C++ |
Python+装饰器标注 |
Python |
| 线程管理 |
显式线程Grid、Block、Thread |
分块Program |
张量运算 |
| 共享内存 |
手动分配与管理 |
编译器自动分配管理 |
完全透明 |
| 上手难度 |
高 |
中 |
低 |
| 平台 |
Nvidia |
Nvidia、AMD等 |
大多数现代GPU |
基础的Kernel
这一部分是最核心基础部分,大部分的Triton Kernel都遵循以下模板:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| import torch import triton import triton.language as tl @triton..jit def my_kernel( input_ptr, output_ptr, N, BLOCK_SIZE: tl.constexpr ): pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0,BLOCK_SIZE)
mask = offsets < N
x = tl.load(input_ptr + offests,mask=mask) y = your_computation(x)
tl.store(output_dir + offsets,y,mask=mask)
|
## host端封装函数
由于kernel不能直接被用户调用,需要一个Python函数进行封装
1 2 3 4 5 6 7 8 9 10 11 12 13
| def my_operator(x:torch.Tensor) -> torch.Tensor: assert x.is_cuda and x.is_contiguous()
output = torch.empty_like(x) N = x.numel()
BLOCK_SIZE = 256
grid = (triton.cdiv(N,BLOCK_SIZE),)
my_kernel[gird](x,output,N,BLOCK_SIZE=BLOCK_SIZE)
return output
|
自动调优
在实际项目中,BLOCK_SIZE 等参数需要根据问题规模动态选择,因此Triton提供 @triton.autotune 装饰器
1 2 3 4 5 6 7 8 9 10 11 12
| @triton.autotune( configs = [ trinton.Config({'BLOCK_SIZE': 128},num_warps=4) trinton.Config({'BLOCK_SIZE': 256},num_warps=48) trinton.Config({'BLOCK_SIZE': 512},num_warps=16) ], key = ['N'], ) @triton.jit def my_kernel_autotuned(input_ptr,output_dir,N,BLOCK_SIZE: tl.constexprt): ...
|
基准测试以及正确性验证
1 2 3 4 5 6 7 8 9 10 11 12 13
| def test_operator(): x = torch.randn(10000, device='cuda')
y_ref = torch.sigmoid(x)
y_triton = my_operator(x)
torch.testing.assert_close(y_triton,y_ref,rtol=1e-3,atol=1e-5)
ms_ref = triton.testing.do_bench(lambda: torch.sigmoid(x)) * 1000 ms_triton = triton.testing.do_bench(lambda: my_operator(x)) * 1000 print(f"PyTorch: {ms_ref:.4f} ms, Triton: {ms_triton:.4f} ms")
|
Pytorch 算子注册
对于常用的算子,我们可以将其注册到Pytorch的 torch.ops 中
1 2 3 4 5
| from torch.library import custom_op @custom_op("my_ops::sigmoid_custom",mutates_args=()) def sigmoid_custom_cuda(x: torch.Tensor) -> torch.Tensor: return my_operator(x) y = torch.ops.my_ops.sigmoid_custom(x)
|
写triton需要考虑的问题
- 正确性: 数据类型是否支持,输入指针是否有效,数据规模是否合理,边界处理是否正确
- 性能:Block_size 和 num_wraps如何调优,是否需要共享内存,是否需要多维grid,是否需要流水线预处理
- 可维护性: 算子是否需要继承,是否需要支持多种GPU架构,是否具有足够的可扩展性.