9-pytorch-现有模型使用及修改

b站小土堆pytorch教程学习笔记

1 使用ImageNet测试模型vgg16

train_data=torchvision.datasets.ImageNet('dataset/ImageNet',train=True
,download=True
,transform=torchvision.transforms.ToTensor())

代码运行报错:ImageNet数据集过大,导致现在无法公开访问。

2 查看VGG16是否预训练:

vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(pretrained=True)

vgg16_false:想当于直接download一个模型,网络参数为默认初始化
在这里插入图片描述
vgg16_true:使用其他数据集预训练完成的包括一定参数的模型
在这里插入图片描述
输出模型:
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)

out_features=1000即输出分类为1000

3 如何使用VGG训练CIFAR10 十分类数据集

  1. 直接将out_features=1000改为out_features=10
vgg16_false.classifier[6]=nn.Linear(4096,10)#将分类下的第六个层修改
print(vgg16_false)

(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=10, bias=True)
)

  1. 在最后一层线性层后再加一个输出层使得Linear(in_features=1000, out_features=10, bias=True)
vgg16_true.add_module('add_linear',nn.Linear(1000,10))#直接在最后加
print((vgg16_true))


(add_linear): Linear(in_features=1000, out_features=10, bias=True)
)
Process finished with exit code 0

3. 实际中,常见将VGG16当做前置网络结构用来提取特征,后接特定结构来完成特定任务

4 模型加载与保存

方式1

#保存
torch.save(vgg16_false,'vgg16_false_method1.pth')#保存模型结构及其参数
#加载
model=torch.load('vgg16_false_method1.pth')

方式2(官方推荐)

#保存
torch.save(vgg16_false.state_dict(),'vgg16_false_method2.pth')#将模型参数保存为Python中的字典格式
#加载
vgg16=torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load('vgg16_false_method2.pth'))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/408735.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

聊聊 Go 边界检查消除

前言 在这篇文章中碰巧看到了Go边界检查消除相关的讨论. 我也借此简单聊聊. 有这样一段代码, 非常简单, 就是一段求向量点积的程序: func sum(a, b []int) int {if len(a) ! len(b) {panic("must be same len")}ret : 0for i : 0; i < len(a); i {ret a[i] * …

SAM轻量化的终点竟然是RepViT + SAM

本文首发&#xff1a;AIWalker&#xff0c;欢迎关注~~ 殊途同归&#xff01;SAM轻量化的终点竟然是RepViT SAM&#xff0c;移动端速度可达38.7fps。 对于 2023 年的计算机视觉领域来说&#xff0c;「分割一切」&#xff08;Segment Anything Model&#xff09;是备受关注的一项…

0-1背包问题-动态规划

解法归纳&#xff1a; 一、如果装不下当前物品&#xff0c;那么前n个物品的最佳组合和前n-1个物品的最佳组合是一样的。 二、如果装得下当前物品。 假设1 :装当前物品&#xff0c;在给当前物品预留了相应空间的情况下&#xff0c;前n-1 个物品的最佳组 合加上当前物品的价值就…

作业 找单身狗2

方法一&#xff1a; 思路&#xff1a; 我们可以先创建一个新的数组&#xff0c;初始化为0&#xff0c;然后让原来的数组里面的元素作为新数组的下标 如果该下标对应的值为0&#xff0c;说明没有出现过该数&#xff0c;赋值为1作为标记&#xff0c;表示出现过1次 如果该下标…

#FPGA(基础知识)

1.IDE:Quartus II 2.设备&#xff1a;Cyclone II EP2C8Q208C8N 3.实验&#xff1a;正点原子-verilog基础知识 4.时序图&#xff1a; 5.步骤 6.代码&#xff1a;

代码随想录刷题第41天

首先是01背包的基础理论&#xff0c;背包问题&#xff0c;即如何在有限数量的货物中选取使具有一定容量的背包中所装货物价值最大。使用动规五步曲进行分析&#xff0c;使用二维数组do[i][j]表示下标从0到i货物装在容量为j背包中的最大价值&#xff0c;dp[i][j]可由不放物品i&a…

物理备份的方式

完全备份恢复流程 停止数据库清理环境重演回滚&#xff0d;&#xff0d;> 恢复数据修改权限启动数据库 1.关闭数据库&#xff1a; [rootmysql-server ~]# systemctl stop mysqld [rootmysql-server ~]# rm -rf /var/lib/mysql/* //删除所有数据// [rootmysql-server ~]# …

Sora:颠覆性AI视频生成工具

Sora是一款基于人工智能&#xff08;AI&#xff09;技术的视频生成工具&#xff0c;它彻底改变了传统视频制作的模式&#xff0c;为创作者提供了高效、便捷、高质量的视频内容生成方式。通过深度学习和自然语言处理等先进技术&#xff0c;Sora实现了从文字描述到视频画面的自动…

并发编程(5)共享模型之不可变

7 共享模型之不可变 本章内容 不可变类的使用不可变类设计无状态类设计 7.1 日期转换的问题 问题提出 下面的代码在运行时&#xff0c;由于 SimpleDateFormat 不是线程安全的, 有很大几率出现 java.lang.NumberFormatException 或者出现不正确的日期解析结果&#xff0c;…

SpringCloud Alibaba 2022之Nacos学习

SpringCloud Alibaba 2022使用 SpringCloud Alibaba 2022需要Spring Boot 3.0以上的版本&#xff0c;同时JDK需要是17及以上的版本。具体的可以看官网的说明。 Spring Cloud Alibaba版本说明 环境搭建 这里搭建的是一个聚合项目。项目结构如下&#xff1a; 父项目的pom.xm…

(拦截器)学习SpringMVC的第三天

一 .拦截器简介 拦截器的几个处理阶段 二 . 拦截器快速入门 2.1 实现拦截器接口 public class MyInterceptor1 implements HandlerInterceptor {Overridepublic boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Excep…

微信小程序开启横屏调试

我们先打开小程序项目 开启真机运行 目前是一个竖屏的 然后打开全局配置文件 app.json 给下面的 window 对象 下面加一个 pageOrientation 属性 值为 landscape 运行结果如下 然后 我们开启真机运行 此时 就变成了个横屏的效果

(done) Positive Semidefinite Matrices 什么是半正定矩阵?如何证明一个矩阵是半正定矩阵? 可以使用特征值

参考视频&#xff1a;https://www.bilibili.com/video/BV1Vg41197ew/?vd_source7a1a0bc74158c6993c7355c5490fc600 参考资料(半正定矩阵的定义)&#xff1a;https://baike.baidu.com/item/%E5%8D%8A%E6%AD%A3%E5%AE%9A%E7%9F%A9%E9%98%B5/2152711?frge_ala 看看半正定矩阵的…

ubantu设置mysql开机启动

阅读本文之前请参阅----MySQL 数据库安装教程详解&#xff08;linux系统和windows系统&#xff09; 在Ubuntu系统中设置MySQL开机启动&#xff0c;通常有以下几种方法&#xff1a; 1. **使用systemctl命令**&#xff1a; Ubuntu 16.04及更高版本使用systemd作为…

Facebook群控:利用代理IP克服多账号关联

拥有多个 Facebook 帐户对于区分您的个人和企业在线形象或维护客户页面非常有用。然而&#xff0c;Facebook 的服务条款正式限制用户只能使用一个个人帐户&#xff0c;想要多账号运营&#xff0c;下面的干货必须看&#xff01; 一、Facebook群控是什么&#xff1f; Facebook群…

HDL FPGA 学习 - FPGA基本要素,开发流程,Verilog语法和规范、编写技巧

目录 Altera FPGA 基本要素 FPGA 开发流程和适用范围 设计和实施规范 顶层设计的要点 Verilog HDL 语法规范 编写规范 设计技巧 编辑整理 by Staok&#xff0c;始于 2021.2 且无终稿。转载请注明作者及出处。整理不易&#xff0c;请多支持。 本文件是“瞰百易”计划的…

线程计数器(CountDownLatch)

&#x1f96d;线程计数器&#xff08;CountDownLatch&#xff09; CountDownLatch也属于共享锁&#xff0c;其内部有一个int类型的属性表示可以同时并发并行的线程的数量 同时等待N个任务执行结束 举例说明&#xff1a; 比如跑步比赛&#xff0c;必须等所有运动员通过终点才…

Oracle EBS GL 外币折算逻辑

背景 由于公司财务在10月份期间某汇率维护错误,导致帐套折算以后并合传送至合并帐套生成合并日记帐凭证的借贷金额特别大,但是财务核对的科目余额有没有问题,始终觉得合并日记帐生成会计分发有问题,需要我们给出外币折算逻辑。 基础设置 汇率 Path: GL->设置->币种-&…

PHP语言检测用户输入密码及调用Python脚本

现在有一份计算流体力学N-S方程的Python脚本&#xff0c;想要在用户登录网站后可以可以运行该脚本&#xff0c;然后将脚本运行后绘制的图片显示在用户网页上。 建一个名为N_S.py的python脚本文件&#xff0c;这个脚本在生成图像后会自行关闭&#xff0c;随后将图片保存在指定的…

Stable Diffusion 3重磅发布

刚不久&#xff0c;Stability AI发布了Stable Diffusion 3.0&#xff0c;这一版本采用了与备受瞩目的爆火Sora相同的DiT架构。通过这一更新&#xff0c;画面质量、文字渲染以及对复杂对象的理解能力都得到了显著提升。由于这些改进&#xff0c;先前的技术Midjourney和DALL-E 3在…