当你兴致勃勃地加载预训练模型准备大干一场时,突然蹦出"loaded state dict contains a parameter group that doesn't match the size of optimizer's group"这个错误,是不是瞬间血压飙升?这个报错本质上是在说:优化器记忆中的参数格局(state_dict)和你当前模型的参数组织方式(parameter group)对不上号了。
我去年在做一个图像分类项目时就踩过这个坑。当时为了提升模型效果,在预训练的ResNet基础上新增了两个全连接层,结果加载优化器状态时就直接崩溃。后来发现这是因为:
举个具体例子,假设原始模型结构如下:
python复制original_model = nn.Sequential(
nn.Conv2d(3, 16, 3),
nn.ReLU(),
nn.Linear(16*26*26, 10) # 假设输入是28x28,经过卷积后是26x26
)
而修改后的模型变成了:
python复制modified_model = nn.Sequential(
nn.Conv2d(3, 32, 5), # 通道数从16改为32,卷积核从3改为5
nn.ReLU(),
nn.Linear(32*24*24, 10), # 5x5卷积后尺寸变为24x24
nn.Linear(10, 5) # 新增的全连接层
)
这时如果直接加载原始模型的优化器状态,就会因为参数形状和组数都不匹配而报错。
遇到这个错误时,千万别急着乱改代码。我总结了一套诊断流程,可以帮你快速定位问题:
首先用这个代码片段打印出当前模型和checkpoint的差异:
python复制# 加载checkpoint
checkpoint = torch.load('your_checkpoint.pth')
# 打印模型参数键名对比
print("=== Model keys ===")
print([k for k in model.state_dict().keys()])
print("=== Checkpoint keys ===")
print([k for k in checkpoint['model_state_dict'].keys()])
# 打印优化器参数组对比
print("\n=== Current optimizer groups ===")
print([len(g['params']) for g in optimizer.param_groups])
print("=== Checkpoint optimizer groups ===")
print([len(g['params']) for g in checkpoint['optimizer_state_dict']['param_groups']])
根据我的经验,问题通常出现在以下几种情况:
对于复杂模型,我推荐使用这个可视化对比函数:
python复制def compare_state_dicts(current, loaded):
diff = {}
for k in loaded:
if k not in current:
diff[f'missing_in_current::{k}'] = loaded[k].shape
elif loaded[k].shape != current[k].shape:
diff[f'shape_mismatch::{k}'] = f"{loaded[k].shape} vs {current[k].shape}"
for k in current:
if k not in loaded:
diff[f'missing_in_loaded::{k}'] = current[k].shape
return diff
# 使用示例
diff = compare_state_dicts(model.state_dict(), checkpoint['model_state_dict'])
print(json.dumps(diff, indent=2))
这是最轻量级的解决方案,适合参数结构变化不大的情况。核心思想是:只加载能匹配上的参数,忽略不匹配的。
python复制def load_with_filter(model, optimizer, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
# 模型参数过滤
model_state_dict = model.state_dict()
filtered_model_state = {k:v for k,v in checkpoint['model_state_dict'].items()
if k in model_state_dict and v.shape == model_state_dict[k].shape}
model.load_state_dict(filtered_model_state, strict=False)
# 优化器参数过滤
if 'optimizer_state_dict' in checkpoint:
opt_state = checkpoint['optimizer_state_dict']
current_opt_state = optimizer.state_dict()
# 过滤state
filtered_state = {}
for param in current_opt_state['state']:
if param in opt_state['state']:
filtered_state[param] = opt_state['state'][param]
# 构建新的optimizer state_dict
new_opt_state = {
'state': filtered_state,
'param_groups': current_opt_state['param_groups']
}
optimizer.load_state_dict(new_opt_state)
return model, optimizer
当参数名发生变化时,可以建立映射关系:
python复制key_mapping = {
'old_conv.weight': 'new_conv.weight',
'old_bn.running_mean': 'new_bn.running_mean'
# 其他映射关系...
}
def map_keys(state_dict, mapping):
new_state = {}
for k, v in state_dict.items():
new_key = mapping.get(k, k)
new_state[new_key] = v
return new_state
# 使用示例
mapped_state = map_keys(checkpoint['model_state_dict'], key_mapping)
model.load_state_dict(mapped_state, strict=False)
当模型结构改动较大时,更稳妥的做法是重建优化器。这个方法虽然会丢失之前的优化器状态(如动量等),但能确保参数组完全匹配。
python复制def rebuild_optimizer(model, old_optimizer, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
# 加载模型参数(使用strict=False允许部分加载)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
# 创建新优化器(保持相同的超参数)
optimizer = type(old_optimizer)(model.parameters(), **old_optimizer.defaults)
# 如果可能,尽量恢复部分状态
if 'optimizer_state_dict' in checkpoint:
old_state = checkpoint['optimizer_state_dict']
new_state = optimizer.state_dict()
# 只恢复能匹配上的参数状态
for param in new_state['state']:
if param in old_state['state']:
new_state['state'][param] = old_state['state'][param]
optimizer.load_state_dict(new_state)
return optimizer
虽然重建优化器会丢失部分状态,但我们可以通过调整学习率来补偿:
python复制# 在重建优化器后
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.5 # 适当降低学习率
param_group['initial_lr'] = param_group['lr'] # 更新初始学习率
对于需要精细控制的高级用户,可以手动建立参数映射关系。这是我处理复杂模型迁移时最常用的方法。
python复制def create_param_mapping(current_model, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
old_params = checkpoint['model_state_dict']
new_params = current_model.state_dict()
mapping = {}
# 自动匹配相同名称和形状的参数
for new_k in new_params:
if new_k in old_params and new_params[new_k].shape == old_params[new_k].shape:
mapping[new_k] = old_params[new_k]
# 手动添加特殊映射
mapping.update({
'new_conv1.weight': old_params['old_conv.weight'][:16], # 取前16个通道
'new_bn.running_var': old_params['old_bn.running_var'].repeat(2) # 通道数翻倍时复制统计量
})
return mapping
python复制def load_with_mapping(model, optimizer, checkpoint_path):
mapping = create_param_mapping(model, checkpoint_path)
# 加载模型参数
model_state = model.state_dict()
model_state.update(mapping)
model.load_state_dict(model_state)
# 处理优化器状态
if 'optimizer_state_dict' in checkpoint:
old_opt_state = checkpoint['optimizer_state_dict']
new_opt_state = optimizer.state_dict()
# 转换state中的参数引用
state_mapping = {}
for new_p, old_p in zip(model.parameters(), old_opt_state['state'].keys()):
if str(old_p) in old_opt_state['state']:
state_mapping[new_p] = old_opt_state['state'][old_p]
new_opt_state['state'] = state_mapping
optimizer.load_state_dict(new_opt_state)
return model, optimizer
去年我在做一个医学图像分类项目时,就遇到了典型的不匹配问题。原始模型是在ImageNet上预训练的ResNet34,而我们的任务需要:
这种情况下直接加载优化器状态肯定会报错。我的解决方案是:
python复制# 1. 首先加载能匹配的基础卷积层参数
pretrained_dict = torch.load('resnet34.pth')
model_dict = model.state_dict()
# 2. 过滤匹配的参数
pretrained_dict = {k: v for k, v in pretrained_dict.items()
if k in model_dict and v.shape == model_dict[k].shape}
# 3. 特殊处理BN层的running_mean/var
for k in list(pretrained_dict.keys()):
if 'running_mean' in k or 'running_var' in k:
# 对于新增的注意力模块中的BN层,复制相近层的统计量
new_k = k.replace('layer3', 'attention')
if new_k in model_dict:
pretrained_dict[new_k] = pretrained_dict[k].clone()
# 4. 加载模型参数
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict, strict=False)
# 5. 重建优化器但保留匹配参数的状态
optimizer = torch.optim.AdamW(model.parameters())
if 'optimizer_state_dict' in checkpoint:
old_state = checkpoint['optimizer_state_dict']
new_state = optimizer.state_dict()
# 建立参数映射
param_map = {new_p: old_p for new_p, old_p in
zip(model.parameters(), old_state['state'].keys())
if str(old_p) in old_state['state']}
# 恢复状态
for new_p, old_p in param_map.items():
if old_p in old_state['state']:
new_state['state'][new_p] = old_state['state'][old_p]
optimizer.load_state_dict(new_state)
这个方案成功恢复了90%以上的优化器状态,使模型在微调初期就能获得较好的表现。