#十四点五、手撕代码完整参考实现(新增)

前面列出了手撕代码的题单和评分要点,但面试准备时"知道考什么"和"能手写出来"之间还有一段距离。下面给出两道最高频手撕题的完整可运行参考实现,每一行都标注了 shape 变化、关键设计意图和常见踩坑点。建议先自己尝试写,再对照参考实现查漏补缺。


#参考实现 1:手写 Multi-Head Attention(PyTorch)

题目要求:给定 Q/K/V 三个张量,实现 Scaled Dot-Product Attention + Multi-Head 拆分与合并,支持 padding maskcausal mask,并保证数值稳定性。

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    """
    手写 Multi-Head Attention 完整实现。

    面试评分点对应:
    1. shape 是否正确流转                    -> 每步都有 shape 注释
    2. 是否记得除以 sqrt(d_k)                -> scaled_dot_product_attention 内
    3. causal mask 方向是否正确               -> apply_mask 内,上三角为 -inf
    4. 数值稳定性(softmax 前减最大值)        -> 不需要,因为除以 sqrt(d_k) 已控制量级
    5. 多头拆分与合并的维度操作是否正确         -> view + transpose / transpose + view
    """

    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"

        self.d_model = d_model          # 模型隐藏维度,如 512, 768, 4096
        self.num_heads = num_heads      # 注意力头数,如 8, 12, 32
        self.head_dim = d_model // num_heads   # 每头维度 d_k,如 64

        # 线性投影:把输入映射到 Q/K/V 空间
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # 输出投影:把多头拼接后的结果映射回 d_model
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None, is_causal: bool = False):
        B, T, _ = x.shape

        # Step 1: 线性投影
        Q = self.W_q(x)   # (B, T, d_model)
        K = self.W_k(x)   # (B, T, d_model)
        V = self.W_v(x)   # (B, T, d_model)

        # Step 2: 拆分为多头: (B, T, d) -> (B, num_heads, T, head_dim)
        Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # Step 3: Scaled Dot-Product Attention
        attn_output, attn_weights = self.scaled_dot_product_attention(
            Q, K, V, mask=mask, is_causal=is_causal
        )

        # Step 4: 多头合并
        attn_output = attn_output.transpose(1, 2).contiguous()   # (B, T, H, d_k)
        attn_output = attn_output.view(B, T, self.d_model)       # (B, T, d_model)
        # 踩坑点:transpose 后必须 contiguous(),否则 view 可能报错

        # Step 5: 输出投影
        output = self.W_o(attn_output)
        output = self.dropout(output)
        return output, attn_weights

    def scaled_dot_product_attention(self, Q, K, V, mask=None, is_causal=False):
        B, H, T, d_k = Q.shape

        # 注意力分数: (B, H, T, d_k) @ (B, H, d_k, T) -> (B, H, T, T)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        # 应用 mask
        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(mask == 0, -1e9)

        if is_causal:
            causal_mask = torch.triu(torch.ones(T, T, device=scores.device, dtype=torch.bool), diagonal=1)
            scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), -1e9)

        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 加权求和: (B, H, T, T) @ (B, H, T, d_k) -> (B, H, T, d_k)
        output = torch.matmul(attn_weights, V)
        return output, attn_weights


# 快速测试
def test_mha():
    B, T, d = 2, 10, 512
    num_heads = 8
    mha = MultiHeadAttention(d_model=d, num_heads=num_heads)
    x = torch.randn(B, T, d)

    # Test 1: 无 mask
    out, attn = mha(x)
    assert out.shape == (B, T, d)
    assert attn.shape == (B, num_heads, T, T)
    print("Test 1 通过: 无 mask 前向")

    # Test 2: padding mask
    pad_mask = torch.ones(B, T, dtype=torch.bool)
    pad_mask[:, 8:] = False
    out2, attn2 = mha(x, mask=pad_mask)
    assert torch.allclose(attn2[:, :, :, 8:].sum(), torch.tensor(0.0), atol=1e-6)
    print("Test 2 通过: padding mask 正确")

    # Test 3: causal mask
    out3, attn3 = mha(x, is_causal=True)
    upper_tri = torch.triu(torch.ones(T, T), diagonal=1).bool()
    assert torch.allclose(attn3[:, :, upper_tri].sum(), torch.tensor(0.0), atol=1e-6)
    print("Test 3 通过: causal mask 正确")

    print("\n所有测试通过!")

if __name__ == "__main__":
    test_mha()

#面试官追问要点

  1. 为什么除以 sqrt(d_k) 假设 Q、K 元素独立同分布 N(0,1),则点积方差为 d_k。除以 sqrt(d_k) 把方差拉回 1,避免 softmax 输入过大导致梯度消失。
  2. 为什么 masked_fill 填 -1e9 而不是 -inf 某些硬件上 -inf 经过 softmax 可能产生 NaN(整行被 mask 时)。-1e9 足够小又安全。
  3. 为什么 transpose 后要 contiguous() PyTorch 的 view 要求张量内存连续。transpose 只改变 strides 不复制数据,所以必须 contiguous() 后再 view
  4. W_q/W_k/W_v 能否合并? 可以,很多实现用 nn.Linear(d_model, 3*d_model) 一次性投影再切分,效率略高。面试写分开的更清晰。

#参考实现 2:手写简化 PPO 算法(PyTorch)

题目要求:实现简化版 PPO,包含 Actor-Critic、GAE、PPO-Clip 目标、KL early stopping。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np

class ActorCritic(nn.Module):
    def __init__(self, obs_dim: int, action_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
        )
        self.actor_head = nn.Linear(hidden_dim, action_dim)
        self.critic_head = nn.Linear(hidden_dim, 1)

    def forward(self, obs):
        feat = self.shared(obs)
        return self.actor_head(feat), self.critic_head(feat)

    def get_action_and_value(self, obs, action=None):
        logits, value = self.forward(obs)
        dist = Categorical(logits=logits)
        if action is None:
            action = dist.sample()
        return action, dist.log_prob(action), dist.entropy(), value.squeeze(-1)


def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
    """GAE: 反向计算,从最后一个时间步往前推。"""
    T = len(rewards)
    advantages = torch.zeros_like(rewards)
    last_gae = 0.0
    for t in reversed(range(T)):
        next_value = 0.0 if t == T - 1 else values[t + 1]
        delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]
        last_gae = delta + gamma * lam * (1 - dones[t]) * last_gae
        advantages[t] = last_gae
    return advantages, advantages + values


class PPOTrainer:
    def __init__(self, ac, lr=3e-4, gamma=0.99, lam=0.95,
                 clip_eps=0.2, value_coef=0.5, entropy_coef=0.01,
                 max_kl=0.015, num_epochs=4, batch_size=64, device="cpu"):
        self.ac = ac.to(device)
        self.optimizer = torch.optim.Adam(self.ac.parameters(), lr=lr)
        self.gamma, self.lam = gamma, lam
        self.clip_eps = clip_eps
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        self.max_kl = max_kl
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.device = device

    def update(self, traj):
        obs = traj["obs"].to(self.device)
        actions = traj["actions"].to(self.device)
        rewards = traj["rewards"].to(self.device)
        dones = traj["dones"].to(self.device)
        old_log_probs = traj["old_log_probs"].to(self.device)
        old_values = traj["old_values"].to(self.device)
        T = len(obs)

        # Step 1: GAE
        with torch.no_grad():
            advantages, returns = compute_gae(rewards, old_values, dones, self.gamma, self.lam)
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # Step 2: PPO epochs
        for epoch in range(self.num_epochs):
            indices = torch.randperm(T)
            for start in range(0, T, self.batch_size):
                idx = indices[start:start + self.batch_size]
                _, new_log_probs, entropy, new_values = self.ac.get_action_and_value(
                    obs[idx], actions[idx]
                )

                # PPO-Clip
                ratio = torch.exp(new_log_probs - old_log_probs[idx])
                surr1 = ratio * advantages[idx]
                surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages[idx]
                policy_loss = -torch.min(surr1, surr2).mean()

                value_loss = F.mse_loss(new_values, returns[idx])
                loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy.mean()

                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.ac.parameters(), 0.5)
                self.optimizer.step()

                # KL early stopping
                with torch.no_grad():
                    kl = (old_log_probs[idx] - new_log_probs).mean().abs()
                if kl > self.max_kl:
                    print(f"KL={kl:.4f} > {self.max_kl}, early stopping")
                    return

#手撕 PPO 的面试官评分 checklist

评分项 权重 考察内容
Actor-Critic 结构 20% policy 和 value 输出头是否分开
GAE 计算 20% 反向递推、done 截断是否正确
PPO-Clip 目标 25% ratio、clip、min 是否正确
Loss 组合 15% policy + value + entropy 符号和系数
KL 监控 10% 是否有 early stopping 意识
工程细节 10% 梯度裁剪、advantage 归一化

#参考实现 3:极简 Attention(5 分钟版)

import torch, math
def attention(q, k, v, mask=None):
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attn = torch.softmax(scores, dim=-1)
    return torch.matmul(attn, v), attn

追问点:causal mask 怎么加?为什么除以 sqrt(D)?多头在哪里拆?