想象一下你在教一个小朋友做数学题。最开始教1+1=2很简单,但随着题目难度增加,比如要计算(1+1)*3-5/2,小朋友可能会因为步骤太多而忘记中间结果。深度神经网络也面临类似问题——当网络层数越来越多时,信息在传递过程中就像经过太多人的传话游戏,最终可能面目全非。
2015年之前,研究人员发现一个奇怪现象:56层的神经网络识别准确率居然比20层的还要低。这完全违背了"越深越好"的直觉,问题就出在梯度消失上。反向传播时,梯度要经过几十层网络,就像用微弱的信号穿越长长的隧道,到达起点时几乎衰减为零。残差连接的发明者何恺明用了个巧妙的办法:在每层旁边修条"捷径",让信号可以跳过某些层直达后方。
传统神经网络层可以表示为y=F(x),而残差块的设计是y=F(x)+x。这个简单的加法操作带来了革命性变化:
python复制# 普通卷积块
def forward(x):
return conv2d(relu(conv2d(x)))
# 残差块
def forward(x):
return conv2d(relu(conv2d(x))) + x # 关键就在这个加法
用装修房子来类比:传统网络要求工人直接装修出理想中的房子,而残差网络只要求工人给出"装修方案"(F(x)),实际效果是原始房子(x)加上装修变化量。当某些层不需要修改时,F(x)可以轻松学习为0,保留原始信息。
反向传播时,导数计算遵循链式法则。对于残差连接,梯度有两条回流路径:
即使常规路径的梯度变得极小,捷径路径也能保证至少∂L/∂y的梯度能直接传回。这就像在陡坡上修建之字形公路的同时,保留了一条直通的紧急通道。
ResNet34的残差块结构值得仔细研究:
python复制class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
# 当输入输出维度不一致时需要投影
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x) # 关键操作
return F.relu(out)
实际使用时要注意几个细节:
原始Transformer论文中的残差连接稍有不同:
python复制# 编码器层的实现示例
class TransformerLayer(nn.Module):
def __init__(self, d_model, nhead):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead)
self.linear = nn.Linear(d_model, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, src):
# 第一处残差连接
src2 = self.self_attn(src, src, src)[0]
src = self.norm1(src + src2)
# 第二处残差连接
src2 = self.linear(src)
src = self.norm2(src + src2)
return src
与CNN的主要区别:
传统神经网络常用He初始化或Xavier初始化,但加入残差连接后需要特别注意:
python复制# 正确的初始化方式
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
# 特别处理残差块最后一层
nn.init.zeros_(self.conv2.weight)
虽然残差连接缓解了梯度消失,但可能引发梯度爆炸。实践中发现:
我在训练ResNet50时遇到过典型情况:当学习率设为0.1时,前几个epoch的loss出现NaN,调整到0.01后训练过程就稳定了。
残差连接的思想正在衍生出更多变体:
最近的一个有趣发现是残差连接实际上在隐式地进行模型集成。每个残差块可以看作是一个浅层网络,整个系统相当于多个网络的加权组合。这解释了为什么残差网络对超参数不那么敏感。
在部署到移动端时,我们还发现残差结构有个意外优势:由于大部分层的输出变化不大,可以跳过某些块的计算(类似early exiting),实际推理速度比理论FLOPs要快20-30%。