当你第一次看到Self-Attention的公式时,可能会被那一串矩阵运算吓到。但别担心,我们可以用一个生活中的例子来理解它。想象你在阅读一篇文章时,眼睛会自动聚焦在关键词上——比如看到"火灾"会立刻警觉,而忽略"的"、"了"这样的助词。Self-Attention就是让AI学会这种能力。
具体到技术实现,整个过程就像是在做三件事:
用代码来说,这就是三个矩阵乘法:
python复制Q = X @ W_Q # 问题卡片
K = X @ W_K # 答案手册
V = X @ W_V # 信息包
我第一次实现时犯过一个典型错误——忘记了对QK^T的结果进行缩放。这会导致softmax后的梯度爆炸,模型根本无法训练。后来才明白那个√d_k的重要性:就像调节音量旋钮,太大就会失真,太小又听不清。
让我们用PyTorch从零开始实现这个过程。假设我们的输入是一个包含3个单词的句子,每个词用4维向量表示(实际中通常是768维):
python复制import torch
import torch.nn.functional as F
# 输入矩阵:3个token,每个4维
X = torch.tensor([[1.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 0.0, 1.0],
[1.0, 1.0, 0.0, 0.0]])
# 初始化权重矩阵 (通常用xavier初始化)
W_Q = torch.randn(4, 4, requires_grad=True)
W_K = torch.randn(4, 4, requires_grad=True)
W_V = torch.randn(4, 4, requires_grad=True)
计算Q、K、V时要注意矩阵形状的变化。我调试时经常打印shape来验证:
python复制Q = X @ W_Q # (3,4) @ (4,4) -> (3,4)
K = X @ W_K # 同上
V = X @ W_V # 同上
print(f"Q shape: {Q.shape}, K shape: {K.shape}, V shape: {V.shape}")
这里有个实用技巧:用einops库可以更直观地操作张量维度。比如rearrange(Q, 'b s d -> b d s')可以快速转置矩阵,比原生PyTorch更易读。
计算注意力权重时最容易出错的是缩放点积这一步。来看具体实现:
python复制d_k = Q.size(-1) # 特征维度4
scores = Q @ K.transpose(-2, -1) / torch.sqrt(torch.tensor(d_k)) # (3,3)
attn_weights = F.softmax(scores, dim=-1)
我曾遇到过两个典型问题:
理解权重矩阵的物理意义很重要。假设我们计算结果是:
code复制[[0.8, 0.1, 0.1],
[0.2, 0.7, 0.1],
[0.1, 0.2, 0.7]]
这表示:
可视化这些权重可以帮助调试。用matplotlib画热力图是我常用的方法:
python复制import matplotlib.pyplot as plt
plt.imshow(attn_weights.detach().numpy(), cmap='hot')
plt.colorbar()
现在我们把所有步骤整合成一个完整的Self-Attention层:
python复制class SelfAttention(nn.Module):
def __init__(self, embed_size):
super().__init__()
self.embed_size = embed_size
self.W_Q = nn.Linear(embed_size, embed_size)
self.W_K = nn.Linear(embed_size, embed_size)
self.W_V = nn.Linear(embed_size, embed_size)
def forward(self, X):
Q = self.W_Q(X)
K = self.W_K(X)
V = self.W_V(X)
d_k = self.embed_size
scores = Q @ K.transpose(-2, -1) / torch.sqrt(torch.tensor(d_k))
attn_weights = F.softmax(scores, dim=-1)
output = attn_weights @ V
return output
验证反向传播是否正常很重要。我的检查方法是:
python复制model = SelfAttention(4)
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.MSELoss()
# 模拟训练步骤
for _ in range(100):
optimizer.zero_grad()
output = model(X)
loss = loss_fn(output, torch.randn_like(output)) # 随机目标
loss.backward()
optimizer.step()
print(loss.item()) # 应该看到loss下降
如果loss不下降,可能是梯度消失/爆炸。这时需要检查初始化方式,或者调整缩放因子。
单头注意力就像只用一只眼睛看世界,而多头则是用多只眼睛从不同角度观察。实现时最需要注意的是维度的拆分与合并:
python复制class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super().__init__()
assert embed_size % num_heads == 0
self.head_size = embed_size // num_heads
self.num_heads = num_heads
self.W_Q = nn.Linear(embed_size, embed_size)
self.W_K = nn.Linear(embed_size, embed_size)
self.W_V = nn.Linear(embed_size, embed_size)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, X):
batch_size = X.size(0)
# 线性变换后拆分多头
Q = self.W_Q(X).view(batch_size, -1, self.num_heads, self.head_size)
K = self.W_K(X).view(batch_size, -1, self.num_heads, self.head_size)
V = self.W_V(X).view(batch_size, -1, self.num_heads, self.head_size)
# 计算注意力
scores = Q @ K.transpose(-2, -1) / torch.sqrt(torch.tensor(self.head_size))
attn_weights = F.softmax(scores, dim=-1)
output = attn_weights @ V
# 合并多头
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_size)
return self.fc_out(output)
调试多头注意力时,我总结了几点经验:
embed_size能被num_heads整除.contiguous()避免内存问题在真实项目中,我们还需要考虑以下几点优化:
1. 掩码处理:
python复制# 创建下三角掩码 (用于解码器)
mask = torch.tril(torch.ones(seq_len, seq_len))
scores = scores.masked_fill(mask == 0, float('-inf'))
2. 注意力dropout:
python复制attn_weights = F.dropout(attn_weights, p=0.1, training=self.training)
3. 缓存KV(用于推理加速):
python复制if use_cache:
self.K = torch.cat([self.K, K], dim=1)
self.V = torch.cat([self.V, V], dim=1)
K, V = self.K, self.V
我曾在处理长文本时遇到OOM问题,后来采用以下方法解决:
抛开代码,从数学上看Self-Attention实际上是在学习一种动态的内容寻址机制。与传统的查找表不同,这里的"地址"是通过QK^T计算得到的相似度。
具体来说,整个过程可以分解为:
这种设计有几点精妙之处:
我在复现原始论文时发现,去掉√d_k这个缩放因子后,模型在深层的梯度要么趋近于0,要么爆炸。这印证了论文中的理论分析——缩放是为了保持梯度在合理范围内。