【论文阅读】Long-Tailed Recognition via Weight Balancing(CVPR2022)附MaxNorm的代码

目录

  • 论文
  • 使用方法
    • weight decay
    • MaxNorm
  • 如果使用原来的代码报错的可以看下面这个

论文

问题:真实世界中普遍存在长尾识别问题,朴素训练产生的模型在更高准确率方面偏向于普通类,导致稀有的类别准确率偏低。
key:解决LTR的关键是平衡各方面,包括数据分布、训练损失和学习中的梯度。
文章主要讨论了三种方法: L2normalization, weight decay, and MaxNorm
本文提出了一个两阶段训练的范式
a. 利用调节权重衰减的交叉熵损失学习特征。
b. 通过调节权重衰减和Max Norm使用类平衡损失学习分类器。
一些有用的看法

  1. 研究表明,与联合训练特征学习和分类器学习的模型相比,解耦特征学习和分类器学习导致了显著的改进。
  2. 根据基准测试结果,通过集成专家模型或采用主动数据增强技术的自监督预训练来实现最好精度。
  3. 研究发现,SGD动量导致LTR出现问题,阻碍了进一步改善。
  4. 最近,Kang等人令人信服地证明了阶段性训练对LTR很重要。
  5. 权重衰减有助于学习隐藏层的平衡权重。
  6. 重要的是,我们的探索发现,虽然在分类器上使用L2规范化约束进行训练比简单训练有所改进,但它的表现不如下面描述的其他两个正则化。
  7. 与严格将所有滤波器权重的范数值设置为1的L2归一化不同,MaxNorm放松了这一约束,允许权重在训练期间在范数球内移动。
  8. 权重衰减中,不同数据集的最优λ各不相同——较大的数据集需要较小的权重衰减,直观地说,因为在更多数据上学习有助于泛化,因此需要较少的正则化。
    单阶段使用不平衡损失训练效果不好的原因:虽然他们没有解释为什么具有类平衡损失的单阶段训练表现不佳,但直观地说,这是因为类平衡损失人为地放大了从罕见的类训练数据计算的梯度,这损害了特征表示学习,从而损害了最终的LTR性能。
    本文作者使用了weight decay和max norm两种方法结合,因为发现两个结合效果更好。让模型不同类之间权重相差不会很大的同时,还能让这些权重缓慢增加。
    下面这幅图就是解释了这些方法的特点。
    在这里插入图片描述
    第一个就是普通方法训练的,它常见的类别权重增长快。
    第二个是L2 normalization,它把所有类别的权重都限定在一个常数。
    第三个是权重衰减,它的所有类的权重小,而且权重在增长。
    第四个是MaxNorm,它限制最大的权重。
    第五个是权重衰减和MaxNorm,会导致范数中的权重较小且平衡。

使用方法

weight decay

先定义好权重衰减的值。

weight_decay = 0.1 #weight decay value

然后在优化器中调用。Adam还有其他的都有weight_decay。

optimizer = optim.SGD([{'params': active_layers, 'lr': base_lr}], lr=base_lr, momentum=0.9, weight_decay=weight_decay)

MaxNorm

就是这个论文中的regularizers.py中的代码。只要会使用就好。就是要是不是作者代码中的模型的话,model.encoder.fc还需要根据自己的代码修改。

#使用前先定义好初始化好
pgdFunc = MaxNorm_via_PGD(thresh=thresh)
pgdFunc.setPerLayerThresh(model) # set per-layer thresholds这个是计算模型每一层的权重的阈值,这篇论文中只计算最后线性层的权重,并对最后线性层的权重进行限制

当模型训练一个epoch结束后,对已经更新完毕的模型权重进行限制,如果超过阈值就进行更新,让权重在最大范数的约束下。

 if pgdFunc:# Projected Gradient Descent
     pgdFunc.PGD(model)#对权重进行限制
import torch
import torch.nn as nn
import math
# The classes below wrap core functions to impose weight regurlarization constraints in training or finetuning a network.

class MaxNorm_via_PGD():
    def __init__(self, thresh=1.0, LpNorm=1, tau=1):
        self.thresh = thresh
        self.LpNorm = LpNorm
        self.tau = tau
        self.perLayerThresh = []

    def setPerLayerThresh(self, model):#根据指定的模型设置每层的阈值
        #set pre-layer thresholds
        self.perLayerThresh = []

        for curLayer in [model.encoder.fc.weight, model.encoder.fc.bias]:#遍历模型的最后两层
            curparam = curLayer.data#获取当前层的数据
            if len(curparam.shape) <= 1:#如果层只有一个维度,是一个偏置或者是一个1D的向量,则设置这一层的阈值为无穷大,继续下一层
                self.perLayerThresh.append(float('inf'))
                continue
            curparam_vec = curparam.reshape((curparam.shape[0], -1))#如果不是,把权重张量展开
            neuronNorm_curparam = torch.linalg.norm(curparam_vec, ord=self.LpNorm, dim=1).detach().unsqueeze(-1)#沿着第一维计算P番薯,结果存储
            curLayerThresh = neuronNorm_curparam.min() + self.thresh*(neuronNorm_curparam.max() - neuronNorm_curparam.min())#计算每一层的阈值及神经元范数的最小值加上最大值和最小值之间的缩放差
            self.perLayerThresh.append(curLayerThresh)#每层阈值存储

    def PGD(self, model):#定义PGD函数,用于在模型的参数上执行投影梯度下降,试试最大范数约束
        if len(self.perLayerThresh) == 0:#如果每层的阈值是空,用setPerLayerThresh方法初始化
            self.setPerLayerThresh(model)
        for i, curLayer in enumerate([model.encoder.fc.weight, model.encoder.fc.bias]):#遍历模型的最后两层
            curparam = curLayer.data#获取当前层的数据张量值
            curparam_vec = curparam.reshape((curparam.shape[0], -1))#变成一维
            neuronNorm_curparam = (torch.linalg.norm(curparam_vec, ord=self.LpNorm, dim=1)**self.tau).detach().unsqueeze(-1)#在最后加一维
            #计算权重张量中每行神经元番薯的tau次方
            scalingVect = torch.ones_like(curparam)#创建一个形状与当前层数据相同的张量,用1初始化
            curLayerThresh = self.perLayerThresh[i]#获取阈值

            idx = neuronNorm_curparam > curLayerThresh#创建bool保存超过阈值的神经元
            idx = idx.squeeze()#
            tmp = curLayerThresh / (neuronNorm_curparam[idx].squeeze())**(self.tau)#根据每层的阈值和超过阈值的神经元番薯计算缩放因子
            for _ in range(len(scalingVect.shape)-1):#扩展缩放因子以匹配当前层数据的维度
                tmp = tmp.unsqueeze(-1)

            scalingVect[idx] = torch.mul(scalingVect[idx],tmp)
            curparam[idx] = scalingVect[idx] * curparam[idx]
            curparam[idx] = scalingVect[idx] * curparam[idx]#通过缩放值更新当前层的数据,以便对超过阈值的神经元进行缩放。完成权重更新


如果使用原来的代码报错的可以看下面这个

我的网络只有一层是线性层idx = idx.squeeze(),idx是(1,1)形状的,squeeze就没了,所以报错,如果有这个原因的可以改成idx = idx.squeeze(1)。maxnorm只改最后两层/一层权重所以,定义了一个列表存储线性层只取最后两层或者一层。

class MaxNorm_via_PGD():
    # learning a max-norm constrainted network via projected gradient descent (PGD)
    def __init__(self, thresh=1.0, LpNorm=2, tau=1):
        self.thresh = thresh
        self.LpNorm = LpNorm
        self.tau = tau
        self.perLayerThresh = []

    def setPerLayerThresh(self, model):
        # set per-layer thresholds
        self.perLayerThresh = []#存储每一层的阈值
        self.last_two_linear_layers = []#提取线性层
        for name, module in model.named_children():
            if isinstance(module, nn.Linear):
                self.last_two_linear_layers.append(module)

        for linear_layer in self.last_two_linear_layers[-min(2, len(self.last_two_linear_layers)):]:  # here we only apply MaxNorm over the last two layers
            curparam = linear_layer.weight.data
            if len(curparam.shape) <= 1:
                self.perLayerThresh.append(float('inf'))
                continue
            curparam_vec = curparam.reshape((curparam.shape[0], -1))
            neuronNorm_curparam = torch.linalg.norm(curparam_vec, ord=self.LpNorm, dim=1).detach().unsqueeze(-1)
            curLayerThresh = neuronNorm_curparam.min() + self.thresh * (
                        neuronNorm_curparam.max() - neuronNorm_curparam.min())
            self.perLayerThresh.append(curLayerThresh)

    def PGD(self, model):
        if len(self.perLayerThresh) == 0:
            self.setPerLayerThresh(model)
        for i, curLayer in enumerate([self.last_two_linear_layers[-min(2,
                                                             len(self.last_two_linear_layers))]]):  # here we only apply MaxNorm over the last two layers

            curparam = curLayer.weight.data

            curparam_vec = curparam.reshape((curparam.shape[0], -1))
            neuronNorm_curparam = (
                        torch.linalg.norm(curparam_vec, ord=self.LpNorm, dim=1) ** self.tau).detach().unsqueeze(-1)
            scalingVect = torch.ones_like(curparam)
            curLayerThresh = self.perLayerThresh[i]

            idx = neuronNorm_curparam > curLayerThresh
            idx = idx.squeeze(1)
            tmp = curLayerThresh / (neuronNorm_curparam[idx].squeeze()) ** (self.tau)
            for _ in range(len(scalingVect.shape) - 1):
                tmp = tmp.unsqueeze(-1)

            scalingVect[idx] = torch.mul(scalingVect[idx], tmp)
            curparam[idx] = scalingVect[idx] * curparam[idx]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/358409.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

网络原理-TCP/IP(1)

应用层 我们之前编写完了基本的java socket, 要知道,我们之前所写的所有代码都在应用层中,都是为了完成某项业务,如翻译等.关于应用层,后面会有专门的讲解,在此处先讲一下基础知识. 应用层对应着应用程序,是程序员打交道最多的一层,调用系统提供的网络api写出的代码都是应用层…

如何安装配置HFS并实现无公网ip远程访问本地电脑共享文件

文章目录 前言1.下载安装cpolar1.1 设置HFS访客1.2 虚拟文件系统 2. 使用cpolar建立一条内网穿透数据隧道2.1 保留隧道2.2 隧道名称2.3 成功使用cpolar创建二级子域名访问本地hfs 总结 前言 在大厂的云存储产品热度下降后&#xff0c;私人的NAS热度快速上升&#xff0c;其中最…

【webrtc】m98 : vs2019 直接构建webrtc及moduletest工程 2

字数有限制,我们继续 【webrtc】m98 : vs2019 直接构建webrtc及unitest工程 1modules_unittests 构建 Build started... 1>------ Build started: Project: modules_unittests, Configuration: GN Win32 ------ 1>ninja: Entering directory `G:\CDN\rtcCli\m98\src\o…

导出Mysql数据库表名和字段并合并成一个word

参考链接&#xff1a; 导出MySQL数据库所有库和字段注释及相关信息为word文档——工具类 java - Apache POI - How to copy tables from one docx to another docx - Stack Overflow 领导让我研究下一个低代码平台的代码&#xff0c;我就想着做一个把数据库字段直接导出来的…

STM32F407移植OpenHarmony笔记4

上一篇写到make menuconfig报错&#xff0c;继续开整。 make menuconfig需要/device/soc/*下面有对应的Kconfig文件。 直接去gitee下载stm32的配置文件拿来参考用。 先提取Kconfig文件&#xff0c;后面再添加其它文件。https://gitee.com/openharmony/device_soc_st/tree/Open…

【React教程】(1) React简介、React核心概念、React初始化

目录 ReactReact 介绍React 特点React 的发展历史React 与 Vue 的对比技术层面开发团队社区Native APP 开发 相关资源链接 EcmaScript 6 补充React 核心概念组件化虚拟 DOM 起步初始化及安装依赖Hello World React React 介绍 React 是一个用于构建用户界面的渐进式 JavaScrip…

day33WEB 攻防-通用漏洞文件上传中间件解析漏洞编辑器安全

目录 一&#xff0c;中间件文件解析漏洞-IIS&Apache&Nginx -IIS 6 7 文件名 目录名 -Apache 换行解析 配置不当 1、换行解析-CVE-2017-15715 2、配置不当-.htaccess 配置不当 -Nginx 文件名逻辑 解析漏洞 1、文件名逻辑-CVE-2013-4547 2、解析漏洞-nginx.conf …

LLM漫谈(四)| ChatDOC:超越ChatPDF性能并支持更多功能的阅读聊天工具

在过去的一年里&#xff0c;ChatGPT的兴起催生了许多基于GPT的人工智能工具&#xff0c;其中Chat PDF工具得到了广泛关注。这些工具对知识密集型专业人员来说尤其有价值&#xff0c;大大提高了生产力。随着Chat PDF工具的激增&#xff0c;选择正确的工具变得至关重要。 接下来&…

视频转GIF动图实践, 支持长视频转GIF

背景 找了很多GIF动图制作的工具&#xff0c;比如将视频转成GIF, 或者将一系列图片转成GIF, 增加背景文案等等功能。很多收费或者用的一些三方库有点点卡顿&#xff0c;或者需要安装一个软件&#xff0c;所以就自己做一款纯前端页面级别的 视频转 GIF 动图工具。 最开始找到一…

网站地址怎么改成HTTPS?

现在&#xff0c;所有类型的网站都需要通过 HTTPS 协议进行安全连接&#xff0c;而实现这一目标的唯一方法是使用 SSL 证书。如果您不将 HTTP 转换为 HTTPS&#xff0c;浏览器和应用程序会将您网站的连接标记为不安全。 但用户询问如何将我的网站从 HTTP 更改为 HTTPS。在此页…

米贸搜|如何“有效利用”Facebook动态广告?!

什么是动态广告&#xff1f; 在我们的日常在线购物经验中&#xff0c;我们经常遇到这样的情况&#xff1a;在某宝平台上浏览或搜索一款产品后&#xff0c;不久你就会发现&#xff0c;平台开始向你推荐与你之前浏览过的商品相似的其他商品。这种商品推荐通常非常符合你的口味&a…

用可视化表单设计轻松实现流程化办公!

在现代化职场办公中&#xff0c;低代码技术平台的利用程度越来越高&#xff0c;成为大家实现流程化办公和数字化转型的友好合作伙伴。作为专业的服务商&#xff0c;流辰信息潜心研发低代码技术平台产品及可视化表单设计工具&#xff0c;轻松助力越来越多行业的客户朋友提高办公…

C语言应用实例——贪吃蛇

&#xff08;图片由AI生成&#xff09; 0.贪吃蛇游戏背景 贪吃蛇游戏&#xff0c;最早可以追溯到1976年的“Blockade”游戏&#xff0c;是电子游戏历史上的一个经典。在这款游戏中&#xff0c;玩家操作一个不断增长的蛇&#xff0c;目标是吃掉出现在屏幕上的食物&#xff0c…

多线程事务如何回滚?

背景介绍 1&#xff0c;最近有一个大数据量插入的操作入库的业务场景&#xff0c;需要先做一些其他修改操作&#xff0c;然后在执行插入操作&#xff0c;由于插入数据可能会很多&#xff0c;用到多线程去拆分数据并行处理来提高响应时间&#xff0c;如果有一个线程执行失败&am…

Seata部署

1、下载seata wget https://github.com/apache/incubator-seata/releases/download/v1.4.2/seata-server-1.4.2.zip 2、执行服务端SQL -- -------------------------------- The script used when storeMode is db -------------------------------- -- the table to store Gl…

12306 真的很拉跨吗?春运是对它最大的误解!

春节降至&#xff0c;大家都抢到火车票了吗&#xff1f;马上就要迎来春节&#xff0c;是不是都在吐槽 12306 的种种不好&#xff0c;它真的有这么拉跨吗&#xff1f; 其实不然&#xff0c;每到各种节假日&#xff0c;都是对 12306 最大的误解&#xff01; 特别是春运&#xf…

多个微信号难管理?

很多客户都遇到了这几个问题&#xff1a; 1、公司分配给员工有很多个微信账号&#xff0c;但是管理起来很难&#xff1f; 2、多个微信号要不来回切换登录&#xff0c;要不需要挂在多台手机了&#xff1f; 3、每个号每天都需要发朋友圈&#xff0c;工作量大&#xff1f;有时还…

GP232RL国产USB串口如何兼容FT232RL开发资料

GP232RL是最新加入 ftdi 系列 usb 接口集成电路设备的设备。 232r是一个 usb 到串行 uart 接口&#xff0c;带有可选的时钟发生器输出&#xff0c;以及新的 ftdichip-idTM 安全加密器特性。此外&#xff0c;还提供了异步和同步位崩接口模式。 通过将外部 eeprom、时钟电路和 …

Java API 操作 HDFS

Java API 操作HDFS一般有两种方式&#xff1a; 使用HDFS客户端配置文件自动配置 Java 代码中配置 一 使用HDFS客户端配置 1.1 下载HDFS客户端配置 1.2 创建Maven项目 创建Maven项目&#xff0c;将下载的客户端配置文件 core-site.xml、hdfs-site.xml 放入resources目录下&…

用AI工具一键生成原创文案的方法

一键生成原创文案对于文案工作者来说它是一种高效率创作文案内容的方法。文案工作者知道创作文案是一件消耗精力和时间的事情&#xff0c;遇到没有创作灵感&#xff0c;想要写一篇高质量的文案内容简直难上加难&#xff0c;因此&#xff0c;互联网上出现了一键生成原创文案的方…