CUTLASS 是 NVIDIA 的矩阵乘模板库catlass 是昇腾的对应物——用 C 模板元编程在编译期生成算子运行时零开销。核心思路把算子拆成可组合的模板参数编译期决定一切tile 大小、数据布局、指令选择运行期只做数据搬运和计算。为什么不用普通函数GEMM 有 15 个可调参数tile M/N/K、向量化宽度、是否用 TF32、是否 fuse ReLU——15 维参数空间手写 15 个特化版本不现实。模板元编程让编译器帮你生成。核心抽象Tile Iteratorcatlass 的一切围绕TileIterator——一个编译期知道 tile 大小、数据布局、向量化宽度的迭代器。它不是 runtime 的for循环是编译期的类型计算。// catlass/include/cutlass/tile_itererator.h// 模板参数// - Shape_MNKtile 的 M/N/K 大小编译期常量// - Element数据类型float16/float32/bfloat16// - Layout数据布局RowMajor/ColumnMajor/VoltaLayout// - VectorWidth向量化加载宽度4/8/16 个元素一次加载templatetypenameShape_MNK,// cutlass::MatrixShape128, 128, 32typenameElement,// float16typenameLayout,// cutlass::layout::RowMajorintVectorWidth8// 一次加载 8 个 float16 128 bitsclassTileIterator{public:usingShapeShape_MNK;usingAccessTypeArrayElement,VectorWidth;// 向量化访问类型// 编译期计算一个 tile 需要几次向量化加载staticconstexprintkAccessCountShape::kMN/VectorWidth*Shape::kN/VectorWidth*Shape::kK/VectorWidth;// 构造函数绑定到全局内存的基地址CUTLASS_DEVICETileIterator(Element*ptr,intstride):ptr_(reinterpret_castAccessType*(ptr)),stride_(stride/VectorWidth){}// 加载一个 tile 到寄存器向量化CUTLASS_DEVICEvoidLoad(TileShape,Elementtile,intm,intn,intk){// 编译期展开kAccessCount 次向量化加载CUTLASS_PRAGMA(unroll)for(inti0;ikAccessCount;i){// 地址计算编译期决定无运行时开销AccessType*addrptr_(m*stride_n)i;tile.data[i]*addr;// 向量化加载128 bits 一次}}// 写回一个 tileCUTLASS_DEVICEvoidStore(constTileShape,Elementtile,intm,intn,intk){CUTLASS_PRAGMA(unroll)for(inti0;ikAccessCount;i){AccessType*addrptr_(m*stride_n)i;*addrtile.data[i];}}private:AccessType*ptr_;// 向量化指针不是 Element*intstride_;// 以 AccessType 为单位的 stride};CUTLASS_PRAGMA(unroll)让编译器把循环完全展开——最终生成的代码没有循环只有 16 条load128向量化加载指令。这是零成本抽象的含义模板代码写起来像泛型编译后和手写汇编一样高效。算子融合的模板实现GEMM Bias ReLU独立的 GEMM 算子C A × B。融合版本C ReLU(A × B Bias)。不用融合的话需要三个 kernelGEMM → AddBias → ReLU两次 HBM 往返。catlass 的融合在模板层面完成——不是 runtime 的if是编译期生成专用的融合 kernel。// catlass/include/cutlass/gemm/kernel/gemm_with_fusion.h// 融合策略模板参数决定融合什么// - Gemm基础矩阵乘// - Epilogue尾部操作bias add / activation / elementwisetemplatetypenameGemmShape,// 128, 128, 32 — tile 大小typenameEpilogueOp,// cutlass::epilogue::thread::ReLUtypenameElementA,// float16typenameElementB,// float16typenameElementC,// float32输出类型typenameElementBias// float32bias 类型classGemmWithFusion{public:usingEpilogueEpilogueOp;// kernel 主函数CUTLASS_DEVICEvoidoperator()(ElementC*ptr_C,intstride_C,ElementA*ptr_A,intstride_A,ElementB*ptr_B,intstride_B,ElementBias*ptr_bias// bias 指针融合用){// 阶段 1分块加载 A 和 B TileIteratorAiteator_A(ptr_A,stride_A);TileIteratorBiteator_B(ptr_B,stride_B);// Tile 在寄存器/SMEM 中TileGemmShape,ElementAtile_A;TileGemmShape,ElementBtile_B;iterator_A.Load(tile_A,blockIdx.x*128,0,threadIdx.x);iterator_B.Load(tile_B,0,blockIdx.y*128,threadIdx.x);// 阶段 2矩阵乘MMA 指令// Ascend NPU 用 Cube 单元做矩阵乘// catlass 把 MMA 封装成模板——编译期选择指令usingMmaOptypenameMmaPromotionGemmShape,ElementA,ElementB,ElementC::Op;MmaOp mma_op;TileGemmShape,ElementCaccum;// 累加器mma_op(accum,tile_A,tile_B,accum);// C A × B C// 阶段 3Epilogue融合的尾部操作// 这是融合的核心——epilogue 在矩阵乘完成后立即执行// 不需要写回 HBM 再读出来Epilogue epilogue;// Step 3.1加载 bias如果融合了这个操作ifconstexpr(Epilogue::kHasBias){TileGemmShape,ElementBiastile_bias;// bias 是 [N] 向量广播到 [M, N] tileLoadBiasbranch::RowBroadcast(tile_bias,ptr_bias,blockIdx.y*128);epilogue.AddBias(accum,tile_bias);}// Step 3.2激活函数ReLU / GELU / SiLUifconstexpr(Epilogue::kHasActivation){epilogue.ApplyActivation(accum);// ReLU: max(0, x)}// Step 3.3写回这是唯一一次 HBM 写TileIteratorCiteator_C(ptr_C,stride_C);iterator_C.Store(accum,blockIdx.x*128,blockIdx.y*128,0);}};融合的关键if constexpr是 C17 的特性——编译期 if。如果Epilogue::kHasBias false整段 bias 加载代码在编译期被删除最终二进制里不存在。这是零成本的另一层含义没用到的融合组件不占代码空间。编译期计算Tensor Core 指令选择Ascend NPU 的 Cube 单元支持多种矩阵乘指令MMA_F16float16 输入float32 累加、MMA_BF16bfloat16、MMA_TF32TensorFloat32仅 Ampere。catlass 用模板特化在编译期选择。// catlass/include/cutlass/arch/mma.h// 基础 MMA 操作描述编译期常量templatetypenameShape,typenameElementA,typenameElementB,typenameElementCstructMmaPromotion;// 特化 1float16 × float16 → float32最常用templateintM,intN,intKstructMmaPromotionMatrixShapeM,N,K,float16,float16,float32{usingOpMmaSychronizedMatrixShape16,8,16,// Tensor Core 的 warp-level tile 大小float16,// A 类型float16,// B 类型float32,// C 类型累加器layout::RowMajor// C 的布局;// 一条 MMA 指令处理 16×8×16 的 tile// 256 个 thread一个 warp 16 warps × 16 threads// 每个 thread 负责 1 个输出元素 → 16 warps × 16 threads 256 elements};// 特化 2bfloat16 × bfloat16 → float32 templateintM,intN,intKstructMmaPromotionMatrixShapeM,N,K,bfloat16,bfloat16,float32{usingOpMmaSychronizedMatrixShape16,8,16,bfloat16,bfloat16,float32,layout::RowMajor;};// 特化 3TensorFloat32TF32Ampere 专用templateintM,intN,intKstructMmaPromotionMatrixShapeM,N,K,tfloat32,tfloat32,float32{usingOpMmaSychronizedMatrixShape16,8,8,// TF32 的 K 维度只有 8精度降低tfloat32,tfloat32,float32,layout::RowMajor;};// 用法自动选择 templatetypenameShape,typenameElementA,typenameElementB,typenameElementCCUTLASS_DEVICEvoidGemmKernel(ElementC*C,ElementA*A,ElementB*B,intM,intN,intK){usingMmatypenameMmaPromotionShape,ElementA,ElementB,ElementC::Op;Mma mma;// 编译期根据 ElementA/B/C 自动选择 MMA 指令// 不需要 runtime if—编译器帮你选TileShape,ElementCaccum;mma(accum,A,B,accum);// ...}Ascend NPU 没有 TF32这是 NVIDIA 的专用格式但 catlass 保持了和 CUTLASS 相同的接口——方便从 NVIDIA 迁移代码。分层模板架构catlass 的模板不是一团乱——分了 4 层每层负责不同的抽象第 1 层Kernel 层最顶层 ↓ 决定GEMM / Conv / Transform ↓ 决定融合策略epilogue 操作 GemmWithFusionShape, Epilogue 第 2 层Tile 层 ↓ 决定tile 大小128×128×32 / 256×128×64 ↓ 决定数据布局RowMajor / ColumnMajor TileIteratorA / TileIteratorB 第 3 层Warp 层 ↓ 决定Warp 内部的 MMA 指令映射 ↓ 决定寄存器分配 MmaSychronizedWarpShape, ElementA, ... 第 4 层指令层最底层 ↓ 直接映射到 NPU 指令 ↓ Cube 单元MMA / MMA_SYNC ↓ Vector 单元FMA / FFMA cute::asmsm(mma.sync..., ...)每层只和上下层交互——改 tile 大小只需改第 2 层不影响第 3/4 层。实战自定义融合算子GEMM SiLUSiLUSigmoid Linear Unitx × sigmoid(x)LLaMA 的激活函数。融合到 GEMM 的 epilogue// catlass/examples/gemm_silu/gemm_silu.cu// 步骤 1定义 Epilogue 操作SiLUnamespacecutlass{namespaceepilogue{namespacethread{classSiLU{public:// Epilogue 操作必须实现 operator()对 tile 的每个元素应用templatetypenameElementCUTLASS_DEVICE Elementoperator()(Element x)const{// SiLU(x) x × sigmoid(x)// sigmoid(x) 1 / (1 exp(-x))floatx_ffloat(x);floatsigmoid1.0f/(1.0fexpf(-x_f));returnElement(x_f*sigmoid);}// 融合标识编译期常量staticconstexprboolkHasBiasfalse;staticconstexprboolkHasActivationtrue;staticconstexprboolkHasSilUtrue;// 新增SiLU 标识};}// namespace thread}// namespace epilogue}// namespace cutlass// 步骤 2用自定义 Epilogue 实例化 GEMMusingGemmShapecutlass::MatrixShape128,128,32;usingGemmSiLUcutlass::gemm::kernel::GemmWithFusionGemmShape,cutlass::epilogue::thread::SiLU,// 自定义 epiloguefloat16,// ElementAfloat16,// ElementBfloat32,// ElementCfloat32// ElementBias未用;// 步骤 3启动 kernelvoidLaunchGemmSiLU(float32*C,float16*A,float16*B,intM,intN,intK){dim3grid((M128-1)/128,(N128-1)/128);dim3block(256);// 一个 warpgroup 256 threadsGemmSiLU kernel;kernelgrid,block(C,M,A,K,B,N,nullptr);}编译期展开后最终生成的 PTX并行线程执行汇编大致是; 伪代码展开后的融合 kernel ld.shared.f16 %fA, [tile_A]; ; 从 shared memory 加载 A ld.shared.f16 %fB, [tile_B]; ; 加载 B mma.sync.aligned.m16n8k16.f16.f16.f32 {...}; ; 矩阵乘 add.f32 %fC, %fC, %fAB; ; 累加到 C ; EpilogueSiLU融合在这不写回 HBM ex2.approx.ftz.f32 %fneg, -%fC; ; exp(-x) 近似 add.f32 %fdenom, 1.0, %fneg; ; 1 exp(-x) div.approx.f32 %fSigmoid, 1.0, %fdenom; ; sigmoid mul.f32 %fSiLU, %fC, %fSigmoid; ; x × sigmoid st.global.f32 [%rdC], %fSiLU; ; 唯一一次 HBM 写对比非融合版本3 个 kernel; 非融合版本 ; Kernel 1: GEMM mma.sync...; add.f32 %fC, ...; st.global.f32 [%rdC], %fC; ; ← 写 HBM第一次 ; Kernel 2: AddBias重新加载 ld.global.f32 %fC, [%rdC]; ; ← 读 HBM第一次 add.f32 %fC_bias, %fC, %fBias; st.global.f32 [%rdC], %fC_bias; ; ← 写 HBM第二次 ; Kernel 3: SiLU重新加载 ld.global.f32 %fC, [%rdC]; ; ← 读 HBM第二次 ex2.approx...; div...; mul...; st.global.f32 [%rdC], %fSiLU; ; ← 写 HBM第三次三次 HBM 往返 vs 一次——延迟差 3×假设 HBM 带宽 900GB/s。踩坑一模板递归深度超过编译限制catlass 大量使用模板递归不是 runtime 递归展开循环。编译期递归深度默认限制是 256——但某些大 tile512×512×64的递归深度会超过。// ❌ 模板递归深度 1024 → 编译错误recursive template instantiation depth exceededtemplateintNstructFactorial{staticconstexprintvalueN*FactorialN-1::value;};templatestructFactorial0{staticconstexprintvalue1;};intxFactorial1024::value;// 编译错误// ✅ 用 constexpr 函数替代C14无递归深度限制constexprintFactorial(intN){intresult1;for(inti1;iN;i)result*i;returnresult;}constexprintxFactorial(1024);// OK编译期计算catlass 自己的代码用了大量模板递归——如果遇到编译错误在 g 下加-ftemplate-depth1024在 clang 下加-ftemplate-depth1024。踩坑二Auto-tuning 搜索空间爆炸catlass 的 tile 大小、VectorWidth、epilogue 融合组合是 15 维参数空间。暴力搜索每个组合跑一遍 benchmark需要 2^15 32768 次编译运行——不现实。# ❌ 暴力搜索32768 种组合每种编译 30 秒 → 273 小时fortile_min[64,128,256]:fortile_nin[64,128,256]:fortile_kin[16,32,64]:forvec_widthin[4,8,16]:forepiloguein[None,ReLU,GELU,SiLU]:CompileAndBenchmark(...)# ✅ 分层搜索先定 tile 大小再调 epilogue# 第 1 步固定 epilogueNone搜索最优 tilebest_tileSearchTileSize(epilogueNone)# 只搜索 3×3×3 27 种# 第 2 步用 best_tile搜索最优 epilogue 融合best_epilogueSearchEpilogue(tilebest_tile)# 只搜索 4 种# 总时间27 4 31 次编译 → 15 分钟分层搜索的理论依据tile 大小对性能的影响是数量级共享内存占用、寄存器压力epilogue 融合的影响是百分比省 1-2 次 HBM 往返。先调数量级再调百分比。踩坑三模板报错信息完全不可读模板元编程的报错信息尤其是类型不匹配可以达到 500 行——因为编译器要展开所有模板实例化路径才报错。error: no matching function for call to Load note: candidate: templateclass T void Load(TileShape, Element, int, int, int) [with T float16; Shape MatrixShape128, 128, 32; Element float32] note: mismatched types Element (float32) vs argument type float16 [128]这个报错的意思是Load(tile_float32, ..., ptr_float16)——tile 是 float32但指针是 float16。// ❌ 报错 500 行TileShape128x128x32,float32tile;// accum 是 float32float16*ptr...;iterator.Load(tile,0,0,0);// ptr 是 float16 → 类型不匹配// ✅ 用静态断言static_assert在编译期给出可读错误templatetypenameElement,typenameAccessTypeCUTLASS_DEVICEvoidLoad(TileShape,Elementtile,AccessType*ptr,...){static_assert(std::is_sameElement,AccessType::value,Load: Element type must match AccessType);// ...}static_assert会在类型不匹配时输出error: static assertion failed: Load: Element type must match AccessType——一行不是 500 行。catlass 的代码大量使用static_assert和类型别名using来让报错可读。遇到模板报错先搜static_assert再看类型推导。catlass 的本质用 C 模板元编程在编译期生成算子运行时零开销。核心抽象TileIterator向量化加载 循环展开 → 最终代码和手写汇编一样高效。融合算子GEMM epilogue靠if constexpr编译期消除未用代码——二进制里没有 dead code。踩坑集中在三方面模板递归深度超限加编译选项、搜索空间爆炸分层搜索、报错不可读用 static_assert。