1. RNN基础实现解析
1.1 模型参数初始化
在RNN实现中,我们首先需要初始化模型参数。这段代码展示了如何为RNN创建和初始化权重矩阵和偏置向量:
python复制def get_params(vocab_size, num_hiddens, device):
num_inputs = num_output = vocab_size
def normal(shape):
return torch.randn(size=shape, device=device) * 0.01
W_xh = normal((num_inputs, num_hiddens)) # 输入到隐藏层的权重
W_hh = normal((num_hiddens, num_hiddens)) # 隐藏层到隐藏层的权重
b_h = torch.zeros(num_hiddens, device=device) # 隐藏层偏置
W_hy = normal((num_hiddens, num_output)) # 隐藏层到输出的权重
b_y = torch.zeros(num_output, device=device) # 输出层偏置
params = [W_xh, W_hh, b_h, W_hy, b_y]
for param in params:
param.requires_grad_(True)
return params
这里有几个关键点需要注意:
- 权重初始化采用小随机数(乘以0.01),这是为了防止初始值过大导致梯度爆炸
- 偏置初始化为零,这是常见的做法
- 所有参数都设置为需要梯度计算(requires_grad=True),以便后续反向传播
- 参数会根据设备(CPU或GPU)进行初始化,确保计算在正确的设备上进行
1.2 隐藏状态初始化
RNN的核心特点是具有隐藏状态,需要在每个序列开始时初始化:
python复制def init_rnn_state(batch_size, num_hiddens, device):
return (torch.zeros(batch_size, num_hiddens, device=device),)
这个函数创建了一个全零的隐藏状态张量,形状为(batch_size, num_hiddens)。使用元组包装是为了与更复杂的RNN变体(如LSTM)保持接口一致。
2. RNN前向传播实现
2.1 单时间步计算
RNN的核心计算逻辑体现在前向传播函数中:
python复制def rnn(inputs, state, params):
W_xh, W_hh, b_h, W_hy, b_y = params
H, = state
outputs = []
for X in inputs:
H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)
Y = torch.mm(H, W_hy) + b_y
outputs.append(Y)
return torch.cat(outputs, dim=0), (H,)
这个函数实现了RNN的标准计算流程:
- 对每个时间步的输入X,计算新的隐藏状态H = tanh(XW_xh + HW_hh + b_h)
- 使用当前隐藏状态计算输出Y = HW_hy + b_y
- 将所有时间步的输出拼接起来返回
- 返回最终的隐藏状态用于下一个序列的计算
注意:这里使用tanh作为激活函数,它可以将值压缩到(-1,1)之间,有助于缓解梯度爆炸问题。
2.2 RNN模型封装
为了更方便地使用RNN,我们将其封装为一个类:
python复制class RNNModelScratch:
def __init__(self, vocab_size, num_hiddens, device,
get_params, init_rnn_state, forward_fn):
self.vocab_size = vocab_size
self.num_hiddens = num_hiddens
self.params = get_params(vocab_size, num_hiddens, device)
self.init_rnn_state = init_rnn_state
self.forward_fn = forward_fn
def __call__(self, X, state):
X = F.one_hot(X.T, self.vocab_size).type(torch.float32)
return self.forward_fn(X, state, self.params)
def begin_state(self, batch_size, device):
return self.init_rnn_state(batch_size, self.num_hiddens, device)
这个封装提供了几个便利:
- 统一管理模型参数和状态初始化
- 自动将输入token转换为one-hot编码
- 提供了清晰的接口用于获取初始状态
3. 序列预测与训练
3.1 序列预测实现
序列预测是RNN的典型应用之一,这里实现了自回归预测:
python复制def predict_ch8(prefix, num_preds, net, vocab, device):
state = net.begin_state(batch_size=1, device=device)
outputs = [vocab[prefix[0]]]
get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape(1,1)
# 预热期:使用已知前缀初始化状态
for y in prefix[1:]:
_, state = net(get_input(), state)
outputs.append(vocab[y])
# 预测期:使用模型自身输出作为下一步输入
for _ in range(num_preds):
y, state = net(get_input(), state)
outputs.append(int(y.argmax(dim=1).reshape(1)))
return ''.join([vocab.idx_to_token[i] for i in outputs])
这个预测函数的工作流程:
- 使用给定的前缀初始化隐藏状态
- 在预热期,使用前缀中的真实字符更新状态
- 在预测期,使用模型预测的下一个字符作为输入(自回归)
- 将预测的token索引转换回字符
3.2 梯度裁剪技术
RNN训练中常见的问题是梯度爆炸,梯度裁剪是有效的解决方案:
python复制def grad_clipping(net, theta):
if isinstance(net, nn.Module):
params = [p for p in net.parameters() if p.requires_grad]
else:
params = net.params
norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
if norm > theta:
for param in params:
param.grad[:] *= theta / norm
梯度裁剪的原理:
- 计算所有参数梯度的L2范数
- 如果范数超过阈值theta,按比例缩小所有梯度
- 这样可以防止梯度值过大导致的参数剧烈变化
3.3 训练循环实现
完整的训练过程包括以下几个部分:
python复制def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):
state, timer = None, d2l.Timer()
metric = d2l.Accumulator(2) # 累计损失和token数
for X, Y in train_iter:
if state is None or use_random_iter:
state = net.begin_state(batch_size=X.shape[0], device=device)
else:
if isinstance(net, nn.Module):
state.detach_()
else:
for s in state:
s.detach_()
y = Y.T.reshape(-1)
X, y = X.to(device), y.to(device)
y_hat, state = net(X, state)
l = loss(y_hat, y.long()).mean()
if isinstance(updater, torch.optim.Optimizer):
updater.zero_grad()
l.backward()
grad_clipping(net, 1)
updater.step()
else:
l.backward()
grad_clipping(net, 1)
updater(batch_size=1)
metric.add(l * y.numel(), y.numel())
return math.exp(metric[0] / metric[1]) # 返回困惑度
训练中的关键点:
- 正确处理隐藏状态的初始化与分离(防止梯度跨序列传播)
- 使用交叉熵损失计算预测误差
- 应用梯度裁剪防止梯度爆炸
- 计算困惑度(perplexity)作为评估指标
4. 实际训练与结果分析
4.1 训练配置
python复制num_hiddens = 512
net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(),
get_params, init_rnn_state, rnn)
num_epochs = 500
lr = 1
train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())
这个配置中:
- 使用512维的隐藏状态
- 训练500个epoch
- 学习率设为1(对于SGD优化器来说较大,但有梯度裁剪保护)
- 自动检测并使用GPU加速
4.2 训练结果分析
从训练曲线和预测结果可以看出:
- 初始阶段模型输出几乎是随机的(困惑度很高)
- 随着训练进行,困惑度逐渐下降
- 最终模型能够生成部分有意义的序列,但仍不完美
这反映了RNN的一些固有局限性:
- 难以捕捉长距离依赖关系
- 训练过程不稳定(即使有梯度裁剪)
- 需要大量数据和训练时间才能达到较好效果
5. 关键问题与改进方向
5.1 常见问题排查
-
梯度消失/爆炸:
- 现象:模型无法学习或参数值变得异常
- 解决方案:使用梯度裁剪;考虑LSTM或GRU结构
-
模式坍塌:
- 现象:模型总是预测相同的输出
- 解决方案:检查初始化;调整学习率;增加数据多样性
-
过拟合:
- 现象:训练损失下降但验证损失上升
- 解决方案:增加正则化;使用dropout;获取更多数据
5.2 改进方向
-
模型结构改进:
- 使用LSTM或GRU替代基础RNN
- 增加双向结构捕捉上下文信息
- 尝试注意力机制
-
训练技巧:
- 使用学习率调度
- 尝试不同的优化器(如Adam)
- 实现早停机制
-
数据预处理:
- 更精细的tokenization
- 增加数据增强
- 更好的批量组织策略
在实际应用中,基础RNN往往表现不佳,现代深度学习通常使用其改进版本(如LSTM、GRU或Transformer)。但这个实现很好地展示了RNN的核心思想和工作原理,是理解更复杂模型的基础。