【目标检测实验系列】YOLOv5模型改进:融合混合注意力机制CBAM,关注通道和空间特征,助力模型高效涨点!(内含源代码,超详细改进代码流程)

       自我介绍:本人硕士期间全程放养,目前成果:一篇北大核心CSCD录用,两篇中科院三区已见刊,一篇中科院四区在投。如何找创新点,如何放养过程厚积薄发,如何写中英论文,找期刊等等。本人后续会以自己实战经验详细写出来,还请大家能够点个关注和赞,收藏一下,谢谢大家。

1. 文章主要内容

       本篇博客主要涉及混合(通道角度与空间角度)注意力机制CBAM融合到YOLOv5模型中。(通读本篇博客需要7分钟左右的时间)

2. 详细代码改进流程

2.1 CBAM源代码(大家自己创建CBAM.py文件)

       注意,博主在CBAM源码当中添加了C3与CBAM结合的代码,还有main函数的测试案例,不影响CBAM的单独使用。

import numpy as np
import torch
from torch import nn
from torch.nn import init

from models.common import Bottleneck, Conv


class ChannelAttention(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.maxpool = nn.AdaptiveMaxPool2d(1)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.se = nn.Sequential(
            nn.Conv2d(channel, channel // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channel // reduction, channel, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        max_result = self.maxpool(x)
        avg_result = self.avgpool(x)
        max_out = self.se(max_result)
        avg_out = self.se(avg_result)
        output = self.sigmoid(max_out + avg_out)
        return output


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        max_result, _ = torch.max(x, dim=1, keepdim=True)
        avg_result = torch.mean(x, dim=1, keepdim=True)
        result = torch.cat([max_result, avg_result], 1)
        output = self.conv(result)
        output = self.sigmoid(output)
        return output


class CBAMBlock(nn.Module):

    def __init__(self, channel=512, reduction=16, kernel_size=7):
        super().__init__()
        self.ca = ChannelAttention(channel=channel, reduction=reduction)
        self.sa = SpatialAttention(kernel_size=kernel_size)

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, _, _ = x.size()
        out = x * self.ca(x)
        out = out * self.sa(out)
        return out

class C3CBAM(nn.Module):
    def __init__(self, c1, c2, n=1, shortcut=True, g=1,
                 e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion #iscyy

        super(C3CBAM, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cbam = CBAMBlock(c1)
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
        # self.m = nn.Sequential(*[CB2d(c_) for _ in range(n)])
        self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

    def forward(self, x):
        out = torch.cat((self.m(self.cv1(self.cbam(x))), self.cv2(self.cbam(x))), dim=1)
        out = self.cv3(out)
        return out

if __name__ == '__main__':
    input = torch.randn(50, 512, 7, 7)
    cbam = C3CBAM(512, 512)
    output = cbam(input)
    print(output.shape)

       需要注意到: 源代码CBAMBlock类只需要传入一个输入的通道数channel,与YOLOv5的C3结构融合后,则C3CBAM需要传入输入和输出通道数,但大家仔细发现在C3CBAM的这行代码self.cbam =CBAMBlock(c1),实际的CBAM也只是需要传入输入的通道数即可。大家可以通过main函数进行测试。另外,在C3CBAM中,其中cv1和cv2方法里面的参数x都先通过了cbam注意力机制,这里大家可以自定义的设置。

2.2 建立一个yolov5-cbam.yaml文件

       注意到,这里博主直接使用C3CBAM代替Backbone部分的四个C3结构,另外注意nc改为自己数据集的类别数。当然,CBAM结构可以自由的放到网络当中的任何结构,但需要特别注意放了之后层次的更替问题,如有不懂,可以查看我之前写的一篇博客(以及评论区注意点):【目标检测实验系列】通过全局上下文注意力机制Global Context Block(GC)融合到YOLOv5案例,吃透简单即插即用注意力机制代码修改要点,举一反三!(超详细改进代码流程)

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 4  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8  小目标
  - [30,61, 62,45, 59,119]  # P4/16 中目标
  - [116,90, 156,198, 373,326]  # P5/32  大目标

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2  output_channel, kernel_size, stride, padding
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3CBAM, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3CBAM, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3CBAM, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3CBAM, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
  
   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

3.3 将C3CBAM引入到yolo.py文件中

       在下图的位置处,引入相关的类即可。
在这里插入图片描述

3.4 修改train.py启动文件

       修改配置文件为yolov5-cbam.yaml即可,如下图所示:
在这里插入图片描述

4. 总结

       本篇博客主要介绍了CBAM融合到YOLOv5模型,通过关注通道和空间特征,助力模型高效涨点。另外,在修改过程中,要是有任何问题,评论区交流;如果博客对您有帮助,请帮忙点个赞,收藏一下;后续会持续更新本人实验当中觉得有用的点子,如果很感兴趣的话,可以关注一下,谢谢大家啦!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/290462.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ElasticSearch自定义算分排序(Function Score Query)

使用 function score query,可以修改文档的相关性算分 (query score),根据新得到的算分排序。 目录 Function Score Query 案例 Function Score Query 几种默认的计算分值的函数: Weight:为每一个文档设置一个简单而不被规范化的权重Fie…

商业综合体能源管理系统功能有哪些?能够为企业带来哪些价值?

随着IT行业的迅猛发展,仅仅实现智能控制已无法满足当前绿色低碳发展的要求,随着智能设备的增多,同时也产生了越来越多的信息孤岛。通过信息技术构建能源管理系统,可以实现先进、智能、经济、可行的商业综合体能源管理解决方案。 …

木工开料机怎么做到减少加工误差的?

在板式家具生产加工中使用木工开料机加工,不仅仅可以提升板材的加工精度,还可以减少板材加工误差,减少板材浪费,所以整个开料环节非常的重要,所以在选择设备的时候一定要谨慎,选择质量可靠的设备。 木工开料…

ssm基于web的网络游戏交易平台信息管理系统的设计与实现论文

摘 要 传统办法管理信息首先需要花费的时间比较多,其次数据出错率比较高,而且对错误的数据进行更改也比较困难,最后,检索数据费事费力。因此,在计算机上安装网络游戏交易平台软件来发挥其高效地信息处理的作用&#x…

HTML5+CSS3+Vue小实例:彩色圆环溶解加载动画

实例:彩色圆环溶解加载动画 技术栈:HTML+CSS+Vue.js 效果: 源码: 【HTML】 <!DOCTYPE html> <html lang="zh-CN"> <head><meta charset="UTF-8"><meta http-equiv="X-UA-Compatible" content="IE=edge&…

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之线性布局容器Column组件

鸿蒙&#xff08;HarmonyOS&#xff09;项目方舟框架&#xff08;ArkUI&#xff09;之线性布局容器Column组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、Column组件 沿垂直方向布局的容器。 子组件 可以包含子组件。 接…

力扣hot100 二叉树的右视图 递归

&#x1f468;‍&#x1f3eb; 题目地址 时间复杂度&#xff1a; O ( n ) O(n) O(n) &#x1f60b; AC code /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* TreeNode right;* TreeNode() {}* TreeNode(…

计算机毕业设计------SSM邮件收发管理系统

项目介绍 本项目为基于SSM的邮件收发管理系统&#xff1b; 用户角色包含以下功能&#xff1a; 用户登录,写信给好友,查看收件箱,查看已发送的邮件,草稿箱查看,通讯录设置,个人资料管理等功能。 环境需要 1.运行环境&#xff1a;最好是java jdk 1.8&#xff0c;我们在这个平…

Golang-strconv库学习笔记

前言&#xff1a; strconv库是go官方提供的一个标准包&#xff0c;主要用于字符串相关的处理。通过参考官方文档、中文文档和其他工具&#xff0c;进行学习记录。学习重点是其中的内置方法。 本文分为Atoi&#xff0c;Format系列&#xff0c;Parse系列&#xff0c;Append系列,…

SparkSQL基础解析(三)

1、 Spark SQL概述 1.1什么是Spark SQL Spark SQL是Spark用来处理结构化数据的一个模块&#xff0c;它提供了2个编程抽象&#xff1a;DataFrame和 DataSet&#xff0c;并且作为分布式SQL查询引擎的作用。 我们已经学习了Hive&#xff0c;它是将Hive SQL转换成MapReduce然后提…

众和策略安全可靠:如何区分大盘股和小盘股?大盘股指数有哪些?

股票根据流转盘能够划分为大盘股、中盘股、小盘股这三大类&#xff0c;那么&#xff0c;怎么区别大盘股和小盘股&#xff1f;大盘股指数有哪些&#xff1f;下面众和策略为我们准备了相关内容&#xff0c;以供参考。 怎么区别大盘股和小盘股&#xff1f; 根据个股的市值来区别…

希亦、觉飞、小吉三款婴儿洗衣机大比拼!全方位对比测评

由于年龄幼小的婴儿的皮肤都非常的幼嫩&#xff0c;因此婴儿衣物材质的类型大部分都是采用为纯棉&#xff0c;并且婴儿的衣物不能够与大人的衣物一起进行混洗&#xff0c;容易把细菌感染到宝宝的衣物上&#xff0c;因此很多家庭为了保证宝宝衣服的有效清洁&#xff0c;避免交叉…

自动化测试用例设计粒度

自动化测试用例的粒度指的是测试用例的细致程度&#xff0c;即每个测试用例检查的功能点的数量和范围。 通常&#xff0c;根据测试用例的粒度&#xff0c;可以被分为3种不同的层次&#xff0c;从更低层次的细粒度到更高层次的粗粒度。 第一种&#xff1a;单元测试 - 细粒度 …

使用IDEA创建使用 JDK8 的 2.x.x 版本的 Spring Boot 项目以及 Spring Boot 项目如何修改JDK版本

目录 一、在阿里云上官网上创建项目 二、将 IDEA 中创建项目的源地址修改为阿里云官网 三、创建 3.x.x 的项目之后修改配置降低至 2.7.x 版本和使用 JDK8&#xff08;修改 Spring Boot 的 JDK 版本同理&#xff09; 从上面的 Spring Boot 官网的截图中可以发现&#xff0c;自…

【算法挨揍日记】day34——647. 回文子串、5. 最长回文子串

647. 回文子串 647. 回文子串 题目描述&#xff1a; 给你一个字符串 s &#xff0c;请你统计并返回这个字符串中 回文子串 的数目。 回文字符串 是正着读和倒过来读一样的字符串。 子字符串 是字符串中的由连续字符组成的一个序列。 具有不同开始位置或结束位置的子串&am…

1 电科院FTU检测标准学习笔记-外观检查

作者简介&#xff1a; 本人从事电力系统多年&#xff0c;岗位包含研发&#xff0c;测试&#xff0c;工程等&#xff0c;具有丰富的经验 在配电自动化验收测试以及电科院测试中&#xff0c;本人全程参与&#xff0c;积累了不少现场的经验 目录 **前言****检测大纲****外观与结构…

docker (portainer 安装nginx)

一、创建容器 二、配置端口&#xff0c;以及容器卷挂载 挂载目录配置&#xff1a;(下方截图的目录如下&#xff0c;docker 改为 mydocker&#xff0c;用docker作为根目录不行) /mydocker/nginx/html /usr/share/nginx/html /mydocker/nginx/conf/nginx.conf /etc/nginx/ng…

18|CAMEL:通过角色扮演脑暴一个鲜花营销方案

18&#xff5c;CAMEL&#xff1a;通过角色扮演脑暴一个鲜花营销方案 CAMEL 交流式代理框架 下面我们一起来看看 CAMEL——这个多 AI 通过角色扮演进行交互的框架&#xff0c;以及它在 LangChain 中的具体实现。 CAMEL&#xff0c;字面意思是骆驼。这个框架来自于论文《CAMEL:…

企业微信开发:自建应用:应用形态(网页,小程序,默认页面)

概述 问题&#xff1a; 企业微信&#xff0c;自建应用&#xff0c;应该实现成什么样子&#xff1f;应用里是一个网页应用吗&#xff1f; 企业微信自建应用可以实现为多种形态&#xff0c;根据实际需求和功能设计&#xff0c;它可以是一个网页应用、一个小程序或者结合企业微信提…

Elasticsearch8集群部署

转载说明&#xff1a;如果您喜欢这篇文章并打算转载它&#xff0c;请私信作者取得授权。感谢您喜爱本文&#xff0c;请文明转载&#xff0c;谢谢。 本文记录在3台服务器上离线搭建es8.7.1版本集群。 1. 修改系统配置 1.1 hosts配置 在三台es节点服务器加入hostname解析&…