文章

triton 学习

triton 学习

🧾 问题场景

  • 输入数组(即一维 Tensor) A: A = [100, 99, 98, ..., 1](长度 N = 100)
  • 我们希望 把所有元素乘 2
  • 我们设定:BLOCK_SIZE = 32(每个线程处理 32 个元素)

🧠 核心代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import triton
import triton.language as tl

@triton.jit
def multiply_kernel(A_ptr, N, BLOCK_SIZE: tl.constexpr):
    # 获取当前程序实例(Program)的 ID,等价于线程编号
    # 类似于 POSIX 中的:pthread_t tid = pthread_self();
    pid = tl.program_id(0)

    # 计算当前程序负责处理的数据索引范围
    # tl.arange(0, BLOCK_SIZE) 返回一个向量 [0, 1, ..., BLOCK_SIZE-1]
    # 每个 program 处理 BLOCK_SIZE 个元素,偏移量取决于 pid
    # 例如:pid=1 → offsets = [32, 33, ..., 63]
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    # 创建掩码(mask),用于屏蔽越界访问
    # 只处理 offsets 中小于 N 的合法索引(N 是数组总长度)
    mask = offsets < N

    # 从全局内存中加载对应索引位置的数据
    # A_ptr 是数组 A 的指针,offsets 是该 program 负责的偏移
    # 加上 mask 是为了避免访问越界位置
    # a 变量所在的位置是program自己的私有sram
    a = tl.load(A_ptr + offsets, mask=mask)

    # 对数据进行逐元素乘以 2 的操作
    a *= 2

    # 将结果写回原始位置,仍使用 mask 避免越界写入
    tl.store(A_ptr + offsets, a, mask=mask)

🧮 执行

  1. 计算所需 grid 数量
  2. 调用 triton 框架开始运行
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 🔢 准备输入数据:逆序 [100, 99, ..., 1]
x = torch.arange(100, 0, -1, dtype=torch.float32, device='cuda')
N = x.numel()
MY_BLOCK_SIZE = 32

# 🧮 计算启动多少个 program(每个处理 MY_BLOCK_SIZE 个元素)
# Triton 启动网格必须是元组格式 (1D, ),哪怕是一维也不能省略逗号
grid = (triton.cdiv(N, MY_BLOCK_SIZE),)

# 解释:grid = ((N + MY_BLOCK_SIZE - 1) // MY_BLOCK_SIZE,)
# 对于 N = 100,BLOCK_SIZE = 32,结果是 4,最后一个 program 只处理 4 个元素

# 🧠 可选写法:使用 lambda 动态计算(推荐用于自动调优)
# grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']),)
1
2
3
4
5
6
# 🚀 启动 Triton kernel
multiply_kernel[grid](
    x,              # A_ptr:PyTorch Tensor 会自动转换为设备指针传入 GPU
    N,              # N:待处理元素数量
    MY_BLOCK_SIZE
)

multiply_kernel[grid]() 这句话是 triton 自定义语法,实际就是重载了 python __getitem__ 方法

python 中的 obj[key] ⇨ 实际调用的是:obj.__getitem__(key),默认功能为下标访问

通过重载让[]的含义从下标访问变为创建一个带 grid 的 kernel 实例对象,这里的 grid 的类型可以是一个 Tuple[int],也可以是一个 lambda(这个 lambda 本质也是返回一个Tuple[int],结果是一样的)。当使用 lambda 类型时,triton 框架会将后面的方法的最后一个参数转换为 dict 格式传递给这个 lambda 的捕获参数(即示例中的meta),lambda 内部就能通过 meta['BLOCK_SIZE'] 访问到(这里用到了反射特性,将 BLOCK_SIZE 这个参数名反射成了 meta 的 key 值,变成了字符串)。

后面的 () 符号就是调用这个对象的 __call__ 方法,这个方法内部就会调用上面定义的用 @triton.jit 修饰的 multiply_kernel 方法,参数个数和定义和 multiply_kernel 方法相同

然后 Triton 会并发启动 4 个 program(线程实例)

  • pid = 0
  • pid = 1
  • pid = 2
  • pid = 3

🔁 运行时过程

这 4 个 Program 可以理解为操作系统层面的 4 个并发线程,每个处理数组的一块元素。

🎯 Program 0(pid = 0

  • offsets = 0 * 32 + [0..31] = [0, 1, ..., 31]
  • mask = [True, True, ..., True](全部在范围内)
  • a = A[0..31] = [100, 99, 98, ..., 69](数组开头的 32 个元素)
  • a *= 2[200, 198, 196, ..., 138]
  • 写回到 A[0..31]

🎯 Program 1(pid = 1

  • offsets = 1 * 32 + [0..31] = [32, 33, ..., 63]
  • mask = [True, ..., True](全部合法)
  • a = A[32..63] = [68, 67, ..., 37]
  • a *= 2
  • 写回 A[32..63]

🎯 Program 2(pid = 2

  • offsets = 2 * 32 + [0..31] = [64, ..., 95]
  • mask = [True, ..., True]
  • a = A[64..95] = [36, 35, ..., 5]
  • a *= 2
  • 写回 A[64..95]

🎯 Program 3(pid = 3

  • offsets = 3 * 32 + [0..31] = [96, ..., 127]
  • mask = [True, True, True, True, False, ..., False] 只有前 4 个元素 [96, 97, 98, 99] 是合法的
  • a = A[96..99] = [4, 3, 2, 1]
  • a *= 2
  • 写回 A[96..99],其余跳过

🖼️ 最终结果

1
2
3
4
5
6
7
8
9
10
11
print(x)

tensor([200., 198., 196., 194., 192., 190., 188., 186., 184., 182., 180., 178.,
        176., 174., 172., 170., 168., 166., 164., 162., 160., 158., 156., 154.,
        152., 150., 148., 146., 144., 142., 140., 138., 136., 134., 132., 130.,
        128., 126., 124., 122., 120., 118., 116., 114., 112., 110., 108., 106.,
        104., 102., 100.,  98.,  96.,  94.,  92.,  90.,  88.,  86.,  84.,  82.,
         80.,  78.,  76.,  74.,  72.,  70.,  68.,  66.,  64.,  62.,  60.,  58.,
         56.,  54.,  52.,  50.,  48.,  46.,  44.,  42.,  40.,  38.,  36.,  34.,
         32.,  30.,  28.,  26.,  24.,  22.,  20.,  18.,  16.,  14.,  12.,  10.,
          8.,   6.,   4.,   2.], device='cuda:0')

数组 A 被 并发分成 4 块处理,每块最多 32 个元素(最后一块不足补齐):

1
2
原始 A: [100, 99, 98, ..., 1]
处理后 A: [200, 198, 196, ..., 2]

🎯 图示结构

program 1program 2program 3program 4
pid = 0pid = 1pid = 2pid = 3
A[0..31]A[32..63]A[64..95]A[96..99]
32 个元素32 个元素32 个元素4 个元素
  • 每个 pid 是一个 独立并发的 program 实例
  • offsets 是该程序负责的 数组元素索引
  • mask 保证 不会访问越界元素

✅ 总结

Triton 行为C++/并发类比
并发程序数(N + BLOCK_SIZE - 1) / BLOCK_SIZE
每个程序处理BLOCK_SIZE 个元素
offsets当前程序负责处理的索引区间
mask越界保护(类似边界检查)
整体执行类似用线程池并发执行分片任务

环境搭建

  1. vscode 安装 jupyter 插件
  2. 配置 venv,安装 torch 和 triton。如果自己已经在别的目录创建了 venv,需要在那里安装 pip install jupyter,然后执行 python -m ipykernel install --user --name=myenv --display-name "Python (myenv)" 进行注册
  3. 新建文件 triton_multiply.ipynb,在编辑器的左上角右键->插入单元格,然后插入代码运行

alt text

alt text

alt text

本文由作者按照 CC BY-SA 4.0 进行授权