我第一次接触自注意力机制是在处理一个中文分词项目时。当时团队尝试用传统RNN模型做词性标注(POS tagging),结果发现模型对长距离依赖关系处理得非常糟糕。比如"他把苹果放进冰箱"这句话,"苹果"作为名词的识别准确率会受到前面动词"把"的影响,但传统RNN在传递到第五个词时已经丢失了大量前面信息。这促使我开始研究self-attention这个神奇的结构。
Sequence Labeling任务可以看作是最早催生自注意力机制的土壤。这类任务的特点是输入输出序列等长,每个位置都需要结合上下文做出决策。传统解决方案是用滑动窗口或RNN,但都存在明显缺陷:
自注意力机制的革命性在于,它让每个位置都能直接"看到"序列所有位置,并通过可学习的权重动态决定关注哪些上下文。举个例子,在分析"银行账户余额"时,"银行"这个词的语义理解需要同时关注后面的"账户"(更相关)和前面的语境(可能相关性较低)。self-attention通过计算query-key-value的三元组,完美实现了这种动态注意力分配。
自注意力最精妙的设计在于其数学表达的简洁性。核心计算可以分解为三个步骤:
用Python代码表示核心计算过程:
python复制import torch
def self_attention(Q, K, V):
# Q,K,V shape: (batch_size, seq_len, d_model)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(d_model)
weights = torch.softmax(scores, dim=-1)
return torch.matmul(weights, V)
这种设计有三大优势:
单头注意力就像只用一只眼睛看世界,而多头机制相当于给了模型多组"视觉系统"。在实践中,我发现设置8个头效果通常不错:
| 头数量 | 训练速度 | 准确率 | 适用场景 |
|---|---|---|---|
| 1 | 最快 | 最低 | 简单任务 |
| 4 | 较快 | 中等 | 中等复杂度 |
| 8 | 标准 | 较高 | 主流选择 |
| 16+ | 较慢 | 可能过拟合 | 超大模型 |
多头注意力的核心思想是让不同头学习不同的注意力模式。比如在翻译任务中,有的头可能专注主语-动词关系,有的头关注时间状语,还有的头捕捉形容词修饰关系。这种"分而治之"的策略大幅提升了模型容量。
早期我犯过一个错误:直接去掉位置编码训练Transformer,结果模型完全无法理解语句顺序。这让我意识到,纯自注意力需要显式的位置信息注入。常见的位置编码方案包括:
python复制PE(pos,2i) = sin(pos/10000^(2i/d_model))
PE(pos,2i+1) = cos(pos/10000^(2i/d_model))
在图像处理中,我还尝试过将二维坐标信息编码进去,效果比单纯展平图像要好很多。
Transformer的另一个精妙设计是残差结构。有次我尝试去掉残差连接,模型深度超过6层后性能急剧下降。这是因为:
实际配置时要注意:先LayerNorm再进入注意力层的效果通常更好,这与原始论文的Pre-LN设计一致。
在文本分类任务上,我做过一组对比实验:
| 模型类型 | 准确率 | 训练速度(样本/秒) | 显存占用 |
|---|---|---|---|
| CNN(kernel=3) | 88.2% | 1200 | 2.1GB |
| CNN(kernel=5) | 88.7% | 950 | 2.3GB |
| Self-Attention | 90.1% | 800 | 3.5GB |
| Hybrid模型 | 90.3% | 700 | 3.8GB |
发现当数据量较小时,CNN仍有优势;但当数据量超过10万条时,自注意力开始显现威力。后来我们开发了一个混合架构,在浅层用CNN捕捉局部特征,深层用自注意力处理全局关系,取得了最佳平衡。
在序列标注任务中,RNN家族(LSTM/GRU)与自注意力的关键差异:
长距离依赖:
计算效率:
内存占用:
实际项目中,对于实时性要求高的短文本处理,我们仍会考虑GRU;但对于质量优先的长文本场景,Transformer已成标配。
d_model(模型维度)的设置需要权衡:
经验公式:
code复制d_model ≈ 4 × (预期最大序列长度)^(1/2)
例如处理512长度的文本时,设置d_model=64-128比较合适。同时要保证:
处理变长序列时,mask的使用容易出错。正确的做法是:
python复制# 创建padding mask
mask = (x != pad_idx).unsqueeze(1) # [batch, 1, seq_len]
# 创建因果mask(解码器用)
seq_len = x.size(1)
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
在混合精度训练时,记得将mask转换为与注意力分数相同的dtype,否则可能引发数值溢出。
当第一次将Vision Transformer应用到图像分类时,效果令人惊艳。与CNN相比:
一个有趣的发现:浅层注意力头往往关注局部纹理,深层头则捕捉全局语义关系。
在语音识别任务中,原始自注意力需要调整:
解决方案:
在我们的实验中,Conformer(CNN+Transformer混合)模型将WER降低了23%。
尽管自注意力表现出色,但在以下场景仍需谨慎:
最近我们在处理法律文书时,就遇到了万字符级别的长文档挑战。最终采用的方案是:
这种方案在保持95%准确率的同时,将显存占用降低了70%。