当前位置: 首页 > news >正文

昇腾CANN ops-blas Batched GEMM:多头注意力的小矩阵乘批处理实战

Transformer 的 Multi-Head Attention 有 H 个注意力头——每个头独立做矩阵乘Qh×Kh^T、Attn×Vh。H32 时一个 BatchNorm 后面紧跟着 32 个小矩阵乘每个头独立。单独启动 32 次 GEMM 会有 32 次 launch 开销~50μs/次 → 1.6ms 总开销加上 32 次 kernel 启动带来的流水线 flush。ops-blas 的 Batched GEMM 把 32 个小矩阵乘合并成一个 kernel——一次 launch 处理全部 32 个头。Batched GEMM 的三种策略ops-blas 根据 batched GEMM 的形状自动选择策略策略选择逻辑 if (batch_count 32 M * N * K 4096): → 策略 1Interleaved Batching交错批处理 把 32 个小 GEMM 交织在一个 block 内执行 elif (batch_count 16 M * N * K 4096): → 策略 2Parallel Batching并行批处理 给每个小 GEMM 分配独立 block else: → 策略 3Hybrid Batching混合批处理 分组内交错的组外并行策略 1Interleaved Batching// ops-blas/kernels/batched_gemm_interleaved.cpp__aicore__voidBatchedGEMMInterleaved(GlobalTensorfloat16A_batched,// [batch, M, K]GlobalTensorfloat16B_batched,// [batch, K, N]GlobalTensorfloat16C_batched,// [batch, M, N]intbatch,intM,intN,intK){// 每个 block 处理一个 batch 的 GEMMfor(intb0;bbatch;b){intblock_idb%gridDim.x;// 轮询分配 block// 在 L1 中交错存储 32 个 batch 的 tile// 单个 tile 大小 tile_M × tile_K 16 × 16 256 elementsLocalTensorfloat16A_tile(tile_M*tile_K);LocalTensorfloat16B_tile(tile_K*tile_N);LocalTensorfloat16C_tile(tile_M*tile_N);intA_offsetb*M*K;intB_offsetb*K*N;intC_offsetb*M*N;// 分块矩阵乘for(intm0;mM;mtile_M){for(intn0;nN;ntile_N){// 初始化累加器C_tile0.0f;for(intk0;kK;ktile_K){// 加载 A 和 B 的 tile 到 L1DataCopy(A_tile,A_batchedA_offsetm*Kk,tile_M*tile_K);DataCopy(B_tile,B_batchedB_offsetk*Nn,tile_K*tile_N);// Cube 单元矩阵乘累加MMA(C_tile,A_tile,B_tile,tile_M,tile_N,tile_K);}// 写回结果DataCopy(C_batchedC_offsetm*Nn,C_tile,tile_M*tile_N);}}}}策略 2Parallel Batching// ops-blas/kernels/batched_gemm_parallel.cpp__aicore__voidBatchedGEMMParallel(GlobalTensorfloat16A_batched,// [batch, M, K]GlobalTensorfloat16B_batched,// [batch, K, N]GlobalTensorfloat16C_batched,// [batch, M, N]intbatch,intM,intN,intK){// 每个 block 处理一个独立的 batch不是所有 block 处理同一 batch// block 分配block_id b % num_batch_blocks// num_batch_blocks gridDim.x / batchintnum_batch_blocksgridDim.x/batch;if(num_batch_blocks1)num_batch_blocks1;// 每个 batch 有 num_batch_blocks 个 block 在并行处理intbatch_idblockIdx.x/num_batch_blocks;intbatch_blockblockIdx.x%num_batch_blocks;intA_offsetbatch_id*M*K;intB_offsetbatch_id*K*N;intC_offsetbatch_id*M*N;// batch_block 决定此 block 处理矩阵的哪一部分// 把 M 维度均分给 num_batch_blocks 个 blockintm_startbatch_block*(M/num_batch_blocks);intm_end(batch_block1)*(M/num_batch_blocks);for(intmm_start;mm_end;mtile_M){for(intn0;nN;ntile_N){LocalTensorfloat16C_tile(tile_M*tile_N);C_tile0.0f;for(intk0;kK;ktile_K){LocalTensorfloat16A_tile(tile_M*tile_K);LocalTensorfloat16B_tile(tile_K*tile_N);DataCopy(A_tile,A_batchedA_offsetm*Kk,tile_M*tile_K);DataCopy(B_tile,B_batchedB_offsetk*Nn,tile_K*tile_N);MMA(C_tile,A_tile,B_tile,tile_M,tile_N,tile_K);}DataCopy(C_batchedC_offsetm*Nn,C_tile,tile_M*tile_N);}}}策略 3Hybrid Batching// ops-blas/kernels/batched_gemm_hybrid.cpp__aicore__voidBatchedGEMMHybrid(GlobalTensorfloat16A_batched,// [batch, M, K]GlobalTensorfloat16B_batched,// [batch, K, N]GlobalTensorfloat16C_batched,// [batch, M, N]intbatch,intM,intN,intK){// 分组每 group_size 个 batch 为一组// 组内用 Interleaved充分利用 L1组间用 Parallelintgroup_size4;// 每组 4 个 batchintnum_groups(batchgroup_size-1)/group_size;intgroup_idblockIdx.x%num_groups;// 每个 block 处理一个 group// 组间并行处理intbatch_startgroup_id*group_size;intbatch_endmin(batch_startgroup_size,batch);// 组内 Interleavedfor(intbbatch_start;bbatch_end;b){intA_offsetb*M*K;intB_offsetb*K*N;intC_offsetb*M*N;// 分块矩阵乘同 Interleaved 策略for(intm0;mM;mtile_M){for(intn0;nN;ntile_N){LocalTensorfloat16C_tile(tile_M*tile_N);C_tile0.0f;for(intk0;kK;ktile_K){LocalTensorfloat16A_tile(tile_M*tile_K);LocalTensorfloat16B_tile(tile_K*tile_N);DataCopy(A_tile,A_batchedA_offsetm*Kk,tile_M*tile_K);DataCopy(B_tile,B_batchedB_offsetk*Nn,tile_K*tile_N);MMA(C_tile,A_tile,B_tile,tile_M,tile_N,tile_K);}DataCopy(C_batchedC_offsetm*Nn,C_tile,tile_M*tile_N);}}}}Multi-Head Attention 的 Batched GEMM 应用Transformer 中 Multi-Head Attention 的三种 GEMM 都可以用 Batched GEMM 加速# PyTorch 自动路由到 ops-blas 的 Batched GEMMimporttorchimporttorch_npu# MHA 的三个 GEMM 步骤# 输入x [batch, seq, d_model] (如 [1, 2048, 4096])# H32 heads, d_head d_model // H 128# 1. QKV projection每个头独立共 3H 个小 GEMM# x W_q[head] → Q[head] [batch, seq, d_head]# 转成 batched form: [batch*seq, d_model] [3, head, d_model, d_head]qkvtorch.nn.functional.linear(x,W_qkv)# 底层用 Batched GEMM# 2. Attention score每个头独立H 个小 GEMM# Q[head] K[head]^T → scores[head] [batch, seq, seq]# batched form: [head, batch*seq, d_head] [head, d_head, batch*seq]attn_scorestorch.bmm(Q.reshape(-1,seq,d_head).transpose(0,1),K.reshape(-1,seq,d_head).transpose(0,1).transpose(1,2))# 底层用 Batched GEMM一次 launch 处理 H 个头# 3. Output projection每个头独立H 个小 GEMM# attn[head] V[head] → output[head] [batch, seq, d_head]# batched form 同理outputtorch.bmm(attn_weights,V.reshape(-1,seq,d_head).transpose(0,1))关键python 侧看到的torch.bmm(batched matrix multiplication)——底层自动映射到 ops-blas 的 Batched GEMM。踩坑一batch 维度的 stride 不连续标准 Batched GEMM 假设 A 和 B 的 batch 维度是连续存储的 ([batch, M, K])。但 MHA 中 QKV projection 的 weight 是 [num_heads, d_model, d_head]——head 维度的 stride d_model * d_head不是 K * d_head。修复ops-blas 的 Batched GEMM 支持 stride 参数// 支持 stride 参数__aicore__voidBatchedGEMMStrided(GlobalTensorfloat16A_batched,GlobalTensorfloat16B_batched,GlobalTensorfloat16C_batched,intbatch,intM,intN,intK,intstride_A,// A 的 batch stride不连续时 M*Kintstride_B,// B 的 batch strideintstride_C// C 的 batch stride){for(intb0;bbatch;b){// 使用 stride 替代 M*KintA_offsetb*stride_A;// 不是 b * M * KintB_offsetb*stride_B;intC_offsetb*stride_C;// ... 其余同 Interleaved}}PyTorch 侧# 非连续 batch → 指定 strideoutputtorch_npu.batched_gemm(A_strided,B_strided,stride_Ad_model*d_head,stride_Bd_head*seq)踩坑二batch 中 GEMM 形状不一致MHA 的 32 个头可能形状不同某些头是 padding 头不需要计算。Batched GEMM 默认假设所有 batch 的 shape 相同——形状不一致时padding 头浪费计算。修复使用 mask 跳过不需要的 batch__aicore__voidBatchedGEMMMasked(GlobalTensorfloat16A_batched,GlobalTensorfloat16B_batched,GlobalTensorfloat16C_batched,GlobalTensoruint8batch_mask,// [batch] 1有效, 0跳过intbatch,intM,intN,intK){for(intb0;bbatch;b){if(!batch_mask[b]){continue;// 跳过这个 batch — 节省 Cube 和时间}// ... 正常计算}}Mask 由上层ATB传入——对于 padding 头batch_mask 0。踩坑三Batched GEMM 和单次大 GEMM 的取舍Merge QKV projection把 3H 个小 GEMM 合并成 1 次大 GEMM——x [W_q, W_k, W_v]。形状是[batch*seq, d_model] [d_model, 3*head*d_head]——一次 GEMM 代替 3H 次小 GEMM。选择逻辑# ops-blas 自动判断ifM4096orK4096:# 大矩阵 → Merge 成一次大 GEMM# 好处Cube 利用率高tile 填满returnmerged_GEMM(x,W_merged)elifbatch_count32:# 很多小 GEMM → Batched GEMM# 好处一次 launch减少开销returnbatched_GEMM(x,W_batched)else:# 中等规模 → 混合策略returnhybrid_GEMM(x,W_batched)经验规则MHA 推理batch1, seq128, head32→ Batched GEMM32 个小矩阵MHA 训练batch8, seq2048, head32→ Merged GEMM1 次大矩阵大 GEMM形状阈值M×K 4096×4096 → Merge否则 → BatchedBatched GEMM 解决的不只是计算效率——而是 launch 开销和流水线中断。32 次 HEAD MM 各 launch 一次32×50μs1.6ms 开销vs 一次 Batched GEMM launch50μs。在推理管线的 2ms 总时间中launch 开销占比从 80% 降到 2.5%。ops-blas 的 Batched GEMM 自动选择策略Interleaved/Parallel/Hybrid、支持 stride 和 mask——让 MHA 的 H 个小矩阵乘变成一次 kernel 调用。
http://www.zskr.cn/news/1363697.html

相关文章:

  • Unity Mod Manager底层原理与模组生命周期管理
  • 别再只用chmod了!麒麟KYLINOS文件权限进阶:用ACL实现更精细的访问控制(含setfacl命令详解)
  • 数据增强在软件工程中的评估陷阱:以Flaky测试分类为例
  • 缺失数据下的因果推断:mDR与mEP学习器原理与实战
  • 2024 iOS自动化测试环境搭建:Appium 2.5+适配Xcode 15.3与iOS 17.4
  • lucie:智能加载UCI数据集的Python工具,解决格式兼容难题
  • 全局量子门变分方法:释放硬件原生优势的量子态制备新范式
  • 【考研英语一·翻译专攻】长难句翻译的“分治策略”:从底层拆分到逻辑重构(1997-2010真题高频陷阱与红笔纠偏)
  • 多速率信号处理与图像量化:从奈奎斯特到工程实践
  • Kruskal-Wallis检验在自动驾驶用户信任度研究中的应用与实操
  • 智能AI图像识别之工地积水识别数据集 道路积水数据集 管道泄漏漏水数据集 图像yolov8图像数据集 积水识别yolo第10260期
  • 信念传播算法:从图模型推理到消息传递原理与应用
  • 核能消费对循环经济的影响:基于DYNARDL模型与机器学习的实证研究
  • 基于OCT-H与特征增强的流体多臂老虎机最优控制策略学习
  • ZygiskFrida:安卓逆向的Zygote层动态插桩新范式
  • RISC-V SoC中的DSP加速器设计与边缘计算优化
  • 基于QR分解与肘部法则的稀疏传感器优化布置方法
  • 基于多维度聚类分析的住宅供暖能耗模式识别与节能策略研究
  • [智能体-37]:协同共生:大模型、智能体与专业工具的系统生产力之道
  • 数值自举与弦论振幅:用SDPB最小化纠缠矩定位开超弦
  • 2026年比较好的深圳淘宝纸箱/深圳物流纸箱/宝安纸箱/纸箱优质公司推荐 - 行业平台推荐
  • 观察 Taotoken 模型广场如何辅助开发者进行初步模型选型
  • 基于Graphlet的网络嵌入:从局部结构到生物功能模块发现
  • 外观专利和实用新型
  • OAuth 2.0授权机制本质与四大模式实战解析
  • TWA方法:利用细粒度错误标注优化机器翻译模型
  • 抖音批量下载神器:轻松保存喜欢的视频、音乐和图集
  • MACE-MP-MOF0:基于机器学习势函数高效计算MOF声子谱与热力学性质
  • 机器学习公平性实战:三大工具库对比与偏见缓解指南
  • 2026年比较好的海口配电控制开关/海口家装照明开关/海南家装照明开关公司对比推荐 - 行业平台推荐