1. PyTorch深度学习框架概述
PyTorch作为当前最受欢迎的深度学习框架之一,其设计理念和架构特点使其在学术界和工业界都获得了广泛应用。让我们从框架的核心特性开始,逐步深入理解PyTorch的强大之处。
1.1 PyTorch的核心设计哲学
PyTorch的设计遵循了几个关键原则,这些原则使其在众多深度学习框架中脱颖而出:
-
Python优先:PyTorch完全拥抱Python生态系统,API设计符合Python编程习惯,使得开发者能够用熟悉的Python语法进行深度学习开发。这种设计降低了学习曲线,也让调试变得更加直观。
-
动态计算图(Define-by-Run):与静态图框架不同,PyTorch的计算图是在代码执行过程中动态构建的。这意味着你可以像编写普通Python程序一样编写神经网络,使用常规的控制流语句(如if条件、for循环),而无需预先定义完整的计算图。
-
即时执行(Eager Execution):PyTorch采用即时执行模式,代码即执行,无需预编译。这种模式提供了更好的调试体验,开发者可以使用标准的Python调试工具(如pdb)逐步执行代码。
-
模块化设计:PyTorch将神经网络组件(nn.Module)、优化器、数据加载器等核心功能分离设计,使得代码复用性更强,也更容易扩展。
1.2 PyTorch 2.3+的新特性
PyTorch 2.x系列带来了显著的性能提升和新功能,以下是2.3+版本的核心特性:
| 特性 | 版本引入 | 功能描述 | 性能提升 |
|---|---|---|---|
| torch.compile | 2.0+ | 图编译优化,自动融合算子 | 1.5-2倍 |
| SDPA | 2.0+ | 统一注意力接口,自动选择最优实现 | 2-4倍(Transformer) |
| torch.export | 2.1+ | 模型导出为可移植格式 | 部署友好 |
| Compile Optimizer | 2.2+ | 优化器编译加速 | 1.2-1.5倍 |
| Custom Operators | 2.3+ | 增强自定义算子支持 | 灵活性提升 |
这些新特性使得PyTorch在保持动态图灵活性的同时,也能获得接近静态图框架的性能表现。
1.3 PyTorch安装指南
安装PyTorch非常简单,官方提供了多种安装选项:
bash复制# CPU版本基础安装
pip install torch torchvision torchaudio
# CUDA 12.1 GPU版本
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# CUDA 11.8 GPU版本
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
安装完成后,可以通过以下命令验证安装:
python复制import torch
print(f'PyTorch版本: {torch.__version__}')
print(f'CUDA可用: {torch.cuda.is_available()}')
提示:选择GPU版本时,请确保CUDA驱动版本与PyTorch要求的版本匹配。可以使用
nvidia-smi命令查看当前CUDA版本。
2. PyTorch与TensorFlow深度对比
2.1 核心架构差异
PyTorch和TensorFlow在计算图机制上有本质区别:
-
PyTorch(动态图/Eager模式):
- 代码执行时动态构建计算图
- 每轮迭代可以改变图结构
- 调试如同普通Python代码
- 开发效率高,适合研究和原型开发
-
TensorFlow(静态图/Graph模式):
- 需要先定义完整计算图
- 图结构固定后执行
- 需要特殊工具调试
- 部署性能优化空间大
2.2 全面功能对比
| 对比维度 | PyTorch | TensorFlow | 说明 |
|---|---|---|---|
| 计算图 | 动态图 | 静态图 | PyTorch 2.x支持编译优化 |
| 调试体验 | 原生Python调试 | 需要特殊工具 | PyTorch更直观 |
| 学习曲线 | 平缓 | 较陡峭 | PyTorch API更Pythonic |
| 学术界 | 主导地位 | 广泛使用 | 70%+论文使用PyTorch |
| 工业界 | 逐渐普及 | 成熟稳定 | TensorFlow部署生态更完善 |
| 生产部署 | TorchServe, ONNX | TensorFlow Serving | TensorFlow工具更丰富 |
| 可视化 | TensorBoard/wandb | TensorBoard | 两者都支持TensorBoard |
| 移动端 | PyTorch Mobile | TensorFlow Lite | TFLite生态更成熟 |
2.3 框架选型建议
选择PyTorch的场景:
- 学术研究和论文复现
- 需要动态网络结构(如可变长度序列处理)
- 快速原型开发
- 复杂模型调试
- Transformer架构开发(Hugging Face生态)
选择TensorFlow的场景:
- 大规模生产部署
- 移动端/嵌入式设备应用
- 需要完整MLOps工具链
- 企业级机器学习平台
- 与Google Cloud深度集成
经验分享:在实际项目中,研究阶段通常使用PyTorch,而部署阶段可能会转换为TensorFlow。但随着PyTorch 2.x的编译优化和部署工具完善,这种分工正在变得模糊。
3. 动态计算图原理深度解析
3.1 计算图基础概念
计算图是深度学习框架的核心抽象,用于表示数学运算的数据流。例如,对于数学表达式:
code复制z = (x + y) * w
对应的计算图表示为:
code复制x ──┐
├──→ [+] ──→ [*] ──→ z
y ──┘ ↑
w ─────────┘
- 节点(Node):表示运算操作(如加法、乘法)
- 边(Edge):表示数据依赖关系(张量流动)
3.2 动态图与静态图对比
| 特性 | 动态图(PyTorch) | 静态图(TensorFlow 1.x) |
|---|---|---|
| 构建时机 | 运行时即时构建 | 预定义后执行 |
| 图结构 | 每轮可变化 | 固定不变 |
| 调试 | 原生Python调试 | 需要会话执行 |
| 优化空间 | 有限(2.x改进) | 全局优化潜力大 |
| 控制流 | Python原生if/for | 特殊控制流算子 |
3.3 PyTorch自动微分实现
PyTorch的自动微分(Autograd)系统是其核心功能之一。让我们通过一个具体例子来理解其工作原理:
python复制import torch
# 创建需要计算梯度的张量
x = torch.tensor(2.0, requires_grad=True)
w = torch.tensor(3.0, requires_grad=True)
b = torch.tensor(1.0, requires_grad=True)
# 前向传播 - 动态构建计算图
u = x * w # 乘法节点
v = u + b # 加法节点
y = v ** 2 # 幂运算节点
# 计算图结构:
# x(2.0) ──[*]──→ u(6.0) ──[+]──→ v(7.0) ──[**2]──→ y(49.0)
# ↑ ↑
# w(3.0)───┘ b(1.0)───┘
# 反向传播 - 自动计算梯度
y.backward()
# 查看梯度
print(f"dy/dx = {x.grad}") # 2*v*w = 2*7*3 = 42
print(f"dy/dw = {w.grad}") # 2*v*x = 2*7*2 = 28
print(f"dy/db = {b.grad}") # 2*v = 2*7 = 14
梯度计算遵循链式法则:
code复制y = v², v = u + b, u = x * w
∂y/∂x = ∂y/∂v * ∂v/∂u * ∂u/∂x
= 2v * 1 * w
= 2 * 7 * 3 = 42
∂y/∂w = ∂y/∂v * ∂v/∂u * ∂u/∂w
= 2v * 1 * x
= 2 * 7 * 2 = 28
∂y/∂b = ∂y/∂v * ∂v/∂b
= 2v * 1
= 2 * 7 = 14
3.4 计算图的生命周期
-
前向传播构建图:
- 输入张量(叶子节点)
- 通过Function节点记录操作
- 输出张量保存grad_fn引用
-
反向传播遍历图:
- 从损失张量开始
- 调用grad_fn.backward()
- 递归传播到叶子节点
3.5 动态图的优势示例
动态RNN处理变长序列:
python复制def dynamic_rnn(inputs, hidden_size):
"""动态RNN - 每轮迭代图结构不同"""
batch_size, seq_len, feat_dim = inputs.shape
hidden = torch.zeros(batch_size, hidden_size)
outputs = []
# 根据实际序列长度动态展开
for t in range(seq_len):
# 每轮循环都创建新的计算图节点
hidden = torch.tanh(
inputs[:, t, :] @ W_ih + hidden @ W_hh + b
)
outputs.append(hidden)
return torch.stack(outputs)
条件控制流示例:
python复制class DynamicNet(torch.nn.Module):
def forward(self, x):
# 根据输入动态决定网络路径
if x.sum() > 0:
return self.branch_a(x)
else:
return self.branch_b(x)
# 可以在前向中插入调试语句
print(f"中间值: {x.mean()}")
import pdb; pdb.set_trace()
注意事项:动态图的灵活性带来了调试便利,但也可能导致性能开销。PyTorch 2.x的torch.compile可以在保持动态图API的同时,通过编译优化获得更好的性能。
4. PyTorch基础:张量与自动微分
4.1 张量操作大全
张量(Tensor)是PyTorch中最基本的数据结构,类似于NumPy的多维数组,但支持GPU加速和自动微分。以下是常用张量操作:
python复制import torch
import numpy as np
# ===== 张量创建 =====
# 从数据创建
scalar = torch.tensor(3.14) # 标量
vector = torch.tensor([1, 2, 3]) # 向量
matrix = torch.tensor([[1, 2], [3, 4]]) # 矩阵
# 特殊张量
zeros = torch.zeros(3, 3) # 全零
ones = torch.ones(2, 2) # 全一
random = torch.rand(3, 3) # 均匀分布
normal = torch.randn(3, 3) # 标准正态分布
eye = torch.eye(3) # 单位矩阵
# ===== 张量属性 =====
t = torch.randn(2, 3, 4)
print(f"形状: {t.shape}") # torch.Size([2, 3, 4])
print(f"数据类型: {t.dtype}") # torch.float32
print(f"设备: {t.device}") # cpu/cuda
# ===== 张量运算 =====
a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
b = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32)
# 元素级运算
c = a + b # 加法
c = a * b # 元素乘法
c = a / b # 除法
# 矩阵运算
c = a @ b # 矩阵乘法
c = torch.matmul(a, b) # 同上
# 维度操作
c = a.t() # 转置
c = a.reshape(1, 4) # 变形
c = a.squeeze() # 去除大小为1的维度
4.2 GPU加速与设备管理
PyTorch提供了简单的API来管理GPU设备:
python复制# 检查GPU可用性
print(f"CUDA可用: {torch.cuda.is_available()}")
print(f"GPU数量: {torch.cuda.device_count()}")
if torch.cuda.is_available():
print(f"当前GPU: {torch.cuda.get_device_name(0)}")
# 设备选择策略
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 张量设备转移
x = torch.rand(3, 3)
x_gpu = x.to(device) # 转移到GPU
x_cpu = x_gpu.cpu() # 转回CPU
# GPU内存管理
print(f"已分配内存: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"预留内存: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
torch.cuda.empty_cache() # 清空缓存
实操技巧:使用
.to(device)而不是直接调用.cuda(),这样代码可以兼容CPU和GPU环境。对于大型模型,可以使用pin_memory=True加速数据加载。
4.3 自动微分实战
PyTorch的自动微分系统(Autograd)是其核心功能之一:
python复制# ===== 基础梯度计算 =====
x = torch.tensor(3.0, requires_grad=True)
y = x**2 + 2*x + 1
y.backward()
print(f"dy/dx at x=3: {x.grad}") # 2*3 + 2 = 8
# ===== 多变量梯度 =====
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = torch.sum(x**2) # y = x1² + x2² + x3²
y.backward()
print(f"梯度: {x.grad}") # [2, 4, 6]
# ===== 高阶导数 =====
x = torch.tensor(2.0, requires_grad=True)
y = x**3
# 一阶导数
dy_dx = torch.autograd.grad(y, x, create_graph=True)[0]
print(f"一阶导数: {dy_dx}") # 3*x² = 12
# 二阶导数
d2y_dx2 = torch.autograd.grad(dy_dx, x)[0]
print(f"二阶导数: {d2y_dx2}") # 6*x = 12
# ===== 梯度管理 =====
W = torch.randn(3, 3, requires_grad=True)
optimizer = torch.optim.SGD([W], lr=0.01)
for i in range(3):
loss = (W**2).sum()
# 梯度清零(重要!)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# ===== 禁用梯度 =====
with torch.no_grad():
y = model(x) # 不计算梯度
y_detached = y.detach() # 创建无梯度副本
常见问题:忘记调用
zero_grad()是导致训练不收敛的常见原因之一。在每次反向传播前,务必清空梯度。
5. 构建神经网络模型
5.1 使用nn.Module构建全连接网络
PyTorch提供了nn.Module基类来构建神经网络,这是构建模型的推荐方式:
python复制import torch.nn as nn
import torch.nn.functional as F
class NeuralNetwork(nn.Module):
"""全连接神经网络示例"""
def __init__(self, input_size=784, hidden_size=256, num_classes=10):
super(NeuralNetwork, self).__init__()
# 定义网络层
self.fc1 = nn.Linear(input_size, hidden_size)
self.bn1 = nn.BatchNorm1d(hidden_size)
self.dropout = nn.Dropout(0.2)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.bn2 = nn.BatchNorm1d(hidden_size)
self.fc3 = nn.Linear(hidden_size, num_classes)
# 初始化权重
self._initialize_weights()
def _initialize_weights(self):
"""He初始化"""
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
# 展平输入
x = x.view(x.size(0), -1)
# 第一层
x = self.fc1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.dropout(x)
# 第二层
x = self.fc2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.dropout(x)
# 输出层
x = self.fc3(x)
return x
# 创建模型实例
model = NeuralNetwork(input_size=784, hidden_size=256, num_classes=10)
# 查看模型结构
print(model)
# 统计参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"总参数量: {total_params:,}")
print(f"可训练参数量: {trainable_params:,}")
5.2 构建卷积神经网络(CNN)
对于图像任务,卷积神经网络(CNN)通常是更好的选择:
python复制class ConvNet(nn.Module):
"""卷积神经网络用于图像分类"""
def __init__(self, num_classes=10, in_channels=1):
super(ConvNet, self).__init__()
# 第一个卷积块
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Dropout2d(0.25)
)
# 第二个卷积块
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Dropout2d(0.25)
)
# 第三个卷积块
self.conv3 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Dropout2d(0.25)
)
# 全连接层
self.fc = nn.Sequential(
nn.Linear(128 * 3 * 3, 128),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(128, num_classes)
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
# 展平
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# 创建模型实例
model = ConvNet(num_classes=10, in_channels=1)
设计技巧:使用
nn.Sequential组织网络层可以使代码更清晰。对于复杂网络,合理使用批归一化(BatchNorm)和Dropout可以显著提高模型性能。
5.3 使用SDPA实现注意力机制
PyTorch 2.0+引入了优化的注意力实现(SDPA),可以自动选择最优实现方式:
python复制import torch.nn.functional as F
class AttentionBlock(nn.Module):
"""使用PyTorch 2.0+ SDPA的注意力模块"""
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
B, N, C = x.shape
# 生成Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# 使用SDPA (自动选择最优实现)
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=False
)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.dropout(x)
return x
SDPA会自动根据硬件和输入大小选择最优实现:
| 实现方式 | 内存复杂度 | 速度 | 适用场景 |
|---|---|---|---|
| 原生实现 | O(N²) | 基准 | 通用 |
| Flash Attention | O(N) | 最快 | CUDA, 长序列 |
| Memory Efficient | O(N) | 快 | 显存受限 |
| Math (默认) | O(N²) | 一般 | CPU/通用 |
性能提示:对于Transformer类模型,使用SDPA可以显著减少内存占用并提高训练速度,特别是在长序列情况下。
6. PyTorch 2.3+编译优化(torch.compile)
6.1 torch.compile工作原理
torch.compile是PyTorch 2.0引入的核心特性,它通过将动态图转换为优化的静态图来提高性能:
code复制PyTorch代码 → 图捕获 → 中间表示(IR) → 优化 → 后端代码生成
↓ ↓ ↓ ↓ ↓
Eager Mode Dynamo FX Graph 优化pass Triton
(前端) (中间层) (融合等) (GPU)
主要优化策略包括:
- 算子融合(Operator Fusion)
- 内存规划(Memory Planning)
- 布局优化(Layout Optimization)
- 常量传播(Constant Propagation)
6.2 使用torch.compile
python复制import torch
# 基础用法 - 一行代码加速模型
model = ConvNet(num_classes=10).cuda()
compiled_model = torch.compile(model)
# 完整配置选项
compiled_model = torch.compile(
model,
mode="default", # default/reduce-overhead/max-autotune
fullgraph=False, # 是否要求完整图捕获
dynamic=False, # 是否支持动态形状
)
# 编译模式对比
"""
| 模式 | 编译时间 | 运行时性能 | 适用场景 |
|:---|:---|:---|:---|
| default | 中等 | 好 | 通用 |
| reduce-overhead | 短 | 较好 | 小模型/多小批次 |
| max-autotune | 长 | 最佳 | 大模型/长训练 |
"""
# 编译优化器 (PyTorch 2.2+)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
compiled_optimizer = torch.compile(optimizer, mode="reduce-overhead")
6.3 编译性能对比
python复制import time
def benchmark(model, input_tensor, num_iterations=100):
"""基准测试函数"""
# 预热
for _ in range(10):
_ = model(input_tensor)
# 同步GPU
if torch.cuda.is_available():
torch.cuda.synchronize()
# 计时
start = time.time()
for _ in range(num_iterations):
output = model(input_tensor)
if torch.cuda.is_available():
torch.cuda.synchronize()
end = time.time()
return (end - start) / num_iterations * 1000 # ms
# 创建测试模型和数据
model = ConvNet(num_classes=10).cuda()
dummy_input = torch.randn(64, 1, 28, 28).cuda()
# 测试Eager模式
eager_time = benchmark(model, dummy_input)
print(f"Eager模式: {eager_time:.2f} ms/iter")
# 测试Compiled模式
compiled_model = torch.compile(model, mode="max-autotune")
compiled_time = benchmark(compiled_model, dummy_input)
print(f"Compiled模式: {compiled_time:.2f} ms/iter")
print(f"加速比: {eager_time / compiled_time:.2f}x")
实测数据:在ResNet50上,torch.compile通常可以获得1.5-2倍的训练速度提升,内存占用也能减少10-20%。
7. 数据加载与预处理
7.1 使用DataLoader加载数据
PyTorch提供了DataLoader类来高效加载和预处理数据:
python复制from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 数据变换
train_transform = transforms.Compose([
transforms.RandomRotation(10),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # MNIST均值和标准差
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 加载数据集
train_dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=train_transform
)
test_dataset = datasets.MNIST(
root='./data',
train=False,
download=True,
transform=test_transform
)
# 创建DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
num_workers=4, # 多进程数据加载
pin_memory=True, # 加速GPU数据传输
persistent_workers=True # 保持worker进程
)
test_loader = DataLoader(
test_dataset,
batch_size=64,
shuffle=False,
num_workers=4,
pin_memory=True
)
7.2 自定义数据集实现
对于非标准数据集,可以继承Dataset类实现自定义数据加载:
python复制from torch.utils.data import Dataset
from PIL import Image
class CustomImageDataset(Dataset):
"""自定义图像数据集"""
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# 加载图像
image = Image.open(self.image_paths[idx]).convert('RGB')
label = self.labels[idx]
# 应用变换
if self.transform:
image = self.transform(image)
return image, label
对于类别不平衡的数据集,可以实现平衡采样器:
python复制import numpy as np
class BalancedBatchSampler(torch.utils.data.Sampler):
"""类别平衡采样器"""
def __init__(self, dataset, labels, n_samples_per_class):
self.labels = np.array(labels)
self.n_samples_per_class = n_samples_per_class
self.n_classes = len(np.unique(labels))
# 为每个类别创建索引列表
self.class_indices = [
np.where(self.labels == c)[0] for c in range(self.n_classes)
]
def __iter__(self):
for _ in range(len(self)):
batch_indices = []
for class_idx in self.class_indices:
# 从每个类别随机采样
selected = np.random.choice(
class_idx,
self.n_samples_per_class,
replace=len(class_idx) < self.n_samples_per_class
)
batch_indices.extend(selected)
np.random.shuffle(batch_indices)
yield batch_indices
def __len__(self):
return len(self.labels) // (self.n_classes * self.n_samples_per_class)
数据加载优化:设置
num_workers=4(根据CPU核心数调整)和pin_memory=True可以显著提高数据加载速度,特别是在使用GPU时。
8. 完整实战:CIFAR-10图像分类
8.1 项目结构
code复制cifar10_classification/
├── data/ # 数据目录
├── models/ # 模型定义
│ └── resnet.py
├── utils/ # 工具函数
│ ├── train.py
│ └── evaluate.py
├── config.py # 配置文件
└── main.py # 主程序
8.2 数据准备与增强
python复制import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split
# 高级数据增强
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(p=0.5),
transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465],
std=[0.2470, 0.2435, 0.2616]
),
transforms.RandomErasing(p=0.5, scale=(0.02, 0.33)), # Cutout变体
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465],
std=[0.2470, 0.2435, 0.2616]
)
])
# 加载CIFAR-10
full_train_dataset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=train_transform
)
test_dataset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=test_transform
)
# 划分训练集和验证集
train_size = int(0.9 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(
full_train_dataset, [train_size, val_size],
generator=torch.Generator().manual_seed(42)
)
# 验证集使用测试变换
val_dataset.dataset.transform = test_transform
# 类别名称
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
8.3 ResNet模型实现
python复制import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
"""ResNet基础块"""
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
"""ResNet for CIFAR-10"""
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512 * block.expansion, num_classes)
self._initialize_weights()
def _make_layer(self, block, out_channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)