视觉Transformer(ViT)的核心思想是将标准Transformer架构直接应用于图像数据。与CNN不同,ViT完全摒弃了卷积操作,采用纯注意力机制处理图像。这种设计带来了几个显著特点:
首先,ViT将输入图像分割为固定大小的patch(通常为16×16像素),每个patch经过线性投影后成为类似NLP中的token。这种处理方式彻底改变了传统计算机视觉的范式——不再通过滑动窗口提取局部特征,而是将图像视为一个token序列。
具体实现时,假设输入图像大小为224×224×3(高×宽×通道),使用16×16的patch大小,则会得到196个patch(224/16=14,14×14=196)。每个patch被展平为16×16×3=768维向量,这正是ViT-Base模型采用的维度。
模型的关键组件包括:
python复制class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size,
stride=patch_size)
def forward(self, x):
x = self.proj(x) # [B, C, H, W] -> [B, D, H/P, W/P]
x = x.flatten(2).transpose(1, 2) # [B, D, N] -> [B, N, D]
return x
Patch Embedding模块负责将2D图像转换为1D token序列。虽然可以使用卷积操作实现(如上代码所示),但实际应用中需要注意几个关键点:
实测发现,对于224×224输入图像,16×16分块在准确率和计算效率上达到较好平衡。当处理更高分辨率图像时(如384×384),保持相同patch大小会显著增加序列长度,需要相应调整模型深度。
ViT中的位置编码是可学习的参数,而非Transformer原版的正弦函数。这种设计带来了几个优势:
python复制class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, num_classes=1000):
super().__init__()
self.patch_embed = PatchEmbed(img_size, patch_size)
num_patches = (img_size // patch_size) ** 2
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x) # [B, N, D]
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
return x
位置编码的可视化显示,ViT确实学会了表示2D空间关系——相邻patch的位置编码相似度高,同行/列的patch也表现出明显的相关性。
ViT的Encoder与原始Transformer基本相同,但针对视觉任务做了以下优化:
每个Transformer Block的实现如下:
python复制class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = Mlp(dim, int(dim * mlp_ratio))
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
在实际应用中,发现以下经验性结论:
结合上述模块,完整的ViT实现需要考虑以下组件:
python复制class ViT(nn.Module):
def __init__(self, depth=12, num_heads=12, mlp_ratio=4.):
super().__init__()
self.patch_embed = PatchEmbed()
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
x = self.patch_embed(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x[:, 0]) # 取cls token
return self.head(x)
经过多次实验,总结出以下实用技巧:
对于不同规模的数据集,建议配置:
在ImageNet上从头训练ViT-Base约需300epoch达到80%+准确率,使用预训练权重可大幅缩短微调时间。