深度学习_微调_7

目标

  • 微调的原理
  • 利用微调模型来完成图像的分类任务

微调的原理

微调(Fine-tuning)是一种在深度学习中广泛应用的技术,特别是在预训练模型(Pretrained-Models)的基础上进行定制化训练的过程。微调的基本原理和步骤如下:

  1. 预训练阶段

    • 微调通常始于一个已经在大规模数据集上预训练过的模型,例如预训练的神经网络模型(如BERT、GPT系列、Vision Transformer等)。
    • 预训练模型在诸如自然语言处理(NLP)或计算机视觉(CV)等任务上进行了自我监督学习或有监督学习,从而习得了大量的通用特征表示。
  2. 适应新任务

    • 当需要解决与预训练任务相关的但更加具体的新任务时,就可以利用微调技术。
    • 对于NLP任务,可能是训练模型去回答问题、生成文本或者分类;对于CV任务,可能是对特定种类的物体进行识别或定位。
  3. 模型结构调整

    • 根据新任务的需求,可能需要对预训练模型的部分或全部层进行微调。
    • 通常在迁移学习中,预训练模型的底层可以捕获非常通用的特征,因此通常会被保留并微调,而顶层(尤其是分类层)会替换为新的、与新任务适配的输出层。
  4. 参数更新

    • 微调过程中,模型在新任务的数据集上重新开始训练,不过并非从随机初始化参数开始,而是从预训练模型的参数开始。
    • 学习率通常会设置得相对较低,以免破坏预训练模型学到的良好特征表示。
    • 其他优化参数(如权重衰减、批归一化层的状态等)也可能根据新任务的特点进行调整。
  5. 学习过程

    • 在目标数据集上训练模型时,不仅会更新新增的输出层参数,也会对预训练模型的某些层参数进行微调,使其更好地适应新任务的数据分布和特征。
    • 由于预训练模型已经具备良好的初始化,因此在相对较小的数据集上进行微调时,模型往往能够更快收敛到较好的解。

总结起来,微调的原理是利用预训练模型中的已学知识作为初始状态,通过对新任务数据的训练,对模型参数进行针对性的更新和优化,从而使模型能够适应新的应用场景。相较于从头训练,微调大大减少了所需的训练时间和数据量,提高了模型在特定任务上的性能和泛化能力。

1.微调

如何在只有6万张图像的MNIST训练数据集上训练模型。学术界当下使用最广泛的大规模图像数据集ImageNet,它有超过1,000万的图像和1,000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。假设我们想从图像中识别出不同种类的椅子,然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子,为每种椅子拍摄1,000张不同角度的图像,然后在收集到的图像数据集上训练一个分类模型。另外一种解决办法是应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效。

微调由以下4步构成。

  1. 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。
  2. 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  3. 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。
  4. 在目标数据集(如椅子数据集)上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。

当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力。

2.热狗识别

接下来我们来实践一个具体的例子:热狗识别。将基于一个小数据集对在ImageNet数据集上训练好的ResNet模型进行微调。该小数据集含有数千张热狗或者其他事物的图像。我们将使用微调得到的模型来识别一张图像中是否包含热狗。

首先,导入实验所需的工具包。

import tensorflow as tf
import numpy as np

2.1 获取数据集

我们首先将数据集放在路径hotdog/data之下:

每个类别文件夹里面是图像文件。

上一节中我们介绍了ImageDataGenerator进行图像增强,我们可以通过以下方法读取图像文件,该方法以文件夹路径为参数,生成经过图像增强后的结果,并产生batch数据:

flow_from_directory(self, directory,
                            target_size=(256, 256), color_mode='rgb',
                            classes=None, class_mode='categorical',
                            batch_size=32, shuffle=True, seed=None,
                            save_to_dir=None)

主要参数:

  • directory: 目标文件夹路径,对于每一个类对应一个子文件夹,该子文件夹中任何JPG、PNG、BNP、PPM的图片都可以读取。
  • target_size: 默认为(256, 256),图像将被resize成该尺寸。
  • batch_size: batch数据的大小,默认32。
  • shuffle: 是否打乱数据,默认为True。

我们创建两个tf.keras.preprocessing.image.ImageDataGenerator实例来分别读取训练数据集和测试数据集中的所有图像文件。将训练集图片全部处理为高和宽均为224像素的输入。此外,我们对RGB(红、绿、蓝)三个颜色通道的数值做标准化。

# 获取数据集
import pathlib
train_dir = 'transferdata/train'
test_dir = 'transferdata/test'
# 获取训练集数据
train_dir = pathlib.Path(train_dir)
train_count = len(list(train_dir.glob('*/*.jpg')))
# 获取测试集数据
test_dir = pathlib.Path(test_dir)
test_count = len(list(test_dir.glob('*/*.jpg')))
# 创建imageDataGenerator进行图像处理
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
# 设置参数
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
# 获取训练数据
train_data_gen = image_generator.flow_from_directory(directory=str(train_dir),
                                                    batch_size=BATCH_SIZE,
                                                    target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                    shuffle=True)
# 获取测试数据
test_data_gen = image_generator.flow_from_directory(directory=str(test_dir),
                                                    batch_size=BATCH_SIZE,
                                                    target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                    shuffle=True)

下面我们随机取1个batch的图片然后绘制出来。

import matplotlib.pyplot as plt
# 显示图像
def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10,10))
    for n in range(15):
        ax = plt.subplot(5,5,n+1)
        plt.imshow(image_batch[n])
        plt.axis('off')
# 随机选择一个batch的图像        
image_batch, label_batch = next(train_data_gen)
# 图像显示
show_batch(image_batch, label_batch)

2.2 模型构建与训练

我们使用在ImageNet数据集上预训练的ResNet-50作为源模型。这里指定weights='imagenet'来自动下载并加载预训练的模型参数。在第一次使用时需要联网下载模型参数。

Keras应用程序(keras.applications)是具有预先训练权值的固定架构,该类封装了很多重量级的网络架构,如下图所示:

实现时实例化模型架构:

tf.keras.applications.ResNet50(
    include_top=True, weights='imagenet', input_tensor=None, input_shape=None,
    pooling=None, classes=1000, **kwargs
)

主要参数:

  • include_top: 是否包括顶层的全连接层。
  • weights: None 代表随机初始化, 'imagenet' 代表加载在 ImageNet 上预训练的权值。
  • input_shape: 可选,输入尺寸元组,仅当 include_top=False 时有效,否则输入形状必须是 (224, 224, 3)(channels_last 格式)或 (3, 224, 224)(channels_first 格式)。它必须为 3 个输入通道,且宽高必须不小于 32,比如 (200, 200, 3) 是一个合法的输入尺寸。

在该案例中我们使用resNet50预训练模型构建模型:

# 加载预训练模型
ResNet50 = tf.keras.applications.ResNet50(weights='imagenet', input_shape=(224,224,3))
# 设置所有层不可训练
for layer in ResNet50.layers:
    layer.trainable = False
# 设置模型
net = tf.keras.models.Sequential()
# 预训练模型
net.add(ResNet50)
# 展开
net.add(tf.keras.layers.Flatten())
# 二分类的全连接层
net.add(tf.keras.layers.Dense(2, activation='softmax'))

接下来我们使用之前定义好的ImageGenerator将训练集图片送入ResNet50进行训练。

# 模型编译:指定优化器,损失函数和评价指标
net.compile(optimizer='adam',
            loss='categorical_crossentropy',
            metrics=['accuracy'])
# 模型训练:指定数据,每一个epoch中只运行10个迭代,指定验证数据集
history = net.fit(
                    train_data_gen,
                    steps_per_epoch=10,
                    epochs=3,
                    validation_data=test_data_gen,
                    validation_steps=10
                    )

Epoch 1/3
10/10 [==============================] - 28s 3s/step - loss: 0.6931 - accuracy: 0.5031 - val_loss: 0.6930 - val_accuracy: 0.5094
Epoch 2/3
10/10 [==============================] - 29s 3s/step - loss: 0.6932 - accuracy: 0.5094 - val_loss: 0.6935 - val_accuracy: 0.4812
Epoch 3/3
10/10 [==============================] - 31s 3s/step - loss: 0.6935 - accuracy: 0.4844 - val_loss: 0.6933 - val_accuracy: 0.4875

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/481563.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CRM软件推荐2024:五款顶级产品解析,助您找到最佳选项!

一天之计在于晨,一年之计在于春。 2024年,民营经济发展继续壮大,这对于各行各业而言都是一种机遇挑战。企业想要规范化客户管理,实现销售增长,CRM软件仍然是一个不错的选择。在数字化时代,企业数字化转型已…

预防颈椎病,从职场健康做起

随着现代社会工作方式的转变,职场人士长时间伏案工作,颈椎病的发病率逐渐上升。本文将介绍一些实用的预防颈椎病的方法,帮助职场人士保持健康,提高工作效率。 一、了解颈椎病 颈椎病是指颈椎间盘退行性变及其继发性椎间关节病理性…

基于Python实现高德地图找房系统-爬虫分析

概要 针对大学毕业生对于工作地周边交通出行情况不了解、租房困难等问题,本文主要研究了厦门市的租房信息及地铁公交出行路线,利用Python爬虫爬取58同城上厦门市的租房信息,并进行处理分析,再通过高德地图API将房源信息展示在地图上,实现了基于高德地图API的租房地图。 关键词&…

基于Spring Boot技术的幼儿园管理系统

摘 要 随着信息时代的来临,过去的传统管理方式缺点逐渐暴露,对过去的传统管理方式的缺点进行分析,采取计算机方式构建幼儿园管理系统。本文通过课题背景、课题目的及意义相关技术,提出了一种活动信息、课程信息、菜谱信息、通知公…

Angular入门问题小本本

1、console.log打印object对象显示[object object] 解决方案:使用JSON.stringify console.log(JSON.stringify($rootScope.MaintainDeviceInfo));2、 State ‘goDiskManagement’’ is already defined 解决方案:同一个项目中,不能定义相同…

基于骨骼的动作识别的行动结构图卷积网络

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 摘要Abstract文献阅读:基于骨骼的动作识别的行动结构图卷积网络1、研究背景2、方法提出3、关键结构3.1、A-links inference module (AIM)3.2、Structura…

MyBatis3源码深度解析(二十)动态SQL实现原理(一)动态SQL的核心组件

文章目录 前言第八章 动态SQL实现原理8.1 动态SQL的使用8.1.1 \<if>8.1.2 <where|trim>8.1.3 <choose|when|otherwise>8.1.4 \<foreach>8.1.5 \<set> 8.2 SqlSource组件&BoundSql组件8.3 LanguageDriver组件8.3.1 XMLLanguageDriver8.3.2 Ra…

leetcode 20.有效的括号 JAVA

题目 思路 括号的匹配&#xff0c;这是一道经典的栈的应用问题。 给我们一个字符串&#xff0c;当我们遍历到左括号时&#xff0c;让其入栈。当我们遍历到右括号时&#xff0c;让栈顶元素出栈&#xff0c;看看栈顶的元素是否和遍历到的右括号匹配。不匹配的话直接false,匹配的…

vue2 脚手架

安装 文档&#xff1a;https://cli.vuejs.org/zh/ 第一步&#xff1a;全局安装&#xff08;仅第一次执行&#xff09; npm install -g vue/cli 或 yarn global add vue/cli 备注&#xff1a;如果出现下载缓慢&#xff1a;请配置npm 淘宝镜像&#xff1a; npm config set regis…

java算法第32天 | ● 122.买卖股票的最佳时机II ● 55. 跳跃游戏 ● 45.跳跃游戏II

122.买卖股票的最佳时机II 本题中理解利润拆分是关键点&#xff01; 不要整块的去看&#xff0c;而是把整体利润拆为每天的利润。假如第 0 天买入&#xff0c;第 3 天卖出&#xff0c;那么利润为&#xff1a;prices[3] - prices[0]。 相当于(prices[3] - prices[2]) (prices[…

XSS-labs详解

xss-labs下载地址https://github.com/do0dl3/xss-labs 进入靶场点击图片&#xff0c;开始我们的XSS之旅&#xff01; Less-1 查看源码 代码从 URL 的 GET 参数中取得 "name" 的值&#xff0c;然后输出一个居中的标题&#xff0c;内容是 "欢迎用户" 后面…

手撕算法-买卖股票的最佳时机 II(买卖多次)

描述 分析 使用动态规划。dp[i][0] 代表 第i天没有股票的最大利润dp[i][1] 代表 第i天持有股票的最大利润 状态转移方程为&#xff1a;dp[i][0] max(dp[i-1][0], dp[i-1][1] prices[i]); // 前一天没有股票&#xff0c;和前一天有股票今天卖掉的最大值dp[i][1] max(dp[i-1…

智能财务新选择!Zoho Books入选福布斯榜单,助力中小企业!

放眼全球&#xff0c;中小企业始终是经济发展的重要组成部分。然而&#xff0c;由于中小企业的规模、流程规范和资源等方面受限较多&#xff0c;从而导致其在管理及运营上存在着诸多问题。其中包括财务管理不规范、成本控制不到位、运营效率低下等&#xff0c;这些问题则直接影…

如何在CentOS安装SQL Server数据库并实现无公网IP远程连接内网数据库

文章目录 前言1. 安装sql server2. 局域网测试连接3. 安装cpolar内网穿透4. 将sqlserver映射到公网5. 公网远程连接6.固定连接公网地址7.使用固定公网地址连接 前言 简单几步实现在Linux centos环境下安装部署sql server数据库&#xff0c;并结合cpolar内网穿透工具&#xff0…

基于GD32E230C8T6的数字示波器

基于GD32E230C8T6的数字示波器 文章目录 基于GD32E230C8T6的数字示波器基于GD32E230C8T6的数字示波器实物演示电路原理**模拟前端处理电路**image.png**交直流耦合电路****输入信号衰减电路****信号调理电路****虚断:****虚短:****电压跟随器****反相比例放大器****同相比例放…

深度剖析GNSS高精度定位原理

一、背景 目前室外使用最广泛的定位手段是GNSS定位&#xff0c;常规的GNSS定位精度约5、10米左右&#xff0c;无法满足高精度场景的应用&#xff0c;如何提升GNSS定位性能是亟待解决的问题。本文由浅入深剖析GNSS定位原理并介绍如何实现厘米级高度定位。 二、GNSS定位原理 1、…

MySQL下载安装和本地连接

1、下载MySQL 从MySQL官网下载MySQL Community Server版本&#xff1a; 下载地址&#xff1a;MySQL官网 1、进入官网&#xff0c;点击DOWNLOADS 2、点击MySQL Community(GPL)Downloads 3、点击MySQL Installer for Windows 4、这个会直接跳转到最新的版本 如果想下载以往的…

面试算法-83-不同路径 II

题目 一个机器人位于一个 m x n 网格的左上角 &#xff08;起始点在下图中标记为 “Start” &#xff09;。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角&#xff08;在下图中标记为 “Finish”&#xff09;。 现在考虑网格中有障碍物。那么从左上角到…

【进程概念】启动进程 | 查看进程 | 创建进程

目录 启动进程 查看进程 方法1&#xff1a;/proc 方法2&#xff1a;查看脚本 ​方法3&#xff1a;系统调用获取进程标示符❗❗ 终止进程 创建进程&#xff08;主fork) &#x1f642;查看父子进程的pid &#x1f642;进程创建/执行/终止 &#x1f642;多次重新启动进…

java的IO之NIO

NIO是一种同步非阻塞的I/O模型&#xff0c;在Java 1.4中引入了NIO框架&#xff0c;对应java.nio包&#xff0c;提供了channel、selector、buffer等。 NIO中的N可以理解为Non-blocking不在单纯是New&#xff0c;它支持面向缓冲的&#xff0c;基于通道的I/O操作方法。NIO提供了与…