在图像分类任务中,数据不平衡是常见但棘手的问题。想象一下,当你兴致勃勃地收集了11000张猫狗图片准备训练模型时,却发现其中10000张是狗,只有1000张是猫——这种9:1的极端比例会让模型严重偏向多数类。我曾在实际项目中遇到过这种情况:模型在验证集上准确率高达90%,但细看发现它把所有图片都预测成了狗!这种"偷懒"的模型显然无法满足真实需求。
解决这类问题,PyTorch提供的WeightedRandomSampler是个优雅的方案。不同于简单的过采样或欠采样,它能在每个epoch动态调整采样策略,既保证类别平衡,又充分利用所有数据。下面我将通过完整的代码示例,带你一步步实现这个解决方案。
数据不平衡不只是准确率数字的游戏。在实际业务场景中,少数类往往才是关键。比如医疗诊断中,患病样本远少于健康样本;金融风控中,欺诈交易占比可能不足1%。这些场景下,单纯追求整体准确率毫无意义。
以我们的猫狗分类为例,当数据比例为10:1时:
常见解决方案对比:
| 方法 | 优点 | 缺点 |
|---|---|---|
| 过采样 | 不丢失信息 | 可能导致过拟合 |
| 欠采样 | 计算效率高 | 丢弃有价值数据 |
| 类别权重 | 实现简单 | 对极端不平衡效果有限 |
| WeightedRandomSampler | 动态平衡 | 需额外计算开销 |
首先我们需要组织数据并创建PyTorch Dataset。假设图片存储在/data/cats_dogs目录下,结构如下:
code复制cats_dogs/
├── train/
│ ├── cat/ # 1000张图片
│ └── dog/ # 10000张图片
└── val/
├── cat/
└── dog/
创建自定义Dataset类:
python复制from torch.utils.data import Dataset
from PIL import Image
import os
class CatDogDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.cat_dir = os.path.join(root_dir, 'cat')
self.dog_dir = os.path.join(root_dir, 'dog')
# 获取所有文件路径
self.cat_files = [os.path.join(self.cat_dir, f)
for f in os.listdir(self.cat_dir)]
self.dog_files = [os.path.join(self.dog_dir, f)
for f in os.listdir(self.dog_dir)]
self.all_files = self.cat_files + self.dog_files
self.labels = [0]*len(self.cat_files) + [1]*len(self.dog_files)
def __len__(self):
return len(self.all_files)
def __getitem__(self, idx):
img_path = self.all_files[idx]
image = Image.open(img_path).convert('RGB')
label = self.labels[idx] # 0 for cat, 1 for dog
if self.transform:
image = self.transform(image)
return image, label
关键步骤是计算每个样本的采样权重。我们需要:
python复制def create_weighted_sampler(dataset):
# 获取类别分布
class_counts = torch.bincount(torch.tensor(dataset.labels))
num_samples = sum(class_counts)
# 计算每个类别的权重
class_weights = 1. / class_counts
# 为每个样本分配权重
sample_weights = [class_weights[label] for label in dataset.labels]
# 创建sampler
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=num_samples, # 通常设为数据集大小
replacement=True # 允许重复采样
)
return sampler
参数详解:
weights:每个样本的采样概率,不要求归一化num_samples:每个epoch采样的总数replacement:设为True允许重复采样少数类样本提示:当数据极度不平衡时,建议设置
replacement=True,否则少数类样本可能无法充分参与训练。
现在整合所有组件,创建数据加载器和训练循环:
python复制import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 创建Dataset
train_dataset = CatDogDataset('/data/cats_dogs/train', transform=transform)
# 创建WeightedRandomSampler
sampler = create_weighted_sampler(train_dataset)
# 创建DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=32,
sampler=sampler, # 使用自定义sampler
num_workers=4
)
# 简单CNN模型
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.fc1 = nn.Linear(32*56*56, 2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
# 训练循环
for epoch in range(10):
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')
在实际项目中,还有一些值得注意的细节:
动态调整采样策略:
python复制# 动态权重调整示例
def dynamic_weights(current_epoch, max_epoch):
alpha = 1 - (current_epoch / max_epoch) # 线性衰减
return [alpha*w + (1-alpha) for w in original_weights]
结合其他技术:
性能优化:
python复制# 优化后的DataLoader配置
train_loader = DataLoader(
dataset,
batch_size=64,
sampler=sampler,
num_workers=8,
pin_memory=True,
persistent_workers=True
)
训练完成后,我们需要验证采样策略的效果。关键指标应包括:
python复制from sklearn.metrics import classification_report
def evaluate(model, dataloader):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for images, labels in dataloader:
outputs = model(images)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
print(classification_report(all_labels, all_preds, target_names=['cat', 'dog']))
在我的实验中,使用WeightedRandomSampler后,猫类的召回率从不足10%提升到了78%,而整体准确率保持85%以上。这种平衡的性能在实际应用中远比单纯的准确率数字更有价值。