Flash Attention记录

推理框架
Published

September 26, 2025

简单记录一下flash attention的推导和实现。

Naive Attention

\[ \begin{aligned} S &= QK^T \in \mathbb{R}^{N\times N} \\ P &= softmax(S) \\ O &= PV \in \mathbb{R}^{N \times d} \end{aligned} \]

本质上就是两个matmul中间有一个伪elementwise操作。 主要由于softmax每个输出点依赖了他的所有input,导致无法进行tiling fusion。原因如下

SoftMax

naive softmax实现:

\[ \begin{aligned} softmax(x_i, \ldots, x_N) = \frac{e ^ {x_i}}{\sum_{j=1}^{N} e^{x_j}}, i \in [1, N] \end{aligned} \]

由于\(e^x\)会很大,容易出现数值溢出,因此出现safe softmax

safe softmax实现

\[ \begin{aligned} max_{N} &= max(x_i), \ i \in [1, N] \\ softmax(x_i, \ldots, x_N) &= \frac{e ^ {x_i - max_{N}}}{\sum_{j=1}^{N} e^{x_j - max_{N}}}, i \in [1, N] \end{aligned} \]

但是在实现上需要循环三次,因为\(max_N\)\(x_{sum}\)都需要单独的循环

\[ \begin{aligned} max_i &= max(m_{i-1}, x_i)\\ sum_i &= sum_{i-1} + e^{x_i - max_N} \\ a_i &= \frac{e^{x_i - max_N}}{sum_N}, \ i \in [1, N] \end{aligned} \]

2-pass softmax

说实话,让我是没办法能想出来融合max以及sum的公式,可能作者也是受Welford方法启发的。 首先我们考虑把\(sum_N\)的公式展开,通过\(exp\)的计算性质把\(- max_N\)这个拆分为两个部分 \[ \begin{aligned} sum_{N} &= \sum_{j = 1} ^ N e^{x_j - max_N} \\ \\ &= \sum_{j = 1} ^ {N-1} e^{x_j - max_N} + e^{x_N - max_N} \\ &= \sum_{j = 1} ^ {N-1} e^{x_j - max_{N-1} + max_{N-1} - max_N} + e^{x_N - max_N} \\ &= (\sum_{j = 1} ^ {N-1} e^{x_j - max_{N-1}}) e^{max_{N-1} - max_N} +e^{x_N - max_N} \\ \end{aligned} \]

观察上面的公式,发现如果从另一个视角去定义变量就可以让他们递归起来 \[ \begin{aligned} \text{let}\ sum_{N}^\prime &=\sum_{j=1}^{N} e^{x_j - max_N} = sum_{N} \\ &= (\sum_{j = 1} ^ {N-1} e^{x_j - max_{N-1}}) e^{max_{N-1} - max_N} +e^{x_N - max_N} \\ &= sum_{N-1}^\prime e^{max_{N-1} - max_N} +e^{x_N - max_N} \\ sum_i ^\prime &= sum_{i-1} ^\prime e^{max_{i-1} - max_i} +e^{x_i - max_i} \end{aligned} \]

通过视角的转换将\(max_N\)\(sum_{i}\)进行了解耦,并且当迭代到最后时\(sum_{N}^\prime = sum_{N}\)。 虽然2-pass的方式需要在每次迭代添加额外的乘\(e^{max_{i-1} - max_i}\)的运算,但显然比访存开销低很多。

Flash Attention

2-pass Attention

首先使用2-pass的softmax来实现一个attention,这里为了不混淆query lenseq len, 分别用ki来表示。

\[ \begin{aligned} \text{for i in [1, N]}:&\\ x_i &= Q[k, :]K^T[:, i]\\ max_i &= \max(max_{i-1}, x_i) \\ sum_i^\prime &= sum_{i-1} ^\prime e^{max_{i-1} - max_i} +e^{x_i - max_i}\\ \text{end} \qquad \qquad \\ \text{for i in [1, N]}:&\\ a_i &= \frac{e^{x_i - max_N}}{sum_N^\prime} \\ o_i &= o_{i-1} + a_i V[i,:] \\ \text{end} \qquad \qquad \\ O[k,:] & = o_N \end{aligned} \]

1-pass Attention

在和V做矩阵乘时,每一个\(o_i\)还是依赖了\(max_N\)。 接下来就是找到办法把\(max_N\)的依赖消除。参考2-pass softmax的套路先定义: \[ \begin{aligned} o_N^\prime &= \sum_{i = 1} ^ {N} a_i V[i,:] \\ &= \sum_{i = 1} ^ {N} \frac{e^{x_i - max_N}}{sum_N^\prime} V[i,:] \\ &= (\sum_{i = 1} ^ {N-1} \frac{e^{x_i - max_N}}{sum_N^\prime} V[i,:]) + \frac{e^{x_N - max_N}}{sum_N^\prime} V[N,:] \\ &= (\sum_{i = 1} ^ {N-1} \frac{e^{x_i - max_N}}{sum_N^\prime} \frac{sum_{N-1}^\prime}{sum_{N-1}^\prime} \frac{e^{x_i - max_{N-1}}}{e^{x_i - max_{N-1}}} V[i,:]) + \frac{e^{x_N - max_N}}{sum_N^\prime} V[N,:] \\ &= (\sum_{i = 1} ^ {N-1} \frac{e^{x_i - max_{N-1}}}{sum_{N-1}^\prime} V[i,:]) \frac{sum_{N-1}^\prime}{sum_{N}^\prime}\frac{e^{x_i - max_N}}{e^{x_i - max_{N-1}}} + \frac{e^{x_N - max_N}}{sum_N^\prime} V[N,:] \\ &= (\sum_{i = 1} ^ {N-1} \frac{e^{x_i - max_{N-1}}}{sum_{N-1}^\prime} V[i,:]) \frac{sum_{N-1}^\prime}{sum_{N}^\prime}e^{max_{N-1} - max_N} + \frac{e^{x_N - max_N}}{sum_N^\prime} V[N,:] \\ &= o_{N-1}^\prime \frac{sum_{N-1}^\prime}{sum_{N}^\prime}e^{max_{N-1} - max_N} + \frac{e^{x_N - max_N}}{sum_N^\prime} V[N,:] \\ \end{aligned} \] 然后归纳得到不包含\(max_N\)\(o_i^\prime\)公式为: \[ \begin{aligned} o_i^\prime &= o_{i-1}^\prime \frac{sum_{i-1}^\prime}{sum_{i}^\prime}e^{max_{i-1} - max_i} + \frac{e^{x_i - max_i}}{sum_i^\prime} V[i,:] \end{aligned} \]

最终列出标量化的1-pass Attention形式: \[ \begin{aligned} \text{for i in [1, N]}:&\\ x_i &= Q[k, :]K^T[:, i]\\ max_i &= \max(max_{i-1}, x_i) \\ sum_i^\prime &= sum_{i-1} ^\prime e^{max_{i-1} - max_i} +e^{x_i - max_i}\\ o_i^\prime &= o_{i-1}^\prime \frac{sum_{i-1}^\prime}{sum_{i}^\prime}e^{max_{i-1} - max_i} + \frac{e^{x_i - max_i}}{sum_i^\prime} V[i,:] \\ \text{end} \qquad \qquad\\ O[k,:] & = o_N \end{aligned} \]

Flash Attention v1

上面推导出来的1-pass attention是基于标量循环的,对于flash attention是需要按tile进行计算的,所以具体的公式还需要稍作修改。

首先列出普通的softmax计算公式:

\[ \begin{aligned} X & = [x_1, \ldots , x_N] \\ max_N & = \max(X) \\ &= \max([x_1, \ldots , x_N]) \\ f(X) & = [f(x_1), \ldots , f(x_N)] \\ &= [e^{x_1 - max_N}, \ldots, e^{x_N - max_N}] \\ sum_N & = \sum_{i = 1}^N f(x_i) \\ & = \sum_{i = 1}^N e^{x_i - max_N} \\ softmax(X) & = \frac{ f(X)}{sum_N} \end{aligned} \]

现在来推导tiled softmax的计算公式, 那么假设现在的\(X\)是由两个长度为\(N\)的子向量组成的, 那么首先把它看成单个向量计算,然后拆分转换为可分治的公式: \[ \begin{aligned} X & = [x^1, x^2] \\ max_{2N} & = \max([\max(x^1),\max(x^2)]) \\ & = \max([max_N^1, max_N^2]) \\ f(X) &= \left[ [e^{x_1^1 - max_{2N}},\ldots, e^{x_N^1 - max_{2N}}] , [e^{x_1^2 - max_{2N}},\ldots, e^{x_N^2 - max_{2N}}] \right] \\ &= \left[ e^{max_N^1 - max_{2N}} [e^{x_1^1 - max_N^1},\ldots, e^{x_N^1 - max_N^1}] , e^{max_N^2 - max_{2N}} [e^{x_1^2 - max_N^2},\ldots, e^{x_N^2 - max_N^2}] \right] \\ &= \left[ e^{max_N^1 - max_{2N}} f(x^1) , e^{max_N^2 - max_{2N}} f(x^2) \right] \\ sum_{2N} & = \sum_{i = 1}^N e^{x_i^1 - max_{2N}} + \sum_{i = 1}^N e^{x_i^2 - max_{2N}} \\ & = e^{max_N^1 - max_{2N}} \sum_{i = 1}^N e^{x_i^1 - max_{N}^1} + e^{max_N^2 - max_{2N}} \sum_{i = 1}^N e^{x_i^2 - max_{N}^2} \\ & = e^{max_N^1 - max_{2N}} sum_N^1 + e^{max_N^2 - max_{2N}} sum_N^2 \\ softmax(X) &= \frac{f(X)}{sum_{2N}} \end{aligned} \]

此时可以发现,除了每个子向量的 \(max^j\) 用于计算 \(f(x^j), sum^j\),还需要维护整体的\(max, sum\)用于计算最终的结果。

flash attention的tiling就是将\(x_i\)向量化,基于1-pass attention的公式,结合tiled softmax公式,只需要略微修改\(max_i,sum_i\)的计算即可得到flash attention的公式: \[ \begin{aligned} \text{for i in [1, N/b]}:&\\ x_i &= Q[k, :]K^T[:, i:i+b]\\ max_i^{local} &= max(x_i) \\ max_i &= \max(max_{i-1}, max_i^{local}) \\ sum_i^\prime &= sum_{i-1} ^\prime e^{max_{i-1} - max_i} + \sum_{j=1}^{b} e^{x_i[j] - max_i}\\ o_i^\prime &= o_{i-1}^\prime \frac{sum_{i-1}^\prime}{sum_{i}^\prime}e^{max_{i-1} - max_i} + \sum_{j=1}^b \frac{e^{x_i[j] - max_i}}{sum_i^\prime} V[(i-1)b+j,:] \\ \text{end}\qquad \qquad\\ O[k,:] & = o_N \end{aligned} \]

附上一个简易的flash attention实现供参考:

import pytest
import torch
import torch.nn.functional as F
import math
import numpy as np


np.set_printoptions(suppress=True)


def flash_attn(query: np.ndarray, key: np.ndarray, value: np.ndarray, attn_mask=None, dropout_p=0.0,
               is_causal=False, scale=None, enable_gqa=False):
  L, S = query.shape[-2], key.shape[-2]
  scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
  attn_bias = np.zeros((L, S), dtype=query.dtype)
  if is_causal:
    assert attn_mask is None
    temp_mask = np.tril(np.ones((L, S), dtype=np.bool_), k=0)
    attn_bias[False == temp_mask] = float("-inf")

  if attn_mask is not None:
    if attn_mask.dtype == np.bool_:
      attn_bias[False == attn_mask] = -np.inf
    else:
      attn_bias = attn_mask + attn_bias

  assert enable_gqa is False, "GQA not implemented"

  for head in range(query.shape[0]):
    Q = query[head]  # [query_len, dim]
    K = key[head]  # [seq_len, dim]
    V = value[head]  # [seq_len, dim]
    O = np.zeros_like(Q, dtype=np.float32)  # [query_len, dim]
    Tc = 4
    Tr = 16
    assert L % Tr == 0
    assert S % Tc == 0

    global_maxs = np.zeros([Tr, L // Tr], dtype=np.float32)
    global_sums = np.zeros([Tr, L // Tr], dtype=np.float32)
    for j in range(0, S, Tc):
      # outer loop is seq_len, because seq_len is `K` dimension, we can reuse Kj,Vj `query_len/Tc` times
      Kj = K[j:j + Tc, :]  # [Tc, dim]
      Vj = V[j:j + Tc, :]  # [Tc, dim]
      for (ii, i) in enumerate(range(0, L, Tr)):
        # load
        Qi = Q[i:i + Tr, :]  # [Tr, dim]
        O_last = O[i:i + Tr, :]  # [Tr, dim]
        max_last = (np.zeros([Tr, 1], dtype=np.float32) - np.inf) if j == 0 else global_maxs[:, ii:ii + 1]
        sum_last = np.zeros([Tr, 1], dtype=np.float32) if j == 0 else global_sums[:, ii:ii + 1]

        a_ij = (Qi @ Kj.T) * scale_factor  # [Tr, Tc]
        a_ij += attn_bias[i:i + Tr, j:j + Tc]
        max_local = np.max(a_ij, axis=1, keepdims=True)  # [Tr, 1]
        max_i = np.maximum(max_last, max_local)  # [Tr, 1]
        p_ij = np.exp(a_ij - max_i)  # [Tr, Tc]
        sum_i = sum_last * np.exp(max_last - max_i) + np.sum(p_ij, axis=1, keepdims=True)  # [Tr, 1]
        O_i = O_last * (sum_last * np.exp(max_last - max_i)) / sum_i + (p_ij / sum_i) @ Vj  # [Tr, dim]

        # store
        O[i:i + Tr, :] = O_i
        global_maxs[:, ii:ii + 1] = max_i
        global_sums[:, ii:ii + 1] = sum_i
    return O


@pytest.mark.parametrize("head_q, head_kv", [(1, 1)])
@pytest.mark.parametrize("query_len, seq_len", [(64, 64)])
@pytest.mark.parametrize("dim", [128])
@pytest.mark.parametrize("is_causal", [False, True])
@pytest.mark.parametrize("scale", [1.0])
def test_flash_attention(head_q, head_kv, query_len, seq_len, dim, is_causal, scale):
  query = np.random.rand(head_q, query_len, dim).astype(np.float32)  # [head_q, query_len, dim]
  key = np.random.rand(head_kv, seq_len, dim).astype(np.float32)  # [head_kv, seq_len, dim]
  value = np.random.rand(head_kv, seq_len, dim).astype(np.float32)  # [head_kv, seq_len, dim]

  o = F.scaled_dot_product_attention(
      torch.tensor(query), torch.tensor(key), torch.tensor(value), is_causal=is_causal, scale=scale)
  o_np = o.numpy()  # [q_head,query,dim]

  o_actual = flash_attn(query, key, value, is_causal=is_causal, scale=scale)

  assert np.allclose(o_np, o_actual, atol=1e-7)


if __name__ == "__main__":
  pytest.main([__file__, "-vvs"])

用 TileLang 在 Hopper 上实现 FA3

FA2 到 FA3 的提速主要来自调度,不是数学。FA3 论文 的核心想法是把 GEMM(张量核)softmax 里的 exp(跑在独立的 multi-function unit 上) 重叠起来——理想情况下,exp 应该在张量核做 matmul 的时候算。论文给了两个层次的重叠:

Ping-pong(warpgroup 之间)。 开两个 consumer warpgroup,用 bar.sync 命名屏障强制 warpgroup 1 的 GEMM(本轮的 PV、下一轮的 QKᵀ)排在 warpgroup 2 的 GEMM 之前;于是 warpgroup 1 的 softmax 正好在 warpgroup 2 跑 GEMM 时被调度,反之亦然。两个 WG 像打乒乓一样交替占用张量核:一个在算 softmax,另一个就在跑 matmul。

Intra-warpgroup overlap(warpgroup 内部)。 单个 warpgroup 内部也把 GEMM 和 softmax 流水起来:第 j 轮的第二个 WGMMA(PV)和第 j+1 轮的 softmax 重叠,代价是多一组寄存器缓存下一轮的分数。

下面这份 TileLang 实现就是这两个重叠的落地:384 线程 = 1 个 TMA producer warpgroup + 2 个 consumer warpgroup,两个 consumer 之间做 ping-pong,每个 consumer 内部做上面的流水。

完整实现

"""Faithful FlashInfer FA3 port WITH 2-warpgroup ping-pong (CUTLASS WarpScheduler).

  - 384 threads = 1 producer WG (TMA) + 2 consumer WG
  - each consumer WG does its OWN m=64 query rows; per-WG Qs[g] (A-operand
    offset must be 0, so split buffers instead of slicing rows)
  - register-P (rs-wgmma): P held in a fp16 fragment (pcast), fed straight to PV.
    The cast stays AFTER wait_wgmma(0) so PV(k-1) finishes reading pcast before
    it is overwritten (fusing exp->pcast earlier serializes PV -> 114us).
  - setmaxnreg: producer dealloc to 24 regs, consumers alloc 240 -> rs-P fits
    (without this, rs-P spilled/serialized and lost to smem-P).
  - WarpScheduler ping-pong via named barriers: WG0 sync(1)/arrive(2),
    WG1 sync(2)/arrive(1); WG1 primes bar1 once -> WG0 softmax overlaps WG1 mma.
  - delayed PV + wait_wgmma(1): softmax(k) overlaps in-flight PV(k-1).
"""
import tilelang
import tilelang.language as T
from tilelang.layout import make_swizzled_layout

import os as _os
BLOCK_M = 128
BLOCK_N = int(_os.environ.get("BN", "128"))
NSK = int(_os.environ.get("NSK", "2"))
NSV = int(_os.environ.get("NSV", "2"))
THREADS = 384
NMMA = 256


_FM = _os.environ.get("FM", "1") == "1"
_pc = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: _FM,
       tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True}
_cf = [
    "-O3",
] + (["--use_fast_math"] if _FM else []) + [
    "-Wno-deprecated-declarations",
    "-U__CUDA_NO_HALF_OPERATORS__",
    "-U__CUDA_NO_HALF_CONVERSIONS__",
    "-U__CUDA_NO_HALF2_OPERATORS__",
    "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
    "--expt-relaxed-constexpr",
    "--expt-extended-lambda",
    "-DNDEBUG",
]


@tilelang.jit(out_idx=[3], pass_configs=_pc, compile_flags=_cf)
def fa3_prefill(B, H, Hkv, Sq, Skv, D, dtype,
                block_M=BLOCK_M, block_N=BLOCK_N, nsK=NSK, nsV=NSV, threads=THREADS):
    scale = (1.0 / D) ** 0.5 * 1.44269504
    groups = H // Hkv
    co = Skv - Sq
    accum = "float"
    half = block_M // 2
    Pol = T.GemmWarpPolicy.FullRow

    @T.prim_func
    def main(
        Q: T.Tensor([B, Sq, H, D], dtype),
        K: T.Tensor([B, Skv, Hkv, D], dtype),
        V: T.Tensor([B, Skv, Hkv, D], dtype),
        O: T.Tensor([B, Sq, H, D], dtype),
    ):
        with T.Kernel(T.ceildiv(Sq, block_M), H, B, threads=threads) as (bx, by, bz):
            Qs = T.alloc_shared([2, half, D], dtype)
            Ks = T.alloc_shared([nsK, block_N, D], dtype)
            Vs = T.alloc_shared([nsV, block_N, D], dtype)
            Os = T.alloc_shared([2, half, D], dtype)  # per-WG smem-staged output (FlashInfer epilogue)
            T.annotate_layout({Qs: make_swizzled_layout(Qs), Ks: make_swizzled_layout(Ks),
                               Vs: make_swizzled_layout(Vs)})

            q_bar = T.alloc_barrier([32])             # 1-warp producer (FlashInfer NUM_PRODUCER_THREADS=32)
            kready = T.alloc_barrier([32] * nsK)
            kfree = T.alloc_barrier([NMMA] * nsK)
            vready = T.alloc_barrier([32] * nsV)
            vfree = T.alloc_barrier([NMMA] * nsV)

            cv = by // groups
            q0 = bx * block_M
            eff = T.min(T.ceildiv(Skv, block_N),
                        T.ceildiv(q0 + block_M + co, block_N))
            tx = T.get_thread_binding()

            if tx >= 256:  # ================= producer =================
                T.set_max_nreg(24, 0)  # producer is TMA-only: release regs to consumers
            if tx >= 256 and tx < 288:  # only 1 warp issues TMA + waits (rest of WG idle)
                T.tma_copy(Q[bz, q0:q0 + half, by, :], Qs[0, :, :], barrier=q_bar)
                T.tma_copy(Q[bz, q0 + half:q0 + block_M, by, :], Qs[1, :, :], barrier=q_bar)
                T.mbarrier_arrive(q_bar)
                for k in T.serial(eff):
                    sk = k % nsK
                    T.mbarrier_wait_parity(kfree[sk], ((k // nsK) % 2) ^ 1)
                    T.tma_copy(K[bz, k * block_N:(k + 1) * block_N, cv, :], Ks[sk, :, :], barrier=kready[sk])
                    T.mbarrier_arrive(kready[sk])
                    sv = k % nsV
                    T.mbarrier_wait_parity(vfree[sv], ((k // nsV) % 2) ^ 1)
                    T.tma_copy(V[bz, k * block_N:(k + 1) * block_N, cv, :], Vs[sv, :, :], barrier=vready[sv])
                    T.mbarrier_arrive(vready[sv])

            with T.ws(0):
                    T.set_max_nreg(240, 1)  # consumer grabs producer's released regs
                    r0 = 0 * half
                    my_bar = 1
                    nxt_bar = 2
                    acc_s = T.alloc_fragment([half, block_N], accum)
                    pcast = T.alloc_fragment([half, block_N], dtype)  # register-P (rs-wgmma)
                    acc_o = T.alloc_fragment([half, D], accum)
                    sm = T.alloc_fragment([half], accum)
                    smp = T.alloc_fragment([half], accum)
                    alpha = T.alloc_fragment([half], accum)
                    ss = T.alloc_fragment([half], accum)
                    logsum = T.alloc_fragment([half], accum)

                    T.fill(acc_o, 0)
                    T.fill(logsum, 0)
                    T.fill(alpha, 1.0)  # pipelined rescale: carried alpha, 1st rescale no-op
                    T.fill(sm, -T.infinity(accum))
                    T.mbarrier_wait_parity(q_bar, 0)
                    pass  # WG0 goes first

                    # prologue: tile 0, QK + softmax (no PV)
                    T.sync_threads(my_bar, NMMA)
                    T.mbarrier_wait_parity(kready[0], 0)
                    T.wgmma_gemm(Qs[0, :, :], Ks[0, :, :], acc_s, transpose_B=True, policy=Pol, clear_accum=True)
                    T.named_barrier_arrive(nxt_bar, NMMA)
                    T.wait_wgmma(0)
                    T.mbarrier_arrive(kfree[0])
                    T.reduce_max(acc_s, sm, dim=1, clear=False)
                    for i, j in T.Parallel(half, block_N):
                        acc_s[i, j] = T.exp2(acc_s[i, j] * scale - sm[i] * scale)
                    T.reduce_sum(acc_s, ss, dim=1)
                    for i in T.Parallel(half):
                        logsum[i] = ss[i]
                    T.copy(acc_s, pcast)

                    nu = T.max(1, T.min(eff, T.floordiv(q0 + r0 + co + 1, block_N)))
                    for k in T.serial(1, nu):
                        sk = k % nsK
                        svp = (k - 1) % nsV
                        T.sync_threads(my_bar, NMMA)
                        T.mbarrier_wait_parity(kready[sk], (k // nsK) % 2)
                        T.wgmma_gemm(Qs[0, :, :], Ks[sk, :, :], acc_s, transpose_B=True, policy=Pol, clear_accum=True)
                        for i, j in T.Parallel(half, D):  # pipelined rescale (prev alpha) covers QK latency
                            acc_o[i, j] *= alpha[i]
                        T.mbarrier_wait_parity(vready[svp], ((k - 1) // nsV) % 2)
                        T.wgmma_gemm(pcast, Vs[svp, :, :], acc_o, policy=Pol, clear_accum=False)
                        T.named_barrier_arrive(nxt_bar, NMMA)
                        T.wait_wgmma(1)
                        T.mbarrier_arrive(kfree[sk])
                        T.copy(sm, smp)
                        T.reduce_max(acc_s, sm, dim=1, clear=False)
                        for i in T.Parallel(half):
                            alpha[i] = T.exp2(smp[i] * scale - sm[i] * scale)
                        for i, j in T.Parallel(half, block_N):
                            acc_s[i, j] = T.exp2(acc_s[i, j] * scale - sm[i] * scale)
                        T.reduce_sum(acc_s, ss, dim=1)
                        T.wait_wgmma(0)
                        T.mbarrier_arrive(vfree[svp])
                        for i in T.Parallel(half):
                            logsum[i] = logsum[i] * alpha[i] + ss[i]
                        T.copy(acc_s, pcast)
                    for k in T.serial(nu, eff):
                        sk = k % nsK
                        svp = (k - 1) % nsV
                        T.sync_threads(my_bar, NMMA)
                        T.mbarrier_wait_parity(kready[sk], (k // nsK) % 2)
                        T.wgmma_gemm(Qs[0, :, :], Ks[sk, :, :], acc_s, transpose_B=True, policy=Pol, clear_accum=True)
                        for i, j in T.Parallel(half, D):  # pipelined rescale (prev alpha)
                            acc_o[i, j] *= alpha[i]
                        T.mbarrier_wait_parity(vready[svp], ((k - 1) // nsV) % 2)
                        T.wgmma_gemm(pcast, Vs[svp, :, :], acc_o, policy=Pol, clear_accum=False)
                        T.named_barrier_arrive(nxt_bar, NMMA)
                        T.wait_wgmma(1)
                        T.mbarrier_arrive(kfree[sk])
                        for i, j in T.Parallel(half, block_N):
                            acc_s[i, j] = T.if_then_else(
                                q0 + r0 + i + co >= k * block_N + j, acc_s[i, j], -T.infinity(accum))
                        T.copy(sm, smp)
                        T.reduce_max(acc_s, sm, dim=1, clear=False)
                        for i in T.Parallel(half):
                            alpha[i] = T.exp2(smp[i] * scale - sm[i] * scale)
                        for i, j in T.Parallel(half, block_N):
                            acc_s[i, j] = T.exp2(acc_s[i, j] * scale - sm[i] * scale)
                        T.reduce_sum(acc_s, ss, dim=1)
                        T.wait_wgmma(0)
                        T.mbarrier_arrive(vfree[svp])
                        for i in T.Parallel(half):
                            logsum[i] = logsum[i] * alpha[i] + ss[i]
                        T.copy(acc_s, pcast)

                    svp = (eff - 1) % nsV
                    for i, j in T.Parallel(half, D):  # pipelined rescale: final alpha before last PV
                        acc_o[i, j] *= alpha[i]
                    T.mbarrier_wait_parity(vready[svp], ((eff - 1) // nsV) % 2)
                    T.wgmma_gemm(pcast, Vs[svp, :, :], acc_o, policy=Pol, clear_accum=False)
                    T.wait_wgmma(0)
                    T.mbarrier_arrive(vfree[svp])
                    for i, j in T.Parallel(half, D):
                        acc_o[i, j] /= logsum[i]
                    T.copy(acc_o, Os[0, :, :])                                  # registers -> smem
                    T.copy(Os[0, :, :], O[bz, q0 + r0:q0 + r0 + half, by, :])  # smem -> global (coalesced)

            with T.ws(1):
                    T.set_max_nreg(240, 1)  # consumer grabs producer's released regs
                    r0 = 1 * half
                    my_bar = 2
                    nxt_bar = 1
                    acc_s = T.alloc_fragment([half, block_N], accum)
                    pcast = T.alloc_fragment([half, block_N], dtype)  # register-P (rs-wgmma)
                    acc_o = T.alloc_fragment([half, D], accum)
                    sm = T.alloc_fragment([half], accum)
                    smp = T.alloc_fragment([half], accum)
                    alpha = T.alloc_fragment([half], accum)
                    ss = T.alloc_fragment([half], accum)
                    logsum = T.alloc_fragment([half], accum)

                    T.fill(acc_o, 0)
                    T.fill(logsum, 0)
                    T.fill(alpha, 1.0)  # pipelined rescale: carried alpha, 1st rescale no-op
                    T.fill(sm, -T.infinity(accum))
                    T.mbarrier_wait_parity(q_bar, 0)
                    T.named_barrier_arrive(1, NMMA)  # prime WG0

                    # prologue: tile 0, QK + softmax (no PV)
                    T.sync_threads(my_bar, NMMA)
                    T.mbarrier_wait_parity(kready[0], 0)
                    T.wgmma_gemm(Qs[1, :, :], Ks[0, :, :], acc_s, transpose_B=True, policy=Pol, clear_accum=True)
                    T.named_barrier_arrive(nxt_bar, NMMA)
                    T.wait_wgmma(0)
                    T.mbarrier_arrive(kfree[0])
                    T.reduce_max(acc_s, sm, dim=1, clear=False)
                    for i, j in T.Parallel(half, block_N):
                        acc_s[i, j] = T.exp2(acc_s[i, j] * scale - sm[i] * scale)
                    T.reduce_sum(acc_s, ss, dim=1)
                    for i in T.Parallel(half):
                        logsum[i] = ss[i]
                    T.copy(acc_s, pcast)

                    nu = T.max(1, T.min(eff, T.floordiv(q0 + r0 + co + 1, block_N)))
                    for k in T.serial(1, nu):
                        sk = k % nsK
                        svp = (k - 1) % nsV
                        T.sync_threads(my_bar, NMMA)
                        T.mbarrier_wait_parity(kready[sk], (k // nsK) % 2)
                        T.wgmma_gemm(Qs[1, :, :], Ks[sk, :, :], acc_s, transpose_B=True, policy=Pol, clear_accum=True)
                        for i, j in T.Parallel(half, D):  # pipelined rescale (prev alpha) covers QK latency
                            acc_o[i, j] *= alpha[i]
                        T.mbarrier_wait_parity(vready[svp], ((k - 1) // nsV) % 2)
                        T.wgmma_gemm(pcast, Vs[svp, :, :], acc_o, policy=Pol, clear_accum=False)
                        T.named_barrier_arrive(nxt_bar, NMMA)
                        T.wait_wgmma(1)
                        T.mbarrier_arrive(kfree[sk])
                        T.copy(sm, smp)
                        T.reduce_max(acc_s, sm, dim=1, clear=False)
                        for i in T.Parallel(half):
                            alpha[i] = T.exp2(smp[i] * scale - sm[i] * scale)
                        for i, j in T.Parallel(half, block_N):
                            acc_s[i, j] = T.exp2(acc_s[i, j] * scale - sm[i] * scale)
                        T.reduce_sum(acc_s, ss, dim=1)
                        T.wait_wgmma(0)
                        T.mbarrier_arrive(vfree[svp])
                        for i in T.Parallel(half):
                            logsum[i] = logsum[i] * alpha[i] + ss[i]
                        T.copy(acc_s, pcast)
                    for k in T.serial(nu, eff):
                        sk = k % nsK
                        svp = (k - 1) % nsV
                        T.sync_threads(my_bar, NMMA)
                        T.mbarrier_wait_parity(kready[sk], (k // nsK) % 2)
                        T.wgmma_gemm(Qs[1, :, :], Ks[sk, :, :], acc_s, transpose_B=True, policy=Pol, clear_accum=True)
                        for i, j in T.Parallel(half, D):  # pipelined rescale (prev alpha)
                            acc_o[i, j] *= alpha[i]
                        T.mbarrier_wait_parity(vready[svp], ((k - 1) // nsV) % 2)
                        T.wgmma_gemm(pcast, Vs[svp, :, :], acc_o, policy=Pol, clear_accum=False)
                        T.named_barrier_arrive(nxt_bar, NMMA)
                        T.wait_wgmma(1)
                        T.mbarrier_arrive(kfree[sk])
                        for i, j in T.Parallel(half, block_N):
                            acc_s[i, j] = T.if_then_else(
                                q0 + r0 + i + co >= k * block_N + j, acc_s[i, j], -T.infinity(accum))
                        T.copy(sm, smp)
                        T.reduce_max(acc_s, sm, dim=1, clear=False)
                        for i in T.Parallel(half):
                            alpha[i] = T.exp2(smp[i] * scale - sm[i] * scale)
                        for i, j in T.Parallel(half, block_N):
                            acc_s[i, j] = T.exp2(acc_s[i, j] * scale - sm[i] * scale)
                        T.reduce_sum(acc_s, ss, dim=1)
                        T.wait_wgmma(0)
                        T.mbarrier_arrive(vfree[svp])
                        for i in T.Parallel(half):
                            logsum[i] = logsum[i] * alpha[i] + ss[i]
                        T.copy(acc_s, pcast)

                    svp = (eff - 1) % nsV
                    for i, j in T.Parallel(half, D):  # pipelined rescale: final alpha before last PV
                        acc_o[i, j] *= alpha[i]
                    T.mbarrier_wait_parity(vready[svp], ((eff - 1) // nsV) % 2)
                    T.wgmma_gemm(pcast, Vs[svp, :, :], acc_o, policy=Pol, clear_accum=False)
                    T.wait_wgmma(0)
                    T.mbarrier_arrive(vfree[svp])
                    for i, j in T.Parallel(half, D):
                        acc_o[i, j] /= logsum[i]
                    T.copy(acc_o, Os[1, :, :])                                  # registers -> smem
                    T.copy(Os[1, :, :], O[bz, q0 + r0:q0 + r0 + half, by, :])  # smem -> global (coalesced)

    return main

逻辑顺序 vs 实际发射时机

在线 softmax 每个 K-block 的逻辑顺序是固定的四步:

QKsoftmax(更新 max、算 P、更新 sum) → rescale(acc_o *= exp(m_prev − m_new),把已累加的输出修正到新 max) → PV(acc_o += P·V)

但为了凑成上面两种重叠,实际 kernel 把这四步打散、错位发射:

  • delayed PV:第 j 轮迭代发射的是上一块PV(j-1),这样它的 wgmma 能盖在本轮 softmax(j) 底下(intra-warpgroup overlap)。
  • carried-alpha rescale:rescale 用的是上一轮 softmax 算好的 alpha,挪到 QK 之后、PV 之前,用来盖住 QK 的延迟。

结果是一个物理迭代 j 里混了两个逻辑 tile 的操作:

图上的块 逻辑归属
QK·jsoftmax·j tile j 的前半:算这块的分数和 P
rescale·(j-1)PV·(j-1) tile j-1 的后半:用上一块的 alpha 修正、再把上一块的 P·V 累加进去

两个边界情况:tile 0 在 prologue 里只做 QK·0softmax·0(没有可累加的历史,不需要 rescale/PV);rescalerescale·0 开始,但它是 alpha=1空操作(此时 acc_o 还是 0),第一个真正起作用的是 rescale·1

流水线长什么样

上面这套错位发射,直接画成时间线最直观。我写了个带打点的版本:在三个 warpgroup 的每个 tile 级操作边界用 clock64() 记一个时间戳到 log tensor,跑单个 CTA(B=H=1,4 个 K-tile),再把时间线画出来,做成一张交互式全屏图(滚轮左右平移,放大/缩小按钮拉伸时间轴,hover 看每块的周期数)。图比较宽,嵌在正文里太挤,单独开了一页:

▶ 打开交互式流水线图(新标签页 · 全屏)

怎么读这张图:

  • 每个 consumer WG 有两条泳道:上面 Consumer WGx(SIMT)是 warpgroup 线程按程序序在干的事(各种 wait、rescale、softmax);下面 ⤷ WGx tensor core 是 wgmma 的异步在飞区间(从发射到把它 drain 掉的那次 wait)。
  • intra-warpgroup 异步就是 tensor-core 泳道的条盖在上方 SIMT 泳道底下——比如 PV·(k-1) 整条盖在 softmax·k 下面,说明 PV 在张量核上跑的同时,线程在算 softmax。
  • softmax 是大头(每块约 2000 cycle),QK/PV 的发射只是一瞬;真正的矩阵乘时间被藏在 softmax 底下。
  • 紫色箭头是 ping-pong handoff:某个 WG 发完 wgmma → arrive → 放行对面的 sync(每行开头那个很小的 wait peer 块);箭头中点标了用的是 bar1 还是 bar2
  • producer 大段停在 wait buffer free,说明 TMA 远不是瓶颈,它早把数据搬好在等消费者腾 buffer。

这是打点版,clock64 读取和额外的 global store 会轻微扰动时序,别拿它读绝对性能;而且单线程 clock 对不准跨 warpgroup 的屏障因果(偶尔会出现 arrive 读数晚于对面 sync 的反向箭头),它适合看结构与重叠关系,不适合掐精确延迟。完整的打点 kernel 与画图脚本和这里贴的 tilelang_fa3.py 配套。