Manual Coding · Attention

手撕经典算法 #1 Attention 篇整理

这份笔记把原文的 Attention 手写代码整理成一条更稳定的学习路线:先理解 SDPA 的张量对象,再扩展到 MHA、KV Cache、MQA/GQA,最后指出原实现里值得面试和工程实践特别留意的 shape、mask 与缓存边界。

5类 Attention 变体
3个核心张量轴
2类 Mask 边界
1统一 GQA/MQA 视角

来源与导入方式

本页是站内整理版,不是外部页面全文镜像。原文提供了多种注意力机制的 PyTorch 示例;这里保留来源脉络,重写关键实现,并补充实现风险、统一抽象和面试复盘清单。

原始材料

hewei2001.pages.dev/Manual-Coding-1,页面题为“手撕经典算法 #1 Attention篇”。

页面状态

原页面标注最后更新为 2025-03-20;本次抓取和整理时间为 2026-05-21。

导入边界

原文 MLA 部分仍为 TODO;本页不虚构 MLA 实现,只说明它与 KV 压缩的关系和后续补充方向。

先看共同结构

所有这些变体都在回答同一个问题:当前位置的 query 应该从哪些 key 对应的 value 中取信息?差异主要来自两个地方:是否拆成多个 head,以及 K/V 是否在多个 query head 之间共享。

\[ \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}} + M\right)V \]
Query

当前位置发出的“我要找什么”。在 MHA 中常见形状是:

(batch, heads, q_len, head_dim)
Key

每个历史位置提供的“我有什么特征”。GQA/MQA 会减少 K head 数。

(batch, kv_heads, kv_len, head_dim)
Value

真正被加权汇总的信息内容。它的 head 共享策略通常和 Key 一致。

(batch, kv_heads, kv_len, head_dim)
变体 核心变化 收益 代价/边界
SDPA 直接计算 \(QK^\top\),不关心投影和多头拆分。 最小机制,适合理解公式和 mask。 不是完整 Transformer attention layer。
MHA 把 hidden 拆成多个 query/key/value head,最后 concat 后输出投影。 不同 head 可以学习不同关系。 KV cache 随 head 数线性增长。
MQA 多个 Q head 共享同一组 K/V。 大幅降低 decode 阶段 KV cache 带宽和显存。 共享过强,表达能力可能受影响。
GQA 介于 MHA 和 MQA:多个 Q head 分组共享 K/V。 显存/带宽与效果之间更平衡。 需要保证 query heads 能整除 kv heads。
MLA 通过低秩 latent 表示压缩 KV 相关状态。 进一步降低缓存成本。 原文未展开;完整实现需要 RoPE、低秩投影和 cache 格式一起考虑。

SDPA:把公式写对

缩放点积注意力是后面所有变体的内核。手写时最重要的不是类结构,而是确认 score 的最后两个维度是 \(q\_len \times k\_len\),mask 只加在 key 维度或未来位置上。

推荐的最小实现

import math
import torch


def scaled_dot_product_attention(query, key, value, attn_mask=None):
    """
    query: (batch, heads, q_len, head_dim)
    key:   (batch, heads, k_len, head_dim)
    value: (batch, heads, k_len, head_dim)
    attn_mask: broadcastable to (batch, heads, q_len, k_len);
               True means the position is masked.
    """
    scale = 1.0 / math.sqrt(query.size(-1))
    scores = torch.matmul(query, key.transpose(-1, -2)) * scale

    if attn_mask is not None:
        scores = scores.masked_fill(attn_mask, torch.finfo(scores.dtype).min)

    weights = torch.softmax(scores, dim=-1)
    return torch.matmul(weights, value), weights
为什么这里用函数而不是先写类?

SDPA 是纯计算内核,没有参数。把它写成函数更符合 KISS:输入 Q/K/V 和 mask,输出 attention result。完整 Attention layer 的可学习参数应该放在 MHA/GQA 类里。

Mask 的关键边界

Mask 类型 典型形状 语义 最容易错的点
Causal mask (q_len, k_len)(1, 1, q_len, k_len) 当前位置不能看未来 token。 decode 单步时 q_len=1,k_len=历史长度+1,不再是完整方阵。
Padding mask (batch, k_len) 转成 (batch, 1, 1, k_len) 不让模型关注 padding token。 如果 score 不是四维,盲目 unsqueeze 两次会触发错误广播。

MHA:多头只是换形状,不是换公式

多头注意力的主线是:对输入做 Q/K/V 三个线性投影,把 hidden_size 拆成 num_heads × head_dim,在每个 head 内做 SDPA,再把所有 head 拼回 hidden_size。

\[ Q=XW_q,\quad K=XW_k,\quad V=XW_v,\quad \operatorname{MHA}(X)=\operatorname{Concat}(head_1,\ldots,head_h)W_o \]

更稳健的 MHA 写法

from torch import nn


class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout=0.0):
        super().__init__()
        if hidden_size % num_heads != 0:
            raise ValueError("hidden_size must be divisible by num_heads")

        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.o_proj = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def _split_heads(self, x):
        batch, seq_len, _ = x.shape
        x = x.view(batch, seq_len, self.num_heads, self.head_dim)
        return x.transpose(1, 2)

    def _merge_heads(self, x):
        batch, _, seq_len, _ = x.shape
        x = x.transpose(1, 2).contiguous()
        return x.view(batch, seq_len, self.hidden_size)

    def forward(self, hidden_states, attn_mask=None):
        query = self._split_heads(self.q_proj(hidden_states))
        key = self._split_heads(self.k_proj(hidden_states))
        value = self._split_heads(self.v_proj(hidden_states))

        output, weights = scaled_dot_product_attention(query, key, value, attn_mask)
        output = self.o_proj(self._merge_heads(output))
        return output, weights
面试解释重点

多头不是把模型“复制多份”,而是把同一个 hidden 表示投影到多个子空间。每个 head 的 score 矩阵独立,最后 concat 再用 \(W_o\) 混合。

工程实现重点

transpose 后再 view 要注意内存连续性;输出合并时用 .contiguous().view(...).reshape(...),不要默认 stride 一定兼容。

KV Cache:真正省的是 decode 阶段的重复 K/V

KV Cache 的语义是:生成第 \(t\) 个 token 时,历史 token 的 K/V 不需要重新投影和保存一遍。注意它主要服务自回归 decode;prefill 阶段通常一次性处理完整 prompt,并建立初始 cache。

单层缓存接口的写法

class CachedMultiHeadAttention(MultiHeadAttention):
    def forward(self, hidden_states, attn_mask=None, past_key_value=None, use_cache=False):
        query = self._split_heads(self.q_proj(hidden_states))
        key = self._split_heads(self.k_proj(hidden_states))
        value = self._split_heads(self.v_proj(hidden_states))

        if past_key_value is not None:
            past_key, past_value = past_key_value
            key = torch.cat([past_key, key], dim=2)
            value = torch.cat([past_value, value], dim=2)

        output, weights = scaled_dot_product_attention(query, key, value, attn_mask)
        output = self.o_proj(self._merge_heads(output))
        new_cache = (key, value) if use_cache else None
        return output, new_cache, weights
阶段 输入长度 是否使用已有 cache 核心指标
Prefill prompt 全长 通常没有已有 cache,但会产出初始 cache。 TTFT,算力密集,矩阵乘较大。
Decode 通常每步 1 个 token 复用历史 K/V,只追加当前 token 的 K/V。 TPOT,带宽和 KV cache 读写更关键。
原页面代码里值得修正的细节

KV Cache 示例的测试函数需要先实例化 attention 模块;此外,真实模型的 cache 不是单层变量,而是每一层各自保存一份 K/V。生成服务里还要管理 batch 内不同序列长度、cache eviction、prefix cache 命中和 P/D 分离。

MQA/GQA:用一个抽象统一 K/V 共享

MQA 和 GQA 最好不要写成两套重复代码。更清晰的抽象是:query_heads 固定为 num_heads,key/value_heads 可配置。KV heads 等于 num_heads 就是 MHA,等于 1 就是 MQA,介于中间就是 GQA。

统一版 GQA/MQA

def repeat_kv(x, repeats):
    # x: (batch, kv_heads, seq_len, head_dim)
    if repeats == 1:
        return x
    batch, kv_heads, seq_len, head_dim = x.shape
    x = x[:, :, None, :, :].expand(batch, kv_heads, repeats, seq_len, head_dim)
    return x.reshape(batch, kv_heads * repeats, seq_len, head_dim)


class GroupedQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, kv_heads):
        super().__init__()
        if hidden_size % num_heads != 0:
            raise ValueError("hidden_size must be divisible by num_heads")
        if num_heads % kv_heads != 0:
            raise ValueError("num_heads must be divisible by kv_heads")

        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.kv_heads = kv_heads
        self.head_dim = hidden_size // num_heads

        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, kv_heads * self.head_dim)
        self.v_proj = nn.Linear(hidden_size, kv_heads * self.head_dim)
        self.o_proj = nn.Linear(hidden_size, hidden_size)

    def _split(self, x, heads):
        batch, seq_len, _ = x.shape
        x = x.view(batch, seq_len, heads, self.head_dim)
        return x.transpose(1, 2)

    def forward(self, hidden_states, attn_mask=None):
        query = self._split(self.q_proj(hidden_states), self.num_heads)
        key = self._split(self.k_proj(hidden_states), self.kv_heads)
        value = self._split(self.v_proj(hidden_states), self.kv_heads)

        repeats = self.num_heads // self.kv_heads
        key = repeat_kv(key, repeats)
        value = repeat_kv(value, repeats)

        output, weights = scaled_dot_product_attention(query, key, value, attn_mask)
        batch, _, seq_len, _ = output.shape
        output = output.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_size)
        return self.o_proj(output), weights
配置 等价变体 KV cache 相对大小 适用直觉
kv_heads = num_heads MHA 最大 表达最自由,成本最高。
kv_heads = 1 MQA 约为 MHA 的 \(1 / num\_heads\) decode 服务很省,但共享最强。
1 < kv_heads < num_heads GQA 约为 MHA 的 \(kv\_heads / num\_heads\) 现代 LLM 常用折中方案。

实现风险复盘

手撕题里最有价值的部分,是能指出代码为什么“看起来能跑”,但一换 batch、device、mask 或生成阶段就可能出错。

Mask 广播

三维 SDPA 和四维 MHA 的 padding mask 不能用同一套 unsqueeze 规则。要先写出 score shape,再让 mask broadcast 到同一 shape。

Device 与 dtype

不要在 forward 里随手创建 CPU tensor 做 scale。用 math.sqrt 或 Python float,避免 GPU 运行时隐式问题。

连续内存

transpose 之后的张量通常不是 contiguous。合并 head 前用 contiguousreshape

整除约束

hidden_size % num_heads == 0num_heads % kv_heads == 0 应该显式校验,不要让 shape error 延迟到 matmul。

Cache 语义

cache 的 seq_len 是历史长度,不一定等于当前输入长度。decode 单步 attention 的 score 通常是 (B,H,1,T)

MLA 不可硬补

MLA 不是简单把 GQA 的 K/V head 再减少。它涉及低秩 latent cache、RoPE 部分维度和投影分解,需要单独推导。

面试复述模板

如果要在面试里手撕 Attention,建议按下面顺序回答。这样不会陷入“代码背出来了,但 shape 说不清”的问题。

  1. 先定义输入:hidden_states(B,T,C)C = H * D
  2. 说明 Q/K/V:三个线性层把同一个输入投影到 query、key、value 空间。
  3. 拆 head:从 (B,T,C) 变为 (B,H,T,D),便于每个 head 独立做注意力。
  4. 计算 score:query @ key.transpose(-1, -2) / sqrt(D),得到 (B,H,T,T)
  5. 加 mask:causal mask 管未来,padding mask 管无效 token。
  6. softmax 后乘 value:得到每个位置的上下文向量。
  7. 合并 head:转回 (B,T,C),再经过输出投影。
  8. 如果追问推理优化,再讲 KV Cache、MQA/GQA 为什么减少 decode 阶段的内存带宽。

我的判断

这篇材料适合作为 Attention 手撕题的入口,因为它覆盖了从 SDPA 到 GQA 的主线。但真正要把它变成可面试、可工程使用的知识,需要把“实现片段”提升成“张量不变量”:每一步都能写出 batch/head/query/key/value 的形状,知道 mask 加在哪个维度,知道 cache 复用的是历史 K/V 而不是历史输出。

最值得沉淀的 insight 是:MQA/GQA 并不是新的注意力公式,而是 K/V 共享策略的改变。这个视角会让实现更 DRY,也更容易解释为什么现代 LLM 在推理服务里偏好 GQA:decode 阶段的瓶颈经常不是算不出矩阵乘,而是读写庞大的 KV cache 太贵。

后续如果继续补全这套“手撕经典算法”笔记,优先补 MLA、RoPE、RMSNorm、SwiGLU、MoE router 和 FlashAttention。它们和 Attention 的关系很紧:一个处理位置,一个处理归一化,一个处理 FFN 表达,一个处理稀疏专家,一个处理 IO-aware attention kernel。