1.导包
import torch
import matplotlib.pyplot as plt
import json
from model import AlexNet
from PIL import Image
from torchvision import transforms
2.数据预处理
data_transform = transforms.Compose(
[transforms.Resize((224, 224)), # 将图片重新裁剪
transforms.ToTensor(), # 转化为tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 标准化数据
3.加载测试图片
# load image
img = Image.open("1.jpeg") # 网上随便下载,放到好找的路径下
plt.imshow(img) # 直接载入图像
img = data_transform(img) 在预处理过程中吧channel提到前面
img = torch.unsqueeze(img, dim=0) # 添加batch维度
4.读取分类文件
# read class_indent
try:
# 读取保存在json文件中索引对应的类别名称
json_file = open('./class_indices,json', 'r')
class_indict = json.load(json_file) # 将json文件解码成字典格式
except Exception as e:
print(e)
exit(-1)
5.初始化网络
output = torch.squeeze(model(img)):先将图片通过正向传播得到输出,再把输出的batch压缩
predict = torch.softmax(output, dim=0):通过softmax得到一个概率分布
predict_cla = torch.argmax(predict).numpy():找到概率最大处所对应的索引值
print将类别名称和预测概率输出
# create model
model = AlexNet(num_classes=5)
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path)) # 载入网络模型
model.eval() # 关闭dropout
with torch.no_grad():
output = torch.squeeze(model(img))
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()
6.预测结果
容易把玫瑰识别成郁金香,把蒲公英识别成向日葵,郁金香,向日葵,小雏菊可以很好的识别出来,模型的准确率还是有点低。大家自己尝试测试一下吧哈哈。
PyTorch搭建AlexNet网络合集:
PyTorch搭建AlexNet网络模型-CSDN博客
PyTorch搭建AlexNet训练集-CSDN博客
Pytorch搭建AlexNet 预测实现-CSDN博客