想象一下你在教一个小朋友背古诗。如果让他从头到尾完整背诵,可能需要反复练习很多遍。但如果把诗句拆成几部分,同时让几个小朋友分别背诵不同段落,最后再组合起来,效率就会高很多。这就是Transformer模型采用Masked Multi-Head Attention实现并行训练的核心思路。
在传统的序列建模中(比如早期的RNN),模型必须像串珠子一样逐个处理每个单词。处理"I love you"时,必须先看完"I",才能处理"love",最后才能处理"you"。这种串行方式有两个致命缺陷:一是训练速度慢,二是误差会像滚雪球一样累积。就像多米诺骨牌,前面倒了一块,后面会跟着倒下一片。
而Masked Multi-Head Attention通过三个关键设计解决了这些问题:
让我们用"Hello world"翻译成"Bonjour le monde"的具体例子,看看掩码如何工作。假设输入序列的向量表示为[h1, h2],目标输出为[b1, b2, b3]。
在传统串行训练中:
这种模式下,b2的预测已经受到b1'误差的影响,b3更是累积了前两步的误差。
而使用掩码的并行训练:
python复制# 模拟三个并行训练步骤
step1_input = [h1,h2] # 掩码屏蔽全部输出
step2_input = [h1,h2,b1] # 掩码屏蔽b2,b3
step3_input = [h1,h2,b1,b2] # 掩码屏蔽b3
关键技巧在于:虽然实际操作是并行的,但通过掩码让每个位置"以为"自己是在顺序处理。就像给三个学生分别发考卷时,用遮挡板盖住他们不应该看到的部分。
具体到代码层面,最常见的实现方式是在计算注意力权重时,给需要屏蔽的位置赋值为负无穷:
python复制def scaled_dot_product_attention(Q, K, V, mask=None):
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9) # 屏蔽位置设为极小数
p_attn = F.softmax(scores, dim=-1)
return torch.matmul(p_attn, V)
这种实现有三大优势:
当模型部署上线后,游戏规则突然改变:现在没有标准答案可供参考了。就像考试时不再提供参考答案,必须自己一步步推导。这时候Masked Multi-Head Attention展现出另一面重要价值——实现自回归生成。
以文本续写任务为例,生成"今天天气真好"的过程:
虽然每次还是计算整个序列的注意力,但通过掩码确保每个位置只能看到当前位置及之前的信息。这就好比写作文时,虽然你可以随时回顾前面写的内容,但无法预知后面要写什么。
很多初学者会困惑:这看起来不又回到RNN的串行模式了吗?关键差异在于:
这种设计使得Transformer在生成长文本时,依然能保持前后一致性。就像虽然每次只写一个字,但作者始终把握着整篇文章的脉络。
单头注意力就像只用一种颜色的荧光笔标记文本,而多头机制相当于同时使用多种颜色标记不同重点。举个例子,在分析句子"这个苹果很好吃"时:
实验表明,8个头左右的配置通常能在计算成本和模型性能间取得较好平衡。每个头会学习不同的注意力模式:
| 头编号 | 主要关注特征 | 适用任务示例 |
|---|---|---|
| 头1 | 局部语法结构 | 词性标注 |
| 头2 | 长程依赖 | 指代消解 |
| 头3 | 语义角色 | 关系抽取 |
当多头遇上掩码,需要注意几个工程细节:
一个典型的多头实现如下:
python复制class MultiHeadAttention(nn.Module):
def __init__(self, h, d_model):
super().__init__()
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
def forward(self, query, key, value, mask=None):
if mask is not None:
mask = mask.unsqueeze(1) # 增加头维度
nbatches = query.size(0)
# 线性变换后分割多头
query, key, value = [
lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for lin, x in zip(self.linears, (query, key, value))
]
# 计算注意力
x = scaled_dot_product_attention(query, key, value, mask)
# 合并多头结果
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
return self.linears[-1](x)
在真实项目中配置Masked Multi-Head Attention时,这些经验可能帮你少走弯路:
学习率设置:由于并行训练的特性,初始学习率可以比RNN大2-5倍。但需要使用warmup策略,比如先用500步从0线性增加到目标学习率。
掩码初始化:对于生成任务,建议在验证集上测试不同掩码强度。有时适当放宽掩码限制(如让当前位置能看到前2-3个未来token)能提升生成流畅度。
内存优化:当处理长序列时,可以采用块稀疏注意力模式。比如将序列分成若干块,只在块内和部分跨块位置计算注意力。