当前位置: 首页 > news >正文

CANN-昇腾NPU梯度累积-显存不够时怎么模拟大batch训练

大模型训练的最佳 batch size 通常在 1M-4M tokens。8 卡 Atlas 800I A2 的总显存 512GBbatch size 能开到 50 万 tokens 左右——不够。梯度累积让你用小 batch 跑多次前向累积梯度后一次性更新等效于大 batch 训练。梯度累积的原理标准训练batch4 x1 → forward → backward → step x2 → forward → backward → step x3 → forward → backward → step x4 → forward → backward → step 梯度累积accumulation_steps4 x1 → forward → backward (不step累积梯度) x2 → forward → backward (累积梯度) x3 → forward → backward (累积梯度) x4 → forward → backward (累积梯度) → step4 次 micro-batch 累积后做 1 次参数更新等效 batch size 4 × micro_batch_size。torch_npu 实现fromtorch_npu.npuimportamp accumulation_steps4modelMyModel().to(npu:0)optimizertorch.optim.AdamW(model.parameters())scaleramp.GradScaler()fori,datainenumerate(dataloader):withamp.autocast(dtypetorch.bfloat16):lossmodel(data)/accumulation_steps# 注意loss 要除以累积步数scaler.scale(loss).backward()if(i1)%accumulation_steps0:scaler.step(optimizer)scaler.update()optimizer.zero_grad()两个关键点loss 除以 accumulation_steps等效于梯度平均跟大 batch 训练的梯度一致zero_grad 在 step 之后累积期间不清零梯度显存开销梯度累积的显存开销跟不累积的区别标准训练batch4 激活4 × seq × hidden × 2 bytes × 32 层 梯度参数量 × 2 bytes 梯度累积micro_batch1, steps4 激活1 × seq × hidden × 2 bytes × 32 层 ← 只有 1/4 梯度参数量 × 2 bytes ← 一样 累积梯度参数量 × 2 bytes ← 额外一份激活显存减少 75%只多了累积梯度的开销约等于参数量大小。Llama2-7B 参数 14GB额外 14GB 的累积梯度换来 3 倍的激活显存节省。跟 DDP 的配合分布式数据并行DDP 梯度累积importtorch.distributedasdist dist.init_process_group(backendhccl)modelDDP(model)fori,datainenumerate(dataloader):withamp.autocast(dtypetorch.bfloat16):lossmodel(data)/accumulation_steps loss.backward()if(i1)%accumulation_steps0:# DDP 自动在 backward 时做 All-Reduce# 累积步结束时梯度已经是跨卡平均后的optimizer.step()optimizer.zero_grad()DDP 的 All-Reduce 在每次 backward 时触发不是在 step 时。这意味着累积 4 步会做 4 次 All-Reduce每次传当前 micro-batch 的梯度而不是 1 次 All-Reduce传总梯度。通信量不变4 次小量 1 次大量但延迟多了 4 次通信启动开销。在 HCCS 带宽高的单机内影响可忽略。跟 MC2 的配合MoE 训练用 MC2 做通算融合梯度累积不影响 MC2 的执行。MC2 通信的是 token 路由数据不是梯度——每次 micro-batch 的 token 路由独立不跨步累积。精度影响梯度累积的梯度值跟大 batch 训练在数学上等价float32 下。但 float16/bf16 下有微小差异大 batch一次计算 4 个 micro-batch 的激活然后一次性算梯度。梯度只在最后一步做 bf16→fp32 转换。梯度累积4 次独立计算梯度每次做 bf16→fp32 转换。4 次 fp32 梯度累加。差异约 1e-6在训练的统计噪声范围内。梯度累积是大模型训练的标准操作——不是什么高级技巧是显存不够时的必要手段。实现简单记得 loss 除以步数就行。跟 DDP 和 MC2 都兼容。仓库在这里https://atomgit.com/cann/torch_npu
http://www.zskr.cn/news/1356687.html

相关文章:

  • 揭秘银泰百货卡回收方法!线上回收教你快速变现 - 团团收购物卡回收
  • 2026年常德黄金回收避坑指南 福运来等六家靠谱实测 - 黄金回收
  • 2026年AI论文写作软件测评:5款神器从选题到格式全流程护航
  • 10分钟掌握Markdown Here:浏览器扩展一键转换Markdown到富文本
  • 让AI读书系列——Claude的读后感
  • 合金低阻贴片电阻:攻克电流采样精度、温漂与可靠性挑战
  • OpenPilot智能驾驶系统:如何实现300+车型的自动驾驶辅助?
  • 2026年热门声音转换成文字工具实测对比,多场景准确率比拼,低调黑马才是真王者
  • 计算机视觉学习全攻略:从核心概念到深度学习实战
  • 2026国产在线PH计十大品牌排行榜|市政污水与工业水处理实测选型指南 - 仪表品牌榜
  • 终极MQTT客户端快速入门指南:5分钟掌握跨平台物联网通信
  • 颠覆性自动驾驶革命:openpilot如何重塑驾驶辅助系统的未来
  • 4Gb密度+256M×16组织:K4B4G1646E-BCNB的DDR3-2133内存颗粒参数解析
  • 2026宁波公司注册代办机构优选推荐,本地十大正规工商落地服务口碑榜单 - 品牌智鉴榜
  • 天虹购物卡回收注意事项:最全的使用范围与心得分享 - 团团收购物卡回收
  • 如何用puppeteer-extra-plugin-stealth突破网站反爬虫检测:18种规避技术深度解析
  • 终极指南:在Windows上无缝安装安卓应用的免费神器
  • 嵌入式RTOS核心概念:任务、线程与进程的区别与应用
  • 智能穿戴设备快速开发:从概念到原型的低代码平台实践
  • 嵌入式系统如何成为医疗设备核心引擎:从需求到落地的全流程解析
  • 2026年华东蒸发器源头厂家推荐:蒸发器 / MVR 蒸发器 / 多效蒸发器 / 高盐废水蒸发器 / 选择指南 - 海棠依旧大
  • Focus-DETR:基于前景特征选择的高效目标检测模型解析
  • 五分钟搞定Nodejs项目对接多模型API的配置教程
  • 0欧电阻:电路设计中的瑞士军刀,从原理到实战全解析
  • GPU加速多波束相控阵雷达:异构计算架构与工程实践
  • [实战指南] 2026年制造业MSA测量系统分析:核心方法论与数字化实施路径
  • 高危作业零穿戴管控,无感定位彻底规避UWB电气安全隐患
  • 【独家首发】保险业首个AI Agent成熟度评估模型(5级量化标准+12项KPI基线数据)
  • HR流程自动化卡点全诊断,从招聘到离职的12个Lindy可配置节点及失效预警清单
  • 对比直接调用厂商API,使用Taotoken聚合端在容灾方面的体验