在目标检测领域,传统方法如Faster R-CNN和YOLO系列长期占据主导地位,它们都依赖于一个共同的设计元素——锚框(Anchor Boxes)。这些预定义的边界框虽然有效,却也带来了复杂的超参数调整和计算开销。2019年提出的CenterNet以完全不同的思路刷新了目标检测的范式,本文将带您深入理解这一创新方法,并基于PyTorch框架从零实现一个完整的检测系统。
传统目标检测方法的核心痛点在于锚框的设计。以YOLOv3为例,每个特征图位置需要预设9个不同比例大小的锚框,这不仅增加了模型复杂度,还引入了大量负样本。CenterNet的作者提出了一个革命性的观点:将目标视为其边界框的中心点,从根本上摆脱了对锚框的依赖。
这种思想带来三个显著优势:
CenterNet的检测流程可以概括为三个关键步骤:
python复制# CenterNet的三大预测头结构示意
class CenterNetHead(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
# 热力图预测分支
self.heatmap = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, num_classes, 1)
)
# 宽高预测分支
self.wh = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 2, 1)
)
# 偏移量预测分支
self.offset = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 2, 1)
)
CenterNet支持多种骨干网络,论文中主要评估了Hourglass-104、ResNet和DLA-34三种架构。考虑到计算资源限制和实现难度,我们选择ResNet-50作为基础特征提取器。ResNet的残差连接设计能有效缓解深层网络的梯度消失问题,其分阶段下采样的特性也适合目标检测任务。
ResNet-50的特征提取过程可分为四个阶段:
python复制class ResNetBackbone(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
original = torchvision.models.resnet50(pretrained=pretrained)
# 拆解ResNet获取各阶段特征
self.conv1 = original.conv1
self.bn1 = original.bn1
self.relu = original.relu
self.maxpool = original.maxpool
self.layer1 = original.layer1 # 输出256x128x128
self.layer2 = original.layer2 # 输出512x64x64
self.layer3 = original.layer3 # 输出1024x32x32
self.layer4 = original.layer4 # 输出2048x16x16
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x1 = self.layer1(x) # 1/4分辨率
x2 = self.layer2(x1) # 1/8分辨率
x3 = self.layer3(x2) # 1/16分辨率
x4 = self.layer4(x3) # 1/32分辨率
return x4
直接从骨干网络输出的16x16特征图分辨率过低,难以精确定位目标中心。CenterNet采用三级反卷积模块逐步上采样特征图:
| 上采样阶段 | 输入尺寸 | 卷积核 | 输出尺寸 | 通道数 |
|---|---|---|---|---|
| 第一次反卷积 | 16x16 | 4x4 | 32x32 | 256 |
| 第二次反卷积 | 32x32 | 4x4 | 64x64 | 128 |
| 第三次反卷积 | 64x64 | 4x4 | 128x128 | 64 |
这种渐进式上采样策略能在保持计算效率的同时恢复空间细节。实验表明,将特征图放大到输入图像的1/4分辨率(128x128)在精度和速度间取得了良好平衡。
python复制class DeconvLayer(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.deconv1 = nn.ConvTranspose2d(in_channels, 256, 4, stride=2, padding=1)
self.bn1 = nn.BatchNorm2d(256)
self.deconv2 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
self.bn2 = nn.BatchNorm2d(128)
self.deconv3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
self.bn3 = nn.BatchNorm2d(64)
def forward(self, x):
x = F.relu(self.bn1(self.deconv1(x))) # 16->32
x = F.relu(self.bn2(self.deconv2(x))) # 32->64
x = F.relu(self.bn3(self.deconv3(x))) # 64->128
return x
热力图是CenterNet最核心的创新点,每个像素值表示该位置存在目标中心点的置信度。对于标注的真实框,我们首先计算其中心点坐标,然后在热力图上以该点为中心生成二维高斯分布:
code复制高斯半径计算公式:
r = √(w*h)/6
其中w和h是目标框的宽度和高度
这种设计使得靠近中心点的位置获得较高响应,同时考虑了目标大小的影响。在代码实现中,我们需要特别注意边界情况的处理:
python复制def draw_gaussian(heatmap, center, radius, k=1):
diameter = 2 * radius + 1
gaussian = gaussian2D((diameter, diameter), sigma=diameter/6)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[:2]
# 处理边界情况
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y-top:y+bottom, x-left:x+right]
masked_gaussian = gaussian[radius-top:radius+bottom, radius-left:radius+right]
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
np.maximum(masked_heatmap, masked_gaussian*k, out=masked_heatmap)
return heatmap
CenterNet的损失函数由三部分组成,每部分都有其独特的考量:
热力图损失(Focal Loss):
偏移量损失(L1 Loss):
宽高损失(L1 Loss):
python复制class CenterNetLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred_hm, pred_wh, pred_offset, gt_hm, gt_wh, gt_offset, reg_mask):
# 热力图损失
pos_inds = gt_hm.eq(1).float()
neg_inds = gt_hm.lt(1).float()
neg_weights = torch.pow(1 - gt_hm, 4)
pred_hm = torch.clamp(pred_hm, 1e-6, 1-1e-6)
pos_loss = torch.log(pred_hm) * torch.pow(1 - pred_hm, 2) * pos_inds
neg_loss = torch.log(1 - pred_hm) * torch.pow(pred_hm, 2) * neg_weights * neg_inds
num_pos = pos_inds.sum()
hm_loss = -(pos_loss.sum() + neg_loss.sum()) / max(1, num_pos)
# 偏移量损失
pred_offset = pred_offset.permute(0,2,3,1).contiguous()
gt_offset = gt_offset.permute(0,2,3,1).contiguous()
reg_mask = reg_mask.unsqueeze(-1).expand_as(gt_offset)
offset_loss = F.l1_loss(pred_offset*reg_mask, gt_offset*reg_mask, reduction='sum')
offset_loss = offset_loss / (num_pos + 1e-4)
# 宽高损失
pred_wh = pred_wh.permute(0,2,3,1).contiguous()
gt_wh = gt_wh.permute(0,2,3,1).contiguous()
wh_loss = F.l1_loss(pred_wh*reg_mask, gt_wh*reg_mask, reduction='sum')
wh_loss = 0.1 * wh_loss / (num_pos + 1e-4)
return hm_loss + offset_loss + wh_loss
有效的数据增强对目标检测至关重要,特别是CenterNet这种依赖中心点定位的方法。我们采用以下增强策略的组合:
几何变换:
色彩扰动:
python复制class CenterNetDataset(Dataset):
def __getitem__(self, index):
# 读取原始图像和标注
img, bboxes = self.load_annotation(index)
# 随机缩放
scale = random.uniform(0.6, 1.4)
new_h = int(img.shape[0] * scale)
new_w = int(img.shape[1] * scale)
img = cv2.resize(img, (new_w, new_h))
bboxes[:, [0,2]] *= scale
bboxes[:, [1,3]] *= scale
# 随机翻转
if random.random() < 0.5:
img = img[:, ::-1]
bboxes[:, [0,2]] = img.shape[1] - bboxes[:, [2,0]]
# 色彩扰动
img = self.color_jitter(img)
# 标准化和通道调整
img = (img.astype(np.float32) / 255. - self.mean) / self.std
img = img.transpose(2, 0, 1)
return img, bboxes
采用两阶段训练策略能有效平衡训练效率和模型性能:
冻结阶段(前50个epoch):
解冻阶段(后50个epoch):
python复制def train():
# 初始化模型
model = CenterNet(backbone='resnet50', num_classes=20)
# 第一阶段:冻结骨干网络
for param in model.backbone.parameters():
param.requires_grad = False
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_loader = DataLoader(..., batch_size=16)
for epoch in range(50):
train_one_epoch(model, optimizer, train_loader)
# 第二阶段:解冻全部参数
for param in model.parameters():
param.requires_grad = True
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
train_loader = DataLoader(..., batch_size=8)
for epoch in range(50, 100):
train_one_epoch(model, optimizer, train_loader)
预测阶段首先对热力图应用3x3最大池化进行非极大抑制,保留局部最大值点作为候选中心点:
python复制def heatmap_nms(heatmap, kernel=3):
pad = (kernel - 1) // 2
hmax = F.max_pool2d(heatmap, kernel, stride=1, padding=pad)
keep = (hmax == heatmap).float()
return heatmap * keep
从热力图中提取topK个中心点后,结合偏移量和宽高预测生成最终检测框:
python复制def decode_bbox(pred_hm, pred_wh, pred_offset, threshold=0.3, topk=100):
# 热力图非极大抑制
pred_hm = heatmap_nms(pred_hm)
# 获取批次大小和类别数
batch, num_classes, height, width = pred_hm.shape
# 解码每个样本的预测
detections = []
for b in range(batch):
# 展平热力图并获取topK得分和索引
heatmap = pred_hm[b].permute(1, 2, 0).view(-1, num_classes)
scores, indices = heatmap.topk(topk, dim=0)
# 获取类别ID和坐标
class_ids = torch.arange(num_classes).view(1, -1)
class_ids = class_ids.expand(topk, num_classes).contiguous().view(-1)
indices = indices.view(-1)
# 过滤低置信度预测
keep = scores.view(-1) > threshold
if keep.sum() == 0:
detections.append([])
continue
# 获取保留的预测
scores = scores.view(-1)[keep]
class_ids = class_ids[keep]
indices = indices[keep]
# 计算坐标
y = indices // width
x = indices % width
# 应用偏移量
offset = pred_offset[b].permute(1, 2, 0).view(-1, 2)
offset = offset[indices]
x += offset[:, 0]
y += offset[:, 1]
# 应用宽高预测
wh = pred_wh[b].permute(1, 2, 0).view(-1, 2)
wh = wh[indices]
boxes = torch.stack([
x - wh[:, 0]/2, # x1
y - wh[:, 1]/2, # y1
x + wh[:, 0]/2, # x2
y + wh[:, 1]/2 # y2
], dim=1)
# 归一化到0-1范围
boxes[:, [0, 2]] /= width
boxes[:, [1, 3]] /= height
# 组合最终检测结果
detections.append(torch.cat([
boxes,
scores.unsqueeze(1),
class_ids.unsqueeze(1).float()
], dim=1))
return detections
在实际部署中,我们可以采用以下技巧提升推理速度:
python复制# TensorRT优化示例
def build_engine(onnx_path, engine_path):
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
with open(onnx_path, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)
serialized_engine = builder.build_serialized_network(network, config)
with open(engine_path, 'wb') as f:
f.write(serialized_engine)
return serialized_engine
在真实场景部署CenterNet时,需要注意以下问题:
小目标检测:
密集目标场景:
类别不平衡:
经过完整实现和优化后,基于ResNet-50的CenterNet在VOC测试集上可以达到72.3%的mAP,同时保持45 FPS的推理速度(在RTX 2080Ti上测试)。相比传统的锚框方法,这种无锚框设计不仅简化了实现流程,还展现了更优雅的检测范式。