学习pytorch20 pytorch完整的模型验证套路

pytorch完整的模型验证套路

  • 使用非数据集的测试数据,测试训练好模型的效果
  • 代码
  • 预测结果
  • 解决报错

B站小土堆pytorch学习视频 https://www.bilibili.com/video/BV1hE411t7RN/?p=32&spm_id_from=pageDriver&vd_source=9607a6d9d829b667f8f0ccaaaa142fcb

在这里插入图片描述

使用非数据集的测试数据,测试训练好模型的效果

 测试:训练好的模型,提供对外真实数据的一个实际应用

从网上下载两张图片,整理图片的输入格式,输入模型测试模型效果
请添加图片描述
请添加图片描述

代码

import torch
from torch import nn
from torchvision import transforms
from PIL import Image
import cv2

dog_path = './images/dog.jpg'
airplane_path = './images/airplane.jpg'
model_path = './images/net_epoch9_gpu.pth'

dog_pil = Image.open(dog_path)
airp_pil = Image.open(airplane_path)
print(dog_pil)  # <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=258x174 at 0x237CC2CBE50>
# RGB 3通道 匹配模型输入的通道数
dog_pil = dog_pil.convert('RGB')  # def convert(self, mode=None, matrix=None, dither=None, palette=Palette.WEB, colors=256):
airp_pil = airp_pil.convert('RGB')
# dog_cv = cv2.imread(dog_path)  # numpy.array
# # print(dog_cv)
# img_trans = torchvision.transforms.ToTensor()  # 实例化转tensor的类
# dog_tensor = img_trans(dog_pil)
# dog_cv_tensor = img_trans(dog_cv)
# print(dog_tensor)
# print(dog_tensor.shape)
# print(dog_cv_tensor)
# 输入模型shape 需要是32*32大小的
transform = transforms.Compose([transforms.Resize((32,32)),
                                transforms.ToTensor()])
dog_tensor = transform(dog_pil)
airp_tensor = transform(airp_pil)
# print(dog_tensor)
print(dog_tensor.shape, airp_tensor.shape)


class Cifar10Net(nn.Module):
    def __init__(self):
        super(Cifar10Net, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.net(x)
        return x

# 加载模型要考虑是以哪种形式保存的模型  模型保存方式1:保存模型结构和参数 方式二:只保存模型参数
model = torch.load(model_path, map_location=torch.device('cpu'))
print(model)
dog_tensor = dog_tensor.reshape((1, 3, 32, 32))
airp_tensor = airp_tensor.reshape((1, 3, 32, 32))
model.eval() # 设置模型为测试状态 网络层有dropout batchNormal层不加eval函数会有问题
with torch.no_grad(): # 测试不做梯度计算 节省算力
    dog_output = model(dog_tensor)
    airp_output = model(airp_tensor)
print(dog_output)
print(dog_output.argmax())
print(dog_output.argmax(1))
print(airp_output)
print(airp_output.argmax(1))  # 概率值不便于解读 使用argmax 可以很方便的读出模型预测的是哪个类别

预测结果

在这里插入图片描述

tensor([[ 1.1317, -4.3441,  3.2116,  2.8930,  2.6749,  4.6079, -3.2860,  3.1357,
         -3.0432, -4.1703]])
tensor(5)
tensor([5])
tensor([[ 5.5993, -0.6140,  4.4758,  0.8463,  1.6311, -1.0217, -3.9990, -2.8343,
          1.1050, -1.6423]])
tensor([0])

预测结果和训练数据的标注一直,预测正确

解决报错

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x16 and 1024x64)
解决: dog_tensor = dog_tensor.reshape((1, 3, 32, 32)) 转换输入是4维的, 模型输入有一个batch-size维度

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device(‘cpu’) to map your storages to the CPU.
解决:model = torch.load(model_path, map_location=torch.device(‘cpu’))
在gpu上训练的模型,要在cpu上测试,模型加载的时候指定cpu设备

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/239943.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

postman接口测试之Postman配置环境变量和全局变量

前言  我们在测试的过程中&#xff0c;遇到最多的问题也可以是环境的问题了吧&#xff0c;今天开发用了这个测试环境&#xff0c;明天又换了另一个测试环境&#xff0c;这样对于我们测试非常的麻烦&#xff0c;特别最接口的时候需要来回的输入环境地址比较麻烦&#xff0c;今天…

基于Python+WaveNet+MFCC+Tensorflow智能方言分类—深度学习算法应用(含全部工程源码)(四)

目录 前言引言总体设计系统整体结构图系统流程图 运行环境模块实现1. 数据预处理2. 模型构建3. 模型训练及保存4. 模型生成 系统测试1. 训练准确率2. 测试效果 相关其它博客工程源代码下载其它资料下载 前言 博主前段时间发布了一篇有关方言识别和分类模型训练的博客&#xff…

动态规划学习——通符串匹配,正则表达式

目录 ​编辑 一&#xff0c;通符串匹配 1.题目 2.题目接口 3&#xff0c;解题思路及其代码 二&#xff0c;正则表达 1.题目 2.题目接口 3.解题思路及其代码 三&#xff0c;交错字符串 1.题目 2&#xff0c;题目接口 3.解题思路及其代码 一&#xff0c;通符串匹配 1…

基于Python+WaveNet+MFCC+Tensorflow智能方言分类—深度学习算法应用(含全部工程源码)(三)

目录 前言引言总体设计系统整体结构图系统流程图 运行环境模块实现1. 数据预处理2. 模型构建1&#xff09;定义模型结构2&#xff09;优化损失函数 3. 模型训练及保存1&#xff09;模型训练2&#xff09;模型保存3&#xff09;映射保存 相关其它博客工程源代码下载其它资料下载…

让植被管理更精准:数据可视化的新利器

【小编整理了300可视化大屏源文件&#xff0c;需要可后台私~&#xff01;】 在当今时代&#xff0c;数据可视化技术已经成为了一个非常重要的技术。对于植被管理来说&#xff0c;数据可视化也有着非常重要的作用。通过将植被管理数据可视化&#xff0c;我们可以更加清晰地了解植…

Apache Flink(十一):Flink集群部署-Standalone集群部署

🏡 个人主页:IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 🚩 私聊博主:加入大数据技术讨论群聊,获取更多大数据资料。 🔔 博主个人B栈地址:豹哥教你大数据的个人空间-豹哥教你大数据个人主页-哔哩哔哩视频 目录 1. 节点划分

SpringCloud-高级篇(七)

前面在微服务里整合了Seata&#xff0c;下面利用Seata去解决分布式事务的问题&#xff0c;回去学习Seata中的四种解决方案 &#xff1a;首先学习XA模式 &#xff08;1&#xff09;XA模式 RM在前面讲的是资源管理器&#xff0c;在XA标准中RM都是由数据库来实现的&#xff0c;数…

数据挖掘目标(Kaggle Titanic 生存测试)

import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns1.数据导入 In [2]: train_data pd.read_csv(r../老师文件/train.csv) test_data pd.read_csv(r../老师文件/test.csv) labels pd.read_csv(r../老师文件/label.csv)[Su…

oracle详细安装教程(附带百度网盘资源)

一,下载安装包途径 1.官网 Unauthorized Request 2.百度网盘分析 https://pan.baidu.com/s/1n221gdTK0Fcho839oRab9g 提取码1q2w 二&#xff0c;安装教程 1.下载完安装包后点击 setup.exe 如果出现一下的问题&#xff0c;使用windows10等系统安装oracle 11g等版本的数据库…

二叉树的最大深度

问题描述&#xff1a; 给定一个二叉树 root &#xff0c;返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 示例 1&#xff1a; 输入&#xff1a;root [3,9,20,null,null,15,7] 输出&#xff1a;3示例 2&#xff1a; 输入&#xff1…

ue4 解决角度万向锁的问题 蓝图节点

问题&#xff1a;当角度值从359-1变化的时候&#xff0c;数值会经历358、357… 解决方法&#xff1a;勾上Shortest Path&#xff0c;角度值的会从359-1

Ajax原理以及优缺点

Ajax原理 1.Ajax的原理简单来说是在用户和服务器之间加了—个中间层(AJAX引擎)&#xff0c;通过XmlHttpRequest对象来向服务器发异步请求&#xff0c; 2.从服务器获得数据&#xff0c;然后用javascript来操作DOM而更新页面。使用户操作与服务器响应异步化。 3.这其中最关键的一…

「Verilog学习笔记」简易秒表

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点&#xff0c;刷题网站用的是牛客网 timescale 1ns/1nsmodule count_module(input clk,input rst_n,output reg [5:0]second,output reg [5:0]minute);always (posedge clk or negedge rst_n) begin if (~rst…

Axure电商产品移动端交互原型,移动端高保真Axure原型图(RP源文件手机app界面UI设计模板)

本作品是一套 Axure8 高保真移动端电商APP产品原型模板&#xff0c;包含了用户中心、会员成长、优惠券、积分、互动社区、运营推广、内容推荐、商品展示、订单流程、订单管理、售后及服务等完整的电商体系功能架构和业务流程。 本模板由一百三十多个界面上千个交互元件及事件组…

各地加速“双碳”落地,数字能源供应商怎么选?

作者 | 曾响铃 文 | 响铃说 随着我国力争2030年前实现“碳达峰”、2060年前实现“碳中和”的“双碳”目标提出&#xff0c;为各地区、各行业的低碳转型和绿色可持续发展制定“倒计时”时间表&#xff0c;一场围绕“数字能源”、“智慧能源”、“新能源”等关键词的创新探索进…

二百一十六、Flume——Flume拓扑结构之负载均衡和故障转移的开发案例(亲测,附截图)

一、目的 对于Flume的负载均衡和故障转移拓扑结构&#xff0c;进行一个开发测试 二、负载均衡和故障转移 &#xff08;一&#xff09;结构含义 Flume支持使用将多个sink逻辑上分到一个sink组 &#xff08;二&#xff09;结构特征 sink组配合不同的SinkProcessor可以实现负…

《地理信息系统原理》笔记/期末复习资料(10. 空间数据挖掘与空间决策支持系统)

目录 10. 空间数据挖掘与空间决策支持系统 10.1. 空间数据挖掘 10.1.1. 空间数据挖掘的概念 10.1.2. 空间数据挖掘的方法与过程 10.1.3. 空间数据挖掘的应用 10.2. 空间决策支持系统 10.2.1. 空间决策支持系统的概念 10.2.2. 空间决策支持系统的结构 10.2.3. 空间决策…

Onlyoffice本地部署超详细教程(附协作空间2.0新资讯)

陈老老老板&#x1f934; &#x1f9d9;‍♂️本文专栏&#xff1a;生活&#xff08;主要讲一下自己生活相关的内容&#xff09;生活就像海洋,只有意志坚强的人,才能到达彼岸。 &#x1f9d9;‍♂️本文简述&#xff1a;ONLYOFFICE相信大家已经有所了解&#xff0c;本篇讲一下o…

mjpg-streamer配置其它端口访问视频

环境 树莓派4B ubuntu 20.04 U口摄像头 确认摄像头可访问 lsusb查看 在dev下可查看到video* sudo mplayer tv://可打开摄像头并访问到视频 下载mjpg-streamer并编译安装 在github下载zip包&#xff0c;下载的源码&#xff0c;需要编译安装 unzip解压 cd mjpg-streamer/mjp…

聚观早报 |一加12首销;华为智能手表释放科技温暖

【聚观365】12月12日消息 一加12首销 华为智能手表释放科技温暖 卡尔动力获地平线战略投资 英伟达希望在越南建立基地 努比亚Z60 Ultra影像规格揭晓 一加12首销 现在有最新消息&#xff0c;近日一加12该机已于昨日开售&#xff0c;售价4299元起。 外观方面&#xff0c;全…