在计算机视觉领域,注意力机制已经成为现代视觉Transformer架构的核心组件。CVPR 2023上提出的BiFormer通过创新的**双层路由注意力(Bi-Level Routing Attention, BRA)**机制,在计算效率和模型性能之间取得了显著平衡。本文将带您从零开始,用PyTorch实现这一前沿注意力机制,并深入解析其工程实现细节。
BRA的核心创新在于动态稀疏注意力的设计。与传统Transformer中所有查询都要与所有键值对计算注意力不同,BRA采用两级路由策略:
这种设计带来了三大优势:
python复制# BRA计算流程伪代码
def bra_forward(x):
# 输入x: (B, H, W, C)
# 1. 区域划分与投影
q, kv = qkv_projection(x) # 得到查询和键值
# 2. 区域级路由
region_q = avg_pool(q) # 区域级查询
region_k = avg_pool(kv[..., :qk_dim]) # 区域级键
affinity = region_q @ region_k.transpose() # 区域亲和度
topk_idx = topk(affinity) # 每个区域选top-k邻居
# 3. Token级注意力
gathered_kv = gather(kv, topk_idx) # 收集相关键值
output = attention(q, gathered_kv) # 局部注意力
return output
路由模块是BRA的核心,负责确定每个区域应该关注哪些其他区域。其实现代码需要考虑以下工程细节:
python复制class TopkRouting(nn.Module):
def __init__(self, qk_dim, topk=4, qk_scale=None, diff_routing=True):
super().__init__()
self.topk = topk
self.scale = qk_scale or qk_dim ** -0.5
self.diff_routing = diff_routing
# 可学习参数增强路由能力
self.proj = nn.Linear(qk_dim, qk_dim)
self.act = nn.Softmax(dim=-1)
def forward(self, query, key):
"""
输入: query/key - (B, num_regions, qk_dim)
输出:
- routing_weights: (B, num_regions, topk)
- topk_indices: (B, num_regions, topk)
"""
if not self.diff_routing:
query, key = query.detach(), key.detach()
# 增强特征表达
query, key = self.proj(query), self.proj(key)
# 计算区域亲和度
affinity = (query * self.scale) @ key.transpose(-2, -1)
# Top-k选择
weights, indices = torch.topk(affinity, k=self.topk, dim=-1)
return self.act(weights), indices
提示:实际部署时,可以考虑将topk操作替换为稀疏矩阵运算,进一步优化内存使用
路由完成后,需要高效地收集相关键值对并进行注意力计算。这里有两个关键优化点:
python复制class KVGather(nn.Module):
def __init__(self):
super().__init__()
def forward(self, r_idx, r_weight, kv):
"""
输入:
r_idx: (B, num_regions, topk)
r_weight: (B, num_regions, topk)
kv: (B, num_regions, tokens_per_region, dim)
输出:
gathered_kv: (B, num_regions, topk, tokens_per_region, dim)
"""
B, N, K = r_idx.shape
_, _, T, D = kv.shape
# 使用gather实现高效收集
expanded_idx = r_idx.view(B, N, K, 1, 1).expand(-1, -1, -1, T, D)
gathered = kv.gather(1, expanded_idx)
# 加权处理
return r_weight.view(B, N, K, 1, 1) * gathered
class BRAttention(nn.Module):
def __init__(self, dim, heads=8):
super().__init__()
self.dim = dim
self.heads = heads
self.scale = (dim // heads) ** -0.5
# 初始化QKV投影和输出层
self.to_qkv = nn.Linear(dim, dim * 3)
self.to_out = nn.Linear(dim, dim)
def forward(self, x):
B, H, W, C = x.shape
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
# 分头处理
q = q.view(B, H*W, self.heads, -1).transpose(1, 2)
k = k.view(B, H*W, self.heads, -1).transpose(1, 2)
v = v.view(B, H*W, self.heads, -1).transpose(1, 2)
# 注意力计算
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = attn @ v
out = out.transpose(1, 2).reshape(B, H, W, C)
return self.to_out(out)
将各子模块整合为完整的BRA模块,需要处理以下工程细节:
python复制class BiLevelRoutingAttention(nn.Module):
def __init__(self, dim, n_win=7, num_heads=8, topk=4,
kv_downsample_ratio=4, kv_downsample_mode='identity',
side_dwconv=5, auto_pad=True):
super().__init__()
self.dim = dim
self.n_win = n_win
self.num_heads = num_heads
self.topk = topk
self.auto_pad = auto_pad
# 初始化各子模块
self.router = TopkRouting(qk_dim=dim//2, topk=topk)
self.kv_gather = KVGather()
# 局部位置编码
self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv,
stride=1, padding=side_dwconv//2, groups=dim)
# 键值下采样
if kv_downsample_mode == 'avgpool':
self.kv_down = nn.AvgPool2d(kv_downsample_ratio)
elif kv_downsample_mode == 'maxpool':
self.kv_down = nn.MaxPool2d(kv_downsample_ratio)
else:
self.kv_down = nn.Identity()
# QKV投影
self.qkv = nn.Linear(dim, dim * 2 + dim//2)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
# 自动填充处理
if self.auto_pad:
N, H_in, W_in, C = x.size()
pad_r = (self.n_win - W_in % self.n_win) % self.n_win
pad_b = (self.n_win - H_in % self.n_win) % self.n_win
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
N, H, W, C = x.shape
# 区域划分
x_region = rearrange(x, 'n (h w) (j i) c -> n (h j) (w i) c',
j=self.n_win, i=self.n_win)
# QKV投影
q, k, v = self.qkv(x_region).split([self.dim//2, self.dim//2, self.dim], dim=-1)
# 区域级路由
q_region = q.mean(dim=2)
k_region = k.mean(dim=2)
r_weight, r_idx = self.router(q_region, k_region)
# 键值收集与注意力计算
kv = torch.cat([k, v], dim=-1)
kv_down = self.kv_down(rearrange(kv, 'n h w c -> n c h w'))
kv_down = rearrange(kv_down, 'n c h w -> n h w c')
kv_selected = self.kv_gather(r_idx, r_weight, kv_down)
k_selected, v_selected = kv_selected.split([self.dim//2, self.dim], dim=-1)
# 分头注意力计算
out = self._attention(q, k_selected, v_selected)
# 添加局部位置编码
lepe = self.lepe(rearrange(v, 'n h w c -> n c h w'))
lepe = rearrange(lepe, 'n c h w -> n h w c')
out = out + lepe
# 恢复原始尺寸
out = rearrange(out, 'n (h w) j i c -> n (h j) (w i) c', h=H//self.n_win, w=W//self.n_win)
# 移除填充部分
if self.auto_pad and (pad_r > 0 or pad_b > 0):
out = out[:, :H_in, :W_in, :]
return self.proj(out)
def _attention(self, q, k, v):
# 分头处理与注意力计算
B, N, T, C = q.shape # T是每个区域的token数
q = q.view(B, N, T, self.num_heads, -1).transpose(2, 3)
k = k.view(B, N, self.topk, T, self.num_heads, -1).permute(0,1,4,2,3,5)
v = v.view(B, N, self.topk, T, self.num_heads, -1).permute(0,1,4,2,3,5)
# 合并topk维度
k = k.reshape(B, N, self.num_heads, -1, C//self.num_heads)
v = v.reshape(B, N, self.num_heads, -1, C//self.num_heads)
# 注意力计算
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = attn @ v
# 合并头维度
out = out.transpose(2, 3).reshape(B, N, T, C)
return out
将BRA模块集成到实际网络中时,需要注意以下几点:
BRA可以无缝替换标准Transformer中的注意力模块。在Swin、PVT等金字塔架构中,只需将原有注意力模块替换为BRA,同时保持其他部分不变。
python复制class BRABlock(nn.Module):
def __init__(self, dim, num_heads, window_size=7):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = BiLevelRoutingAttention(dim, n_win=window_size, num_heads=num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
根据实际任务需求调整以下关键参数:
| 参数 | 典型值 | 影响 | 适用场景 |
|---|---|---|---|
| n_win | 7-14 | 区域大小 | 大值适合高分辨率输入 |
| topk | 2-8 | 路由数量 | 小值提升速度,大值提升精度 |
| kv_downsample_ratio | 1-8 | 键值下采样率 | 高比率减少计算量 |
| num_heads | 4-16 | 注意力头数 | 更多头增加模型容量 |
内存溢出:
训练不稳定:
性能不佳:
在自定义数据集上微调时,建议初始使用较小的topk值(如2-4),并逐步增加。实践中发现,BRA对学习率较为敏感,通常需要比标准注意力更小的学习率(约0.5-0.8倍)。