第一次遇到数值溢出问题是在训练一个简单的文本分类模型时。模型在训练过程中突然崩溃,日志里赫然写着"NaN"——这个让所有机器学习工程师头皮发麻的提示。经过排查,发现问题出在交叉熵损失函数的计算上:当模型对某个类别的预测置信度过高时,直接计算softmax会导致指数运算结果超出浮点数表示范围。
数值稳定性问题在深度学习中有两种典型表现:
这两种情况都会导致梯度计算出现异常。比如在计算交叉熵损失时:
python复制def unsafe_softmax(x):
exps = np.exp(x)
return exps / np.sum(exps) # 当x中有很大正值时,这里可能得到inf/inf=NaN
标准的Softmax计算存在明显的数值稳定性问题:
python复制import numpy as np
x = np.array([1000, 0, -1000]) # 极端输入案例
naive_softmax = np.exp(x) / np.sum(np.exp(x)) # 直接计算会得到[NaN, NaN, NaN]
LogSumExp技巧的核心在于数学上的等价变换。设b=max(x),我们有:
code复制log(∑exp(x_i)) = log(exp(b) * ∑exp(x_i - b))
= b + log(∑exp(x_i - b))
这个变换的妙处在于:
基于LogSumExp的稳定实现:
python复制def safe_softmax(x):
b = np.max(x)
exps = np.exp(x - b)
return exps / np.sum(exps)
x = np.array([1000, 0, -1000])
print(safe_softmax(x)) # 正确输出[1., 0., 0.]
这个实现可以处理任意大小的输入值,因为:
交叉熵损失的标准形式:
code复制CE = -log(exp(x_true)/∑exp(x_i))
= -x_true + log(∑exp(x_i))
直接实现时,log(∑exp)部分容易出问题。使用LogSumExp技巧:
python复制def stable_cross_entropy(logits, labels):
shift = logits - np.max(logits, axis=-1, keepdims=True)
log_probs = shift - np.log(np.sum(np.exp(shift), axis=-1, keepdims=True))
return -np.sum(labels * log_probs, axis=-1)
LogSumExp的梯度恰好就是Softmax输出:
code复制∂LSE(x)/∂x_i = exp(x_i) / ∑exp(x_j)
这意味着:
实际测试对比:
python复制# 不稳定实现
x = np.array([500, 300, 200], dtype=np.float32)
grad = np.exp(x) / np.sum(np.exp(x)) # 得到[NaN, NaN, NaN]
# 稳定实现
b = np.max(x)
stable_grad = np.exp(x - b) / np.sum(np.exp(x - b)) # 正确输出[1., 0., 0.]
实际训练中我们通常处理批量数据。高效的向量化实现:
python复制def batch_softmax(logits):
max_logits = np.max(logits, axis=-1, keepdims=True)
exps = np.exp(logits - max_logits)
return exps / np.sum(exps, axis=-1, keepdims=True)
在使用FP16混合精度训练时,数值范围更小(最大约6.5e4),LogSumExp更为关键:
python复制def mixed_precision_softmax(logits): # logits是FP16
logits_fp32 = logits.astype(np.float32)
max_logits = np.max(logits_fp32, axis=-1, keepdims=True)
exps = np.exp(logits_fp32 - max_logits)
return (exps / np.sum(exps, axis=-1, keepdims=True)).astype(np.float16)
LogSumExp技巧同样适用于:
例如带温度参数的稳定实现:
python复制def tempered_softmax(logits, temperature):
scaled = logits / temperature
b = np.max(scaled, axis=-1, keepdims=True)
exps = np.exp(scaled - b)
return exps / np.sum(exps, axis=-1, keepdims=True)