【图像分割】【深度学习】PFNet官方Pytorch代码-PFNet网络损失函数模块解析

【图像分割】【深度学习】PFNet官方Pytorch代码-PFNet网络损失函数模块解析

文章目录

  • 【图像分割】【深度学习】PFNet官方Pytorch代码-PFNet网络损失函数模块解析
  • 前言
  • PM定位模块损失函数
  • FM聚焦模块损失函数
  • 总结


前言

在详细解析PFNet代码之前,首要任务是成功运行PFNet代码【win10下参考教程】,后续学习才有意义。本博客讲解PFNet神经网络模块的损失函数模块代码,不涉及其他功能模块代码。

PFNet中有四个输出预测,一个来自定位模块(PM),三个来自聚焦模块(FM),整体的损失函数为:
ℓ o v e r a l l = ℓ p m + ∑ i = 1 3 2 ( 3 − i ) ℓ f m i {\ell _{overall}}{\rm{ }} = {\rm{ }}{\ell _{pm}} + \sum\limits_{i = 1}^3 {{2^{(3 - i)}}} \ell _{fm}^i overall=pm+i=132(3i)fmi
其中 ℓ f m i \ell _{fm}^i fmi表示在PFNet网络中至上往下第 i i i个的聚焦模块的预测的损失。

博主将各功能模块的代码在不同的博文中进行了详细的解析,点击【win10下参考教程】,博文的目录链接放在前言部分。


PM定位模块损失函数

对于PM模块,使用二值交叉熵损失(Binary CrossEntropy Loss,BCE)损失 ℓ b c e \ell _{{\rm{bce}}} bce和IoU损失 ℓ i o u \ell _{{\rm{iou}}} iou的输出,即 ℓ p m = ℓ b c e + ℓ i o u {\ell _{{\rm{pm}}}} = {\ell _{{\rm{bce}}}} + {\ell _{{\rm{iou}}}} pm=bce+iou,以引导PM探索目标对象的初始位置。
二值交叉熵损失 ℓ i o u \ell _{{\rm{iou}}} iou是常见用法,因此不再具体讲解,本小节主要介绍 ℓ i o u \ell _{{\rm{iou}}} iou,因为它不同于目标检测中用于衡量预测边界框与真实边界框之间的重叠程度,而在论文中对此并没有详细解释,因此博主根据论文源码绘制以下示意图具体讲解 ℓ i o u \ell _{{\rm{iou}}} iou的作用:

ℓ i o u = 1 − i o u {\ell _{{\rm{iou}}}} = 1 - iou iou=1iou i o u iou iou重合度越高, ℓ i o u \ell _{{\rm{iou}}} iou损失越小, i o u = i n t e r u n i o n − i n t e r iou = \frac{{{\rm{inter}}}}{{{\rm{union - inter}}}} iou=unioninterinter。那么 i n t e r inter inter u n i o n − i n t e r union - inter unioninter分别表示什么含义呢?博主将根据所绘制的示意图详细说明其中的含义,如上图所示, m a s k mask mask只有前景为1背景为0俩种值, p r e d pred pred的取值范围则在(0~1)之间,为了方便理解博主也是暴力的拆解成前景为0.8背景为0.2俩种值。

  1. i n t e r inter inter表示真实标签 m a s k mask mask和预测标签 p r e d pred pred对应像素相乘后再对像素值求和的值,如上图的inter所示(只表示到对应元素相乘), i n t e r inter inter的含义可以理解成真实标签的前景部分在预测标签上的预测结果,简单来说就是只考虑预测标签针对真实前景的预测效果,默认背景部分完全预测正确,屏蔽了背景不作考虑,因此 i n t e r = T b + P f inter=T_b+P_f inter=Tb+Pf
  2. u n i o n union union表示真实标签 m a s k mask mask和预测标签 p r e d pred pred对应像素相加后再对像素值求和的值,如上图的union所示(只表示到对应元素相加),那么 u n i o n − i n t e r union-inter unioninter的含义可以理解成真实标签的背景部分在预测标签上的预测结果,如上图的union-inter所示,简单来说就是只考虑预测标签针对真实背景的预测效果,默认前景部分完全预测正确,屏蔽了前景不作考虑,因此 u n i o n − i n t e r = T f + P b union-inter=T_f+P_b unioninter=Tf+Pb

T b T_b Tb表示背景位置真实像素求和值(也就是0), P f P_f Pf表示前景位置预测像素求和值, T f T_f Tf表示前景位置真实像素求和值, P b P_b Pb表示背景位置预测像素求和值。
注意!!!!区分背景位置预测像素和预测背景像素俩个概念!!!前者是真实背景像素位置可能真确预测为背景,也可能错误预测成前景;后者则是对预测一个像素位置为背景。

解释了 i n t e r inter inter u n i o n − i n t e r union - inter unioninter的含义, i o u iou iou也可以表示成 i o u = T b + P f T f + P p iou = \frac{{{T_b} + {P_{\rm{f}}}}}{{{T_f} + {P_p}}} iou=Tf+PpTb+Pf T b T_b Tb T f T_f Tf是固定不变的,那么 ℓ i o u \ell _{{\rm{iou}}} iou的优化目标就是 P f P_f Pf越来越大且 P b P_b Pb越来越小。
代码位置:train.py

# PM loss function
bce_loss = nn.BCEWithLogitsLoss().cuda(device_ids[0])
iou_loss = loss.IOU().cuda(device_ids[0])
def bce_iou_loss(pred, target):
    bce_out = bce_loss(pred, target)
    iou_out = iou_loss(pred, target)
    loss = bce_out + iou_out
    return loss

代码位置:loss.py

博主为了方便大家理解,小改了下源码,但是没有丝毫影响源码的原始目的。

class IOU(torch.nn.Module):
    def __init__(self):
        super(IOU, self).__init__()
    def _iou(self, pred, target):
        pred = torch.sigmoid(pred)
        # 交集区域
        inter = (pred * target).sum(dim=(2, 3))
        # 并集区域
        union = (pred + target).sum(dim=(2, 3))
        # iou损失
        iou = 1 - (inter / (union- inter))
        return iou.mean()
    def forward(self, pred, target):
        return self._iou(pred, target)

FM聚焦模块损失函数

对于FM模块,希望更多地关注对象的边界、细长区域或孔处等分散注意力区域。因此,使用加权二值交叉熵损失(Binary CrossEntropy Loss,BCE)损失 ℓ w b c e \ell _{{\rm{wbce}}} wbce和加权IoU损失 ℓ w i o u \ell _{{\rm{wiou}}} wiou的输出,即 ℓ f m = ℓ w b c e + ℓ w i o u {\ell _{{\rm{fm}}}} = {\ell _{{\rm{wbce}}}} + {\ell _{{\rm{wiou}}}} fm=wbce+wiou,以迫使FM更加关注可能的分散注意力区域。
ℓ i o u \ell _{{\rm{iou}}} iou在上个章节就进行了说明, ℓ w i o u \ell _{{\rm{wiou}}} wiou大同小异,因此不再具体讲解,本小节主要介绍 ℓ w b c e \ell _{{\rm{wbce}}} wbce ℓ w i o u \ell _{{\rm{wiou}}} wiou中的 w w w权重的产生,在论文中对此并没有详细解释,因此博主根据论文源码绘制以下示意图具体讲解 w w w的作用:

w w w权重是通过对标签 m a s k mask mask进行平均池化操作,再减去 m a s k mask mask,最后取绝对值:
w = 1 + 5 × ∣ A v g P o o l ( m a s k ) − m a s k ∣ w = 1 + 5 \times \left| {\left. {AvgPool(mask) - mask} \right|} \right. w=1+5×AvgPool(mask)mask
为什么这么简单的操作就能让 w w w更加关注可能的分散注意力区域?博主分以下几种情况讨论:

  • 第一种情况:如上图1所示位置,该前景像素位于前景目标的内部,因此不是对象的边界、细长区域或孔处等分散注意力区域,其 w w w权重计算为1,不需要对其做额外加强;
  • 第二种情况:如上图2所示位置,该前景像素是对象的边界,属于分散注意力区域,其 w w w权重计算为4.9,可谓是剧烈加强;
  • 第三种情况:如上图3所示位置,该背景像素是模糊边界,也属于分散注意力区域,其 w w w权重计算为4.3,也是剧烈加强;
  • 第四种情况:如上图4所示位置,该像素是背景,其 w w w权重计算为1,不需要对其做额外加强;

博主绘制的示意图只是为了方便理解,真实的池化核大小不可能只有3×3那么小,源码中使用的池化核大小是31×31。
代码位置:train.py

# FM loss function
structure_loss = loss.structure_loss().cuda(device_ids[0])

代码位置:loss.py

class structure_loss(torch.nn.Module):
    def __init__(self):
        super(structure_loss, self).__init__()

    def _structure_loss(self, pred, mask):
        print(pred.shape)
        # 根据mask标签生成关于mask的权重
        # 根据公式可以知道,越是靠近前景目标边缘的像素,权重可能就越高,而越靠近前景目标的中心的像素权重越低,最低为1
        weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
        # 因为预测标签还要进行加权,暂时需要保留结构,所以损失在每个元素上计算,reduce选择none
        wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
        # 加权的bce
        wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
        pred = torch.sigmoid(pred)
        # 交集区域
        inter = ((pred * mask) * weit).sum(dim=(2, 3))
        # 并集区域
        union = ((pred + mask) * weit).sum(dim=(2, 3))
        # 加权的iou损失
        wiou = 1 - (inter) / (union - inter)
        return (wbce + wiou).mean()
    def forward(self, pred, mask):
        return self._structure_loss(pred, mask)

总结

尽可能简单、详细的介绍PFNet网络中的损失函数模块的结构和代码。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/196981.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

会声会影2024旗舰版系统配置要求及格式支持

会声会影2024旗舰版是一款广受欢迎的视频编辑软件,它的最新版本,会声会影2023,已经发布。在这篇文章中,我们将探讨会声会影2024旗舰版系统配置要求及格式支持 会声会影2024是一款专业的视频剪辑软件,能够帮助用户制作高…

vue+uniapp校园寻物失物招领平台 微信小程序1f6z5

系统中的核心用户是管理员,管理员登录后,通过管理员菜单来管理后台系统。主要功能有:首页、个人中心、用户管理、物品分类管理、物品信息管理、物品归还管理、留言板管理、系统管理等功能。管理员用例如图3-7所示。 对于本网上失物招领小程序…

unity3d地图、地面跟着NPC跑

清除烘焙后,再 将地图、地面的设置为非静态。只设置NPC的寻路路面为静态,再烘焙

03、K-means聚类实现步骤与基于K-means聚类的图像压缩(1)

03、K-means聚类实现步骤与基于K-means聚类的图像压缩(1) 03、K-means聚类实现步骤与基于K-means聚类的图像压缩(1) 03、K-means聚类实现步骤与基于K-means聚类的图像压缩(2) 开始学习机器学习啦&#xf…

电力智能化系统(智能电力综合监控系统)

电力智能化系统是一个综合性的系统,它利用物联网、云计算、大数据、人工智能等技术,依托电易云-智慧电力物联网,采用智能采集终端和物联网关,将电力设备、用电负荷、电力市场等各个环节有机地联系起来,实现了对电力配送…

一篇学会cron表达式

1、定义 Cron表达式是一种用于定义定时任务的格式化字符串。它被广泛用于Unix、Linux和类Unix系统中,用于在指定的时间执行预定的任务。Cron表达式由6个字段组成,每个字段通过空格分隔开。 在本文中,我们将学习如何理解和编写Cron表达式。 C…

Java高级技术(单元测试)

一,概括 二,junit 三,案例 (1),实验类 package com.bilibili;public class Name {public static void main(String name) {if (name null){System.out.println("0");return;}System.out.print…

电子学会C/C++编程等级考试2022年06月(三级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:制作蛋糕 小A擅长制作香蕉蛋糕和巧克力蛋糕。制作一个香蕉蛋糕需要2个单位的香蕉,250个单位的面粉,75个单位的糖,100个单位的黄油。制作一个巧克力蛋糕需要75个单位的可可粉,200个单位的面粉,150个单位的糖,150个单位的黄…

MyBatis-Plus条件构造器

说明 Wrapper:条件构造抽象类,最顶端父类AbstractWrapper:用于查询条件封装,生成sql的where条件QueryWrapper:查询条件封装UpdateWrapper:更新条件封装AbstractLambdaWrapper:使用Lambda语法La…

基本数据结构二叉树(3)

目录 4.二叉树链式结构的操作 4.1 前置说明 4.2二叉树的遍历 4.2.1 前序、中序以及后序遍历 4.3 节点个数以及高度等 4.二叉树链式结构的操作 4.1 前置说明 由于博主对二叉树的结果掌握还不够深入,因此在讲解相关操作前将手动创建一颗简单的二叉树&#xff0c…

【传智杯】儒略历、评委打分、萝卜数据库题解

🍎 博客主页:🌙披星戴月的贾维斯 🍎 欢迎关注:👍点赞🍃收藏🔥留言 🍇系列专栏:🌙 蓝桥杯 🌙请不要相信胜利就像山坡上的蒲公英一样唾手…

地大与明道云的实践:零代码产教融合与协同育人

摘要 中国地质大学(武汉)与明道云合作,通过建设数字学院的方式,塑造教育数字化新动能。具体实践包括: 联合建设数字学院:选择经济管理学院作为试点,通过统筹规划、统一标准、分步实施的方式&a…

centos 显卡驱动安装(chatglm2大模型安装步骤一)

1.服务器配置 服务器系统:Centos7.9 x64 显卡:RTX3090 (24G) 2.安装环境 2.1 检查显卡驱动是否安装 输入命令:nvidia-smi(显示显卡信息) 如果有以下显示说明,已经有显卡驱动。否则需要重装。 2.2 下载显卡驱动 第一步:浏览器输入https://www.nvidia.cn/Downloa…

vue+elementUI的tabs与table表格联动固定与滚动位置

有个变态的需求,要求tabs左侧固定,右侧是表格,点击左侧tab,右侧表格滚动到指定位置,同时,右侧滚动的时候,左侧tab高亮相应的item 上图 右侧的高度非常高,内容非常多 常规的瞄点不适…

高级JVM

一、Java内存模型 1. 我们开发人员编写的Java代码是怎么让电脑认识的 首先先了解电脑是二进制的系统,他只认识 01010101比如我们经常要编写 HelloWord.java 电脑是怎么认识运行的HelloWord.java是我们程序员编写的,我们人可以认识,但是电脑不…

C语言——数字金字塔

实现函数输出n行数字金字塔 #define _CRT_SECURE_NO_WARNINGS 1#include <stdio.h>void pyramid(int n) {int i,j,k;for (i1; i<n; i){//输出左边空格&#xff0c;空格数为n-i for (j1; j<n-i; j){printf(" "); } //每一行左边空格输完后输出数字&#…

C++初阶模板

介绍&#xff1a; 我们先认识以下C中的模板。模板是一种编程技术&#xff0c;允许程序员编写与数据类型无关的代码&#xff0c;它是一种泛型编程的方式&#xff0c;可以用于创建可处理多种数据类型的函数或类&#xff0c;也就是说泛型编程就是编写与类型无关的通用代码&#xf…

第一百八十二回 自定义一个可以滑动的刻度尺

文章目录 1. 概念介绍2. 思路与方法2.1 实现思路2.2 实现方法3. 示例代码4. 内容总结我们在上一章回中介绍了"如何绘制阴影效果"相关的内容,本章回中将介绍 如何自定义一个可以滑动的刻度尺.闲话休提,让我们一起Talk Flutter吧。 1. 概念介绍 任何优美的文字在图…

1128. 等价多米诺骨牌对的数量

力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台备战技术面试&#xff1f;力扣提供海量技术面试资源&#xff0c;帮助你高效提升编程技能&#xff0c;轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/number-of-equivalent-domino-pa…

模拟退火算法应用——求解TSP问题

仅作自己学习使用 一、问题 旅行商问题(TSP) 是要求从一个城市出发&#xff0c;依次访问研究区所有的城市&#xff0c;并且只访问一次不能走回头路&#xff0c;最后回到起点&#xff0c;求一个使得总的周游路径最短的城市访问顺序。 采用模拟退火算法求解TSP问题&#x…