笔记3:torch训练测试VGG网络

(1)利用Netron查看网络实际情况

在这里插入图片描述
上图链接
python生成上图代码如下,其中GETVGGnet是搭建VGG网络的程序GETVGGnet.py,VGGnet是该程序中的搭建网络类。netron是需要pip安装的可视化库,注意do_constant_folding=False可以防止Netron中不显示Batchnorm2D层,禁用参数隐藏。

import torch
from torch.autograd import Variable
from GetVGGnet import VGGnet
import netron

net = VGGnet()
x = Variable(torch.FloatTensor(1,3,28,28))
y = net(x)
print(y.data.shape)
onnx_path = "./save_model/VGGnet.onnx"
torch.onnx.export(net, x, onnx_path,do_constant_folding=False)
print(net)
netron.start(onnx_path)

(2)VGG训练测试全过程

此次训练在CPU上进行,迭代次epoch = 10,迭代内轮次batch=300,训练集10000张,测试集2000张。
train loss和train corre分别代表损失和正确率,横轴是不同迭代下每一个伦次的loss&corre累加,一个迭代进行33个轮次,每个迭代最后一个伦次数据不足被网络舍弃,10个迭代总共320次。test loss和test corre是每个一个迭代下所有伦次的正确率平均值。根据图可以看出,训练和测试结果都较好。
在这里插入图片描述
训练的损失和正确率在波动,但总体趋势较好。
在这里插入图片描述
数据集大小可以在此处修改:在这里插入图片描述

代码:cifar10_handle和GetVGGnet在上几篇文章有说明

#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: 楠楠星球
@time: 2024/5/10 10:15 
@file: VGGTrain.py-->test
@project: pythonProject
@# ------------------------------------------(one)--------------------------------------
@# ------------------------------------------(two)--------------------------------------
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from GetVGGnet import VGGnet
from cifar10_handle import train_dataset,test_dataset
import matplotlib.pyplot as plt

epoch = 10  #迭代次数
learn_rate = 0.01 #初始学习率

net = VGGnet().to(device='cpu') #模型实例化
loss_fun = nn.CrossEntropyLoss() #调用损失函数
train_data_loder = DataLoader(dataset=train_dataset,
                              batch_size=300,  #每一次迭代的调用的波次
                              shuffle=True,    #这个波次是否打乱数据集
                              num_workers=4,   # 线程数
                              drop_last=True)  # 最后一个波次数据不足是否舍去

test_data_loder = DataLoader(dataset=test_dataset,
                             batch_size=300,
                             shuffle=False,
                             num_workers=4,
                             drop_last=True)

# optimizer = torch.optim.Adam(net.parameters(), lr=learn_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learn_rate, momentum=0.5) #优化器

# scheduler = torch.optim.lr_scheduler.StepLR(optijumizer, step_size=5, gamma=0.9) #step_size=1表示每迭代一次更新一下学习率
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.7) #学习率调整器


def train(epoch_num,train_net):
	# ------------------------------------------()--------------------------------------
	loss_base = []
	corre_base = []
	
	test_loss_base = []
	test_corre_base =[]
	for epoch in range(epoch_num):
		# ------------------------------------------(TRAIN)--------------------------------------
		train_net.train()
		for i, data in enumerate(train_data_loder):
			input_tensor, label = data
			input_tensor = input_tensor.to(device='cpu')
			label = label.to(device='cpu')
			
			output_tensor = train_net(input_tensor)
			loss = loss_fun(output_tensor, label)
			
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()
			
			_, pred = torch.max(output_tensor.data, dim=1)
			correct = pred.eq(label.data).cpu().sum()
			
			print(f"训练中:第{epoch + 1}次迭代的小迭代{i}的损失率为:{1.00 * loss.item()},正确率为:{100.00 * correct / 300}")
			loss_base.append(loss.item())
			corre_base.append(100.00 * correct.item() / 300)
	
		scheduler.step()
		
		# ------------------------------------------(TEST)--------------------------------------
		sum_test_loss = 0
		sum_test_corre = 0
		train_net.eval()
		for i, test_data in enumerate(test_data_loder):
			input_tensor, label = test_data
			input_tensor = input_tensor.to(device='cpu')
			label = label.to(device='cpu')
			
			output_tensor = train_net(input_tensor)
			loss = loss_fun(output_tensor, label)
			_, pred = torch.max(output_tensor.data, dim=1)
			correct = pred.eq(label.data).cpu().sum()

			sum_test_loss += loss.item()
			sum_test_corre += correct.item()
			
		
		test_loss = sum_test_loss * 1.0 / len(test_data_loder)
		test_corre = sum_test_corre * 100.0 / len(test_data_loder) / 300
		test_loss_base.append(test_loss)
		test_corre_base.append(test_corre)
		print(f"测试中:当前迭代的测试集损失为:{test_loss},正确率为:{test_corre}")
	return loss_base,corre_base,test_loss_base,test_corre_base
	# ------------------------------------------()--------------------------------------

if __name__ == '__main__':
	[train_loss,train_corre,test_loss,test_corr] = train(epoch,net)
	fig, axes = plt.subplots(2, 2)
	
	axes[0, 0].plot(list(range(1, len(train_loss)+1 )), train_loss,color ='r')
	axes[0, 0].set_title('train loss')
	
	axes[0, 1].plot(list(range(1, len(train_corre) + 1)), train_corre, color ='r')
	axes[0, 1].set_title('train corre')
	
	axes[1, 0].plot(list(range(1, len(test_loss) + 1)), test_loss,color ='r')
	axes[1, 0].set_title('test loss')

	axes[1, 1].plot(list(range(1, len(test_corr) + 1)), test_corr,color ='r')
	axes[1, 1].set_title('test corre')
	plt.show()
	
	# torch.save(net.state_dict(), './save_model/example1.pt')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/629679.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

26 分钟惊讶世界,GPT-4o 引领未来人机交互

前言 原文链接:OpenAI最新模型——GPT-4o,实时语音视频交互,未来人机交互近在眼前 - Kaiho小站 北京时间 5 月 14 日凌晨,OpenAI 发布新一代模型——GPT-4o,仅在 ChatGPT 面世 17 个月后,OpenAI 再次通过…

掌握这些神器,让你的编程之路更加“丝滑”

前言: 在软件开发的旅程中,程序员的实用神器确实如同指南针,帮助他们在复杂的代码海洋中导航。以下是从三个方向——自动化测试工具、持续集成/持续部署(CI/CD)以及代码审查与质量分析——来探讨这些实用神器的应用和影…

2024软件测试必问的常见面试题1000问!

01、您所熟悉的测试用例设计方法都有哪些?请分别以具体的例子来说明这些方法在测试用例设计工作中的应用。 答:有黑盒和白盒两种测试种类,黑盒有等价类划分法,边界分析法,因果图法和错误猜测法。白盒有逻辑覆盖法&…

一种基于电场连续性的高压MOSFET紧凑模型,用于精确表征电容特性

来源:A Compact Model of High-Voltage MOSFET Based on Electric Field Continuity for Accurate Characterization of Capacitance(TED 24年) 摘要 本文提出了一种新的高压MOSFET(HV MOS)紧凑模型,以消…

机器学习(3)

目录 3-1线性回归 3-2最小二乘解 3-3多元线性回归 3-4广义线性模型 3-5对率回归 3-6对率回归求解 3-7线性判别分析 3-8LDA的多类推广 3-9多分类学习基本思路 3-10类别不平衡 3-1线性回归 线性模型为什么重要? 人类在考虑问题时,通常…

用c++实现快速排序、最大子段和问题

6.2.2 快速排序 【问题】快速排序(quick sort)的分治策略如下(图6-5)。 (1)划分:(选定一个记录作为轴值,以轴值为基准将整个序列划分为两个子序列,轴值的位置在划分的过程中确定,并且左侧子序列的所有记录…

全网最全的基于电机控制的38类simulink仿真全家桶----新手大礼包

整理了基于电机的38种simulink仿真全家桶,包含多种资料,类型齐全十分适合新手学习使用。包括但是不局限于以下: 1、基于多电平逆变器的无刷直流电机驱动simulink仿真 2、基于负载转矩的感应电机速度控制simulink仿真 3、基于滑膜观测器的永…

【全开源】JAVA情侣扭蛋机情侣游戏系统源码支持微信小程序+微信公众号+H5

让爱情更添趣味与惊喜 在繁忙的生活中,情侣们总是渴望找到一种新颖而有趣的方式来增进彼此的感情。为此,我们特别推出了“情侣扭蛋机情侣游戏系统”,让你们的爱情之旅更加充满趣味与惊喜。 情侣扭蛋机不仅是一个简单的游戏工具,…

计算机的内存是如何实现的

你好,我是 shengjk1,多年大厂经验,努力构建 通俗易懂的、好玩的编程语言教程。 欢迎关注!你会有如下收益: 了解大厂经验拥有和大厂相匹配的技术等 希望看什么,评论或者私信告诉我! 文章目录 一…

成功解决Uncaught TypeError: Failed to resolve module specifier “vue“.

成功解决Uncaught TypeError: Failed to resolve module specifier “vue”. 一、问题背景 俗话说,温故而知新。首先,非常感谢我许哥,教会了我网页相关的知识,其他方面我也受益良多。言归正传,最近由于要运行Python&a…

【C语言习题】12.扫雷游戏

文章目录 1.扫雷游戏分析和设计1.1 扫雷游戏的功能说明1.2游戏界面:1.3游戏的分析和设计1.2.1 数据结构的分析1.2.2 ⽂件结构设计 2.扫雷游戏的代码实现3.代码讲解 1.扫雷游戏分析和设计 1.1 扫雷游戏的功能说明 使用控制台实现经典的扫雷游戏游戏可以通过菜单实现…

记笔记从学Typora开始--------------------(1)下载、安装、购买、激活

一、登录Typora官网 官网地址:Typora 二、鼠标往下滑,点击下载按钮 三、下载得到安装包,双击 四、一直点击下一步,进行安装 五、安装完成 六、启动Typoera 七、针对欢迎界面点击下一页 八、一直点击直到弹出以下软件激活界面 九…

深度盘点在当今经济形势下资深项目经理或PMO的或去或从

在当今经济形势下,资深项目经理(Project Manager)或项目管理办公室(PMO)的去向和选择受到多种因素的影响。以下是对他们可能面临的或去或从的深度盘点: 1、发展去向 1. 深化专业领域:在经济形势…

跨ROS系统通信:使用TCP实现节点间的直连

当涉及到在机器人操作系统(ROS)环境中的通信时,标准做法通常是在同一个ROS网络内通过话题和服务进行。但在某些特定情况下,比如当你有两个分布在不同网络中的ROS系统时,标准的通信方法可能不太适用。此时,一…

超实用的excel进销存管理系统(75份),自带库存预警,直接用!

进销存(Inventory Management)是企业管理中的一个核心组成部分,它涉及到商品的采购(进货)、销售和存储(库存)等环节。有效的进销存管理可以帮助企业降低成本、提高效率和客户满意度。 1. 采购管…

线程池的一些问题

核心线程数1.最大线程5.队列5.存活时间10s 1.场景一 如果核心线程数.被一直占用得不到释放.新进来1个任务.会怎么样?答: 会在队列中中死等. 只要进来的任务.不超过队列的长度,就会一直挡在队列中死等 package com.lin;import java.util.concurrent.Executors; import java.u…

knife4j案例

1.导入 <dependency><groupId>com.github.xiaoymin</groupId><artifactId>knife4j-spring-boot-starter</artifactId> </dependency>2.在配置类中加入 knife4j 相关配置并设置静态资源映射&#xff08;否则接口文档页面无法访问&#xff…

基于Python的jieba库分析《斗破苍穹》文本中的高频词汇

分析《斗破苍穹》文本中的高频词汇 在进行文本分析时&#xff0c;了解文本中出现频率较高的词汇对于把握文本的主题和风格非常有帮助。本文将介绍如何使用Python的jieba库对《斗破苍穹》这部小说的文本进行分词处理&#xff0c;并统计高频词汇的出现次数&#xff08;本文只统计…

【机器学习】:基于决策树与随机森林对数据分类

机器学习实验报告&#xff1a;决策树与随机森林数据分类 实验背景与目的 在机器学习领域&#xff0c;决策树和随机森林是两种常用的分类算法。决策树以其直观的树形结构和易于理解的特点被广泛应用于分类问题。随机森林则是一种集成学习算法&#xff0c;通过构建多个决策树并…

图解堆排序【一眼看穿逻辑思路】

P. S.&#xff1a;以下代码均在VS2019环境下测试&#xff0c;不代表所有编译器均可通过。 P. S.&#xff1a;测试代码均未展示头文件stdio.h的声明&#xff0c;使用时请自行添加。 目录 1、堆的概念2、实现堆排序前的准备工作3、堆排序的思路3.1 第一步3.2 第二步 4、结语 1、…