LSTM 与 GAN 实现

本项目围绕两个经典建模任务展开:手动实现 LSTM 单元用于回文序列预测,以及在 MNIST 数据集上训练全连接 GAN 并进行潜在空间插值实验。两部分均接近底层实现——不使用高层 API 封装——以直接展示对底层机制的理解。

在 GitHub 上查看


亮点

  • 从头实现 LSTM 门控逻辑(输入门、遗忘门、细胞门、输出门;隐藏状态与细胞状态更新),不使用 torch.nn.LSTM
  • 在不同长度的回文序列上验证自定义 LSTM,分析预测精度随序列长度增加的退化情况
  • 从头构建全连接 GAN 用于 MNIST:生成器使用 BatchNorm 和 Tanh,判别器使用 LeakyReLU 和 sigmoid,交替进行对抗优化
  • 通过潜在向量插值实验,验证生成器所学表示的连续性
  • 完整保留训练产物:各阶段样本网格图、损失曲线、准确率图、序列化生成器 checkpoint 及书面技术报告

Part 1 — 手动 LSTM 序列建模

LSTM 在 lstm.py 中实现,未使用 torch.nn.LSTM。所有四个门和细胞/隐藏状态更新方程均通过 nn.Linear 层显式编写,使递归逻辑完全透明。

训练设置: 可配置序列长度的回文数据集,梯度裁剪,逐 epoch 损失和准确率追踪。

训练损失 — 序列长度 10
准确率 — 序列长度 10
训练损失 — 序列长度 30(更长依赖)
准确率 — 序列长度 30

模型对中短长度回文序列表现可靠,随序列长度增加退化较慢,体现了门控机制在保持长距离信息方面的作用。


Part 2 — 在 MNIST 上训练全连接 GAN

基于标准生成器-判别器对抗框架,在 MNIST 上训练多层感知机 GAN:

  • 生成器 — 堆叠线性层,BatchNorm,LeakyReLU 激活,Tanh 输出
  • 判别器 — 线性层,LeakyReLU,sigmoid 输出
  • 训练 — 交替更新生成器和判别器,BCE 损失;定期保存样本网格图
训练初期 — 噪声、无结构样本
训练中期 — 数字结构初现
训练末期 — 可辨认的数字样本

使用全连接层(而非卷积层)使得每个训练阶段的目标函数和模型行为更易于观察和分析。


潜在空间插值

在两对潜在向量之间进行插值,验证生成器是否学到了平滑、连续的表示,而非仅记忆离散样本。

两个潜在向量之间的插值
平滑过渡验证了连续的潜在空间结构

技术概览

   
语言 Python 3
框架 PyTorch
模型 手动实现 LSTM,全连接 GAN(MLP)
任务 序列建模、图像生成、潜在空间插值
关键组件 自定义门控方程、对抗训练循环、BatchNorm、梯度裁剪
产物 损失/准确率曲线、分阶段样本网格图、生成器 checkpoint、技术报告