当你面对一个猫狗分类数据集,发现其中90%的样本都是狗,只有10%是猫时,传统的随机采样方法会让模型严重偏向于识别狗。这种数据不平衡问题在医疗影像分析、金融欺诈检测等领域尤为常见。本文将带你深入理解PyTorch中的WeightedRandomSampler,从底层原理到完整代码实现,彻底解决少数类样本被模型忽视的问题。
数据不平衡是指不同类别的样本数量存在显著差异。以信用卡欺诈检测为例,正常交易可能占99.9%,而欺诈交易仅占0.1%。这种极端不平衡会导致模型倾向于预测多数类,因为即使完全忽略少数类,也能获得很高的准确率。
数据不平衡带来的主要问题包括:
常见解决方案对比:
| 方法 | 优点 | 缺点 |
|---|---|---|
| 过采样少数类 | 不丢失信息 | 可能导致过拟合 |
| 欠采样多数类 | 计算效率高 | 丢失有价值信息 |
| 类别权重调整 | 简单易实现 | 对极端不平衡效果有限 |
| WeightedRandomSampler | 平衡效果好 | 需要合理设置权重 |
WeightedRandomSampler是PyTorch提供的一种采样策略,它通过为每个样本分配不同的采样权重,使得少数类样本有更高的概率被选中。其核心参数包括:
python复制torch.utils.data.WeightedRandomSampler(
weights, # 每个样本的权重序列
num_samples, # 需要采样的总数
replacement=True # 是否允许重复采样
)
关键点解析:
注意:权重是赋给每个样本而非类别,需要为数据集中的每个样本单独计算权重
让我们通过一个猫狗分类的完整示例,演示如何应用WeightedRandomSampler。
首先,我们需要计算每个样本的权重。假设我们的数据集分布如下:
python复制# 样本标签分布示例
labels = ['cat', 'cat', 'dog', 'dog', 'dog', 'dog', 'dog']
计算权重的完整代码:
python复制import numpy as np
from collections import Counter
def calculate_sample_weights(labels):
# 统计每个类别的出现次数
class_counts = Counter(labels)
# 计算每个类别的权重(与频率成反比)
class_weights = {cls: 1.0/count for cls, count in class_counts.items()}
# 为每个样本分配对应类别的权重
sample_weights = [class_weights[cls] for cls in labels]
return sample_weights
# 示例使用
labels = ['cat', 'cat', 'dog', 'dog', 'dog', 'dog', 'dog']
weights = calculate_sample_weights(labels)
print(f"样本权重: {weights}")
输出结果:
code复制样本权重: [0.5, 0.5, 0.2, 0.2, 0.2, 0.2, 0.2]
接下来,我们将权重应用到PyTorch的数据加载流程中:
python复制import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
class CatDogDataset(Dataset):
def __init__(self, images, labels):
self.images = images
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.images[idx], self.labels[idx]
# 假设我们已经准备好了图像数据和标签
images = [...] # 实际的图像数据
labels = [...] # 实际的标签数据
# 创建数据集
dataset = CatDogDataset(images, labels)
# 计算样本权重
weights = calculate_sample_weights(labels)
weights = torch.DoubleTensor(weights)
# 创建采样器
sampler = WeightedRandomSampler(
weights,
num_samples=len(weights), # 通常与数据集大小相同
replacement=True # 允许重复采样
)
# 创建DataLoader
dataloader = DataLoader(
dataset,
batch_size=32,
sampler=sampler,
num_workers=4
)
在训练循环中,使用带采样器的DataLoader与常规方式无异:
python复制model = ... # 你的模型定义
criterion = ... # 损失函数
optimizer = ... # 优化器
for epoch in range(num_epochs):
for batch_images, batch_labels in dataloader:
# 前向传播
outputs = model(batch_images)
# 计算损失
loss = criterion(outputs, batch_labels)
# 反向传播与优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
提示:即使使用了WeightedRandomSampler,在计算损失时也可以额外应用类别权重,双重保障模型对少数类的关注
对于极度不平衡的数据(如1:1000),可以考虑以下策略:
当数据集非常大时,WeightedRandomSampler可能会成为性能瓶颈。可以考虑:
需要注意的是:
python复制# 验证集DataLoader应不使用采样器
val_dataloader = DataLoader(
val_dataset,
batch_size=32,
shuffle=False, # 验证集通常不shuffle
num_workers=4
)
除了WeightedRandomSampler,还有其他应对数据不平衡的方法:
损失函数层面的解决方案:
python复制# 使用带权重的交叉熵损失
class_weights = torch.tensor([1.0, 5.0]) # 假设类别1是少数类
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
数据增强技术:
模型架构调整:
在实际项目中,我通常会尝试组合多种策略。例如,同时使用WeightedRandomSampler和带权重的损失函数,并在少数类样本上应用数据增强。这种组合方法在医疗影像分析任务中取得了显著效果,将少数类的召回率从15%提升到了68%。