文章目录
- 1. 准备工作
- 2. 训练网络
- 3. 测试网络
- 4. 训练和测试循环
- 5. 模型保存
- 6. 最终完整代码
- 7. 结果截图
使用PyTorch训练一个手写数字识别模型(MNIST)
在这篇博客中,使用了PyTorch构建一个简单的神经网络来识别手写数字。将使用MNIST数据集,这是一个经典的机器学习基准数据集,用于测试各种模型的性能。MNIST数据集包含大约60,000个训练样本和10,000个测试样本,每个样本都是一个28x28像素的手写数字图像,标签为0到9。
1. 准备工作
首先导入必要的库,并定义一个简单的神经网络结构。这个神经网络由三个线性层组成,每个线性层之间使用ReLU激活函数进行激活。最后一层使用log softmax作为输出。类似下图
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return F.log_softmax(x, dim=1)
初始化神经网络、损失函数和优化器。
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
定义数据预处理的转换。将图像转换为PyTorch张量,并对图像进行标准化处理。
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307<