目标检测任务中最大的痛点之一就是样本不平衡问题。想象一下,在一张城市街景图中,车辆和行人可能只占整张图的5%不到,剩下95%都是背景。这种极端不平衡的数据分布会导致模型训练时被大量简单背景样本"带偏",就像班级里90%的学生都能轻松考满分,老师自然会把精力放在剩下10%的困难学生身上。
传统交叉熵损失(CE)对所有样本一视同仁的缺点在这里暴露无遗。我曾在某交通监控项目中遇到过这种情况:使用普通CE损失训练后,模型对遮挡车辆和小尺寸行人的召回率不足30%。后来分析发现,这些困难样本的梯度被海量简单样本的梯度淹没,就像微弱信号被噪声覆盖。
OHEM(Online Hard Example Mining)的聪明之处在于它像一位经验丰富的教练,能自动识别哪些样本对当前模型最有挑战性。具体来说,它会:
先看传统交叉熵损失:
code复制CE = -[y*log(p) + (1-y)*log(1-p)]
OHEM-CE在此基础上增加了三重过滤机制:
用代码表示这个逻辑更直观:
python复制if 是正样本:
保留
elif 是负样本且loss > thresh:
保留
elif 保留样本数 < n_min:
补充选择top n_min高loss样本
else:
忽略
thresh参数是OHEM的核心开关,它决定了什么样的样本算"困难"。这里有个容易混淆的点:代码中的self.thresh实际是-log(thresh)。比如设置thresh=0.7时:
python复制self.thresh = -torch.log(torch.tensor(0.7)) # ≈0.3567
这意味着当样本预测概率p<0.7时,其损失值-log(p)就会大于0.3567,从而被判定为困难样本。我在多个项目实验中发现,0.6-0.8是比较通用的推荐范围:
ignore_simple_sample_factor这个参数决定了n_min的大小:
python复制n_min = 总有效像素数 // ignore_simple_sample_factor
经过大量实验验证,我总结出这些经验:
OHEM与Focal Loss是互补关系而非替代:
典型组合配置示例:
python复制loss_func = OhemCELoss(
thresh=0.7,
lb_ignore=255,
ignore_simple_sample_factor=16
)
在MMDetection中集成OHEM只需两步:
python复制model=dict(
train_cfg=dict(
rpn=dict(
ohem=dict(
enable=True,
thresh=0.7,
n_min=256
)
)
)
)
python复制loss_cls=dict(
type='OhemCrossEntropyLoss',
thresh=0.7,
min_kept=256
)
基于实际项目经验分享几个关键点:
一个完整的训练周期示例:
python复制# 初始化
optimizer = torch.optim.SGD(
params=model.parameters(),
lr=0.02*0.7, # 常规学习率的70%
momentum=0.9,
weight_decay=0.0001
)
# 学习率调度
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[8, 11] # 适当延后衰减点
)
loss突然变NaN:
验证指标不升反降:
GPU内存溢出:
torch.where替代布尔索引加速筛选:python复制# 优化前
loss_hard = loss[loss > self.thresh]
# 优化后
mask = (loss > self.thresh).float()
loss_hard = torch.sum(loss * mask) / (torch.sum(mask) + 1e-6)
python复制# 根据目标尺寸动态调整阈值
scale = get_scale_factor(targets) # 自定义尺度计算
dynamic_thresh = base_thresh * scale
python复制with autocast():
loss = criterion(logits, labels)
# 需要手动处理scaler.scale(loss).backward()