深夜的实验室只剩下机箱的嗡鸣,屏幕上loss曲线正在稳步下降——突然一道闪电划过,整栋楼陷入黑暗。这种场景对深度学习开发者而言无异于噩梦,但有了正确的断点续训策略,你完全可以从容应对突发断电。本文将手把手教你构建一个健壮的PyTorch训练存档系统,让意外关机变得像游戏暂停一样无害。
一个完善的checkpoint应该像时光胶囊一样完整保存训练现场。以下是我们推荐的字典结构:
python复制checkpoint = {
'epoch': current_epoch + 1, # 下次应该开始的epoch
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict() if scheduler else None,
'best_score': best_score,
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
'grad_scaler_state': grad_scaler.state_dict() if grad_scaler else None
}
关键改进点:相比简单保存模型参数,我们额外捕获了随机数生成器状态、梯度缩放器状态等容易被忽视但影响训练连续性的要素。
固定间隔保存可能造成关键epoch数据丢失。更聪明的做法是动态调整保存频率:
python复制def should_save_checkpoint(epoch, is_best=False):
# 基础保存间隔
if epoch % args.save_interval == 0:
return True
# 性能突破时强制保存
if is_best:
return True
# 最后几个epoch密集保存
if epoch >= args.max_epochs - 3:
return True
return False
这段代码会自动扫描输出目录,找到最新的有效检查点文件:
python复制def find_latest_checkpoint(checkpoint_dir):
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_*.pth'))
if not checkpoint_files:
return None
# 按步数/epoch排序
checkpoint_files.sort(key=lambda x: int(re.search(r'checkpoint_(\d+).pth', x).group(1)))
return checkpoint_files[-1]
加载检查点时需要考虑多种异常情况:
python复制def safe_load_checkpoint(checkpoint_path, model, optimizer, device='cuda'):
try:
checkpoint = torch.load(checkpoint_path, map_location=device)
# 处理可能的key不匹配问题
model_state = checkpoint['model_state']
if any(k not in model.state_dict() for k in model_state):
print("Warning: 模型结构不匹配,尝试部分加载...")
model.load_state_dict(model_state, strict=False)
else:
model.load_state_dict(model_state)
optimizer.load_state_dict(checkpoint['optimizer_state'])
# 恢复CUDA随机状态
if 'cuda_rng_state' in checkpoint and torch.cuda.is_available():
torch.cuda.set_rng_state_all(checkpoint['cuda_rng_state'])
return checkpoint['epoch'], checkpoint.get('best_score', 0)
except Exception as e:
print(f"加载检查点失败: {str(e)}")
return 0, 0 # 从零开始
当需要在不同设备间迁移时,这个工具函数非常有用:
python复制def make_state_dict_compatible(state_dict, target_device):
new_state_dict = {}
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
new_state_dict[k] = v.to(target_device)
else:
new_state_dict[k] = v
return new_state_dict
DataParallel/DistributedDataParallel训练时需要特别注意:
python复制# 保存时移除module.前缀
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
# 加载时处理可能的设备不匹配
if any(k.startswith('module.') for k in model.state_dict()) and not any(k.startswith('module.') for k in model_state):
model_state = {'module.'+k: v for k, v in model_state.items()}
python复制def train_loop(model, train_loader, optimizer, start_epoch, num_epochs):
for epoch in range(start_epoch, num_epochs):
model.train()
# 恢复随机状态
if 'rng_state' in checkpoint:
torch.set_rng_state(checkpoint['rng_state'])
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 每100个batch检查一次中断信号
if batch_idx % 100 == 0 and os.path.exists('stop.tmp'):
print("检测到中断信号,保存检查点...")
save_checkpoint(epoch, batch_idx)
os.remove('stop.tmp')
return
# epoch结束保存
if should_save_checkpoint(epoch):
save_checkpoint(epoch + 1)
通过信号监听实现可控中断:
python复制import signal
class GracefulExiter:
def __init__(self):
self.state = False
signal.signal(signal.SIGINT, self.change_state)
def change_state(self, signum, frame):
print("捕获中断信号,准备安全退出...")
open('stop.tmp', 'w').close() # 创建中断标志文件
def should_exit(self):
return os.path.exists('stop.tmp')
exiter = GracefulExiter()
# 在训练循环中检查
if exiter.should_exit():
save_checkpoint(...)
break
对于大型模型,可以只保存变化的参数:
python复制def save_delta_checkpoint(model, base_checkpoint, delta_path):
current_state = model.state_dict()
delta = {k: current_state[k] - base_checkpoint['model_state'][k]
for k in base_checkpoint['model_state']}
torch.save(delta, delta_path)
处理超大模型时的加载优化:
python复制def load_large_checkpoint(path):
# 先加载到CPU内存
checkpoint = torch.load(path, map_location='cpu')
# 使用内存映射处理大张量
for k in checkpoint['model_state']:
if isinstance(checkpoint['model_state'][k], torch.Tensor):
checkpoint['model_state'][k] = checkpoint['model_state'][k].pin_memory()
return checkpoint
在真实项目中,我发现最实用的技巧是在训练脚本启动时就自动备份一份代码快照,这样即使几个月后需要复现结果,也能确保代码版本与检查点完全匹配。可以简单地在训练开始时添加:
python复制import shutil
from datetime import datetime
code_backup_dir = f"code_snapshots/{datetime.now().strftime('%Y%m%d_%H%M%S')}"
os.makedirs(code_backup_dir, exist_ok=True)
shutil.copytree('./src', os.path.join(code_backup_dir, 'src'))