当你第一次看到Transformer论文中那些复杂的矩阵运算时,是否感到一头雾水?作为现代深度学习最重要的架构之一,Transformer的核心——Self-Attention机制,其实可以通过代码变得直观易懂。今天我们不谈数学推导,直接动手用PyTorch构建一个完整的Self-Attention层,让你真正理解Q、K、V矩阵如何运作。
在开始编码前,我们需要明确几个关键概念:
这三个矩阵都来自同一个输入,只是通过不同的权重矩阵进行线性变换得到。这种设计允许模型灵活地学习输入序列中各个元素之间的关系。
python复制import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
让我们分解Self-Attention的前向传播过程。首先,输入张量的形状应该是(batch_size, sequence_length, embed_size)。我们需要将其分割为多个头(heads),然后分别计算Q、K、V。
python复制def forward(self, values, keys, query, mask):
N = query.shape[0] # 批大小
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# 分割嵌入维度为多个头
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
# 计算Q、K、V
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
注意:在实际应用中,我们通常使用相同的输入作为values、keys和queries,这就是"Self"-Attention名称的由来——它关注输入序列内部的关系。
注意力分数的计算是Self-Attention的核心。我们通过矩阵乘法计算query和key的点积,然后缩放以避免梯度消失问题。
python复制# 计算注意力分数
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
# 应用softmax获取注意力权重
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
# 应用注意力权重到values上
out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
out = out.reshape(N, query_len, self.heads * self.head_dim)
这里有几个关键点需要注意:
torch.einsum是爱因斯坦求和约定,用于高效实现多维张量运算embed_size ** (1/2)防止点积结果过大导致softmax梯度消失多头注意力的优势在于允许模型同时关注不同位置的不同表示子空间。我们需要将各个头的输出拼接起来,然后通过一个全连接层进行整合。
python复制# 拼接多头输出并通过全连接层
out = self.fc_out(out)
return out
完整的Self-Attention层现在可以像这样使用:
python复制embed_size = 256
heads = 8
attention = SelfAttention(embed_size, heads)
x = torch.rand((32, 10, embed_size)) # 批大小32,序列长度10,嵌入维度256
out = attention(x, x, x, mask=None)
print(out.shape) # torch.Size([32, 10, 256])
在实现Self-Attention时,你可能会遇到以下问题:
调试时可以关注这些关键张量的形状:
| 张量名称 | 预期形状 | 说明 |
|---|---|---|
| queries | (N, query_len, heads, head_dim) | 查询矩阵 |
| keys | (N, key_len, heads, head_dim) | 键矩阵 |
| values | (N, value_len, heads, head_dim) | 值矩阵 |
| energy | (N, heads, query_len, key_len) | 注意力分数 |
| output | (N, query_len, embed_size) | 最终输出 |
原始Self-Attention缺少位置信息,我们需要添加位置编码。以下是常用的正弦位置编码实现:
python复制class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
在实际Transformer中,Self-Attention通常与Layer Normalization和残差连接配合使用:
python复制class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, dropout, forward_expansion):
super(TransformerBlock, self).__init__()
self.attention = SelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size)
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, mask):
attention = self.attention(value, key, query, mask)
x = self.dropout(self.norm1(attention + query))
forward = self.feed_forward(x)
out = self.dropout(self.norm2(forward + x))
return out
让我们看一个简单的文本分类任务如何使用Self-Attention:
python复制class TextClassifier(nn.Module):
def __init__(self, vocab_size, embed_size, num_classes, heads):
super(TextClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.position = PositionalEncoding(embed_size)
self.attention = SelfAttention(embed_size, heads)
self.fc = nn.Linear(embed_size, num_classes)
def forward(self, x):
embedded = self.embedding(x)
embedded = self.position(embedded)
attended = self.attention(embedded, embedded, embedded, None)
# 取序列第一个位置的输出作为分类依据
return self.fc(attended[:, 0, :])
这个简单的模型已经能够捕捉输入文本中词语间的重要关系。在实际项目中,你可以堆叠多个Self-Attention层,并加入更多技巧如:
实现Self-Attention最难的部分不是写代码,而是理解其背后的设计思想。当你亲手实现过几次后,那些看似复杂的矩阵运算会变得非常直观。我在第一次实现时最大的收获是理解了为什么需要三个不同的矩阵(Q、K、V),而不是直接用输入计算注意力——这种分离让模型能够更灵活地学习不同的关系模式。