Swin Transformer作为计算机视觉领域的重要突破,其核心创新在于将传统Transformer的全局注意力机制改进为基于滑动窗口的局部注意力。这种设计巧妙地结合了CNN的局部性和Transformer的全局建模能力。我第一次在实际项目中尝试Swin Transformer时,最直观的感受就是它比传统ViT模型更高效,特别是在处理高分辨率图像时优势明显。
滑动窗口机制包含两种关键操作:W-MSA(Window Multi-head Self-Attention)和SW-MSA(Shifted Window Multi-head Self-Attention)。W-MSA将图像划分为不重叠的局部窗口,在每个窗口内计算自注意力。这种设计将计算复杂度从图像尺寸的平方降低到线性关系,使得模型能够处理更大尺寸的输入。而SW-MSA则通过窗口偏移操作,在不同层之间建立跨窗口连接,有效解决了局部窗口带来的信息隔离问题。
层级下采样(Patch Merging)是另一个精妙设计。它类似于CNN中的池化操作,但实现方式更加灵活。通过四个stage的逐步下采样,模型能够构建多尺度特征表示,这对于目标检测、语义分割等需要多尺度信息的任务尤为重要。我在实际使用中发现,这种层级结构特别适合处理不同尺度的视觉对象。
Patch Embedding是Swin Transformer处理图像输入的第一道工序,它的作用是将二维图像转换为适合Transformer处理的一维序列。这个过程的代码实现看似简单,但包含了许多值得注意的细节。
python复制class PatchEmbed(nn.Module):
def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
super().__init__()
self.patch_size = (patch_size, patch_size)
self.proj = nn.Conv2d(in_c, embed_dim,
kernel_size=patch_size,
stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
_, _, H, W = x.shape
# 处理非整数倍尺寸的padding
if H % self.patch_size[0] != 0 or W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
0, self.patch_size[0] - H % self.patch_size[0],
0, 0))
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
这段代码有几个关键点值得关注:首先,它使用卷积操作实现patch划分,kernel_size和stride都设置为patch_size,这种实现方式比传统的分割+展平操作更高效。其次,forward方法中包含了自动padding处理,这保证了无论输入图像尺寸如何,都能被正确划分为整数个patch。我在实际项目中就遇到过因为忽略padding而导致模型崩溃的情况,这个细节处理非常实用。
输出维度方面,Patch Embedding会将输入图像从[B,C,H,W]转换为[B,L,C]的形式,其中L=H/patch_size * W/patch_size。这种表示方式既保留了空间信息的相对位置关系,又适合后续的Transformer处理。
滑动窗口注意力是Swin Transformer最具创新性的部分,其PyTorch实现包含了许多精妙的设计选择。WindowAttention类实现了核心的窗口注意力计算,其中相对位置编码的处理尤为关键。
python复制class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
# 相对位置编码表
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
# 相对位置索引
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += window_size[0] - 1
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
相对位置编码的实现有几个巧妙之处:首先,它使用了一个可学习的相对位置偏置表(relative_position_bias_table),而不是固定的正弦编码。其次,通过精心设计的索引计算,将二维相对位置映射到一维的偏置表中,这种实现既节省内存又高效。在实际应用中,我发现这种相对位置编码对小目标检测特别有帮助。
注意力计算部分采用了标准的QKV形式,但加入了相对位置偏置。这种设计使得模型能够感知patch之间的相对位置关系,同时保持了平移等变性。我在自定义窗口大小时曾遇到过注意力权重不稳定的问题,后来发现是因为忽略了scale因子的调整,这点需要特别注意。
Swin Transformer的层级结构通过Patch Merging实现,这是模型能够处理多尺度信息的关键。Patch Merging的操作类似于CNN中的池化层,但实现方式更具Transformer特色。
python复制class PatchMerging(nn.Module):
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# 处理奇数尺寸的padding
if H % 2 == 1 or W % 2 == 1:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
# 间隔采样并拼接
x0 = x[:, 0::2, 0::2, :]
x1 = x[:, 1::2, 0::2, :]
x2 = x[:, 0::2, 1::2, :]
x3 = x[:, 1::2, 1::2, :]
x = torch.cat([x0, x1, x2, x3], -1)
x = x.view(B, -1, 4 * C)
x = self.norm(x)
x = self.reduction(x)
return x
Patch Merging的实现有几个值得注意的细节:首先,它采用间隔采样的方式将2x2邻域的特征图拼接起来,这比直接使用最大池化保留了更多信息。其次,通过线性变换将通道数从4C降到2C,实现了特征压缩。我在实验中发现,这种设计比简单的池化操作能更好地保留空间信息。
层级结构的另一个关键点是BasicLayer的实现,它组合了多个Swin Transformer Block和一个可选的Patch Merging层。这种设计使得模型能够在不同尺度上建立远程依赖关系,对于密集预测任务特别有效。在实际部署时,可以根据任务需求灵活调整各stage的深度和通道数。
将各个组件集成为完整的Swin Transformer模型时,有许多实践经验值得分享。模型初始化、深度衰减率设置等细节都会显著影响最终性能。
python复制class SwinTransformer(nn.Module):
def __init__(self, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96,
depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), window_size=7,
mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0.1, norm_layer=nn.LayerNorm, patch_norm=True):
super().__init__()
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
# stochastic depth衰减规则
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
# 构建各stage
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layers = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None)
self.layers.append(layers)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
模型初始化采用截断正态分布,这对训练稳定性很重要。drop path rate采用线性递增策略,浅层使用较小的drop率,深层使用较大的drop率,这种设计符合深度网络的训练特性。我在实验中发现,合理设置drop path rate可以显著提升模型性能,特别是在数据量不足的情况下。
实际部署时,window_size的选择需要权衡计算效率和模型性能。较大的窗口能捕获更长距离的依赖关系,但会显著增加计算量。对于图像分类任务,7x7的窗口通常是不错的选择;而对于目标检测任务,可能需要尝试更大的窗口尺寸。