Flash Attenion算法原理

Flash Attenion算法原理

Flash Attenion算法原理

1.1 Flash Attention Step-by-Step

 

image

 

1.2 Softmax

\( \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{k=1}^{N} e^{x_k}} \)

import torch 
A = torch.randn(2, 6)
A_exp = torch.exp(A)
A_sum = torch.sum(A_exp, dim=1).unsqueeze(1)
P = A_exp / A_sum #广播
print(A)
print(P)

结果

tensor([[ 1.0668, -0.3969, -0.2226,  0.7207,  1.0509, -1.0740],[ 0.6774,  1.0916, -1.8402, -1.0806,  0.9309,  2.4612]])
tensor([[0.3016, 0.0698, 0.0831, 0.2133, 0.2968, 0.0355],[0.0999, 0.1512, 0.0081, 0.0172, 0.1288, 0.5948]])

1.3 Safe Softmax

原始softmax数值不稳定,改写成Safe Softmax版本

\( \text{SafeSoftmax}(x_i) = \frac{e^{x_i - m}}{\sum_{k=1}^{N} e^{x_k - m}} \quad \text{其中 } m = \max(x_1, x_2, ..., x_N) \)

1.3 Online Softmax

\( \begin{aligned} d_i' &= \sum_{j}^{i} e^{x_j - m_i} \\ &= \sum_{j}^{i-1} e^{x_j - m_i} + e^{x_i - m_i} \\ &= \sum_{j}^{i-1} e^{x_j - m_{i-1} + m_{i-1} - m_i} + e^{x_i - m_i} \\ &= \left( \sum_{j}^{i-1} e^{x_j - m_{i-1}} \right) e^{m_{i-1} - m_i} + e^{x_i - m_i} \\ &= d_{i-1}' e^{m_{i-1} - m_i} + e^{x_i - m_i} \end{aligned} \)