把组合损失中的权重设置为可学习参数

目前的需求是:有一个模型,准备使用组合损失,其中有2个或者多个损失函数。准备对其进行加权并线性叠加。但想让这些权重进行自我学习,更新迭代成最优加权组合。

目录

1、构建组合损失类

2、调用组合损失类

3、为其构建优化器

4、梯度归零

5、跟新优化器参数

6、结果展示


1、构建组合损失类

每项损失函数可以定义在init里面,这样的话就只需要模型的输出和训练目标。我这里没有这样设置,选择把每项损失值传过来进行线性加权叠加。

# 定义组合损失函数---------------------------------------START
class CombinedLoss(nn.Module):
    def __init__(self):
        super(CombinedLoss, self).__init__()
        # 定义损失函数权重作为可训练参数
        self.w_adv = nn.Parameter(torch.ones(1, requires_grad=True))  # 对抗损失的权重,初始值为0.2 
        self.w_con = nn.Parameter(torch.ones(1, requires_grad=True))  # 内容感知损失的权重,初始值为0.2
        self.w_mse = nn.Parameter(torch.ones(1, requires_grad=True))  # 均方误差损失的权重,初始值为0.2
        self.w_s3im = nn.Parameter(torch.ones(1, requires_grad=True))  # 随机结构相似性损失的权重,初始值为0.2
        self.w_gui = nn.Parameter(torch.ones(1, requires_grad=True))  # 边缘引导损失的权重,初始值为0.2


    def forward(self, loss_adv, loss_con, loss_mse, loss_s3im, loss_gui):
        return self.w_adv*loss_adv + self.w_con*loss_con + self.w_mse*loss_mse + self.w_s3im*loss_s3im + self.w_gui*loss_gui

2、调用组合损失类

在计算组合损失之前,需要初始化类对象。

combinedloss = Loss.CombinedLoss()

unet_loss = self.combinedloss(
                            loss_adv = unet_gan_loss, 
                            loss_con = gen_content_loss, 
                            loss_mse = unet_criterion, 
                            loss_s3im = s3im_loss, 
                            loss_gui = guid_loss)

3、为其构建优化器

最好单独构建优化器,这样我们可以设置与总损失不用的学习率。避免学习率过大导致梯度消失。

self.lr_weight_optimizer = optim.Adam(
            self.combinedloss.parameters(),
            lr = 1e-4,
            betas=(0.9, 0.999)
        )

4、梯度归零

在每次计算总损失之前,需要把每个优化器的梯度归零

self.lr_weight_optimizer.zero_grad()

5、跟新优化器参数

在总损失反向传播之后,需要对优化器的参数进行更新

self.lr_weight_optimizer.step()

6、结果展示

每个权重都会自动更新。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/509716.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

filebeat日常使用.

先决条件(已经安装好filebeatelasticsearch) 详情可以参考官网,或者 虚拟机部署elasticsearch集群-CSDN博客 centos7 使用rpm包部署filebeat-CSDN博客 filebeat配置文件 标准输入console输出 (1)编写测试的配置⽂件 mkdir /etc/filebeat/config [rootelk101 /tmp]$vim /et…

练习 15 Web [极客大挑战 2019]Knife

因为没有按顺序刷题,这道送分题 要知道两个点 1.什么是“一句话木马” 2.什么是“菜刀” or“蚁剑” 蚁剑连接就行 flag{1b894a2d-8df1-4b06-9054-e90d73257006}

春季养生从睡眠开始

睡眠可以缓解疲劳,恢复精神状态。特别是在春季,调整好睡眠对于养生是很有好处的,HUAWEI WATCH GT 4 系列用更科学的睡眠监测、管理与改善,守护你的好梦。

js 基础知识 forEach 和 map 的区别,及 map 不加 return 返回什么

问题一:forEach 和 map 之间的区别: 1、forEach 不返回新数组,map 返回新数组,其中包含回调函数的返回值。 2、用途:如果只想要遍历数组并对每个元素执行某些操作而不产生新数组,那么应该使用 forEach&am…

Selinux安全策略文件

在Selinux框架中,安全策略都是写在te文件中,以adb.te 文件为例 allow adbd shell_data_file:dir create_dir_perms;策略的基本格式是: rule_name source_type target_type :object_class perm_setrule_name 规则名。常见的规则名有allow,ne…

异常,Lambda表达式

文章目录 异常介绍存在形式程序中异常发生后的第一反应体系JVM的默认处理方案处理方式声明 throws概述格式抛出 throw格式注意意义 throws和throw的区别 捕获 try,catch介绍格式执行方式多异常捕获处理意义 如何选择用哪个 Throwable类介绍常用方法 自定义异常概述实现步骤范例…

论文阅读RangeDet: In Defense of Range View for LiDAR-based 3D Object Detection

文章目录 RangeDet: In Defense of Range View for LiDAR-based 3D Object Detection问题笛卡尔坐标结构图Meta-Kernel Convolution RangeDet: In Defense of Range View for LiDAR-based 3D Object Detection 论文:https://arxiv.org/pdf/2103.10039.pdf 代码&…

docker--部署 (超详版) (五)

环境准备:docker,mysql,redis,镜像,nginx 把虚拟机打开,连接xshell,参考博客: https://blog.csdn.net/m0_74229802/article/details/136965820?spm1001.2014.3001.5501 一&#x…

什么是过载

宇航员相关知识会涉及到过载,导弹相关知识也会涉及到过载,如导弹的过载加速度,什么是过载呢?博主从B站上看到一UP主讲的很好, 该up主视频链接: 过载是什么_哔哩哔哩_bilibili 内容截图如下:

vue+element ui实现表单组件的封装

效果图: 主要是使用vue elmentUi 的from的基础上进行封装 使用改组件,是需要自定义从父组件传入一下字段表单字段 export const topicTypeMainTaskEdit: any new Map([// 主任务可编辑状态[feasibleInstructions, // 督办件[{value: documentNum…

B端:别看分页你天天用,你未必能摸清分页门道,一文告诉你.

在B端系统中,分页的作用是在处理大量数据时,将数据分成多个页面进行展示,以提高系统的性能和用户体验。 分页可以有效减少单个页面加载的数据量,加快页面加载速度,减少服务器和客户端的压力,并且方便用户浏…

vue使用i18n

🎈个人主页:靓仔很忙i 💻B 站主页:👉B站👈 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:前端 🤝希望本文对您有所裨益,如有不足之处&#x…

如何使用cobbler定制安装

cobbler批量安装 客户端安装 新建虚拟机从pxe启动,若出现以下界面则表示成功: uos9安装 rockylinux9安装: 切换镜像 [rootcjy ~]# umount /mnt [rootcjy ~]# mount /dev/cdrom /mnt 升级 [rootcjy ~]# cobbler signature update task sta…

【御控物联】JavaScript JSON结构转换(16):对象To数组——综合应用

文章目录 一、JSON结构转换是什么?二、术语解释三、案例之《JSON对象 To JSON数组》四、代码实现五、在线转换工具六、技术资料 一、JSON结构转换是什么? JSON结构转换指的是将一个JSON对象或JSON数组按照一定规则进行重组、筛选、映射或转换&#xff0…

Redis数据库常用命令和数据类型

文章目录 一、Redis数据库常用命令1、set/get2、keys3、exists4、del5、type6、rename6.1 重命名6.2 覆盖 7、renamenx8、dbsize9、密码设置10、密码验证11、查看密码12、取消密码13、Redis多数据库常用命令13.1 多数据库间切换13.2 多数据库间移动数据13.3 清除数据库数据 二、…

可视化大屏的行业应用(1):智慧生产和智能制造

可视化大屏在生产制造行业中具有重要的应用价值,主要表现在以下几个方面: 生产监控:可视化大屏可以实时显示生产线的运行状态和生产数据,包括设备运行状态、生产效率、产品质量等,帮助管理人员及时监控生产情况&#…

【新手适用】手把手教你从零开始实现一个基于Pytorch的卷积神经网络CNN三: 如何验证和测试模型

【新手适用】手把手教你从零开始实现一个基于Pytorch的卷积神经网络CNN二: 如何训练模型,内附详细损失、准确率、均值计算-CSDN博客 从零开始实现一个基于Pytorch的卷积神经网络 - 知乎 (zhihu.com) 1 初始化、导入模型和数据集 新建一个test.py文件&a…

Android的图片加载框架

Android的图片加载框架 为什么要使用图片加载框架?图片加载框架1. Universal Image Loader [https://github.com/nostra13/Android-Universal-Image-Loader](https://github.com/nostra13/Android-Universal-Image-Loader)2. Glide [https://muyangmin.github.io/gl…

SWM341系列SDRAM应用

SWM341系列SDRAM应用 1、不同的时钟频率下,SDRAM的初始化参数设置 现象:驱屏应用,显示一段时间后出现卡住的现象 分析:SDRAM的初始 化参数优化 主频150Mhz,建议配置CASL 3,TRFC ≥8。 主频100Mhz,ClkDiv可配置为1…

低压500kW 发电机组测试交流干式负载箱的工作方式

低压500kW发电机组测试交流干式负载箱是一种专门用于测试发电机组性能的设备,其主要功能是在发电机组运行过程中,通过模拟实际负载情况,对发电机组的输出功率、电压、电流等参数进行实时监测和调整,以确保发电机组在各种工况下的稳…