Pytorch实战01——CIAR10数据集

目录

1、model.py文件 (预训练的模型)

2、train.py文件(会产生训练好的.th文件)

3、predict.py文件(预测文件)

4、结果展示:


1、model.py文件 (预训练的模型)

import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # RGB图像;  这里用了16个卷积核;卷积核的尺寸为5x5的
        self.conv1 = nn.Conv2d(3, 16, 5)  # 输入的是RBG图片,所以in_channel为3; out_channels=卷积核个数;kernel_size:5x5的
        self.pool1 = nn.MaxPool2d(2, 2)  # kernal_size:2x2   stride:2
        self.conv2 = nn.Conv2d(16, 32, 5)  # 这里使用32个卷积核;kernal_size:5x5
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)  # 全连接层的输入,是一个一维向量,所以我们要把输入的特征向量展平。
                                           # 将得到的self.poolx(x) 的output(32,5,5)展开;  图片上给的全连接层是120
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)  # 这里的10,是需要根据训练集修改的

    def forward(self, x):   # 正向传播
        # Pytorch Tensor的通道排序:[channel,height,width]
        '''
            卷积后的尺寸大小计算:
                N = (W-F+2P)/S + 1
                其中,默认的padding:0   stride:1
                    ①输入图片大小:WxW
                    ②Filter大小 FxF  (卷积核大小)
                    ③步长S
                    ④padding的像素数P
        '''
        x = F.relu(self.conv1(x))   # 输入特征图为32x32大小的RGB图片;  input(3,32,32)  output(16,28,28)
        x = self.pool1(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半  output(16,14,14)   池化层,只改变特征矩阵的高和宽;
        x = F.relu(self.conv2(x))   # output(32, 10, 10)  因为第二个卷积层的卷积核大小是32个,这里就是32
        x = self.pool2(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半output(32, 5, 5)

        x = x.view(-1, 32*5*5)   # x.view()  将其展开成一维向量,-1代表第一个维度
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
# 测试下
# import torch
# input1 = torch.rand([32,3,32,32])
# model = LeNet()
# print(model)
# output = model(input1)

2、train.py文件(会产生训练好的.th文件)

import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data
import torchvision
from torch import nn, optim
from torchvision import transforms

from pilipala_pytorch.pytorch_learning.Test1_pytorch_demo.model import LeNet

# 1、下载数据集
# 图形预处理 ;其中transforms.Compose()是用来组合多个图像转换操作的,使得这些操作可以顺序地应用于图像。
transform = transforms.Compose(
    [transforms.ToTensor(),   # 将PIL图像或ndarray转换为torch.Tensor,并将像素值的范围从[0,255]缩放到[0.0, 1.0]
     transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]   # 对图像进行标准化;标准化通常用于使模型的训练更加稳定。
)
# 50000张训练图片
train_ds = torchvision.datasets.CIFAR10('data',
                                        train=True,
                                        transform=transform,
                                        download=False)
# 10000张测试图片
test_ds = torchvision.datasets.CIFAR10('data',
                                       train=False,
                                       transform=transform,
                                       download=False)
# 2、加载数据集
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=36, shuffle=True, num_workers=0)    # shuffle数据是否是随机提取的,一般设置为True
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=10000, shuffle=True, num_workers=0)

test_image,test_label = next(iter(test_dl))  # 将test_dl 转换为一个可迭代的迭代器,通过next()方法获取数据

classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

'''
    标准化处理:output = (input - 0.5) / 0.5
  反标准化处理: input = output * 0.5 + 0.5 = output / 2 + 0.5
'''
# 测试下展示图片
# def imshow(img):
#     img = img / 2 + 0.5   # unnormalize  反标准化处理
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1,2,0)))
#     plt.show()
#
# # 打印标签
# print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
# imshow(torchvision.utils.make_grid(test_image))


# 实例化网络模型
net = LeNet()
# 定义相关参数
loss_function = nn.CrossEntropyLoss()  # 定义损失函数
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器, 这里使用的是Adam优化器
# 训练过程
for epoch in range(5):  # 定义循环,将训练集迭代多少轮
    running_loss = 0.0  # 叠加,训练过程中的损失
    for step,data in enumerate(train_dl,start=0):  # 遍历训练集样本
        inputs, labels = data   # 获取图像及其对应的标签
        optimizer.zero_grad()  # 将历史梯度清零;如果不清除历史梯度,就会对计算的历史梯度进行累加

        outputs = net(inputs)   # 将输入的图片输入到网络,进行正向传播
        loss = loss_function(outputs, labels)  # outputs网络预测的值, labels真实标签
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if step % 500 == 499:
            with torch.no_grad():  # with 是一个上下文管理器
                outputs = net(test_image)  # [batch,10]
                predict_y = torch.max(outputs, dim=1)[1]   # 网络预测最大的那个
                accuracy = (predict_y == test_label).sum().item() / test_label.size(0)  # 得到的是tensor  (predict_y == test_label).sum()  要通过item()拿到数值
                print("[%d, %5d] train_loss: %.3f test_accuracy:%.3f" % (epoch + 1, step + 1, running_loss / 500, accuracy))
                running_loss = 0.0
print('Finished Training')

save_path = './Lenet.pth'  # 保存模型
torch.save(net.state_dict(), save_path)  # net.state_dict() 模型字典;save_path 模型路径

3、predict.py文件(预测文件)

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet

transform = transforms.Compose(
    [transforms.Resize((32, 32)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship' , 'truck')

net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))  # 加载train里面的训练好 产生的模型。

im = Image.open('2.jpg')  # 载入准备好的图片
im = transform(im)  # 如果要将图片放入网络,进行正向传播,就得转换下格式   得到的结果为:[C,H,W]
im = torch.unsqueeze(im, dim=0)    # 增加一个维度;得到 [N,C,H,W],从而模拟一个批量大小为1的输入。

with torch.no_grad():  # 不需要计算损失梯度
    outputs = net(im)
    predict = torch.max(outputs, dim=1)[1].data.numpy()   # outputs是一个张量;torch.max()用于找到张量在指定维度上的最大值;
                                    # torch.max()函数返回两个张量,一个包含最大值,另一个包含最大值的作用。
                                    # .data()属性用于从变量中提取底层的张量数据。直接使用.data()已经被认为是不安全的,推荐使用.detach()
                                    # .numpy() 表示将pytorch转换成numpy数组,从而使用numpy库的各种功能来操作数据。
print(classes[int(predict)])

#     predict = torch.softmax(outputs,dim=1)  # 可以返回概率
# print(predict)

4、结果展示:

返回结果:预测是猫的概率为 86%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/453923.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

QGIS 开发之旅二《构建插件工程》

上一篇文章写了二次开发环境的构建,这一章我们从零开始构建插件工程,并理解下QIGIS 如何识别插件程序的。 1、创建QGIS 工程 新建项目,选择下面的空工程 工程创建成功后,是下面的样子,没有任何文件 2、配置QGIS工程 …

docker私有仓库-harbor的搭建

docker 官方提供的私有仓库 registry,用起来虽然简单 ,但在管理的功能上存在不足。 Harbor是一个用于存储和分发Docker镜像的企业级Registry服务器,harbor使用的是官方的docker registry(v2命名是distribution)服务去完成。harbor在docker di…

基于Java (spring-boot)的人才招聘系统

一、项目介绍 公司: IT公司的注册与管理 招聘要求的发布与维护 站内私信 求职者: 招聘需求浏览 招聘需求筛选(按岗位、薪酬、城市、地区等) 简历编辑,建立投递等 站内私信 管理员: 用户信息维护 岗…

vue学习笔记24-组件事件配合v-model使用

搜索时v-model绑定的search数据时时发生变化 watch侦听器时时监察变化&#xff0c;一旦数据发生变化 &#xff0c;就实时发送数据给父组件 子组件的完整代码&#xff1a; <template>搜索&#xff1a;<input type"text" v-model"search"> <…

数学建模【时间序列】

一、时间序列简介 时间序列也称动态序列&#xff0c;是指将某种现象的指标数值按照时间顺序排列而成的数值序列。时间序列分析大致可分成三大部分&#xff0c;分别是描述过去、分析规律和预测未来&#xff0c;本篇将主要介绍时间序列分析中常用的三种模型&#xff1a;季节分解…

通过对话式人工智能实现个性化用户体验

智能交流新时代&#xff1a;如何选择对话式人工智能产品 在快速发展的数字环境中&#xff0c;对话式人工智能正在彻底改变企业与客户互动的方式。 通过集成机器学习、自然语言处理和语音识别等先进技术&#xff0c;对话式人工智能可提供个性化、无缝的用户体验。 了解对话式人…

数字工厂环境中3D开发工具HOOPS的应用及其价值

随着新一代工业技术的推进&#xff0c;数字工厂的理念逐渐被企业所接受和推广。其中&#xff0c;HOOPS技术平台在数字工厂的设定和运营中扮演了重要的角色&#xff0c;其应用带来了诸多优势。 HOOPS技术平台包含HOOPS Visualize, HOOPS Exchange和HOOPS Communicator等强大工具…

代码随想录day19(2)二叉树:二叉树的最大深度(leetcode104)

题目要求&#xff1a;求出二叉树的最大深度 思路&#xff1a;首先要区分二叉树的高度与深度。二叉树的高度是任一结点到叶子结点的距离&#xff0c;而二叉树的深度指的是任一节点到根节点的距离&#xff08;从1开始&#xff09;。所以求高度使用后序遍历&#xff08;从下往上&…

自动化测试系列-Selenium三种等待详解

第一种也是最简单粗暴的一种办法就是强制等待sleep(time)&#xff0c;强制让程序等time秒时间&#xff0c;不管程序能不能跟上速度&#xff0c;还是已经提前到了&#xff0c;都必须等time时长。 如下代码案例所示: from selenium import webdriverfrom time import sleepdriv…

软件测试Pytest实现接口自动化应该如何在用例执行后打印日志到日志目录生成日志文件?

Pytest可以使用内置的logging模块来实现接口自动化测试用例执行后打印日志到日志目录以生成日志文件。以下是实现步骤&#xff1a; 1、在pytest配置文件&#xff08;conftest.py&#xff09;中&#xff0c;定义一个日志输出路径&#xff0c;并设置logging模块。 import loggi…

【PLC】现场总线和工业以太网汇总

1、 现场总线 1.1 什么是现场总线 1&#xff09;非专业描述&#xff1a; 如下图&#xff1a;“人机界面”一般通过以太网连接“控制器(PLC)”&#xff0c;“控制器(PLC)”通过 “现场总线”和现场设备连接。 2&#xff09;专业描述&#xff08;维基百科&#xff09; 现场总线…

「建议收藏」常用adb操作命令详解

1、查看当前运行的所有设备 adb devices 返回当前设备列表 这个命令是查看当前连接的设备, 连接到计算机的android设备或者模拟器将会列出显示 2、安装软件 adb install 验证是否成功。需要到设备的 data/app路径下查看是否有该包名 这个命令将指定的apk文件安装到设备上 …

案例分析篇00-【历年案例分析真题考点汇总】与【专栏文章案例分析高频考点目录】(2024年软考高级系统架构设计师冲刺知识点总结-案例分析篇-先导篇)

专栏系列文章&#xff1a; 2024高级系统架构设计师备考资料&#xff08;高频考点&真题&经验&#xff09;https://blog.csdn.net/seeker1994/category_12593400.html 案例分析篇01&#xff1a;软件架构设计考点架构风格及质量属性 案例分析篇11&#xff1a;UML设计考…

Windows系统安装OpenSSH结合VS Code远程ssh连接Ubuntu【内网穿透】

&#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法|MySQL| ​&#x1f4ab;个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-mEkKUraSFHLKkzIj {font-family:"trebuchet ms",verdana,arial,sans-serif;f…

MySQL 数据库 下载地址 国内阿里云站点

https://mirrors.aliyun.com/mysql/ 以 MySQL 5.7 为例 https://mirrors.aliyun.com/mysql/MySQL-5.7/ 各个版本很齐全&#xff0c;在这里下载要比去 MySQL 官网下载快很多&#xff0c;安逸得很 那我们下期见&#xff0c;拜拜&#xff01;

什么洗地机值得推荐?旗舰洗地机希亦、追觅、西屋、海尔实际表现如何?

洗地机这个产品相信大家已经不陌生了&#xff0c;它集合吸尘器和电动扫地拖把的功能&#xff0c;轻轻推拉便可以解决地面上的赃物&#xff0c;且不用我们手动清洗滚刷&#xff0c;深得家务人的喜爱&#xff0c;可是&#xff0c;当我们真正要去选购的时候&#xff0c;还是很纠结…

linux统计一个文件有多少行

第一种方法 cat -n anaconda-ks.cfg&#xff08;文件&#xff09; 第二种 wc 文件名 wc是 words count之意

Linux系统安全②SNAT与DNAT

一.SNAT 1.定义 利用SNAT技术实现2台私网地址都可以访问公网 2.实验环境准备 &#xff08;1&#xff09;三台服务器&#xff1a;PC1客户端、PC2网关、PC3服务端。 &#xff08;2&#xff09;硬件要求&#xff1a;PC1和PC3均只需一块网卡、PC2需要2块网卡 &#xff08;3&a…

[linux][调度] linux 下如何观察线程调度延时 ?

1 什么是调度延时 在实际业务中&#xff0c;大多数情况下&#xff0c;线程都不是一直占着 cpu 在运行的。在读写文件&#xff0c;收发网络数据时&#xff0c;可能会阻塞&#xff0c;阻塞就会进入睡眠&#xff0c;触发调度&#xff1b;在线程等待锁的时候&#xff0c;也可能会触…

ELK 安装部署

文章目录 1.日志收集规划2.Elasticsearch部署2.1.Elasticsearch安装2.2.Elasticsearch-head安装2.3.Elasticsearch设置分片数2.4.elasticsearch健康检查 3.Kibana部署4.Logstash部署5.Filebeat部署 开源中间件 # Elastic Stackhttps://iothub.org.cn/docs/middleware/ https:/…