我们建议将本教程作为 notebook 而不是脚本运行。要下载 notebook(.ipynb
)文件,请点击页面顶部的链接。
PyTorch 提供了精心设计的模块和类,如 torch.nn
、torch.optim
、Dataset
和 DataLoader
,帮助你创建和训练神经网络。为了充分利用这些工具的强大功能并根据你的问题进行定制,你需要真正理解它们的工作原理。为了帮助你理解,我们将首先在 MNIST 数据集上训练一个基本的神经网络,而不使用这些模型中的任何功能;最初我们只会使用最基本的 PyTorch 张量操作。然后,我们将逐步添加一个特性,来自 torch.nn
、torch.optim
、Dataset
或 DataLoader
,逐个展示每个部分的作用,以及它如何使代码变得更加简洁或灵活。
本教程假设你已经安装了 PyTorch,并且熟悉张量操作的基础知识。(如果你熟悉 Numpy 数组操作,你会发现这里使用的 PyTorch 张量操作几乎是一样的)。
1. MNIST 数据集设置
我们将使用经典的 MNIST 数据集,它由手绘数字(0 到 9)构