第一次接触PyTorch和CIFAR-10数据集时,我花了不少时间在环境配置上。建议直接使用Anaconda创建Python 3.8+的虚拟环境,这样能避免很多依赖冲突问题。安装PyTorch时,记得根据你的显卡选择CUDA版本,没有N卡的同学用CPU版本也能跑,只是训练速度会慢些。
CIFAR-10数据集包含6万张32x32的彩色图片,分为10个类别。我第一次看到这些迷你图片时,差点以为下载错了——飞机、汽车都像马赛克一样模糊。但正是这种小尺寸让它成为入门练手的绝佳选择,毕竟不是谁都有显卡能处理高清大图。
用PyTorch加载数据特别简单,几行代码就能搞定:
python复制import torchvision
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])
train_data = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
这里有个坑我踩过:忘记做数据标准化(Normalize)。刚开始训练时损失值死活不下降,后来发现是像素值(0-255)没归一化到(-1,1)区间。建议新手务必加上这个转换,能显著提升训练稳定性。
设计CNN结构就像搭积木,但第一次我堆得太复杂,结果模型根本训不动。后来发现对于CIFAR-10这种小图片,3-4个卷积层就够了。下面这个结构是我调试多次后比较稳定的版本:
python复制import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.fc1 = nn.Linear(64 * 8 * 8, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.pool(nn.functional.relu(self.conv1(x)))
x = self.pool(nn.functional.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
几个关键点:
第一次我忘了加flatten操作,导致全连接层报维度错误,debug了半小时才找到问题。建议新手每写完一个模块就打印下张量形状,比如在forward里加print(x.shape)。
训练代码看着简单,但调参才是真正的玄学。这是我总结的黄金参数组合:
python复制model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
关键技巧:
训练循环中最重要的就是记录损失和准确率:
python复制for epoch in range(20):
running_loss = 0.0
for i, data in enumerate(trainloader):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
scheduler.step()
print(f'Epoch {epoch+1} loss: {running_loss/len(trainloader):.3f}')
我习惯用tensorboard记录训练曲线,这样能直观看到模型是否在收敛。如果发现损失值震荡太大,可以尝试减小学习率或增大批量大小。
训练完别急着高兴,测试集才是试金石。我第一次跑出来的测试准确率只有58%,比瞎猜好不了多少。通过以下改进逐步提升到了75%+:
python复制transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(...)
])
最终我的测试代码长这样:
python复制correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy: {100 * correct / total:.2f}%')
如果准确率卡在某个值上不去,可能是模型容量不够或数据不够多样。这时可以考虑换更复杂的架构如ResNet,但要注意CIFAR-10图片太小,直接套用大模型可能适得其反。
保存和加载模型时要注意设备兼容性:
python复制# 保存
torch.save(model.state_dict(), 'cifar10_cnn.pth')
# 加载
model = SimpleCNN()
model.load_state_dict(torch.load('cifar10_cnn.pth'))
最后分享一个可视化卷积核的小技巧:
python复制import matplotlib.pyplot as plt
weights = model.conv1.weight.detach()
fig, axs = plt.subplots(4, 8, figsize=(12,6))
for idx, ax in enumerate(axs.flat):
ax.imshow(weights[idx].permute(1,2,0))
ax.axis('off')
这能帮你理解模型到底学到了哪些特征。刚开始我的卷积核都是噪声,训练后才出现有意义的边缘检测器。