困难样本挖掘:Hard Sample Mining(原理及实现)

Hard Sample Mining

Hard Sample Mining,即困难样本挖掘,是目标检测中的一种常用方法。其主要思想是针对训练过程中损失较高的样本(即那些难以被正确分类的样本)进行挖掘,并将其补充到数据集中重新训练,以提高模型的识别能力。这种方法可以有效地解决样本不平衡和简单样本过多的问题。

分类中的一些基础概念:

在Hard Sample Mining中,正样本和负样本的概念是基础。正样本是包含我们想要识别类别的样本,而负样本则是不包含这些类别的样本。其中,难分正样本(Hard Positives)是指那些容易被错误地分类为负样本的正样本,而难分负样本(Hard Negatives)则是那些容易被错误地分类为正样本的负样本。这些难分样本通常具有较大的损失值,即它们的预测结果与真实标签相差较大。

  • 正样本:包含我们想要识别类别的样本,例如,我们在做猫狗分类,那么在训练的时候,包含猫或者狗的图片就是正样本
  • 负样本:在上面的例子中,不包含猫或者狗的其他所有的图片都是负样本
  • 难分正样本(hard positives):易错分成负样本的正样本,对应在训练过程中损失最高的正样本,loss比较大(label与prediction相差较大)。
  • 难分负样本(hard negatives):易错分成正样本的负样本,对应在训练过程中损失最高的负样本
  • 易分正样本(easy positive):容易正确分类的正样本,该类的概率最高。对应在训练过程中损失最低的正样本
  • 易分负样本(easy negatives):容易正确分类的负样本,该类的概率最高。对应在练过程中损失最低的负样本。

举例补充: 如果ROI(Region of Interest,感兴趣区域)里没有物体,全是背景,这时候分类器很容易正确分类成背景,这个就叫easy negative;如果roi里有二分之一个物体,标签仍是负样本,这时候分类器就容易把他看成正样本,这时候就是had negative。

为什么要进行困难样本挖掘?

在区域提议(Region Proposal) 的目标检测算法中,负样本的数量会远高于正样本的数量,并且这些负样本中大部分是对网络训练作用相对较小的易分负样本(easy negatives)。这种样本不均衡问题会导致模型在训练过程中过度关注简单样本,而忽视了那些真正具有挑战性的难分样本(hard examples),从而影响了模型的收敛和精度。

根据Focal Loss论文的统计,简单例(基本都是负例)与有用信息的难例(正例+难负例)之比高达100000:100,一个非常悬殊的比例。这意味着在训练过程中,简单例的损失函数值会远高于难例的损失函数值,甚至可能达到40倍之多。这种损失函数值的不均衡会导致模型在优化过程中偏离正确的方向,无法有效地学习到难例中的有用信息。

为解决这个问题,提出了困难样本挖掘(Hard Sample Mining)等方法。这些方法的核心思想是在训练过程中动态地选择一些具有高损失的样本(即难例)进行重点训练,从而提高模型的识别能力。通过这种方法,我们可以有效地降低简单例对模型训练的影响,使模型更加关注那些真正具有挑战性的难例。

解决目标检测中的样本不均衡问题对于提高模型的性能至关重要。通过合理处理简单例和难例之间的关系,可以使模型在训练过程中更加高效地学习有用信息,从而实现更高的检测精度和更好的泛化能力。
在这里插入图片描述
由于正样本数量一般较少,所有对于困难样本挖掘(hard example mining)一般是指难负例挖掘(Hard Negative Mining)。

为了让模型正常训练,我们必须要通过某种方法抑制大量的简单负例,挖掘所有难例的信息,这就是难例挖掘的初衷。即在训练时,尽量多挖掘些难负例(hard negative)加入负样本集参与模型的训练,这样会比easy negative组成的负样本集效果更好。

关于hard negative mining,比较生动的例子是高中时期你准备的错题集。错题集不会是每次所有的题目你都往上放。放上去的都是你最没有掌握的那些知识点(错的最厉害的),而这一部分是对你学习最有帮助的。

挖掘方法

离线

1、TopK Loss方法

在训练过程中,选择损失值(loss)最大的前K个样本进行反向传播(back propagate),而损失值较小的样本(即easy samples)则认为已经分类正确,不再进行反向传播。这里的K可以是一个百分比,例如前70%的困难样本。这种方法有助于模型更加关注那些难以分类的样本。

2、基于IOU的离线方法:Hard Negative Mining (HNM)

这是最早期和最简单的困难样本挖掘方法之一,它专注于那些模型错误分类为正样本的负样本(即难分负样本或hard negatives)。

在训练过程中,通过计算区域提议(Region Proposal)与真实标注(Ground Truth)之间的交并比(IoU),并设定一个阈值,将那些IoU低于某个阈值的负样本选为困难样本。结果超过阈值的被认为是正样本,低于阈值的则认为是负样本。这些样本随后被加入到负样本集中,以便在后续的训练迭代中被模型重点学习。

然而,随着训练的进行,可能会出现正样本数量远小于负样本的问题,导致数据分布不平衡。为了解决这个问题,有些研究者提出了对称的模型来处理这种不平衡的数据。
在这里插入图片描述
就是类似上图,将Hard Posiotive也重新赋给正样本。

在线

Online Hard Example Mining(OHEM)是困难样本挖掘的一种实现方式。它的核心思想是在训练过程中动态地选择一些具有高损失的样本作为训练样本,从而改善网络参数的效果。这种方法不需要通过设置正负样本比例来解决数据的类别不平衡问题,而是采用在线选择的方式,更具针对性。

代表性论文:Training Region-Based Object Detectors with Online Hard Example Mining 【CVPR2016】

上述论文将难分样本挖掘(hard example mining)机制嵌入到SGD算法中,使得Fast R-CNN在训练的过程中根据region proposal的损失自动选取合适的Region Proposal作为正负例训练。

实验结果表明使用OHEM(Online Hard Example Mining)机制可以使得Fast R-CNN算法在VOC2007和VOC2012上mAP提高 4%左右。

实现

在mmdetection中的实现

在正负样本的挑选过程中,采用困难样例挖掘的方法进行筛选而不是简单的随机挑选;

    def _sample_pos(self,
                    assign_result,
                    num_expected,
                    bboxes=None,
                    feats=None,
                    **kwargs):
        # Sample some hard positive samples
        pos_inds = torch.nonzero(assign_result.gt_inds > 0)
        if pos_inds.numel() != 0:
            pos_inds = pos_inds.squeeze(1)
        if pos_inds.numel() <= num_expected: #如果样本量本身少于期望值不进行困难样本挖掘
            return pos_inds
        else:
            return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds],
                                    assign_result.labels[pos_inds], feats)

    def _sample_neg(self,
                    assign_result,
                    num_expected,
                    bboxes=None,
                    feats=None,
                    **kwargs):
        # Sample some hard negative samples
        neg_inds = torch.nonzero(assign_result.gt_inds == 0)
        if neg_inds.numel() != 0:
            neg_inds = neg_inds.squeeze(1)
        if len(neg_inds) <= num_expected:
            return neg_inds
        else: #如果样本量多于期望值,进行困难样本挖掘进行筛选
            return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds],
                                    assign_result.labels[neg_inds], feats)


计算样本的损失值,根据类别的loss挑选出损失比较大的样本【困难样本】

def hard_mining(self, inds, num_expected, bboxes, labels, feats):
#inds: 正负样本的索引; num_expected:期望的正负样本数量; bboxes:正负样本【anchor】, labels:类别标签  feats:特征图
    with torch.no_grad():  #不参与梯度的计算
        rois = bbox2roi([bboxes])
        bbox_feats = self.bbox_roi_extractor(
            feats[:self.bbox_roi_extractor.num_inputs], rois)
        cls_score, _ = self.bbox_head(bbox_feats)
        loss = self.bbox_head.loss(
            cls_score=cls_score,
            bbox_pred=None,
            labels=labels,
            label_weights=cls_score.new_ones(cls_score.size(0)),
            bbox_targets=None,
            bbox_weights=None,
            reduction_override='none')['loss_cls']
        _, topk_loss_inds = loss.topk(num_expected)
    return inds[topk_loss_inds]

2、基于Yolov5/Yolov7的困难样本挖掘—LRM loss,提升难样本检测精度

论文地址:Improved Hard Example Mining Approach for Single Shot Object Detectors

代码地址:https://github.com/aybora/yolov5Loss

简介

困难例挖掘方法通常可以提高目标检测器的性能,因为它受到不平衡训练集的影响。在这项工作中,两种现有的困难例子挖掘方法(LRM和焦点损失,FL)被调整并结合到最先进的实时目标检测器YOLOv5中。广泛地评估了所提出的方法对于提高困难例性能的有效性。在2021年Anti-UAV挑战数据集上,与使用原始损失函数相比,所提出的方法使mAP提高了3%,与单独使用困难挖掘方法(LRM或FL)相比,提高了约1%。
在这里插入图片描述
在基于Yolov5/Yolov7的目标检测模型中,实现困难样本挖掘(Hard Example Mining)通常涉及到对损失函数的修改,以便模型在训练过程中更加关注那些难以正确分类的样本。LRM(Loss Rank Mining)loss是一种专门为困难样本挖掘设计的方法,它通过考虑样本在整个数据集中的相对损失排名来选择困难样本。

基于OHEM的思想,Yu等人引入了Loss Rank Mining(LRM)[18]:该方法适用于单次检测,它通过在检测阶段之前过滤掉特征图上的一些容易的例子,使目标检测器专注于困难的例子。在训练过程中,作为第一步,输入要经过模型主干得到特征图。然后,对于每个检测,都要计算损失值。在非最大抑制(NMS)阶段之后,这些检测的损失值按降序排序,前K个检测结果被选中并过滤。其余的检测值在训练过程中不被使用。这个想法可能适用于目前的目标检测器,并对其有利,如果它能被实施到其结构中。

代码示例:

import torch
import torch.nn as nn
import torch.nn.functional as F

class LRMLoss(nn.Module):
    def __init__(self, gamma=0.1, beta=0.05):
        super(LRMLoss, self).__init__()
        self.gamma = gamma
        self.beta = beta

    def forward(self, predictions, targets, indices):
        # 计算原始损失(例如,交叉熵损失)
        loss = F.cross_entropy(predictions, targets, reduction='none')

        # 获取所有样本的损失值
        loss_values = loss.view(-1)

        # 计算每个样本的排名(从1开始)
        ranks = torch.argsort(loss_values, descending=True)
        ranks = ranks + 1  # 转换为从1开始的排名

        # 计算排名损失
        rank_loss = self.gamma * torch.log(1 + 1 / (ranks + self.beta))

        # 选择困难样本进行更新
        _, hard_indices = torch.topk(rank_loss, k=len(hard_indices))
        hard_losses = loss[hard_indices]

        # 计算最终的LRM损失
        lrm_loss = hard_losses.mean()

        return lrm_loss

# 假设我们有以下输入
# predictions: 模型的预测输出,尺寸为 [batch_size, num_classes, num_boxes]
# targets: 真实的标签,尺寸为 [batch_size, num_boxes]
# indices: 每个样本的索引,用于确定哪些样本是困难样本

# 初始化LRM损失函数
lrm_loss = LRMLoss(gamma=0.1, beta=0.05)

# 计算LRM损失
loss = lrm_loss(predictions, targets, indices)

# 反向传播和优化
loss.backward()
optimizer.step()

首先计算了标准的交叉熵损失,然后根据损失值计算每个样本的排名。接着使用排名来计算排名损失,并通过选择排名靠前的样本(即困难样本)来计算最终的LRM损失。

参考文献

  1. 深度学习难分样本挖掘(Hard Mining)
  2. 改进的one-shot目标检测的困难样本挖掘方法
  3. 涨点技巧:基于Yolov5/Yolov7的困难样本挖掘—LRM loss,提升难样本检测精度
  4. YOLOv9改进策略:loss优化 | LRM loss困难样本挖掘,提升难样本、遮挡物、低对比度等检测精度
  5. http://www.hqwc.cn/news/577750.html
  6. https://blog.csdn.net/qq_44804542/article/details/115276930

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/510713.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Qt 学习笔记】Qt 背景介绍

博客主页&#xff1a;Duck Bro 博客主页系列专栏&#xff1a;Qt 专栏关注博主&#xff0c;后期持续更新系列文章如果有错误感谢请大家批评指出&#xff0c;及时修改感谢大家点赞&#x1f44d;收藏⭐评论✍ Qt 背景介绍 文章编号&#xff1a;Qt 学习笔记 / 01 文章目录 Qt 背景…

配置plsql链接Oracle数据库(新手)

配置plsql链接Oracle数据库 安装Oracle客户端 、安装plsql客户端并激活 配置tnsnames.ora文件&#xff08;路径D:\app\peter\Oracle\InstantClient\network\admin根据你的实际路径设置&#xff09; 配置文件如下 # tnsnames.ora Network Configuration File: D:\app\peter\O…

【CKA模拟题】一文教你用StorageClass轻松创建PV

题干 For this question, please set this context (In exam, diff cluster name) kubectl config use-context kubernetes-adminkubernetesYour task involves setting up storage components in a Kubernetes cluster. Follow these steps: Step 1: Create a Storage Class…

卡尔曼滤波笔记

资料&#xff1a;https://www.zhihu.com/question/47559783/answer/2988744371 https://www.zhihu.com/question/47559783 https://blog.csdn.net/seek97/article/details/120012667 一、基本思想 在对一个状态值进行估计的时候&#xff0c;如果想测量值更准&#xff0c;很自然…

“探秘数据结构:栈的奇妙魔力“

每日一言 兰有秀兮菊有芳&#xff0c;怀佳人兮不能忘。 —刘彻- 栈 栈的概念及结构 栈(Stack) &#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈顶&#xff0c;另一端称为栈底。栈中的数据元素遵守…

vue3+vite 模板vue3-element-admin框架如何关闭当前页面跳转 tabs

使用模版: 有来开源组织 / vue3-element-admin 需要关闭的.vue 页面增加以下方法 //setup 里import {LocationQuery, useRoute, useRouter} from "vue-router"; const router useRouter(); function close() {console.log(|--router.currentRoute.value, router.cur…

【MySQL系列】使用 ALTER TABLE 语句修改表结构的方法

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

每日一题 第六十三期 洛谷 树状数组模板

【模板】树状数组 1 题目描述 如题&#xff0c;已知一个数列&#xff0c;你需要进行下面两种操作&#xff1a; 将某一个数加上 x x x 求出某区间每一个数的和 输入格式 第一行包含两个正整数 n , m n,m n,m&#xff0c;分别表示该数列数字的个数和操作的总个数。 第二…

4.2 JavaWeb Day05分层解耦

三层架构功能 controller层接收请求&#xff0c;响应数据&#xff0c;层内调用了service层的方法&#xff0c;service层仅负责业务逻辑处理&#xff0c;其中要获取数据&#xff0c;就要去调用dao层&#xff0c;由dao层进行数据访问操作去查询数据&#xff08;进行增删改查&…

YOLOv8结合SCI低光照图像增强算法!让夜晚目标无处遁形!【含端到端推理脚本】

这里的"SCI"代表的并不是论文等级,而是论文采用的方法 — “自校准光照学习” ~ 左侧为SCI模型增强后图片的检测效果,右侧为原始v8n检测效果 这篇文章的主要内容是通过使用SCI模型和YOLOv8进行算法联调,最终实现了如上所示的效果:在增强图像可见度的同时,对图像…

Scala中如何使用Jsoup库处理HTML文档?

在当今互联网时代&#xff0c;数据是互联网应用程序的核心。对于开发者来说&#xff0c;获取并处理数据是日常工作中的重要一环。本文将介绍如何利用Scala中强大的Jsoup库进行网络请求和HTML解析&#xff0c;从而实现爬取京东网站的数据&#xff0c;让我们一起来探索吧&#xf…

【容易不简单】love 2d Lua 俄罗斯方块超详细教程

源码已经更新在CSDN的码库里&#xff1a; git clone https://gitcode.com/funsion/love2d-game.git 一直在找Lua 能快速便捷实现图形界面的软件&#xff0c;找了一堆&#xff0c;终于发现love2d是小而美的原生lua图形界面实现的方式。 并参考相关教程做了一个更详细的&#x…

C++算法——滑动窗口

一、长度最小的子数组 1.链接 209. 长度最小的子数组 - 力扣&#xff08;LeetCode&#xff09; 2.描述 3.思路 本题从暴力求解的方式去切入&#xff0c;逐步优化成“滑动窗口”&#xff0c;首先&#xff0c;暴力枚举出各种组合的话&#xff0c;我们先让一个指针指向第一个&…

【Qt学习笔记 01】Qt 背景介绍

博客主页&#xff1a;Duck Bro 博客主页系列专栏&#xff1a;Qt 专栏关注博主&#xff0c;后期持续更新系列文章如果有错误感谢请大家批评指出&#xff0c;及时修改感谢大家点赞&#x1f44d;收藏⭐评论✍ Qt 背景介绍 文章编号&#xff1a;Qt 学习笔记 / 01 文章目录 Qt 背景…

AttributeError: ‘Namespace‘ object has no attribute ‘EarlyStopping‘

报错原因 这个报错信息表明在Python脚本train.py中尝试访问命令行参数args.EarlyStopping时出错&#xff0c;具体错误是AttributeError: Namespace对象没有名为EarlyStopping的属性。 在Python的argparse模块中&#xff0c;当我们通过命令行传递参数并解析时&#xff0c;解析…

UGUI 进阶

UI事件监听接口 目前所有的控件都只提供了常用的事件监听列表 如果想做一些类似长按&#xff0c;双击&#xff0c;拖拽等功能是无法制作的 或者想让Image和Text&#xff0c;RawImage三大基础控件能够响应玩家输入也是无法制作的 而事件接口就是用来处理类似问题 让所有控件都…

10秒钟用python接入讯飞星火API(保姆级)

正文&#xff1a; 科大讯飞是中国领先的人工智能公众公司&#xff0c;其讯飞星火API为开发者提供了丰富的接口和服务&#xff0c;以支持各种语音和语言技术的应用。 步骤一&#xff1a;注册账号并创建应用 首先&#xff0c;您需要访问科大讯飞开放平台官网&#xff0c;注册一个…

dcoker 下redis设置密码

修改Docker里面Redis密码 Redis是一个开源的内存数据结构存储系统&#xff0c;常用于缓存、消息队列和数据持久化等场景。在使用Docker部署Redis时&#xff0c;默认情况下是没有设置密码的&#xff0c;这可能会导致安全隐患。因此&#xff0c;为了保证数据的安全性&…

2024环境,资源与绿色能源国际会议(ICERGE2024)

2024环境&#xff0c;资源与绿色能源国际会议(ICERGE2024) 会议简介 2024环境、资源与绿色能源国际会议(ICERGE2024)将于2024年在三亚举行。该会议是一个围绕环境、资源与绿色能源研究领域的国际学术交流活动。 会议主题包括但不限于环境科学、环境工程、资源利用、绿色能源开…

微信小程序上传代码到远程仓库

个人介绍 hello hello~ &#xff0c;这里是 code袁~&#x1f496;&#x1f496; &#xff0c;欢迎大家点赞&#x1f973;&#x1f973;关注&#x1f4a5;&#x1f4a5;收藏&#x1f339;&#x1f339;&#x1f339; &#x1f981;作者简介&#xff1a;一名喜欢分享和记录学习的…