【猫狗分类】Pytorch VGG16 实现猫狗分类5-预测新图片

背景
 

好了,现在开尝试预测新的图片,并且让vgg16模型判断是狗还是猫吧。

声明:整个数据和代码来自于b站,链接:使用pytorch框架手把手教你利用VGG16网络编写猫狗分类程序_哔哩哔哩_bilibili

预测

1、导包

from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from net import vgg16

2、设置新照片的路径

test_pth=r'.\img.png'#设置可以检测的图像
test=Image.open(test_pth)

3、处理图片:图片变成tensor

transform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])
image=transform(test)
  • transforms.Compose:这是一个类,可以将多个变换操作组合在一起。当你需要对数据执行一系列变换时,就会用到它。它接受一个变换函数列表作为参数。

4、设置设备

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")#CPU与GPU的选择

5、加载网络(vgg16net)

net =vgg16()#输入网络

6、加载模型(权重模型)

model=torch.load(r".\DogandCat5.pth",map_location=device)#已训练完成的结果权重输入
net.load_state_dict(model)#模型导入

网络是网络,模型是模型!模型是训练出来的权重模型!网络是认为设定的!

7、模式选择(是训练模式还是推理模式)

net.eval()#设置为推测模式
  • 在PyTorch中,net.eval()是一个非常重要的方法调用,它用于改变模型的状态,使其从训练模式切换到推理(推测)模式。理解这一点很重要,因为模型在两种模式下的行为有所不同:
  • 训练模式 (net.train()): 在这种模式下,模型中的所有层都会处于活跃状态,包括像Dropout和Batch Normalization这样的层,它们会在每次前向传播时根据训练数据进行更新,引入随机性和依赖于批次的统计信息。这对于学习模型参数是非常必要的。

  • 推理模式 (net.eval()): 调用net.eval()后,模型会进入推理模式。这时,Dropout层将不起作用(即总是通过),而Batch Normalization层会使用在训练过程中计算得到的移动平均和方差,而不是 mini-batch 中的统计信息。这意味着模型的输出对于相同的输入将变得确定性,这对于测试和预测非常重要,因为你希望对同一输入多次运行模型时得到相同的结果。

  • 总结来说,当你准备好使用训练好的模型对新数据进行预测,而不是继续修改模型参数时,就应该调用net.eval()来确保模型以正确、一致的方式进行推理

8、传图片到网络,调整输入维度为四维张量

image=torch.reshape(image,(1,3,224,224))#四维图形,RGB三个通道

在PyTorch中,使用torch.reshape或者更常用的torch.Tensor.view方法可以改变张量的形状。对于图像数据,特别是当您准备将图像输入到深度学习模型时,将其调整为适合模型输入维度的四维张量是很常见的操作。

9、开始预测

with torch.no_grad():
    out=net(image)
out=F.softmax(out,dim=1) #softmax转为概率学问题
out=out.data.cpu().numpy()
print(out)
a=int(out.argmax(1))#输出最大值位置
  • with torch.no_grad():: 这一行代码用来指示PyTorch在接下来的代码块中不记录任何梯度信息。这对于推理(预测)阶段是非常重要的,因为它可以减少内存使用并加速计算过程,因为不需要为反向传播做准备。

  • out=net(image): 在上下文管理器torch.no_grad()内,将处理过的图像image输入到神经网络模型net中进行前向传播,得到模型的原始输出out。这个输出通常是未经处理的概率分布,对于分类任务,它通常代表每个类别的得分

  • out=F.softmax(out, dim=1): 使用F.softmax函数对模型输出out进行处理,该函数会将每一行的数据转换为概率分布,确保所有元素之和为1。这里dim=1表示沿着类别维度(通常对应于神经网络输出的最后一维)进行softmax操作,使得每个样本的预测结果可以解释为各类别的概率。

例举:假设你有一个简单的分类任务,模型需要区分猫、狗、鸟三种动物,即共有3个类别。你使用一个神经网络模型进行预测,对于一个批次内单个样本的输出可能看起来像这样(在未经过softmax处理前):

out_before_softmax = torch.tensor([2.0, 1.0, 0.5], dtype=torch.float32)

这里的输出张量out_before_softmax表示模型对于这个样本属于三个类别的原始打分或logits。注意,这些数值没有直接的概率意义,它们可以是任意实数。

应用Softmax

为了将这些原始分数转化为概率分布,你将使用F.softmax函数,并且指定dim=1,因为在这个一维张量的情况下,类别维度自然就是最后一维。执行操作后:

import torch.nn.functional as F
out_after_softmax = F.softmax(out_before_softmax, dim=1)
print(out_after_softmax)

输出解释

执行上述代码后,你可能会看到类似以下的输出(具体数值可能因四舍五入略有不同):

tensor([0.5561, 0.2476, 0.1963])

现在,out_after_softmax中的每个元素代表样本属于对应类别的概率,且所有概率之和为1(或接近1,由于浮点运算的精度限制)。例如,这里模型认为该样本有大约55.61%的概率是猫,24.76%的概率是狗,以及19.63%的概率是鸟。

总结

通过指定dim=1,你告诉softmax函数沿张量的最后一维进行操作,这在多分类任务中至关重要,因为它确保了每个样本的预测能够被合理地解释为各类别的概率分布。

  • out=out.data.cpu().numpy(): 将张量out从GPU(如果有的话)复制到CPU上,并转换为numpy数组,以便于进一步的处理和显示。这样做是因为后续的操作可能涉及到非PyTorch的库,如matplotlib用于绘图。

  • a=int(out.argmax(1)): 找出概率最大的类别索引,即预测的类别。argmax(1)沿着第1维度(类别维度)找到最大值的索引。argmax函数是用来找出数组或张量中最大值所在的位置(索引)。

10、显示图像

plt.figure()
list=['Cat','Dog']
plt.suptitle("Classes:{}:{:.1%}".format(list[a],out[0,a]))
plt.imshow(test)
plt.show()

  • list=['Cat','Dog']: 定义了一个类别标签列表,这里简化为猫和狗两类。
  • plt.suptitle(...): 设置图表的主标题,显示预测的类别名称及最高概率的百分比。
  • plt.imshow(test): 显示原始测试图像。
  • plt.show(): 显示整个图表,包括图像和标题。

给一张柴犬的照片,预测下:

                        

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/714804.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

分数布朗运动FBM期权定价模型

BS定价模型和蒙特卡洛模拟期权定价方法都 假设标的资产价格的对数服从布朗运动 . 但是实际 的金融市场中标的资产价格运动过程具有 “尖峰厚尾 ” 现象 , 运用分数布朗运动 (FBM )来刻画标的资产 价格的运动过程可能更加合适。 …

HP惠普暗影精灵10 OMEN Gaming Laptop 16-wf1xxx原厂Win11系统镜像下载

惠普hp暗影精灵10笔记本电脑16-wf1000TX原装出厂Windows11,恢复开箱状态oem预装系统安装包,带恢复重置还原 适用型号:16-wf1xxx 16-wf1000TX,16-wf1023TX,16-wf1024TX,16-wf1025TX, 16-wf1026TX,16-wf1027TX,16-wf1028TX,16-wf1029TX, 16-wf1030TX,16-…

docker拉取镜像太慢解决方案

前言 这是我在这个网站整理的笔记,有错误的地方请指出,关注我,接下来还会持续更新。 作者:神的孩子都在歌唱 创建daemon.json文件,输入以下信息 vim /etc/docker/daemon.json{"registry-mirrors": ["https://9cpn8tt6.mirror…

JAVA开发 选择多个文件,系统运行后自动生成ZIP压缩包

选择多个文件,系统运行后自动生成ZIP压缩包 实现方法1.1 代码块1.2 运行结果截取 相关知识 实现方法 案例简述:通过启动java代码来打开文件选择器对话框,用户选择确认需要进行压缩的文件,可一次性选择多个文件,选择完…

边缘计算采集网关解决方案:为企业提供高效、灵活的数据处理方案-天拓四方

一、企业背景 某大型制造企业,位于国内某经济发达的工业园区内,拥有多个生产线和智能化设备,致力于提高生产效率、降低运营成本。随着企业规模的扩大和生产自动化的推进,该企业面临着海量数据处理、实时响应和网络安全等多重挑战…

安卓在Fragment控制状态栏显示隐藏

废话不多上效果 隐藏 显示 核心代码 首先是Framgrent package com.zx.tab;import android.content.Context; import android.os.Bundle; import android.view.LayoutInflater; import android.view.View; import android.view.ViewGroup; import android.widget.Button;impor…

Django中间件探索:揭秘中间件在Web应用中的守护角色与实战应用

系列文章目录 Django入门全攻略:从零搭建你的第一个Web项目Django ORM入门指南:从概念到实践,掌握模型创建、迁移与视图操作Django ORM实战:模型字段与元选项配置,以及链式过滤与QF查询详解Django ORM深度游&#xff…

Cisco Packet Tracer实验(四)

生成树协议(Spanning Tree Protocol) 交换机在目的地址未知或接收到广播帧时是要进行广播的。如果交换机之间存在回路/环路,那么就会产生广播循环风暴,从而严重影响网络性能。 而交换机中运行的STP协议能避免交换机之间发生广播…

登录MySQL方式

登录MySQL方式 方式一:通过MySQL自带的客户端 MySQL 客户端输入命令即可 方式二:通过window自带的客户端 从命令端(cmd)进入 mysql -h localhost -P 3306 -u root -p Enter password:密码登录方式: mysql -h 主…

【Java03】Java中数组在内存中的机制

1. 内存中的数组 Java中的数组是一种引用类型,数组变量(引用)和数组元素在内存中是分开的。 Java中的数组变量其实就是指针。 如果想要访问数组元素,只能通过这个数组的引用变量(指针)来访问。 实际数组对…

上海计算机考研避雷,25考研慎报

上大计算机一直很热 408考研er重来没有让我失望过,现在上大的专业课是11408,按理说,这个专业课的难度是很高的,但是408er给卷出了新高度,大家可以去上大官网看看今年最新的数据,我也帮大家统计了24年最新的…

【YOLOv10轻量级涨点改进:block优化 | 华为诺亚2023极简的神经网络模型 VanillaNet】

本文属于原创独家改进:一种极简的神经网络模型VanillaNet,以极简主义的设计为理念,网络中仅仅包含最简单的卷积计算,去掉了残差和注意力模块 计算量参数量比较,8.4 GFLOPs降低至6.1 GFLOPs YOLOv10n summary: 385 layers, 2709380 parameters, 2709364 gradients, 8.4 GF…

分享一个 .NET Core 使用选项方式读取配置内容的详细例子

前言 在 .NET Core 中,可以使用选项模式(Options Pattern)来读取和管理应用程序的配置内容。 选项模式通过创建一个 POCO(Plain Old CLR Object)来表示配置选项,并将其注册到依赖注入容器中,方…

图像处理:Python使用OpenCV进行图像锐化 (非锐化掩模、拉普拉斯滤波器)

文章目录 非锐化掩模 (Unsharp Masking)拉普拉斯滤波器 (Laplacian Filter)效果对比总结 在图像处理中,锐化操作用于增强图像的边缘和细节,使图像看起来更清晰。常见的图像锐化方法包括非锐化掩模(Unsharp Masking)和拉普拉斯滤波…

Linux:基础IO(二.缓冲区、模拟一下缓冲区、详细讲解文件系统)

上次介绍了:Linux:基础IO(一.C语言文件接口与系统调用、默认打开的文件流、详解文件描述符与dup2系统调用) 文章目录 1.缓冲区1.1概念1.2作用与意义 2.语言级别的缓冲区2.1刷新策略2.2具体在哪里2.3支持格式化 3.自己来模拟一下缓…

Burp Suite Professional 2024.5 (macOS, Linux, Windows) - Web 应用安全、测试和扫描

Burp Suite Professional 2024.5 (macOS, Linux, Windows) - Web 应用安全、测试和扫描 Burp Suite Professional, Test, find, and exploit vulnerabilities. 请访问原文链接:Burp Suite Professional 2024.5 (macOS, Linux, Windows) - Web 应用安全、测试和扫描…

JWT令牌、过滤器Filter、拦截器Interceptor

目录 JWT令牌 简介 JWT生成 解析JWT 登陆后下发令牌 过滤器(Filter) Filter快速入门 Filter拦截路径 过滤器链 登录校验Filter-流程 拦截器(Interceptor) Interceptor 快速入门 拦截路径 登录校验流程 JWT令牌 简介 全称:JSON Web Token(https://iwt.io/) …

springboot与flowable(12):网关服务(包容网关)

一、绘制流程图 包容网关可以看作是排他网关和并行网关的结合体。和排他网关一样,可以在外出顺序流上定义条件,包容网关会解析它们。但是主要的区别是包容网关可以选择多余一条顺序流,这和并行网关一样。包容网关的功能是基于进入和外出顺序流…

axure使用中继器画柱状图

源文件在顶部。 在axure通过读取中继器中的数据来画柱状图,如下图: 1)创建一个中继器,在里面创建两列:1列是柱状图底部的名称、2列是柱的高度,如下图: 2)双击中继器,画一…

华为OD机试 - 多段线数据压缩(Java 2024 D卷 100分)

华为OD机试 2024D卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(D卷C卷A卷B卷)》。 刷的越多,抽中的概率越大,每一题都有详细的答题思路、详细的代码注释、样例测…