Flash Attention记录
简单记录一下flash attention的推导和实现。
Naive Attention
\[ \begin{aligned} S &= QK^T \in \mathbb{R}^{N\times N} \\ P &= softmax(S) \\ O &= PV \in \mathbb{R}^{N \times d} \end{aligned} \]
本质上就是两个matmul中间有一个伪elementwise操作。 主要由于softmax每个输出点依赖了他的所有input,导致无法进行tiling fusion。原因如下
SoftMax
naive softmax实现:
\[ \begin{aligned} softmax(x_i, \ldots, x_N) = \frac{e ^ {x_i}}{\sum_{j=1}^{N} e^{x_j}}, i \in [1, N] \end{aligned} \]
由于\(e^x\)会很大,容易出现数值溢出,因此出现safe softmax
safe softmax实现
\[ \begin{aligned} max_{N} &= max(x_i), \ i \in [1, N] \\ softmax(x_i, \ldots, x_N) &= \frac{e ^ {x_i - max_{N}}}{\sum_{j=1}^{N} e^{x_j - max_{N}}}, i \in [1, N] \end{aligned} \]
但是在实现上需要循环三次,因为\(max_N\)和\(x_{sum}\)都需要单独的循环
\[ \begin{aligned} max_i &= max(m_{i-1}, x_i)\\ sum_i &= sum_{i-1} + e^{x_i - max_N} \\ a_i &= \frac{e^{x_i - max_N}}{sum_N}, \ i \in [1, N] \end{aligned} \]
2-pass softmax
说实话,让我是没办法能想出来融合max以及sum的公式,可能作者也是受Welford方法启发的。 首先我们考虑把\(sum_N\)的公式展开,通过\(exp\)的计算性质把\(- max_N\)这个拆分为两个部分 \[ \begin{aligned} sum_{N} &= \sum_{j = 1} ^ N e^{x_j - max_N} \\ \\ &= \sum_{j = 1} ^ {N-1} e^{x_j - max_N} + e^{x_N - max_N} \\ &= \sum_{j = 1} ^ {N-1} e^{x_j - max_{N-1} + max_{N-1} - max_N} + e^{x_N - max_N} \\ &= (\sum_{j = 1} ^ {N-1} e^{x_j - max_{N-1}}) e^{max_{N-1} - max_N} +e^{x_N - max_N} \\ \end{aligned} \]
观察上面的公式,发现如果从另一个视角去定义变量就可以让他们递归起来 \[ \begin{aligned} \text{let}\ sum_{N}^\prime &=\sum_{j=1}^{N} e^{x_j - max_N} = sum_{N} \\ &= (\sum_{j = 1} ^ {N-1} e^{x_j - max_{N-1}}) e^{max_{N-1} - max_N} +e^{x_N - max_N} \\ &= sum_{N-1}^\prime e^{max_{N-1} - max_N} +e^{x_N - max_N} \\ sum_i ^\prime &= sum_{i-1} ^\prime e^{max_{i-1} - max_i} +e^{x_i - max_i} \end{aligned} \]
通过视角的转换将\(max_N\)与\(sum_{i}\)进行了解耦,并且当迭代到最后时\(sum_{N}^\prime = sum_{N}\)。 虽然2-pass的方式需要在每次迭代添加额外的乘\(e^{max_{i-1} - max_i}\)的运算,但显然比访存开销低很多。
Flash Attention
2-pass Attention
首先使用2-pass的softmax来实现一个attention,这里为了不混淆query len和seq len, 分别用k和i来表示。
\[ \begin{aligned} \text{for i in [1, N]}:&\\ x_i &= Q[k, :]K^T[:, i]\\ max_i &= \max(max_{i-1}, x_i) \\ sum_i^\prime &= sum_{i-1} ^\prime e^{max_{i-1} - max_i} +e^{x_i - max_i}\\ \text{end} \qquad \qquad \\ \text{for i in [1, N]}:&\\ a_i &= \frac{e^{x_i - max_N}}{sum_N^\prime} \\ o_i &= o_{i-1} + a_i V[i,:] \\ \text{end} \qquad \qquad \\ O[k,:] & = o_N \end{aligned} \]
1-pass Attention
在和V做矩阵乘时,每一个\(o_i\)还是依赖了\(max_N\)。 接下来就是找到办法把\(max_N\)的依赖消除。参考2-pass softmax的套路先定义: \[
\begin{aligned}
o_N^\prime &= \sum_{i = 1} ^ {N} a_i V[i,:] \\
&= \sum_{i = 1} ^ {N} \frac{e^{x_i - max_N}}{sum_N^\prime} V[i,:] \\
&= (\sum_{i = 1} ^ {N-1} \frac{e^{x_i - max_N}}{sum_N^\prime} V[i,:]) + \frac{e^{x_N - max_N}}{sum_N^\prime} V[N,:] \\
&= (\sum_{i = 1} ^ {N-1} \frac{e^{x_i - max_N}}{sum_N^\prime} \frac{sum_{N-1}^\prime}{sum_{N-1}^\prime} \frac{e^{x_i - max_{N-1}}}{e^{x_i - max_{N-1}}} V[i,:]) + \frac{e^{x_N - max_N}}{sum_N^\prime} V[N,:] \\
&= (\sum_{i = 1} ^ {N-1} \frac{e^{x_i - max_{N-1}}}{sum_{N-1}^\prime} V[i,:]) \frac{sum_{N-1}^\prime}{sum_{N}^\prime}\frac{e^{x_i - max_N}}{e^{x_i - max_{N-1}}} + \frac{e^{x_N - max_N}}{sum_N^\prime} V[N,:] \\
&= (\sum_{i = 1} ^ {N-1} \frac{e^{x_i - max_{N-1}}}{sum_{N-1}^\prime} V[i,:]) \frac{sum_{N-1}^\prime}{sum_{N}^\prime}e^{max_{N-1} - max_N} + \frac{e^{x_N - max_N}}{sum_N^\prime} V[N,:] \\
&= o_{N-1}^\prime \frac{sum_{N-1}^\prime}{sum_{N}^\prime}e^{max_{N-1} - max_N} + \frac{e^{x_N - max_N}}{sum_N^\prime} V[N,:] \\
\end{aligned}
\] 然后归纳得到不包含\(max_N\)的\(o_i^\prime\)公式为: \[
\begin{aligned}
o_i^\prime &= o_{i-1}^\prime \frac{sum_{i-1}^\prime}{sum_{i}^\prime}e^{max_{i-1} - max_i} + \frac{e^{x_i - max_i}}{sum_i^\prime} V[i,:]
\end{aligned}
\]
最终列出标量化的1-pass Attention形式: \[ \begin{aligned} \text{for i in [1, N]}:&\\ x_i &= Q[k, :]K^T[:, i]\\ max_i &= \max(max_{i-1}, x_i) \\ sum_i^\prime &= sum_{i-1} ^\prime e^{max_{i-1} - max_i} +e^{x_i - max_i}\\ o_i^\prime &= o_{i-1}^\prime \frac{sum_{i-1}^\prime}{sum_{i}^\prime}e^{max_{i-1} - max_i} + \frac{e^{x_i - max_i}}{sum_i^\prime} V[i,:] \\ \text{end} \qquad \qquad\\ O[k,:] & = o_N \end{aligned} \]
Flash Attention v1
上面推导出来的1-pass attention是基于标量循环的,对于flash attention是需要按tile进行计算的,所以具体的公式还需要稍作修改。
首先列出普通的softmax计算公式:
\[ \begin{aligned} X & = [x_1, \ldots , x_N] \\ max_N & = \max(X) \\ &= \max([x_1, \ldots , x_N]) \\ f(X) & = [f(x_1), \ldots , f(x_N)] \\ &= [e^{x_1 - max_N}, \ldots, e^{x_N - max_N}] \\ sum_N & = \sum_{i = 1}^N f(x_i) \\ & = \sum_{i = 1}^N e^{x_i - max_N} \\ softmax(X) & = \frac{ f(X)}{sum_N} \end{aligned} \]
现在来推导tiled softmax的计算公式, 那么假设现在的\(X\)是由两个长度为\(N\)的子向量组成的, 那么首先把它看成单个向量计算,然后拆分转换为可分治的公式: \[ \begin{aligned} X & = [x^1, x^2] \\ max_{2N} & = \max([\max(x^1),\max(x^2)]) \\ & = \max([max_N^1, max_N^2]) \\ f(X) &= \left[ [e^{x_1^1 - max_{2N}},\ldots, e^{x_N^1 - max_{2N}}] , [e^{x_1^2 - max_{2N}},\ldots, e^{x_N^2 - max_{2N}}] \right] \\ &= \left[ e^{max_N^1 - max_{2N}} [e^{x_1^1 - max_N^1},\ldots, e^{x_N^1 - max_N^1}] , e^{max_N^2 - max_{2N}} [e^{x_1^2 - max_N^2},\ldots, e^{x_N^2 - max_N^2}] \right] \\ &= \left[ e^{max_N^1 - max_{2N}} f(x^1) , e^{max_N^2 - max_{2N}} f(x^2) \right] \\ sum_{2N} & = \sum_{i = 1}^N e^{x_i^1 - max_{2N}} + \sum_{i = 1}^N e^{x_i^2 - max_{2N}} \\ & = e^{max_N^1 - max_{2N}} \sum_{i = 1}^N e^{x_i^1 - max_{N}^1} + e^{max_N^2 - max_{2N}} \sum_{i = 1}^N e^{x_i^2 - max_{N}^2} \\ & = e^{max_N^1 - max_{2N}} sum_N^1 + e^{max_N^2 - max_{2N}} sum_N^2 \\ softmax(X) &= \frac{f(X)}{sum_{2N}} \end{aligned} \]
此时可以发现,除了每个子向量的 \(max^j\) 用于计算 \(f(x^j), sum^j\),还需要维护整体的\(max, sum\)用于计算最终的结果。
flash attention的tiling就是将\(x_i\)向量化,基于1-pass attention的公式,结合tiled softmax公式,只需要略微修改\(max_i,sum_i\)的计算即可得到flash attention的公式: \[ \begin{aligned} \text{for i in [1, N/b]}:&\\ x_i &= Q[k, :]K^T[:, i:i+b]\\ max_i^{local} &= max(x_i) \\ max_i &= \max(max_{i-1}, max_i^{local}) \\ sum_i^\prime &= sum_{i-1} ^\prime e^{max_{i-1} - max_i} + \sum_{j=1}^{b} e^{x_i[j] - max_i}\\ o_i^\prime &= o_{i-1}^\prime \frac{sum_{i-1}^\prime}{sum_{i}^\prime}e^{max_{i-1} - max_i} + \sum_{j=1}^b \frac{e^{x_i[j] - max_i}}{sum_i^\prime} V[(i-1)b+j,:] \\ \text{end}\qquad \qquad\\ O[k,:] & = o_N \end{aligned} \]
附上一个简易的flash attention实现供参考:
import pytest
import torch
import torch.nn.functional as F
import math
import numpy as np
np.set_printoptions(suppress=True)
def flash_attn(query: np.ndarray, key: np.ndarray, value: np.ndarray, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False):
L, S = query.shape[-2], key.shape[-2]
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = np.zeros((L, S), dtype=query.dtype)
if is_causal:
assert attn_mask is None
temp_mask = np.tril(np.ones((L, S), dtype=np.bool_), k=0)
attn_bias[False == temp_mask] = float("-inf")
if attn_mask is not None:
if attn_mask.dtype == np.bool_:
attn_bias[False == attn_mask] = -np.inf
else:
attn_bias = attn_mask + attn_bias
assert enable_gqa is False, "GQA not implemented"
for head in range(query.shape[0]):
Q = query[head] # [query_len, dim]
K = key[head] # [seq_len, dim]
V = value[head] # [seq_len, dim]
O = np.zeros_like(Q, dtype=np.float32) # [query_len, dim]
Tc = 4
Tr = 16
assert L % Tr == 0
assert S % Tc == 0
global_maxs = np.zeros([Tr, L // Tr], dtype=np.float32)
global_sums = np.zeros([Tr, L // Tr], dtype=np.float32)
for j in range(0, S, Tc):
# outer loop is seq_len, because seq_len is `K` dimension, we can reuse Kj,Vj `query_len/Tc` times
Kj = K[j:j + Tc, :] # [Tc, dim]
Vj = V[j:j + Tc, :] # [Tc, dim]
for (ii, i) in enumerate(range(0, L, Tr)):
# load
Qi = Q[i:i + Tr, :] # [Tr, dim]
O_last = O[i:i + Tr, :] # [Tr, dim]
max_last = (np.zeros([Tr, 1], dtype=np.float32) - np.inf) if j == 0 else global_maxs[:, ii:ii + 1]
sum_last = np.zeros([Tr, 1], dtype=np.float32) if j == 0 else global_sums[:, ii:ii + 1]
a_ij = (Qi @ Kj.T) * scale_factor # [Tr, Tc]
a_ij += attn_bias[i:i + Tr, j:j + Tc]
max_local = np.max(a_ij, axis=1, keepdims=True) # [Tr, 1]
max_i = np.maximum(max_last, max_local) # [Tr, 1]
p_ij = np.exp(a_ij - max_i) # [Tr, Tc]
sum_i = sum_last * np.exp(max_last - max_i) + np.sum(p_ij, axis=1, keepdims=True) # [Tr, 1]
O_i = O_last * (sum_last * np.exp(max_last - max_i)) / sum_i + (p_ij / sum_i) @ Vj # [Tr, dim]
# store
O[i:i + Tr, :] = O_i
global_maxs[:, ii:ii + 1] = max_i
global_sums[:, ii:ii + 1] = sum_i
return O
@pytest.mark.parametrize("head_q, head_kv", [(1, 1)])
@pytest.mark.parametrize("query_len, seq_len", [(64, 64)])
@pytest.mark.parametrize("dim", [128])
@pytest.mark.parametrize("is_causal", [False, True])
@pytest.mark.parametrize("scale", [1.0])
def test_flash_attention(head_q, head_kv, query_len, seq_len, dim, is_causal, scale):
query = np.random.rand(head_q, query_len, dim).astype(np.float32) # [head_q, query_len, dim]
key = np.random.rand(head_kv, seq_len, dim).astype(np.float32) # [head_kv, seq_len, dim]
value = np.random.rand(head_kv, seq_len, dim).astype(np.float32) # [head_kv, seq_len, dim]
o = F.scaled_dot_product_attention(
torch.tensor(query), torch.tensor(key), torch.tensor(value), is_causal=is_causal, scale=scale)
o_np = o.numpy() # [q_head,query,dim]
o_actual = flash_attn(query, key, value, is_causal=is_causal, scale=scale)
assert np.allclose(o_np, o_actual, atol=1e-7)
if __name__ == "__main__":
pytest.main([__file__, "-vvs"])用 TileLang 在 Hopper 上实现 FA3
FA2 到 FA3 的提速主要来自调度,不是数学。FA3 论文 的核心想法是把 GEMM(张量核) 和 softmax 里的 exp(跑在独立的 multi-function unit 上) 重叠起来——理想情况下,exp 应该在张量核做 matmul 的时候算。论文给了两个层次的重叠:
Ping-pong(warpgroup 之间)。 开两个 consumer warpgroup,用 bar.sync 命名屏障强制 warpgroup 1 的 GEMM(本轮的 PV、下一轮的 QKᵀ)排在 warpgroup 2 的 GEMM 之前;于是 warpgroup 1 的 softmax 正好在 warpgroup 2 跑 GEMM 时被调度,反之亦然。两个 WG 像打乒乓一样交替占用张量核:一个在算 softmax,另一个就在跑 matmul。
Intra-warpgroup overlap(warpgroup 内部)。 单个 warpgroup 内部也把 GEMM 和 softmax 流水起来:第 j 轮的第二个 WGMMA(PV)和第 j+1 轮的 softmax 重叠,代价是多一组寄存器缓存下一轮的分数。
下面这份 TileLang 实现就是这两个重叠的落地:384 线程 = 1 个 TMA producer warpgroup + 2 个 consumer warpgroup,两个 consumer 之间做 ping-pong,每个 consumer 内部做上面的流水。
完整实现
k_fa3pp.py(点击展开)
"""Faithful FlashInfer FA3 port WITH 2-warpgroup ping-pong (CUTLASS WarpScheduler).
- 384 threads = 1 producer WG (TMA) + 2 consumer WG
- each consumer WG does its OWN m=64 query rows; per-WG Qs[g] (A-operand
offset must be 0, so split buffers instead of slicing rows)
- register-P (rs-wgmma): P held in a fp16 fragment (pcast), fed straight to PV.
The cast stays AFTER wait_wgmma(0) so PV(k-1) finishes reading pcast before
it is overwritten (fusing exp->pcast earlier serializes PV -> 114us).
- setmaxnreg: producer dealloc to 24 regs, consumers alloc 240 -> rs-P fits
(without this, rs-P spilled/serialized and lost to smem-P).
- WarpScheduler ping-pong via named barriers: WG0 sync(1)/arrive(2),
WG1 sync(2)/arrive(1); WG1 primes bar1 once -> WG0 softmax overlaps WG1 mma.
- delayed PV + wait_wgmma(1): softmax(k) overlaps in-flight PV(k-1).
"""
import tilelang
import tilelang.language as T
from tilelang.layout import make_swizzled_layout
import os as _os
BLOCK_M = 128
BLOCK_N = int(_os.environ.get("BN", "128"))
NSK = int(_os.environ.get("NSK", "2"))
NSV = int(_os.environ.get("NSV", "2"))
THREADS = 384
NMMA = 256
_FM = _os.environ.get("FM", "1") == "1"
_pc = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: _FM,
tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True}
_cf = [
"-O3",
] + (["--use_fast_math"] if _FM else []) + [
"-Wno-deprecated-declarations",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"-DNDEBUG",
]
@tilelang.jit(out_idx=[3], pass_configs=_pc, compile_flags=_cf)
def fa3_prefill(B, H, Hkv, Sq, Skv, D, dtype,
block_M=BLOCK_M, block_N=BLOCK_N, nsK=NSK, nsV=NSV, threads=THREADS):
scale = (1.0 / D) ** 0.5 * 1.44269504
groups = H // Hkv
co = Skv - Sq
accum = "float"
half = block_M // 2
Pol = T.GemmWarpPolicy.FullRow
@T.prim_func
def main(
Q: T.Tensor([B, Sq, H, D], dtype),
K: T.Tensor([B, Skv, Hkv, D], dtype),
V: T.Tensor([B, Skv, Hkv, D], dtype),
O: T.Tensor([B, Sq, H, D], dtype),
):
with T.Kernel(T.ceildiv(Sq, block_M), H, B, threads=threads) as (bx, by, bz):
Qs = T.alloc_shared([2, half, D], dtype)
Ks = T.alloc_shared([nsK, block_N, D], dtype)
Vs = T.alloc_shared([nsV, block_N, D], dtype)
Os = T.alloc_shared([2, half, D], dtype) # per-WG smem-staged output (FlashInfer epilogue)
T.annotate_layout({Qs: make_swizzled_layout(Qs), Ks: make_swizzled_layout(Ks),
Vs: make_swizzled_layout(Vs)})
q_bar = T.alloc_barrier([32]) # 1-warp producer (FlashInfer NUM_PRODUCER_THREADS=32)
kready = T.alloc_barrier([32] * nsK)
kfree = T.alloc_barrier([NMMA] * nsK)
vready = T.alloc_barrier([32] * nsV)
vfree = T.alloc_barrier([NMMA] * nsV)
cv = by // groups
q0 = bx * block_M
eff = T.min(T.ceildiv(Skv, block_N),
T.ceildiv(q0 + block_M + co, block_N))
tx = T.get_thread_binding()
if tx >= 256: # ================= producer =================
T.set_max_nreg(24, 0) # producer is TMA-only: release regs to consumers
if tx >= 256 and tx < 288: # only 1 warp issues TMA + waits (rest of WG idle)
T.tma_copy(Q[bz, q0:q0 + half, by, :], Qs[0, :, :], barrier=q_bar)
T.tma_copy(Q[bz, q0 + half:q0 + block_M, by, :], Qs[1, :, :], barrier=q_bar)
T.mbarrier_arrive(q_bar)
for k in T.serial(eff):
sk = k % nsK
T.mbarrier_wait_parity(kfree[sk], ((k // nsK) % 2) ^ 1)
T.tma_copy(K[bz, k * block_N:(k + 1) * block_N, cv, :], Ks[sk, :, :], barrier=kready[sk])
T.mbarrier_arrive(kready[sk])
sv = k % nsV
T.mbarrier_wait_parity(vfree[sv], ((k // nsV) % 2) ^ 1)
T.tma_copy(V[bz, k * block_N:(k + 1) * block_N, cv, :], Vs[sv, :, :], barrier=vready[sv])
T.mbarrier_arrive(vready[sv])
with T.ws(0):
T.set_max_nreg(240, 1) # consumer grabs producer's released regs
r0 = 0 * half
my_bar = 1
nxt_bar = 2
acc_s = T.alloc_fragment([half, block_N], accum)
pcast = T.alloc_fragment([half, block_N], dtype) # register-P (rs-wgmma)
acc_o = T.alloc_fragment([half, D], accum)
sm = T.alloc_fragment([half], accum)
smp = T.alloc_fragment([half], accum)
alpha = T.alloc_fragment([half], accum)
ss = T.alloc_fragment([half], accum)
logsum = T.alloc_fragment([half], accum)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(alpha, 1.0) # pipelined rescale: carried alpha, 1st rescale no-op
T.fill(sm, -T.infinity(accum))
T.mbarrier_wait_parity(q_bar, 0)
pass # WG0 goes first
# prologue: tile 0, QK + softmax (no PV)
T.sync_threads(my_bar, NMMA)
T.mbarrier_wait_parity(kready[0], 0)
T.wgmma_gemm(Qs[0, :, :], Ks[0, :, :], acc_s, transpose_B=True, policy=Pol, clear_accum=True)
T.named_barrier_arrive(nxt_bar, NMMA)
T.wait_wgmma(0)
T.mbarrier_arrive(kfree[0])
T.reduce_max(acc_s, sm, dim=1, clear=False)
for i, j in T.Parallel(half, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - sm[i] * scale)
T.reduce_sum(acc_s, ss, dim=1)
for i in T.Parallel(half):
logsum[i] = ss[i]
T.copy(acc_s, pcast)
nu = T.max(1, T.min(eff, T.floordiv(q0 + r0 + co + 1, block_N)))
for k in T.serial(1, nu):
sk = k % nsK
svp = (k - 1) % nsV
T.sync_threads(my_bar, NMMA)
T.mbarrier_wait_parity(kready[sk], (k // nsK) % 2)
T.wgmma_gemm(Qs[0, :, :], Ks[sk, :, :], acc_s, transpose_B=True, policy=Pol, clear_accum=True)
for i, j in T.Parallel(half, D): # pipelined rescale (prev alpha) covers QK latency
acc_o[i, j] *= alpha[i]
T.mbarrier_wait_parity(vready[svp], ((k - 1) // nsV) % 2)
T.wgmma_gemm(pcast, Vs[svp, :, :], acc_o, policy=Pol, clear_accum=False)
T.named_barrier_arrive(nxt_bar, NMMA)
T.wait_wgmma(1)
T.mbarrier_arrive(kfree[sk])
T.copy(sm, smp)
T.reduce_max(acc_s, sm, dim=1, clear=False)
for i in T.Parallel(half):
alpha[i] = T.exp2(smp[i] * scale - sm[i] * scale)
for i, j in T.Parallel(half, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - sm[i] * scale)
T.reduce_sum(acc_s, ss, dim=1)
T.wait_wgmma(0)
T.mbarrier_arrive(vfree[svp])
for i in T.Parallel(half):
logsum[i] = logsum[i] * alpha[i] + ss[i]
T.copy(acc_s, pcast)
for k in T.serial(nu, eff):
sk = k % nsK
svp = (k - 1) % nsV
T.sync_threads(my_bar, NMMA)
T.mbarrier_wait_parity(kready[sk], (k // nsK) % 2)
T.wgmma_gemm(Qs[0, :, :], Ks[sk, :, :], acc_s, transpose_B=True, policy=Pol, clear_accum=True)
for i, j in T.Parallel(half, D): # pipelined rescale (prev alpha)
acc_o[i, j] *= alpha[i]
T.mbarrier_wait_parity(vready[svp], ((k - 1) // nsV) % 2)
T.wgmma_gemm(pcast, Vs[svp, :, :], acc_o, policy=Pol, clear_accum=False)
T.named_barrier_arrive(nxt_bar, NMMA)
T.wait_wgmma(1)
T.mbarrier_arrive(kfree[sk])
for i, j in T.Parallel(half, block_N):
acc_s[i, j] = T.if_then_else(
q0 + r0 + i + co >= k * block_N + j, acc_s[i, j], -T.infinity(accum))
T.copy(sm, smp)
T.reduce_max(acc_s, sm, dim=1, clear=False)
for i in T.Parallel(half):
alpha[i] = T.exp2(smp[i] * scale - sm[i] * scale)
for i, j in T.Parallel(half, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - sm[i] * scale)
T.reduce_sum(acc_s, ss, dim=1)
T.wait_wgmma(0)
T.mbarrier_arrive(vfree[svp])
for i in T.Parallel(half):
logsum[i] = logsum[i] * alpha[i] + ss[i]
T.copy(acc_s, pcast)
svp = (eff - 1) % nsV
for i, j in T.Parallel(half, D): # pipelined rescale: final alpha before last PV
acc_o[i, j] *= alpha[i]
T.mbarrier_wait_parity(vready[svp], ((eff - 1) // nsV) % 2)
T.wgmma_gemm(pcast, Vs[svp, :, :], acc_o, policy=Pol, clear_accum=False)
T.wait_wgmma(0)
T.mbarrier_arrive(vfree[svp])
for i, j in T.Parallel(half, D):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Os[0, :, :]) # registers -> smem
T.copy(Os[0, :, :], O[bz, q0 + r0:q0 + r0 + half, by, :]) # smem -> global (coalesced)
with T.ws(1):
T.set_max_nreg(240, 1) # consumer grabs producer's released regs
r0 = 1 * half
my_bar = 2
nxt_bar = 1
acc_s = T.alloc_fragment([half, block_N], accum)
pcast = T.alloc_fragment([half, block_N], dtype) # register-P (rs-wgmma)
acc_o = T.alloc_fragment([half, D], accum)
sm = T.alloc_fragment([half], accum)
smp = T.alloc_fragment([half], accum)
alpha = T.alloc_fragment([half], accum)
ss = T.alloc_fragment([half], accum)
logsum = T.alloc_fragment([half], accum)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(alpha, 1.0) # pipelined rescale: carried alpha, 1st rescale no-op
T.fill(sm, -T.infinity(accum))
T.mbarrier_wait_parity(q_bar, 0)
T.named_barrier_arrive(1, NMMA) # prime WG0
# prologue: tile 0, QK + softmax (no PV)
T.sync_threads(my_bar, NMMA)
T.mbarrier_wait_parity(kready[0], 0)
T.wgmma_gemm(Qs[1, :, :], Ks[0, :, :], acc_s, transpose_B=True, policy=Pol, clear_accum=True)
T.named_barrier_arrive(nxt_bar, NMMA)
T.wait_wgmma(0)
T.mbarrier_arrive(kfree[0])
T.reduce_max(acc_s, sm, dim=1, clear=False)
for i, j in T.Parallel(half, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - sm[i] * scale)
T.reduce_sum(acc_s, ss, dim=1)
for i in T.Parallel(half):
logsum[i] = ss[i]
T.copy(acc_s, pcast)
nu = T.max(1, T.min(eff, T.floordiv(q0 + r0 + co + 1, block_N)))
for k in T.serial(1, nu):
sk = k % nsK
svp = (k - 1) % nsV
T.sync_threads(my_bar, NMMA)
T.mbarrier_wait_parity(kready[sk], (k // nsK) % 2)
T.wgmma_gemm(Qs[1, :, :], Ks[sk, :, :], acc_s, transpose_B=True, policy=Pol, clear_accum=True)
for i, j in T.Parallel(half, D): # pipelined rescale (prev alpha) covers QK latency
acc_o[i, j] *= alpha[i]
T.mbarrier_wait_parity(vready[svp], ((k - 1) // nsV) % 2)
T.wgmma_gemm(pcast, Vs[svp, :, :], acc_o, policy=Pol, clear_accum=False)
T.named_barrier_arrive(nxt_bar, NMMA)
T.wait_wgmma(1)
T.mbarrier_arrive(kfree[sk])
T.copy(sm, smp)
T.reduce_max(acc_s, sm, dim=1, clear=False)
for i in T.Parallel(half):
alpha[i] = T.exp2(smp[i] * scale - sm[i] * scale)
for i, j in T.Parallel(half, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - sm[i] * scale)
T.reduce_sum(acc_s, ss, dim=1)
T.wait_wgmma(0)
T.mbarrier_arrive(vfree[svp])
for i in T.Parallel(half):
logsum[i] = logsum[i] * alpha[i] + ss[i]
T.copy(acc_s, pcast)
for k in T.serial(nu, eff):
sk = k % nsK
svp = (k - 1) % nsV
T.sync_threads(my_bar, NMMA)
T.mbarrier_wait_parity(kready[sk], (k // nsK) % 2)
T.wgmma_gemm(Qs[1, :, :], Ks[sk, :, :], acc_s, transpose_B=True, policy=Pol, clear_accum=True)
for i, j in T.Parallel(half, D): # pipelined rescale (prev alpha)
acc_o[i, j] *= alpha[i]
T.mbarrier_wait_parity(vready[svp], ((k - 1) // nsV) % 2)
T.wgmma_gemm(pcast, Vs[svp, :, :], acc_o, policy=Pol, clear_accum=False)
T.named_barrier_arrive(nxt_bar, NMMA)
T.wait_wgmma(1)
T.mbarrier_arrive(kfree[sk])
for i, j in T.Parallel(half, block_N):
acc_s[i, j] = T.if_then_else(
q0 + r0 + i + co >= k * block_N + j, acc_s[i, j], -T.infinity(accum))
T.copy(sm, smp)
T.reduce_max(acc_s, sm, dim=1, clear=False)
for i in T.Parallel(half):
alpha[i] = T.exp2(smp[i] * scale - sm[i] * scale)
for i, j in T.Parallel(half, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - sm[i] * scale)
T.reduce_sum(acc_s, ss, dim=1)
T.wait_wgmma(0)
T.mbarrier_arrive(vfree[svp])
for i in T.Parallel(half):
logsum[i] = logsum[i] * alpha[i] + ss[i]
T.copy(acc_s, pcast)
svp = (eff - 1) % nsV
for i, j in T.Parallel(half, D): # pipelined rescale: final alpha before last PV
acc_o[i, j] *= alpha[i]
T.mbarrier_wait_parity(vready[svp], ((eff - 1) // nsV) % 2)
T.wgmma_gemm(pcast, Vs[svp, :, :], acc_o, policy=Pol, clear_accum=False)
T.wait_wgmma(0)
T.mbarrier_arrive(vfree[svp])
for i, j in T.Parallel(half, D):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Os[1, :, :]) # registers -> smem
T.copy(Os[1, :, :], O[bz, q0 + r0:q0 + r0 + half, by, :]) # smem -> global (coalesced)
return main逻辑顺序 vs 实际发射时机
在线 softmax 每个 K-block 的逻辑顺序是固定的四步:
QK → softmax(更新 max、算
P、更新 sum) → rescale(acc_o *= exp(m_prev − m_new),把已累加的输出修正到新 max) → PV(acc_o += P·V)
但为了凑成上面两种重叠,实际 kernel 把这四步打散、错位发射:
- delayed PV:第
j轮迭代发射的是上一块的PV(j-1),这样它的 wgmma 能盖在本轮softmax(j)底下(intra-warpgroup overlap)。 - carried-alpha rescale:
rescale用的是上一轮 softmax 算好的alpha,挪到 QK 之后、PV 之前,用来盖住 QK 的延迟。
结果是一个物理迭代 j 里混了两个逻辑 tile 的操作:
| 图上的块 | 逻辑归属 |
|---|---|
QK·j、softmax·j |
tile j 的前半:算这块的分数和 P |
rescale·(j-1)、PV·(j-1) |
tile j-1 的后半:用上一块的 alpha 修正、再把上一块的 P·V 累加进去 |
两个边界情况:tile 0 在 prologue 里只做 QK·0 和 softmax·0(没有可累加的历史,不需要 rescale/PV);rescale 从 rescale·0 开始,但它是 alpha=1 的空操作(此时 acc_o 还是 0),第一个真正起作用的是 rescale·1。
流水线长什么样
上面这套错位发射,直接画成时间线最直观。我写了个带打点的版本:在三个 warpgroup 的每个 tile 级操作边界用 clock64() 记一个时间戳到 log tensor,跑单个 CTA(B=H=1,4 个 K-tile),再把时间线画出来,做成一张交互式全屏图(滚轮左右平移,放大/缩小按钮拉伸时间轴,hover 看每块的周期数)。图比较宽,嵌在正文里太挤,单独开了一页:
怎么读这张图:
- 每个 consumer WG 有两条泳道:上面
Consumer WGx(SIMT)是 warpgroup 线程按程序序在干的事(各种 wait、rescale、softmax);下面⤷ WGx tensor core是 wgmma 的异步在飞区间(从发射到把它 drain 掉的那次 wait)。 - intra-warpgroup 异步就是 tensor-core 泳道的条盖在上方 SIMT 泳道底下——比如
PV·(k-1)整条盖在softmax·k下面,说明 PV 在张量核上跑的同时,线程在算 softmax。 - softmax 是大头(每块约 2000 cycle),QK/PV 的发射只是一瞬;真正的矩阵乘时间被藏在 softmax 底下。
- 紫色箭头是 ping-pong handoff:某个 WG 发完 wgmma →
arrive→ 放行对面的sync(每行开头那个很小的wait peer块);箭头中点标了用的是bar1还是bar2。 - producer 大段停在
wait buffer free,说明 TMA 远不是瓶颈,它早把数据搬好在等消费者腾 buffer。
这是打点版,
clock64读取和额外的 global store 会轻微扰动时序,别拿它读绝对性能;而且单线程 clock 对不准跨 warpgroup 的屏障因果(偶尔会出现 arrive 读数晚于对面 sync 的反向箭头),它适合看结构与重叠关系,不适合掐精确延迟。完整的打点 kernel 与画图脚本和这里贴的tilelang_fa3.py配套。