当前位置: 首页 > news >正文

【Tilelang入门】Tilelang Puzzles 08

Tillang Puzzles

一个开源仓库https://github.com/tile-ai/tilelang-puzzles/tree/main

给出用tilelang实现经典算子的例子,附带讲解。分为10个puzzle,每个问题都有待补全文件,和参考实现,以及文字讲解。

采用循序渐进的思路,难度逐渐递增,01-05熟悉语法,06-09实现经典算子,10为挑战复杂实战算子

08 Matirx

矩阵乘法是整个算子优化的核心,大部分经典算子都可以规约到矩阵乘法,比如前一节的注意力里有QKTQK^TQKT,前向传播有参数*输入,反向传播的求偏导也是矩阵乘法,卷积通过im2col转化后也能变成矩阵乘法。

GEMV

先来个基础的,矩阵乘向量,可以看成矩阵乘法的特殊情况,N=1
定义

foriinrange(M):ACC=0# float32 累加器forkinrange(K):ACC+=A[i,k]*B[k]C[i]=ACC# 转换回 float16

实际上也可以看成规约求和的特殊情况,看成带权规约,普通规约权重都是1,这里的权重是B[k]

# Reduce Sum (Puzzle 05)foriinrange(N):C[i]=sum(A[i,:])# GEMV (Puzzle 08)foriinrange(M):C[i]=sum(A[i,:]*B[:])# 加权求和

baseline

defref_gemv(A:torch.Tensor,B:torch.Tensor):assert A.shape==(M,K)assert B.shape==(K,)assert A.dtype==B.dtype==torch.float16returntorch.matmul(input=A,other=B)# 返回[M,]
  • C_local = T.alloc_fragment((BLOCK_M,), accum_dtype),累加类型使用fp32,不同于输入类型fp16,因为fp16不管是精度还是值域都太小了,矩阵乘法有乘法,有累加,数值很大,用32才能保证不溢出+精度过关
  • AB_local[i, j] = A_local[i, j].astype(accum_dtype) * B_local[j].astype(accum_dtype)按前面说的,看成带权规约,先计算乘上权重后的结果。由于输入是fp16,还想保证精度,计算时先显式转成fp32,类似于cpp里的ans += 1ll * x * y
  • 然后规约T.reduce_sum(AB_local, C_local, dim=1, clear=False)
@tilelang.jitdeftl_gemv(A,B,BLOCK_M:int,BLOCK_K:int):M,K=T.const("M, K")dtype=T.float16 accum_dtype=T.float32 A:T.Tensor((M,K),dtype)B:T.Tensor((K,),dtype)C=T.empty((M,),dtype)# TODO: Implement this functionwithT.Kernel(T.ceildiv(M,BLOCK_M),threads=128)asbx:A_local=T.alloc_fragment((BLOCK_M,BLOCK_K),dtype)B_local=T.alloc_fragment((BLOCK_K,),dtype)C_local=T.alloc_fragment((BLOCK_M,),accum_dtype)AB_local=T.alloc_fragment((BLOCK_M,BLOCK_K),accum_dtype)T.clear(C_local)forbkinT.Serial(T.ceildiv(K,BLOCK_K)):T.copy(A[bx*BLOCK_M,bk*BLOCK_K],A_local)T.copy(B[bk*BLOCK_K],B_local)fori,jinT.Parallel(BLOCK_M,BLOCK_K):AB_local[i,j]=A_local[i,j].astype(accum_dtype)*B_local[j].astype(accum_dtype)T.reduce_sum(AB_local,C_local,dim=1,clear=False)T.copy(C_local,C[bx*BLOCK_M])returnC


性能很差,这算是不叫暴力的做法

朴素GEMM

  • T.gemm(A_local, B_local, C_local)和前面唯一的区别,把手动乘上权重,再逐行规约,改成调用gemm接口计算一个块的结果了,只需传入两个输入矩阵,一个接收矩阵。
@tilelang.jitdeftl_matmul_naive(A,B,BLOCK_M:int,BLOCK_N:int,BLOCK_K:int):M,N,K=T.const("M, N, K")dtype=T.float16 accum_dtype=T.float32 A:T.Tensor((M,K),dtype)B:T.Tensor((K,N),dtype)C=T.empty((M,N),dtype)# TODO: Implement this functionwithT.Kernel(T.ceildiv(M,BLOCK_M),T.ceildiv(N,BLOCK_N),threads=128)as(bx,by):A_local=T.alloc_fragment((BLOCK_M,BLOCK_K),dtype)B_local=T.alloc_fragment((BLOCK_K,BLOCK_N),dtype)C_local=T.alloc_fragment((BLOCK_M,BLOCK_N),accum_dtype)T.clear(C_local)forbkinT.Serial(T.ceildiv(K,BLOCK_K)):T.copy(A[bx*BLOCK_M,bk*BLOCK_K],A_local)T.copy(B[bk*BLOCK_K,by*BLOCK_N],B_local)T.gemm(A_local,B_local,C_local)T.copy(C_local,C[bx*BLOCK_M,by*BLOCK_N])returnC


性能还是很差,而且看起来5x比前面的GEMV的3x左右还要差?难道gemm还不如手动规约高效?显然不是,因为GEMM两个输入都是矩阵,计算规模更大了,那么效率差距会被放大。实际GEMM肯定效率是比手动规约高的。

GEMM内部一般会直接调用MMA接口,使用Tensor Core计算。Tensor Core是矩阵计算专用单元,只能级算特定大小的矩阵乘法,不灵活,但是计算吞吐量大;前面的reduce和拷贝这些操作,都是CUDA Core执行的,可以执行通用计算,但是效率低。所以调用Tensor Core的GEMM接口一般性能肯定比CUDA Core的reduce高。

这里我们和torch还有差距,说明还有优化没用上。

优化版GEMM

  • B_local = T.alloc_shared((BLOCK_K, BLOCK_N), dtype)这里把张量从寄存器换到了共享内存上,明明寄存器更快,这是为什么?因为寄存器资源是很有限的,基本装下一个(BLOCK_M, BLOCK_K)大的张量就快满了,两个就不够了,而溢出部分的数据,会被直接存到全局内存,全局内存的延迟是最高的,整体效率被这个环节完全拖慢了,什么优化都没用了。所以,只有最频繁用到的累加数组,我们考虑安排在寄存器上,两个输入分块安排在共享内存,共享内存一般很大,容纳多个块都没问题。同时,访问速度也还可以接受。
  • for bk in T.Pipelined(T.ceildiv(K, BLOCK_K), num_stages=3):又是tilelang的一个强大接口,这里可以在并行循环里,增加一个参数num_stage,指定流水线级数。就可以把这个循环流水线化!理论上合适的流水阶段划分,可以实现等同于级数的加速比!
  • 这里虽然轻飘飘的一行,实际内部优化思想非常重要,注意到流水线建立前,每一轮的循环执行的是,先拷贝,再计算。但是我们前面提到过,现代GPU的内存带宽远小于计算吞吐量,也就是说大部分时间,计算核心都处于阻塞,等待内存搬运,这正是适合流水线优化的地方,可以把搬运和计算解耦,流水线一个阶段负责搬运,一个阶段负责计算,这样搬运的时候也可以计算,大大提升效率,整体瓶颈只取决于最慢的部分,也就是搬运,计算延迟几乎被完全隐藏了。
@tilelang.jitdeftl_matmul_opt(A,B,BLOCK_M:int,BLOCK_N:int,BLOCK_K:int):M,N,K=T.const("M, N, K")dtype=T.float16 accum_dtype=T.float32 A:T.Tensor((M,K),dtype)B:T.Tensor((K,N),dtype)C=T.empty((M,N),dtype)# TODO: Implement this functionwithT.Kernel(T.ceildiv(M,BLOCK_M),T.ceildiv(N,BLOCK_N),threads=128)as(bx,by):A_local=T.alloc_shared((BLOCK_M,BLOCK_K),dtype)B_local=T.alloc_shared((BLOCK_K,BLOCK_N),dtype)C_local=T.alloc_fragment((BLOCK_M,BLOCK_N),accum_dtype)T.clear(C_local)forbkinT.Pipelined(T.ceildiv(K,BLOCK_K),num_stages=3):T.copy(A[bx*BLOCK_M,bk*BLOCK_K],A_local)T.copy(B[bk*BLOCK_K,by*BLOCK_N],B_local)T.gemm(A_local,B_local,C_local)T.copy(C_local,C[bx*BLOCK_M,by*BLOCK_N])returnC

优化都用上后和torch实现差的不多了。

另外来验证一下,所有张量都申请在寄存器上会不会导致性能退化。可以看到几乎退化到朴素GEMM版本了,这是合理的,因为朴素版本就是直接访问全局内存,这里内存溢出后,张量也是存在全局内存上的,访问延迟自然和直接存在全局内存上相近。

http://www.zskr.cn/news/1451301.html

相关文章:

  • 【AI监控融合实战指南】:20年运维专家亲授5大落地陷阱与避坑清单
  • 保姆级教程:在Windows/Linux上为YOLOv8s模型生成GradCAM热力图(避坑CUDA/CPU环境配置)
  • 告别GPIO模拟时序:用STM32的FSMC外设高效驱动TFTLCD,性能提升实测
  • 从日常运维到脚本编写:详解Windows批处理中find与findstr的10个经典使用场景
  • 智慧电网电力设施目标检测数据集|输电线天线风机烟囱识别YOLO深度学习数据集10148期
  • 告别“狗牙”圆:Bresenham画圆算法在嵌入式屏幕(如STM32+LCD)上的C语言实战
  • [智能体-226]:大模型 ↔ 计算机硬件全套类比详解(冯・诺依曼架构对齐),智能体完整复刻冯诺依曼计算机运行范式
  • 手把手用Python复现Robbins-Monro算法:从求根到在线均值估计的完整代码示例
  • 2026年口碑好的西安新房装修/西安装修优选公司推荐 - 行业平台推荐
  • 从Kaggle竞赛入门:用随机森林搞定泰坦尼克号预测的完整避坑指南(含特征工程与调参)
  • 从Fluent面板到理论公式:一文讲透ANSYS Help文档的四种正确打开方式
  • 做了springAI项目中的三个功能总结的心得
  • 避开蓝桥杯DS1302的坑:从时间加减乱码到稳定显示的完整避坑指南
  • Ansaldo cpu684 印刷电路板
  • 别再踩LONG数据类型的坑了!从Oracle官方文档看CLOB如何优雅替代(附迁移脚本)
  • CrewAI实战:如何用分层流程(Hierarchical Process)和本地Ollama模型打造一个‘经理+员工’的AI团队
  • 抖音批量下载工具技术深度解析:从API逆向到智能编排的完整实现
  • 抖音无水印下载终极指南:5分钟掌握douyin-downloader完整使用技巧
  • YOLO26涨点改进| TGRS 2025 |独家创新首发、卷积改进篇| 引入SFD空间-频率解耦模块,通过“空间分支 + 频率分支”对退化图像进行双域解耦与增强,助力目标检测、图像增强任务有效涨点
  • LabVIEW直连GPU加速环境安装包(含NVIDIA/AMD驱动与运行库)
  • 如何用3个简单设置让猫抓成为你的专属资源猎手?
  • 硅胶制品厂主要集中在哪些地方?
  • 从4K到2M:动手实验对比Linux大页(HugePages)下,一二级页表的内存开销与性能影响
  • 从AI小白到提示词高手,我只用了这10个技巧
  • 深入RK3568 USB3.0控制器:从DTS设备树配置到内核驱动加载的底层原理剖析
  • 3分钟掌握DamaiHelper:告别手速焦虑,轻松抢到心仪演唱会门票
  • 避坑指南:在CentOS 7上手动编译安装SPECCPU2017,解决gcc/gfortran依赖的那些事儿
  • 别再手动翻文件夹了!用Windows批处理+for命令,5分钟搞定照片/文档的批量提取
  • 告别电脑束缚!用CW-Writer实现离线烧录CW32芯片的保姆级教程
  • 拆解D3D12渲染管线:用“画三角形”的例子,彻底搞懂命令队列、PSO和围栏