IB 公式解析

公式

3.2. Influence Function

影响函数允许我们在移除样本时估计模型参数的变化,而无需实际移除数据并重新训练模型。

 3.3 影响平衡加权因子

 3.4 影响平衡损失

 3.5 类内重加权

m代表一个批次(batch)的大小,这意味着公式对一个批次中的所有样本进行计算,然后去平均值。

 代码

criterion_ib = IBLoss(weight=per_cls_weights, alpha=1000).cuda()
def ib_loss(input_values, ib):
    """Computes the focal loss"""
    loss = input_values * ib
    return loss.mean()
class IBLoss(nn.Module):
    def __init__(self, weight=None, alpha=10000.):
        super(IBLoss, self).__init__()
        assert alpha > 0
        self.alpha = alpha
        self.epsilon = 0.001
        self.weight = weight

    def forward(self, input, target, features):
        grads = torch.sum(torch.abs(F.softmax(input, dim=1) - F.one_hot(target, num_classes)),1) # N * 1
        ib = grads * features.reshape(-1)
        ib = self.alpha / (ib + self.epsilon)
        return ib_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), ib)

1.计算梯度 grads

grads = torch.sum(torch.abs(F.softmax(input, dim=1) - F.one_hot(target, num_classes)), 1) # N * 1
  • 计算 softmax 概率分布F.softmax(input, dim=1) 将模型的输出转换为概率分布。
  • 计算 one-hot 编码F.one_hot(target, num_classes) 将目标标签转换为 one-hot 编码。
  • 计算绝对差值:通过计算 softmax 输出与 one-hot 编码之间的绝对差值,得到每个样本的梯度,表示样本对模型的损失贡献。

 2. 计算影响平衡因子(IB Factor)

ib = grads * features.reshape(-1)
ib = self.alpha / (ib + self.epsilon)

影响平衡因子(IB Factor)确实与梯度成反比。梯度越大,IB因子越小,分配给该样本的权重越小;梯度越小,IB因子越大,分配给该样本的权重越大。这一机制确保了模型在处理不平衡数据时,能够更有效地减小对多数类样本的过拟合,提升对少数类样本的泛化能力。

 3. 计算最终损失

return ib_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), ib)

将论文中的公式与代码对应起来

论文中的公式:

对应代码

首先,我们来看影响平衡损失 IBLoss 的代码实现:

class IBLoss(nn.Module):
    def __init__(self, weight=None, alpha=10000.):
        super(IBLoss, self).__init__()
        assert alpha > 0
        self.alpha = alpha
        self.epsilon = 0.001
        self.weight = weight

    def forward(self, input, target, features):
        grads = torch.sum(torch.abs(F.softmax(input, dim=1) - F.one_hot(target, num_classes)), 1) # N * 1
        ib = grads * features.reshape(-1)
        ib = self.alpha / (ib + self.epsilon)
        return ib_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), ib)

对应关系

  1. 批次大小 m

    在代码中,批次大小由 train_loadertest_loader 的批次大小参数决定。

  2. 数据集 𝐷𝑚

    代码中的 train_loadertest_loader 提供了批次数据。

  3. 类别权重因子 𝜆𝑘

    在代码中,通过 per_cls_weights 来实现:

per_cls_weights = 1.0 / np.array(cls_num_list)
per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()

4.基本损失函数 𝐿(𝑦,𝑓(𝑥,𝑤))

代码中使用 torch.nn.CrossEntropyLoss 计算交叉熵损失:

base_loss = F.cross_entropy(input, target, reduction='none', weight=self.weight)

5.模型输出 𝑓(𝑥,𝑤)f(x,w)

在代码中,模型的输出为 input

output, features = model(images)

6.模型输出与真实标签的 L1 范数 ∥𝑓(𝑥,𝑤)−𝑦∥1

在代码中,通过以下方式计算:

grads = torch.sum(torch.abs(F.softmax(input, dim=1) - F.one_hot(target, num_classes)), 1) # N * 1

7.隐藏层特征向量 ℎ 和其 L1 范数 ∥ℎ∥1**

在代码中,通过以下方式计算隐藏层特征向量的 L1 范数:

features = torch.sum(torch.abs(feats), 1).reshape(-1, 1)

8.最终影响平衡因子 IB

在代码中,通过以下方式计算:

ib = grads * features.reshape(-1)
ib = self.alpha / (ib + self.epsilon)

 9.最终影响平衡损失 𝐿IB(𝑤)

通过自定义的 ib_loss 函数计算:

return ib_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), ib)

 为什么类别权重因子要这样实现

per_cls_weights = 1.0 / np.array(cls_num_list)
per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()

类别权重因子的实现旨在通过加权样本来处理类别不平衡问题。以下是详细解释为什么要这样实现 per_cls_weights 以及每一步的作用:

代码实现

per_cls_weights = 1.0 / np.array(cls_num_list)
per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()

每一步的解释

计算每个类别的逆频率

per_cls_weights = 1.0 / np.array(cls_num_list)
  • cls_num_list 是每个类别的样本数量列表。例如,如果有三个类别,且每个类别的样本数量为 [100, 200, 50],则 cls_num_list = [100, 200, 50]
  • 通过取倒数 1.0 / np.array(cls_num_list),我们得到了每个类别的逆频率。例如,结果将是 [0.01, 0.005, 0.02]
  • 逆频率反映了类别数量的稀少程度,样本数量少的类别(少数类)将得到更高的权重。

归一化权重

per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
  • 首先,计算权重的总和 np.sum(per_cls_weights)。根据前面的例子,总和为 0.01 + 0.005 + 0.02 = 0.035
  • 然后,将每个类别的权重除以总和,使得所有权重的和为 1。这是标准化步骤,使得权重变为 [0.01/0.035, 0.005/0.035, 0.02/0.035],即 [0.2857, 0.1429, 0.5714]
  • 接下来,将这些标准化权重乘以类别的数量 len(cls_num_list)。在这个例子中,类别数量是 3。因此,最终的权重变为 [0.2857*3, 0.1429*3, 0.5714*3],即 [0.8571, 0.4286, 1.7143]

这一步的作用是确保每个类别的权重和类别数量成正比,同时保持权重的总和为类别数量。

转换为 PyTorch 张量

per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()
  • 将 NumPy 数组转换为 PyTorch 张量,以便在 PyTorch 中使用这些权重。
  • 将权重张量移动到 GPU(如果可用),以加速计算。

归一化步骤的原因

归一化权重的目的是确保类别权重的相对比例合理,并且所有权重的总和与类别数量一致。这有助于避免某些类别被赋予过高或过低的权重,从而确保训练过程中的稳定性和效果。

处理类别不平衡的原因

类别不平衡问题是指在数据集中,不同类别的样本数量差异很大。在这种情况下,传统的损失函数往往会被多数类主导,导致模型在少数类上的性能较差。通过加权样本,特别是对少数类样本赋予更高的权重,可以平衡各类样本对损失的贡献,从而改善模型在少数类上的表现。

总结

  • 逆频率权重:通过取样本数量的倒数,使得样本数量少的类别得到更高的权重。
  • 归一化:将权重标准化,并确保权重的总和与类别数量一致,保持权重比例的合理性。
  • 转换为张量:将权重转换为 PyTorch 张量,以便在训练过程中使用。

这种权重计算方法确保了在处理类别不平衡问题时,少数类样本对损失函数的贡献增加,从而提高模型对少数类的识别能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/618603.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

阮怀俊参与五龙乡黄沙村村企联办“强村公司”

为走好海岛县高质量发展共同富裕特色之路,探索村级集体经济发展新路径、扶持新模式、运行新机制,嵊泗县五龙乡黄沙村股份经济合作社与杭州山舍乡建乡村产业发展有限责任公司联办成“强村公司”。 创始人阮怀俊表示,双方就融合乡域发展和文旅产…

科林Linux_4 信号

#include <signal.h> 信号signal&#xff1a;Linux或Unix系统支持的经典的消息机制&#xff0c;用于处置进程&#xff0c;挂起进程或杀死进程 kill -l #查看系统支持的信号 1~31 Unix经典信号&#xff08;软件开发工程师&#xff09; 32、33信号被系统隐藏&#xf…

虚拟化数据恢复—误还原虚拟机快照怎么办?怎么恢复最新虚拟机数据?

虚拟化技术原理是将硬件虚拟化给不同的虚拟机使用&#xff0c;利用虚拟化技术可以在一台物理机上安装多台虚拟机。误操作或者物理机器出现故障都会导致虚拟机不可用&#xff0c;虚拟机中的数据丢失。 虚拟化数据恢复环境&#xff1a; 有一台虚拟机是由物理机迁移到ESXI上面的&a…

数据库管理-第188期 23ai:怎么用PGQL创建图(20240511)

数据库管理188期 2024-05-10 数据库管理-第188期 23ai:怎么用PGQL创建图&#xff08;20240511&#xff09;1 PGQL创建属性图1.1 PGQL属性图的元数据表1.2 创建一个PGQL属性图1.3 获取PGQL属性图的元数据 2 PGQL属性图3 官方示例演示3.1 插入数据3.2 创建PGQL属性图3.3 通过PGQL…

SpringBoot:SpringBoot原理

SpringBoot高级 SpringBoot配置 配置文件优先级 按照yaml>yml>properties的顺序加载 存在相同配置项,后加载的会覆盖先加载的 加载顺序越靠后,优先级越高 SpringBoot存在其他的多种方式进行配置,如下所示,越靠下优先级越高 1. Default properties (specified by s…

vm虚拟机安装网络适配器驱动卡死/无响应/无限等待状态

大部分原因都是以前的vm没有卸载干净所导致的&#xff0c;只需要使用CCleaner清楚干净就好 使用控制面板里的卸载把VM卸载干净 使用CCleaner软件删除干净注册表&#xff0c;这个软件百度很容易找到&#xff0c;只有十兆左右 打开下载好的软件&#xff0c;不需要注册码&#xff…

长安汽车:基于云器 Lakehouse 的车联网大数据平台建设

近年来随着智能汽车行业的迅速发展&#xff0c;数据也在呈爆炸式增长。长安大数据平台承接了长安在生产上大部分流量应用及日常生产业务应用。本文将分享长安汽车在车联网场景下大数据平台建设面临的一些挑战和具体落地的实践。 主要内容如下&#xff1a; 1. 背景介绍 2. 长…

Java数组:三种初始化

一.静态初始化 代码演示&#xff1a; //静态初始化:创建 赋值int[] a {1,2,3,4,5,6};System.out.println(a[0]); 二.动态初始化 代码演示&#xff1a; //动态初始化:包含默认初始化int[] b new int[10];b[0] 10;System.out.println(b[0]); //10System.out.println(b[1])…

25计算机考研院校数据分析 | 中南大学

中南大学&#xff08;Central South University&#xff09;&#xff0c;位于湖南省长沙市&#xff0c;是中华人民共和国教育部直属的全国重点大学 &#xff0c;中央直管副部级建制&#xff0c;位列国家“双一流”、“985工程”、“211工程”&#xff0c;入选国家“2011计划”牵…

MySQL前缀索引、脏页和干净页、COUNT(*)讨论、表删除内存问题

文章目录 如何加索引如何给身份证号添加索引 SQL语句变慢脏页 (Dirty Pages)干净页 (Clean Pages)为何区分脏页和干净页处理脏页管理策略 flush如何控制 为什么删除表数据后表文件大小不变问题背景核心原因数据存储方式参数影响 解决方案1. 调整innodb_file_per_table设置2. 使…

vs2019 cpp20 规范的线程头文件 <thread> 注释并探讨两个问题

&#xff08;1&#xff09;学习线程&#xff0c;与学习其它容器一样&#xff0c;要多读 STL 库的源码。很多知识就显然而然的明白了。也不用死记硬背一些结论。上面上传了一份注释了一下的 源码。主要是补充泛型推导与函数调用链。基于注释后的源码探讨几个知识点。 STL 库的多…

【SpringBoot】 什么是springboot(三)?springboot使用ajax、springboot使用reids

文章目录 SpringBoot第五章第六章1、springboot使用ajax2、springboot使用reids1、单机版**使用步骤**1-5步67-9步RedisTemplate使用RedisTemplate2、集群版开启集群项目配置1234-5第七章1、springboot文件上传使用步骤1-234-52、springboot邮件发送步骤1-23453、springboot拦截…

【智能算法】最优捕食算法(OFA)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献5.代码获取 1.背景 2017年&#xff0c;GY Zhu受到动物行为生态学理论启发&#xff0c;提出了最优捕食算法&#xff08;Optimal Foraging Algorithm, OFA&#xff09;。 2.算法原理 2.1算法思想 OFA灵感来源…

网络编程学习笔记1

文章目录 一、socket1、创建socket2、网络通信流程3、accept()函数4、signal()函数5、recv()函数6、connect()函数 二、I/O多路复用1.select模型2.poll模型3.epoll模型 注 一、socket 1、创建socket int socket(int domain,int type,int protocol); //返回值&#xff1a;一个…

微信小程序的Vant Weapp组件库(WeUI组件库)

一、定义&#xff1a; 是一套开源的微信小程序UI组件库。提供了一整套UI基础组件和业务组件&#xff0c;能够快速地搭配出一套风格统一的页面 二、使用&#xff1a; &#xff08;1&#xff09;&#xff08;找到.eslintrc.js 右键&#xff0c;在内件终端打开&#xff09;打开命…

|Python新手小白中级教程|第二十八章:面向对象编程(类定义语法私有属性类的继承与多态)(4)

文章目录 前言一、类定义语法二、私有方法和私有属性1.私有属性2.私有方法 三、类“继承”1.初识继承2.使用super函数调用父类中构造的东西 四、类“多态”1.多态基础2.子类不同形态3.使用isinstance函数与多态结合判断类型 总结 前言 大家好&#xff0c;我是BoBo仔吖&#xf…

RocketMQ学习笔记(一)

一、基本概念 生产者&#xff08;Producer&#xff09;&#xff1a;也称为消息发布者&#xff0c;是RocketMQ中用来构建并传输消息到服务端的运行实体&#xff0c;举例&#xff1a;发信者主题&#xff08;Topic&#xff09;&#xff1a;Topic是RocketMQ中消息传输和存储的顶层…

【全开源】Java知识付费教育付费资源付费平台公众号小程序源码

特色功能&#xff1a; 多样化的内容呈现&#xff1a;资源付费平台小程序支持图文、音视频、直播等多种形式的内容呈现&#xff0c;为用户提供了丰富的学习体验。直播课程&#xff1a;专家或讲师可以通过小程序进行在线授课&#xff0c;与用户实时互动&#xff0c;增强了学习的…

再有人说数字孪生大屏没有用,用这8条怼回去。

数字孪生大屏之所以受到欢迎&#xff0c;主要有以下几个原因&#xff1a; 实时数据可视化 数字孪生大屏可以将实时数据以直观的可视化形式展示出来&#xff0c;让用户能够一目了然地了解数据的状态和趋势。这样可以帮助用户更好地理解和分析数据&#xff0c;及时做出决策和调…

动态规划算法练习——计数问题

题目描述 给定两个整数 a 和 b&#xff0c;求 a 和 b 之间的所有数字中 0∼9 的出现次数。 例如&#xff0c;a1024&#xff0c;b1032&#xff0c;则 a 和 b 之间共有 9 个数如下&#xff1a; 1024 1025 1026 1027 1028 1029 1030 1031 1032 其中 0 出现 10 次&#xff0c;1 出现…