Self-Attention自注意力机制

Self-Attention自注意力机制

1、关于Q、K、V

Self-Attention 里每个词会变成三个向量:

Q = Query 我现在想找什么信息?
K = Key 我这里有什么信息,适合被谁匹配?
V = Value 如果别人关注我,我实际提供什么内容?

Self-Attention 就是:

每个 token 拿自己的 Query 去和所有 token 的 Key 做匹配,得到我应该关注谁的权重,然后用这些权重加权所有 token 的 Value,形成新的 token 表示。

Self-Attention = 计算词与词之间的相关性,然后重新融合上下文信息。


2、计算过程

假设输入序列是:

X = [x_1, x_2, x_3, ..., x_n],输入1个词语,有n维向量


其中W_Q、W_K、W_V是把输入变成 Query、 Key、 Value的矩阵,大小为n✖n,通过训练得到。

然后计算注意力分数:

计算得到的矩阵表示第i个token对于j个 token 的关注程度

但为了避免数值过大,将其除以方差,这里就是。由于softmax的特性,如果数值太大会导致得到极端概率,如[0.999,0.01],可能导致梯度消失。

简单来说就是数值越大的同比例数字经过softmax变化之后的概率相差越大,比如8和4相比于4和2经 softmax变化之后分别为[0.982, 0.018]和[0.881, 0.119]。

接着使用softmax,将分数变为概率:

最后使用权重加权Value得到公式:


3、简单例子计算

比如一句话有 3 个词,每个词使用2维向量表示:

先不妨简化Q=K=V=X,此时带入计算:

注意力矩阵,表示第i个token对于j个 token 的关注程度:

最后使用权重加权得到输出:

可以看见相比于原始输入X,其矩阵形状并没有改变,仍然为3×2。


4、一个小的PyTorch实现

import torch import torch.nn as nn import torch.nn.functional as F import math class SelfAttention(nn.Module): def __init__(self, embed_dim): super().__init__() self.embed_dim = embed_dim self.W_q = nn.Linear(embed_dim, embed_dim) self.W_k = nn.Linear(embed_dim, embed_dim) self.W_v = nn.Linear(embed_dim, embed_dim) def forward(self, x): Q = self.W_q(x) K = self.W_k(x) V = self.W_v(x) scores = Q @ K.T scores = scores / math.sqrt(self.embed_dim) attention_weights = F.softmax(scores, dim=-1) output = attention_weights @ V return output, attention_weights x = torch.randn(4, 8) attention = SelfAttention(embed_dim=8) output, weights = attention(x) print("output shape:", output.shape) print("attention weights shape:", weights.shape) print(weights)

最后输出:

output shape: torch.Size([4, 8]) attention weights shape: torch.Size([4, 4]) tensor([[0.3009, 0.1964, 0.2414, 0.2612], [0.3054, 0.2573, 0.2479, 0.1895], [0.2112, 0.3020, 0.2613, 0.2255], [0.2292, 0.2222, 0.2339, 0.3147]], grad_fn=<SoftmaxBackward0>)

5、简易模拟具身VLA场景

模拟机械臂通过夹爪抓取场景,物体为一个红色方块和蓝色茶杯。

5.1、selfattention机制演示

from __future__ import annotations import math import torch torch.set_printoptions(precision=3, sci_mode=False) def scaled_dot_product_attention(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Compute single-head self-attention without nn.MultiheadAttention.""" d_model = x.shape[-1] # Fixed projections keep the example readable: Q, K, V all use x directly. q = x k = x v = x scores = q @ k.T / math.sqrt(d_model) weights = torch.softmax(scores, dim=-1) output = weights @ v return output, weights def build_embodied_tokens() -> tuple[list[str], torch.Tensor]: """Create interpretable toy tokens. Feature layout: [red, blue, block, cup, robot, target, x, y] """ token_names = [ "language: target is red block", "vision: red block at (0.8, 0.7)", "vision: blue cup at (0.2, 0.3)", "robot: gripper at (0.1, 0.6)", "action query: where to move next", ] x = torch.tensor( [ [1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 1.0, 0.0, 0.0, 0.8, 0.8, 0.7], [0.0, 1.0, 0.0, 1.0, 0.0, 0.1, 0.2, 0.3], [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.1, 0.6], [1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.1, 0.6], ], dtype=torch.float32, ) return token_names, x def main() -> None: token_names, x = build_embodied_tokens() output, weights = scaled_dot_product_attention(x) print("Tokens:") for index, name in enumerate(token_names): print(f"{index}: {name}") print("\nAttention weights:") print(weights) action_index = len(token_names) - 1 print("\nWhat the action query attends to:") for name, weight in zip(token_names, weights[action_index]): print(f"{weight.item():.3f} {name}") print("\nNew action-token representation:") print(output[action_index]) if __name__ == "__main__": main()

输出结果:

Tokens: 0: language: target is red block 1: vision: red block at (0.8, 0.7) 2: vision: blue cup at (0.2, 0.3) 3: robot: gripper at (0.1, 0.6) 4: action query: where to move next Attention weights: tensor([[0.275, 0.256, 0.099, 0.095, 0.275], [0.223, 0.314, 0.097, 0.099, 0.266], [0.159, 0.180, 0.327, 0.164, 0.170], [0.154, 0.183, 0.165, 0.249, 0.249], [0.214, 0.237, 0.082, 0.120, 0.347]]) What the action query attends to: 0.214 language: target is red block 0.237 vision: red block at (0.8, 0.7) 0.082 vision: blue cup at (0.2, 0.3) 0.120 robot: gripper at (0.1, 0.6) 0.347 action query: where to move next New action-token representation: tensor([0.798, 0.082, 0.798, 0.082, 0.467, 0.758, 0.253, 0.471])

注意力可视化图为:

5.2、对比

手工设计规则,无学习训练。一者使用self-attention,一者写死红色方块,二者平均得分对比。

from __future__ import annotations import math import random from dataclasses import dataclass import torch FEATURES = ["red", "blue", "block", "cup", "robot", "target", "x", "y"] VOCAB = { "red_block": torch.tensor([1, 0, 1, 0, 0, 1, 0, 0], dtype=torch.float32), "blue_cup": torch.tensor([0, 1, 0, 1, 0, 1, 0, 0], dtype=torch.float32), } @dataclass(frozen=True) class Scene: red_block: torch.Tensor blue_cup: torch.Tensor robot: torch.Tensor target_name: str def make_scene() -> Scene: red_block = torch.rand(2) blue_cup = torch.rand(2) robot = torch.rand(2) target_name = random.choice(["red_block", "blue_cup"]) return Scene(red_block=red_block, blue_cup=blue_cup, robot=robot, target_name=target_name) def object_token(name: str, xy: torch.Tensor) -> torch.Tensor: if name == "red_block": base = torch.tensor([1, 0, 1, 0, 0, 0], dtype=torch.float32) elif name == "blue_cup": base = torch.tensor([0, 1, 0, 1, 0, 0], dtype=torch.float32) else: raise ValueError(f"Unknown object: {name}") return torch.cat([base, xy]) def robot_token(xy: torch.Tensor) -> torch.Tensor: return torch.tensor([0, 0, 0, 0, 1, 0, xy[0].item(), xy[1].item()], dtype=torch.float32) def attention_policy(scene: Scene) -> torch.Tensor: language = VOCAB[scene.target_name] red = object_token("red_block", scene.red_block) blue = object_token("blue_cup", scene.blue_cup) robot = robot_token(scene.robot) # Query asks for the target object. Keys are candidate objects. q = language keys = torch.stack([red, blue]) values = torch.stack([scene.red_block, scene.blue_cup]) semantic_keys = keys[:, :6] semantic_query = q[:6] scores = 4.0 * (semantic_keys @ semantic_query) / math.sqrt(semantic_query.numel()) weights = torch.softmax(scores, dim=0) attended_target_xy = weights @ values direction = attended_target_xy - scene.robot return direction / (direction.norm() + 1e-8) def naive_policy(scene: Scene) -> torch.Tensor: # Always moves to red block, even when the instruction asks for blue cup. direction = scene.red_block - scene.robot return direction / (direction.norm() + 1e-8) def target_direction(scene: Scene) -> torch.Tensor: target_xy = scene.red_block if scene.target_name == "red_block" else scene.blue_cup direction = target_xy - scene.robot return direction / (direction.norm() + 1e-8) def cosine(a: torch.Tensor, b: torch.Tensor) -> float: return torch.dot(a, b).item() def evaluate(num_scenes: int = 1000) -> None: random.seed(9) torch.manual_seed(9) attention_scores = [] naive_scores = [] for _ in range(num_scenes): scene = make_scene() gold = target_direction(scene) attention_scores.append(cosine(attention_policy(scene), gold)) naive_scores.append(cosine(naive_policy(scene), gold)) attention_mean = sum(attention_scores) / len(attention_scores) naive_mean = sum(naive_scores) / len(naive_scores) print(f"Scenes: {num_scenes}") print(f"Attention policy: {attention_mean:.3f}") print(f"Naive fixed policy: {naive_mean:.3f}") if __name__ == "__main__": evaluate()

输出结果为:

cenes: 1000 Attention policy: 0.999 Naive fixed policy: 0.707

attention策略基本能过做到正确选择,平均分数达到0.999,而写死红色方块的即使输入要它抓取蓝色茶杯依然抓取红色方块,由于红色蓝色物体存在角度偏差,最后只有0.707的平均分数,可以看出attention有效。