迁移学习|代码实现

图片

还记得我们之前实现的猫狗分类器吗?在哪里,我们设计了一个网络,这个网络接受一张图片,最后输出这张图片属于猫还是狗。实现分类器的过程比较复杂,准备的数据也比较少。所以我们是否可以使用一种方法,在数据很少的情况下仍然可以训练出较好的模型。

借助已经训练好的模型是个不错的想法。因此我们将学习如何使用预训练好的模型来构建只需要很少数据的先进的猫狗图像分类器。

图片

图片

首先,加载一个预训练的模型,例如ResNet18。

借助torchvision库,我们很容易获得一组已经训练好的模型。这些模型大多数接受一个称为pretrained的参数,当这个参数为True时,它会下载为ImageNet分类问题调整好的权重。就像这样:

from torchvision import modelsnetwork1=models.resnet18(pretrained=True)

当代码第一次运行时,需要一点时间...

接着,我们需要冻结所有层,所有权重不会随训练而更新。​​​​​​​

for param in network1.parameters():    param.requires_grad=False

图片

当然,这个模型并不是针对2分类问题,所以,我们需要将其最后一层的输出特征从1000改为2

首先我们要知道最后一层的名字:​​​​​​​

network1ResNet(  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)...  ...  ...  (fc): Linear(in_features=512, out_features=1000, bias=True)

最后一层是个全连接层,名为fc。

所以,我们就可将最后一层替换为输出特征为2的全连接层

network1.fc=nn.Linear(512,2)

注:此时,因为该层为新的层,所以其requires_grad=True,这样整个网络仅有这一层可以更新权重

打印网络​​​​​​​

network1ResNet(  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  (relu): ReLU(inplace=True)  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)...  ...  ...  (fc): Linear(in_features=512, out_features=2, bias=True)

此时,network1就是一个符合猫狗分类问题的模型

最后,既然我们只对最后一层训练,那么我们只需要将最后一层的参数传入优化器

optimizer=optim.SGD(network1.fc.parameters(),lr=...,momentnum=...)

总结一下代码:​​​​​​​

from torchvision import modelsimport torch.nn as nnimport torch.optim as optim#网络搭建network1=models.resnet18(pretrained=True)
for param in network1.parameters():    param.requires_grad=False
network1.fc=nn.Linear(512,2)#损失函数criterion=nn.CrossEntropyLoss()#优化器optimizer=optim.SGD(network1.fc.parameters(),lr=...,momentnum=...)

其实,我们就是利用已经训练好的模型的主要目的就是它已经能够提取出非常好的特征,最后一层接受前面层提取的特征,然后误差反向传播,仅更新这一层的权重,不断迭代,最后达到一个非常好的效果。

图片

我们这里只对最后一层进行了调整,只训练这一层,主要原因就是数据太少;如果数据较多,可以把预训练的前面一些层权重固定住,后面层不固定,修改最后一层以满足任务,然后训练;如果数据很多,算力充沛,那么可以对所有层进行精调,只把预训练的模型的参数作为初始化参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/299810.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于多反应堆的高并发服务器【C/C++/Reactor】(中)添加 删除 修改 释放

在上篇文章(处理任务队列中的任务)中我们讲解了处理任务队列中的任务的具体流程,eventLoopProcessTask函数的作用: 处理队列中的任务,需要遍历链表并根据type进行对应处理,也就是处理dispatcher中的任务。 // 处理任…

Linux之Ubuntu环境Jenkins部署前端项目

今天分享Ubuntu环境Jenkins部署前端vue项目 一、插件安装 1、前端项目依赖nodejs,需要安装相关插件 点击插件管理,输入node模糊查询 选择NodeJS安装 安装成功 2、配置nodejs 点击后进入 点击新增 NodeJS 配置脚手架类型:如果不填 默认npm …

华为HarmonyOS 创建第一个鸿蒙应用 运行Hello World

使用DevEco Studio创建第一个项目 Hello World 1.创建项目 创建第一个项目,命名为HelloWorld,点击Finish 选择Empty Ability模板,点击Next Hello World 项目已经成功创建,接来下看看效果 2.预览 Hello World 点击右侧的预…

[VUE]2-vue的基本使用

目录 vue基本使用方式 1、vue 组件 2、文本插值 3、属性绑定 4、事件绑定 5、双向绑定 6、条件渲染 7、axios 8、⭐跨域问题 🍃作者介绍:双非本科大三网络工程专业在读,阿里云专家博主,专注于Java领域学习,擅…

RPC基础知识总结

RPC 是什么? RPC(Remote Procedure Call) 即远程过程调用,通过名字我们就能看出 RPC 关注的是远程调用而非本地调用。 为什么要 RPC ? 因为,两个不同的服务器上的服务提供的方法不在一个内存空间,所以&am…

【UML】第17篇 包图

目录 一、什么是包图 二、包图的作用: 三、应用场景: 四、绘图符号的说明: 五、语法: 六、其他要说的 一、什么是包图 包图(Package Diagram)是一种用于描述系统中包和包之间关系的UML图。包是一种将…

Thonny开发ESP32点灯

简介 ESP32是一款功能强大的低功耗微控制器,由乐鑫(Espressif)公司开发。它集成了Wi-Fi和蓝牙功能,适用于各种物联网应用。Thonny是一款基于Python的开源集成开发环境(IDE),专为MicroPython设计…

【数据分享】2024年我国主要城市地铁站点和线路数据

地铁站点与线路数据是我们经常会用到的一种基础数据。去哪里获取该数据呢? 今天我们就给大家分享一份2024年1月采集的全国有地铁城市的地铁站点与线路数据,数据格式为shp,数据坐标为wgs1984地理坐标。数据中不仅包括地铁,也包括轻…

Java Swing手搓坦克大战遇到的问题和思考

1.游戏中的坐标系颇为复杂 像素坐标系还有行列坐标,都要使用,这之间的互相转化使用也要注意 2.游戏中坦克拐弯的处理,非常重要 由于坦克中心点是要严格对齐到一条网格线,并沿着这条线前进的,如果拐弯不做处理&#…

法线变换矩阵的推导

背景 在冯氏光照模型中,其中的漫反射项需要我们对法向量和光线做点乘计算。 从顶点着色器中读入的法向量数据处于模型空间,我们需要将法向量转换到世界空间,然后在世界空间中让法向量和光线做运算。这里便有一个问题,如何将法线…

AI爆文变现:怼量也有技巧!如何提升你的创作收益

做AI爆文项目,赚小钱是没有问题的。 想要赚大钱,就是要做矩阵,怼量。 之前参加训练营的时候,也是要求怼量。 怼量,加高质量文章,让你的收益更高。 如何提升文章质量,减少AI味,AI…

性能优化-OpenMP基础教程(二)

本文主要介绍OpenMP并行编程技术,编程模型、指令和函数的介绍、以及OpenMP实战的几个例子。希望给OpenMP并行编程者提供指导。 🎬个人简介:一个全栈工程师的升级之路! 📋个人专栏:高性能(HPC&am…

动手学深度学习之卷积神经网络之池化层

池化层 卷积层对位置太敏感了,可能一点点变化就会导致输出的变化,这时候就需要池化层了,池化层的主要作用就是缓解卷积层对位置的敏感性 二维最大池化 这里有一个窗口,来滑动,每次我们将窗口中最大的值给拿出来 还是上…

更改ERPNEXT源

更改ERPNEXT源 一, 更改源 针对已经安装了erpnext的,需要更改源的情况: 1, 更改为官方默认源, 进入frapp-bench的目录, 然后执行: bench remote-reset-url frappe //重设frappe的源为官方github地址。 bench remote-reset-url…

ARM工控机Node-red使用教程

嵌入式ARM工控机Node-red安装教程 从前车马很慢书信很远,而现在人们不停探索“科技改变生活”。 智能终端的出现改变了我们的生活方式,钡铼技术嵌入式工控机协助您灵活布建能源管理、大楼自动化、工业自动化、电动车充电站等各种多元性IoT应用&#xff…

IDEA 控制台中文乱码问题解决方法(UTF-8 编码)

设置 IDEA 编码格式 1:打开 IntelliJ IDEA>File>Setting>Editor>File Encodings,将 Global Encoding、Project Encoding、Default encodeing for properties files 这三项都设置成 UTF-8 2:将 vm option 参数改为: -…

spring bean对象request字段自动注入javax.servlet.ServletRequest

##spring bean对象request字段自动注入javax.servlet.ServletRequest 对象里面request全局变量是否线程安全呢? 答案:线程安全 ##代码如下 import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Controller; …

解决“invalid UTF-8 encoding”

有如下一个程序 package mainimport"fmt"func main(){fmt.Println("hello,2024年") }go run xxx.go出现以下的问题 问题“invalid UTF-8 encoding”,无效的utf8编码。有可能是文件的编码不是“utf8” 为了验证猜想,看一下“xxx.go”…

【vue】emit 的理解与使用

文章目录 介绍流程示例效果父组件子组件 介绍 $emit 是 Vue 组件实例中的一个方法,用来触发自定义事件,并向父组件传递信息它接受两个参数:事件名称和可选参数this.$emit(事件名称, 参数);流程 示例 效果 触发前 触发后 父组件 父组件使…

windows安装nvm以及nvm常用命令

目录 1.什么是nvm以及为啥要用nvm 1.什么是nvm 2.为什么要用nvm 2.安装nvm 1. 下载 2. 安装 1.双击解压后的文件,nvm-setup.exe 2.同意 3.安装路径 4.下一步,这里有建议改成自己的文件夹,这个是用来存储通过nvm切换node后版本的存储路径 5.安装…