大模型训练的最佳 batch size 通常在 1M-4M tokens。8 卡 Atlas 800I A2 的总显存 512GBbatch size 能开到 50 万 tokens 左右——不够。梯度累积让你用小 batch 跑多次前向累积梯度后一次性更新等效于大 batch 训练。梯度累积的原理标准训练batch4 x1 → forward → backward → step x2 → forward → backward → step x3 → forward → backward → step x4 → forward → backward → step 梯度累积accumulation_steps4 x1 → forward → backward (不step累积梯度) x2 → forward → backward (累积梯度) x3 → forward → backward (累积梯度) x4 → forward → backward (累积梯度) → step4 次 micro-batch 累积后做 1 次参数更新等效 batch size 4 × micro_batch_size。torch_npu 实现fromtorch_npu.npuimportamp accumulation_steps4modelMyModel().to(npu:0)optimizertorch.optim.AdamW(model.parameters())scaleramp.GradScaler()fori,datainenumerate(dataloader):withamp.autocast(dtypetorch.bfloat16):lossmodel(data)/accumulation_steps# 注意loss 要除以累积步数scaler.scale(loss).backward()if(i1)%accumulation_steps0:scaler.step(optimizer)scaler.update()optimizer.zero_grad()两个关键点loss 除以 accumulation_steps等效于梯度平均跟大 batch 训练的梯度一致zero_grad 在 step 之后累积期间不清零梯度显存开销梯度累积的显存开销跟不累积的区别标准训练batch4 激活4 × seq × hidden × 2 bytes × 32 层 梯度参数量 × 2 bytes 梯度累积micro_batch1, steps4 激活1 × seq × hidden × 2 bytes × 32 层 ← 只有 1/4 梯度参数量 × 2 bytes ← 一样 累积梯度参数量 × 2 bytes ← 额外一份激活显存减少 75%只多了累积梯度的开销约等于参数量大小。Llama2-7B 参数 14GB额外 14GB 的累积梯度换来 3 倍的激活显存节省。跟 DDP 的配合分布式数据并行DDP 梯度累积importtorch.distributedasdist dist.init_process_group(backendhccl)modelDDP(model)fori,datainenumerate(dataloader):withamp.autocast(dtypetorch.bfloat16):lossmodel(data)/accumulation_steps loss.backward()if(i1)%accumulation_steps0:# DDP 自动在 backward 时做 All-Reduce# 累积步结束时梯度已经是跨卡平均后的optimizer.step()optimizer.zero_grad()DDP 的 All-Reduce 在每次 backward 时触发不是在 step 时。这意味着累积 4 步会做 4 次 All-Reduce每次传当前 micro-batch 的梯度而不是 1 次 All-Reduce传总梯度。通信量不变4 次小量 1 次大量但延迟多了 4 次通信启动开销。在 HCCS 带宽高的单机内影响可忽略。跟 MC2 的配合MoE 训练用 MC2 做通算融合梯度累积不影响 MC2 的执行。MC2 通信的是 token 路由数据不是梯度——每次 micro-batch 的 token 路由独立不跨步累积。精度影响梯度累积的梯度值跟大 batch 训练在数学上等价float32 下。但 float16/bf16 下有微小差异大 batch一次计算 4 个 micro-batch 的激活然后一次性算梯度。梯度只在最后一步做 bf16→fp32 转换。梯度累积4 次独立计算梯度每次做 bf16→fp32 转换。4 次 fp32 梯度累加。差异约 1e-6在训练的统计噪声范围内。梯度累积是大模型训练的标准操作——不是什么高级技巧是显存不够时的必要手段。实现简单记得 loss 除以步数就行。跟 DDP 和 MC2 都兼容。仓库在这里https://atomgit.com/cann/torch_npu