033、微调

之——高级炼丹术

目录

之——高级炼丹术

杂谈

正文

1.标注数据集是很贵的

2.微调的思想

3.尝试

小结


杂谈

        微调(Fine-tuning)是深度学习中的一种常见策略,它通常用于预训练模型在特定任务上的性能提升。微调的过程涉及在一个已经在大规模数据上进行了预训练的模型的基础上,通过使用目标任务的相关数据集进行进一步训练。

        一般的微调步骤如下:

  1. 预训练阶段: 在大规模的数据集上,使用一个深度学习模型进行预训练。这个数据集通常是非常庞大的,例如ImageNet包含数百万张图像。

  2. 微调阶段: 将预训练的模型应用于特定的任务,使用与该任务相关的小型数据集进行微调。这个小型数据集可能是与任务密切相关的领域或者任务本身的数据。

  3. 模型调整: 在微调阶段,可以调整模型的一些超参数,如学习率、优化器等,以适应特定任务的要求。

        微调的一个常见应用是使用在大规模图像分类任务上预训练的模型,然后在较小的数据集上微调以进行特定任务,比如医学图像分类或自定义的图像分类任务。应用迁移学习(transfer learning)将从源数据集学到的知识迁移到目标数据集


正文

1.标注数据集是很贵的

        大的公开数据集花销很大,而且强调泛化性,并不强调特定任务,而自己实际情况下的任务需求一般来说没有这么大的数据集,所以我们一般使用在公开数据集上预训练的模型拓展到我们自己的任务需求上。


2.微调的思想

  1. 在源数据集(例如ImageNet数据集)上预训练神经网络模型,即源模型

  2. 创建一个新的神经网络模型,即目标模型。这将复制源模型上的所有模型设计及其参数(输出层除外)。我们假定这些模型参数包含从源数据集中学到的知识,这些知识也将适用于目标数据集。我们还假设源模型的输出层与源数据集的标签密切相关;因此不在目标模型中使用该层。

  3. 向目标模型添加输出层,其输出数是目标数据集中的类别数。然后随机初始化该层的模型参数。

  4. 在目标数据集(如椅子数据集)上训练目标模型。输出层将从头开始进行训练,而所有其他层的参数将根据源模型的参数进行微调

        训练:

         一些技巧:

        神经网络越接近度量层输出层越是与数据集相关,而底层的则更加通用,固定底层的一些细节去优化语义是更有效率的选择,因为底层的通用特征是通过大数据集预训练出来的,是有很高质量的。


3.尝试

        下载数据集并查看:

import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

#下载hotdog数据集
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')
#%%

#读入数据集
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))


#展示
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)

        数据增广:

#数据增广
# 使用RGB通道的均值和标准差,以标准化每个通道
# imagenet做了这个事情,所以等搬过来,参数来自imagenet
normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomVerticalFlip(),
    torchvision.transforms.ToTensor(),
    normalize])

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    normalize])

        导入预训练数据集,修改输出层:

#%%
#查看预训练好的模型
pretrained_net = torchvision.models.resnet18(pretrained=True)

pretrained_net.fc
#%%
#由于我们的类别不一样,所以我们考虑copy预训练但改变最后的输出层
finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight)
finetune_net.fc

         训练:

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,
                      param_group=True):
    train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train'), transform=train_augs),
        batch_size=batch_size, shuffle=True)
    test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'test'), transform=test_augs),
        batch_size=batch_size)
    devices = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction="none")
    if param_group:
        params_1x = [param for name, param in net.named_parameters()
             if name not in ["fc.weight", "fc.bias"]]
        #别的层不变,最后一层10倍学习率
        trainer = torch.optim.SGD([{'params': params_1x},
                                   {'params': net.fc.parameters(),
                                    'lr': learning_rate * 10}],
                                lr=learning_rate, weight_decay=0.001)
    else:
        trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
                                  weight_decay=0.001)
    d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
                   devices)


#%%
#较小的学习率,通过微调预训练获得的模型参数
train_fine_tuning(finetune_net, 5e-5)

        结果:

        有预训练的5个epoch:

        没有预训练的:

         可见明显差别。


小结

  • 迁移学习将从源数据集中学到的知识迁移到目标数据集,微调是迁移学习的常见技巧。

  • 除输出层外,目标模型从源模型中复制所有模型设计及其参数,并根据目标数据集对这些参数进行微调。但是,目标模型的输出层需要从头开始训练。

  • 通常,微调参数使用较小的学习率,而从头开始训练输出层可以使用更大的学习率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/165556.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

7.22 SpringBoot项目实战【收藏 和 取消收藏】

文章目录 前言一、编写控制器二、编写服务层三、Postman测试最后前言 本系统还支持 收藏图书,就是对心仪的书加一下收藏,大家都懂,这是一个很常见的功能。 那么我们来看看怎么来做,先分析一下:【一个人】对【一本书】只需【收藏一次】,但可以【收藏N本】不同的书,收藏…

2024年csdn最新最全面的fiddler教程【1】

Fiddler简介 Fiddler是比较好用的web代理调试工具之一,它能记录并检查所有客户端与服务端的HTTP/HTTPS请求,能够设置断点,篡改及伪造Request/Response的数据,修改hosts,限制网速,http请求性能统计&#xff…

基于springboot实现校园在线拍卖系统项目【项目源码】计算机毕业设计

基于springboot实现校园在线拍卖系统演示 Javar技术 JavaScript是一种网络脚本语言,广泛运用于web应用开发,可以用来添加网页的格式动态效果,该语言不用进行预编译就直接运行,可以直接嵌入HTML语言中,写成js语言&…

高效管理文件:如何通过文件数量归类提高工作效率

在日常生活和工作中,需要处理大量的文件和资料。然而,如果这些文件没有得到妥善的管理,就会使得我们花费大量的时间和精力去寻找和整理它们。对于大量文件,按照数量归类可以使得文件管理更加有序和规范。根据文件的数量建立相应的…

Go 语言中切片的使用和理解

切片与数组类似,但更强大和灵活。与数组一样,切片也用于在单个变量中存储相同类型的多个值。然而,与数组不同的是,切片的长度可以根据需要增长和缩小。在 Go 中,有几种创建切片的方法: 使用[]datatype{val…

深信服AC流量管理技术

拓扑图 一.保证通道针对修仙部,访问网站,邮件,DNS,IM,办工 OA,微博论坛网上银行等常见应用保证带宽最低 50%,最高 100% 1. 先新建线路带宽 2.新增流量管理通道(保证关键应用&#x…

Selenium UI 自动化

一、Selenium 自动化 1、什么是Selenium? Selenium是web应用中基于UI的自动化测试框架。 2、Selenium的特点? 支持多平台、多浏览器、多语言。 3、自动化工作原理? 通过上图,我们可以注意到3个角色,下面具体讲解一…

16. @PostConstruct注解和开关原理(验证码开关、IP开关)

1►PostConstruct注解 PostConstruct是java自带的注解,会在java项目启动的时候先执行下面的方法 2►开关原理(验证码开关) 我们的项目具有验证码功能,旧版不支持关闭,新版已经支持关闭了。 我们打开页面“参数管…

腾讯云轻量数据库性能如何?轻量数据库租用配置价格表

腾讯云轻量数据库测评,轻量数据库100%兼容MySQL 5.7和8.0,腾讯云提供1C1G20GB、1C1G40GB、1C2G80GB、2C4G120GB、2C8G240GB五种规格轻量数据库,腾讯云百科txybk.com分享腾讯云轻量数据库测评、轻量数据库详细介绍、特性、配置价格和常见问题解…

运行ps显示msvcp140.dll丢失怎么恢复?msvcp140.dll快速解决的4个不同方法

msvcp140.dll无法继续执行代码的主要原因有以下几点 系统缺失:msvcp140.dll是Visual Studio 2015编译的程序默认的库文件,如果系统中没有这个库文件,那么在运行相关程序时就会出现找不到msvcp140.dll的错误提示。 文件损坏:如果…

C语言入门笔记—static、extern、define、指针、结构体

一、static static修饰局部变量的时候,局部变量出了作用域,不销毁。本质上,static修饰局部变量的时候,改变了变量的存储位置。详见下图,当a不被static修饰和被static修饰的时候。 C/C static关键字详解&#xff…

TensorRT量化实战课YOLOv7量化:YOLOv7-QAT量化

目录 前言1. YOLOv7-QAT流程2. QAT训练流程 前言 手写 AI 推出的全新 TensorRT 模型量化实战课程,链接。记录下个人学习笔记,仅供自己参考。 该实战课程主要基于手写 AI 的 Latte 老师所出的 TensorRT下的模型量化,在其课程的基础上&#xff…

SPASS-聚类和判别分析

聚类与判别分析概述 基本概念 聚类分析 聚类分析的基本思想是找出一些能够度量样本或指标之间相似程度的统计量,以这些统计量为划分类型的依据,把一些相似程度较大的样本(或指标)聚合为一类,把另外一些彼此之间相似程…

VMware——WindowServer2012R2安装jdk1.8及环境变量配置

一、安装 双击【jdk-8u161-windows-x64.exe】程序包,弹出窗口点击【下一步】,如下图: 指定安装目录为【Java\jdk1.8.0_161】,磁盘目录自定义,如下图: 点击【下一步】一直到有个【更改】按钮,可…

Network(五)数值介绍与子网划分

一 数值 1 数值介绍 (1)带宽 在一定时间内通过某一网络连接的信息量 基本单位:比特每秒 (bit/s) 在计算机软件方面用字节每秒为单位 (2)存储量 计算机存储量可以用位和字节计量 (3)常用…

C语言——2.安装并使用VS

文章目录 1.编译器是什么2.编译器的选择2.1.VS2019/2022 的初步了解2.2.为什么不选择其他编译器呢? 3.编译器的安装过程(保姆级别教学)3.1.检查电脑版本3.2.下载安装包3.3.选择安装选项3.4.重启电脑3.5.创建账户登录3.6.颜色配置3.7.VS&#…

【C语言.oj刷题】有序#整型矩阵元素查找##{思路+C源码}

目录 题目信息 题目分析: 法一: 遍历二维数组(低效) 思路 源码 局限性 法二: 对每一行二分查找(有所提效) 思路 源码 局限性 法三: 利用一切有利条件使用二分查找 思路 …

Linux系统之lsof命令的基本使用

Linux系统之lsof命令的基本使用 一、lsof命令的基本使用二、lsof命令的使用帮助2.1 lsof命令的help帮助信息2.2 lsof命令帮助解释 三、lsof的基本使用3.1 直接使用lsof命令3.2 查看某个进程打开的所有文件3.3 查看某个用户打开的所有文件3.4 查看某个文件被哪些进程打开3.5 查看…

卷积神经网络(CNN)天气识别

文章目录 前期工作1. 设置GPU(如果使用的是CPU可以忽略这步)我的环境: 2. 导入数据3. 查看数据 二、数据预处理1. 加载数据2. 可视化数据3. 再次检查数据4. 配置数据集 三、构建CNN网络四、编译五、训练模型六、模型评估 前期工作 1. 设置GP…

汽车虚拟仿真视频数据理解--CLIP模型原理

CLIP模型原理 CLIP的全称是Contrastive Language-Image Pre-Training,中文是对比语言-图像预训练,是一个预训练模型,简称为CLIP。该模型是 OpenAI 在 2021 年发布的,最初用于匹配图像和文本的预训练神经网络模型,这个任…