Yolov8有效涨点,添加多种注意力机制,修改损失函数提高目标检测准确率

目录

简介

CBAM注意力机制原理及代码实现

原理

 代码实现

 GAM注意力机制

原理

代码实现

修改损失函数

YAML文件

完整代码


🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

http://t.csdnimg.cn/sVHxvicon-default.png?t=N7T8http://t.csdnimg.cn/sVHxv

简介

Ultralytics 推出了最新版本的 YOLO 模型。注意力机制是提高模型性能最热门的方法之一。

本次将介绍几种常见的注意力机制,这些注意力机制在大多数的数据集上均能有效的提升目标检测的精度/召回率/准确率。

CBAM注意力机制原理及代码实现
原理
CBAM注意力机制结构图

CBAM(Convolutional Block Attention Module)是一种用于卷积神经网络(CNN)的注意力机制,它能够增强网络对输入特征的关注度,提高网络性能。CBAM 主要包含两个子模块:通道注意力模块(Channel Attention Module)和空间注意力模块(Spatial Attention Module)。

以下是CBAM注意力机制的基本原理:

1. 通道注意力模块(Channel Attention Module):
输入:经过卷积层的特征图。
处理步骤:
对每个通道进行全局平均池化,得到通道的全局平均值。
通过两个全连接层,将全局平均值映射为两个权重向量(一个用于缩放,一个用于偏置)。
将这两个权重向量与原始特征图相乘,以加权调整每个通道的重要性。

2. 空间注意力模块(Spatial Attention Module):**
输入:通道注意力模块的输出。
处理步骤:
     对每个通道的特征图进行分别的最大池化和平均池化,得到两个空间特征图。
     将这两个空间特征图相加,通过一个卷积层产生一个权重图。
     将原始特征图与权重图相乘,以加权调整每个空间位置的重要性。

3. 整合:
   将通道注意力模块和空间注意力模块的输出相乘,得到最终的注意力增强特征图。
   将这个注意力增强的特征图传递给网络的下一层进行进一步处理。

CBAM的关键优势在于它能够同时考虑通道和空间信息,有助于网络更好地理解和利用输入特征。这种注意力机制有助于提高网络在视觉任务上的性能,使其能够更有针对性地关注重要的特征。

 代码实现

路径:"./ultralytics/nn/modules/conv.py"

class ChannelAttention(nn.Module):
    """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""

    def __init__(self, channels: int) -> None:
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.act = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * self.act(self.fc(self.pool(x)))


class SpatialAttention(nn.Module):
    """Spatial-attention module."""

    def __init__(self, kernel_size=7):
        """Initialize Spatial-attention module with kernel size argument."""
        super().__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.act = nn.Sigmoid()

    def forward(self, x):
        """Apply channel and spatial attention on input for feature recalibration."""
        return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))


class CBAM(nn.Module):
    """Convolutional Block Attention Module."""

    def __init__(self, c1, kernel_size=7):  # ch_in, kernels
        super().__init__()
        self.channel_attention = ChannelAttention(c1)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        """Applies the forward pass through C1 module."""
        return self.spatial_attention(self.channel_attention(x))

添加完代码以后需要在"./ultralytics/nn/tasks.py"进行注册

 GAM注意力机制
原理

目标的设计是一种减少信息缩减并放大全局维度交互特征的机制。我们采用 CBAM 的顺序通道空间注意力机制并重新设计子模块。整个过程如图所示。

GAM结构图


通道注意力机制
通道注意力子模块使用 3D 排列来保留三个维度的信息。然后,它使用两层 MLP(多层感知器)放大跨维度通道空间依赖性。 (MLP是一种编码器-解码器结构,其缩减比为r,与BAM相同。)通道注意子模块如图所示。 

通道注意力子模块


空间注意力机制
在空间注意力子模块中,为了关注空间信息,我们使用两个卷积层进行空间信息融合。我们还使用与 BAM 相同的通道注意子模块的缩减率 r。同时,最大池化会减少信息并产生负面影响。我们删除池化以进一步保留特征图。因此,空间注意力模块有时会显着增加参数的数量。为了防止参数显着增加,我们在 ResNet50 中采用带有通道洗牌的组卷积。没有组卷积的空间注意力子模块如图所示。 

空间注意力子模块
代码实现

代码添加在 ./ultralytics/nn/modules/conv.py 中,同样需要在task.py中注册

class GAM_Attention(nn.Module):
    def __init__(self, c1, c2, group=True, rate=4):
        super(GAM_Attention, self).__init__()
 
        self.channel_attention = nn.Sequential(
            nn.Linear(c1, int(c1 / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(c1 / rate), c1)
        )
 
        self.spatial_attention = nn.Sequential(
 
            nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(c1, int(c1 / rate),
                                                                                                     kernel_size=7,
                                                                                                     padding=3),
            nn.BatchNorm2d(int(c1 / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(int(c1 / rate), c2,
                                                                                                     kernel_size=7,
                                                                                                     padding=3),
            nn.BatchNorm2d(c2)
        )
 
    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)
        # x_channel_att=channel_shuffle(x_channel_att,4) #last shuffle
        x = x * x_channel_att
 
        x_spatial_att = self.spatial_attention(x).sigmoid()
        x_spatial_att = channel_shuffle(x_spatial_att, 4)  # last shuffle
        out = x * x_spatial_att
        # out=channel_shuffle(out,4) #last shuffle
        return out
修改损失函数

WIoU是一种新型的损失函数,代码实现

def WIoU(cls, pred, target, self=None):
        self = self if self else cls(pred, target)
        dist = torch.exp(self.l2_center / self.l2_box.detach())
        return self._scaled_loss(dist * self.iou)

 这个其实就是修改了loss.py中的BboxLoss,在本段代码的第十二行,将type改成了WIoU

class BboxLoss(nn.Module):
 
    def __init__(self, reg_max, use_dfl=False):
        """Initialize the BboxLoss module with regularization maximum and DFL settings."""
        super().__init__()
        self.reg_max = reg_max
        self.use_dfl = use_dfl
 
    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
        """IoU loss."""
        weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
        loss,iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False,type_='WIoU')
        loss_iou=loss.sum()/target_scores_sum
 
        # DFL loss
        if self.use_dfl:
            target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
            loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
            loss_dfl = loss_dfl.sum() / target_scores_sum
        else:
            loss_dfl = torch.tensor(0.0).to(pred_dist.device)
 
        return loss_iou, loss_dfl
YAML文件
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 9  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
 
# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
 
# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12
  - [-1, 1, GAM_Attention, [512,512]]
 
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)
  - [-1, 1, GAM_Attention, [256,256]]
 
  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 20 (P4/16-medium)
  - [-1, 1, GAM_Attention, [512,512]]
 
  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 24 (P5/32-large)
  - [-1, 1, GAM_Attention, [1024,1024]]
 
  - [[17, 21, 25], 1, Detect, [nc]]  # Detect(P3, P4, P5)

在head部分,可以将GAM_attention改成不同的注意力机制,来改变网络结构,从而提升目标检测 的精度

完整代码

链接: https://pan.baidu.com/s/1IDnEZxpcaEgBowlTxX2iNA?pwd=vdrs 提取码: vdrs 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/442597.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据结构:Heap(二叉树)的基本操作

目录 1.有关二叉树必须知道的几个基本概念 2.有关二叉树的基本操作 2.0有关元素的定义以及要进行的操作 2.1初始化和销毁操作 2.2插入操作以及上调操作 2.2.1插入操作以及上调操作的图解 2.2.2插入操作以及上调操作的代码 2.3删除根元素及其下调操作 2.3.2删除根元素及…

Pandas 之 merge

merge的作用: merge函数在Python的pandas库中的作用是用来合并两个或多个DataFrame数据表,依据指定的一个或多个键(通常是列名)进行连接操作[1]。 merge函数可以有多种连接类型(如内连接inner、左连接left、右连接ri…

【科研基础|课程】信息论

信息论课程 -上海交大 - 2020春季学期 F:\B\2.sources\CS258信息论 文章目录 1- 信息熵 | 联合熵 | 条件熵 | 链式法则1.1-Entropy1- 信息熵 | 联合熵 | 条件熵 | 链式法则 P3 信息熵 | 联合熵 | 条件熵 | 链式法则 1.1-Entropy Entropy: Brief History 熵的由来,由热力学…

【sw网络监控】通过snmp协议相关的snmp-exporter(收集交换机网络监控数据)+ promethus + grafana

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》:python零基础入门学习 《python运维脚本》: python运维脚本实践 《shell》:shell学习 《terraform》持续更新中:terraform_Aws学习零基础入门到最佳实战 《k8…

● 309.最佳买卖股票时机含冷冻期 ● 714.买卖股票的最佳时机含手续费

● 309.最佳买卖股票时机含冷冻期 多加条件:卖出之后有一天冷冻期不能买入,即卖出之后至少隔一天才能再买入。 要搞清楚每一天有什么状态:持有股票(已买入)、不持有股票(已卖出)。不持有股票…

代码随想录 回溯算法-棋盘问题

目录 51.N皇后 37.解数独 51.N皇后 51. N 皇后 困难 按照国际象棋的规则,皇后可以攻击与之处在同一行或同一列或同一斜线上的棋子。 n 皇后问题 研究的是如何将 n 个皇后放置在 nn 的棋盘上,并且使皇后彼此之间不能相互攻击。 给你一个整数 n &…

【论文精读】融合知识图谱和语义匹配的医疗问答系统

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢…

LLM 推理优化探微 (2) :Transformer 模型 KV 缓存技术详解

编者按:随着 LLM 赋能越来越多需要实时决策和响应的应用场景,以及用户体验不佳、成本过高、资源受限等问题的出现,大模型高效推理已成为一个重要的研究课题。为此,Baihai IDP 推出 Pierre Lienhart 的系列文章,从多个维…

世界上最伟大的商业模式是“让利”,总结10套消费返利玩转市场!

文丨微三云营销总监胡佳东,点击上方“关注”,为你分享市场商业模式电商干货。 - 引言:很多企业家朋友说,生意越来越难做了,市场太卷、同行价格低、难招商、难融资、难推广,其实你是不懂“人心”&#xff…

118.龙芯2k1000-pmon(17)-制作ramdisk

目前手上这个设备装系统不容易,总是需要借助虚拟机才能实现。 对生产就不太那么友好,能否不用虚拟机就能装Linux系统呢? 主要是文件系统的问题需要解决,平时我们一般是用nfs挂载后,然后对硬盘格式化,之后…

博特激光——激光打标机工作原理介绍

激光打标机,作为现代标识技术的杰出代表,其工作原理的高效与精确性使得它在众多行业中占据了举足轻重的地位。今天,我们将深入探讨激光打标机的工作原理及其背后的科技魅力。 激光打标机的工作原理主要基于激光的高能量和聚焦特性。首先&…

AI新工具 百分50%算力确达到了GPT-4水平;将音乐轨道中的人声、鼓声、贝斯等音源分离出来等

1: Pi 百分50%算力确达到了GPT-4水平 Pi 刚刚得到了巨大的升级!它现在由最新的 LLMInflection-2.5 提供支持,它在所有基准测试中都与 GPT-4 并驾齐驱,并且使用不到一半的计算来训练。 地址:https://pi.ai/ 2: Moseca 能将音乐…

【REST2SQL】12 REST2SQL增加Token生成和验证

【REST2SQL】01RDB关系型数据库REST初设计 【REST2SQL】02 GO连接Oracle数据库 【REST2SQL】03 GO读取JSON文件 【REST2SQL】04 REST2SQL第一版Oracle版实现 【REST2SQL】05 GO 操作 达梦 数据库 【REST2SQL】06 GO 跨包接口重构代码 【REST2SQL】07 GO 操作 Mysql 数据库 【RE…

阿里云有免费服务器吗?在哪领?

阿里云服务器免费试用申请链接入口:aliyunfuwuqi.com/go/free 阿里云个人用户和企业用户均可申请免费试用,最高可以免费使用3个月,阿里云服务器网分享阿里云服务器免费试用申请入口链接及云服务器配置: 阿里云免费服务器领取 阿里…

数据结构之deque双端队列

一、概念: 众所周知,数据结构是用来存储数据,deque也不例外,他是集结了队列和栈的性质而成的结构,他几乎拥有所有数据结构能有的操作,看似已经大杀四方,可实际情况如何呢,那就带者这…

ssm+vue的农业信息管理系统(有报告)。Javaee项目,ssm vue前后端分离项目。

演示视频: ssmvue的农业信息管理系统(有报告)。Javaee项目,ssm vue前后端分离项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构&…

C语言第三十七弹---文件操作(下)

✨个人主页: 熬夜学编程的小林 💗系列专栏: 【C语言详解】 【数据结构详解】 文件操作 1、文件的随机读写 1.1、fseek 1.2、ftell 1.3、rewind 2、文件读取结束的判定 2.1、被错误使用的 feof 3、文件缓冲区 总结 1、文件的随机读写…

“轻松入门Electron:一步步构建梦想中的桌面软件

在数字化的浪潮中,桌面应用依旧占据着其独特而重要的位置,不论是在企业解决方案、专业工具软件还是个性化应用领域中都是如此。随着技术的演进,创建这些应用的过程已经变得更为简单和可行,尤其是随着Electron等框架的出现。Electr…

启动查看工具总结

启动目标:2s内优秀,2-5s普通,之后的都需要优化,热启动则是1.5s-2s内 1 看下大致串联启动流程: App 进程在 Fork 之后,需要首先执行 bindApplication Application 的环境创建好之后,就开始activ…

Dynamo——常用几何形体的创建与编辑(二)

上一次,我们简单整理了一些创建几何形体的节点用法,今天我们接着整理一些,几何形体的编辑方法。 一、坐标点的平移复制 [Point.Add] 使用节点 “Vector.ByCoordinates” 生成一个向量,将该向量连接到 “Point.Add” 节点的输入端 …