GiantPandaCV | 提升分类模型acc(一):BatchSizeLARS

本文来源公众号“GiantPandaCV”,仅用于学术分享,侵权删,干货满满。

原文链接:提升分类模型acc(一):BatchSize&LARS

在使用大的bs训练情况下,会对精度有一定程度的损失,本文探讨了训练的bs大小对精度的影响,同时探究Layer-wise Adaptive Rate Scaling(LARS)是否可以有效的提升精度。

论文链接:https://arxiv.org/abs/1708.03888

论文代码: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
知乎专栏: https://zhuanlan.zhihu.com/p/406882110

1 引言

如何提升业务分类模型的性能,一直是个难题,毕竟没有99.999%的性能都会带来一定程度的风险,所以很多时候我们只能通过控制阈值来调整准召以达到想要的效果。本系列主要探究哪些模型trick和数据的方法可以大幅度让你的分类性能更上一层楼,不过要注意一点的是,tirck不一定是适用于不同的数据场景的,但是数据处理方法是普适的。本篇文章主要是对于大的bs下训练分类模型的情况,如果bs比较小的可以忽略,直接看最后的结论就好了(这个系列以后的文章讲述的方法是通用的,无论bs大小都可以用)。

2 实验配置

  • 模型:ResNet50

  • 数据:ImageNet1k

  • 环境:8xV100

3 BatchSize对精度的影响

我这里设计了4组对照实验,256, 1024, 2048和4096的batchsize,开了FP16也只能跑到了4096了。采用的是分布式训练,所以单张卡的bs就是bs = total_bs / ngpus_per_node。这里我没有使用跨卡bn,对于bs 64单卡来说理论上已经很大了,bn的作用是约束数据分布,64的bs已经可以表达一个分布的subset了,再大的bs还是同分布的,意义不大,跨卡bn的速度也更慢,所以大的bs基本可以忽略这个问题。但是对于检测的任务,跨卡bn还是有价值的,毕竟输入的分辨率大,单卡的bs比较小,一般4,8,16,这时候统计更大的bn会对模型收敛更好。

很明显可以看出来,当bs增加到4k的时候,acc下降了将近0.8%个点,1k的时候,下降了0.2%个点,所以,通常我们用大的bs训练的时候,是没办法达到最优的精度的。个人建议,使用1k的bs和0.4的学习率最优。

4 LARS(Layer-wise Adaptive Rate Scaling)

4.1. 理论分析

由于bs的增加,在同样的epoch的情况下,会使网络的weights更新迭代的次数变少,所以需要对LR随着bs的增加而线性增加,但是这样会导致上面我们看到的问题,过大的lr会导致最终的收敛不稳定,精度有所下降。

LARS代码如下:

class LARC(object):
    def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
        self.optim = optimizer
        self.trust_coefficient = trust_coefficient
        self.eps = eps
        self.clip = clip

    def step(self):
        with torch.no_grad():
            weight_decays = []
            for group in self.optim.param_groups:
                # absorb weight decay control from optimizer
                weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
                weight_decays.append(weight_decay)
                group['weight_decay'] = 0
                for p in group['params']:
                    if p.grad is None:
                        continue
                    param_norm = torch.norm(p.data)
                    grad_norm = torch.norm(p.grad.data)

                    if param_norm != 0 and grad_norm != 0:
                        # calculate adaptive lr + weight decay
                        adaptive_lr = self.trust_coefficient * (param_norm) / (
                                    grad_norm + param_norm * weight_decay + self.eps)

                        # clip learning rate for LARC
                        if self.clip:
                            # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
                            adaptive_lr = min(adaptive_lr / group['lr'], 1)

                        p.grad.data += weight_decay * p.data
                        p.grad.data *= adaptive_lr

        self.optim.step()
        # return weight decay control to optimizer
        for i, group in enumerate(self.optim.param_groups):
            group['weight_decay'] = weight_decays[i]

这里有一个超参数,trust_coefficient,也就是公式里面所提到的, 这个参数对精度的影响比较大,实验部分我们会给出结论。

4.2. 实验结论

可以很明显发现,使用了LARS,设置turst_confidence为1e-3的情况下,有着明显的掉点,设置为2e-2的时候,在1k和4k的情况下,有着明显的提升,但是2k的情况下有所下降。

LARS一定程度上可以提升精度,但是强依赖超参,还是需要细致的调参训练。

5 结论

  • 8卡进行分布式训练,使用1k的bs可以很好的平衡acc&speed。

  • LARS一定程度上可以提升精度,但是需要调参,做业务可以不用考虑,刷点的话要好好训练。

6 结束语

本文是提升分类模型acc系列的第一篇,后续会讲解一些通用的trick和数据处理的方法,敬请关注。

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/691194.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【CS.SE】端午节特辑:Docker容器化技术详解与实战

端午节, 先祝愿大家端午安康,阖家幸福, 哈哈!这篇讲下Docker这一现代软件开发中不可或缺的技术。软件工程涉及软件开发的整个生命周期,包括需求分析、设计、构建、测试、部署和维护。Docker作为一种容器化技术,直接关联到软件部署…

WWDC 2024前瞻:苹果如何用AI技术重塑iOS 18和Siri

苹果下周的全球开发者大会有望成为这家 iPhone 制造商历史上的关键时刻。在 WWDC 上,这家库比蒂诺科技巨头将展示如何选择将人工智能技术集成到其设备和软件中,包括通过与 OpenAI 的历史性合作伙伴关系。随着重大事件的临近,有关 iOS 18 及其…

uniapp引入uview无代码提示

前提安装正确: 无论是基于npm和Hbuilder X方式安装,一定要配置正确。 解决办法 以前在pages.json里面的写法: "easycom": {"^u-(.*)": "uview-ui/components/u-$1/u-$1.vue" }但是现在hbuilderx要求规范ea…

【小白专用24.6.8】C# 异步任务Task和异步方法async/await详解

一、什么是异步 同步和异步主要用于修饰方法。当一个方法被调用时,调用者需要等待该方法执行完毕并返回才能继续执行,我们称这个方法是同步方法;当一个方法被调用时立即返回,并获取一个线程执行该方法内部的业务,调用…

springboot+minio+kkfileview实现文件的在线预览

在原来的文章中已经讲述过springbootminio的开发过程,这里不做讲述。 原文章地址: https://blog.csdn.net/qq_39990869/article/details/131598884?spm1001.2014.3001.5501 如果你的项目只是需要在线预览图片或者视频那么可以使用minio自己的预览地址进…

C++期末复习总结(2)

目录 1.运算符重载 2.四种运算符重载 (1)关系运算符的重载 (2) 左移运算符的重载 (3)下标运算符的重载 (4)赋值运算符的重载 3.继承的方式 4.继承的对象模型 5.基类的构造 6…

中缀表达式和前缀后缀

在中缀表达式中,操作数可能与两个操作符相结合 但是,想要不带括号无歧义,且不需要考虑运算符优先级和结合性 所以考虑 前缀表达式,波兰表达式 后缀表达式 逆波兰表达式 对于人来说,中缀表达式是最容易读懂的。但是对于…

C语言之字符函数总结(全部!),一篇记住所有的字符函数

前言 还在担心关于字符的库函数记不住吗?不用担心,这篇文章将为你全面整理所有的字符函数的用法。不用记忆,一次看完,随查随用。用多了自然就记住了 字符分类函数和字符转换函数 C语言中有一系列的函数是专门做字符分类和字符转换…

无人机、机器人10公里WiFi远距离图传模块,实时高清视频传输,飞睿CV5200模组方案,支持mesh自组网模块

在快速发展的物联网时代,远距离无线通信技术已成为连接各种智能设备的关键。无人机、安防监控、机器人等领域对数据传输的距离和速度要求越来越高。 公里级远距离WiFi模组方案可以通过多种技术和策略的结合来实现无人机和机器人之间的高效通信传输。 飞睿智能CV52…

【Java毕业设计】基于JavaWeb企业违规信息综合管理系统

文章目录 摘 要ABSTRACT目 录1 概述1.1 研究背景及意义1.2 国内外研究现状1.3 拟研究内容1.4 系统开发技术1.4.1 Java编程语言1.4.2 SpringBoot框架1.4.3 MySQL数据库1.4.4 B/S结构1.4.5 MVC模式 2 系统需求分析2.1 可行性分析2.2 功能需求分析 3 系统设计3.1 功能模块设计3.2 …

Python应用开发——30天学习Streamlit Python包进行APP的构建(5)

上几次我们已经将一些必备的内容进行了快速的梳理,让我们掌握了streanlit的凯快速上手,接下来我们将其它的一些基础函数再做简单的梳理,以顺便回顾我们未来可能用到的更丰富的函数来实现应用的制作。 st.write_stream 将生成器、迭代器或类似流的序列串流到应用程序中。 …

数据总线、位扩展、字长

数据总线(Data Bus) 定义 数据总线是计算机系统中的一组并行信号线,用于在计算机内部传输数据。这些数据可以在中央处理器(CPU)、内存和输入/输出设备之间传输。 作用 数据传输:数据总线负责在计算机各…

Elasticsearch 认证模拟题 - 15

一、题目 原索引 task1 的字段 title 字段包含单词 The,查询 the 可以查出 1200 篇文档。重建 task1 索引为 task1_new,重建后的索引, title 字段查询 the 单词,不能匹配到任何文档。 PUT task1 {"mappings": {"…

Jar包部署为linux系统服务

文章目录 引言I 以系统服务的方式部署(推荐)1.1 创建systemd服务1.2 SSH上传jar包,并重启服务1.3 收集自定义systemd服务的日志【可选】II 脚本部署方式(不推荐)2.1 启动脚本2.2 关闭脚本2.3 SSH上传jar包,并重启服务III 打包3.1 build中的plugins中标签的含义3.2 jar中没…

代码随想录算法训练营第36期DAY50

DAY50 如果写累了就去写套磁信吧。 198打家劫舍 class Solution {public: int rob(vector<int>& nums) { vector<int> dp(nums.size()); dp[0]nums[0]; if(nums.size()1) return nums[0]; dp[1]max(nums[0],nums[1]); …

【Unity UGUI】Screen.safeArea获取异形屏数据失败

Screen.safeArea获取不到异形屏的尺寸位置等数据 检查AndroidManifest.xml文件是否有设置&#xff1a;android:theme"style/UnityThemeSelector"&#xff0c;没有加上即可 android:theme"style/UnityThemeSelector"

基于大模型 Gemma-7B 和 llama_index,轻松实现 NL2SQL

节前&#xff0c;我们星球组织了一场算法岗技术&面试讨论会&#xff0c;邀请了一些互联网大厂朋友、参加社招和校招面试的同学. 针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。 汇总合集&…

时光正好保剑锋的抱治百病与成年人的世界

《时光正好》&#xff1a;保剑锋的“抱治百病”与成年人的世界在繁忙的都市里&#xff0c;每个角落上演着各自的人生戏码。而在这些戏码中&#xff0c;由保剑锋主演的《时光正好》无疑成为了近期引人注目的焦点。这部电视剧以其真实而深刻的剧情&#xff0c;让我们看到了成年人…

用于认知负荷评估的集成时空深度聚类(ISTDC)

Integrated Spatio-Temporal Deep Clustering (ISTDC) for cognitive workload assessment 摘要&#xff1a; 本文提出了一种新型的集成时空深度聚类&#xff08;ISTDC&#xff09;模型&#xff0c;用于评估认知负荷。该模型首先利用深度表示学习&#xff08;DRL&#xff09;…

Debug-014-nginx代理路径的一条规则

直接上图&#xff1a; 今天看禹神的前端视频&#xff0c;讲到在nginx中代理路径的时候&#xff0c;有一个规则&#xff1a; 如果/dev和下面的proxy_pass路径最后都带‘/’,那么就是匹配到dev之后要删除dev,然后再带着后面的路径&#xff1b;如果/dev和下面的proxy_pass路径最后…