EMA训练微调

就是取前几个epoch的weight的平均值,可以缓解微调时的灾难性遗忘(因为新数据引导,模型权重逐渐,偏离训练时学到的数据分布,忘记之前学好的先验知识)
在这里插入图片描述

class EMA():
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay  # decay rate
        self.shadow = {}  # old weight
        self.backup = {}  # new weight
 
    def register(self):  # deep copy weight for init
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()
 
    def update(self):  # ema:average weight for train
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()
 
    def apply_shadow(self):  # load old weight for eval begin
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                self.backup[name] = param.data
                param.data = self.shadow[name]
 
    def restore(self):  # load new weight for eval end
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}
 
# 初始化
ema = EMA(model, 0.999)
ema.register()
 
# 训练过程中,更新完参数后,同步update shadow weights
def train():
    optimizer.step()
    ema.update()
 
# eval前,apply shadow weights;eval之后,恢复原来模型的参数
def evaluate():
    ema.apply_shadow()
    # evaluate
    ema.restore()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/201448.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

RabbitMQ消息模型之Sample

Hello World Hello World是官网给出的第一个模型,使用的交换机类型是直连direct,也是默认的交换机类型。 在上图的模型中,有以下概念: P:生产者,也就是要发送消息的程序C:消费者:消…

机器学习:领域自适应学习

训练一个分类器是小问题 上难度 训练数据和测试数据不一致,比如训练数据是黑白的,测试时彩色的,结果准确率非常低。 训练数据和测试数据有点差距的时候,能不能效果也能好呢?这就用到了领域自使用domain adptation 用一…

Windows 11的新功能不适用于所有人,但对将要使用的人来说非常酷

正如一个新的预览版本所示,Windows 11即将为那些使用手写笔的人添加一些智能功能,以及其他改进。 这是预览版22635.2776(也称为KB5032292),已推出Beta频道,这是发布预览版之前的最后一个测试方法&#xff…

速速报名!请查收 2023 龙蜥操作系统大会超全指南

亲爱的小伙伴们,大家好!我是大家的老朋友小龙!自 2023 龙蜥操作系统大会宣布启动以来,小龙收到了来自四面八方的诸多期待和小心心。首届龙蜥大会正如火如荼地进行中,为表示对关注社区的每一位小伙伴由衷的感谢&#xf…

Ubuntu安装ssh

Ubuntu安装ssh服务器 一、ssh ssh:安全外壳协议(secure shell)的缩写,安全外壳协议(安全的shell),是一个计算机网络协议(默认端口号为22)。通过ssh协议可以在客户端安全(提供身份认…

k8s中Pod控制器简介,ReplicaSet、Deployment、HPA三种处理无状态pod应用的控制器介绍

目录 一.Pod控制器简介 二.ReplicaSet(简写rs) 1.简介 (1)主要功能 (2)rs较完整参数解释 2.创建和删除 (1)创建 (2)删除 3.扩容和缩容 &#xff08…

【Java SE】带你在String类世界中遨游!!!

🌹🌹🌹我的主页🌹🌹🌹 🌹🌹🌹【Java SE 专栏】🌹🌹🌹 🌹🌹🌹上一篇文章:带你走近Java的…

【C 语言经典100例】C 练习实例9

题目&#xff1a;要求输出国际象棋棋盘。 程序分析&#xff1a;国际象棋棋盘由64个黑白相间的格子组成&#xff0c;分为8行*8列。用i控制行&#xff0c;j来控制列&#xff0c;根据ij的和的变化来控制输出黑方格&#xff0c;还是白方格。 #include<stdio.h>int main() {…

从 Elasticsearch 到 SelectDB,观测云实现日志存储与分析的 10 倍性价比提升

作者&#xff1a;观测云 CEO 蒋烁淼 & 飞轮科技技术团队 在云计算逐渐成熟的当下&#xff0c;越来越多的企业开始将业务迁移到云端&#xff0c;传统的监控和故障排查方法已经无法满足企业的需求。在可观测理念逐渐深入人心的当下&#xff0c;人们越来越意识到通过多层次、…

YOLOv5小目标检测层

目录 一、原理 二、yaml配置文件 一、原理 小目标检测层,就是增加一个检测头,增加一层锚框,用来检测输入图像中像素较小的目标 二、yaml配置文件 # YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters nc: 3 # number of classes depth_multiple: 0.33 # model…

案例,linux环境下OpenCV+Java,实现证件照在线更换背景色

先看效果&#xff08;图片来自网络&#xff0c;如有侵权&#xff0c;请联系作者删除&#xff09; 主要是通过java实现的&#xff0c;linux环境编译安装opencv及证件照背景色更换的核心算法在前面一篇文章中有写到。 目前算法还有瞎呲&#xff0c;当照片光线不均的时候会出现误…

低调使用。推荐一个 GPT4 Turbo、Vision、GPTs、DELL·E3 等所有最新功能同步可用国内网站

在 11 月 6 日&#xff0c;万众期待的 OpenAI DevDay&#xff0c;ChatGPT 发布了一系列新的产品&#xff0c;其中推出了 GPT4 Turbo&#xff0c;并且将GPT4 Vision&#xff0c;DELLE3 等等能力全部集合到一起&#xff0c;不需要再分开使用&#xff0c;原来的局限的文本聊天也进…

创业公司or大厂怎么选?不是凡尔赛,一个技巧让你涨薪10W!

最近总有一些特别“凡尔赛”的发几个 offer 问我选择哪个&#xff1f;其中比较典型的一个问题就是&#xff1a; “一个是处于上升期的创业型公司 &#xff0c;一个行业大厂&#xff0c;薪资待遇差不多&#xff0c;到底该如何进行选择和取舍呢&#xff1f;“ 这个问题不是个别…

Spring---对象的存储和读取

文章目录 Spring对象的存储创建Bean对象将Bean对象存储到spring中添加配置文件存储Bean对象 Spring对象的读取得到Spring上下文对象从Spring中取出Bean对象使用Bean对象 Spring对象的存储 创建Bean对象 Bean对象其实就是一个普通的Java对象。我们按照创建Java对象的方式来创建…

48个代码大模型汇总,涵盖原始、改进、专用、微调4大类

代码大模型具有强大的表达能力和复杂性&#xff0c;可以处理各种自然语言任务&#xff0c;包括文本分类、问答、对话等。这些模型通常基于深度学习架构&#xff0c;如Transformer&#xff0c;并使用预训练目标&#xff08;如语言建模&#xff09;进行训练。 在对大量代码数据的…

配电网重构单时段+多时段(附带matlab代码)

配电网重构单时段多时段 对于《主动配电网最优潮流研究及其应用实例》的基本复现 简介&#xff1a;最优潮流研究在配电网规划运行中不可或缺&#xff0c;且在大量分布式能源接入的主动配电网环境下尤为重要。传统的启发式算法在全局最优解和求解速度上均无法满足主动配电网运行…

什么是计算机病毒?

计算机病毒 1. 定义2. 计算机病毒的特点3. 计算机病毒的常见类型和攻击方式4. 如何防御计算机病毒 1. 定义 计算机病毒是计算机程序编制者在计算机程序中插入的破坏计算机功能或者破坏数据&#xff0c;影响计算机使用并且能够自我复制的一组计算机指令或程序代码。因其特点与生…

【面试】typescript

目录 为什么用TypeScript&#xff1f; TS和JS的区别 控制类成员可见性的访问关键字&#xff1f; public protected&#xff09;&#xff0c;该类及其子类都可以访问它们。 但是该类的实例无法访问。 私有&#xff08;private&#xff09;&#xff0c;只有类的成员可以访问…

什么是媒体发布?媒体发布平台有哪些?

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 一&#xff0c;什么是媒体发布&#xff1f; 媒体发布是指利用互联网媒体及移动端媒体和传统媒体发布关于人物、品牌、商业公司等的新闻及推广软文和传记等的行为。媒体平台可以是电视、…

C51--LCD1602显示屏

LCD602显示&#xff1a; 1、概述 LCD602是一种工业字符型液晶&#xff0c;能够同时显示16x02&#xff0c;即32字符&#xff08;16列&#xff0c;2行&#xff09; 2、引脚&#xff1a; VSS&#xff1a;电源地VDD&#xff1a;电源正极——5V电源VO&#xff1a; 液晶显示偏压 …