1. 为什么需要自定义数据集加载
在深度学习项目中,数据就像燃料一样重要。但现实世界的数据往往不像MNIST或CIFAR-10那样整齐划一。我遇到过太多这样的情况:客户发来的图像分散在几十个文件夹里,医疗数据需要特殊预处理,工业检测图片的命名毫无规律...这就是为什么PyTorch的Dataset类如此重要。
Dataset类就像是一个智能的数据管家,它能帮你:
- 统一处理各种"非标准"数据格式
- 在训练过程中动态进行数据增强
- 实现高效的内存管理(特别是处理大型数据集时)
- 与DataLoader配合实现多进程数据加载
2. 理解Dataset类的核心机制
2.1 Dataset类的三大必备方法
每个自定义Dataset必须继承torch.utils.data.Dataset并实现三个核心方法:
python复制class CustomDataset(Dataset):
def __init__(self, ...):
# 初始化:读取元数据、定义转换规则等
pass
def __len__(self):
# 返回数据集总样本数
return len(self.samples)
def __getitem__(self, idx):
# 根据索引返回单个样本(数据+标签)
return sample, label
注意:
__getitem__必须返回一个样本的数据和标签元组,这是PyTorch的约定俗成
2.2 数据加载的工作流程
当DataLoader请求数据时,背后发生了这些事:
- DataLoader决定要获取哪些索引(考虑batch_size, shuffle等)
- 对每个索引调用Dataset的
__getitem__ - 将返回的样本堆叠成batch张量
- 返回batch给训练循环
3. 实战:构建图像分类数据集
3.1 处理文件夹结构的图像数据
假设我们有如下目录结构:
code复制data/
class1/
img1.jpg
img2.jpg
class2/
img1.jpg
...
实现方案:
python复制from PIL import Image
import os
class ImageFolderDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.classes = os.listdir(root_dir)
self.class_to_idx = {c:i for i,c in enumerate(self.classes)}
# 收集所有图像路径和对应标签
self.samples = []
for class_name in self.classes:
class_dir = os.path.join(root_dir, class_name)
for img_name in os.listdir(class_dir):
self.samples.append((
os.path.join(class_dir, img_name),
self.class_to_idx[class_name]
))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
3.2 添加数据增强
使用torchvision.transforms组合各种增强:
python复制from torchvision import transforms
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
4. 高级技巧与性能优化
4.1 懒加载 vs 预加载
懒加载(推荐大型数据集):
__init__只存储文件路径__getitem__时才读取文件- 节省内存但IO压力大
预加载(适合小数据集):
__init__时将所有数据读入内存__getitem__直接返回- 内存占用高但训练速度快
4.2 使用内存映射文件
对于超大型数组数据(如numpy格式):
python复制class MemmapDataset(Dataset):
def __init__(self, npz_path):
self.data = np.load(npz_path, mmap_mode='r')
def __getitem__(self, idx):
return self.data[idx]
4.3 多模态数据加载
处理同时包含图像和文本的数据:
python复制class MultiModalDataset(Dataset):
def __init__(self, img_dir, text_path, transform=None):
self.img_dir = img_dir
self.texts = self._load_texts(text_path)
self.transform = transform
def __getitem__(self, idx):
img = Image.open(os.path.join(self.img_dir, f"{idx}.jpg"))
text = self.texts[idx]
if self.transform:
img = self.transform(img)
return {"image": img, "text": text}
5. 常见问题与解决方案
5.1 内存泄漏排查
症状:训练时间越来越长,内存持续增长
可能原因:
- 在
__getitem__中意外保留了引用 - 没有正确关闭文件句柄
解决方案:
python复制def __getitem__(self, idx):
with Image.open(self.paths[idx]) as img: # 使用上下文管理器
return self.transform(img), self.labels[idx]
5.2 多进程加载问题
当num_workers>0时注意:
- Dataset必须能在子进程中正确初始化
- 避免使用不能pickle的对象
- Linux上性能更好(Windows多进程实现不同)
5.3 数据不均衡处理
方案1:在Dataset中实现加权采样
python复制class WeightedDataset(Dataset):
def __init__(self, ...):
# 计算每个类的权重
class_counts = np.bincount(labels)
self.weights = 1. / class_counts[labels]
def get_weights(self):
return self.weights
然后在DataLoader中使用:
python复制weights = dataset.get_weights()
sampler = WeightedRandomSampler(weights, len(weights))
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
6. 与DataLoader的配合技巧
6.1 关键参数解析
python复制DataLoader(
dataset,
batch_size=32,
shuffle=True, # 训练集通常设为True
num_workers=4, # 根据CPU核心数调整
pin_memory=True, # GPU训练时建议开启
drop_last=False # 是否丢弃最后不足batch_size的样本
)
6.2 自定义collate_fn
处理不等长序列或特殊数据结构:
python复制def collate_fn(batch):
# batch是__getitem__返回的样本列表
images = torch.stack([item[0] for item in batch])
labels = torch.tensor([item[1] for item in batch])
return images, labels
6.3 预取加速技巧
python复制class PrefetchLoader:
def __init__(self, loader):
self.loader = loader
self.stream = torch.cuda.Stream()
def __iter__(self):
for batch in self.loader:
with torch.cuda.stream(self.stream):
batch = [b.cuda(non_blocking=True) for b in batch]
yield batch
7. 真实项目经验分享
在医疗影像项目中,我遇到过这些实际挑战:
- DICOM格式处理:
python复制import pydicom
def __getitem__(self, idx):
dicom = pydicom.dcmread(self.paths[idx])
img = dicom.pixel_array.astype(np.float32)
img = (img - img.min()) / (img.max() - img.min()) # 归一化
return torch.from_numpy(img), self.labels[idx]
- 超大图像分块加载:
python复制class PatchDataset(Dataset):
def __getitem__(self, idx):
# 计算原始图像中的位置
img_idx = idx // self.patches_per_image
patch_idx = idx % self.patches_per_image
with Image.open(self.paths[img_idx]) as img:
# 计算patch坐标并裁剪
left = ... # 根据patch_idx计算
patch = img.crop((left, top, left+size, top+size))
return self.transform(patch)
- 数据版本控制:
- 在
__init__中记录数据哈希值 - 保存预处理参数的完整配置
- 使用
@property实现动态数据过滤
python复制class VersionedDataset(Dataset):
def __init__(self, ...):
self.config = {
'data_hash': self._compute_hash(),
'transform': str(transform),
'filter_cond': filter_cond
}
@property
def filtered_indices(self):
return [i for i in range(len(self)) if self._filter(i)]