在图像分类任务中,我们经常会遇到一个令人头疼的问题:某些类别的样本数量远远多于其他类别。比如在医疗影像分析中,正常样本可能占总数据的90%,而异常样本只占10%。这种情况下,如果直接使用普通的交叉熵损失函数,模型会倾向于把所有样本都预测为多数类,因为这样就能轻松获得很高的准确率。
我曾在工业质检项目中遇到过类似情况。当时我们需要检测产品表面的缺陷,但缺陷样本只占总数据的5%。最初使用普通交叉熵训练时,模型准确率高达95%,看似不错,但实际上它把所有样本都预测为"正常",完全没学会识别缺陷。这就是典型的类别不平衡问题。
加权交叉熵损失函数的聪明之处在于,它给少数类分配更高的权重,让模型在训练时更关注这些容易被忽略的样本。具体来说,如果一个类别在数据集中出现频率较低,我们就给它一个较大的权重值,这样当模型误判这个类别的样本时,就会受到更大的惩罚。
要理解加权交叉熵,我们得先回顾下普通交叉熵损失函数。对于一个C类分类问题,给定真实标签y和预测概率p,交叉熵损失定义为:
python复制def cross_entropy(y, p):
return -torch.sum(y * torch.log(p))
这个公式计算的是预测概率分布与真实分布之间的"距离"。当预测完全正确时(p=y),损失值为0;预测越不准,损失值越大。
加权交叉熵在此基础上引入了一个权重向量w,其中w_i表示第i个类别的权重。改进后的公式变为:
python复制def weighted_cross_entropy(y, p, w):
return -torch.sum(w * y * torch.log(p))
这个简单的改动带来了显著的效果提升。权重w就像一个调节器,可以控制模型对不同类别错误的敏感程度。在实践中,我们通常会将少数类的权重设得比多数类大,迫使模型更关注这些样本。
如何确定各个类别的权重呢?常见的方法有:
逆类别频率法:权重与类别频率成反比
python复制weights = 1.0 / class_counts
平滑逆频率法:在分母上加一个平滑项防止极端值
python复制weights = 1.0 / (class_counts + epsilon)
有效样本数法:考虑类别间的相对关系
python复制beta = 0.99
effective_num = 1.0 - beta**class_counts
weights = (1.0 - beta) / effective_num
我在实际项目中发现,第三种方法在极端不平衡的数据集上(比如1:100的比例)表现尤为出色。
PyTorch已经内置了加权交叉熵的实现,使用起来非常简单:
python复制import torch.nn as nn
# 假设我们有3个类别,其样本比例为 [100, 50, 10]
class_weights = torch.tensor([1.0, 2.0, 10.0])
criterion = nn.CrossEntropyLoss(weight=class_weights)
这里的关键是weight参数,它接受一个与类别数相同长度的张量。注意权重不需要求和为1,PyTorch会自动进行归一化处理。
有时候固定的权重可能不够灵活。我们可以实现一个动态调整权重的版本:
python复制class DynamicWeightedCE(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.num_classes = num_classes
def forward(self, inputs, targets):
# 计算当前batch的类别分布
batch_counts = torch.bincount(targets, minlength=self.num_classes)
batch_weights = 1.0 / (batch_counts.float() + 1e-6)
return nn.functional.cross_entropy(
inputs, targets, weight=batch_weights.to(inputs.device)
)
这种方法会根据每个batch的实际类别分布动态调整权重,特别适合在线学习场景。
当遇到极端不平衡的情况时(比如1:1000),我发现以下技巧很有效:
对权重进行对数缩放:
python复制weights = torch.log(1.0 / (class_counts + 1e-6) + 1.0)
对少数类样本进行梯度裁剪:
python复制loss = criterion(outputs, labels)
loss = torch.clamp(loss, max=10.0) # 防止少数类样本梯度爆炸
结合Focal Loss使用:
python复制pt = torch.exp(-loss)
focal_loss = (1 - pt)**gamma * loss # gamma通常取2
让我们看一个真实案例。假设我们要建立一个胸部X光片分类模型,区分正常、肺炎和COVID-19三种情况。数据分布如下:
| 类别 | 样本数 | 权重 |
|---|---|---|
| 正常 | 8000 | 1.0 |
| 肺炎 | 2000 | 4.0 |
| COVID-19 | 500 | 16.0 |
python复制model = torchvision.models.resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 3)
weights = torch.tensor([1.0, 4.0, 16.0]).cuda()
criterion = nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
使用普通交叉熵和加权交叉熵的对比结果:
| 指标 | 普通CE | 加权CE |
|---|---|---|
| 整体准确率 | 89% | 86% |
| COVID-19召回率 | 12% | 68% |
| 肺炎F1分数 | 45% | 72% |
虽然加权版本的整体准确率略有下降,但对关键少数类的识别能力大幅提升。在医疗场景中,这比整体准确率更重要——我们宁可多检查一些假阳性的COVID-19病例,也不能漏诊真正的患者。
在将加权交叉熵模型部署到生产环境时,有几个要点需要注意:
我在一个实际项目中就遇到过这样的情况:初期COVID-19样本很少,我们给了很高的权重;随着疫情发展,这类样本增多后,原来的权重设置反而影响了模型表现。后来我们改为每月自动重新计算权重,问题就解决了。
加权交叉熵不是解决类别不平衡的唯一方法。让我们看看它与其他常见方法的对比:
| 方法 | 优点 | 缺点 |
|---|---|---|
| 过采样 | 保持所有原始信息 | 可能导致过拟合,增加训练时间 |
| 欠采样 | 训练速度快 | 丢失多数类有用信息 |
| 加权交叉熵 | 不改变数据分布,计算高效 | 需要仔细调权 |
Focal Loss是另一种处理不平衡的损失函数,它通过降低易分类样本的权重来聚焦难样本。两者对比:
| 特性 | 加权交叉熵 | Focal Loss |
|---|---|---|
| 调整方式 | 按类别调整 | 按样本难度调整 |
| 超参数 | 类别权重 | 聚焦参数γ |
| 计算复杂度 | 低 | 稍高 |
| 适用场景 | 类别间不平衡 | 类别内难易样本不平衡 |
在实践中,我发现对于纯粹的类别不平衡问题,加权交叉熵通常足够;而对于同时存在类别不平衡和样本难度差异的情况(如目标检测),Focal Loss可能更合适。
有时候,将加权交叉熵与其他技术结合会得到更好的效果。比如:
加权交叉熵 + 类别平衡采样:
python复制sampler = WeightedRandomSampler(weights, num_samples)
dataloader = DataLoader(dataset, sampler=sampler)
加权交叉熵 + 迁移学习:
先在不平衡数据上预训练,再在平衡子集上微调
加权交叉熵 + 模型集成:
训练多个不同权重的模型,然后集成预测
在我的一个项目中,使用加权交叉熵+类别平衡采样的组合,将少数类的召回率从55%提升到了82%,而整体准确率只下降了3个百分点。
当少数类权重设置过大时,可能会遇到梯度爆炸或训练震荡的问题。解决方法包括:
对权重进行归一化:
python复制weights = weights / weights.max()
使用梯度裁剪:
python复制torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
降低学习率:
python复制optimizer = Adam(model.parameters(), lr=1e-5)
有时验证损失在下降,但关心的业务指标(如召回率)却没有改善。这时可以:
对于多标签场景(一个样本可能属于多个类别),需要对加权交叉熵稍作修改:
python复制def weighted_bce_with_logits(logits, targets, weights):
loss = weights * (targets * -torch.log_sigmoid(logits) +
(1 - targets) * -torch.log_sigmoid(-logits))
return loss.mean()
这个实现会给每个标签独立的权重,适用于标签间也存在不平衡的情况。
与其手动设置权重,我们可以让模型自动学习最优权重:
python复制class LearnableWeightedCE(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.weights = nn.Parameter(torch.ones(num_classes))
def forward(self, inputs, targets):
return nn.functional.cross_entropy(
inputs, targets, weight=torch.softmax(self.weights, dim=0)
)
这种方法特别适合类别分布经常变化或未知的场景。