从零手写注意力机制:可调试的QKV计算与数值稳定性实践

从零手写注意力机制:可调试的QKV计算与数值稳定性实践

1. 项目概述:这不是在讲“注意力”,而是在造一个注意力机制的发动机

“Demystifying Attention: Building It from the Ground Up”——这个标题一出来,我就知道它不是那种泛泛而谈“注意力有多重要”的鸡汤文,也不是调用几行 PyTorch 就完事的速成教程。它直指当前大模型底层最核心的那块“心脏”:注意力机制(Attention Mechanism)。关键词里没写“Transformer”“LLM”“PyTorch”,但它们全都在背景里站着;没提“softmax”“QKV”“masking”,可这些就是你要亲手拧紧的每一颗螺丝。我带过不少刚从机器学习入门转到NLP方向的工程师,他们常卡在一个认知断层上:读论文时看到“multi-head self-attention”像看天书,跑通 Hugging Face 的 pipeline 却完全不知道中间矩阵乘了几次、维度怎么变、梯度往哪流。这个项目,就是专为填平这个断层而生的——它要求你从零手写一个可调试、可打断点、可逐层打印 shape 和数值的注意力模块,不依赖任何高级封装,连torch.nn.Linear都得自己用torch.randn初始化权重再做矩阵乘。

我试过三次完整复现:第一次用 NumPy,纯靠纸笔推导前向传播的每一步张量形状变化,卡在batch_size × seq_len × d_modelbatch_size × num_heads × seq_len × head_dim的 reshape 逻辑上整整两天;第二次用 PyTorch 写出骨架,结果反向传播时发现attn_weights的梯度在 softmax 后被截断,查了六小时才发现是torch.no_grad()没关;第三次才真正跑通带 mask 的 causal attention,并把每个中间变量——从Q @ K.T / sqrt(d_k)的原始分值,到 softmax 后的概率分布,再到attn_output = attn_weights @ V的加权和——全部打印出来,一行行对齐论文公式。这种“笨功夫”带来的收益是立竿见影的:当你亲手算出Q[0, 0, :] @ K[0, 0, :].T等于 23.7,而softmax后对应位置变成 0.0012,你就再也不会把“注意力得分高”误解为“两个词语义相似”,而是清醒意识到:这是模型在当前 token 位置,对历史所有位置计算出的一个动态权重分配器,它的输出不是判断,而是路由信号。适合谁?不是给想快速上线业务模型的产品经理,而是给那些愿意花三天时间只为搞懂scaled_dot_product_attention里那个sqrt(d_k)为什么非得是根号、而不是除以 2 或者 log 的算法工程师、研究型学生,以及被面试官问“self-attention 的梯度怎么回传”时当场愣住的求职者。它解决的不是“能不能用”,而是“为什么这么用”“错一点会怎样”“改一个参数整个链条怎么崩”。

2. 核心设计思路:为什么必须“从地面建起”,而不是站在巨人的肩膀上?

2.1 拒绝黑箱:从matmul开始,而非nn.MultiheadAttention

很多教程一上来就调用torch.nn.MultiheadAttention,然后告诉你“设置embed_dim=512, num_heads=8就行”。这就像教人修车,直接递给你一台装好发动机的整车,说“踩油门它就跑”。但当你发现车在高速时抖动,你根本无从下手——是火花塞老化?正时皮带松动?还是曲轴动平衡出了问题?注意力机制同理。nn.MultiheadAttention是个高度优化的工业级组件,它内部做了 fused kernel、memory-efficient attention、flash attention 适配,甚至自动处理了is_causal=True时的下三角 mask。这些优化对生产环境至关重要,但对理解原理却是障碍。我们选择“从地面建起”,核心逻辑有三层:

第一层是控制变量。当你手写Q @ K.T时,你可以强制让QK全为 1,立刻看到输出矩阵全是d_k(因为1×1+1×1+...d_k次),从而验证维度计算是否正确;而用封装接口,你永远看不到这个中间态。第二层是暴露缺陷。比如softmaxQ @ K.T值域极大时会溢出(exp(1000)直接变inf),手写实现会让你第一时间撞上torch.finfo(torch.float32).max ≈ 3.4e38这堵墙,进而逼你实现logsumexp稳定化;而封装版早已内置了clamplog_softmax,你只看到结果正常,却不知背后有多少防御性代码。第三层是建立直觉。我让学生对比两种写法:一种是attn_weights = torch.softmax(Q @ K.T / math.sqrt(d_k), dim=-1);另一种是先算scores = Q @ K.T / math.sqrt(d_k),再attn_weights = torch.softmax(scores, dim=-1)。看似一样,但前者在调试时无法 inspectscores,后者却能清晰看到“原始分值”如何被缩放、如何被指数化、如何被归一化。这种可观察性,是构建工程直觉的基石。

提示:不要跳过math.sqrt(d_k)这个缩放因子。它不是魔法数字——当QK的元素服从均值为 0、方差为 1 的分布时,Q @ K.T的方差会随d_k线性增长(因为Var(XY) = E[X²]E[Y²] - (E[X]E[Y])² ≈ 1×1 = 1,但矩阵乘是d_k项求和,所以总方差≈d_k)。如果不缩放,softmax的输入会越来越大,导致梯度消失(exp(large)让小值趋近 0,梯度几乎为 0)。这就是为什么d_k=64时,1/sqrt(64)=0.125是个关键调节阀。

2.2 分阶段演进:从单头到多头,从无 mask 到 causal mask

我们不追求一步到位写出工业级代码,而是按认知负荷递进。第一阶段:单头、无 mask、无 dropout。目标只有一个:让forward()输出和torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=False)完全一致(torch.allclose误差 < 1e-5)。这迫使你精确处理Q, K, V的 batch 维度对齐、seq_len轴的点积广播规则、softmaxdim参数指定。第二阶段:加入 causal mask。这时你必须手动构造一个下三角矩阵(torch.tril(torch.ones(seq_len, seq_len))),并把它 broadcast 到(batch, heads, seq_len, seq_len)形状,再用-1e9 * (1 - mask)做 masked_fill。这里有个经典陷阱:mask 必须是float类型,且1e9要足够大,才能让softmax后的对应位置接近 0(exp(-1e9)几乎为 0)。第三阶段:拆解 multi-head。重点不是“复制粘贴 8 次”,而是理解Linear层如何将d_model映射到d_model × 3(Q/K/V 各占一份),再view(batch, seq_len, num_heads, head_dim),最后transpose(1, 2)seq_lennum_heads轴交换——这个 transpose 是为了后续matmul能对每个 head 并行计算。很多人在这里混淆permutetranspose,结果Q的 shape 变成(batch, num_heads, head_dim, seq_len),导致点积维度不匹配。实测下来,用einops.rearrange(x, 'b s (h d) -> b h s d', h=num_heads)比原生view+transpose更不易出错。

2.3 工具链极简主义:只用torchnumpy,拒绝一切高级抽象

项目明确禁用transformersxformers、甚至torch.nn.TransformerEncoderLayer。理由很实在:这些库的源码本身就是“谜题”。比如xformersmemory_efficient_attention用了 Triton 内核,你根本没法pdb.set_trace()transformersAttention类里混着past_key_values缓存、attention_mask多种格式兼容、output_attentions开关,逻辑分支太多。我们只要最干净的信号路径:输入xLinearQ/K/VmatmulsoftmaxmatmulLinear→ 输出。连dropout都放在最后一步,而不是插在attn_weights后——因为论文原始实现中,dropout 是作用在attn_output上的(Dropout(attn_output)),而非attn_weights。这个细节差异会导致训练稳定性不同:attn_weightsdropout 会让某些 token 完全被忽略,而attn_outputdropout 是对最终加权和做随机丢弃,更符合“特征层面正则化”的直觉。我自己就踩过这个坑:早期把 dropout 加在 softmax 后,模型在长序列上 loss 突然飙升,debug 三天才发现是attn_weights的稀疏化破坏了信息路由的连续性。

3. 核心细节解析:手写注意力的 7 个生死关卡与避坑指南

3.1 关卡一:Q/K/V 的初始化与维度对齐——别让第一行代码就报错

手写注意力的第一行,往往是Q = self.W_q(x)。这里藏着三个致命细节。第一,W_q的权重初始化不能用默认的torch.nn.Linear初始化(kaiming_uniform_),而必须用torch.nn.init.xavier_normal_。为什么?因为xavier_normal的标准差是1/sqrt(fan_in),能保证Q = W_q @ x的输出方差接近 1,与K, V保持量级一致;若用kaiming(针对 ReLU 设计),Q的方差会偏大,导致Q @ K.T的值域爆炸。第二,x的输入 shape 必须是(batch_size, seq_len, d_model),但很多初学者从nn.Embedding拿到的是(seq_len, batch_size, d_model)(PyTorch 默认的batch_first=False)。如果你没调用.transpose(0, 1)Q @ K.T会因seq_lenbatch_size维度错位而报matmul: expected 2D tensor。第三,W_q, W_k, W_vin_features必须严格等于d_modelout_features必须等于d_model(不是d_model // num_heads!)。多头的拆分是在Linear输出后做的,不是在权重维度上切分的。我见过最离谱的错误是:把W_q定义成nn.Linear(d_model, d_model // num_heads),结果Q的最后一个维度只有 64,而K是 512,matmul直接崩溃。解决方案:统一用nn.Linear(d_model, d_model),然后通过view(..., num_heads, head_dim)拆分。

注意:head_dim = d_model // num_heads必须整除。如果d_model=512,num_heads=6512//6=85.333,程序不会报错,但view时会因总元素数不匹配而RuntimeError。务必在__init__中加断言:assert d_model % num_heads == 0, f"d_model {d_model} not divisible by num_heads {num_heads}"

3.2 关卡二:Q @ K.T的广播与缩放——那个sqrt(d_k)不是装饰品

Q @ K.T看似简单,实则暗流汹涌。假设Q.shape = (batch, num_heads, seq_len, head_dim)K.shape = (batch, num_heads, seq_len, head_dim),那么Q @ K.T的结果 shape 是(batch, num_heads, seq_len, seq_len)。这里K.T不是简单的K.transpose(-2, -1),而是K.permute(0, 1, 3, 2),因为K是 4D 张量,T只对最后两维生效。如果误用K.transpose(-1, -2),在head_dim=64时可能侥幸成功,但一旦seq_len ≠ head_dim,就会因维度不匹配而失败。更隐蔽的问题是缩放。scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(head_dim)这行代码,math.sqrt(head_dim)必须是float,不能是int。在 Python 3.8+ 中int除法会自动转float,但为了保险,显式写math.sqrt(float(head_dim))。我曾因head_dim=64int,在某些旧版本 PyTorch 中触发RuntimeError: matmul: expected a floating point tensor。此外,/ math.sqrt(head_dim)必须在matmul后立即执行,不能等到softmax前再除——因为softmax对输入的绝对值敏感,延迟缩放会导致数值不稳定。实测数据:当head_dim=64Q @ K.T的最大值达~1200exp(1200)直接溢出;缩放后最大值约15exp(15)≈3.2e6,完全在安全范围内。

3.3 关卡三:mask 的构造与应用——causal mask 不是画个三角形那么简单

Causal mask 的目标是让位置i只能看到≤i的位置,即attn_weights[i, j] = 0j > i。最直观的做法是mask = torch.tril(torch.ones(seq_len, seq_len)),但这只是 2D 矩阵。实际需要的是 4D mask:(batch, num_heads, seq_len, seq_len)。直接mask.unsqueeze(0).unsqueeze(0)会创建(1, 1, seq_len, seq_len),然后 broadcast 到 batch 和 heads 维度。但问题来了:torch.tril返回的是float64,而你的scoresfloat32,混合运算会触发隐式类型转换警告。正确做法:mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool)),用bool类型避免精度问题,再scores.masked_fill_(~mask, float('-inf'))。注意是~mask(取反),因为maskTrue表示允许的位置,~mask才是需要屏蔽的位置。另一个常见错误是maskseq_lenscoresseq_len不一致——比如你在forward中用x.size(1)得到seq_len,但mask是在__init__中预生成的固定尺寸。这会导致RuntimeError: The size of tensor a (128) must match the size of tensor b (64)。解决方案:永远在forward中动态生成 mask,或用torch.finfo(scores.dtype).min替代float('-inf'),确保类型严格匹配。

3.4 关卡四:softmax的数值稳定性——logsumexp是你的救命稻草

softmax(x) = exp(x) / sum(exp(x)),当x很大时,exp(x)溢出;当x很小时,exp(x)下溢为 0。标准解法是softmax(x) = exp(x - max(x)) / sum(exp(x - max(x))),即减去每行最大值。但手写时容易犯两个错:一是只减scores.max()(全局最大),而不是scores.max(dim=-1, keepdim=True)[0](每行最大);二是忘记keepdim=True,导致scores - max_val因维度不匹配而 broadcast 错误。更鲁棒的做法是直接调用torch.logsumexplog_probs = scores - torch.logsumexp(scores, dim=-1, keepdim=True),然后attn_weights = torch.exp(log_probs)。这样既避免了exp溢出,又保证了sum(attn_weights, dim=-1)严格等于 1(浮点误差内)。我做过对比实验:用原始softmaxsum(attn_weights, dim=-1)的最大偏差达1e-3;用logsumexp版本,偏差稳定在1e-7以内。这对长序列训练至关重要——偏差累积会导致梯度更新方向漂移。

3.5 关卡五:attn_weights @ V的维度缝合——transposeview的战争

attn_weights.shape = (batch, num_heads, seq_len, seq_len)V.shape = (batch, num_heads, seq_len, head_dim),那么attn_weights @ V的结果是(batch, num_heads, seq_len, head_dim)。接下来要缝合成(batch, seq_len, d_model)。这里有两条路:一是attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, d_model);二是attn_output = einops.rearrange(attn_output, 'b h s d -> b s (h d)')。前者需要contiguous(),因为transpose返回的是 view,而view要求内存连续;后者无需担心。我推荐einops,因为它的语义清晰:“把hd维度合并成一个”。但如果你坚持不用第三方库,contiguous()是必选项。漏掉它,view会报RuntimeError: view size is not compatible with input tensor's size and stride。另外,attn_outputLinear投影层W_oin_features必须是d_model(即num_heads × head_dim),out_features也必须是d_model,这样才能保持残差连接的维度一致。这里有个隐藏陷阱:W_o的初始化同样要用xavier_normal_,否则attn_output的方差会偏离 1,影响后续 LayerNorm 的效果。

3.6 关卡六:残差连接与 LayerNorm——别让 normalization 毁了你的梯度流

Transformer 的残差连接是x + attn_output,但x的 shape 是(batch, seq_len, d_model)attn_output经过W_o后也是(batch, seq_len, d_model),看起来完美。然而,attn_output的均值和方差往往与x不同——x经过 embedding 和 positional encoding,均值接近 0,方差接近 1;而attn_output经过多次matmulsoftmax,其分布可能偏移。直接相加会导致x + attn_output的方差变大,后续LayerNormgammabeta参数需要剧烈调整才能适应。标准解法是Pre-LN:在MultiHeadAttention模块前,先对xLayerNorm,再送入Q/K/V计算;残差后不再LayerNorm。原始论文用的是Post-LN(残差后LayerNorm),但 Pre-LN 训练更稳定。手写时,LayerNormnormalized_shape必须是(d_model,),而不是(seq_len, d_model)(batch, seq_len, d_model)nn.LayerNorm(d_model)会对最后的d_model维度做归一化,即对每个 token 的d_model维向量独立归一化。如果误设nn.LayerNorm((seq_len, d_model)),会试图对(seq_len, d_model)这个二维 shape 归一化,直接报错。实操心得:在forward中打印x.mean(), x.std()attn_output.mean(), attn_output.std(),如果两者标准差相差超过 2 倍,就要检查W_q/W_k/W_v/W_o的初始化是否一致。

3.7 关卡七:dropout 的位置与模式——训练时开,推理时关,但别关错了

Dropout在注意力模块中有两个位置可选:attn_weights后,或attn_output后。原始论文和torch.nn.MultiheadAttention都采用后者:attn_output = self.dropout(attn_output)。原因在于,attn_weights是概率分布,对其 dropout 会破坏 softmax 的归一化性质(sum(dropout(attn_weights), dim=-1)不再是 1),导致attn_output的期望值偏移。而attn_output是特征向量,对其 dropout 是标准的正则化。手写时,self.dropout = nn.Dropout(dropout_p)必须在__init__中定义,并在forward中调用self.dropout(attn_output)。关键细节:nn.Dropouttraining=True时随机置 0,在training=False时不做任何操作(即x * 1)。但很多人忘记在eval()模式下调用model.eval(),导致推理时仍在 dropout,输出波动巨大。更隐蔽的错误是:self.dropout被定义在MultiHeadAttention类里,但forward中调用的是self.dropout(attn_output),而attn_outputfloat32dropout期望float32,没问题;但如果attn_outputfloat16(混合精度训练),dropout会报错,需用torch.nn.Dropout1d或手动实现。我的经验是:始终用float32训练注意力模块,等整个模型稳定后再引入 AMP。

4. 实操过程:从零开始构建可调试注意力模块的完整流水线

4.1 环境准备与最小可运行骨架

我们从最简骨架开始,不追求功能完整,只确保能跑通。创建文件attention_from_scratch.py,内容如下:

import torch import torch.nn as nn import math class ScaledDotProductAttention(nn.Module): def __init__(self, dropout_p=0.0): super().__init__() self.dropout = nn.Dropout(dropout_p) def forward(self, Q, K, V, mask=None): # Q, K, V: (batch, num_heads, seq_len, head_dim) d_k = Q.size(-1) # 计算 scores: (batch, num_heads, seq_len, seq_len) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) # 应用 mask if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) # softmax 得到注意力权重 attn_weights = torch.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) # 加权求和 output = torch.matmul(attn_weights, V) return output, attn_weights # 测试骨架 if __name__ == "__main__": batch, seq_len, num_heads, head_dim = 2, 4, 2, 8 d_model = num_heads * head_dim # 随机生成输入 Q = torch.randn(batch, num_heads, seq_len, head_dim) K = torch.randn(batch, num_heads, seq_len, head_dim) V = torch.randn(batch, num_heads, seq_len, head_dim) # 创建 mask: causal mask for seq_len=4 mask = torch.tril(torch.ones(seq_len, seq_len)).bool() mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) attn = ScaledDotProductAttention() output, weights = attn(Q, K, V, mask) print(f"Q shape: {Q.shape}") print(f"Output shape: {output.shape}") print(f"Attn weights shape: {weights.shape}") print(f"Sum over last dim: {weights.sum(dim=-1)}")

运行此脚本,应输出:

Q shape: torch.Size([2, 2, 4, 8]) Output shape: torch.Size([2, 2, 4, 8]) Attn weights shape: torch.Size([2, 2, 4, 4]) Sum over last dim: tensor([[[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]])

这个骨架验证了最核心的matmulsoftmaxmask流程。注意maskunsqueeze操作,以及weights.sum(dim=-1)必须全为 1,这是 sanity check 的黄金标准。

4.2 构建 MultiHeadAttention 类:缝合 Q/K/V 与输出投影

在骨架基础上,扩展为完整的MultiHeadAttention

class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads, dropout_p=0.0): super().__init__() assert d_model % num_heads == 0, f"d_model {d_model} not divisible by num_heads {num_heads}" self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads # Linear layers for Q, K, V self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.attn = ScaledDotProductAttention(dropout_p) self.dropout = nn.Dropout(dropout_p) # 初始化权重 self._reset_parameters() def _reset_parameters(self): # 使用 xavier_normal 初始化所有 Linear 层 for p in self.parameters(): if p.dim() > 1: nn.init.xavier_normal_(p) def forward(self, x, mask=None): # x: (batch, seq_len, d_model) batch, seq_len, _ = x.shape # 生成 Q, K, V: (batch, seq_len, d_model) Q = self.W_q(x) K = self.W_k(x) V = self.W_v(x) # 拆分为多头: (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) Q = Q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) K = K.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) V = V.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # 如果有 mask,扩展到多头维度 if mask is not None: # mask: (batch, 1, seq_len, seq_len) or (1, 1, seq_len, seq_len) mask = mask.unsqueeze(1) # (batch, 1, seq_len, seq_len) -> (batch, 1, seq_len, seq_len) # 计算注意力 attn_output, attn_weights = self.attn(Q, K, V, mask) # 合并多头: (batch, num_heads, seq_len, head_dim) -> (batch, seq_len, d_model) attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model) # 输出投影 output = self.W_o(attn_output) output = self.dropout(output) return output, attn_weights

测试代码追加:

# 测试 MultiHeadAttention mha = MultiHeadAttention(d_model=16, num_heads=2, dropout_p=0.1) x = torch.randn(2, 4, 16) # (batch, seq_len, d_model) mask = torch.tril(torch.ones(4, 4)).bool().unsqueeze(0).unsqueeze(0) # (1, 1, 4, 4) output, weights = mha(x, mask) print(f"MHA Output shape: {output.shape}") print(f"MHA Weights shape: {weights.shape}")

此时output.shape应为(2, 4, 16),与输入x一致,满足残差连接要求。

4.3 添加 Pre-LayerNorm 与残差连接:构建 Transformer Block

真正的 Transformer Block 包含MultiHeadAttentionFeedForward两部分,我们先完成前者:

class TransformerBlock(nn.Module): def __init__(self, d_model, num_heads, dropout_p=0.0): super().__init__() self.norm1 = nn.LayerNorm(d_model) self.attn = MultiHeadAttention(d_model, num_heads, dropout_p) self.norm2 = nn.LayerNorm(d_model) # FeedForward 留空,专注注意力 self.dropout = nn.Dropout(dropout_p) def forward(self, x, mask=None): # Pre-LN: 先 norm,再 attn norm_x = self.norm1(x) attn_output, attn_weights = self.attn(norm_x, mask) # 残差连接 x = x + self.dropout(attn_output) # FFN 部分省略,只保留 attn return x, attn_weights # 测试 Transformer Block block = TransformerBlock(d_model=16, num_heads=2, dropout_p=0.1) x = torch.randn(2, 4, 16) mask = torch.tril(torch.ones(4, 4)).bool().unsqueeze(0).unsqueeze(0) output, weights = block(x, mask) print(f"Block Output shape: {output.shape}")

关键点:self.norm1(x)attn前调用,x + self.dropout(attn_output)是残差。此时outputmeanstd应与输入x接近,证明归一化有效。

4.4 深度调试:逐层打印中间变量,定位数值异常

调试的核心是“可视化”。在MultiHeadAttention.forward中插入打印:

def forward(self, x, mask=None): batch, seq_len, _ = x.shape print(f"[MHA] Input x: mean={x.mean():.4f}, std={x.std():.4f}, min={x.min():.4f}, max={x.max():.4f}") Q = self.W_q(x) print(f"[MHA] Q after W_q: mean={Q.mean():.4f}, std={Q.std():.4f}") Q = Q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) print(f"[MHA] Q after reshape: shape={Q.shape}, mean={Q.mean():.4f}") # ... 后续步骤同理,每步都 print

运行时,你会看到:

[MHA] Input x: mean=0.0021, std=0.9987, min=-3.2145, max=3.1876 [MHA] Q after W_q: mean=0.0015, std=0.9992 [MHA] Q after reshape: shape=torch.Size([2, 2, 4, 8]), mean=0.0015

如果某一步std突然变成 10 或 0.01,就说明该层权重或计算有误。例如,若Q after W_qstd=10,检查W_q初始化是否用了xavier_normal_;若Q after reshapemean偏离Q after W_q,检查view是否改变了数据顺序。

4.5 与 PyTorch 原生实现对齐:用allclose验证正确性

最终验证,用torch.nn.functional.scaled_dot_product_attention作为黄金标准:

# 生成相同输入 torch.manual_seed(42) x = torch.randn(2, 4, 16) mask = torch.tril(torch.ones(4, 4)).bool().unsqueeze(0).unsqueeze(0) # 手写 MHA mha_custom = MultiHeadAttention(d_model=16, num_heads=2, dropout_p=0.0) mha_custom.eval() # 关闭 dropout with torch.no_grad(): out_custom, _ = mha_custom(x, mask) # PyTorch 原生 Q_native = mha_custom.W_q(x).view(2, 4, 2, 8).transpose(1, 2) K_native = mha_custom.W_k(x).view(2, 4, 2, 8).transpose(1, 2) V_native = mha_custom.W_v(x).view(2, 4, 2, 8).transpose(1, 2) out_native = torch.nn.functional.scaled_dot_product_attention( Q_native, K_native, V_native, attn_mask=mask, dropout_p=0.0, is_causal=False ) out_native = out_native.transpose(1, 2).contiguous().view(2, 4, 16) out_native = mha_custom.W_o(out_native) print(f"Custom vs Native allclose: {torch.allclose(out