李沐37_微调——自学笔记

标注数据集很贵

网络架构

1.一般神经网络分为两块,一是特征抽取原始像素变成容易线性分割的特征,二是线性分类器来做分类

微调

1.原数据集不能直接使用,因为标号发生改变,通过微调可以仍然对我数据集做特征提取

2.pre-train源数据集后,迁移模型中的架构,调整权重

3.是一个目标数据集上的正常训练任务,但使用更强的正则化(更小的学习率、更少的数据迭代)

4.源数据集远复杂于目标数据,通常微调效果更好

重用分类器权重

1.源数据集可能也有目标数据中的部分标号

2.可以使用预训练好的模型分类器中对应标号对应的向量来初始化

固定一些层

1.神经网络通常学习有层次的特征表达,低层次的特征更加通用,高层次的特征则和数据集相关

2.可以固定底部一些层的参数,不参与更新,更强的正则化

总结

1.微调通过使用在大数据上得到的预训练好的模型来初始化模型权重来提升精度

2.预训练模型质量很重要

3.微调通常速度更快、精度更高

代码实现——热狗识别

%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

获取数据集

解压下载热狗数据集,一共1400张热狗“正向”图片,其他尽可能多的其他食物“负向”图片,解压下载的数据集,我们获得了两个文件夹hotdog/train和hotdog/test。 这两个文件夹都有hotdog(有热狗)和not-hotdog(无热狗)两个子文件夹, 子文件夹内都包含相应类的图像。

d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')
Downloading ../data/hotdog.zip from http://d2l-data.s3-accelerate.amazonaws.com/hotdog.zip...

读取训练集和测试集的图片

train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))

前8个正类、后8个负类,图片的大小和高宽比不同。

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

在这里插入图片描述

训练:从图像中裁切随机大小和随机长宽比的区域,然后将该区域缩放为224X224
输入图像。
测试:我们将图像的高度和宽度都缩放到256像素,然后裁剪中央224X224
区域作为输入。

此外,对于RGB(红、绿和蓝)颜色通道,我们分别标准化每个通道。 具体而言,该通道的每个值减去该通道的平均值,然后将结果除以该通道的标准差。

# 使用RGB通道的均值和标准差,以标准化每个通道
normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    normalize])

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    normalize])

初始化模型

使用在ImageNet数据集上预训练的ResNet-18作为源模型。 在这里,我们指定pretrained=True以自动下载预训练的模型参数。 如果首次使用此模型,则需要连接互联网才能下载。

pretrained_net = torchvision.models.resnet18(pretrained=True)
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 111MB/s]

fc:预训练的源模型实例包含许多特征层和一个输出层fc。 此划分的主要目的是促进对除输出层以外所有层的模型参数进行微调。

# fc
pretrained_net.fc
Linear(in_features=512, out_features=1000, bias=True)

目标模型finetune_net中成员变量features的参数被初始化为源模型相应层的模型参数。 由于模型参数是在ImageNet数据集上预训练的,并且足够好,因此通常只需要较小的学习率即可微调这些参数。

成员变量output的参数是随机初始化的,通常需要更高的学习率才能从头开始训练。 假设Trainer实例中的学习率为n,我们将成员变量output中参数的学习率设置为10n。

finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);

训练模型微调

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,
                      param_group=True):
    train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train'), transform=train_augs),
        batch_size=batch_size, shuffle=True)
    test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'test'), transform=test_augs),
        batch_size=batch_size)
    devices = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction="none")
    if param_group:
        params_1x = [param for name, param in net.named_parameters()
             if name not in ["fc.weight", "fc.bias"]]
        trainer = torch.optim.SGD([{'params': params_1x},
                                   {'params': net.fc.parameters(),
                                    'lr': learning_rate * 10}],
                                lr=learning_rate, weight_decay=0.001)
    else:
        trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
                                  weight_decay=0.001)
    d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
                   devices)

使用较小的学习率

train_fine_tuning(finetune_net, 5e-5)
loss 0.238, train acc 0.910, test acc 0.941
360.1 examples/sec on [device(type='cuda', index=0)]

在这里插入图片描述

为了比较,定义相同的模型,但需较大学习率

scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
train_fine_tuning(scratch_net, 5e-4, param_group=False)
loss 0.403, train acc 0.821, test acc 0.791
366.6 examples/sec on [device(type='cuda', index=0)]

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/547257.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【保姆级】2024年OnlyFans订阅指南

OnlyFans是一个独特的社交媒体平台&#xff0c;它为创作者和粉丝提供了一个互动交流的空间。通过这个平台&#xff0c;创作者可以分享他们的独家内容&#xff0c;而粉丝则可以通过订阅来支持和享受这些内容。如果你对OnlyFans感兴趣&#xff0c;并希望成为其中的一员&#xff0…

LeetCode 113—— 路径总和 II

阅读目录 1. 题目2. 解题思路3. 代码实现 1. 题目 2. 解题思路 看到树的问题一般我们先考虑一下是否能用递归来做。 假设 root 节点的值为 value&#xff0c;如果根节点的左子树有一个路径总和等于 targetSum - value&#xff0c;那么只需要将根节点的值插入到这个路径列表中…

hbase-2.2.7分布式搭建

一、下载上传解压 1.在官网或者云镜像网站下载jar包 华为云镜像站&#xff1a;Index of apache-local/hbase/2.2.7 2.上传到linux并解压 tar -zxvf hbase-2.2.7-bin.tar.gz -C /usr/locol/soft 二、配置环境变量 1. vim /etc/profile export HBASE_HOME/usr/local/soft/h…

快速探索随机树-RRT

文章目录 简介原理算法运动规划的变体和改进简介 快速探索随机树(RRT)是一种算法,旨在通过随机构建空间填充树来有效搜索非凸高维空间。该树是从搜索空间随机抽取的样本中逐步构建的,并且本质上偏向于向问题的大型未搜索区域生长。RRT 由 Steven M. LaValle 和 James J. K…

Unity 扩展自定义编辑器窗口

在Assets文件夹路径下任意位置创建Editor文件夹&#xff0c;将扩展编辑器的代码放在Editor文件夹下 生成编辑器窗口 代码中首先引用命名空间 using UnityEditor; 然后将创建的类继承自EditorWindow public class MenuEditor : EditorWindow 然后通过扩展编辑器菜单功能调用…

Jackson 2.x 系列【24】Spring Web 集成之 Jackson2ObjectMapperBuilder

有道无术&#xff0c;术尚可求&#xff0c;有术无道&#xff0c;止于术。 本系列Jackson 版本 2.17.0 源码地址&#xff1a;https://gitee.com/pearl-organization/study-jaskson-demo 文章目录 1. 前言2. Spring Web3. Jackson2ObjectMapperBuilder3.1 成员属性3.2 静态方法3…

FMEA分析

目录 1、FMEA的核心目的 2、FMEA的种类 3、FMEA的实施步骤 4、FMEA的SOD等级 5、FMEA的例子 FMEA&#xff08;Failure Modes and Effects Analysis&#xff0c;失效模式与影响分析&#xff09;是一种预防性的可靠性设计分析&#xff0c;用来确定潜在失效模式及其原因。它主…

IDEA使用SCALA

一、在IDEA中下载插件 在设置->插件中找到scala&#xff0c;并下载。 下载完成后重启idea 二、在idea中创建spark的RDD操作项目 新建项目选中Scala。 创建完成后为项目添加java包&#xff0c;这个添加的是spark安装包中jars目录下的所有jar包 然后编写RDD操作 import or…

安全中级-初开始

一、网络基础 重要点&#xff1a;TTL值&#xff08;防环&#xff0c;linux64.Windows128 &#xff09;&#xff0c;IP数据包包头格式字节&#xff08;20&#xff09; 标识标志偏移量起到什么作用&#xff08;数据超过1500会分片&#xff09; wireshack抓包会有一个MSS&#x…

铭飞 MCMS 存在SQL注入漏洞

声明&#xff1a; 本文仅用于技术交流&#xff0c;请勿用于非法用途 由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;文章作者不为此承担任何责任。 简介 铭飞&#xff08;MCMS&#xff09;是一种计算机管理…

Opencv3.4+FFMpeg3.4+pkg-config交叉编译arm开发板

Ubuntu16.04 64位 FFmpeg3.4 OpenCv3.4 一、下载FFmpeg https://github.com/FFmpeg/FFmpeg 1.配置 ./configure --prefix/home/zeng/ffmpeg_install --enable-cross-compile --cross-prefixarm-linux-gnueabihf- --ccarm-linux-gnueabihf-gcc --target-oslinux --cpuco…

算法课程笔记——排序

Bool返回真假 为何用const不用define 1.保护被修饰的东西 2.通常不分配存储空间&#xff0c; 效率高 匿名函数只在一处用&#xff0c;其他处用不到 不写&就是拷贝 u相等就u&#xff0c;不等就v 一个字符是空格一个是换行&#xff0c;后面是取下标i那就是1&#xff08;true&…

图解数学:拉格朗日松弛方法的直观理解

昨晚写了拉格朗日松弛方法的原理分析&#xff0c;今天意犹未尽&#xff0c;图解一下&#xff0c;从直观上进一步理解这种方法。 一、一个简单例子 我们先来看一个简单的例子&#xff0c;下面数学规划问题没有约束条件&#xff1a; min ⁡ f ( x ) − x 2 8 x − 10 \begin…

全排列问题

日升时奋斗&#xff0c;日落时自省 目录 1、全排列 2、全排列II 3、子集 4、组合 1、全排列 首先要了解全排列是怎么样的 例如:数组[1,2,3]的全排列&#xff08;全排列就是不同顺序排列方式&#xff09; 例子所有的排列方式如&#xff1a;[1,2,3],[1,3,2],[2,1,3],[2,3…

Leetcode876_链表的中间结点

1.leetcode原题链接&#xff1a;. - 力扣&#xff08;LeetCode&#xff09; 2.题目描述 给你单链表的头结点 head &#xff0c;请你找出并返回链表的中间结点。 如果有两个中间结点&#xff0c;则返回第二个中间结点。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5…

PTA 编程题(C语言)-- 判断素数

题目标题&#xff1a; 判断素数 题目作者 陈越 浙江大学 本题的目标很简单&#xff0c;就是判断一个给定的正整数是否素数。 输入格式&#xff1a; 输入在第一行给出一个正整数N&#xff08;≤ 10&#xff09;&#xff0c;随后N行&#xff0c;每行给出一个小于…

康耐视visionpro-CoglntersectLineLineTool操作说明工具详细说明

◆CogIntersectLineLineTool功能说明&#xff1a; 创建两条线的交点 备注&#xff1a;在“Geometry-Intersection”选项中的所有工具都是创建两个图形的交点工具&#xff0c;其中包括圆与圆的交点、线与圆的交点、线与线的交点、线与圆的交点等&#xff0c;工具使用的方法相似。…

C++ - set 和 map详解

目录 0. 引言 1. 关联式容器 2. 键值对 3. 树形结构 4. set 4.1 set 的定义 4.2 set 的构造 4.3 set 的常用函数 4.4 set 的特点 5. multiset 5.1 multiset 插入冗余数据 5.2 multiset - count 的使用 6. map 6.1 map 的定义 6.2 map 的构造 6.3 map的常…

dcoker+nginx解决前端本地开发跨域

步骤 docker 拉取nginx镜像跑容器 并配置数据卷nginx.conf nginx.conf文件配置 这里展示server server {listen 80;listen [::]:80;server_name localhost;#access_log /var/log/nginx/host.access.log main;location / {# 当我们访问127.0.0.1:8028就会跳转到ht…

软件2班20240415

快速引用类选择器 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevi…