笔记:Few-Shot Learning小样本分类问题 + 孪生网络 + 预训练与微调

内容摘自王老师的B站视频,大家还是尽量去看视频,老师讲的特别好,不到一小时的时间就缕清了小样本学习的基础知识点~Few-Shot Learning (1/3): 基本概念_哔哩哔哩_bilibili

Few-Shot Learning(小样本分类)

假设现在每类只有一两个样本,计算机能否做到像人一样的正确分类?

  • 这个例子Support Set有两类,每类只有一两个样本,靠这些样本,难以训练出一个深度神经网络,这个集合只能提供一些参考信息。对于小样本问题,不能用传统的分类方法。

小样本分类与传统的监督学习有所不同,小样本学习的目标不是让机器通过学习训练集中图片,知道哪类是什么样子;当我拿一个很大的训练集来训练神经网络后进行小样本分类,预训练模型的目的是让机器自己学会学习-----也就是学习事物的异同,学会区分不同的事物。

现在训练集有五类,其中并没有松鼠这个类别

训练完成之后,可以问模型这两张图片是否是相同的东西呢?这时候模型已经学会分辨了事物的异同,比如给出两张松鼠图片,模型知道这两个动物之间长得很像,模型能够告诉你两张图片很可能是相同的东西。

支持集

给出一张图片,神经网络不知道这是什么。

这时候就需要支持集(Support Set),每类给出少样本(1~2)张,神经网络将Query图片和支持集中的每个类别依次对比,找出最相似的。

训练集和支持集的区别
  • 训练集规模很大,每类有很多张图片,可以训练一个深度神经网络

  • 支持集每类只有一张或几张图片,不足以训练一个大的神经网络,只能在做预测时候提供一些额外信息。

  • 用足够大的训练集训练的目的不是让模型识别训练集中的大象、老虎,而是知道事物的异同。对于训练的模型,只要提供含有该类别的小样本信息,模型就能区分类别,尽管训练集中没有这个类别。

小样本分类:Learn To Learn

带小朋友去动物园,小朋友不知道这个动物是什么,但是小朋友只需要翻一遍卡片(将目标与卡片上动物对应),就知道看到的动物是什么,这个卡片就是支持集,前提是小朋友有读卡片的能力,也就是得先经过训练学习。

如果卡片中每类只有一张,那就是One-Shot Learning(单样本学习)

传统监督学习 和 小样本学习 步骤的区别

  • 传统监督学习:测试图片虽然不是训练集中图片,但包含在训练集类别,模型已经见过上千张该类别图片,能够判断出是哪类。

  • 小样本学习:测试图片不但不包含在训练集中,也不是训练集中的类别。所以小样本学习比传统监督学习更难。因为不是训练集中的类别,所以要提供支持集,提供更多信息(给模型看小卡片,每张卡片有一个图片和一个标签,模型发现测试图片和某张卡片相似度高,就知道测试图片属于哪个标签)

小样本学习两个术语
  • k-way :支持集含有的种类数

  • n-shot : 支持集中每个种类有多少张图片

小样本学习预测准确率

  • 横轴是支持集类别数量。随着类别数量增加,分类准确率会降低。

  • 比如从三选一变成六选一

  • 每类样本越多,做预测越容易

相似度函数

sim(x, x'), x,x'为两个input

理想情况:sim(x1,x2) = 1 , sim(x1,x3) = 0, sim(x2,x3) = 0

从一个很大的训练集上学习一个相似度函数,它可以判断两张图片的相似度有多高。

孪生神经网络就可以作为相似度函数,可以拿大规模数据集做训练,训练结束之后,可以拿得到的相似度函数做预测。给一个测试图片,可以拿他跟支持集中的图片逐一对比,计算相似度,找到相似度最高作为预测结果。

  • Omniglot 特点:小样本(20个,105*105)

孪生网络(Siamese Network)

孪生网络要解决的问题
  • 第一类,分类数量较少,每一类的数据量较多,比如ImageNet、VOC等。这种分类问题可以使用神经网络或者SVM解决,只要事先知道了所有的类。

  • 第二类,分类数量较多(或者说无法确认具体数量),每一类的数据量较少,比如人脸识别、人脸验证任务。(少样本问题)

孪生网络的优点
  • 这个网络主要的优点是淡化了标签,使得网络具有很好的扩展性,可以对那些没有训练过的小样本类别进行分类,这点是优于很多算法的。

第一种训练孪生网络方法:每次取2样本,比较相似度 。
  • 训练这个神经网络要用一个大的数据集,每类有标注,每类下面都有很多个样本。

  • 我们需要用训练集来构造正样本和负样本

    • 正样本告诉神经网络什么东西是同一类。

    • 负样本告诉神经网络事物之间的区别。

  • 正样本获取

    • 每次从训练集中抽取一张图片(老虎),然后从同一类中随机抽取另一张图片(老虎),标签设置为1 (tiger, tiger, 1),意思是相似度满分。

  • 负样本获取

    • 每次从训练集中抽取一张图片(汽车),排除汽车这个类别,再从数据集中随机抽样(大象),标签设计为(car, elephant, 0),意思是相似度为0

  • 搭建一个卷积神经网络CNN用来提取特征,这个神经网络有很多卷积层,Pooling层,以及一个flatten层。输入是一张图片x,输出是提取的特征向量 f(x)

  • 现在开始训练神经网络,输入为(x1, x2 , 0或1),把这两张图片输入神经网络,把刚才搭建的卷积神经网络记作函数f。

  • 对于提取的特征向量,第一张图片特征向量记作h1 = f(x1),第二张图片特征向量记作h2 = f(x2),如果都是用CNN,这两个f需要是相同的卷积神经网络,共享相同的权值W(之所以叫孪生,就是因为共享特征提取的部分)。也可以不同权值,则不同场景,允许不同神经网络。

  • 然后拿h1 - h2 得到一个向量,再对这个向量所有元素求绝对值,记作z = ||h1 - h2||,表示两个特征向量之间的区别,再用一些全连接层来处理z向量,输出一些标量。

  • 最后用Sigmoid激活函数,得到输出是一个介于0~1之间的实数,可以衡量两个图片之间的相似度。如果两张图片是同一个类别,输出应该接近1,如果两张图片不同类别,输出应该接近0(希望神经网络的训练输出接近1),把标签与预测之间的差别作为损失函数

  • 损失函数可以是标签与预测的交叉熵损失函数cross-entropy loss function,可以衡量标签与预测的差别

  • 有了损失函数可以用反向传播计算梯度,用梯度下降来更新模型参数。

  • 模型主要有两部分,一个是卷积神经网络f用来从图片提取特征,一个是全连接层预测相似度,训练部分就是更新这两个的参数

  • 做反向传播,梯度从损失函数传回到向量z以及全连接层的参数,有了损失函数关于全连接层的梯度,就可以更新全连接层的参数了。

  • 然后梯度进一步从向量z传回到卷积神经网络,更新卷积神经网络参数,这样就完成了一轮训练

  • 做训练时候,我们要准备同样数量正样本和负样本。负样本标签设置为0,希望神经网络预测接近0,意思是这两张图片不同。还是用同样方法做反向传播,更新参数。

训练好模型之后,可以做One-Shot Prediction

  • 六个类别,每个类别一张图片,这六个类别可以都不在训练集中

  • 将Query与Support Set支持集中图片作对比:

    • 将Query图片与支持集中某一类一张图片作为input1 和 input2 ,输入到孪生网络中,孪生网络会输出一个0~1之间的值。用同样方法算出Query与所有图片相似度,查找相似度最高的。

孪生网络第二种训练方法:Triplet Loss
准备数据
  • 有这样一个训练集,每次选出三张图片

  • 首先从训练集随机选一张图片,作为anchor(锚点),记录这个锚点,然后从同类中随机抽取一张图片作为正样本Positive;排除该类别,从数据集中作随机抽样,得到不同类别的负样本Negative。

  • 现在有锚点x^a,正样本x+,负样本x-,把三张图片分别输入卷积神经网络f来提取特征(f指的是同一个卷积神经网络),得到三个特征向量

  • 计算正样本和锚点再特征空间上的距离,将特征向量 f(x+)与f(xa)求差,然后算二范数的平方,得到距离d+

  • 类似操作得到d-

  • 我们希望得到的神经网络有这样性质,像同类别特征向量聚在一起,不同类别的特征向量能够被分开,所以d+应该很小,d-应该很大

  • 这个坐标系是特征空间,卷积神经网络可以把图片映射到这个特征空间

  • d-应该比d+大很多,否则模型分辨不了同类和不同类

  • 所以鼓励正样本在特征空间接近锚点(d+尽量小),鼓励负样本在特征空间远离锚点(d-尽量大)

  • 指定一个margin :α,α>0。如果d- >= d+ + α,我们就认为没有损失loss=0,分类正确。假如条件不满足,则会有loss = d+ + α - d- , 我们希望loss越小越好

  • 有了损失函数,就可以求损失函数关于神经网络的梯度,作梯度下降来更新模型参数

测试模型
  • 给一个query,一个支持集,用神经网络提取特征,把所有这些图片变为特征向量,比较特征向量之间的距离。找出距离最小的。

总结

我们使用了Siamese Network解决了少样本学习

基本思路:

  • 用一个比较大的训练集来训练孪生网络,让孪生网络知道事物之间的异同

  • 训练结束之后拿孪生网络作预测,解决少样本问题。少样本的问题是少样本的类别不在训练集中。比如query是松鼠,但训练集中没有松鼠这个类别,需要额外的信息来识别query的图片,这个额外的信息就是少样本支持集。

  • 支持集称为k-way, n-shot,k个类别,类别越多,预测越困难,n个样本,样本越少,预测越困难,one-shot learning单样本预测最困难。

  • 有了训练好的孪生网络,我们就可以将query与support set中的样本逐一对比,选出距离最小或相似度最高作为分类结果。

  • 两种训练孪生网络方法:1.两个input,标签0或1,输出0~1之间数值,与标签差值作为loss,目标是让预测尽量接近标签。 2.另一种是Triplet Loss,xa,x+,x-,用CNN提取得到三个特征向量,输出d+,d-,目标是让d+尽量小,d-尽量大。有了这样一个神经网络就可以用它提取特征,比较两张图片在特征空间距离,作出few-shot分类

Fine Tuning

基本思路

在大规模数据上预训练模型,然后再小规模的support set上做fine-tuning。方法简单,准确率高。

  • 看个例子,余弦相似度consine similarity,衡量两个向量之间相似度,现在两个向量长度都是1,即他们的二范数都为1。

  • 把向量x和w的夹角记作θ,由于向量x和w长度都是1,cosθ就是x和w的内积,表示两个向量的相似度

  • 可以理解,把向量x投影到w方向上,投影长度就是-1到+1之间

  • 如果向量x和w的长度不是1,则需要做归一化把他们程度变为1,然后求得的内积才是余弦相似度

微调主要用到Softmax Function
  • 它是一个常用的激活函数,可以把一个k维向量映射成一个概率分布

  • 输入为Φ,它是任意的k维向量。把Φ的每一个元素做指数变换,得到k个大于0的数;然后对其作归一化,让得到的k个数相加等于1,把得到的k个数记为向量p

  • 向量p就是softmax函数的输出

  • 性质

    • 输入Φ和输出p都是k维向量

    • 向量p的元素都是正数,而且相加等于1

    • 所以p是个概率分布

  • softmax通常用于分类器的输出层,如果有k个类别,那么softmax的输出就是k个概率值,每个概率值表示对一个类别的confidence

  • softmax会让最大的值变大,其余的值变小。softmax比max函数要温柔一些

Softmax分类器
  • 是一个全连接层加一个Softmax函数

  • 分类器的输入是特征向量x,表示输入的测试图片的特征向量,把x乘到参数矩阵w上,再加上向量b,得到一个向量

  • 对得到的向量做softmax变换,得到输出向量p

  • 假如类别数量为k,那么向量p就是k维的

  • 矩阵W和b是这一层的参数,可以从训练数据中学习。W有K行,k是类别数量,所以W每一行对应一个类别,d是每个类别的特征数量

使用预训练好的神经网络,在query和support set上做fine-tuning的过程
  • 把query和support set中的图片都映射成特征向量,这样可以比较query和support set在特征空间上的相似度,比如可以计算两两之间的cosine similarity。最后选择相似度最高的作为query的分类结果

  • 预训练

    • 搭一个卷积神经网络用来提取特征,有很多卷积层、Pooling层以及一个Flatten层,也可以有全连接层

    • 神经网络输入是一张图片x,输出一个特征向量f(x)

    • 可以用传统的监督学习,预训练好后把全连接层都去掉;也可以用孪生网络训练

  • Few-Shot分类方法

    • 3-way 2-shot,三类别,每类别两样本

    • 拿预训练的神经网络提取特征,每张图片变成一个特征向量,每个类别两个特征向量

    • 平均每个类别特征向量作平均,得到一个同样大小的向量,也就是均值向量

    • 有三个类别,一共得到三个均值向量

    • 均值向量归一化,得到三个向量μ1,μ2,μ3,它们的二范数都等于一,μ1,μ2,μ3就是对三个类别的表征

    • 做分类的时候,要拿query的特征向量对μ1,μ2,μ3作对比

  • 对query作分类

    • 给一张query图片,需要判断是三个类别中的哪一个

    • 拿预训练的神经网络f来提取特征,得到一个特征向量

    • 对特征向量作归一化,得到向量q,它的二范数等于1

    • 与刚才从support set中提取的三个向量μ1,μ2,μ3,它们的二范数也是1,每个μ向量表征一个类别

    • 可以把三个μ向量堆叠起来,作为矩阵M的三个行向量

  • 做few-shot预测

    • query的特征向量q乘到矩阵M上,再做Softmax变换,得到p = Softmax(Mq),p是个概率分布,这个例子里,p是三维向量,表示对三个类别的confidence

    • 三个元素分别是q与μ1,μ2,μ3的内积

    • 很显然,在向量p中,第一个元素最大,分类结果是第一类

Fine-tuning可以大幅提高预测准确率
  • 基本都是先做预训练,后做Fine-Tuning

  • 刚才我们用了固定的W和b,没有学习这两个参数

  • 可以在Support Set上学习W和b,这叫做fine tuning

    • Cross Entropy来衡量yj与pj的差别有多大,yj是真实标签,pj是分类器做出的预测,损失函数就是Cross Entropy Loss

    • Support set中有几个或者几十个有标注的样本,每个样本都对应一个Cross Entropy Loss,把这些Cross entropy loss加起来,作为损失函数

    • 也就是说我们用support set中所有的图片和标签来学习这个分类器

    • CrossEntropyLoss做最小化Minimization,让预测pj尽量接近真实标签yj

    • Minimization是对分类器参数W和b求的,希望学习W和b;当然也可以让梯度传播到卷积神经网络,更新神经网络参数,让提取的特征向量更有效

    • support通常很小几十个到几百个样本,最好加个regularization来防止过拟合。有一篇文章建议用Entropy Regularization

  • 有一篇ICLR2020的论文说 对于5-way 1-shot,做fine tuning可以提到2%~7%的准确率;对5-way 5-shot,提高1.5%~4%准确率

  • 尽管support set很小,但用support set来训练分类器有助于提高准确率,预训练+fine tuning比只用预训练好很多

  • W,b默认值

  • Entropy Regularization防止过拟合

    • 希望Entropy Regularization越小越好

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/842885.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

UniVue@v1.5.0版本发布:里程碑版本

前言 以后使用UniVue都推荐使用1.5.0以后的版本,这个版本之后,更新的速度将会放缓。 希望这个框架能够切实的帮助大家更好的开发游戏,做出一款好游戏!本开源项目采用的开源协议为MIT协议,完全开源化,以后也…

数据结构——线性表(循环链表)

一、循环链表定义 将单链表中终端结点的指针端由空指针改为指向头结点,就使整个单链表形成一 个环,这种头尾相接的单链表称为单循环链表,简称循环链表(circular linked list)。 循环链表解决了一个很麻烦的问题。如何从当中一 个结点出发&am…

二叉树的前、中、后序遍历(递归法、迭代法)leetcode144/94/145

leetcode144、二叉树的前序遍历 给你二叉树的根节点 root ,返回它节点值的 前序 遍历。 示例 1: 输入:root [1,null,2,3] 输出:[1,2,3] 示例 2: 输入:root [] 输出:[] 示例 3:…

【Linux】Linux环境设置环境变量操作步骤

Linux环境设置环境变量操作步骤 在一些开发过程中本地调试经常需要依赖环境变量的参数,但是怎么设置对小白来说有点困难,今天就介绍下具体的操作步骤,跟着实战去学习,更好的检验自己的技术水平,做技术还是那句话&…

三字棋游戏(C语言详细解释)

hello,小伙伴们大家好,算是失踪人口回归了哈,主要原因是期末考试完学校组织实训,做了俄罗斯方块,后续也会更新,不过今天先从简单的三字棋说起 话不多说,开始今天的内容 一、大体思路 我们都知…

pytest常用命令行参数解析

简介:pytest作为一个成熟的测试框架,它提供了许多命令行参数来控制测试的运行方式,以配合适用于不同的测试场景。例如 -x 可以用于希望出现错误就停止,以便定位和分析问题。–rerunsnum适用于希望进行失败重跑等个性化测试策略。 …

用ComfyUI安装可图Kolors大模型做手机壁纸

一、Kolors简介 国内科技公司快手在人工智能领域取得了显著进展,特别推出了「可图 Kolors」这一开源模型,它在图像生成质量上超越了SD3,与Midjourney v6模型相媲美,并支持中文提示词识别与生成中文字符,成为国产AI绘画…

经典神经网络(14)T5模型原理详解及其微调(文本摘要)

经典神经网络(14)T5模型原理详解及其微调(文本摘要) 2018 年,谷歌发布基于双向 Transformer 的大规模预训练语言模型 BERT,而后一系列基于 BERT 的研究工作如春笋般涌现,预训练模型也成为了业内解决 NLP 问题的标配。 2019年,谷歌…

Qt开发网络嗅探器03

数据包分析 想要知道如何解析IP数据包,就要知道不同的IP数据包的包头结构,于是我们上⽹查查资料: 以太网数据包 ARP数据包 IPv4 IPv6 TCP UDP ICMP ICMPv6 根据以上数据包头结构,我们就有了我们的protocol.h文件,声明…

node解析Excel中的考试题并实现在线做题功能

1、背景 最近公司安排业务技能考试,下发excel文件的题库,在excel里查看并不是很方便,就想着像学习驾考题目一样,一边看一边做,做完之后可以查看正确答案。 2、开始分析需求 题目格式如下图 需求比较简单,…

配置RIPv2的认证

目录 一、配置IP地址、默认网关、启用端口 1. 路由器R1 2. 路由器R2 3. 路由器R3 4. Server1 5. Server2 二、搭建RIPv2网络 1. R1配置RIPv2 2. R2配置RIPv2 3. Server1 ping Server2 4. Server2 ping Server1 三、模拟网络攻击,为R3配置RIPv2 四、在R…

ExoPlayer架构详解与源码分析(15)——Renderer

系列文章目录 ExoPlayer架构详解与源码分析(1)——前言 ExoPlayer架构详解与源码分析(2)——Player ExoPlayer架构详解与源码分析(3)——Timeline ExoPlayer架构详解与源码分析(4)—…

拖拽上传(预览图片)

需求 点击上传图片&#xff0c;或直接拖拽图片到红色方框里面也可上传图片&#xff0c;上传后预览图片 效果 实现 <!DOCTYPE html> <html lang"zh-cn"><head><meta charset"UTF-8"><meta name"viewport" content&…

【safari】react在safari浏览器中,遇到异步时间差的问题,导致状态没有及时更新到state,引起传参错误。如何解决

在safari浏览器中&#xff0c;可能会遇到异步时间差的问题&#xff0c;导致状态没有及时更新到state&#xff0c;引起传参错误。 PS&#xff1a;由于useState是一个普通的函数&#xff0c; 定义为() > void;因此此处不能用await/async替代setTimeout&#xff0c;只能用在返…

价格较低,功能最强?OpenAI 推出 GPT-4o mini,一个更小、更便宜的人工智能模型

OpenAI美东时间周四推出“GPT-4o mini”&#xff0c;入局“小而精”AI模型竞争&#xff0c;称这款新模型是“功能最强、成本偏低的模型”&#xff0c;计划今后整合图像、视频、音频到这个模型中。 OpenAI表示&#xff0c;GPT-4o mini 相较于 OpenAI 目前最先进的 AI 模型更加便…

51单片机(STC8H8K64U/STC8051U34K64)_RA8889驱动TFT大屏_I2C_HW参考代码(v1.3) 硬件I2C方式

本篇介绍单片机使用硬件I2C方式控制RA8889驱动彩屏。 提供STC8H8K64U和STC8051U34K64的参考代码。 【硬件部份】STC8H8K64U/STC8051U34K64 RA8889开发板 7寸TFT 800x480 1. 实物连接图&#xff1a;STC8H8K64URA8889开发板&#xff0c;使用P2口I2C接口&#xff1a; 2.实物连…

ISP代理和双ISP代理:区别和优势

随着互联网技术的不断发展和普及&#xff0c;网络代理服务成为众多用户保护隐私、提高网络性能、增强安全性的重要工具。其中&#xff0c;ISP代理和双ISP代理是两种常见的网络代理服务形式。本文将详细探讨ISP代理和双ISP代理的区别和优势&#xff0c;以便用户更好地了解并选择…

【LeetCode】填充每个节点的下一个右侧节点指针 II

目录 一、题目二、解法完整代码 一、题目 给定一个二叉树&#xff1a; struct Node { int val; Node *left; Node *right; Node *next; } 填充它的每个 next 指针&#xff0c;让这个指针指向其下一个右侧节点。如果找不到下一个右侧节点&#xff0c;则将 next 指针设置为 NUL…

MySQL学习作业二

作业描述 SQL语言 建库&#xff0c;使用库 mysql> create database mydb8_worker;#新建库mysql> use mydb8_worker; 建表&#xff0c;查看表 #建表 mysql> create table t_worker(department_id int(11) not null comment部门号,worker_id int(11) primary key no…

Flink History Server配置

目录 问题复现 History Server配置 HADOOP_CLASSPATH配置 History Server配置 问题修复 启动flink集群 启动Histroty Server 问题复现 在bigdata111上执行如下命令开启socket&#xff1a; nc -lk 9999 如图&#xff1a; 在bigdata111上执行如下命令运行flink应用程序 …