#十八点五、ZeRO 与 3D 并行深度解析(新增:显存定量分析与组合策略)

前面的分布式训练专项已经覆盖了 DP/TP/PP/ZeRO 的基本概念,但面试中最能拉开差距的追问是:"具体能省多少显存?""如果模型从 7B 涨到 70B,你的并行策略怎么变?""3D 并行中每一维度的通信量是多少?"

本节从定量分析工程决策两个维度,把分布式训练从"知道名词"提升到"能算、能配、能排障"。


#一、ZeRO 的显存节省:一步步算清楚

#1. 单机单卡训练时的显存构成

以 AdamW + BF16 训练一个 L 层、d_model=D 的模型为例:

显存组件 计算 相对参数量
参数(BF16) P × 2 bytes
梯度(BF16) P × 2 bytes
优化器状态(FP32) P × 2 × 4 bytes 4×(一阶矩+二阶矩)
激活值 取决于 seq_len, batch, L 通常 1-4×
总计 约 12-16× 参数量

其中 P 是模型参数量。

具体例子:7B 模型

参数:    7B × 2B  = 14 GB
梯度:    7B × 2B  = 14 GB
Optimizer State: 7B × 8B = 56 GB
激活(seq=4K, batch=1): 约 20-40 GB
-----------------------------------
总计: 约 104-124 GB

单卡 A100 (80 GB) 放不下!这就是为什么需要 ZeRO。


#2. ZeRO-1/2/3 的显存节省对比

假设有 N 张 GPU,模型参数量为 P。

ZeRO-1:切分 Optimizer States

每张卡存储:
  参数: P × 2B (完整)
  梯度: P × 2B (完整)
  优化器状态: P × 8B / N (分片)

单卡显存(不含激活)= 2P + 2P + 8P/N = 4P + 8P/N

N=8 时: 4P + P = 5P(对比原来的 12P,节省 58%)

ZeRO-2:切分 Optimizer States + Gradients

每张卡存储:
  参数: P × 2B (完整)
  梯度: P × 2B / N (分片)
  优化器状态: P × 8B / N (分片)

单卡显存(不含激活)= 2P + 2P/N + 8P/N = 2P + 10P/N

N=8 时: 2P + 1.25P = 3.25P(对比 12P,节省 73%)

ZeRO-3:切分 Parameters + Gradients + Optimizer States

每张卡存储:
  参数: P × 2B / N (分片)
  梯度: P × 2B / N (分片)
  优化器状态: P × 8B / N (分片)

单卡显存(不含激活)= 2P/N + 2P/N + 8P/N = 12P/N

N=8 时: 12P/8 = 1.5P(对比 12P,节省 87.5%)

ZeRO-Offload(CPU/NVMe):把优化器状态和/或梯度放到 CPU 内存甚至 NVMe 硬盘上:

单卡显存(仅参数 + 部分激活)≈ 2P + 激活

但代价是:CPU-GPU 通信成为瓶颈,训练速度显著下降。

#3. 数值对比表(7B 模型,8 张 A100)

方案 单卡显存(不含激活) 能否放下(80GB 卡) 备注
无 ZeRO 84 GB ❌ 不行 仅参数+梯度+状态
ZeRO-1 35 GB ✅ 可以 状态分 8 份
ZeRO-2 22.75 GB ✅ 可以 状态+梯度分 8 份
ZeRO-3 10.5 GB ✅ 可以 全部状态分 8 份
ZeRO-3 + Offload ~7 GB ✅ 很宽松 CPU 通信成瓶颈

#二、3D 并行(DP + TP + PP)的系统性讲解

#1. 为什么要 3D 并行?

单一并行策略只能解决一个问题:

  • DP:解决吞吐问题(加卡加数据),但要求单卡放下完整模型;
  • TP:解决单层矩阵太大的问题(如 hidden_dim=8192 的线性层),但要求卡间高速互联;
  • PP:解决模型层数太多的问题(如 L=80),但引入流水线气泡。

当模型大到单卡完全放不下时(如 70B+),必须组合使用


#2. 3D 并行的组合方式

典型配置:8 节点 × 8 GPU = 64 卡

最外层: DP(数据并行)
  └─ 每个 DP 组内有: TP(张量并行)+ PP(流水线并行)

例:
- TP size = 4(每 4 张卡张量并行,要求 NVLink)
- PP size = 4(每 4 个 stage 流水线并行)
- DP size = 64 / (4 × 4) = 4

总卡数 = DP × TP × PP = 4 × 4 × 4 = 64

物理拓扑对应

Node 0: GPU 0-1-2-3 (TP group 0, Stage 0) ─┐
Node 1: GPU 4-5-6-7 (TP group 1, Stage 0) ─┤ Pipeline Stage 0
                                             │
Node 2: GPU 8-9-10-11 (TP group 2, Stage 1)─┤
Node 3: GPU 12-13-14-15 (TP group 3, Stage 1)┤ Pipeline Stage 1
                                              │
...                                          │
                                             │
Node 6-7: Stage 2-3                         ┘ Pipeline Stage 2-3

#3. 各维度的通信量分析

TP(张量并行)通信

  • 每层的 AllReduce / AllGather:通信量 ≈ 2 × batch_size × seq_len × d_model × bytes
  • 发生在每层的前向和反向,频率最高
  • 因此 TP 要求卡间有最高速互联(NVLink/NVSwitch),通常限制在单机内。

PP(流水线并行)通信

  • 每个 stage 边界传输激活值:通信量 ≈ batch_size × seq_len × d_model × bytes
  • 频率低(每 stage 一次),但传输的数据块大;
  • 可以用 IB/RoCE 跨机,不需要 NVLink 级别带宽。

DP(数据并行)通信

  • 梯度同步 AllReduce:通信量 ≈ 模型参数量 × bytes
  • 每 step 一次,通信量大但频率低;
  • 在 ZeRO-1/2 中,AllReduce 可以 overlap;在 ZeRO-3/FSDP 中,用 reduce-scatter + all-gather。

通信量对比表(以 batch=4, seq=2K, d=4096, BF16 为例):

并行维度 单次通信量 每 step 次数 对互联要求
TP ~134 MB 2×L 极高(NVLink)
PP ~67 MB 2×PP 中(IB/RoCE)
DP (AllReduce) ~14 GB (7B model) 1 中(IB/RoCE)

#4. 混合并行的决策流程

Step 1: 模型能否放入单卡?
├─ 是 ─> 只用 DP(最简单)
└─ 否 ─> 继续 Step 2

Step 2: 单层矩阵能否放入单卡?
├─ 否 ─> 先用 TP 切单层(TP size = 2/4/8)
└─ 是 ─> 继续 Step 3

Step 3: 模型层数能否放入单卡(在 TP 后)?
├─ 否 ─> 加 PP 切层(PP size = 2/4/8)
└─ 是 ─> 继续 Step 4

Step 4: 还需要更多吞吐?
├─ 是 ─> 加 DP 扩数据并行
└─ 否 ─> 完成

Step 5: 显存还是不够?
├─ 是 ─> 加 ZeRO-2/3 或 FSDP 切状态
└─ 否 ─> 完成

典型模型配置示例

模型大小 TP PP DP ZeRO 总卡数 说明
7B 1 1 8 ZeRO-1 8 单层放得下,纯 DP
13B 1 1 8 ZeRO-2 8 状态更大,切梯度
30B 2 4 4 ZeRO-1 32 TP=2 切单层,PP=4 切层
70B 4 4 4 ZeRO-1 64 TP=4 需要 NVLink
175B 8 8 8 ZeRO-1 512 3D 并行全上
1T+ 8 16+ 16+ ZeRO-3/FSDP 2048+ 必须全分片

#三、FSDP vs ZeRO-3:工程实现差异

维度 DeepSpeed ZeRO-3 PyTorch FSDP
分片粒度 参数级(更细) 模块级(如 Transformer Block)
通信模式 all-gather + reduce-scatter all-gather + reduce-scatter
reshard 策略 step 后立刻释放 可配置(full shard / sharded grad / hybrid)
offload 支持 原生支持 CPU/NVMe offload 需配合 offloading 扩展
使用方式 配置文件 + DeepSpeed engine PyTorch 原生 API
性能 大模型优化更深入 与 PyTorch 生态更无缝

面试回答要点:两者本质都是"全分片"思想,差异主要在框架生态和实现粒度。真正影响性能的不是选哪个框架,而是:

  1. 分片粒度是否和模型结构匹配;
  2. 通信是否能和计算有效重叠;
  3. 是否避免了不必要的全量参数聚合。

#四、Pipeline Parallel 的 Bubble 定量分析

Pipeline Bubble 是 PP 的核心性能损失来源。

假设

  • PP size = p(p 个 stage)
  • 每个 stage 处理一个 micro-batch 的时间 = t_f(前向)+ t_b(反向)
  • 为了简化,假设 t_f = t_b = t
  • micro-batch 数量 = m

Bubble 时间(GPipe 风格,all-forward-then-all-backward):

Bubble = (p - 1) × (t_f + t_b) = 2(p - 1)t

总时间 = m × p × t + bubble = mp t + 2(p-1)t
Pipeline 效率 = (mpt) / (mpt + 2(p-1)t) = mp / (mp + 2p - 2)

当 m >> p 时,效率趋近于 1

具体数值

PP=4, micro-batch=8:
  效率 = 32 / (32 + 6) = 32/38 ≈ 84%

PP=8, micro-batch=8:
  效率 = 64 / (64 + 14) = 64/78 ≈ 82%

PP=8, micro-batch=32:
  效率 = 256 / (256 + 14) = 256/270 ≈ 95%

关键结论:micro-batch 数量越大,bubble 占比越小。但 micro-batch 不能无限增大(受显存限制)。

面试追问:怎么缓解 bubble?

  1. 增加 micro-batch 数量(显存允许时);
  2. 使用 interleaved pipeline scheduling(如 PipeDream、1F1B);
  3. 通信计算重叠(overlap forward/backward 与梯度同步)。

#五、训练系统排障:从症状到根因

症状 可能根因 排查方法
Loss 突然变 NaN 学习率过大 / 梯度爆炸 / 数值溢出 检查梯度范数、loss scaling 回退次数
多卡 loss 不一致 通信错误 / 数据加载不一致 检查 all-reduce 一致性、数据分片
加卡后加速比低 通信瓶颈 / bubble / 负载不均 profiling 通信时间、检查各卡利用率
显存爆了 激活过大 / batch 太大 / 未开 checkpointing 逐层检查显存峰值、减小 seq_len/batch
ZeRO-3 特别慢 all-gather 未重叠 / 分片太细 检查 overlap 配置、增大 bucket size
PP 某 stage 特别慢 负载不均 / 某层计算太重 检查各 stage 的 FLOPs 是否均衡