01——LenNet网络结构,图片识别

目录

1、model.py文件 (预训练的模型)

2、train.py文件(会产生训练好的.th文件)

3、predict.py文件(预测文件)

4、结果展示:


1、model.py文件 (预训练的模型)

import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # RGB图像;  这里用了16个卷积核;卷积核的尺寸为5x5的
        self.conv1 = nn.Conv2d(3, 16, 5)  # 输入的是RBG图片,所以in_channel为3; out_channels=卷积核个数;kernel_size:5x5的
        self.pool1 = nn.MaxPool2d(2, 2)  # kernal_size:2x2   stride:2
        self.conv2 = nn.Conv2d(16, 32, 5)  # 这里使用32个卷积核;kernal_size:5x5
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)  # 全连接层的输入,是一个一维向量,所以我们要把输入的特征向量展平。
                                           # 将得到的self.poolx(x) 的output(32,5,5)展开;  图片上给的全连接层是120
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)  # 这里的10,是需要根据训练集修改的

    def forward(self, x):   # 正向传播
        # Pytorch Tensor的通道排序:[channel,height,width]
        '''
            卷积后的尺寸大小计算:
                N = (W-F+2P)/S + 1
                其中,默认的padding:0   stride:1
                    ①输入图片大小:WxW
                    ②Filter大小 FxF  (卷积核大小)
                    ③步长S
                    ④padding的像素数P
        '''
        x = F.relu(self.conv1(x))   # 输入特征图为32x32大小的RGB图片;  input(3,32,32)  output(16,28,28)
        x = self.pool1(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半  output(16,14,14)   池化层,只改变特征矩阵的高和宽;
        x = F.relu(self.conv2(x))   # output(32, 10, 10)  因为第二个卷积层的卷积核大小是32个,这里就是32
        x = self.pool2(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半output(32, 5, 5)

        x = x.view(-1, 32*5*5)   # x.view()  将其展开成一维向量,-1表示第一个维度batch需要自动推理
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
# 测试下
# import torch
# input1 = torch.rand([32,3,32,32])
# model = LeNet()
# print(model)
# output = model(input1)

2、train.py文件(会产生训练好的.th文件)

import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data
import torchvision
from torch import nn, optim
from torchvision import transforms

from pilipala_pytorch.pytorch_learning.Test1_pytorch_demo.model import LeNet

# 1、下载数据集
# 图形预处理 ;其中transforms.Compose()是用来组合多个图像转换操作的,使得这些操作可以顺序地应用于图像。
transform = transforms.Compose(
    [transforms.ToTensor(),   # 将PIL图像或ndarray转换为torch.Tensor,并将像素值的范围从[0,255]缩放到[0.0, 1.0]
     transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]   # 对图像进行标准化;标准化通常用于使模型的训练更加稳定。
)
# 50000张训练图片
train_ds = torchvision.datasets.CIFAR10('data',
                                        train=True,
                                        transform=transform,
                                        download=False)
# 10000张测试图片
test_ds = torchvision.datasets.CIFAR10('data',
                                       train=False,
                                       transform=transform,
                                       download=False)
# 2、加载数据集
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=36, shuffle=True, num_workers=0)    # shuffle数据是否是随机提取的,一般设置为True
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=10000, shuffle=True, num_workers=0)

test_image,test_label = next(iter(test_dl))  # 将test_dl 转换为一个可迭代的迭代器,通过next()方法获取数据

classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

'''
    标准化处理:output = (input - 0.5) / 0.5
  反标准化处理: input = output * 0.5 + 0.5 = output / 2 + 0.5
'''
# 测试下展示图片
# def imshow(img):
#     img = img / 2 + 0.5   # unnormalize  反标准化处理
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1,2,0)))
#     plt.show()
#
# # 打印标签
# print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
# imshow(torchvision.utils.make_grid(test_image))


# 实例化网络模型
net = LeNet()
# 定义相关参数
loss_function = nn.CrossEntropyLoss()  # 定义损失函数
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器, 这里使用的是Adam优化器
# 训练过程
for epoch in range(5):  # 定义循环,将训练集迭代多少轮
    running_loss = 0.0  # 叠加,训练过程中的损失
    for step,data in enumerate(train_dl,start=0):  # 遍历训练集样本
        inputs, labels = data   # 获取图像及其对应的标签
        optimizer.zero_grad()  # 将历史梯度清零;如果不清除历史梯度,就会对计算的历史梯度进行累加

        outputs = net(inputs)   # 将输入的图片输入到网络,进行正向传播
        loss = loss_function(outputs, labels)  # outputs网络预测的值, labels真实标签
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if step % 500 == 499:
            with torch.no_grad():  # with 是一个上下文管理器
                outputs = net(test_image)  # [batch,10]
                predict_y = torch.max(outputs, dim=1)[1]   # 网络预测最大的那个
                accuracy = (predict_y == test_label).sum().item() / test_label.size(0)  # 得到的是tensor  (predict_y == test_label).sum()  要通过item()拿到数值
                print("[%d, %5d] train_loss: %.3f test_accuracy:%.3f" % (epoch + 1, step + 1, running_loss / 500, accuracy))
                running_loss = 0.0
print('Finished Training')

save_path = './Lenet.pth'  # 保存模型
torch.save(net.state_dict(), save_path)  # net.state_dict() 模型字典;save_path 模型路径

测试下展示图片:

运行下,train.py文件,看下正确率、损失率:

3、predict.py文件(预测文件)

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet

transform = transforms.Compose(
    [transforms.Resize((32, 32)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship' , 'truck')

net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))  # 加载train里面的训练好 产生的模型。

im = Image.open('2.jpg')  # 载入准备好的图片
im = transform(im)  # 如果要将图片放入网络,进行正向传播,就得转换下格式   得到的结果为:[C,H,W]
im = torch.unsqueeze(im, dim=0)    # 增加一个维度;得到 [N,C,H,W],从而模拟一个批量大小为1的输入。

with torch.no_grad():  # 不需要计算损失梯度
    outputs = net(im)
    predict = torch.max(outputs, dim=1)[1].data.numpy()   # outputs是一个张量;torch.max()用于找到张量在指定维度上的最大值;
                                    # torch.max()函数返回两个张量,一个包含最大值,另一个包含最大值的作用。
                                    # .data()属性用于从变量中提取底层的张量数据。直接使用.data()已经被认为是不安全的,推荐使用.detach()
                                    # .numpy() 表示将pytorch转换成numpy数组,从而使用numpy库的各种功能来操作数据。
print(classes[int(predict)])

#     predict = torch.softmax(outputs,dim=1)  # 可以返回概率
# print(predict)

4、结果展示:

返回结果:预测是猫的概率为 86%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/461395.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Day17 深入类加载机制

Day17 深入类加载机制 文章目录 Day17 深入类加载机制一、初识类加载过程二、深入类加载过程三、利用类加载过程理解面试题四、类加载器五、类加载器分类六、类加载器之间的层次关系七、双亲委派模型 - 概念八、双亲委派模型 - 工作过程九、双亲委派模型 - 好处十、双亲委派原则…

Jmeter---分布式

分布式:多台机协作,以集群的方式完成测试任务,可以提高测试效率。 分布式架构:控制机(分发任务)与多台执行机(执行任务) 环境搭建: 不同的测试机上安装 Jmeter 配置基…

Sparse Convolution 讲解

文章目录 1. 标准卷积与Sparse Conv对比(1)普通卷积(2) 稀疏卷积(3) 改进的稀疏卷积(subm)2 Sparse Conv 官方API3. Sparse Conv 计算3. 1 Sparse Conv 计算流程3. 2 案例3.2.1 普通稀疏卷积3.2.2 subm模式的稀疏卷积3D点云数据非常稀疏,尤其体素化处理后(比如200k的点放…

【算法篇】七大基于比较的排序算法精讲

目录 排序 1.直接插入排序 2.希尔排序 3.直接选择排序 4.堆排序 5.冒泡排序 6.快速排序 7.归并排序 排序 排序算法的稳定性:假设在待排序的序列中,有多个相同的关键字,经过排序后,这些关键字的先后顺序不发生改变&#…

Spring项目问题—前后端交互:Method Not Allowed

问题 前后端交互时出现Method Not Allowed问题 Ajax中使用的是get,方法仍然出现post方法报错 Resolved [org.springframework.web.HttpRequestMethodNotSupportedException: Request method POST not supported] 浏览器中没有报错,只是接收不到后端返…

解锁数据潜力:OceanBase国产数据库学习不容错过的秘密!

介绍:OceanBase是一款由阿里巴巴和蚂蚁金服自主研发的通用分布式关系型数据库,它专为企业级应用而设计,具有金融级别的可靠性。以下是对OceanBase的详细介绍: 高可用性:OceanBase通过实现Paxos多数派协议和多副本特性&…

MySql入门教程--MySQL数据库基础操作

꒰˃͈꒵˂͈꒱ write in front ꒰˃͈꒵˂͈꒱ ʕ̯•͡˔•̯᷅ʔ大家好,我是xiaoxie.希望你看完之后,有不足之处请多多谅解,让我们一起共同进步૮₍❀ᴗ͈ . ᴗ͈ აxiaoxieʕ̯•͡˔•̯᷅ʔ—CSDN博客 本文由xiaoxieʕ̯•͡˔•̯᷅ʔ 原创 CSDN …

C语言从入门到熟悉------第四阶段

指针 地址和指针的概念 要明白什么是指针,必须先要弄清楚数据在内存中是如何存储的,又是如何被读取的。如果在程序中定义了一个变量,在对程序进行编译时,系统就会为这个变量分配内存单元。编译系统根据程序中定义的变量类型分配…

【滤波专题-第8篇】ICA降噪方法——类EMD联合ICA降噪及MATLAB代码实现(以VMD-ICA为例)

今天来介绍一种效果颇为不错的降噪方法。(针对高频白噪声) 上一篇文章我们讲到了FastICA方法。在现实世界的许多情况下,噪声往往接近高斯分布,而有用的信号(如语音、图像特征等)往往表现出非高斯的特性。F…

unity学习(60)——选择角色界面--MapHandler2-MapHandler.cs

1.新建一个脚本&#xff0c;里面有static变量loadingPlayerList using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks;namespace Assets.Scripts.Model {internal class LoadData{public static List<Pl…

3D地图在BI大屏中的应用实践

前言 随着商业智能的不断发展&#xff0c;数据可视化已成为一项重要工具&#xff0c;有助于用户更好地理解数据和分析结果。其中&#xff0c;3D地图作为一种可视化工具&#xff0c;已经在BI大屏中得到了广泛地应用。 3D地图通过将地理信息与数据相结合&#xff0c;以更加直观…

【AI】用iOS的ML(机器学习)创建自己的AI App

用iOS的ML(机器学习)创建自己的AI App 目录 用iOS的ML(机器学习)创建自己的AI App机器学习如同迭代过程CoreML 的使用方法?软件要求硬件开始吧!!构建管道:设计和训练网络Keras 转 CoreML将模型集成到 Xcode 中结论推荐超级课程: Docker快速入门到精通Kubernetes入门到…

计算机网络——物理层(数据通信基础知识)

计算机网络——物理层&#xff08;1&#xff09; 物理层的基本概念数据通信的基本知识一些专业术语消息和数据信号码元 传输速率的两种表示方法带宽串行传输和并行传输同步传输和异步传输 信道基带信号调制常用编码方式 我们今天进入物理层的学习&#xff0c;如果还没有了解OSI…

Transformer代码从零解读【Pytorch官方版本】

文章目录 1、Transformer大致有3大应用2、Transformer的整体结构图3、如何处理batch-size句子长度不一致问题4、MultiHeadAttention&#xff08;多头注意力机制&#xff09;5、前馈神经网络6、Encoder中的输入masked7、完整代码补充知识&#xff1a; 1、Transformer大致有3大应…

C++ 入门篇

目录 1、了解C 2、C关键字 2、命名空间 2.1 命名空间的定义 2.2 命名空间的使用 3. C输入与输出 4.缺省参数 4.1 缺省参数的概念 4.2 缺省参数的分类 5. 函数重载 5.1 函数重载的概念 5.2 C中支持函数重载的原理--名字修饰 6. 引用 6.1 引用概念 6.2 引用…

Docker 系列2【docker安装mysql】【开启远程连接】

文章目录 前言开始步骤1.增加mysql挂载目录2.下载镜像2.启动容器具体步骤4.无法连接5.测试连接 总结 前言 本文开始&#xff0c;默认已经安装docker&#xff0c;如果你还没有完成这个步骤&#xff0c;请查看这一篇文章【docker安装与使用】 开始步骤 1.增加mysql挂载目录 m…

网络原理(1)——UDP协议

目录 一、应用层 举个例子&#xff1a;点外卖 约定数据格式简单粗暴的例子 客户端和服务器的交互&#xff1a; 序列化和返序列化 xml、json、protobuffer 1、xml 2、json 3、protobuffer 二、传输层 端口 端口号范围划分 认识知名的端口号 三、UDP协议 端口 U…

宜搭faas服务器报错Network response was not OK

[error] https://api.dingtalk.com/v1.0/yida/forms/instances? fetch error Error: Network response was not OK 不出意外的话肯定是请求代码的某个部分出了问题&#xff1a;其中formInstanceId和updateFormDataJson是业务的内容 我检查过是没问题的。appType和systemToken…

面试经典-MySQL篇

一、MySQL组成 MySQL数据库的连接池&#xff1a;由一个线程来监听一个连接上请求以及读取请求数据&#xff0c;解析出来一条我们发送过去的SQL语句SQL接口&#xff1a;负责处理接收到的SQL语句查询解析器&#xff1a;让MySQL能看懂SQL语句查询优化器&#xff1a;选择最优的查询…

【C#】【SAP2000】读取SAP2000中所有Frame对象在指定工况的温度荷载值到Grasshopper中

if (build true) {// 连接到正在运行的 SAP2000// 使用 COM 接口获取 SAP2000 的 API 对象cOAPI mySapObject (cOAPI)System.Runtime.InteropServices.Marshal.GetActiveObject("CSI.SAP2000.API.SapObject");// 获取 SAP2000 模型对象cSapModel mySapModel mySap…