在计算机视觉领域,ResNet架构因其简洁的模块化设计成为众多下游任务的首选主干网络。然而,传统ResNet在跨通道交互和特征注意力机制上存在明显局限。本文将手把手教你用PyTorch实现ResNeSt的核心创新——Split-Attention模块,并展示如何将其无缝集成到目标检测和语义分割模型中。
Split-Attention的核心思想是通过多层次特征分组与动态权重分配,实现更精细的跨通道交互。与SE-Net等传统注意力机制不同,它在两个维度上对特征进行分解:
这种双重划分创造了K×R个特征子空间,每个子空间都能学习独特的特征表示。模块通过全局上下文信息动态计算各子空间的注意力权重,实现特征的自适应融合。
关键计算公式如下:
python复制# 基数组表示计算(沿splits维度求和)
group_rep = sum(split_rep for split_rep in splits)
# 注意力权重计算(基于全局池化)
gap = F.adaptive_avg_pool2d(group_rep, (1, 1))
attention = softmax(FC(FC(gap))) # 两层全连接+softmax
# 加权特征融合
weighted_rep = sum(attention[i] * splits[i] for i in range(radix))
下面我们逐步构建完整的Split-Attention模块:
python复制import torch
import torch.nn as nn
import torch.nn.functional as F
class RadixSoftmax(nn.Module):
def __init__(self, radix, cardinality):
super().__init__()
self.radix = radix
self.cardinality = cardinality
def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x
class SplitAttention(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=3,
stride=1, radix=2, reduction_factor=4):
super().__init__()
out_channels = out_channels or in_channels
self.radix = radix
mid_channels = in_channels * radix
# 特征变换层
self.conv = nn.Conv2d(
in_channels, mid_channels, kernel_size,
stride=stride, padding=kernel_size//2,
groups=radix, bias=False)
# 注意力计算路径
attn_channels = max(in_channels // reduction_factor, 32)
self.fc1 = nn.Conv2d(out_channels, attn_channels, 1)
self.fc2 = nn.Conv2d(attn_channels, mid_channels, 1)
self.rsoftmax = RadixSoftmax(radix, 1)
def forward(self, x):
x = self.conv(x)
# 计算注意力权重
B, C, H, W = x.shape
if self.radix > 1:
splits = x.view(B, self.radix, C//self.radix, H, W)
gap = splits.sum(dim=1)
else:
gap = x
gap = F.adaptive_avg_pool2d(gap, 1)
attn = self.fc1(gap)
attn = self.fc2(attn)
attn = self.rsoftmax(attn).view(B, -1, 1, 1)
# 应用注意力
if self.radix > 1:
attn = attn.view(B, self.radix, C//self.radix, 1, 1)
out = (splits * attn).sum(dim=1)
else:
out = x * attn
return out
提示:实际应用中,radix通常设置为2-4之间,过大的值会增加计算量但收益递减
将Split-Attention嵌入到标准的残差块结构中:
python复制class ResNeStBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, radix=2):
super().__init__()
width = planes * (64 // 64) # base_width=64
# 1x1降维
self.conv1 = nn.Conv2d(inplanes, width, 1, bias=False)
self.bn1 = nn.BatchNorm2d(width)
# 3x3 SplitAttention卷积
self.conv2 = SplitAttention(
width, width, kernel_size=3,
stride=stride, radix=radix)
self.bn2 = nn.BatchNorm2d(width)
# 1x1升维
self.conv3 = nn.Conv2d(width, planes*self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
# 下采样路径
self.downsample = None
if stride != 1 or inplanes != planes*self.expansion:
self.downsample = nn.Sequential(
nn.Conv2d(inplanes, planes*self.expansion,
1, stride=stride, bias=False),
nn.BatchNorm2d(planes*self.expansion))
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
关键参数配置建议:
| 参数 | 典型值 | 作用说明 |
|---|---|---|
| radix | 2 | 每个基数组内的划分数量 |
| cardinality | 1 | 基数分组数(类似ResNeXt) |
| reduction_factor | 4 | 注意力路径的通道缩减比例 |
以Faster R-CNN为例,替换主干网络的步骤:
python复制from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
# 使用ResNeSt作为骨干网络
backbone = resnet_fpn_backbone(
'resnest101',
pretrained=True,
trainable_layers=5)
model = FasterRCNN(backbone, num_classes=91)
以DeepLabV3+为例的集成方案:
python复制from torchvision.models.segmentation import deeplabv3_resnet101
model = deeplabv3_resnet101(pretrained=False, num_classes=21)
# 替换骨干网络为ResNeSt
from timm.models import resnest
backbone = resnest.resnest101d(pretrained=True)
model.backbone = backbone
性能优化要点:
在COCO数据集上的性能表现:
| 模型 | mAP@0.5 | 参数量(M) | FLOPS(G) |
|---|---|---|---|
| Faster R-CNN (ResNet50) | 37.4 | 41.5 | 180.5 |
| Faster R-CNN (ResNeSt50) | 41.2 | 42.3 | 182.1 |
| Faster R-CNN (ResNeSt101) | 43.7 | 63.6 | 315.4 |
在ADE20K语义分割上的表现:
| 模型 | mIoU (%) | 训练周期 |
|---|---|---|
| DeepLabV3+ (ResNet101) | 45.7 | 80 |
| DeepLabV3+ (ResNeSt101) | 48.2 | 80 |
python复制# 余弦退火+线性warmup
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=0.1,
steps_per_epoch=len(train_loader),
epochs=100)
python复制model = Model(
...
drop_path_rate=0.2, # Stochastic Depth
norm_layer=partial(nn.BatchNorm2d, eps=1e-5, momentum=0.1)
)
python复制from torch.utils.checkpoint import checkpoint
def forward(self, x):
return checkpoint(self._forward, x)
在多个实际项目中,ResNeSt相比传统ResNet通常能带来1.5-3%的mAP提升,而计算开销仅增加约5-8%。特别是在小目标检测和细粒度分割任务上,Split-Attention机制展现出明显优势。