手撕经典算法 #1 Attention 篇整理
这份笔记把原文的 Attention 手写代码整理成一条更稳定的学习路线:先理解 SDPA 的张量对象,再扩展到 MHA、KV Cache、MQA/GQA,最后指出原实现里值得面试和工程实践特别留意的 shape、mask 与缓存边界。
来源与导入方式
本页是站内整理版,不是外部页面全文镜像。原文提供了多种注意力机制的 PyTorch 示例;这里保留来源脉络,重写关键实现,并补充实现风险、统一抽象和面试复盘清单。
hewei2001.pages.dev/Manual-Coding-1,页面题为“手撕经典算法 #1 Attention篇”。
原页面标注最后更新为 2025-03-20;本次抓取和整理时间为 2026-05-21。
原文 MLA 部分仍为 TODO;本页不虚构 MLA 实现,只说明它与 KV 压缩的关系和后续补充方向。
先看共同结构
所有这些变体都在回答同一个问题:当前位置的 query 应该从哪些 key 对应的 value 中取信息?差异主要来自两个地方:是否拆成多个 head,以及 K/V 是否在多个 query head 之间共享。
当前位置发出的“我要找什么”。在 MHA 中常见形状是:
(batch, heads, q_len, head_dim)
每个历史位置提供的“我有什么特征”。GQA/MQA 会减少 K head 数。
(batch, kv_heads, kv_len, head_dim)
真正被加权汇总的信息内容。它的 head 共享策略通常和 Key 一致。
(batch, kv_heads, kv_len, head_dim)
| 变体 | 核心变化 | 收益 | 代价/边界 |
|---|---|---|---|
| SDPA | 直接计算 \(QK^\top\),不关心投影和多头拆分。 | 最小机制,适合理解公式和 mask。 | 不是完整 Transformer attention layer。 |
| MHA | 把 hidden 拆成多个 query/key/value head,最后 concat 后输出投影。 | 不同 head 可以学习不同关系。 | KV cache 随 head 数线性增长。 |
| MQA | 多个 Q head 共享同一组 K/V。 | 大幅降低 decode 阶段 KV cache 带宽和显存。 | 共享过强,表达能力可能受影响。 |
| GQA | 介于 MHA 和 MQA:多个 Q head 分组共享 K/V。 | 显存/带宽与效果之间更平衡。 | 需要保证 query heads 能整除 kv heads。 |
| MLA | 通过低秩 latent 表示压缩 KV 相关状态。 | 进一步降低缓存成本。 | 原文未展开;完整实现需要 RoPE、低秩投影和 cache 格式一起考虑。 |
SDPA:把公式写对
缩放点积注意力是后面所有变体的内核。手写时最重要的不是类结构,而是确认 score 的最后两个维度是 \(q\_len \times k\_len\),mask 只加在 key 维度或未来位置上。
推荐的最小实现
import math
import torch
def scaled_dot_product_attention(query, key, value, attn_mask=None):
"""
query: (batch, heads, q_len, head_dim)
key: (batch, heads, k_len, head_dim)
value: (batch, heads, k_len, head_dim)
attn_mask: broadcastable to (batch, heads, q_len, k_len);
True means the position is masked.
"""
scale = 1.0 / math.sqrt(query.size(-1))
scores = torch.matmul(query, key.transpose(-1, -2)) * scale
if attn_mask is not None:
scores = scores.masked_fill(attn_mask, torch.finfo(scores.dtype).min)
weights = torch.softmax(scores, dim=-1)
return torch.matmul(weights, value), weights
SDPA 是纯计算内核,没有参数。把它写成函数更符合 KISS:输入 Q/K/V 和 mask,输出 attention result。完整 Attention layer 的可学习参数应该放在 MHA/GQA 类里。
Mask 的关键边界
| Mask 类型 | 典型形状 | 语义 | 最容易错的点 |
|---|---|---|---|
| Causal mask | (q_len, k_len) 或 (1, 1, q_len, k_len) |
当前位置不能看未来 token。 | decode 单步时 q_len=1,k_len=历史长度+1,不再是完整方阵。 |
| Padding mask | (batch, k_len) 转成 (batch, 1, 1, k_len) |
不让模型关注 padding token。 | 如果 score 不是四维,盲目 unsqueeze 两次会触发错误广播。 |
MHA:多头只是换形状,不是换公式
多头注意力的主线是:对输入做 Q/K/V 三个线性投影,把 hidden_size 拆成 num_heads × head_dim,在每个 head 内做 SDPA,再把所有 head 拼回 hidden_size。
更稳健的 MHA 写法
from torch import nn
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_size, num_heads, dropout=0.0):
super().__init__()
if hidden_size % num_heads != 0:
raise ValueError("hidden_size must be divisible by num_heads")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.o_proj = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(dropout)
def _split_heads(self, x):
batch, seq_len, _ = x.shape
x = x.view(batch, seq_len, self.num_heads, self.head_dim)
return x.transpose(1, 2)
def _merge_heads(self, x):
batch, _, seq_len, _ = x.shape
x = x.transpose(1, 2).contiguous()
return x.view(batch, seq_len, self.hidden_size)
def forward(self, hidden_states, attn_mask=None):
query = self._split_heads(self.q_proj(hidden_states))
key = self._split_heads(self.k_proj(hidden_states))
value = self._split_heads(self.v_proj(hidden_states))
output, weights = scaled_dot_product_attention(query, key, value, attn_mask)
output = self.o_proj(self._merge_heads(output))
return output, weights
多头不是把模型“复制多份”,而是把同一个 hidden 表示投影到多个子空间。每个 head 的 score 矩阵独立,最后 concat 再用 \(W_o\) 混合。
transpose 后再 view 要注意内存连续性;输出合并时用 .contiguous().view(...) 或 .reshape(...),不要默认 stride 一定兼容。
KV Cache:真正省的是 decode 阶段的重复 K/V
KV Cache 的语义是:生成第 \(t\) 个 token 时,历史 token 的 K/V 不需要重新投影和保存一遍。注意它主要服务自回归 decode;prefill 阶段通常一次性处理完整 prompt,并建立初始 cache。
单层缓存接口的写法
class CachedMultiHeadAttention(MultiHeadAttention):
def forward(self, hidden_states, attn_mask=None, past_key_value=None, use_cache=False):
query = self._split_heads(self.q_proj(hidden_states))
key = self._split_heads(self.k_proj(hidden_states))
value = self._split_heads(self.v_proj(hidden_states))
if past_key_value is not None:
past_key, past_value = past_key_value
key = torch.cat([past_key, key], dim=2)
value = torch.cat([past_value, value], dim=2)
output, weights = scaled_dot_product_attention(query, key, value, attn_mask)
output = self.o_proj(self._merge_heads(output))
new_cache = (key, value) if use_cache else None
return output, new_cache, weights
| 阶段 | 输入长度 | 是否使用已有 cache | 核心指标 |
|---|---|---|---|
| Prefill | prompt 全长 | 通常没有已有 cache,但会产出初始 cache。 | TTFT,算力密集,矩阵乘较大。 |
| Decode | 通常每步 1 个 token | 复用历史 K/V,只追加当前 token 的 K/V。 | TPOT,带宽和 KV cache 读写更关键。 |
KV Cache 示例的测试函数需要先实例化 attention 模块;此外,真实模型的 cache 不是单层变量,而是每一层各自保存一份 K/V。生成服务里还要管理 batch 内不同序列长度、cache eviction、prefix cache 命中和 P/D 分离。
MQA/GQA:用一个抽象统一 K/V 共享
MQA 和 GQA 最好不要写成两套重复代码。更清晰的抽象是:query_heads 固定为 num_heads,key/value_heads 可配置。KV heads 等于 num_heads 就是 MHA,等于 1 就是 MQA,介于中间就是 GQA。
统一版 GQA/MQA
def repeat_kv(x, repeats):
# x: (batch, kv_heads, seq_len, head_dim)
if repeats == 1:
return x
batch, kv_heads, seq_len, head_dim = x.shape
x = x[:, :, None, :, :].expand(batch, kv_heads, repeats, seq_len, head_dim)
return x.reshape(batch, kv_heads * repeats, seq_len, head_dim)
class GroupedQueryAttention(nn.Module):
def __init__(self, hidden_size, num_heads, kv_heads):
super().__init__()
if hidden_size % num_heads != 0:
raise ValueError("hidden_size must be divisible by num_heads")
if num_heads % kv_heads != 0:
raise ValueError("num_heads must be divisible by kv_heads")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.kv_heads = kv_heads
self.head_dim = hidden_size // num_heads
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, kv_heads * self.head_dim)
self.v_proj = nn.Linear(hidden_size, kv_heads * self.head_dim)
self.o_proj = nn.Linear(hidden_size, hidden_size)
def _split(self, x, heads):
batch, seq_len, _ = x.shape
x = x.view(batch, seq_len, heads, self.head_dim)
return x.transpose(1, 2)
def forward(self, hidden_states, attn_mask=None):
query = self._split(self.q_proj(hidden_states), self.num_heads)
key = self._split(self.k_proj(hidden_states), self.kv_heads)
value = self._split(self.v_proj(hidden_states), self.kv_heads)
repeats = self.num_heads // self.kv_heads
key = repeat_kv(key, repeats)
value = repeat_kv(value, repeats)
output, weights = scaled_dot_product_attention(query, key, value, attn_mask)
batch, _, seq_len, _ = output.shape
output = output.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_size)
return self.o_proj(output), weights
| 配置 | 等价变体 | KV cache 相对大小 | 适用直觉 |
|---|---|---|---|
kv_heads = num_heads |
MHA | 最大 | 表达最自由,成本最高。 |
kv_heads = 1 |
MQA | 约为 MHA 的 \(1 / num\_heads\) | decode 服务很省,但共享最强。 |
1 < kv_heads < num_heads |
GQA | 约为 MHA 的 \(kv\_heads / num\_heads\) | 现代 LLM 常用折中方案。 |
实现风险复盘
手撕题里最有价值的部分,是能指出代码为什么“看起来能跑”,但一换 batch、device、mask 或生成阶段就可能出错。
三维 SDPA 和四维 MHA 的 padding mask 不能用同一套 unsqueeze 规则。要先写出 score shape,再让 mask broadcast 到同一 shape。
不要在 forward 里随手创建 CPU tensor 做 scale。用 math.sqrt 或 Python float,避免 GPU 运行时隐式问题。
transpose 之后的张量通常不是 contiguous。合并 head 前用 contiguous 或 reshape。
hidden_size % num_heads == 0、num_heads % kv_heads == 0 应该显式校验,不要让 shape error 延迟到 matmul。
cache 的 seq_len 是历史长度,不一定等于当前输入长度。decode 单步 attention 的 score 通常是 (B,H,1,T)。
MLA 不是简单把 GQA 的 K/V head 再减少。它涉及低秩 latent cache、RoPE 部分维度和投影分解,需要单独推导。
面试复述模板
如果要在面试里手撕 Attention,建议按下面顺序回答。这样不会陷入“代码背出来了,但 shape 说不清”的问题。
- 先定义输入:
hidden_states是(B,T,C),C = H * D。 - 说明 Q/K/V:三个线性层把同一个输入投影到 query、key、value 空间。
- 拆 head:从
(B,T,C)变为(B,H,T,D),便于每个 head 独立做注意力。 - 计算 score:
query @ key.transpose(-1, -2) / sqrt(D),得到(B,H,T,T)。 - 加 mask:causal mask 管未来,padding mask 管无效 token。
- softmax 后乘 value:得到每个位置的上下文向量。
- 合并 head:转回
(B,T,C),再经过输出投影。 - 如果追问推理优化,再讲 KV Cache、MQA/GQA 为什么减少 decode 阶段的内存带宽。
我的判断
这篇材料适合作为 Attention 手撕题的入口,因为它覆盖了从 SDPA 到 GQA 的主线。但真正要把它变成可面试、可工程使用的知识,需要把“实现片段”提升成“张量不变量”:每一步都能写出 batch/head/query/key/value 的形状,知道 mask 加在哪个维度,知道 cache 复用的是历史 K/V 而不是历史输出。
最值得沉淀的 insight 是:MQA/GQA 并不是新的注意力公式,而是 K/V 共享策略的改变。这个视角会让实现更 DRY,也更容易解释为什么现代 LLM 在推理服务里偏好 GQA:decode 阶段的瓶颈经常不是算不出矩阵乘,而是读写庞大的 KV cache 太贵。
后续如果继续补全这套“手撕经典算法”笔记,优先补 MLA、RoPE、RMSNorm、SwiGLU、MoE router 和 FlashAttention。它们和 Attention 的关系很紧:一个处理位置,一个处理归一化,一个处理 FFN 表达,一个处理稀疏专家,一个处理 IO-aware attention kernel。