用deepseek学大模型08-循环神经网络
通过上述步骤,您可系统掌握 RNN 的核心理论、实现及优化方法。控制历史信息保留,避免传统 RNN 的连乘梯度,缓解消失问题。,导致梯度消失/爆炸。LSTM 通过细胞状态。
从入门到精通循环神经网络 (RNN)
https://www.dxy.cn/bbs/newweb/pc/post/50883341
https://wenku.csdn.net/column/kbnq75axws
1. RNN 基础
RNN 通过隐藏状态传递序列信息,核心公式:
- 隐藏状态:
ht=tanh(Whhht−1+Wxhxt+bh)\mathbf{h}_t = \tanh(\mathbf{W}_{hh} \mathbf{h}_{t-1} + \mathbf{W}_{xh} \mathbf{x}_t + \mathbf{b}_h)ht=tanh(Whhht−1+Wxhxt+bh) - 输出:
yt=Whyht+by\mathbf{y}_t = \mathbf{W}_{hy} \mathbf{h}_t + \mathbf{b}_yyt=Whyht+by
2. 目标函数与损失函数
- 目标函数:最小化预测与真实值的差距。
- 损失函数(以 MSE 为例):
L=12T∑t=1T(yt−y^t)2L = \frac{1}{2T} \sum_{t=1}^T (\mathbf{y}_t - \mathbf{\hat{y}}_t)^2L=2T1∑t=1T(yt−y^t)2
3. 梯度下降与数学推导
标量形式(以WhhW_{hh}Whh为例):
∂L∂Whh=∑t=1T∂L∂yt⋅∂yt∂ht⋅(∏k=1t∂hk∂hk−1)⋅∂h1∂Whh \frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{y}_t} \cdot \frac{\partial \mathbf{y}_t}{\partial \mathbf{h}_t} \cdot \left( \prod_{k=1}^t \frac{\partial \mathbf{h}_k}{\partial \mathbf{h}_{k-1}} \right) \cdot \frac{\partial \mathbf{h}_1}{\partial W_{hh}} ∂Whh∂L=t=1∑T∂yt∂L⋅∂ht∂yt⋅(k=1∏t∂hk−1∂hk)⋅∂Whh∂h1
其中,∂hk∂hk−1=WhhT⋅diag(1−tanh2(⋅))\frac{\partial \mathbf{h}_k}{\partial \mathbf{h}_{k-1}} = \mathbf{W}_{hh}^T \cdot \text{diag}(1 - \tanh^2(\cdot))∂hk−1∂hk=WhhT⋅diag(1−tanh2(⋅)),导致梯度消失/爆炸。
矩阵形式:
∂L∂Whh=∑t=1Tdiag(1−ht2)⋅ht−1T⋅(WhyT(y^t−yt)∏k=t1WhhTdiag(1−hk2)) \frac{\partial L}{\partial \mathbf{W}_{hh}} = \sum_{t=1}^T \text{diag}(1 - \mathbf{h}_t^2) \cdot \mathbf{h}_{t-1}^T \cdot \left( \mathbf{W}_{hy}^T (\mathbf{\hat{y}}_t - \mathbf{y}_t) \prod_{k=t}^1 \mathbf{W}_{hh}^T \text{diag}(1 - \mathbf{h}_k^2) \right) ∂Whh∂L=t=1∑Tdiag(1−ht2)⋅ht−1T⋅(WhyT(y^t−yt)k=t∏1WhhTdiag(1−hk2))
4. PyTorch 代码案例
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
# 数据生成
seq_len = 20
time = torch.arange(0, seq_len, 0.1)
data = torch.sin(time) + torch.randn(seq_len * 10) * 0.1
# 转换为序列数据
def create_dataset(data, window=5):
X, y = [], []
for i in range(len(data)-window):
X.append(data[i:i+window])
y.append(data[i+window])
return torch.stack(X), torch.stack(y)
X, y = create_dataset(data, window=5)
X = X.unsqueeze(-1).float() # (samples, window, features)
y = y.unsqueeze(-1).float()
# 定义模型
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
out, _ = self.rnn(x) # out: (batch, seq, hidden)
out = self.fc(out[:, -1, :]) # 取最后一个时间步
return out
model = RNN(1, 32, 1)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 训练
epochs = 100
losses = []
for epoch in range(epochs):
optimizer.zero_grad()
outputs = model(X)
loss = criterion(outputs, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # 梯度裁剪
optimizer.step()
losses.append(loss.item())
# 可视化损失
plt.plot(losses)
plt.title('Training Loss')
plt.show()
# 预测
with torch.no_grad():
pred = model(X)
plt.plot(time[5:], y.numpy(), label='True')
plt.plot(time[5:], pred.numpy(), label='Predicted')
plt.legend()
plt.show()
5. 可视化展示
- 损失曲线:展示训练过程中损失下降。
- 预测对比:真实值与预测值的时间序列对比。
- 隐藏状态可视化(可选):通过 PCA 降维展示隐藏状态变化。
6. 应用场景与优缺点
- 应用:时间序列预测、文本生成、机器翻译。
- 优点:处理变长序列,捕捉时序依赖。
- 缺点:梯度消失/爆炸,长程依赖困难,计算效率低。
7. 改进方法
- 结构改进:使用 LSTM/GRU 的门控机制,例如 LSTM 的遗忘门:
ft=σ(Wf[ht−1,xt]+bf)f_t = \sigma(\mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f)ft=σ(Wf[ht−1,xt]+bf) - 梯度裁剪:限制梯度最大值,防止爆炸。
- 优化算法:Adam 自适应学习率。
- 注意力机制:增强长距离依赖捕捉能力。
8. 数学推导改进(LSTM 示例)
LSTM 通过细胞状态Ct\mathbf{C}_tCt传递信息,梯度流动更稳定:
Ct=ft⊙Ct−1+it⊙C~t \mathbf{C}_t = f_t \odot \mathbf{C}_{t-1} + i_t \odot \tilde{\mathbf{C}}_t Ct=ft⊙Ct−1+it⊙C~t
其中遗忘门ftf_tft控制历史信息保留,避免传统 RNN 的连乘梯度,缓解消失问题。
通过上述步骤,您可系统掌握 RNN 的核心理论、实现及优化方法。

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。
更多推荐
所有评论(0)