YOLOv5改进(七)--改进损失函数EIoU、Alpha-IoU、SIoU、Focal-EIOU

文章目录

  • 1、前言
  • 2、损失函数代码实现
    • 2.1、修改metrics.py
    • 2.2、修改loss.py
  • 3、替换EIOU
  • 4、替换SIoU
  • 5、替换Alpha-IoU
  • 6、替换Focal-EIOU
  • 7、目标检测系列文章

1、前言

YOLOv5默认使用损失函数为CIoU,本文主要针对损失函数进行修改,主要将bbox_iou函数进行修改,添加 EIoU、Alpha-IoU、SIoU、Focal-IOU等边界框回归损失。

2、损失函数代码实现

2.1、修改metrics.py

(1)首先找到utils/metrics.py文件,然后找到该python文件下的bbox_iou函数,其实在yolov5源码中设置是有GIoU, DIoU, CIoU这些边界框iou损失,但是默认值都为False

在这里插入图片描述

(2)将原始的bbox_iou函数代码注释掉,替换成如下代码,这段代码是将EIoU、Alpha-IoU、SIoU、Focal-EIOU这几个功能集中在一起,如果想要使用不同的Iou计算边界框损失,只需要修改utils/loss.py下的iou方法即可。

# 优化后的代码
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIoU=False, Focal=False, alpha=1,
             gamma=0.5, eps=1e-7):
    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)

    # Get the coordinates of bounding boxes
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
        w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
            (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    # iou = inter / union # ori iou
    iou = torch.pow(inter / (union + eps), alpha)  # alpha iou
    if CIoU or DIoU or GIoU or EIoU or SIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
        if CIoU or DIoU or EIoU or SIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = (cw ** 2 + ch ** 2) ** alpha + eps  # convex diagonal squared
            rho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (
                        b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha  # center dist ** 2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
                with torch.no_grad():
                    alpha_ciou = v / (v - iou + (1 + eps))
                if Focal:
                    return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter / (union + eps),
                                                                                                 gamma)  # Focal_CIoU
                else:
                    return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha))  # CIoU
            elif EIoU:
                rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2
                rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2
                cw2 = torch.pow(cw ** 2 + eps, alpha)
                ch2 = torch.pow(ch ** 2 + eps, alpha)
                if Focal:
                    return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter / (union + eps),
                                                                                      gamma)  # Focal_EIou
                else:
                    return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2)  # EIou
            elif SIoU:
                # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
                s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps
                s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps
                sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
                sin_alpha_1 = torch.abs(s_cw) / sigma
                sin_alpha_2 = torch.abs(s_ch) / sigma
                threshold = pow(2, 0.5) / 2
                sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
                angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
                rho_x = (s_cw / cw) ** 2
                rho_y = (s_ch / ch) ** 2
                gamma = angle_cost - 2
                distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
                omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
                omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
                shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
                if Focal:
                    return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(
                        inter / (union + eps), gamma)  # Focal_SIou
                else:
                    return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha)  # SIou
            if Focal:
                return iou - rho2 / c2, torch.pow(inter / (union + eps), gamma)  # Focal_DIoU
            else:
                return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        if Focal:
            return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter / (union + eps),
                                                                                      gamma)  # Focal_GIoU https://arxiv.org/pdf/1902.09630.pdf
        else:
            return iou - torch.pow((c_area - union) / c_area + eps, alpha)  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    if Focal:
        return iou, torch.pow(inter / (union + eps), gamma)  # Focal_IoU
    else:
        return iou  # IoU

要点

  1. gamma参数: 是Focal EloU中的gamma参数,一般就是为0.5,有需要可以自行更改。
  2. alpha参数:为Alpha-IOU中的alpha参数,默认为1,即使用原始I0U。若需要使用Alpha-IOU,只需将其设置为任意值(论文中默认设置为3)

2.2、修改loss.py

找到utils/loss.py损失函数计算文件,修改ComputeLoss类下面的__call__函数,通过修改iou = bbox_iou(pbox, tbox[i],x1y1x2y2=False, CIoU=True)里面第4个参数实现不同的损失函数。

在这里插入图片描述

将红框内容替换成如下代码:

iou = bbox_iou(pbox, tbox[i], CIoU=True)  # iou(prediction, target)
if type(iou) is tuple:
    lbox += (iou[1].detach().squeeze() * (1 - iou[0].squeeze())).mean()
    iou = iou[0].squeeze()
else:
    lbox += (1.0 - iou.squeeze()).mean()  # iou loss
    iou = iou.squeeze()

在这里插入图片描述

3、替换EIOU

如果想要使用EIOU,只需要将CIoU替换成EIOU:

iou = bbox_iou(pbox, tbox[i], EIoU=True) 

4、替换SIoU

如果想要使用SIoU,只需要将CIoU替换成SIoU:

iou = bbox_iou(pbox, tbox[i], SIoU=True) 

5、替换Alpha-IoU

如果想要使用Alpha-IoU,只需要添加alpha=3这个参数项开启Alpha,如果不设置该参数,alpha默认为1:

iou = bbox_iou(pbox, tbox[i], CIoU=True, alpha=3) 

6、替换Focal-EIOU

Focal-EIOU相对于EIOU只多了一个Focal项,这两个iou损失都是出自同一篇论文,只需要设置Focal=True即可。

iou = bbox_iou(pbox, tbox[i], EIoU=True, Focal=True) 

当然Focal项也可以用于CIoU、SIoU,至于效果需要根据不同数据集进行测试,修改如下:

Focal-CIoU

iou = bbox_iou(pbox, tbox[i], CIOU=True, Focal=True) 

Focal-SIoU

iou = bbox_iou(pbox, tbox[i], SIOU=True, Focal=True) 

7、目标检测系列文章

  1. YOLOv5s网络模型讲解(一看就会)
  2. 生活垃圾数据集(YOLO版)
  3. YOLOv5如何训练自己的数据集
  4. 双向控制舵机(树莓派版)
  5. 树莓派部署YOLOv5目标检测(详细篇)
  6. YOLO_Tracking 实践 (环境搭建 & 案例测试)
  7. 目标检测:数据集划分 & XML数据集转YOLO标签
  8. DeepSort行人车辆识别系统(实现目标检测+跟踪+统计)
  9. YOLOv5参数大全(parse_opt篇)
  10. YOLOv5改进(一)-- 轻量化YOLOv5s模型
  11. YOLOv5改进(二)-- 目标检测优化点(添加小目标头检测)
  12. YOLOv5改进(三)-- 引进Focaler-IoU损失函数
  13. YOLOv5改进(四)–轻量化模型ShuffleNetv2
  14. YOLOv5改进(五)-- 轻量化模型MobileNetv3
  15. YOLOv5改进(六)–引入YOLOv8中C2F模块

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/743988.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

昇思25天学习打卡营第2天|onereal》

今天学习内容是了解华为昇思平台。虽然打了卡,但是我的jupyter里面并没有播放按钮,所以还是无法运行代码。反映给昇思吴彦祖小哥了,他说需要专家帮我解决。 我还是要自我表扬一下,不懂就问,切莫不懂装懂,那…

充电宝哪个牌子最好最耐用?耐用西圣、罗马仕、绿联充电宝实测

目前充电宝是我们出行必备的“能量伴侣”。然而,市面上充电宝品牌繁多,让人眼花缭乱,究竟哪个牌子最好最耐用呢?为了给大家找到答案,我们精心挑选了西圣、罗马仕和绿联这三个备受关注的品牌,并对它们的充电…

clion调试opencv程序时,查看内存中图像的方法

很久之前在windows上使用 visual studio中 Image Watch插件查看内存中图像感觉很方便,后来转到liunx上开发后,在clion上一直没有找到对应的插件。之前看到过qtOpenImageDebugger好像有类似功能,vs code好像也可以,但是没有配置。 …

文华6幅图指标公式大全-多空精准买卖点提示指标源码

文华6幅图指标公式大全-多空精准买卖点提示指标源码: HH: HHV ( HIGH ,1)/5 HHV ( HIGH ,2)/5 HHV ( HIGH ,2)/5 HHV ( HIGH ,5)/5 HHV ( HIGH ,8)/5; LL: LLV ( LOW ,1)/5 LLV ( LOW ,2)/5 LLV ( LOW ,2)/5 LLV ( LOW ,5)/5 LLV ( LOW ,8)/5; H1: IFELSE ( H &l…

18.枚举

学习知识:枚举类型、相关的使用方法 Main.java: public class Main {public static void main(String[] args) {myenum[] colorlist myenum.values();//获取枚举中所有对象的引用数组for (myenum one : colorlist){System.out.println(one.toString(…

git提交新仓库代码,提示无权限,但用户名已修改

目录 1 用户名无权限 2 删除用户凭据 2.1 打开控制面板 2.2 找到“凭据管理器” 2.3 删除git历史 3 npm工具库添加git仓库指引 1 用户名无权限 之前因为时间的原因,js-tool-big-box工具库没有提交到github上去,今天想着往上提交一下,但…

如何打造安全DNS以保障业务可用?一文解读

DNS自1987年被实施以来,已成为网络通信中最重要的核心基础设施,同时也是企业对外提供数字服务的关键。没有正常安全的DNS服务,企业经营也就无从谈起。在DNS攻击逐年上升且容易被忽略的现代应用时代,如何打造安全DNS?本…

Mybatis操作数据库(二)

动态SQL 在 MyBatis 中&#xff0c;动态 SQL 是一种强大的特性&#xff0c;它允许根据不同的条件和场景在 SQL 语句中动态地组合和构建部分语句 < if>标签 我们在注册用户的时候,可能会有一个问题注册有两个字段 必填字段和非必填字段,这时我们需要使用动态标签来判定 例…

数据库管理与数据库语句

数据库用户管理及高级sql语句 数据库管理 数据库用户管理 mysql权限表 在mysql中mysql库中的user表是最重要的权限表&#xff0c;记录允许连接到服务器的账号信息以及全局权限&#xff0c; 在mysql库中db和host表也是重要的权限表 db表中存储了用户对某个数据库的操作权限&…

CS-流量通讯特征修改-端口store证书流量通讯规则

免责声明:本文仅做技术交流与学习... 目录 1.修改默认端口&#xff1a; 2.去除store证书特征&#xff1a; 查看证书指纹&#xff1a; 生成证书指纹&#xff1a; 应用证书指纹&#xff1a; 3.去除流量通讯特征&#xff1a; 规则资源 http流量特征修改: https流量特征修改:…

Web APIs-DOM-事件相关整理(完成网页交互)

目录 1.事件监听 2.事件监听绑定 3.事件类型 4.实例注意 5.事件对象 6.环境对象 7.回调函数 1.事件监听 &#xff08;绑定事件/注册事件&#xff09;: 程序检测有没有事件产生&#xff08;事件&#xff1a;比如单机一个按钮&#xff08;编程时系统发生的动作或者事情&a…

C++ 14新特性个人总结

variable templates 变量模板。这个特性允许模板被用于定义变量&#xff0c;就像之前模板可以用于定义函数或类型一样。变量模板为模板编程带来了新的灵活性&#xff0c;特别是在定义泛化的常量和元编程时非常有用。 变量模板的基本语法 变量模板的声明遵循以下基本语法&am…

云计算基础知识

前言&#xff1a; 随着ICT技术的高速发展&#xff0c;企业架构对计算、存储、网络资源的需求更高&#xff0c;急需一种新的架构来承载业务&#xff0c;以获得持续&#xff0c;高速&#xff0c;高效的发展&#xff0c;云计算应运而生。 云计算背景 信息大爆炸时代&#xff1a…

HttpServletRequest・getContentLeng・getContentType区别

getContentLength()&#xff1a; 获取客户端发送到服务器的HTTP请求主体内容的字节数&#xff08;长度&#xff09; 如果请求没有正文内容&#xff08;如GET&#xff09;&#xff0c;或者请求头中没有包含Content-Length字段&#xff0c;则该方法返回 -1 getContentType()&am…

【昇思初学入门】第七天打卡-模型训练

训练模型 学习心得 构建数据集。这通常包括训练集、验证集&#xff08;可选&#xff09;和测试集。训练集用于训练模型&#xff0c;验证集用于调整超参数和监控过拟合&#xff0c;测试集用于评估模型的泛化能力。 &#xff08;mindspore提供数据集https://www.mindspore.cn/d…

绘制全球各大洲典型流域的时间序列图

流量世界第一、长度第二的亚马逊流域&#xff08;Amazon&#xff09;、南美洲第四大、整条河流位于巴西的圣弗朗西斯科流域&#xff08;Sao Francisco&#xff09;、世界第四长、北美洲最长的密西西比流域&#xff08;Mississippi&#xff09;、欧洲最长的伏尔加流域&#xff0…

GitLab 不小心提交了master/develop版本如何回退

1. 找寻最近的版本&#xff0c;使用git reset --hard 回退到具体的提交版本号 2. git push origin master --force 这个会遇到gitlab默认拦截&#xff0c;处理版本 版本仓库页面&#xff0c;选择Setting——Repository&#xff0c;找到Protected branches 3. 再回到master分支…

python爬虫-爬虫的基础知识储备

爬虫就是一个不断的去抓去网页的程序&#xff0c;根据我们的需要得到我们想要的结果&#xff01;但我们又要让服务器感觉是我们人在通过浏览器浏览不是程序所为&#xff01;归根到底就是我们通过程序访问网站得到html代码&#xff0c;然后分析html代码获取有效内容的过程。下面…

开源/标准版 首页 logo大小修改

这个是diy的&#xff1a; 文件地址&#xff1a;template/uni-app/pages/index/diy/components/headerSerch.vue 这个是页面设计的&#xff1a; 文件地址&#xff1a;template/uni-app/pages/index/visualization/components/headerSerch.vue 先删除这三个 然后改下图的地方

C++ 模板:全特化和偏特化

目录 全特化&#xff08;Full Specialization&#xff09; 偏特化&#xff08;Partial Specialization&#xff09; 特点和使用场景 注意事项 在C中&#xff0c;模板特化&#xff08;template specialization&#xff09;是一种强大的功能&#xff0c;允许对模板进行特定情…