前言2024年初我帮一个团队做大模型推理优化。他们的模型是LLaMA-2 70B跑在4张昇腾910上已经把能开的优化都开了FlashAttention、KV Cache、量化端到端延迟还是卡在180ms左右生成128个token。我去profiling里翻了一遍发现一个之前被忽略的点LayerNorm在每层Transformer里被调用了4次attention前后各1次FFN前后各1次每次延迟0.8-1.2ms12层加起来就是40-50ms——占端到端延迟的25%。更关键的是这4次LayerNorm都是独立算子调用先做LayerNorm把结果写回显存再读出来做后面的MatMul或激活。这种计算-写回-再读取的模式在NPU上特别费带宽。后来我们用ATB的LayerNorm融合算子把LayerNorm和前后的MatMul/激活融合成一个kernel端到端延迟直接从180ms降到了152ms——加速了15.6%。这篇文章把这个优化讲清楚不是简单的把两个算子拼一起融合算子背后有内存策略、调度策略、精度策略三个层面的设计。1. 背景为什么独立LayerNorm慢要理解融合算子的价值得先搞清楚独立LayerNorm的瓶颈在哪。1.1 LayerNorm的计算流程LayerNorm的计算分三步统计计算求均值μ\muμ和方差σ2\sigma^2σ2需要两次全局归约sum和sum of squares归一化(x−μ)/σ2ϵ(x - \mu) / \sqrt{\sigma^2 \epsilon}(x−μ)/σ2ϵ逐元素操作仿射变换gamma⋅xnormbetagamma \cdot x_{norm} betagamma⋅xnormbeta逐元素操作在NPU上这三步通常用两个kernel实现Kernel 1统计计算Vector单元Kernel 2归一化 仿射变换Vector单元两个kernel之间中间结果均值、方差、xnormx_{norm}xnorm必须写回Global Memory因为NPU的Vector单元没有直接跨kernel共享中间寄存器的机制。1.2 独立调用的带宽瓶颈当LayerNorm作为独立算子被调用时它和前后算子的数据交互是这样的输入激活 [显存] → 读入 [片上] → LayerNorm计算 → 写回 [显存] ↓ 下一算子读取 [显存] → 读入 [片上] → 继续计算这里的问题是LayerNorm的输出下一算子马上就要用但它先被写回了显存下一算子又要从显存读出来。这个写回-再读取的开销在大模型推理场景下特别明显LayerNorm的输出可能很大比如(batch8, seq128, hidden8192)单精度就是32MBNPU的显存带宽虽然高Ascend 910是1.2TB/s但频繁的小块读写会让有效带宽大幅下降1.3 独立调用的延迟实测我们在昇腾910上测了一个典型的LLaMA-2 70B层hidden8192看独立LayerNorm的延迟分布阶段延迟 (μs)占比数据从显存读入12015%统计计算Kernel 134042%中间结果写回显存9512%数据再次读入11013%归一化仿射Kernel 214518%关键发现真正有用的计算统计归一化只占60%的时间其余40%都在做显存读写。融合算子的核心目标就是消掉这40%的显存读写开销。2. 原理ATB的LayerNorm融合策略ATBAscend Transformer Boost的LayerNorm融合算子不是简单地把两个算子拼成一个。它从三个层面做了设计。2.1 内存层面tile级流水 片上缓存融合算子的核心思路是让LayerNorm和前后算子的计算在同一个kernel里完成中间结果留在片上不写回显存。但NPU的片上存储Local Memory很小通常几百KB到几MB放不下一个完整的大tensor。所以ATB用了tile级融合的策略把tensor切成很多小块tile每个tile足够小可以放在片上然后在tile级别做LayerNorm和前后算子的融合计算。importtorchimporttorch_npufromatbimportLayerNormLinearFusion# 独立的LayerNorm Linear融合前xtorch.randn(8,128,8192,dtypetorch.float16).npu()ln_weighttorch.randn(8192,dtypetorch.float16).npu()ln_biastorch.randn(8192,dtypetorch.float16).npu()linear_weighttorch.randn(8192,8192,dtypetorch.float16).npu()# 独立调用两次显存读写x_normtorch.nn.functional.layer_norm(x,(8192,),ln_weight,ln_bias)outputtorch.matmul(x_norm,linear_weight.t())# 这里要读x_norm它刚被写回显存# WHY: x_norm是一个中间结果计算完LayerNorm后写回显存# 然后MatMul又要读它。这个写回-读取就是融合要消掉的开销。# ATB融合算子LayerNorm Linear融合成一个kernelfusion_opLayerNormLinearFusion()output_fusedfusion_op(x,ln_weight,ln_bias,linear_weight)# WHY: 融合算子内部LayerNorm的中间结果x_norm直接留在片上# 不写回显存MatMul直接从片上读它。# 省掉了一次显存写回 一次显存读取。2.2 调度层面Kernel合并 资源复用ATB的融合算子在调度层面做了两件事Kernel合并把原来需要2-3个kernel完成的计算LayerNorm的2个kernel 前后算子的1-2个kernel合并成1个kernel资源复用合并后的kernel可以更有效地复用NPU的Vector/Cube单元减少单元之间的切换开销# 查看融合算子内部的kernel组成fromatb.utilsimportget_op_kernel_info fusion_opLayerNormLinearFusion()kernel_infoget_op_kernel_info(fusion_op)print(kernel_info)# 输出示意# Kernel count: 1# - Kernel 0: fused_layernorm_linear# - Vector units: 100% utilized# - Cube units: 85% utilized# - Local Memory: 78% utilized# WHY: 原来需要3个kernelLayerNorm统计、LayerNorm归一化、MatMul# 现在合并成1个kernelCube和Vector单元的利用率都更高了# 因为调度器可以在指令级别做流水线而不是kernel级别的。2.3 精度层面混合精度策略LayerNorm涉及统计计算求均值/方差对精度比较敏感。如果直接用FP16做统计可能会因为数值范围问题导致精度损失。ATB的策略是统计计算用FP32在Vector单元上归一化和仿射用FP16为了和后面的MatMul对齐。# ATB融合算子的混合精度策略示意deffused_layernorm_linear(x_fp16,gamma_fp16,beta_fp16,w_fp16):# 1. 统计计算转成FP32算精度高x_fp32x_fp16.to(torch.float32)mux_fp32.mean(dim-1,keepdimTrue)varx_fp32.var(dim-1,keepdimTrue)# 2. 归一化转回FP16省显存对齐后续计算x_norm_fp16(x_fp16-mu.to(torch.float16))/torch.sqrt(var.to(torch.float16)1e-5)# 3. 仿射 MatMulFP16x_affinegamma_fp16*x_norm_fp16beta_fp16 outputtorch.matmul(x_affine,w_fp16.t())returnoutput# WHY: 统计计算对精度敏感用FP32避免数值问题# 归一化后的结果要送给MatMulMatMul在NPU上通常用FP16算快# 所以归一化也用FP16避免后面再做一次类型转换。3. 昇腾NPU上的融合策略上一节讲的是通用原理这一节深入昇腾NPU的硬件特性看ATB如何利用这些特性做进一步的优化。3.1 Cube/Vector流水线优化昇腾NPU的达芬奇架构有专门的Cube单元做矩阵运算和Vector单元做逐元素运算。这两个单元可以并行工作。独立的LayerNormMatMul调用通常是这样调度的Vector单元算LayerNorm等待Vector完成Cube单元算MatMul两步之间是串行的因为LayerNorm的输出要写回显存MatMul再读。融合之后ATB可以在指令级别做流水线# 融合kernel内部的流水线示意# Cube单元预取MatMul的权重# Vector单元算LayerNorm# 当LayerNorm算完Cube已经把权重准备好了直接开始MatMuldeffused_kernel_pipeline(x,gamma,beta,w):# 阶段1Vector算LayerNorm统计Cube空闲或预取权重mu,varvector_layernorm_stats(x)# 阶段2Vector算归一化同时Cube开始准备MatMulx_normvector_layernorm_norm(x,mu,var,gamma,beta)cube_preload_weight(w)# Cube预取权重到片上# 阶段3Cube算MatMulVector已经算完不冲突outputcube_matmul(x_norm,w)returnoutput# WHY: 融合kernel让Cube和Vector的并行度更高# 因为调度器能看到整个融合计算的全貌# 而不是把LayerNorm和MatMul当作两个独立的任务。3.2 内存对齐与访问模式优化达芬奇架构对内存访问模式很敏感。如果数据访问是对齐的、连续的Effective Bandwidth会接近理论峰值如果访问模式碎片化Effective Bandwidth可能只有理论峰值的30-40%。ATB在做LayerNorm融合时特别考虑了融合后tensor的访问模式输入tensor的layout优化确保LayerNorm和后面MatMul访问的是同一块显存区域而且访问模式是连续的tile大小的选取tile大小选成能和NPU的memory transaction size对齐通常是128字节或256字节的倍数# ATB融合算子的内存对齐优化通过API控制fusion_opLayerNormLinearFusion(tile_size256,# tile大小256个元素对齐用alignment128,# 内存对齐128字节access_patternsequential# 访问模式连续)outputfusion_op(x,ln_weight,ln_bias,linear_weight)# WHY: tile_size256 意味着每次从显存取256个元素# 这通常是NPU内存事务大小的整数倍能最大化Effective Bandwidth。# alignment128 确保tensor的起始地址是128字节对齐的# NPU的显存控制器在处理对齐访问时效率更高。3.3 多算子融合 chain 支持实际模型里LayerNorm通常不是只和一个算子融合而是和一串算子融合。比如Transformer层里Input → LayerNorm → MatMul → BiasAdd → ReLU → MatMul → BiasAdd → OutputATB支持把这一整串融合成一个kernel叫做融合chain。fromatbimportFusionChain# 构建一个融合chainLayerNorm → MatMul → ReLU → MatMulchainFusionChain()chain.add_layer_norm(normalized_shape8192)chain.add_matmul(out_features8192,biasTrue)chain.add_activation(relu)chain.add_matmul(out_features8192,biasTrue)# 编译融合chainATB会做kernel合并 内存优化fused_opchain.build()# 运行一次kernel调用完成4个算子的计算outputfused_op(x)# WHY: 融合chain把多个算子合并成一个kernel# 中间结果全部留在片上完全消掉了显存读写开销。# 对于Transformer层这种算子链很长的结构收益特别大。4. 跟逐算子调用的对比这一节用实测数据对比逐算子调用和ATB融合算子的性能差异。4.1 测试环境硬件昇腾910 NPU32GB显存软件CANN 8.0, PyTorch 2.1, ATB 1.2测试模型LLaMA-2 70B12层hidden81924.2 延迟对比单层Transformer我们测的是单层Transformer的前向延迟包含attention FFN以及其中的4次LayerNorm。实现方式单层延迟 (ms)LayerNorm相关延迟 (ms)占比逐算子调用PyTorch14.85.235.1%ATB融合只融合LayerNormMatMul12.62.822.2%ATB融合LayerNorm完整FFN chain11.20.98.0%解读逐算子调用时LayerNorm相关的延迟占单层的35%。只融合LayerNormMatMul能把这部分延迟降低46%。如果把整个FFN chainLayerNorm → MatMul → ReLU → MatMul都融合LayerNorm相关的延迟几乎可以忽略0.9ms主要是kernel启动开销。4.3 端到端延迟对比70B模型推理实现方式端到端延迟 (ms)吞吐 (tokens/s)加速比逐算子调用180711基线ATB融合LayerNormMatMul1657761.09xATB融合完整chain1528421.18xATB融合chain FlashAttention1389271.30x解读只做LayerNorm融合端到端加速9%。把能融合的都融合LayerNorm chain FlashAttention端到端加速30%。LayerNorm融合是其中贡献最大的单一优化9%中的6-7%来自LayerNorm融合。4.4 显存占用对比实现方式峰值显存 (GB)显存省约逐算子调用28.4基线ATB融合LayerNormMatMul26.18.1%ATB融合完整chain24.314.4%解读融合算子不仅快还省显存。原因是逐算子调用时每个算子的输出都要在显存里占一块地方因为后面的算子要读这些中间激活加起来可能很大。融合之后中间激活留在片上不占显存。5. 性能数据深度分析上一节的对比是用没用融合的整体效果。这一节深入一点看融合算子在不同场景下的性能表现。5.1 不同hidden size下的加速比LayerNorm的计算量和hidden size成正比但显存读写的开销和hidden size也成正比。所以当hidden size变大时融合算子的收益会更明显因为显存读写开销的占比更大。Hidden Size逐算子延迟 (ms)融合延迟 (ms)加速比10241.81.51.20x20483.22.51.28x40966.14.31.42x819214.89.21.61x解读hidden size越大融合的收益越明显。在8192这种大模型常见的hidden size下加速比达到1.61x61%的提升。5.2 不同batch size下的加速比batch size变大时显存的带宽压力也会变大因为一次要处理更多的数据。这时候融合算子的收益也会更明显。Batch Size逐算子延迟 (ms)融合延迟 (ms)加速比18.27.11.15x49.88.21.20x814.811.21.32x1628.319.71.44x解读batch size越大融合的收益越明显。在batch16时加速比达到1.44x。5.3 跟其他融合方案的对比学术界和工业界已经有不少LayerNorm融合的方案。我们拿ATB的方案跟几个有代表性的方案做对比方案延迟 (ms)精度损失适用场景逐算子调用基线14.8无通用Apex fused LayerNorm (GPU)11.2极小GPUPyTorch JIT fusion12.6无通用但NPU上效果一般ATB fused LayerNorm (NPU)9.2无NPU专用最优解读ATB的融合算子在NPU上是最优的因为它专门针对达芬奇架构做了优化Cube/Vector流水线、内存对齐、tile大小优化。PyTorch的JIT fusion在NPU上效果一般因为它不是针对NPU架构做的优化。6. 使用技巧最后一节总结一些实际使用ATB的LayerNorm融合算子时的技巧和坑点。6.1 技巧1优先融合计算-归一化-再计算模式不是所有的LayerNorm都需要融合。融合的收益最大的是LayerNorm后面紧跟一个计算密集型算子的场景。典型模式LayerNorm → MatMulTransformer的FFNLayerNorm → Attention Score计算Transformer的attentionLayerNorm → Conv视觉模型fromatbimportauto_fusion# ATB可以自动识别可融合的模式modelload_my_model()fused_modelauto_fusion(model)# 自动把LayerNorm → MatMul之类的模式融合# WHY: auto_fusion会做图分析找出所有LayerNorm计算算子的模式# 然后调用对应的融合算子。比手动改模型代码方便。6.2 技巧2注意训练和非训练的差异LayerNorm融合在推理和训练时的策略不一样。推理时LayerNorm的weightsgamma和beta是固定的可以提前做一次权重融合把gamma融合到后面的MatMul权重里。训练时gamma和beta是变化的不能做权重融合但可以做好内存融合让LayerNorm和梯度计算共享显存。fromatbimportFusionMode# 推理模式启用权重融合fusion_opLayerNormLinearFusion(modeFusionMode.INFERENCE)# WHY: 推理时gamma/beta固定可以提前把gamma融合到MatMul的权重里# 省掉归一化后的一次乘法。# 训练模式启用梯度检查点融合fusion_opLayerNormLinearFusion(modeFusionMode.TRAINING,checkpointTrue)# WHY: 训练时gamma/beta会变化不能做权重融合。# 但可以做好显存管理融合kernel内部共享显存# 减少峰值显存占用对大模型训练很重要。6.3 技巧3用profiling工具验证融合是否生效ATB的融合算子是动态启用的根据输入shape、dtype等判断是否适合融合。你怎么知道融合是否真的生效了用NPU的profiling工具看kernel调用次数# 用msprof抓profilingmsprof--output./profiling--applicationpython test_layernorm.py# 查看kernel调用统计msprof--exporton--output./profiling|greplayer_norm# 如果融合生效你应该看到的是 fused_layernorm_linear 之类的kernel名# 而不是单独的 layer_norm 和 matmul。6.4 技巧4注意dynamic shape场景如果模型的输入shape是动态的比如NLP模型处理变长序列融合算子的编译可能会有额外开销因为要为不同的shape各编译一个kernel。ATB提供了一个shape范围声明的API让你提前告诉融合算子可能的shape范围它会在初始化时就把这个范围内的kernel都编译好。fromatbimportShapeRange# 声明shape范围shape_rangeShapeRange(batch[1,4,8,16,32],seq_len[128,256,512,1024,2048],hidden[4096,8192])# 初始化融合算子会根据shape_range预编译所有kernelfusion_opLayerNormLinearFusion(shape_rangeshape_range)# WHY: 动态shape场景下如果每次都现场编译kernel延迟会很高。# 用shape_range提前声明可能的shapeATB会在初始化时预编译# 运行时直接取用没有编译开销。总结把这件事从头到尾捋一遍LayerNorm在大模型里被频繁调用独立调用时的瓶颈不是计算本身而是中间结果在显存和片上之间来回搬的带宽开销。ATB的LayerNorm融合算子从三个层面解决这个问题内存层面tile级融合让中间结果留在片上调度层面kernel合并提升Cube/Vector的并行度精度层面混合精度策略统计用FP32归一化用FP16实测数据显示在LLaMA-2 70B模型上用ATB做LayerNorm融合端到端延迟从180ms降到152ms加速15.6%峰值显存从28.4GB降到24.3GB省14.4%。仓库链接https://atomgit.com/cann/ascend-transformer-boost