#十四点五、手撕代码完整参考实现(新增)
前面列出了手撕代码的题单和评分要点,但面试准备时"知道考什么"和"能手写出来"之间还有一段距离。下面给出两道最高频手撕题的完整可运行参考实现,每一行都标注了 shape 变化、关键设计意图和常见踩坑点。建议先自己尝试写,再对照参考实现查漏补缺。
#参考实现 1:手写 Multi-Head Attention(PyTorch)
题目要求:给定 Q/K/V 三个张量,实现 Scaled Dot-Product Attention + Multi-Head 拆分与合并,支持 padding mask 和 causal mask,并保证数值稳定性。
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
"""
手写 Multi-Head Attention 完整实现。
面试评分点对应:
1. shape 是否正确流转 -> 每步都有 shape 注释
2. 是否记得除以 sqrt(d_k) -> scaled_dot_product_attention 内
3. causal mask 方向是否正确 -> apply_mask 内,上三角为 -inf
4. 数值稳定性(softmax 前减最大值) -> 不需要,因为除以 sqrt(d_k) 已控制量级
5. 多头拆分与合并的维度操作是否正确 -> view + transpose / transpose + view
"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.d_model = d_model # 模型隐藏维度,如 512, 768, 4096
self.num_heads = num_heads # 注意力头数,如 8, 12, 32
self.head_dim = d_model // num_heads # 每头维度 d_k,如 64
# 线性投影:把输入映射到 Q/K/V 空间
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# 输出投影:把多头拼接后的结果映射回 d_model
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.head_dim)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, is_causal: bool = False):
B, T, _ = x.shape
# Step 1: 线性投影
Q = self.W_q(x) # (B, T, d_model)
K = self.W_k(x) # (B, T, d_model)
V = self.W_v(x) # (B, T, d_model)
# Step 2: 拆分为多头: (B, T, d) -> (B, num_heads, T, head_dim)
Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
# Step 3: Scaled Dot-Product Attention
attn_output, attn_weights = self.scaled_dot_product_attention(
Q, K, V, mask=mask, is_causal=is_causal
)
# Step 4: 多头合并
attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, H, d_k)
attn_output = attn_output.view(B, T, self.d_model) # (B, T, d_model)
# 踩坑点:transpose 后必须 contiguous(),否则 view 可能报错
# Step 5: 输出投影
output = self.W_o(attn_output)
output = self.dropout(output)
return output, attn_weights
def scaled_dot_product_attention(self, Q, K, V, mask=None, is_causal=False):
B, H, T, d_k = Q.shape
# 注意力分数: (B, H, T, d_k) @ (B, H, d_k, T) -> (B, H, T, T)
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# 应用 mask
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(1).unsqueeze(2)
scores = scores.masked_fill(mask == 0, -1e9)
if is_causal:
causal_mask = torch.triu(torch.ones(T, T, device=scores.device, dtype=torch.bool), diagonal=1)
scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), -1e9)
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 加权求和: (B, H, T, T) @ (B, H, T, d_k) -> (B, H, T, d_k)
output = torch.matmul(attn_weights, V)
return output, attn_weights
# 快速测试
def test_mha():
B, T, d = 2, 10, 512
num_heads = 8
mha = MultiHeadAttention(d_model=d, num_heads=num_heads)
x = torch.randn(B, T, d)
# Test 1: 无 mask
out, attn = mha(x)
assert out.shape == (B, T, d)
assert attn.shape == (B, num_heads, T, T)
print("Test 1 通过: 无 mask 前向")
# Test 2: padding mask
pad_mask = torch.ones(B, T, dtype=torch.bool)
pad_mask[:, 8:] = False
out2, attn2 = mha(x, mask=pad_mask)
assert torch.allclose(attn2[:, :, :, 8:].sum(), torch.tensor(0.0), atol=1e-6)
print("Test 2 通过: padding mask 正确")
# Test 3: causal mask
out3, attn3 = mha(x, is_causal=True)
upper_tri = torch.triu(torch.ones(T, T), diagonal=1).bool()
assert torch.allclose(attn3[:, :, upper_tri].sum(), torch.tensor(0.0), atol=1e-6)
print("Test 3 通过: causal mask 正确")
print("\n所有测试通过!")
if __name__ == "__main__":
test_mha()
#面试官追问要点
- 为什么除以
sqrt(d_k)? 假设 Q、K 元素独立同分布 N(0,1),则点积方差为d_k。除以sqrt(d_k)把方差拉回 1,避免 softmax 输入过大导致梯度消失。 - 为什么
masked_fill填 -1e9 而不是-inf? 某些硬件上-inf经过 softmax 可能产生 NaN(整行被 mask 时)。-1e9足够小又安全。 - 为什么 transpose 后要
contiguous()? PyTorch 的view要求张量内存连续。transpose只改变 strides 不复制数据,所以必须contiguous()后再view。 - W_q/W_k/W_v 能否合并? 可以,很多实现用
nn.Linear(d_model, 3*d_model)一次性投影再切分,效率略高。面试写分开的更清晰。
#参考实现 2:手写简化 PPO 算法(PyTorch)
题目要求:实现简化版 PPO,包含 Actor-Critic、GAE、PPO-Clip 目标、KL early stopping。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
class ActorCritic(nn.Module):
def __init__(self, obs_dim: int, action_dim: int, hidden_dim: int = 256):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(obs_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
)
self.actor_head = nn.Linear(hidden_dim, action_dim)
self.critic_head = nn.Linear(hidden_dim, 1)
def forward(self, obs):
feat = self.shared(obs)
return self.actor_head(feat), self.critic_head(feat)
def get_action_and_value(self, obs, action=None):
logits, value = self.forward(obs)
dist = Categorical(logits=logits)
if action is None:
action = dist.sample()
return action, dist.log_prob(action), dist.entropy(), value.squeeze(-1)
def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
"""GAE: 反向计算,从最后一个时间步往前推。"""
T = len(rewards)
advantages = torch.zeros_like(rewards)
last_gae = 0.0
for t in reversed(range(T)):
next_value = 0.0 if t == T - 1 else values[t + 1]
delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]
last_gae = delta + gamma * lam * (1 - dones[t]) * last_gae
advantages[t] = last_gae
return advantages, advantages + values
class PPOTrainer:
def __init__(self, ac, lr=3e-4, gamma=0.99, lam=0.95,
clip_eps=0.2, value_coef=0.5, entropy_coef=0.01,
max_kl=0.015, num_epochs=4, batch_size=64, device="cpu"):
self.ac = ac.to(device)
self.optimizer = torch.optim.Adam(self.ac.parameters(), lr=lr)
self.gamma, self.lam = gamma, lam
self.clip_eps = clip_eps
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_kl = max_kl
self.num_epochs = num_epochs
self.batch_size = batch_size
self.device = device
def update(self, traj):
obs = traj["obs"].to(self.device)
actions = traj["actions"].to(self.device)
rewards = traj["rewards"].to(self.device)
dones = traj["dones"].to(self.device)
old_log_probs = traj["old_log_probs"].to(self.device)
old_values = traj["old_values"].to(self.device)
T = len(obs)
# Step 1: GAE
with torch.no_grad():
advantages, returns = compute_gae(rewards, old_values, dones, self.gamma, self.lam)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Step 2: PPO epochs
for epoch in range(self.num_epochs):
indices = torch.randperm(T)
for start in range(0, T, self.batch_size):
idx = indices[start:start + self.batch_size]
_, new_log_probs, entropy, new_values = self.ac.get_action_and_value(
obs[idx], actions[idx]
)
# PPO-Clip
ratio = torch.exp(new_log_probs - old_log_probs[idx])
surr1 = ratio * advantages[idx]
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages[idx]
policy_loss = -torch.min(surr1, surr2).mean()
value_loss = F.mse_loss(new_values, returns[idx])
loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy.mean()
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.ac.parameters(), 0.5)
self.optimizer.step()
# KL early stopping
with torch.no_grad():
kl = (old_log_probs[idx] - new_log_probs).mean().abs()
if kl > self.max_kl:
print(f"KL={kl:.4f} > {self.max_kl}, early stopping")
return
#手撕 PPO 的面试官评分 checklist
| 评分项 | 权重 | 考察内容 |
|---|---|---|
| Actor-Critic 结构 | 20% | policy 和 value 输出头是否分开 |
| GAE 计算 | 20% | 反向递推、done 截断是否正确 |
| PPO-Clip 目标 | 25% | ratio、clip、min 是否正确 |
| Loss 组合 | 15% | policy + value + entropy 符号和系数 |
| KL 监控 | 10% | 是否有 early stopping 意识 |
| 工程细节 | 10% | 梯度裁剪、advantage 归一化 |
#参考实现 3:极简 Attention(5 分钟版)
import torch, math
def attention(q, k, v, mask=None):
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = torch.softmax(scores, dim=-1)
return torch.matmul(attn, v), attn
追问点:causal mask 怎么加?为什么除以 sqrt(D)?多头在哪里拆?