模型的权值平均的原理和Pytorch的实现

一、前言

模型权值平均是一种用于改善深度神经网络泛化性能的技术。通过对训练过程中不同时间步的模型权值进行平均,可以得到更宽的极值点(optima)并提高模型的泛化能力。 在PyTorch中,官方提供了实现模型权值平均的方法。

这里我们首先介绍指数移动平均(EMA)方法,它使用一个衰减系数来平衡当前权值和先前平均权值。其次,介绍了随机加权平均(SWA)方法,它通过将当前权值与先前平均权值进行加权平均来更新权值。最后,介绍了Tanh自适应指数移动EMA算法(T_ADEMA),它使用Tanh函数来调整衰减系数,以更好地适应训练过程中的不同阶段。

为了方便使用这些权值平均方法,我将官方的代码写成了一个基类AveragingBaseModel,以此引出EMAModel、SWAModel和T_ADEMAModel等方法。这些类可以用于包装原始模型,并在训练过程中更新平均权值。 为了验证这些权值平均方法的效果,我还在ResNet18模型上进行了简单的实验。实验结果表明,使用权值平均方法可以提高模型的准确率,尤其是在训练后期。

但请注意,博客中所提供的代码示例仅用于演示权值平均的原理和PyTorch的实现方式,并不能保证在所有情况下都能取得理想的效果。实际应用中,还需要根据具体任务和数据集来选择适合的权值平均方法和参数设置。

二、算法介绍

基类实现

这里我们的基类完全是参照于torch源码部分,仅仅进行了一点细微的修改。

它首先通过de_parallel函数将原始模型转换为单个GPU模型。de_parallel函数用于处理并行模型,将其转换为单个GPU模型。然后,它将转换后的模型复制到适当的设备(CPU或GPU)上(这一步很重要,问题大多数就是因为计算不匹配),并注册一个名为n_averaged的缓冲区,用于跟踪已平均的次数。

在forward方法中,它简单地将调用传递给转换后的模型。update方法首先获取当前模型和新模型的参数,并将它们转换为可迭代对象,用于更新平均权值。它接受一个新的模型作为参数,并将其与当前模型(已平均的权值)进行比较。

from copy import deepcopy
from pyzjr.core.general import is_parallel
import itertools
from torch.nn import Module

def de_parallel(model):
    """
    将并行模型(DataParallel 或 DistributedDataParallel)转换为单 GPU 模型。
    """
    return model.module if is_parallel(model) else model

class AveragingBaseModel(Module):
    def __init__(self, model, cuda=False, avg_fn=None, use_buffers=False):
        super(AveragingBaseModel, self).__init__()
        device = 'cuda' if cuda and torch.cuda.is_available() else 'cpu'
        self.module = deepcopy(de_parallel(model))
        self.module = self.module.to(device)
        self.register_buffer('n_averaged',
                             torch.tensor(0, dtype=torch.long, device=device))
        self.avg_fn = avg_fn
        self.use_buffers = use_buffers

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

    def update(self, model):
        self_param = itertools.chain(self.module.parameters(), self.module.buffers() if self.use_buffers else [])
        model_param = itertools.chain(model.parameters(), model.buffers() if self.use_buffers else [])

        self_param_detached = [p.detach() for p in self_param]
        model_param_detached = [p.detach().to(p_averaged.device) for p, p_averaged in zip(model_param, self_param_detached)]

        if self.n_averaged == 0:
            for p_averaged, p_model in zip(self_param_detached, model_param_detached):
                p_averaged.copy_(p_model)

        if self.n_averaged > 0:
            for p_averaged, p_model in zip(self_param_detached, model_param_detached):
                n_averaged = self.n_averaged.to(p_averaged.device)
                p_averaged.copy_(self.avg_fn(p_averaged, p_model, n_averaged))

        if not self.use_buffers:
            for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
                b_swa.copy_(b_model.to(b_swa.device).detach())

        self.n_averaged += 1

若当前模型尚未进行过平均(即n_averaged为0),则直接将新模型的参数复制到当前模型中。若当前模型已经进行过平均,则通过avg_fn函数计算当前模型和新模型的加权平均,并将结果复制到当前模型中。如果use_buffers为True,则会将缓冲区从新模型复制到当前模型。最后,n_averaged增加1,表示已进行一次平均。

指数移动平均(EMA)

EMA被用于根据当前参数和之前的平均参数来更新平均参数。其计算公式如下所示:

EMA_{param} = decay * EMA_{param} + (1 - decay) * current_{param}

这里的EMA param是当前的平均参数,current param是当前的参数,decay是一个介于0和1之间的衰减因子,它用于控制当前参数对平均参数的贡献程度。decay越接近1,平均参数对当前参数的影响就越小,反之亦是。

def get_ema_avg_fn(decay=0.999):
    @torch.no_grad()
    def ema_update(ema_param, current_param, num_averaged):
        return decay * ema_param + (1 - decay) * current_param
    return ema_update

class EMAModel(AveragingBaseModel):
    def __init__(self, model, cuda = False, decay=0.9, use_buffers=False):
        super().__init__(model=model, cuda=cuda, avg_fn=get_ema_avg_fn(decay), use_buffers=use_buffers)

随机加权平均(SWA)

SWA通过对神经网络的权重进行平均来改善模型的泛化能力。其计算公式如下所示:

SWA_{param} = avg_{param} + (current_{param} - avg_{param}) / (num_{avg} + 1)

SWA param是新的平均参数,averaged param是之前的平均参数,current param是当前的参数,num avg是已经平均的参数数量。

def get_swa_avg_fn():
    @torch.no_grad()
    def swa_update(averaged_param, current_param, num_averaged):
        return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
    return swa_update

class SWAModel(AveragingBaseModel):
    def __init__(self, model, cuda = False,use_buffers=False):
        super().__init__(model=model, cuda=cuda, avg_fn=get_swa_avg_fn(), use_buffers=use_buffers)

Tanh自适应指数移动EMA算法(T_ADEMA)

这一个是在查询资料的时候,找到的一篇论文描述的,是否有效,还得经过实验才对。

全文阅读--XML全文阅读--中国知网 (cnki.net)

论文表示是为了在神经网络训练过程中根据不同的训练阶段更有效地过滤噪声,所提出的公式:

decay = alpha * tanh(num_{avg})

T_ADEMA_{param} = decay * avg_{param} + (1 - decay) * current_{param}

T_ADEMA param是新的平均参数,avg param是之前的平均参数,current param是当前的参数,num avg是已经平均的参数数量。alpha是一个控制衰减速率的超参数。通过将参数数量作为输入传递给切线函数的参数,动态地计算衰减因子。切线函数(tanh)的输出范围为[-1, 1],随着参数数量的增加,衰减因子会逐渐趋近于1。由于切线函数的特性,当参数数量较小时,衰减因子接近于0;当参数数量较大时,衰减因子接近于1。

def get_t_adema(alpha=0.9):
    num_averaged = [0]  # 使用列表包装可变对象,以在闭包中引用
    @torch.no_grad()
    def t_adema_update(averaged_param, current_param, num_averageds):
        num_averaged[0] += 1
        decay = alpha * torch.tanh(torch.tensor(num_averaged[0], dtype=torch.float32))
        tadea_update = decay * averaged_param + (1 - decay) * current_param
        return tadea_update
    return t_adema_update

class T_ADEMAModel(AveragingBaseModel):
    def __init__(self, model, cuda=False, alpha=0.9, use_buffers=False):
        super().__init__(model=model, cuda=cuda, avg_fn=get_t_adema(alpha), use_buffers=use_buffers)

三、构建一个简单的实验测试

这一部分我正在做实验,下面是调用了一个简单的resnet18网络,看看逻辑上面是否有错。

if __name__=="__main__":
    # 创建 ResNet18 模型
    import torch
    import torchvision.models as models
    from torch.utils.data import DataLoader
    from tqdm import tqdm
    from torch.optim.swa_utils import AveragedModel

    class RandomDataset(torch.utils.data.Dataset):
        def __init__(self, size=224):
            self.data = torch.randn(size, 3, 224, 224)
            self.labels = torch.randint(0, 2, (size,))

        def __getitem__(self, index):
            return self.data[index], self.labels[index]

        def __len__(self):
            return len(self.data)


    model = models.resnet18(pretrained=False)
    # model = model.to('cuda')
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.CrossEntropyLoss()

    # 创建数据加载器
    train_dataset = RandomDataset()
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    # 定义权重平均模型
    swa_model = SWAModel(model, cuda=True)
    ema_model = EMAModel(model, cuda=True)
    t_adema_model = T_ADEMAModel(model, cuda=True)

    for epoch in range(5):
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{5}"):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # 更新权重平均模型
            ema_model.update(model)
            swa_model.update(model)
            t_adema_model.update(model)

    # 测试模型
    test_dataset = RandomDataset(size=100)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


    def evaluate(model):
        model.eval()  # 切换到评估模式
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to('cuda'), labels.to('cuda')
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = correct / total
        print(f"模型准确率:{accuracy * 100:.2f}%")

    # 原模型测试
    print("Model Evaluation:")
    evaluate(model.to('cuda'))   #
    # 测试权重平均模型
    print("SWAModel Evaluation:")
    evaluate(swa_model.to('cuda'))

    print("EMAModel Evaluation:")
    evaluate(ema_model.to('cuda'))

    print("T-ADEMAModel Evaluation:")
    evaluate(t_adema_model.to('cuda'))

运行效果:

Model Evaluation:
模型准确率:46.00%
SWAModel Evaluation:
模型准确率:54.00%
EMAModel Evaluation:
模型准确率:58.00%
T - ADEMAModel Evaluation:
模型准确率:58.00%

仅仅是测试是否能够跑通,过程中也有比原模型要低的时候,而且权值平均主要是用于训练中后期,所以有没有效果应该需要自己去做实验。

当前你可以下载pip install pyzjr==1.2.9,调用from pyzjr.nn import EMAModel运行。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/308024.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据结构(2023-2024)

一、判断题 1.队列是一种插入和删除操作分别在表的两端进行的线性表,是一种先进后出的结构。(F) 队列先进先出,在表的一端插入元素,在表的另一端删除元素。允许插入的一端称为队尾(rear)&#…

搭建Docker私有镜像服务器

一、前言 1、本文主要内容 基于Decker Desktop&Docker Registry构建Docker私有镜像服务器测试在CentOS 7上基于Docker Registry搭建公共Docker镜像服务器修改Docker Engine配置以HTTP协议访问Docker Registry修改Docker Engine配置通过域名访问Docker Registry配置SSL证书…

了解不同方式导入导出的速度之快

目录 一、用工具导出导入 Navicat(速度慢) 1.1、导入: 共耗时: 1.2、导出表 共耗时: 二、用命令语句导出导入 2.1、mysqldump速度快 导出表数据和表结构 共耗时: 只导出表结构 导入 共耗时&…

C#,字符串匹配算法(模式搜索)Z算法的源代码与数据可视化

Z算法也是模式搜索(Pattern Search Algorithm)的常用算法。 本文代码的运算效果: 一、Z 算法 线性时间模式搜索算法的Z算法,在线性时间内查找文本中模式的所有出现。 假设文本长度为 n,模式长度为 m,那么…

__init__中的__getattr__方法

结论: 在 __init__.py 文件中定义的 __getattr__ 方法,如果存在的话,通常用于处理包级别的属性访问。在包级别,__getattr__ 方法在导入模块时被调用,而不是在实例化包时。当你尝试访问包中不存在的属性时,__getattr__ 方法会被调用,给你一个机会来处理这个属性访问。 …

Linux第24步_安装windows下的VisualStudioCode软件

Windows下的VSCode安装后,还需要安装gcc编译器和g编译器。 gcc:编译C语言程序的编译器; g:编译C代码的编译器; 1、在Windows下安装VSCode; 双击“VSCodeUserSetup-x64-1.50.1.exe”,直到安装完成。 2、…

ride无法使用open Browser关键字

一般是版本兼容性问题。将robotframework版本降级为:3.1.2 pip install robotframework3.1.2 2、仍然没有得到解决时,查看robotframework-selenium2library版本 pip list 将robotframework-seleniumlibrary也改成3.XX的版本就可以了 pip unstall robotfr…

Git远端删除的分支,本地依然能看到 git remote prune origin

在远端已经删除ylwang_dev_786等三四个分支,本地git branch -a 时 依然显示存在。 执行 git remote show origin 会展示被删除的那些分支 当你在Git远程仓库(如GitLab)上删除一个分支后,这个变更不会自动同步到每个开发者的本地…

2024年第九届计算机与通信系统国际会议(ICCCS2024) ,邀您相约西安!

会议官网: ICCCS2024 | Xian China 时间: 2024年4月19-22日 地点: 中国西安 会议简介: 近年来,信息通信在不断发展,为计算机网络的进步与发展提供了先进可靠的技术支持。随着计算机网络与通信技术的深入发展,计算机通信技术、数…

排队免单?买东西花了钱还能拿回来?——工会排队模式

随着互联网和电子商务的迅猛发展,消费者的购物需求和期望也在不断升级。为了满足这一需求,工会排队模式作为一种创新消费体验模式应运而生。 工会排队模式是一种颠覆传统的电商模式,它通过向消费者返还现金的方式,重新定义了购物体…

《路由与交换技术》---练习题(无答案纯享版)

注意!!!这篇blog是无答案纯享版的 选择填空的答案我会放评论区 简答题可以看这里 计算题可以发私信问我(当然WeChat也成)but回讯息很慢 一、选择题 1.以下不会在路由表里出现的是: ( ) A.下一跳地址 B.网络地址 C…

Java线程池最全详解

1. 引言 在当今高度并发的软件开发环境中,有效地管理线程是确保程序性能和稳定性的关键因素之一。Java线程池作为一种强大的并发工具,不仅能够提高任务执行的效率,还能有效地控制系统资源的使用。 本文将深入探讨Java线程池的原理、参数配置…

PHP Web应用程序中常见漏洞

一淘模板(56admin.com)发现PHP 是一种流行的服务器端脚本语言,用于开发动态 Web 应用程序。但是,与任何其他软件一样,PHP Web 应用程序也可能遭受安全攻击。 在本文中,我们将讨论 PHP Web 应用程序中一些最常见的漏洞…

linux异常情况,排查处理中

登录客户环境后,发现一个奇怪情况如下图,之前也遇到过,直接fuser -ck /backup操作的话,主机将会重启,因数据库运行中,等待停机维护时间,同时也在想办法不重启的情况下解决该问题 [rootdb ~]# f…

使用西瓜视频官网来创造一个上一集,下一集的按钮,进行视频的切换操作

需求: 仿照西瓜视频写一个视频播放和上一集下一集的按钮功能 回答: 先访问官网: 西瓜播放器 这是西瓜视频的官网, 点击官网的示例按钮,可以看到相关的视频示例以及相关的代码, 我们复制下来代码,然后添加按钮和切换视频的方法, 完整代码: <!DOCTYPE html> <ht…

Hotspot源码解析-第十七章-虚拟机万物创建(三)

17.4 Java堆空间内存分配 分配Java堆内存前&#xff0c;我们先通过两图来了解下C堆、Java堆、内核空间、native本地空间的关系。 1、从图17-1来看&#xff0c;Java堆的分配其实就是从Java进程运行时堆中选中一块内存区域来映射 2、从图17-2&#xff0c;可以看中各内存空间的…

Springboot3(一、lambda、::的应用)

文章目录 一、使用lambda简化实例创建1.语法&#xff1a;2.示例&#xff1a;3.Function包3.1 有入参&#xff0c;有返回值【多功能函数】3.2 有入参&#xff0c;无返回值【消费者】3.3 无入参&#xff0c;有返回值【提供者】3.4 无入参&#xff0c;无返回值 二、类::方法的使用…

基于ssm运动会管理系统的设计与实现 【附源码】

基于ssm运动会管理系统的设计与实现 【附源码】 &#x1f345; 作者主页 央顺技术团队 &#x1f345; 欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; &#x1f345; 文末获取源码联系方式 &#x1f4dd; 项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuil…

C2-3.3.2 机器学习/深度学习——数据增强

C2-3.3.2 数据增强 参考链接 1、为什么要使用数据增强&#xff1f; ※总结最经典的一句话&#xff1a;希望模型学习的更稳健 当数据量不足时候&#xff1a; 人工智能三要素之一为数据&#xff0c;但获取大量数据成本高&#xff0c;但数据又是提高模型精度和泛化效果的重要因…

工业智能网关:HiWoo Box远程采集设备数据

工业智能网关&#xff1a;HiWoo Box远程采集设备数据 在工业4.0和智能制造的浪潮下&#xff0c;工业互联网已成为推动产业升级、提升生产效率的关键。而在这其中&#xff0c;工业智能网关扮演着至关重要的角色。今天&#xff0c;我们就来深入探讨一下工业智能网关。 一、什么…