yolov5增加AFPN-全新特征融合模块AFPN,效果完胜PAFPN

论文学习:AFPN: Asymptotic Feature Pyramid Network for Object Detection-全新特征融合模块AFPN,完胜PAFPN_athrunsunny的博客-CSDN博客

先上配置文件yolov5s-AFPN.yaml

# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[4, 1, Conv, [64, 1, 1]], 
   [6, 1, Conv, [128, 1, 1]], 

   [[10, 11], 1, ASFF_2, [64, 0]], 
   [[10, 11], 1, ASFF_2, [128, 1]], 

   [-2, 1, C3, [64, False]], 
   [-2, 1, C3, [128, False]], 

   [9, 1, Conv, [256, 1, 1]],

   [[14, 15, 16], 1, ASFF_3, [64, 0]],
   [[14, 15, 16], 1, ASFF_3, [128, 1]],
   [[14, 15, 16], 1, ASFF_3, [256, 2]],

   [17, 1, C3, [64, False]],
   [18, 1, C3, [128, False]],
   [19, 1, C3, [256, False]],
   [[20, 21, 22], 1, Detect, [nc, anchors]]
]

在models/common.py增加

class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super(Upsample, self).__init__()

        self.upsample = nn.Sequential(
            Conv(in_channels, out_channels, 1),
            nn.Upsample(scale_factor=scale_factor, mode='bilinear')
        )

        # carafe
        # from mmcv.ops import CARAFEPack
        # self.upsample = nn.Sequential(
        #     BasicConv(in_channels, out_channels, 1),
        #     CARAFEPack(out_channels, scale_factor=scale_factor)
        # )

    def forward(self, x):
        x = self.upsample(x)

        return x

class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels,scale_factor=2):
        super(Downsample, self).__init__()

        self.downsample = nn.Sequential(
            Conv(in_channels, out_channels, scale_factor, scale_factor, 0)
        )

    def forward(self, x):
        x = self.downsample(x)

        return x

class ASFF_2(nn.Module):
    def __init__(self, inter_dim=512,level=0,channel=[64,128]):
        super(ASFF_2, self).__init__()

        self.inter_dim = inter_dim
        compress_c = 8

        self.weight_level_1 = Conv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_2 = Conv(self.inter_dim, compress_c, 1, 1)

        self.weight_levels = nn.Conv2d(compress_c * 2, 2, kernel_size=1, stride=1, padding=0)

        self.conv = Conv(self.inter_dim, self.inter_dim, 3, 1)
        self.upsample = Upsample(channel[1],channel[0])
        self.downsample = Downsample(channel[0],channel[1])
        self.level = level


    def forward(self, x):
        input1, input2 = x
        if self.level == 0:
            input2 = self.upsample(input2)
        elif self.level == 1:
            input1 = self.downsample(input1)

        level_1_weight_v = self.weight_level_1(input1)
        level_2_weight_v = self.weight_level_2(input2)

        levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v), 1)
        levels_weight = self.weight_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = input1 * levels_weight[:, 0:1, :, :] + \
                            input2 * levels_weight[:, 1:2, :, :]

        out = self.conv(fused_out_reduced)

        return out


class ASFF_3(nn.Module):
    def __init__(self, inter_dim=512,level=0,channel=[64,128,256]):
        super(ASFF_3, self).__init__()

        self.inter_dim = inter_dim
        compress_c = 8

        self.weight_level_1 = Conv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_2 = Conv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_3 = Conv(self.inter_dim, compress_c, 1, 1)

        self.weight_levels = nn.Conv2d(compress_c * 3, 3, kernel_size=1, stride=1, padding=0)

        self.conv = Conv(self.inter_dim, self.inter_dim, 3, 1)


        self.level = level
        if self.level == 0:
            self.upsample4x = Upsample(channel[2],channel[0], scale_factor=4)
            self.upsample2x = Upsample(channel[1], channel[0], scale_factor=2)
        elif self.level == 1:
            self.upsample2x1 = Upsample(channel[2], channel[1], scale_factor=2)
            self.downsample2x1 = Downsample(channel[0],channel[1], scale_factor=2)
        elif self.level == 2:
            self.downsample2x = Downsample(channel[1], channel[2], scale_factor=2)
            self.downsample4x = Downsample(channel[0], channel[2], scale_factor=4)

    def forward(self, x):
        input1, input2, input3 = x
        if self.level == 0:
            input2 = self.upsample2x(input2)
            input3= self.upsample4x(input3)
        elif self.level == 1:
            input3 = self.upsample2x1(input3)
            input1 = self.downsample2x1(input1)
        elif self.level == 2:
            input1 = self.downsample4x(input1)
            input2 = self.downsample2x(input2)
        level_1_weight_v = self.weight_level_1(input1)
        level_2_weight_v = self.weight_level_2(input2)
        level_3_weight_v = self.weight_level_3(input3)

        levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v, level_3_weight_v), 1)
        levels_weight = self.weight_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = input1 * levels_weight[:, 0:1, :, :] + \
                            input2 * levels_weight[:, 1:2, :, :] + \
                            input3 * levels_weight[:, 2:, :, :]

        out = self.conv(fused_out_reduced)

        return out

 在models/yolo.py中修改:

def parse_model(d, ch):  # model_dict, input_channels(3)
    # Parse a YOLOv5 model.yaml dictionary
    LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10}  {'module':<40}{'arguments':<30}")
    anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
    if act:
        Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()
        LOGGER.info(f"{colorstr('activation:')} {act}")  # print
    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)

    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
        m = eval(m) if isinstance(m, str) else m  # eval strings
        for j, a in enumerate(args):
            with contextlib.suppress(NameError):
                args[j] = eval(a) if isinstance(a, str) else a  # eval strings

        n = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gain
        if m in {
                Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
            c1, c2 = ch[f], args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)

            args = [c1, c2, *args[1:]]
            if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x}:
                args.insert(2, n)  # number of repeats
                n = 1
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum(ch[x] for x in f)
        # TODO: channel, gw, gd
        elif m in {Detect, Segment}:
            args.append([ch[x] for x in f])
            if isinstance(args[1], int):  # number of anchors
                args[1] = [list(range(args[1] * 2))] * len(f)
            if m is Segment:
                args[3] = make_divisible(args[3] * gw, 8)
        elif m is Contract:
            c2 = ch[f] * args[0] ** 2
        elif m is Expand:
            c2 = ch[f] // args[0] ** 2
        elif m in {ASFF_2, ASFF_3}:
            c2 = args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)
            args[0] = c2
            args.append([ch[x] for x in f])
        else:
            c2 = ch[f]
        a = [*args]
        m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
        t = str(m)[8:-2].replace('__main__.', '')  # module type
        np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
        LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f}  {t:<40}{str(args):<30}')  # print
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        ch.append(c2)
    return nn.Sequential(*layers), sorted(save)

 在yolo.py中配置--cfg为yolov5s-AFPN.yaml,点击运行,可见下图:

        论文中提到使用AFPN的效果要比PAN的好,暂时还没有验证,先肝代码,这是最初版,后续会优化。可以看最上面的图,参数确实是少了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/34814.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

OpenCV:深入Feature2D组件——角点检测

角点检测 1 Harris角点检测1.1 兴趣点与角点1.2 角点检测1.3 harris角点检测1.4 实现harris角点检测&#xff1a;cornerHarris()函数1.5 综合案例&#xff1a;harris角点检测与测绘 2. Shi—Tomasi角点检测2.1Shi—Tomasi角点检测概述2.2 确定图像强角点&#xff1a;goodFeatur…

实时包裹信息同步:WebSocket 在 Mendix 中的应用

场景介绍 在现代物流中&#xff0c;能够实时跟踪包裹信息&#xff0c;尤其是包裹重量&#xff0c;是非常重要的。在这种场景中&#xff0c;我们可以使用称重设备获取包裹的信息&#xff0c;然后实时将这些信息同步给 Mendix 开发的 App&#xff0c;并在 App 的页面上实时显示包…

用git下载gitee上的项目资源

目录 用git下载gitee上的项目资源 用git 的clone 命令 然后到gitee上复制相关的下载地址&#xff1a; 粘贴到clone后面即可&#xff08;注意地址与clone之间有空格&#xff01;&#xff01;&#xff01;&#xff09; 运行结果&#xff1a; 用git下载gitee上的项目资源 用git…

MySQL安装与部署

第一种方法&#xff1a;在线安装 配置一个安装yum源 Adding the MySQL Yum Repository 可以手动配置yum源&#xff0c;baseurl指向国内镜像源地址&#xff0c;比如清华、中科大。 Installing MySQL Starting the MySQL Server&#xff1a; 查询临时登录密码 修改数据库密码…

golang 结构体struct转map实践

1、反射 type sign struct { Name string json:"name,omitempty" Age int json:"age,omitempty" } var s sign s.Name "csdn" s.Age 18 //方式1 反射 var data make(map[string]interface{}) t : reflect.TypeOf(s) v : …

Spring Bean的实例化过程

一、前言 对于写Java的程序员来说&#xff0c;Spring已经成为了目前最流行的第三方开源框架之一&#xff0c;在我们充分享受Spring IOC容器带来的红利的同时&#xff0c;我们也应该考虑一下Spring这个大工厂是如何将一个个的Bean生产出来的&#xff0c;本期我们就一起来讨论一…

2023年第三届工业自动化、机器人与控制工程国际会议

会议简介 Brief Introduction 2023年第三届工业自动化、机器人与控制工程国际会议&#xff08;IARCE 2023&#xff09; 会议时间&#xff1a;2023年10月27 -30日 召开地点&#xff1a;中国成都 大会官网&#xff1a;www.iarce.org 2023年第三届工业自动化、机器人与控制工程国际…

Redis通信协议

RESP协议 Redis是一个CS架构的软件&#xff0c;通信一般分两步&#xff08;不包括pipeline和PubSub&#xff09;&#xff1a; ① 客户端&#xff08;client&#xff09;向服务端&#xff08;server&#xff09;发送一条命令 ② 服务端解析并执行命令&#xff0c;返回响应结果…

Spring MVC各种参数进行封装

目录 一、简单数据类型 1.1 控制器方法 1.2 测试结果 二、对象类型 2.1 单个对象 2.1.1 控制器方法 2.1.2 测试结果 2.2 关联对象 2.2.1 控制器方法 2.2.2 测试结果 三、集合类型 3.1 简单数据类型集合 3.1.1 控制方法 3.1.2 测试结果 3.2 对象数据类型集合 3.…

使用MQL4编写自己的交易策略:技巧与经验分享

随着技术的发展&#xff0c;越来越多的投资者开始使用程序化交易系统进行交易&#xff0c;其中MQL4语言是广泛应用于MetaTrader 4平台上编写交易策略的一种语言。本文将分享一些技巧和经验&#xff0c;帮助读者利用MQL4编写自己的交易策略。 策略开发流程 首先&#xff0c;我…

传输控制协议 TCP

文章目录 一、TCP报文格式1.报头格式2.TCP最大段长度 MSS 二、TCP连接建立与释放1.连接建立&#xff1a;三次握手2.报文传输3.连接释放&#xff1a;四次挥手4.保持定时器与时间等待定时器 三、TCP差错重传1.字节流状态分类与滑动窗口&#xff08;发送&#xff09;① 滑动窗口两…

Android Studio实现内容丰富的安卓博客发布平台

如需源码可以添加q-------3290510686&#xff0c;也有演示视频演示具体功能&#xff0c;源码不免费&#xff0c;尊重创作&#xff0c;尊重劳动。 项目编号078 1.开发环境 android stuido jdk1.8 eclipse mysql tomcat 2.功能介绍 安卓端&#xff1a; 1.注册登录 2.查看博客列表…

[AJAX]原生AJAX——自定义请求头

客户端 <script>// 1、创建对象const xhr new XMLHttpRequest();// 2、初始化&#xff1a;设置请求类型和urlxhr.open(POST, http://127.0.0.1:8000/server);// 设置请求头// Content-Type&#xff1a;设置请求体内容类型// application/x-www-form-urlencoded&#xf…

2022(二等奖)C2464植物保护管理系统

作品介绍 一、需求分析 1. 应用背景 森林是陆地生态系统的主体&#xff0c;是人类生存与发展的物质基础。以森林为主要经营对象的林业&#xff0c;不仅承担着生态建设的主要任务&#xff0c;而且承担着提供多种林产品的重大使命。进入21世纪&#xff0c;人类正在继农业文明和…

二进制、十进制相互转换

二进制转十进制&#xff1a; 1100 0000转为十进制的数值为&#xff1a;12864192 十进制转二进制&#xff1a; 列如&#xff1a;十进制数为202 1286432168421二进制11001010 解析&#xff1a; 202>128&#xff0c;第一个二进制数为&#xff1a;1 202-128>64&#xf…

Spring 事务管理方案和事务管理器及事务控制的API

目录 一、事务管理方案 1. 修改业务层代码 2. 测试 二、事务管理器 1. 简介 2. 在配置文件中引入约束 3. 进行事务配置 三、事务控制的API 1. PlatformTransactionManager接口 2. TransactionDefinition接口 3. TransactionStatus接口 往期专栏&文章相关导读 …

【Lua】ZeroBrane Studio免费专业IDE使用详解

▒ 目录 ▒ &#x1f6eb; 问题描述环境 1️⃣ IDE界面说明项目目录编辑器控制台窗口输出窗口选择解释器堆栈窗口监视窗口大纲窗口 2️⃣ 调试程序3️⃣ 自定义lua解释器编译自己的lua解释器增加interpreters配置文件重启IDE 4️⃣ 其它IDE比较Lua EditorVSCode &#x1f6ec; …

Redis:redis基于各大实战场景下的基本使用

文章目录 前言String 命令实战1.业务缓存对应redis中的指令伪代码 2.分布式锁对应redis中的指令伪代码 3.限流对应redis中的指令伪代码 List 命令实战1.提醒功能对应Redis中的指令伪代码 2.热点列表对应Redis中的指令伪代码 Hash 命令实战1.用户资料缓存对应redis中的指令伪代码…

算法设计与分析 课程期末复习简记

目录 网络流 线性规划 回溯算法 分支限界 贪心算法 动态规划 分治算法 算法复杂度分析 相关概念 网络流 下面是本章需要掌握的知识 • 流量⽹络的相关概念 • 最⼤流的概念 • 最⼩割集合的概念 • Dinic有效算法的步骤 • 会⼿推⼀个流量⽹络的最⼤流 下面对此依次进行复…

数据结构--串的定义和基本操作

数据结构–串的定义和基本操作 注:数据结构三要素――逻辑结构、数据的运算、存储结构&#xff08;物理结构) 存储结构不同&#xff0c;运算的实现方式不同 \color{pink}存储结构不同&#xff0c;运算的实现方式不同 存储结构不同&#xff0c;运算的实现方式不同 串的定义 串 …