前言MoEMixture of Experts是当前大模型架构的标配——Mixtral、DeepSeek、Qwen都用MoE把参数量做大的同时保持推理成本低。但MoE训练有一个致命瓶颈Token路由。每个Token要被路由到不同的Expert8个Expert意味着8路AllToAll通信。8卡训练每张卡负责1个Expert每次前向传播要来两轮AllToAlldispatchcombine通信量是Dense模型的4-6倍。通信时间比计算时间还长GPU/NPU利用率不到50%。ops-transformer的MoE算子核心优化就是Expert计算路由通信融合——把原来3次kernel launch合并为1次减少AllToAll的等待开销。实测下来8 Expert MoE训练ops-transformer比PyTorch手写快5倍。MoE训练的通信瓶颈先理解问题出在哪。标准MoE的前向传播流程1. Gate计算h → gate_logits → top_k experts → dispatch_mask 2. AllToAll Dispatch按路由把Token发到对应Expert所在的卡 3. Expert计算各卡上的Expert做FFN计算 4. AllToAll Combine把计算结果发回原卡 5. Combine按路由权重加权求和PyTorch手写的实现这5步是5个独立的kernel# PyTorch手写MoE简化版defmoe_forward(x,gate,experts):# Step1: Gate计算gate_logitsgate(x)# kernel 1topk_vals,topk_indicestorch.topk(gate_logits,k2)# Step2: AllToAll Dispatchdispatch_bufferall_to_all_dispatch(x,topk_indices)# kernel 2 通信# Step3: Expert计算expert_outputexperts(dispatch_buffer)# kernel 3# Step4: AllToAll Combinecombine_bufferall_to_all_combine(expert_output)# kernel 4 通信# Step5: Combineoutputcombine(combine_buffer,topk_vals,topk_indices)# kernel 5returnoutput5个kernel launch 2轮AllToAll总耗时 5×launch开销 2×通信时间 计算时间。在8卡训练中AllToAll通信时间约占60%计算只占20%launch开销占20%。ops-transformer的MoE算子优化ops-transformer做了三件事优化1Expert计算路由融合把Gate计算、dispatch、Expert计算合并为一个kernel减少2次launch开销。优化2AllToAll与计算overlap在AllToAll dispatch的通信过程中已经开始做部分Expert计算通信和计算并行执行不用等通信完成再计算。优化3优化通信拓扑利用hcomm的原语级优化选择最优的AllToAll通信拓扑减少跨节点通信量。PyTorch手写 Gate → [等待] → AllToAll → [等待] → Expert → [等待] → AllToAll → [等待] → Combine 总耗时 T_gate T_a2a1 T_expert T_a2a2 T_combine ops-transformer融合 GateDispatchExpert → [AllToAll与Expert overlap] → Combine 总耗时 ≈ T_gate max(T_a2a, T_expert) T_combine代码实战用ops-transformer搭建Switch Transformerimporttorchimporttorch.nnasnnimportops_transformerclassSwitchTransformerLayer(nn.Module):用ops-transformer的MoE算子实现Switch Transformer层def__init__(self,d_model4096,d_ff16384,n_experts8,top_k1):super().__init__()self.d_modeld_model self.n_expertsn_experts self.top_ktop_k# Gate决定每个Token去哪个Expertself.gatenn.Linear(d_model,n_experts,biasFalse)# Experts8个FFN每个是一个独立的MLPself.expertsnn.ModuleList([nn.Sequential(nn.Linear(d_model,d_ff,biasFalse),nn.SiLU(),nn.Linear(d_ff,d_model,biasFalse),)for_inrange(n_experts)])defforward(self,x:torch.Tensor)-torch.Tensor: x: [batch, seq_len, d_model] batch,seq_len,d_modelx.shape# 用ops-transformer的融合MoE算子# 一个调用完成GateDispatchExpertCombineoutputops_transformer.moe(x,gateself.gate(x),# Gate logitsexpertsself.experts,# Expert列表num_expertsself.n_experts,# Expert数量top_kself.top_k,# Top-K路由renormalizeTrue,# 重新归一化路由权重use_distributedTrue,# 启用分布式AllToAll)returnoutput# 性能对比 importtime d_model4096n_experts8seq_len2048batch_size4# 创建模型model_pytorchSwitchTransformerLayerPyTorch(d_model,16384,n_experts).npu()model_fusedSwitchTransformerLayer(d_model,16384,n_experts).npu()xtorch.randn(batch_size,seq_len,d_model).npu()# PyTorch手写MoEwarmup 测时_model_pytorch(x)torch.npu.synchronize()t0time.time()for_inrange(50):ymodel_pytorch(x)torch.npu.synchronize()pytorch_time(time.time()-t0)/50# ops-transformer融合MoEwarmup 测时_model_fused(x)torch.npu.synchronize()t0time.time()for_inrange(50):ymodel_fused(x)torch.npu.synchronize()fused_time(time.time()-t0)/50print(fPyTorch手写MoE:{pytorch_time*1000:.1f}ms)print(fops-transformer融合MoE:{fused_time*1000:.1f}ms)print(f加速比:{pytorch_time/fused_time:.1f}x)# 典型输出8卡Ascend 910# PyTorch手写MoE: 45.2ms# ops-transformer融合MoE: 9.1ms# 加速比: 5.0x代码讲解ops_transformer.moe是融合MoE算子的入口一个调用完成Gate计算Token DispatchExpert计算Combine。renormalizeTrue表示对Top-K路由权重做重新归一化Switch Transformer默认做法。use_distributedTrue启用分布式AllToAll通信多卡训练时自动做Expert分发。对比PyTorch手写实现融合算子省掉了4次kernel launch和2次同步等待。踩坑实录坑1Expert数量不是卡数的倍数AllToAll对不齐现象6卡训练8个Expertops_transformer.moe报错AllToAll shape mismatch。原因AllToAll要求每张卡分到相同数量的Token。8个Expert在6张卡上分配不均匀2卡各2个Expert4卡各1个导致各卡收到的Token数不一致。解决Expert数量必须能被卡数整除。# 错误8 Expert在6卡上分配不均n_experts8# 8 % 6 ≠ 0n_gpus6# 正确选能被卡数整除的Expert数量n_experts6# 6 % 6 0每卡1个Expertn_experts12# 12 % 6 0每卡2个Expert# 或者用EPExpert Parallelism# 允许1张卡放多个Expert绕过整除限制坑2Top-K路由导致负载不均衡现象训练前期所有Token都路由到Expert 0和Expert 3其他Expert闲着。原因Top-K路由存在赢者通吃效应——强Expert越来越强弱Expert越来越弱。解决加负载均衡loss。# 标准做法加辅助loss惩罚不均匀的路由分布defload_balancing_loss(gate_logits,n_experts): gate_logits: [batch*seq_len, n_experts] 返回: 辅助loss加到训练loss中 # 每个Expert被选中的概率probstorch.softmax(gate_logits,dim-1)# 每个Expert被选中的频率_,top_indicestorch.topk(gate_logits,k1,dim-1)freqtorch.zeros(n_experts,devicegate_logits.device)freq.scatter_add_(0,top_indices.squeeze(-1),torch.ones_like(top_indices.squeeze(-1),dtypetorch.float32))freqfreq/freq.sum()# 辅助loss n * sum(freq_i * prob_i)aux_lossn_experts*(freq*probs.mean(dim0)).sum()returnaux_loss# 训练时加入辅助losstotal_losstask_loss0.01*load_balancing_loss(gate_logits,n_experts)坑3FP16下Gate精度不够路由抖动现象训练不稳定loss震荡路由在epoch之间剧烈变化。原因FP16的精度只有1/1024Gate logits的微小差异比如5.0 vs 5.1在FP16下被放大导致路由决策在边界处频繁翻转。解决Gate用FP32计算。# 错误Gate在FP16下计算gate_logitsself.gate(x.half())# 精度不够# 正确Gate在FP32下计算gate_logitsself.gate(x.float()).half()# 先FP32再转回FP16性能对比数据测试环境Ascend 910 × 8CANN 8.0PyTorch 2.1。配置PyTorch手写ops-transformer加速比4 Expert, Top1, 单卡8.5ms4.2ms2.0x8 Expert, Top1, 8卡45.2ms9.1ms5.0x8 Expert, Top2, 8卡62.3ms13.8ms4.5x16 Expert, Top2, 8卡95.1ms18.5ms5.1x8卡训练时加速最明显因为AllToAll通信占比最高融合overlap优化的收益最大。单卡训练通信开销小加速比只有2倍。结尾ops-transformer的MoE算子住在CANN五层架构第2层AOL算子库用Expert计算路由通信融合AllToAll overlap优化把8 Expert MoE训练加速到PyTorch手写的5倍。如果在昇腾NPU上训练MoE模型强烈建议用ops-transformer的融合MoE算子。实测下来8卡训练一个Switch Transformer层只要9msPyTorch手写要45ms省下来的时间够多训3轮epoch。昇腾CANN的大模型算子能力还在持续增强。如果在用的过程中遇到啥问题欢迎去AtomGit上的昇腾CANN开源社区逛逛里面有一手资料和活跃社区。社区链接https://atomgit.com/cann/ops-transformer