以下是使用Mermaid语法绘制的详细流程图,解释了这段MNIST手写数字识别代码的完整执行流程:

评估模型
禁用梯度
设置评估模式
遍历测试集
前向传播
预测类别
统计正确数
计算准确率
打印准确率
训练模型
遍历epochs
设置训练模式
遍历batches
梯度清零
前向传播
计算损失
反向传播
参数更新
记录损失
打印epoch结果
构建神经网络
输出
初始化三层全连接层
定义SimpleNN类
定义ReLU和Softmax
模型实例
数据预处理
输出
加载MNIST数据集
定义transform
创建DataLoader
训练集/测试集
开始
数据预处理
构建神经网络
训练模型
评估模型
结束

流程详解:

  1. 数据预处理阶段

    ToTensor
    Normalize
    download=True
    batch_size=64
    B1
    B2
    B3
    B4
    • 先定义图像转换规则(归一化+标准化)
    • 自动下载MNIST数据集
    • 创建可分批次加载的数据加载器
  2. 模型构建阶段

    flowchart LR
        C1 -->|继承nn.Module| C2
        C2 -->|784→128→64→10| C3
        C3 -->|激活函数| C4
    
    • 构建三层的全连接网络
    • 每层后接ReLU激活函数
    • 输出层使用Softmax
  3. 训练阶段

    epochs=5
    batch=64
    forward
    CrossEntropy
    backward
    Adam优化
    D2
    D3
    D5
    D6
    D7
    D8
    D9
    • 前向传播计算预测值
    • 用交叉熵计算损失
    • 反向传播更新参数
    • 循环5个epoch
  4. 评估阶段

    flowchart LR
        E3 -->|10000张测试图| E4
        E4 -->|model(images)| E5
        E5 -->|torch.max| E6
        E6 -->|correct/total| E7
    
    • 禁用梯度计算模式
    • 统计预测正确的数量
    • 计算最终准确率

关键点说明:

  • 数据流向:原始图像 → 张量 → 标准化 → 网络各层 → 概率输出
  • 训练循环:每次处理64张图片,共重复5轮完整数据集遍历
  • 梯度更新:Adam优化器根据损失梯度调整784×128+128×64+64×10=109,386个参数
  • 评估机制:比较预测概率最大的类别与实际标签是否一致

这个流程图展示了从数据加载到模型评估的完整机器学习pipeline,每个步骤都包含PyTorch的关键操作。实际执行时,数据会像流水线一样依次通过这些处理阶段。

Logo

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。

更多推荐