pytorch实现图像分类器

pytorch实现图像分类器

  • 一、定义LeNet网络模型
    • 1,卷积 Conv2d
    • 2,池化 MaxPool2d
    • 3,Tensor的展平:view()
    • 4,全连接 Linear
    • 5,代码:定义 LeNet 网络模型
  • 二、训练并保存网络参数
    • 1,数据预处理
    • 2,数据集
    • 3,代码
  • 三、图像分类测试

一、定义LeNet网络模型

pytorch 中的卷积、池化、输入输出层中参数的含义与位置,可参考下图:
在这里插入图片描述

1,卷积 Conv2d

常用的卷积(Conv2d)在pytorch中对应的函数是

# in_channels:输入特征矩阵的深度。如输入一张RGB彩色图像,那in_channels=3
# out_channels:输入特征矩阵的深度。也等于卷积核的个数,使用n个卷积核输出的特征矩阵深度就是n
# kernel_size:卷积核的尺寸。可以是int类型,如3 代表卷积核的height=width=3,也可以是tuple类型如(3, 5)代表卷积核的height=3,width=5
# stride:卷积核的步长。默认为1,和kernel_size一样输入可以是int型,也可以是tuple类型
# padding:补零操作,默认为0。可以为int型如1即补一圈0,如果输入为tuple型如(2, 1) 代表在上下补2行,左右补1列。
torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

经卷积后的输出层尺寸计算公式为:
在这里插入图片描述
例如:当定义 Conv2d(3, 16, 5) 和 input(3, 32, 32),步长 S 为 1,P 为0时,此时卷积核尺度 F 为5,W 为32,计算得到 output(16, 28, 28)

2,池化 MaxPool2d

最大池化(MaxPool2d)在 pytorch 中对应的函数是:

MaxPool2d(kernel_size, stride)

3,Tensor的展平:view()

注意到,在经过第二个池化层后,数据还是一个三维的Tensor (32, 5, 5),需要先经过展平后(3255)再传到全连接层:

  x = self.pool2(x)            # 第二个池化层 output(32, 5, 5)
  x = x.view(-1, 32*5*5)       # 展平 output(32*5*5)
  x = F.relu(self.fc1(x))      # 传到全连接层 output(120)

4,全连接 Linear

全连接(Linear)在 pytorch 中对应的函数是:

Linear(in_features, out_features, bias=True)

5,代码:定义 LeNet 网络模型

model.py

# 定义LeNet网络模型
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):                  # 继承于nn.Module这个父类
    def __init__(self):                  # 初始化网络结构
        super(LeNet, self).__init__()    # 多继承需用到super函数
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):            # 正向传播过程
        x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)
        x = self.pool1(x)            # output(16, 14, 14)
        x = F.relu(self.conv2(x))    # output(32, 10, 10)
        x = self.pool2(x)            # output(32, 5, 5)
        x = x.view(-1, 32*5*5)       # output(32*5*5)
        x = F.relu(self.fc1(x))      # output(120)
        x = F.relu(self.fc2(x))      # output(84)
        x = self.fc3(x)              # output(10)
        return x

二、训练并保存网络参数

1,数据预处理

ToTensor:把输入的图像数据为 shape (H x W x C) in the range [0, 255] 转化为 shape (C x H x W) in the range [0.0, 1.0],同时将 image 和 numpy 输入格式转化为 tensor
Normalize:标准化

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

2,数据集

用的是CIFAR10数据集,是 pytorch 自带的一个很经典的图像分类数据集,一共包含 10 个类别的 RGB 彩色图片。
在这里插入图片描述
注意:第一次运行程序,需要下载数据集到本地,所以第一次运行训练集下载时download=True为True,下载完成后改为False。测试集的加载则不用变化。

3,代码

名词定义
epoch对训练集的全部数据进行一次完整的训练,称为 一次 epoch
batch由于硬件算力有限,实际训练时将训练集分成多个批次训练,每批数据的大小为 batch_size
iteration 或 step对一个batch的数据训练的过程称为 一个 iteration 或 step
# 加载数据集并训练,训练集计算loss,测试集计算accuracy,保存训练好的网络参数
import torch
import torchvision
import torch.nn as nn
from model import LeNet 
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time

# 数据预处理
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 导入、加载训练集
# 导入50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='./data',      # 数据集存放目录
                                        train=True,          # 表示是数据集中的训练集
                                        download=False,       # 第一次运行时为True,下载数据集,下载完成后改为False
                                        transform=transform) # 预处理过程
# 加载训练集,实际过程需要分批次(batch)训练                                        
train_loader = torch.utils.data.DataLoader(train_set,      # 导入的训练集
                                           batch_size=50,  # 每批训练的样本数
                                           shuffle=False,  # 是否打乱训练集
                                           num_workers=0)  # 使用线程数,在windows下设置为0

# 导入测试集
# 导入10000张测试图片
test_set = torchvision.datasets.CIFAR10(root='./data', 
                                        train=False,    # 表示是数据集中的测试集
                                        download=False,transform=transform)
# 加载测试集
test_loader = torch.utils.data.DataLoader(test_set, 
                                          batch_size=10000, # 每批用于验证的样本数
                                          shuffle=False, num_workers=0)
# 获取测试集中的图像和标签,用于accuracy计算
test_data_iter = iter(test_loader)
test_image, test_label = test_data_iter.next()

#训练过程
net = LeNet()                                       # 定义训练的网络模型
loss_function = nn.CrossEntropyLoss()               # 定义损失函数为交叉熵损失函数 
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器(训练参数,学习率)

for epoch in range(5):  # 一个epoch即对整个训练集进行一次训练
    running_loss = 0.0	# 累加过程中的损失
    time_start = time.perf_counter()
    
    for step, data in enumerate(train_loader, start=0):   # enumerate遍历训练集,可以同时返回 data 和 步数,step从0开始计算
        inputs, labels = data 	# 获取训练集的图像和标签
        optimizer.zero_grad()   # 清除历史损失梯度
        
        # forward + backward + optimize
        outputs = net(inputs)  				  # 正向传播
        loss = loss_function(outputs, labels) # 计算损失
        loss.backward() 					  # 反向传播
        optimizer.step() 					  # 优化器更新参数

        # 打印耗时、损失、准确率等数据
        running_loss += loss.item()
        if step % 1000 == 999:    # print every 1000 mini-batches,每1000步打印一次
            with torch.no_grad(): # 在以下步骤中(验证过程中)不用计算每个节点的损失梯度,防止内存占用
                outputs = net(test_image) 				 # 测试集传入网络(test_batch_size=10000),output维度为[10000,10]
                predict_y = torch.max(outputs, dim=1)[1] # 以output中值最大位置对应的索引(标签)作为预测输出
                accuracy = (predict_y == test_label).sum().item() / test_label.size(0)
                
                print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %  # 打印epoch,step,loss,accuracy
                      (epoch + 1, step + 1, running_loss / 500, accuracy))
                
                print('%f s' % (time.perf_counter() - time_start))        # 打印耗时
                running_loss = 0.0

print('Finished Training')

# 保存训练得到的参数
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

三、图像分类测试

使用训练并保存好的网络参数,从数据集外找一张图像进行分类测试

# 导入包
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet

# 数据预处理
transform = transforms.Compose(
    [transforms.Resize((32, 32)), # 首先需resize成跟训练集图像一样的大小
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])    # 数据标准化

# 导入要测试的图像
im = Image.open('./car.jpg').convert('RGB')    # 若图像为4通道,则用 convert('RGB') 转化为3通道,否则 transform 会报错
im = transform(im)  # [C, H, W]
im = torch.unsqueeze(im, dim=0)  # 对数据增加一个新维度,因为tensor的参数是[batch, channel, height, width] 

# 实例化网络,加载训练好的模型参数
net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))

# 预测
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
with torch.no_grad():
    outputs = net(im)
    predict = torch.max(outputs, dim=1)[1].data.numpy()    # 找出最大概率的下标
	predicts = torch.softmax(outputs , dim=1)    # 所有分类的预测概率
print(classes[int(predict)])
print(predicts)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/28936.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Exception in thread “main“ java.lang.UnsupportedClassVersionError 50报错处理

之间正常走jenkinsdocker自动化部署的项目,今天改了一个文件,点了一下,竟然没有部署上去,提示如上,如下 Exception in thread "main" java.lang.UnsupportedClassVersionError: com/coocaa/tsp/sys/user/Use…

采用UWB定位技术开发的室内定位系统源码

UWB精准定位系统源码 UWB是什么? UWB(Ultra Wideband)超宽带技术是一种全新的、与传统通信技术有极大差异的通信新技术。它不需要使用传统通信体制中的载波,而是通过发送和接收具有纳秒或纳秒级以下的极窄脉冲来传输数据,实现精准定位。 技术…

机器鸟实现摆动尾巴功能

1. 功能说明 本文示例将实现R329样机机器鸟摆动尾巴的功能。 2. 电子硬件 在这个示例中,我们采用了以下硬件,请大家参考: 主控板 Basra主控板(兼容Arduino Uno)‍ 扩展板 Bigfish2.1扩展板‍ 电池7.4V锂电池 电路连接…

【头歌-Python】9.1 X射线衍射曲线绘制(project)第1~2关

第1关:X 射线衍射曲线 任务描述 本关任务:读文件中的数据绘制线图形。 相关知识 为了完成本关任务,你需要掌握: 1.python 读取文件 2.使用 matplotlib 绘制图形 python 读取文件 python读取文件可以用以下函数实现&#xff1…

chatgpt赋能python:Python收费介绍

Python收费介绍 什么是Python? Python是一种高级的、解释性、面向对象、纯粹的动态语言,多用于快速应用程序开发、脚本编写、系统管理任务等。它有一个简单直观优美的语法,非常容易学习。 Python的收费形式 Python语言本身是免费的,任何…

如何使用Jmeter进行http接口测试?

目录 前言: 一、开发接口测试案例的整体方案: 二、接口自动化适用场景: 三、接口测试环境准备 四、创建工程: 总结: 前言: 本文主要针对http接口进行测试,使用Jmeter工具实现。 Jmter工具设…

1.2 Scala变量与数据类型

一、变量声明 (一)、利用val声明变量 案例演示 (二)利用var声明变量 案例演示 (三)换行输入语句(续行) (四)同时声明多个变量 Scala还可以将多个变量放在一起…

射频电路layout总结

射频电路板设计由于在理论上还有很多不确定性,因此常被形容为一种“黑色艺术”,但这个观点只有部分正确,RF电路板设计也有许多可以遵循的准则和不应该被忽视的法则。在实际设计时,真正实用的技巧是当这些准则和法则因各种设计约束…

RTU遥测终端机的应用场景有哪些?

遥测终端机又称智能RTU遥测终端机,是一种用于采集、传输和处理遥测数据的设备。在现代科技的发展中,遥测终端机扮演着重要的角色。它是一种能够实现远程监测和控制的关键设备,广泛应用于各个领域,包括水文水利、环境监测、工业自动…

创新案例|专注在线 协作平台 设计产品中国首家PLG独角兽企业蓝湖如何实现98%的头部企业渗透率

蓝湖起步于2015年,是一款服务于产品经理、设计师、工程师的产品设计研发在线协作工具, 2021年10月,蓝湖宣布完成C轮融资,融资额高达10亿人民币,称为中国2B市场中首家采用PLG发展的独角兽企业,并实现了从100…

Web 自动化测试Selenium 之PO 模型

目录 1. po 模型介绍 2. PageObject 设计模式 3. PO 的核心要素 4. 非PO 实现 5. PO 实现 6. 总结 7. PO 模式的特点 总结: 1. po 模型介绍 在自动化中,Selenium 自动化测试中有一个名字经常被提及 PageObject (思想与面向对象的特征相同)&#x…

MySQL数据库高级操作

目录 MySQL中6种常见的约束克隆表清空表的数据记录临时表创建外键约束,保证数据的完整性和一致性。 MySQL中6种常见的约束 主键约束(primary key)外键约束(foreign key)非空约束(not null)唯一…

安卓端Google隐私沙盒归因报告聚焦

自2022年2月Google首次提出将推出隐私沙盒至今已一年有余。现在,安卓端的隐私沙盒Beta测试已针对特定Android13设备正式开始。作为早期测试者,Adjust很高兴与 Google一同迈出增强用户隐私的第一步,并在接下来的旅程中继续携手同行。为帮助移动…

opencv初学记录

准备工作: 1.找一张图片 2.准备python运行环境,并导入库,pip install opencv-python 读取文件,并打印维度 import cv2 #为什么是cv2不是cv呢,这个2指的是c的api,是为了兼容老板,cv指的就是c&am…

今天面了一个来字节要求月薪23K,明显感觉他背了很多面试题...

最近有朋友去字节面试,面试前后进行了20天左右,包含4轮电话面试、1轮笔试、1轮主管视频面试、1轮hr视频面试。 据他所说,80%的人都会栽在第一轮面试,要不是他面试前做足准备,估计都坚持不完后面几轮面试。 其实&…

ChatGPT/InstructGPT详解

前言 GPT系列是OpenAI的一系列预训练文章,GPT的全称是Generative Pre-Trained Transformer,顾名思义,GPT的目的就是通过Transformer为基础模型,使用预训练技术得到通用的文本模型。目前已经公布论文的有文本预训练GPT-1&#xff…

【操作系统】计算机操作系统知识点总结

文章目录 前言一、操作系统的概念与发展二、操作系统的结构与功能1、操作系统的结构2、操作系统的功能 三、进程管理1、进程2、进程的创建3、进程管理的实现4、进程控制块 四、内存管理1、内存2、内存管理3、内存管理的实现 五、文件系统1、文件系统2、文件系统的主要任务3、文…

Java开发 - 带你了解集群间的相互调用,你还在等什么?

目录 前言 导读 项目准备 集群准备 父工程引入子项目 服务调用方HelloService准备 pom文件 yml文件 Controller文件 服务提供方HelloWorld准备 pom文件 yml文件 Controller文件 运行此两个工程 hello_world组集群 集群调用测试 RestTemplate换成Dubbo行不行…

Word 2021入门指南:详细解读常用功能

软件安装:办公神器office2021安装教程,让你快速上手_正经人_____的博客-CSDN博客 一、 新建文档 打开Word 2021后,可以看到左上角的“文件”选项,点击它,在弹出的菜单中选择“新建”选项。然后可以选择空白文档或者使…

Linux安装和配置VCenter

Linux安装和配置VCenter 以下演示安装 Linux VCenter,也就是使用VMware-VCSA-all-6.7.0-13010631.iso 镜像包。通过一台 Windows服务器远程连接 ESXI 服务器安装 Linux 版本的 VCenter。也就是Windows 服务器只是安装的界面的一个载体。 Linux VCenter环境搭建 下…