什么是Triton/Pytorch/CUDA

  • 显然,接触最多的pytorch是最上层的框架,在pytorch中,最小的计算单位是 Tensor, 我们无需关心矩阵乘法在GPU内部的调度,内存管理等细节,屏蔽了底层所有的硬件细节
  • 对于cuda来说,视角是单个线程。手动管理Grid, BLock控制每一个Thread行为,例如线程读取什么内存,计算什么元素,什么时候同步等;
    1
    2
    3
    4
    int idx = blockIDx.x + blockDim.x + threadIdx.x
    if (idx < N){
    C[idx] = A[idx] + B[idx]
    }
    此外还要手动管理共享内存,寄存器使用等榨干GPU的性能能
  • 而Triton就显然是折中的方案,它使用分块变成范式,不需要管理但单个线程,而是以Tile为粒度描述计算
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    @triton.jit
    def add_kernel(A_ptr,B_ptr,C_ptr,N,BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0) # 获得当前Program ID
    # 获得Tile内的全部偏移
    offsets = pid * BLOCK_SIZE + tl.arange(0,BLOCK_SIZE)
    mask = offsets < N

    a = tl.load(A_ptr + offsets)
    b = tl.load(B_ptr + offsets)

    c = a + b
    tl.store(C_ptr + offsets, c)
    这看起来像是标量计算,但是注意tl.arange实际上生成了一个向量化的偏移量数组,tl.load 一次性加载一整块数据,triton的编译器则会负责将这块的计算映射到底层线程上
对比维度 CUDA Triton Pytorch
编程语言 C++ Python+装饰器标注 Python
线程管理 显式线程Grid、Block、Thread 分块Program 张量运算
共享内存 手动分配与管理 编译器自动分配管理 完全透明
上手难度
平台 Nvidia Nvidia、AMD等 大多数现代GPU

基础的Kernel

这一部分是最核心基础部分,大部分的Triton Kernel都遵循以下模板:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import triton
import triton.language as tl
@triton..jit
def my_kernel(
input_ptr, # 输入数据指针
output_ptr, # 输出数据指针
N, #总元素数量
BLOCK_SIZE: tl.constexpr # 编译器常量,每个block处理的元素数
):
pid = tl.program_id(axis=0) # 获得当前BLock id

block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0,BLOCK_SIZE) # 计算当前block负责的偏移量范围

mask = offsets < N # 越界保护掩码

x = tl.load(input_ptr + offests,mask=mask) # 从全局内存加载数据
y = your_computation(x) # 执行具体计算

tl.store(output_dir + offsets,y,mask=mask) # 将结果写回全局内存

## host端封装函数
由于kernel不能直接被用户调用,需要一个Python函数进行封装
1
2
3
4
5
6
7
8
9
10
11
12
13
def my_operator(x:torch.Tensor) -> torch.Tensor:
assert x.is_cuda and x.is_contiguous()

output = torch.empty_like(x)
N = x.numel()

BLOCK_SIZE = 256

grid = (triton.cdiv(N,BLOCK_SIZE),)

my_kernel[gird](x,output,N,BLOCK_SIZE=BLOCK_SIZE)

return output

自动调优

在实际项目中,BLOCK_SIZE 等参数需要根据问题规模动态选择,因此Triton提供 @triton.autotune 装饰器

1
2
3
4
5
6
7
8
9
10
11
12
@triton.autotune(
configs = [
trinton.Config({'BLOCK_SIZE': 128},num_warps=4) # 每个Block分配的Warp数量,一个Warp中有32个线程
trinton.Config({'BLOCK_SIZE': 256},num_warps=48)
trinton.Config({'BLOCK_SIZE': 512},num_warps=16)
],
key = ['N'],
)
@triton.jit
def my_kernel_autotuned(input_ptr,output_dir,N,BLOCK_SIZE: tl.constexprt):
#kernel 的内容
...

基准测试以及正确性验证

1
2
3
4
5
6
7
8
9
10
11
12
13
def test_operator():
x = torch.randn(10000, device='cuda')

y_ref = torch.sigmoid(x)

y_triton = my_operator(x)

torch.testing.assert_close(y_triton,y_ref,rtol=1e-3,atol=1e-5)

ms_ref = triton.testing.do_bench(lambda: torch.sigmoid(x)) * 1000
ms_triton = triton.testing.do_bench(lambda: my_operator(x)) * 1000
print(f"PyTorch: {ms_ref:.4f} ms, Triton: {ms_triton:.4f} ms")

Pytorch 算子注册

对于常用的算子,我们可以将其注册到Pytorch的 torch.ops 中

1
2
3
4
5
from torch.library import custom_op
@custom_op("my_ops::sigmoid_custom",mutates_args=())
def sigmoid_custom_cuda(x: torch.Tensor) -> torch.Tensor:
return my_operator(x)
y = torch.ops.my_ops.sigmoid_custom(x)

写triton需要考虑的问题

  • 正确性: 数据类型是否支持,输入指针是否有效,数据规模是否合理,边界处理是否正确
  • 性能:Block_size 和 num_wraps如何调优,是否需要共享内存,是否需要多维grid,是否需要流水线预处理
  • 可维护性: 算子是否需要继承,是否需要支持多种GPU架构,是否具有足够的可扩展性.