在MindSpore框架下训练复合任务模型时,最令人头疼的问题莫过于训练过程中突然出现的NaN损失值。我最近在开发一个同时处理图像分类和边界框回归的多任务网络时,就遇到了这个典型问题:模型在前几个epoch表现正常,损失值稳步下降,但到第3-5个epoch时,损失值突然变成NaN,导致训练完全崩溃。
从日志中可以清晰看到这个恶化过程:
code复制epoch 1, loss: 2.3456, cls_loss: 1.2345, reg_loss: 1.1111
epoch 2, loss: 1.7890, cls_loss: 0.9876, reg_loss: 0.8014
epoch 3, loss: 1.2345, cls_loss: 0.7654, reg_loss: 0.4691
epoch 4, loss: nan, cls_loss: nan, reg_loss: nan
问题出现在自定义的复合损失函数中,该函数同时结合了分类任务的交叉熵损失和回归任务的MSE损失:
python复制class CustomLoss(nn.LossBase):
def __init__(self, weight_factor=0.5):
super(CustomLoss, self).__init__()
self.weight_factor = weight_factor
self.ce_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
self.mse_loss = nn.MSELoss(reduction='mean')
def construct(self, logits, labels, regression_pred, regression_target):
cls_loss = self.ce_loss(logits, labels) # 分类损失
reg_loss = self.mse_loss(regression_pred, regression_target) # 回归损失
total_loss = cls_loss + self.weight_factor * reg_loss # 复合损失
return total_loss
通过观察问题表现,可以总结出几个关键特征:
这些特征表明,问题很可能与梯度流动和数值稳定性有关,而非简单的代码错误。
首先需要确认计算过程中是否出现了数值溢出。在浮点运算中,以下情况会导致NaN:
在复合损失函数中,MSE损失的计算需要特别关注:
python复制reg_loss = (pred - target)^2 # 如果pred和target差异很大,会导致平方项爆炸
通过MindSpore的调试工具检查梯度变化,发现当NaN出现时:
这种正反馈循环最终使整个网络的数值计算崩溃。
检查回归目标的统计特性:
python复制print(f"回归目标范围: [{regression_target.min()}, {regression_target.max()}]")
print(f"回归目标均值: {regression_target.mean()}, 标准差: {regression_target.std()}")
发现目标值范围在[0, 1000]之间,且标准差很大(约300),这会导致MSE损失值初始就很大。
对回归目标进行标准化处理:
python复制# 计算训练集的均值和标准差
reg_mean = regression_target.mean()
reg_std = regression_target.std()
# 标准化处理
regression_target = (regression_target - reg_mean) / reg_std
同时,在模型预测后需要反向转换:
python复制# 预测时
reg_output = self.regressor(features) * reg_std + reg_mean
改进后的损失函数增加了多项保护措施:
python复制class CustomLoss(nn.LossBase):
def __init__(self, weight_factor=0.5, clip_value=10.0):
super(CustomLoss, self).__init__()
self.weight_factor = weight_factor
self.clip_value = clip_value # 损失截断阈值
self.ce_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
self.mse_loss = nn.MSELoss(reduction='mean')
def construct(self, logits, labels, regression_pred, regression_target):
cls_loss = self.ce_loss(logits, labels)
# 对回归损失进行截断
reg_loss = self.mse_loss(regression_pred, regression_target)
reg_loss = ops.clip_by_value(reg_loss,
clip_value_min=0,
clip_value_max=self.clip_value)
# 复合损失
total_loss = cls_loss + self.weight_factor * reg_loss
# 增强的数值检查
if ops.isnan(total_loss) or ops.isinf(total_loss):
print(f"数值异常 - cls_loss: {cls_loss}, reg_loss: {reg_loss}")
# 返回安全值避免训练中断
return ops.zeros_like(total_loss) + 1.0
return total_loss
python复制from mindspore import amp
loss_scale_manager = amp.DynamicLossScaleManager(init_loss_scale=2**16,
scale_factor=2,
scale_window=1000)
python复制from mindspore.nn import ClipByGlobalNorm
optimizer = nn.Adam(network.trainable_params(),
learning_rate=0.001)
optimizer = ClipByGlobalNorm(optimizer, clip_norm=1.0)
在回归分支添加BatchNorm层稳定数值分布:
python复制self.regressor = nn.SequentialCell([
nn.Dense(64*16*16, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dense(256, 4)
])
使用MindInsight可视化工具监控:
在网络关键位置添加检查点:
python复制def construct(self, x):
features = self.feature_extractor(x)
# 特征检查
if ops.isnan(features).any():
print("NaN detected in features!")
return ops.zeros_like(x), ops.zeros_like(x)
features = features.view(features.shape[0], -1)
cls_output = self.classifier(features)
reg_output = self.regressor(features)
# 输出检查
if ops.isnan(reg_output).any():
print("NaN in regression output!")
reg_output = ops.clip_by_value(reg_output,
clip_value_min=-1e6,
clip_value_max=1e6)
return cls_output, reg_output
采用渐进式训练策略:
通过这次调试经历,我总结了以下几点关键经验:
数据规范化是基础:回归任务的输入和目标值必须进行适当的标准化处理,建议使用Z-score标准化
梯度管理不可忽视:复合任务中,不同分支的梯度量级可能差异很大,必须使用梯度裁剪或动态损失缩放
防御性编程很重要:在网络关键位置添加数值检查,可以及早发现问题而非等到训练崩溃
监控工具要善用:MindSpore提供的调试和可视化工具能极大提高问题定位效率
渐进式训练更稳定:复杂任务建议先分模块训练,再整体微调,而非一开始就端到端训练
一个更健壮的复合损失函数实现应该包含:
这些措施虽然增加了少量计算开销,但能显著提高训练稳定性。在实际项目中,训练时间的少量增加远比因NaN问题导致训练失败要划算得多。