Gated Delta Net 学习

推理框架
Published

April 21, 2026

学习 Qwen3.5-35B-A3B 的 GDN 实现,本文的目标是用 numpy reference 把 GDN 的每一步都手动走一遍,尤其是理解 “delta rule” 到底在做什么、为什么 state 可以是固定大小的矩阵而不是随序列增长的 KV cache。

1. GDN的设计动机

总所周知,Full Attention decode 的瓶颈在于 KV cache 随序列长度线性增长:每新增一个 token,cache 多一份 (k, v),计算时要重新扫全部历史。Linear Attention 想解决的正是这个问题:它保留“query 按 key 相似度去汇总 value”这个核心思想,但不再显式保留全部历史 token,而是把历史压缩进一个固定大小的状态里。

机制 标准 Attention / FlashAttention Linear Attention
历史表示 显式保留全部 \(K, V\) 压缩成固定状态 \(S\)
读取公式 \(o = \mathrm{softmax}(qK^\top)V\) \(o = qS\)
相似度权重 先算全部 \(q \cdot k_i\),再做 softmax 归一化 \(S\) 的构造隐式决定
decode 代价 随历史长度增长 与历史长度解耦

它的核心想法是用一个固定大小的矩阵 \(S \in \mathbb{R}^{d_k \times d_v}\) 当“可用向量寻址的哈希表”。这里的“哈希表”只是类比,不是真的离散 dict;更准确地说,\(S\) 是一个可微分的 key-value 存储。历史 token 把自己的 \((k_i, v_i)\) 写进去,之后当前 query \(q\) 再按和各个 key 的相似度把相关的 value 读出来。

其中:

  • \(q\) 是当前 token 的 query
  • \(k_i\) 是第 \(i\) 个历史 token 的 key
  • \(v_i\) 是第 \(i\) 个历史 token 的 value
  • \(S\) 是把历史信息累计压缩后的记忆矩阵

对应的写入和读取公式是:

  • 写入:\(S \mathrel{+}= k \otimes v\)
  • 读取:\(o = q \cdot S\)

先看 \(S\) 是怎么来的。每来一个历史 token,就把它的 \((k_i, v_i)\) 用一次外积写进状态:

\[ S \mathrel{+}= k_i \otimes v_i \]

所以写完整段历史之后,状态矩阵就是:

\[ S = \sum_i k_i \otimes v_i \]

这时读取公式虽然写成 \(o = q \cdot S\),但把上面的 \(S\) 展开后,就得到:

\[ o = qS = q\left(\sum_i k_i \otimes v_i\right) = \sum_i (q \cdot k_i) v_i \]

也就是说,实现上读取就是 qS;但语义上它等价于“让 query 和每个历史 key 做一次相似度计算,再用这个分数去加权对应的 value”。如果不同历史 token 的 \(k\) 互相接近正交,那么这种召回会比较干净;这就是 Hopfield 关联记忆 / Widrow-Hoff 一类方法背后的直觉。

问题也随之而来。朴素 Linear Attention 至少有三个缺点:

  • 容量有限:因为 \(S = \sum_i k_i \otimes v_i\) 必须把全部历史都压进一个固定大小的矩阵里,历史一长,不同 token 的贡献就会在同一个状态里叠加。读取时 \(o = \sum_i (q \cdot k_i)v_i\),如果有很多 \(k_i\) 都和当前 \(q\) 有非零相似度,就会有很多无关项一起混进输出,形成干扰。这里不一定只是“应该只看最近几个 token”,更准确地说,是模型想选中的少数记忆,无法总是和其它记忆干净分离
  • 没有覆盖:如果两个 token 落在相近的可寻址方向上,第二次写入不会替换第一次,而只会继续累加。例如连续两次写入同一个 key \(k\),得到 \(S = k \otimes v_1 + k \otimes v_2\);此时再用 \(q = k\) 去读,结果更接近 \((k \cdot k)(v_1 + v_2)\),而不是只返回最新的 \(v_2\)
  • 没有遗忘:对朴素写法来说,状态更新只有 \(S_t = S_{t-1} + k_t \otimes v_t\),没有任何衰减项。因此只要写进去过,旧信息就会一直留在 \(S\) 里,不会自己变淡。

GDN 本质上就是针对后两个问题加了两种机制: - delta rule,解决 没有覆盖 问题,让新信息更像“覆盖”而不是“累加”。 - gate,解决 没有遗忘 问题,让旧记忆逐步淡出

2. 没有覆盖问题:朴素 Linear Attention vs Delta Rule

先设计一个实验验证这个问题,也就是朴素线性写入会把两次记忆混在一起

同一个 key \(k\) 连续写入两个不同的 value,然后用 \(k\) 作为 query 去查。理想行为应该是召回第二次写的那个 \(v_2\)(覆盖语义),而不是 \(v_1 + v_2\)(累加)。

下面开始实验:

Code
import numpy as np
np.set_printoptions(precision=3, suppress=True)

# 先造一个可控的 key 空间: 两个正交的单位向量
k_a = np.array([1.0, 0.0, 0.0, 0.0])  # key A
k_b = np.array([0.0, 1.0, 0.0, 0.0])  # key B (与 A 正交)

# 要存进去的 values
v_a1 = np.array([10.0, 0.0, 0.0, 0.0])    # 第一次给 key A 存的 value
v_a2 = np.array([0.0,  0.0, 99.0, 0.0])    # 第二次给 key A 存的 value (想覆盖 v_a1)
v_b  = np.array([0.0, -5.0, 0.0, 0.0])    # 给 key B 存的

d_k, d_v = 4, 4

2.1 朴素 Linear Attention: S += k ⊗ v

Code
def naive_write(S, k, v):
    """朴素累加式写入: 没有覆盖机制"""
    return S + np.outer(k, v)

def read(S, q):
    """读取: o = q · S"""
    return q @ S

S = np.zeros((d_k, d_v))
S = naive_write(S, k_a, v_a1)  # 存 (A, v_a1)
S = naive_write(S, k_b, v_b)   # 存 (B, v_b)
print('用 k_a 写入:', v_a1 , '查询得到: ', read(S, k_a))
print('用 k_b 写入:', v_b , '查询得到: ', read(S, k_b))

S = naive_write(S, k_a, v_a2)  # 再次存 (A, v_a2) —— 期望覆盖 v_a1
print('用 k_a 写入:', v_a2 , '查询得到: ', read(S, k_a), "❌ 查询结果和写入不一致, 这就是没有覆盖问题!")
print()
用 k_a 写入: [10.  0.  0.  0.] 查询得到:  [10.  0.  0.  0.]
用 k_b 写入: [ 0. -5.  0.  0.] 查询得到:  [ 0. -5.  0.  0.]
用 k_a 写入: [ 0.  0. 99.  0.] 查询得到:  [10.  0. 99.  0.] ❌ 查询结果和写入不一致, 这就是没有覆盖问题!

2.2 Delta Rule: 先读、减、再写

核心只有 3 行。注意这里的“读”不是正常推理阶段用 query 去读,而是在写入当前 (k, v) 之前,先临时拿这个 key 自己去查一下:如果现在就用 k 查询,旧状态 \(S\) 会返回什么?

retrieval = k @ S            # 用当前 key 当 query,先看 S 现在会返回什么
delta     = v - retrieval    # 我们想存 v,和当前返回差多少?
S         = S + k ⊗ delta    # 只把差值写回

展开来看是 \(S \leftarrow S + k \otimes (v - k^\top S) = (I - k k^\top) S + k \otimes v\),这是 Widrow-Hoff 最小均方(LMS)学习律的一步梯度下降。

Code
def delta_write(S, k, v):
    """Delta rule 写入: 先用当前 key 读一次,再写 delta"""
    retrieval = k @ S                    # [d_v] 这里把当前 key 临时当成 query,检查旧记忆对它的返回值
    delta     = v - retrieval            # [d_v] —— 这就是 'delta' 一词的来历
    return S + np.outer(k, delta)

S = np.zeros((d_k, d_v))
S = delta_write(S, k_a, v_a1)
S = delta_write(S, k_b, v_b)
print('用 k_a 写入:', v_a1 , '查询得到: ', read(S, k_a))
print('用 k_b 写入:', v_b , '查询得到: ', read(S, k_b))

S = delta_write(S, k_a, v_a2)   # 再次存 (A, v_a2)
print('用 k_a 写入:', v_a2 , '查询得到: ', read(S, k_a), '✅ 这次查询结果和写入一致,说明旧值被覆盖了')
print()
用 k_a 写入: [10.  0.  0.  0.] 查询得到:  [10.  0.  0.  0.]
用 k_b 写入: [ 0. -5.  0.  0.] 查询得到:  [ 0. -5.  0.  0.]
用 k_a 写入: [ 0.  0. 99.  0.] 查询得到:  [ 0.  0. 99.  0.] ✅ 这次查询结果和写入一致,说明旧值被覆盖了

Delta rule 和朴素线性注意力的区别就是写入前先读一次,只写入当前 retrieval 和目标 v 的差。代价是每次 update 多一次 \(k \cdot S\) 的乘法(\(O(d_k \cdot d_v)\)),收益是 \(S\) 真正拥有了”覆盖”语义,不会随 token 数堆积污染。

3. 没有遗忘问题:加上 Gate,让旧记忆逐步淡出

如果历史里 1000 个 token 的 key 几乎都不相同,\(S\) 依然会被塞满、互相干扰。解决办法就是在 delta rule 前面再加一个 gate,让旧记忆在每一步先整体衰减一次。如果把它拆成 3 个动作,其实并不复杂:

  1. 先衰减旧状态:\(S'_t = \alpha_t S_{t-1}\)
  2. 再读取当前 key 在衰减后状态上的返回值:\(r_t = k_t^\top S'_t\)
  3. 最后按 delta rule 写回:\(S_t = S'_t + \beta_t \, k_t \otimes (v_t - r_t)\)

把第 1 步和第 3 步合起来,就是论文里常见的一行式:

\[ S_t = \alpha_t \cdot (I - \beta_t\, k_t k_t^\top)\, S_{t-1} + \beta_t \cdot k_t \otimes v_t \]

这个式子可以直接从上面三步展开得到:

\[ S_t = S'_t + \beta_t \, k_t \otimes (v_t - k_t^\top S'_t) \]

\[ = \alpha_t S_{t-1} + \beta_t \, k_t \otimes v_t - \beta_t \, k_t \otimes (k_t^\top \alpha_t S_{t-1}) \]

\[ = \alpha_t (I - \beta_t k_t k_t^\top) S_{t-1} + \beta_t \, k_t \otimes v_t \]

所以这一个大公式对应就是Delta Rule + Gate。 其中:

  • \(\alpha_t \in (0,1)\)global gate,对所有记忆统一做指数衰减。由输入 \(x_t\) 经过 in_proj_a → softplus → exp(-exp(A_log)·softplus(...)) 产生,per head 一个标量
  • \(\beta_t \in (0,1)\)delta strength,控制本次 delta 写入的强度。由输入 \(x_t\) 经过 in_proj_b → sigmoid 产生,per head 一个标量

整个 update 一步里同时做了两件事:

  1. 所有旧记忆按 \(\alpha_t\) 淡化
  2. \(k_t\) 方向上按 \(\beta_t\) 做 delta 覆盖

下面做一个实验:存完 (k_a, v_a1) 后间隔很多步空转(只做 gate 衰减、不写任何东西),看 \(S\)\(k_a\) 的召回强度如何指数衰减。

Code
def gated_delta_write(S, k, v, alpha, beta):
    """Gated Delta Rule (一行代码概括 GDN 的核心)."""
    S_decayed = alpha * S
    retrieval = k @ S_decayed
    delta     = v - retrieval
    return S_decayed + beta * np.outer(k, delta)

alpha = 0.9   # 假设每步保留 90% 旧记忆
beta  = 1.0   # delta 全强度写入

S = np.zeros((d_k, d_v))
S = gated_delta_write(S, k_a, v_a1, alpha, beta)
print('step 0  (刚写入 k_a):', read(S, k_a))

# 空转 20 步 (只 gate, 不 write any key)
for t in range(1, 21):
    S = alpha * S
    if t in (1, 5, 10, 20):
        print(f'step {t:>2} (空转):      {read(S, k_a)}   # 应约等于 v_a1 × {alpha**t:.3f}')
step 0  (刚写入 k_a): [10.  0.  0.  0.]
step  1 (空转):      [9. 0. 0. 0.]   # 应约等于 v_a1 × 0.900
step  5 (空转):      [5.905 0.    0.    0.   ]   # 应约等于 v_a1 × 0.590
step 10 (空转):      [3.487 0.    0.    0.   ]   # 应约等于 v_a1 × 0.349
step 20 (空转):      [1.216 0.    0.    0.   ]   # 应约等于 v_a1 × 0.122

可以看到 \(S\)\(k_a\) 的召回值就是 \(v_{a1} \times \alpha^t\) —— 相当于给每条记忆贴了一个指数衰减的时间戳

在实际场景下,\(\alpha\) 不是常数 0.9,而是每个 token、每个 head 独立由网络决定,这就是 Mamba-style selective 机制:

\[ \alpha_t = \exp\bigl(-\exp(A_{\log}) \cdot \mathrm{softplus}(a_t + \mathrm{dt\_bias})\bigr) \]

  • A_log 是学习到的 per-head 常数(log-decay rate,保证 exp 后恒正)
  • a_t = in_proj_a(x_t) 是依赖输入的标量
  • 整体形式保证 \(\alpha_t \in (0, 1)\),且网络可以根据 token 内容决定忘快还是忘慢

4. 完整的 GDN 层实现

下面直接按 HF Qwen3_5MoeForConditionalGeneration 的 linear attention decode 路径写一个 numpy 版单步 reference。这里随机初始化了权重,但投影拆分方式、张量布局和数据流都尽量贴近真实实现;整段 prefill 路径先不展开,只聚焦 decode step = 1

  1. in_proj_qkv 只负责产生 q/k/vin_proj_zin_proj_ain_proj_b 是三组独立投影。

  2. q/k 的 head 数是 16v/state 的 head 数是 32,因此需要通过 GQA 的 repeat_interleaveq/k16 个 head 扩到 32 个 head。

  3. conv1d 作用在拼接后的 QKV 上。按当前 HF 这版实现,cache 里保存的 conv_state 形状是 [batch, conv_C, kernel],对本模型就是 [batch, 8192, 4];卷积后立刻接一个 SiLU

  4. A_logdt_biasin_proj_ain_proj_b 都是 per value headHv = 32),不是 per key head。

  5. RMSNormGated(·silu(z)) 也是按 HF 的形式做:先对每个 value head 的 head_v_dim 做 RMSNorm,再乘上 silu(z)

Code
# ============================================================
# HF Qwen3.5-35B-A3B linear attention config
# ============================================================
B        = 1
H        = 2048
Hk       = 16     # linear_num_key_heads
Hv       = 32     # linear_num_value_heads
KV_REPEAT = Hv // Hk
D_k      = 128    # linear_key_head_dim
D_v      = 128    # linear_value_head_dim
K_conv   = 4      # linear_conv_kernel_dim

Qd = Hk * D_k
Kd = Hk * D_k
Vd = Hv * D_v
conv_C = Qd + Kd + Vd

# ============================================================
# Weights (随机初始化代替 HF safetensors)
# ============================================================
np.random.seed(0)
W_qkv = (np.random.randn(H, Qd + Kd + Vd).astype(np.float32) * 0.02)
W_z   = (np.random.randn(H, Vd).astype(np.float32) * 0.02)
W_a   = (np.random.randn(H, Hv).astype(np.float32) * 0.02)
W_b   = (np.random.randn(H, Hv).astype(np.float32) * 0.02)
# HF conv1d.weight 的形状是 [conv_C, 1, K_conv],这里 squeeze 成 [conv_C, K_conv]
W_conv = (np.random.randn(conv_C, K_conv).astype(np.float32) * 0.02)
A_log = (np.random.randn(Hv).astype(np.float32) * 0.02)
dt_bias = (np.random.randn(Hv).astype(np.float32) * 0.02)
norm_weight = np.ones(D_v, dtype=np.float32)
W_o = (np.random.randn(Vd, H).astype(np.float32) * 0.02)

print(f'in_proj_qkv: {W_qkv.shape}   (hidden -> Q|K|V = {Qd}|{Kd}|{Vd})')
print(f'in_proj_z:   {W_z.shape}     (hidden -> z = {Vd})')
print(f'in_proj_a:   {W_a.shape}     (hidden -> a, per-value-head scalars)')
print(f'in_proj_b:   {W_b.shape}     (hidden -> b, per-value-head scalars)')
print(f'conv1d:      {W_conv.shape}  (depthwise, squeezed from [conv_C, 1, {K_conv}])')
print(f'GQA repeat:  {KV_REPEAT}       (Hk={Hk} -> Hv={Hv})')
print(f'A_log:       {A_log.shape}     dt_bias: {dt_bias.shape}')
in_proj_qkvz: (2048, 12288)   (hidden -> Q|K|V|z = 2048|2048|4096|4096)
in_proj_ba:   (2048, 32)     (hidden -> b|a, per-key-head scalars)
conv1d:       (4, 8192)   (depthwise, kernel=4)
GQA repeat:   2       (Hk=16 -> Hv=32)
A_log:        (16,)     dt_bias: (16,)

4.1 单步推理过程

注意这里我们只看 单步 decodeSconv_state 都是跨 step 持活的缓存,它们的尺寸固定,不随序列长度增长:

  • S per layer per request: [batch, Hv, D_k, D_v] —— 每个 value head 一份 GDN 状态
  • conv_state per layer per request: [batch, conv_C, K_conv] —— 按当前 HF 这版 Qwen3_5MoeDynamicCache,这里实际缓存的是长度 K_conv 的卷积窗口
  • q/k 先按 [batch, seq, Hk, D_k] 投影,再通过 repeat_interleave(KV_REPEAT) 扩到 [batch, seq, Hv, D_k]
  • a/b/A_log/dt_bias 都是 per value head,因此它们的 head 维也是 Hv = 32
  • 路径上要区分:prefill 时会先把 conv_state 写好,但卷积本身走整段 conv1d+siludecode 时才会真正读取已有 conv_state,走 causal_conv1d_update
Code
def softplus(x):
    return np.log1p(np.exp(x))

def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))

def silu(x):
    return x * sigmoid(x)

def rmsnorm_lastdim(x, eps=1e-6):
    return x / np.sqrt((x**2).mean(axis=-1, keepdims=True) + eps)

def repeat_interleave_heads(x, repeat):
    """x: [B, S, Hk, D]f32 -> [B, S, Hv, D]f32"""
    return np.repeat(x, repeat, axis=2)

def depthwise_conv1d_update(mixed_qkv, conv_state):
    """严格按 HF `torch_causal_conv1d_update` 的逻辑展开。"""
    # mixed_qkv: [B, conv_C, 1]f32
    # conv_state: [B, conv_C, K_conv]f32
    state_len = conv_state.shape[-1]
    hidden_states_new = np.concatenate([conv_state, mixed_qkv], axis=-1)        # [B, conv_C, K_conv+1]f32
    conv_state = hidden_states_new[:, :, -state_len:]                           # [B, conv_C, K_conv]f32
    out = np.empty((B, conv_C, hidden_states_new.shape[-1] - K_conv + 1), dtype=np.float32)
    for t in range(out.shape[-1]):
        window = hidden_states_new[:, :, t : t + K_conv]                        # [B, conv_C, K_conv]f32
        out[:, :, t] = (window * W_conv[None, :, :]).sum(axis=-1)              # [B, conv_C]f32
    out = silu(out[:, :, -mixed_qkv.shape[-1]:])                                # [B, conv_C, 1]f32
    return out, conv_state

def rmsnorm_gated(core_attn_out, z):
    # core_attn_out / z: [B, S, Hv, D_v]f32
    normed = rmsnorm_lastdim(core_attn_out.astype(np.float32))          # [B, S, Hv, D_v]f32
    gated = normed * silu(z.astype(np.float32))                         # [B, S, Hv, D_v]f32
    return gated * norm_weight[None, None, None, :]                     # [B, S, Hv, D_v]f32

def gdn_decode_step(hidden_states, S, conv_state):
    """一个和 HF linear_attn decode 数据流对齐的单步 reference."""
    # hidden_states: [B, 1, H]f32
    hidden_states = hidden_states.astype(np.float32)
    hidden_states = rmsnorm_lastdim(hidden_states)                      # [B, 1, H]f32

    # --- L2: four independent projections ---
    qkv = hidden_states @ W_qkv                                         # [B, 1, Qd+Kd+Vd]f32
    z = (hidden_states @ W_z).reshape(B, 1, Hv, D_v)                    # [B, 1, Hv, D_v]f32
    a = hidden_states @ W_a                                             # [B, 1, Hv]f32
    b = hidden_states @ W_b                                             # [B, 1, Hv]f32

    # --- L3: causal depthwise conv1d + silu on concatenated QKV ---
    mixed_qkv = np.transpose(qkv, (0, 2, 1))                            # [B, conv_C, 1]f32
    mixed_qkv, conv_state = depthwise_conv1d_update(mixed_qkv, conv_state)
    mixed_qkv = np.transpose(mixed_qkv.astype(np.float32), (0, 2, 1))   # [B, 1, conv_C]f32
    query, key, value = np.split(mixed_qkv, [Qd, Qd + Kd], axis=-1)     # [B,1,Qd], [B,1,Kd], [B,1,Vd]
    query = query.reshape(B, 1, Hk, D_k)                                # [B, 1, Hk, D_k]f32
    key = key.reshape(B, 1, Hk, D_k)                                    # [B, 1, Hk, D_k]f32
    value = value.reshape(B, 1, Hv, D_v)                                # [B, 1, Hv, D_v]f32

    # --- L3.5: GQA repeat_interleave from Hk to Hv ---
    query = repeat_interleave_heads(query, KV_REPEAT)                   # [B, 1, Hv, D_k]f32
    key = repeat_interleave_heads(key, KV_REPEAT)                       # [B, 1, Hv, D_k]f32

    # --- L4: Gated Delta Rule ---
    # HF 里直接算 g = -exp(A_log) * softplus(a + dt_bias);这里显式转成 alpha = exp(g)
    g = -np.exp(A_log)[None, None, :] * softplus(a + dt_bias[None, None, :])  # [B, 1, Hv]f32
    beta = sigmoid(b)                                                   # [B, 1, Hv]f32

    core_attn_out = np.empty((B, 1, Hv, D_v), dtype=np.float32)         # [B, 1, Hv, D_v]f32
    for h in range(Hv):
        alpha_h = np.exp(g[0, 0, h]).astype(np.float32)                 # scalar
        S_dec = alpha_h * S[0, h]                                       # [D_k, D_v]f32
        retrieval = key[0, 0, h] @ S_dec                                # [D_v]f32
        delta = value[0, 0, h] - retrieval                              # [D_v]f32
        S[0, h] = S_dec + beta[0, 0, h] * np.outer(key[0, 0, h], delta) # [D_k, D_v]f32
        core_attn_out[0, 0, h] = query[0, 0, h] @ S[0, h]              # [D_v]f32

    # --- L4n: RMSNormGated ---
    out = rmsnorm_gated(core_attn_out, z)                               # [B, 1, Hv, D_v]f32

    # --- L6: out_proj ---
    y = out.reshape(B, 1, Vd) @ W_o                                     # [B, 1, H]f32
    return y.astype(np.float32), S, conv_state

# 初始化持久化 state
S = np.zeros((B, Hv, D_k, D_v), dtype=np.float32)                       # [B, Hv, D_k, D_v]f32
conv_state = np.zeros((B, conv_C, K_conv), dtype=np.float32)            # [B, conv_C, K_conv]f32

# 跑 3 个 decode step
for t in range(3):
    hidden_states = np.random.randn(B, 1, H).astype(np.float32)         # [B, 1, H]f32
    y, S, conv_state = gdn_decode_step(hidden_states, S, conv_state)
    print(f'step {t}: y.shape={y.shape}, conv_state.shape={conv_state.shape}, ||S||_F={np.linalg.norm(S):.3f}')
step 0: y.shape=(2048,), dtype=float32, ||S||_F=0.039
step 1: y.shape=(2048,), dtype=float32, ||S||_F=0.069
step 2: y.shape=(2048,), dtype=float32, ||S||_F=0.104

4.2 decode 时 conv_state 更新逻辑

HF 所实现的 torch_causal_conv1d_update,decode 路径的核心逻辑:

hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1)
conv_state.copy_(hidden_states_new[:, :, -state_len:])
out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
out = F.silu(out[:, :, -seq_len:])

设当前是单步 decode,seq_len = 1,并且当前 HF cache 里保存的 conv_state 形状是 [batch, conv_C, K_conv]。对本模型来说就是 [1, 8192, 4]。这时:

  • 当前 token 的 mixed_qkv 形状是 [1, 8192, 1]
  • 先和旧 conv_state 在时间维拼起来,得到 hidden_states_new: [1, 8192, 5]
  • 然后把最后 state_len=4 个位置拷回 conv_state

也就是说,假设某个通道原来的 cache 是:

\[ [s_0, s_1, s_2, s_3] \]

当前 token 的这个通道值是:

\[ [x_t] \]

那么拼接后就是:

\[ [s_0, s_1, s_2, s_3, x_t] \]

更新后的 conv_state 会变成最后 4 个位置:

\[ [s_1, s_2, s_3, x_t] \]

接着 F.conv1d(...) 会在长度为 5 的序列上、用 kernel size 4 做 depthwise 卷积,所以每个通道会产生两个输出位置:

  1. 窗口 [s_0, s_1, s_2, s_3]
  2. 窗口 [s_1, s_2, s_3, x_t]

而 decode 只关心当前 token 对应的最后一个输出位置,所以源码才会取:

out = F.silu(out[:, :, -seq_len:])

因为这里 seq_len=1,所以最终保留下来的就是第二个窗口的卷积结果,也就是当前 token 真正送进后续 Gated Delta Rule 的 mixed_qkv

5. 容量实验:\(S\) 到底能存多少个 (k, v) 对?

线性注意力的根本代价是 \(S\) 是固定尺寸的矩阵,必然有容量上限。直觉上:

  • 如果所有 key 互相正交,delta rule 能清晰存下 \(d_k\) 对 (k,v),召回几乎完美
  • 如果 key 是随机不正交的,容量会显著低于 \(d_k\),delta 的优势也会被削弱

下面分两个场景验证。\(d_k = d_v = 16\),逐步增加 key 数量 \(N\),测召回和真实 \(v\) 的 cosine similarity 均值。

Code
def capacity_test(N, d_k=16, d_v=16, use_delta=True, orthogonal=False, seed=0):
    rng = np.random.default_rng(seed)
    if orthogonal:
        # 用 Gram-Schmidt 造一组正交单位向量 (最多 d_k 个)
        assert N <= d_k
        keys = np.linalg.qr(rng.standard_normal((d_k, d_k)))[0][:N]
    else:
        keys  = rng.standard_normal((N, d_k))
        keys /= np.linalg.norm(keys, axis=1, keepdims=True)
    vals = rng.standard_normal((N, d_v))

    S = np.zeros((d_k, d_v))
    for k, v in zip(keys, vals):
        if use_delta:
            S = S + np.outer(k, v - k @ S)
        else:
            S = S + np.outer(k, v)

    recalled = keys @ S
    cos = (recalled * vals).sum(-1) / (
        np.linalg.norm(recalled, axis=-1) * np.linalg.norm(vals, axis=-1) + 1e-8
    )
    return cos.mean()

print('=== 情景 A: 完全正交 key (理想情况, d_k=16 所以 N<=16) ===')
print(f'{"N":>4}  {"delta":>8}  {"naive":>8}')
for N in [1, 4, 8, 12, 16]:
    d = capacity_test(N, use_delta=True,  orthogonal=True)
    n = capacity_test(N, use_delta=False, orthogonal=True)
    print(f'{N:>4}  {d:>8.3f}  {n:>8.3f}')

print()
print('=== 情景 B: 随机单位 key (更接近真实) ===')
print(f'{"N":>4}  {"delta":>8}  {"naive":>8}')
for N in [1, 4, 8, 16, 32, 64]:
    d = capacity_test(N, use_delta=True,  orthogonal=False)
    n = capacity_test(N, use_delta=False, orthogonal=False)
    print(f'{N:>4}  {d:>8.3f}  {n:>8.3f}')
=== 情景 A: 完全正交 key (理想情况, d_k=16 所以 N<=16) ===
   N     delta     naive
   1     1.000     1.000
   4     1.000     1.000
   8     1.000     1.000
  12     1.000     1.000
  16     1.000     1.000

=== 情景 B: 随机单位 key (更接近真实) ===
   N     delta     naive
   1     1.000     1.000
   4     0.967     0.934
   8     0.774     0.802
  16     0.715     0.732
  32     0.472     0.625
  64     0.231     0.438

观察:

情景 A (正交 key): delta rule 在 \(N \leq d_k\) 时召回完美 (\(\cos = 1\)),而朴素线性注意力完美召回直到 \(N = d_k\) —— 因为正交 key 互不干扰,连累加都不会污染(\(q \cdot S = \sum_i \delta_{ij} v_i = v_j\))。

情景 B (随机 key): delta 和 naive 差距不大,两者都随 \(N\) 增加而退化。这里的真实优势要等到同一个 key 被多次写入(像 §2 那种覆盖场景)才体现出来。

所以 delta rule 的核心价值其实不是提高容量,而是允许覆盖:网络可以学到用类似的 \(k\) 重复写入来更新某个 slot 的内容,而不是被动累积。结合 §3 的 gate 之后,GDN 就既有:

  • 定向覆盖\(k\) 相似时 delta rule 生效)
  • 无方向遗忘(每步全局指数衰减)

但要注意:这两件事都不能从根本上消除“容量有限”。gate 和 delta rule 解决的是“怎么管理已有容量”,不是“把容量变成无限大”。真正应对容量上限的办法主要有两层:

  • 在 GDN 内部,模型通过 遗忘 + 覆盖 把有限状态尽量留给更重要、更新鲜的信息
  • 在整体架构上,再混入周期性的 Full Attention,负责那些 GDN 无法精确保存的长距离 token-level 检索

也就是说,GDN 解决的是“有限工作记忆如何更有效地用”,而不是单独解决“固定状态如何精确记住任意长上下文”。两种记忆管理手段,这是它在实用中超过纯 GLA / Mamba2 的原因。

关于混合架构\(S\) 的容量约为 \(O(d_k)\) 的工作记忆,无法精确召回任意远处的 token。Qwen3.5 用 \(d_k = 128\) 的 GDN 承载大部分上下文压缩,每 4 层插 1 层 Full Attention 负责精确长距离查询——这就是 [L,L,L,F]x10 混合架构的工作分工。

参考

  • Yang et al. 2024, Parallelizing Linear Transformers with the Delta Rule over Sequences
  • Yang et al. 2024, Gated Delta Networks: Improving Mamba2 with Delta Rule
  • Hugging Face Qwen/Qwen3.5-35B-A3B