前言昇腾CANN的ops-transformer仓库提供了Transformer类大模型需要的进阶算子其中FlashAttention算子是最核心的注意力计算优化本文深度解读FlashAttention算子的原理实现和性能表现背景注意力计算的算力挑战Transformer架构的核心是自注意力机制标准注意力计算需要计算QK^T这个中间结果的大小是序列长度乘以序列长度当序列长度从512涨到8192显存需求直接爆炸具体来说标准注意力计算分为三步计算QK^T得到注意力分数矩阵对注意力分数矩阵做softmax用softmax结果乘以V得到输出假设batch4, heads16, seq_len8192, head_dim64QK^T的大小4乘以16乘以8192乘以8192乘以4字节float32 64GB这显然超出了任何现有NPU的显存容量原理分块计算与内存优化FlashAttention的核心思路是分块计算不把整个注意力矩阵都存下来而是分块计算分块写回这样显存占用从O(N^2)降到O(N)具体实现分为以下几步1. 分块策略将QKV矩阵按序列长度维度分块假设块大小为B那么每个块的大小是[batch, heads, B, head_dim]2. 分块计算注意力分数对于每个Q块遍历所有K块计算注意力分数由于是分块计算不需要存储完整的注意力分数矩阵3. 在线softmax在分块计算注意力分数的同时在线计算softmax这需要维护两个统计量最大值m和求和项l4. 分块计算输出用计算好的注意力权重乘以V块得到输出块实现昇腾NPU上的FlashAttention在昇腾NPU上实现FlashAttention需要充分利用达芬奇架构的硬件特性达芬奇架构有专门的矩阵计算单元Cube UnitFlashAttention的分块计算可以很好地映射到Cube Unit上关键优化点分块大小选择分块大小需要适配Cube Unit的计算能力太大的分块会导致寄存器溢出太小的分块无法充分利用Cube Unit内存层级利用充分利用片上内存L1 Buffer来减少对HBM的访问次数流水线设计将计算和数据搬运流水线化隐藏内存访问延迟代码讲解FlashAttention核心逻辑下面是FlashAttention的核心代码逻辑简化版importtorchdefflash_attention_forward(Q,K,V,causalTrue):闪光注意力前向计算简化版 Args:Q:Query矩阵形状为[batch,seq_len,heads,head_dim]K:Key矩阵形状同上 V:Value矩阵形状同上 causal:是否使用因果注意力掩码 Returns:输出矩阵形状为[batch,seq_len,heads,head_dim]batch,seq_len,heads,head_dimQ.shape# 分块大小需要根据硬件特性调整block_size128# 初始化输出和中间统计量Otorch.zeros_like(Q)ltorch.zeros(batch,seq_len,heads).to(Q.device)# softmax的分母mtorch.full((batch,seq_len,heads),-float(inf)).to(Q.device)# 最大值# 外层循环遍历Q块foriinrange(0,seq_len,block_size):# 获取Q块Q_blockQ[:,i:iblock_size,:,:]# [batch, block_size, heads, head_dim]# 初始化当前块的输出和统计量O_blocktorch.zeros_like(Q_block)l_blocktorch.zeros(batch,block_size,heads).to(Q.device)m_blocktorch.full((batch,block_size,heads),-float(inf)).to(Q.device)# 内层循环遍历KV块forjinrange(0,seq_len,block_size):# 获取KV块K_blockK[:,j:jblock_size,:,:]V_blockV[:,j:jblock_size,:,:]# 计算注意力分数Q_block乘以K_block的转置S_blocktorch.matmul(Q_block,K_block.transpose(-2,-1))/(head_dim**0.5)# S_block形状[batch, block_size, heads, block_size]# 因果注意力掩码ifcausal:# 创建因果掩码masktorch.triu(torch.ones(block_size,block_size),diagonal1).bool()maskmask.to(Q.device)S_block.masked_fill_(mask,-float(inf))# 在线softmax更新最大值和求和项m_newtorch.max(S_block,dim-1)# [batch, block_size, heads]l_newtorch.sum(torch.exp(S_block-m_new.unsqueeze(-1)),dim-1)# [batch, block_size, heads]# 更新统计量m_block_newtorch.max(m_block,m_new)l_block_newtorch.exp(m_block-m_block_new.unsqueeze(-1))*l_blocktorch.exp(m_new-m_block_new.unsqueeze(-1))*l_new# 更新输出O_blocktorch.exp(m_block-m_block_new.unsqueeze(-1)).unsqueeze(-1)*O_blocktorch.matmul(torch.exp(S_block-m_block_new.unsqueeze(-1)),V_block)# 更新统计量m_blockm_block_new l_blockl_block_new# 归一化输出O[:,i:iblock_size,:,:]O_block/l_block.unsqueeze(-1)returnO# 测试代码if__name____main__:# 创建测试数据batch2seq_len512heads8head_dim64Qtorch.randn(batch,seq_len,heads,head_dim)Ktorch.randn(batch,seq_len,heads,head_dim)Vtorch.randn(batch,seq_len,heads,head_dim)# 计算FlashAttentionoutputflash_attention_forward(Q,K,V)print(fQ shape:{Q.shape})print(fK shape:{K.shape})print(fV shape:{V.shape})print(fOutput shape:{output.shape})这段代码展示了FlashAttention的核心思路分块计算在线softmax避免存储完整的注意力矩阵实际使用时不需要自己实现这个逻辑直接调用ops-transformer提供的算子即可性能表现实测数据ops-transformer中的FlashAttention算子在昇腾NPU上的性能表现如下测试环境硬件Ascend 910服务器8乘以NPU软件CANN 8.0模型GPT-3 13B测试结果配置吞吐量tokens/s首token延迟ms显存占用GB基线标准注意力1,2502,38024.5FlashAttention3,8701,12018.2可以看到使用FlashAttention后吞吐量提升了3倍多首token延迟降低了53%显存占用下降了26%总结FlashAttention是Transformer架构中最重要的注意力计算优化它通过分块计算和在线softmax将显存占用从O(N^2)降到O(N)能支持更长的序列长度昇腾CANN的ops-transformer仓库提供了高性能的FlashAttention算子实现充分利用了达芬奇架构的硬件特性如果你正在昇腾NPU上做Transformer类的模型训练或推理FlashAttention绝对值得一试更多技术细节可以参考ops-transformer仓库的文档https://atomgit.com/cann/ops-transformer