PyTorch实战:用WeightedRandomSampler解决猫狗数据集不平衡问题(附完整代码)

林脸脸

PyTorch实战:用WeightedRandomSampler解决猫狗数据集不平衡问题(附完整代码)

在图像分类任务中,数据不平衡是常见但棘手的问题。想象一下,当你兴致勃勃地收集了11000张猫狗图片准备训练模型时,却发现其中10000张是狗,只有1000张是猫——这种9:1的极端比例会让模型严重偏向多数类。我曾在实际项目中遇到过这种情况:模型在验证集上准确率高达90%,但细看发现它把所有图片都预测成了狗!这种"偷懒"的模型显然无法满足真实需求。

解决这类问题,PyTorch提供的WeightedRandomSampler是个优雅的方案。不同于简单的过采样或欠采样,它能在每个epoch动态调整采样策略,既保证类别平衡,又充分利用所有数据。下面我将通过完整的代码示例,带你一步步实现这个解决方案。

1. 理解数据不平衡的影响

数据不平衡不只是准确率数字的游戏。在实际业务场景中,少数类往往才是关键。比如医疗诊断中,患病样本远少于健康样本;金融风控中,欺诈交易占比可能不足1%。这些场景下,单纯追求整体准确率毫无意义。

以我们的猫狗分类为例,当数据比例为10:1时:

  • 模型可能学会"永远预测为狗"的简单策略,验证准确率仍达90%
  • 对猫的召回率可能趋近于0,模型完全失效
  • 损失函数被多数类主导,梯度更新几乎不反映少数类特征

常见解决方案对比

方法 优点 缺点
过采样 不丢失信息 可能导致过拟合
欠采样 计算效率高 丢弃有价值数据
类别权重 实现简单 对极端不平衡效果有限
WeightedRandomSampler 动态平衡 需额外计算开销

2. 准备数据集与自定义Dataset

首先我们需要组织数据并创建PyTorch Dataset。假设图片存储在/data/cats_dogs目录下,结构如下:

code复制cats_dogs/
    ├── train/
    │   ├── cat/  # 1000张图片
    │   └── dog/  # 10000张图片
    └── val/
        ├── cat/
        └── dog/

创建自定义Dataset类:

python复制from torch.utils.data import Dataset
from PIL import Image
import os

class CatDogDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.cat_dir = os.path.join(root_dir, 'cat')
        self.dog_dir = os.path.join(root_dir, 'dog')
        
        # 获取所有文件路径
        self.cat_files = [os.path.join(self.cat_dir, f) 
                         for f in os.listdir(self.cat_dir)]
        self.dog_files = [os.path.join(self.dog_dir, f)
                         for f in os.listdir(self.dog_dir)]
        
        self.all_files = self.cat_files + self.dog_files
        self.labels = [0]*len(self.cat_files) + [1]*len(self.dog_files)
    
    def __len__(self):
        return len(self.all_files)
    
    def __getitem__(self, idx):
        img_path = self.all_files[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]  # 0 for cat, 1 for dog
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

3. 实现WeightedRandomSampler

关键步骤是计算每个样本的采样权重。我们需要:

  1. 统计每个类别的样本数量
  2. 为每个样本分配权重(少数类权重高)
  3. 创建WeightedRandomSampler实例
python复制def create_weighted_sampler(dataset):
    # 获取类别分布
    class_counts = torch.bincount(torch.tensor(dataset.labels))
    num_samples = sum(class_counts)
    
    # 计算每个类别的权重
    class_weights = 1. / class_counts
    
    # 为每个样本分配权重
    sample_weights = [class_weights[label] for label in dataset.labels]
    
    # 创建sampler
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=num_samples,  # 通常设为数据集大小
        replacement=True  # 允许重复采样
    )
    return sampler

参数详解

  • weights:每个样本的采样概率,不要求归一化
  • num_samples:每个epoch采样的总数
  • replacement:设为True允许重复采样少数类样本

提示:当数据极度不平衡时,建议设置replacement=True,否则少数类样本可能无法充分参与训练。

4. 构建完整训练流程

现在整合所有组件,创建数据加载器和训练循环:

python复制import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# 创建Dataset
train_dataset = CatDogDataset('/data/cats_dogs/train', transform=transform)

# 创建WeightedRandomSampler
sampler = create_weighted_sampler(train_dataset)

# 创建DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    sampler=sampler,  # 使用自定义sampler
    num_workers=4
)

# 简单CNN模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.fc1 = nn.Linear(32*56*56, 2)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

# 训练循环
for epoch in range(10):
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
    print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')

5. 高级技巧与优化建议

在实际项目中,还有一些值得注意的细节:

动态调整采样策略

  • 随着训练进行,可以逐渐降低采样权重差异
  • 监控各类别的验证准确率,针对性调整
python复制# 动态权重调整示例
def dynamic_weights(current_epoch, max_epoch):
    alpha = 1 - (current_epoch / max_epoch)  # 线性衰减
    return [alpha*w + (1-alpha) for w in original_weights]

结合其他技术

  • 在损失函数中使用类别权重
  • 应用mixup或cutmix等数据增强
  • 使用focal loss处理难样本

性能优化

  • 对于极大数据集,考虑缓存权重计算
  • 使用pin_memory加速GPU数据传输
  • 适当调整num_workers提升IO效率
python复制# 优化后的DataLoader配置
train_loader = DataLoader(
    dataset,
    batch_size=64,
    sampler=sampler,
    num_workers=8,
    pin_memory=True,
    persistent_workers=True
)

6. 验证与结果分析

训练完成后,我们需要验证采样策略的效果。关键指标应包括:

  • 各类别的精确率、召回率、F1分数
  • 混淆矩阵
  • 训练过程中的loss曲线对比
python复制from sklearn.metrics import classification_report

def evaluate(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in dataloader:
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    print(classification_report(all_labels, all_preds, target_names=['cat', 'dog']))

在我的实验中,使用WeightedRandomSampler后,猫类的召回率从不足10%提升到了78%,而整体准确率保持85%以上。这种平衡的性能在实际应用中远比单纯的准确率数字更有价值。

内容推荐

8086微处理器:从BIU与EU的协同到现代CPU架构的基石
本文深入解析8086微处理器的革命性设计,重点探讨BIU与EU的协同工作机制及其对现代CPU架构的影响。通过分析8086的双核架构、指令预取机制和总线周期优化,揭示其如何奠定现代处理器的基础设计理念,包括流水线技术、缓存体系和并行计算。
给程序员讲线性代数:用NumPy和几何动画理解基底与线性变换
本文从程序员视角解析线性代数的核心概念,通过NumPy实现和几何动画演示基底变换与线性变换的工程应用。详细讲解如何用代码实现坐标架变换、图形变形及逆矩阵操作,揭示行列式在空间变换中的实际意义,并探讨游戏引擎和性能优化中的实用技巧,帮助开发者将抽象数学转化为可视化解决方案。
用K230开发板给AI模型拍训练集照片:一个物理按键搞定数据采集
本文详细介绍了如何利用K230开发板构建智能数据采集系统,通过物理按键一键完成AI模型训练数据集的拍摄。从硬件配置到软件实现,再到数据质量控制,提供了完整的解决方案,特别适合个人开发者和教育场景使用。
【Mermaid】从零开始:手把手教你用代码绘制专业流程图
本文详细介绍了如何使用Mermaid代码绘制专业流程图,从基础语法到实战技巧全面解析。Mermaid作为一种基于文本的图表生成工具,具有版本控制友好、修改成本低和跨平台兼容等优势,特别适合技术文档和项目设计。文章包含5分钟快速上手指南、实战案例和高级技巧,帮助读者高效掌握这一强大工具。
性能飙升60%!手把手教你用Tool.Net 3.0.0的TcpFrame构建高性能字节流服务
本文详细介绍了如何使用Tool.Net 3.0.0的TcpFrame模块构建高性能字节流服务,实现60%的性能提升。通过协议栈重构、内存管理优化和智能心跳机制,Tool.Net显著提升了数据传输效率,适用于实时通信和物联网场景。文章还提供了实战示例和进阶调优技巧,帮助开发者快速掌握这一技术。
超越基础配置:SAP QM主检验特性(MIC)的三种‘模型’实战解析(Copy/Reference/Incomplete)
本文深入解析SAP QM主检验特性(MIC)的三种模型(Copy/Reference/Incomplete Copy)及其在质量管理中的实战应用。通过对比不同模型的数据同步机制和修改权限,帮助企业在QS21中合理配置检验特性,确保质量数据的准确性和合规性,特别适用于乳制品、制药和汽车零部件等行业。
MCP2515调试笔记----SPI时序与CS引脚操作要点
本文详细解析了MCP2515调试中的SPI时序与CS引脚操作要点,揭示了初始化过程中的常见陷阱及解决方案。通过硬件设计建议、软件优化方案和完整初始化流程示例,帮助工程师避免通信失败,提升MCP2515的稳定性和可靠性。
跨平台文件同步利器:Beyond Compare实战指南
本文详细介绍了跨平台文件同步工具Beyond Compare的实战应用,涵盖安装配置、核心比较模式、远程服务器连接及自动化脚本等高级功能。通过具体案例展示如何提升文件同步效率300%,特别适合开发者和团队协作场景,有效避免手动同步导致的错误。
YOLOv8 Detect Head 核心机制:从特征图到预测框的完整解码
本文深入解析了YOLOv8 Detect Head的核心机制,详细介绍了从多尺度特征图到预测框的完整解码过程。重点探讨了特征整合、位置预测和类别判断三大任务,以及Anchor-Free网格点生成、DFL边界框解码等关键技术,帮助开发者深入理解YOLOv8物体检测的实现原理。
PX4从入门到实战(三):外部控制与指令系统深度解析
本文深入解析PX4飞控系统的外部控制与指令系统,重点介绍OFFBOARD模式的核心原理与实战应用。通过详细代码示例,展示如何实现位置控制、轨迹跟踪及COMMAND接口的全面应用,帮助开发者掌握PX4的高级控制功能,提升无人机开发效率。
别再让PD图吃灰了!手把手教你用Python实现持续同调矢量化(附代码)
本文详细介绍了如何利用Python将持续同调图(PD)转化为机器学习模型可用的特征向量,涵盖五大矢量化方法:持续性图像(PI)、持续性景观(PL)、持续同调熵(PE)、贝蒂曲线和加权轮廓。通过实战代码和性能对比,帮助读者高效处理拓扑数据,提升模型表现。
深入OpenSfM的config.yaml:调参实战指南,让你的3D重建效果提升一个档次
本文深入解析OpenSfM的config.yaml配置文件,提供从特征提取到BA优化的全流程调参实战指南。通过场景化参数调整策略,帮助用户解决3D重建中的点云稀疏、模型断裂等问题,显著提升重建质量和效率。特别针对建筑、小物体等不同场景,给出具体参数优化方案。
ASN.1编码规则解析:从BER到XER的演进与应用
本文深入解析ASN.1编码规则,从基础的BER到高效的PER再到可读性强的XER,详细介绍了各种编码规则的特点、应用场景及实际开发经验。文章通过具体案例展示了不同编码规则在网络协议、金融交易、物联网等领域的应用,帮助开发者根据需求选择合适的编码方式,提升系统性能和兼容性。
告别卡顿!用Vue 3的transition实现一个丝滑的移动端跑马灯组件
本文介绍如何利用Vue 3的transition特性实现一个高性能的移动端跑马灯组件,解决传统setInterval方案导致的卡顿问题。通过动态文本宽度适配、CSS硬件加速和内存泄漏防护等优化技巧,确保动画在低端设备上也能丝滑流畅运行。
PyJWT Subject Must Be a String: Debugging Authentication Errors in Python APIs
本文详细解析了Python API中常见的JWT认证错误'Subject must be a string',深入探讨了PyJWT库对subject字段的严格类型检查机制。通过实际案例展示了如何调试和修复Flask-JWT-Extended中的类型不匹配问题,并提供了不同数据库环境下的解决方案和防御性编程实践,帮助开发者避免类似认证错误。
优雅处理JSON反序列化:空字符串到空集合的转换策略
本文详细探讨了JSON反序列化过程中空字符串到空集合的转换策略,重点介绍了如何使用Jackson自定义反序列化器优雅处理这一常见问题。通过实现EmptyStringListDeserializer类,开发者可以灵活应对前端传递空字符串的场景,同时确保类型安全。文章还提供了多种优化方案和测试用例,帮助开发者选择最适合项目的解决方案。
SolidWorks机械臂模型转STL导入Matlab保姆级教程(含Robotics Toolbox代码)
本文提供SolidWorks机械臂模型转STL并导入Matlab的完整教程,涵盖模型预处理、STL导出参数设置、Robotics Toolbox初始化及可视化渲染等关键步骤。特别针对机械臂运动学仿真中的坐标系对齐、模型缩放等问题提供解决方案,帮助研究者高效实现3D模型在Matlab环境中的精准导入与应用。
Qt横向流式布局实战:从官方Demo到自定义增强,打造灵活UI的2种核心方案
本文深入探讨了Qt横向流式布局的两种核心实现方案:QListView方案和自定义FlowLayout方案,并对比了它们的优缺点。通过官方Demo解析和自定义增强接口实战,帮助开发者打造灵活、高效的UI布局,特别适用于标签云、工具箱等动态排列场景。
VSCode调试C++程序时,那个烦人的‘gdb32.exe’错误到底怎么破?
本文详细解析了VSCode调试C++程序时常见的'gdb32.exe'错误,提供了从快速修复到专业配置的多重解决方案。通过分析GDB版本兼容性问题、优化launch.json配置及环境变量管理,帮助开发者彻底解决这一典型问题,提升调试效率。特别适用于使用MinGW-W64和GCC工具链的C++开发者。
从仿真到联动:手把手教你用MoveIt和Gazebo搭建机械臂闭环仿真环境
本文详细介绍了如何使用MoveIt和Gazebo搭建机械臂闭环仿真环境,涵盖环境准备、控制器配置、启动文件整合及高级调试技巧。通过ROS平台实现RViz与Gazebo的联动,帮助开发者验证机械臂控制算法在物理仿真中的表现,提升开发效率并降低硬件测试风险。
已经到底了哦
精选内容
热门内容
最新内容
Linux服务器频繁报soft lockup?这10种硬件和配置问题你排查了吗
本文详细解析了Linux服务器频繁报soft lockup的10种硬件和配置问题排查方法,包括电源系统、CPU与散热系统、内存子系统等关键部件的检测与优化。通过专业工具和实用技巧,帮助硬件工程师和系统运维人员快速定位并解决kernel:NMI watchdog触发的CPU异常问题,提升服务器稳定性。
Valgrind工具在嵌入式开发中的交叉编译实践与性能优化策略
本文详细介绍了Valgrind工具在嵌入式开发中的交叉编译实践与性能优化策略。通过ARM架构下的交叉编译实战,解决glibc符号问题,并提供编译期和运行时的优化参数,显著提升工具在资源受限环境中的效率。文章还分享了嵌入式场景下的诊断技巧和高级内存问题定位方法,帮助开发者更高效地使用Valgrind进行内存调试和性能优化。
RS485电路设计实战:从模块解析到工业场景可靠性保障
本文深入探讨了RS485电路设计在工业自动化领域的实战应用,从模块解析到工业场景可靠性保障。通过分析电磁兼容性、环境应力等工业特殊因素,详细介绍了信号隔离、ESD保护等核心电路设计要点,并提供了长距离传输、多节点网络优化等解决方案,助力工程师实现高可靠性的工业通讯系统。
零成本搞定日语视频字幕:从识别、处理到翻译的全链路实践
本文详细介绍了一套零成本制作日语视频字幕的全链路方案,涵盖语音识别、字幕优化和翻译三个核心环节。通过autosub、SrtEdit和字幕组机翻小助手等免费工具的组合使用,即使没有编程基础的用户也能轻松完成从日语识别到中文字幕生成的全流程,处理一小时视频仅需20-30分钟。
Android TV一键播放功能实战:HDMI-CEC底层实现与避坑指南
本文深入解析Android TV中HDMI-CEC协议的底层实现与一键播放功能开发实战。从HdmiControlService框架层到HAL硬件抽象层,详细剖析跨设备兼容性解决方案,并提供MTK平台特殊处理、多设备兼容性矩阵等实用技巧,帮助开发者规避常见陷阱,优化CEC响应性能。
从原理到实战:LCMV与GSC波束形成算法对比及MATLAB仿真全解析
本文深入解析了LCMV与GSC两种波束形成算法的原理、MATLAB实现及工程应用对比。通过详细的数学推导和仿真示例,展示了LCMV在多约束场景下的精确控制能力,以及GSC在实时自适应处理中的优势。针对算法选择、性能优化和工程实践中的常见问题提供了实用解决方案,为信号处理工程师提供了宝贵的参考指南。
从零搭建一辆ROS小车(四)激光雷达SLAM实战:Hector与Gmapping对比
本文详细对比了Hector与Gmapping两种激光雷达SLAM算法在ROS小车建图中的应用。Hector_mapping以零配置快速建图见长,适合资源有限或手持建图场景;而slam_gmapping则依赖里程计但精度更高,适合精细建图需求。通过实战参数配置和RPLIDAR实测数据,为开发者提供算法选型建议和优化技巧。
内网环境K8s部署Harbor避坑指南:从Helm Chart下载到Ceph S3存储配置全流程
本文详细介绍了在内网Kubernetes环境中部署Harbor镜像仓库的全流程,包括Helm Chart离线安装、Ceph S3存储配置等关键步骤。针对金融行业核心系统容器化改造场景,提供了从资源筹备到高可用架构设计的实战经验,帮助用户规避常见部署陷阱,实现稳定高效的企业级镜像管理。
Linux设备文件传输新思路:巧用telnet登录与busybox ftpget/ftpput
本文介绍了在Linux设备上通过telnet登录结合busybox的ftpget/ftpput工具实现高效文件传输的创新方法。针对嵌入式系统资源有限、网络环境封闭等场景,详细讲解了从telnet连接到文件上传下载的完整流程,包括实用技巧、常见问题解决方案及安全注意事项,为远程文件传输提供了轻量级替代方案。
ESP32-S3与ST7789屏幕的完美结合:1.3寸屏驱动实战指南
本文详细介绍了ESP32-S3与ST7789屏幕的硬件搭配与驱动开发实战指南。通过SPI通信优化、开发环境搭建避坑、TFT_eSPI库深度配置等技巧,帮助开发者快速实现1.3寸屏的高效驱动,适用于智能穿戴、便携设备等场景。