YOLOv8改进教程|加入可改变核卷积AKConv模块,效果远超DSConv!


⭐⭐ YOLOv8改进专栏|包含主干、模块、注意力机制、检测头等前沿创新 ​ ⭐⭐


 一、 论文介绍

        论文链接:https://arxiv.org/abs/2311.11587

        代码链接:GitHub - CV-ZhangXin/AKConv

论文速览::AKConv是2023年11月发表的一种可变卷积核,赋予卷积核任意数量的参数和任意采样形状,以解决具有固定样本形状和正方形的卷积核不能很好地适应不断变化的目标的问题点可以为网络开销和性能之间的权衡提供更丰富的选择。 AKConv的核心思想在于它为卷积核提供了任意数量的参数和任意采样形状,能够使用任意数量的参数(如1,2,3,4,5,6,7等)来提取姝征,这在标准卷积和可变形卷积中并未实现。AKConv能够根据硬件环境,使卷积参数的数星呈线性增减((非常适用于轻量化模型)。

总结:AKConv是一种具有任意数量的参数和任意采样形状的可变卷积核,对不规则特征有更好的提取效果。


二、 加入到RT-DETR中

2.1 复制代码

        复制代码粘到ultralytics->nn->modules->conv.py文件中,在顶部导入torch.nn.functional包,(torch.nn.functional as F),将代码粘贴于下方,并在__all__中声明,如下图所示:

# Ultralytics YOLO 🚀, AGPL-3.0 license
"""Convolution modules."""

import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


__all__ = (
    "Conv",
    "Conv2",
    "LightConv",
    "DWConv",
    "DWConvTranspose2d",
    "ConvTranspose",
    "Focus",
    "GhostConv",
    "ChannelAttention",
    "SpatialAttention",
    "CBAM",
    "Concat",
    "RepConv",
    "AKConv",
)


class AKConv(nn.Module):
    def __init__(self, inc, outc, num_param, stride=1, bias=None):
        super(AKConv, self).__init__()
        self.num_param = num_param
        self.stride = stride
        self.conv = nn.Sequential(nn.Conv2d(inc, outc, kernel_size=(num_param, 1), stride=(num_param, 1), bias=bias),
                                  nn.BatchNorm2d(outc),
                                  nn.SiLU())  # the conv adds the BN and SiLU to compare original Conv in YOLOv5.
        self.p_conv = nn.Conv2d(inc, 2 * num_param, kernel_size=3, padding=1, stride=stride)
        nn.init.constant_(self.p_conv.weight, 0)
        self.p_conv.register_full_backward_hook(self._set_lr)

    @staticmethod
    def _set_lr(module, grad_input, grad_output):
        grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
        grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))

    def forward(self, x):
        # N is num_param.
        offset = self.p_conv(x)
        dtype = offset.data.type()
        N = offset.size(1) // 2
        # (b, 2N, h, w)
        p = self._get_p(offset, dtype)

        # (b, h, w, 2N)
        p = p.contiguous().permute(0, 2, 3, 1)
        q_lt = p.detach().floor()
        q_rb = q_lt + 1

        q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), torch.clamp(q_lt[..., N:], 0, x.size(3) - 1)],
                         dim=-1).long()
        q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)],
                         dim=-1).long()
        q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
        q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)

        # clip p
        p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1)

        # bilinear kernel (b, h, w, N)
        g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
        g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
        g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
        g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))

        # resampling the features based on the modified coordinates.
        x_q_lt = self._get_x_q(x, q_lt, N)
        x_q_rb = self._get_x_q(x, q_rb, N)
        x_q_lb = self._get_x_q(x, q_lb, N)
        x_q_rt = self._get_x_q(x, q_rt, N)

        # bilinear
        x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
                   g_rb.unsqueeze(dim=1) * x_q_rb + \
                   g_lb.unsqueeze(dim=1) * x_q_lb + \
                   g_rt.unsqueeze(dim=1) * x_q_rt

        x_offset = self._reshape_x_offset(x_offset, self.num_param)
        out = self.conv(x_offset)

        return out

    # generating the inital sampled shapes for the AKConv with different sizes.
    def _get_p_n(self, N, dtype):
        base_int = round(math.sqrt(self.num_param))
        row_number = self.num_param // base_int
        mod_number = self.num_param % base_int
        p_n_x, p_n_y = torch.meshgrid(
            torch.arange(0, row_number),
            torch.arange(0, base_int))
        p_n_x = torch.flatten(p_n_x)
        p_n_y = torch.flatten(p_n_y)
        if mod_number > 0:
            mod_p_n_x, mod_p_n_y = torch.meshgrid(
                torch.arange(row_number, row_number + 1),
                torch.arange(0, mod_number))

            mod_p_n_x = torch.flatten(mod_p_n_x)
            mod_p_n_y = torch.flatten(mod_p_n_y)
            p_n_x, p_n_y = torch.cat((p_n_x, mod_p_n_x)), torch.cat((p_n_y, mod_p_n_y))
        p_n = torch.cat([p_n_x, p_n_y], 0)
        p_n = p_n.view(1, 2 * N, 1, 1).type(dtype)
        return p_n

    # no zero-padding
    def _get_p_0(self, h, w, N, dtype):
        p_0_x, p_0_y = torch.meshgrid(
            torch.arange(0, h * self.stride, self.stride),
            torch.arange(0, w * self.stride, self.stride))

        p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)

        return p_0

    def _get_p(self, offset, dtype):
        N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3)

        # (1, 2N, 1, 1)
        p_n = self._get_p_n(N, dtype)
        # (1, 2N, h, w)
        p_0 = self._get_p_0(h, w, N, dtype)
        p = p_0 + p_n + offset
        return p

    def _get_x_q(self, x, q, N):
        b, h, w, _ = q.size()
        padded_w = x.size(3)
        c = x.size(1)
        # (b, c, h*w)
        x = x.contiguous().view(b, c, -1)

        # (b, h, w, N)
        index = q[..., :N] * padded_w + q[..., N:]  # offset_x*w + offset_y
        # (b, c, h*w*N)
        index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)

        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)

        return x_offset

    #  Stacking resampled features in the row direction.
    @staticmethod
    def _reshape_x_offset(x_offset, num_param):
        b, c, h, w, n = x_offset.size()
        # using Conv3d
        # x_offset = x_offset.permute(0,1,4,2,3), then Conv3d(c,c_out, kernel_size =(num_param,1,1),stride=(num_param,1,1),bias= False)
        # using 1 × 1 Conv
        # x_offset = x_offset.permute(0,1,4,2,3), then, x_offset.view(b,c×num_param,h,w)  finally, Conv2d(c×num_param,c_out, kernel_size =1,stride=1,bias= False)
        # using the column conv as follow, then, Conv2d(inc, outc, kernel_size=(num_param, 1), stride=(num_param, 1), bias=bias)

        x_offset = rearrange(x_offset, 'b c h w n -> b c (h n) w')
        return x_offset

2.2 更改modules.__init__.py文件 

       打开ultralytics->nn->modules->__init__.py,在第64行与81行加入AKConv进行声明。

​2.3 更改task.py文件 

        打开ultralytics->nn路径下的tasks.py文件,首先在第51行加入AKConv导入模块,然后在第928行(或其他合适的位置)加入下方代码:

        elif m is AKConv:
            c2 = args[0]
            c1 = ch[f]
            args = [c1, c2, *args[1:]]

 2.4 更改yaml文件 

        创建yaml文件,使用AKConv替换yaml文件中原有的Conv模块。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 1, AKConv, [256, 3]]
  - [-1, 1, SPPF, [1024, 5]] # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 12

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]] # cat head P4
  - [-1, 3, C2f, [512]] # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]] # cat head P5
  - [-1, 3, C2f, [1024]] # 21 (P5/32-large)

  - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)

 2.5 修改train.py文件

        在train.py脚本中填入创建好的yaml路径,运行即可训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/629455.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

详细分析Vue3中的reactive(附Demo)

目录 1. 基本知识2. 用法3. Demo 1. 基本知识 reactive 是一个函数,用于将一个普通的 JavaScript 对象转换为响应式对象 当对象的属性发生变化时,Vue 会自动追踪这些变化,并触发相应的更新 Vue2没有,而Vue3中有,为啥…

C++入门系列-赋值运算符重载

🌈个人主页:羽晨同学 💫个人格言:“成为自己未来的主人~” 赋值运算符重载 运算符重载 C为了增强代码的可读性引入了运算符重载,运算符重载是具有特殊函数名的函数,也具有其返回值类型,函数名字以及参…

评价决策类-层次分析法

师从江北 问题引出 归一化处理:指标的数组[a b c]归一化处理得到[a/(abc),b/(abc),c/(abc)] 因为每个指标的重要性不同,所以要加上一个权重 如何科学的确定权重,就要用到层次分析法(AHP) 模型原理 建立递阶层次结构模…

百度云防护如何开启CC攻击防护

百度云防护的最重要的功能是可以CC攻击防护,针对CC攻击,百度云防护有被动的CC攻击拦截规则,也有主动自定义访问策略拦截。 今天百度云来教大家如何开启百度云防护的CC攻击防御功能。 1.进入防护模板功能-创建模板 2.开启CC攻击防御功能&…

ubuntu20.04 ROS 环境下使用速腾80线激光雷达

1.相关系统环境 系统版本:ubuntu 20.04 ROS版本:ROS1 - noetic 激光雷达型号:RoboSense Ruby (更新于2024.5.14) 2.网口配置: 将PC/工控机的网口配置为: ipv4,方式设置为手动 ip地址、掩码以…

半小时搞懂STM32知识点——UART

1.UART 1.1为什么要使用UART这种协议?介绍一下UART及其特点 成本低,硬件简单,数据格式灵活; 低速全双工异步串行通信 1.2 UART数据帧格式? 起始位(1)+数据位(5-8) 校验位…

保研机试之【execve函数】

execve 参考:fork()函数两次返回_fork是如何返回两次的-CSDN博客 setjmp/longjmp 还有E:

解决kali Linux2024无法获取动态IPv4地址(DHCP)解决方案

用root用户启动终端 进入根目录,选择配置文件 cd到根目录下/../etc/network找到interfaces文件 编辑interfaces文件 vi interfaces,编辑interfaces文件 输入如下命令 打开虚拟网络编辑器 选择虚拟机选项卡,编辑,打开虚拟网络编…

数据结构——二叉树知识点详解!

引言:本篇博客将详细介绍到数据结构中的又一位大将——二叉树。它也是我们目前学到的第一个非线性的数据结构。并且本章将学到的概念居多,希望大家可以理解并牢记。 更多有关C语言和数据结构知识详解可前往个人主页:计信猫 目录 一&#xff0…

C++语法|对象的浅拷贝和深拷贝

背景&#xff1a; 我们手写一个顺序栈&#xff0c;展开接下来的实验&#xff1a; ⭐️ this指针指向的是类在内存中的起始位置 class SeqStack { public:SqeStack(int size 10) {cout << this << "SeqStack()" << endl;pstack_ new int[size_];t…

Gemini 5.14日更新 - 推出Gemini Advance服务

收到Gemini Advance试用邀请 今天和往常一样&#xff0c;打开Gemini&#xff0c;惊喜的发现右小角一行小字&#xff1a;试用Gemini Advance。好家伙&#xff0c;OpenAI 刚推出ChatGPT 4o&#xff0c;Google立马推出Gemini Advance&#xff0c;说明国外高科技企业也是很拼的。 …

STC8增强型单片机开发【热敏电阻】

目录 一、引言 二、热敏电阻概述 三、STC8增强型单片机简介 四、基于STC8单片机的热敏电阻测温系统 五、热敏电阻测温系统的优化与扩展 提高测量精度 扩展系统功能 六、 温度计算步骤 通过ADC采样计算出热敏电阻位置的电压 通过欧姆定律计算热敏电阻的阻值 通过阻值…

【Linux】常用指令、热键与权限管理

一、常用指令 &#xff08;1&#xff09;ls 功能&#xff1a;列出指定目录下的所有子目录与文件 用法&#xff1a;ls &#xff08;选项&#xff09; &#xff08;目录或文件名&#xff09; 常用选项&#xff1a; -a&#xff1a;列出目录下的所有文件&#xff0c;包括隐藏…

fastjson1.2.68对于文件操作的分析最全

fastjson1.2.68对于文件操作的分析 前言分析复制文件写入文件读取文件分析poc拓宽场景极限环境poc优化修改再次优化poc的分析 前言 这次分析也是分析了很久&#xff0c;因为每个链子都是自己去跟着分析了的&#xff0c;然后主要是去学习了一下怎么去挖链子 分析 前面漏洞复现…

TypeScript中的泛型(Generics)

TypeScript中的泛型&#xff08;Generics&#xff09; 在前面的几篇文章中&#xff0c;我们了解了TypeScript的类、接口和基本的数据类型系统。本文将重点介绍TypeScript中的泛型&#xff0c;这是一种强大的工具&#xff0c;它允许我们创建可重用的组件&#xff0c;同时保持类…

【Java学习笔记10 Java Web 应用——JSP

JSP(Java Script Pages)技术是一种网站开发技术&#xff0c;可以让Web开发人员快速、高效的开发出易于维护的动态网页。使用JSP技术开发的Web应用程序具有跨平台性&#xff0c;不需要修改程序&#xff0c;发布后即可在Windows、Linux等不同的操作系统中运行。 10.1 JSP技术概述…

申请免费通配符证书

通配符证书是一种 SSL/TLS 证书&#xff0c;可用于保护多个域&#xff08;主机&#xff09;&#xff0c;由域名字段中的通配符 (*) 指示。这种证书主要用于具有许多子域的组织。通配符证书对主域及其所有次级子域有效。 通配符证书和多域名证书的区别&#xff1a; 要注意的是…

webpack优化构建速度示例-合理配置loader的include exclude:

实际上&#xff0c;babel-loader 在 Webpack 配置中默认并不包含 exclude 和 include 选项的默认值&#xff0c;通常&#xff0c;为了优化构建性能&#xff0c;开发者会显式地设置 exclude 和 include 选项&#xff0c;以便 babel-loader 只处理必要的文件。 src/index.js impo…

Zookeeper and RPC dubbo

javaguide zookeeper面试题 Zookeeper 啥是Zookeeper干啥的 ZooKeeper 可以被用作注册中心、分布式锁&#xff1b; ZooKeeper 是 Hadoop 生态系统的一员&#xff1b; 构建 ZooKeeper 集群的时候&#xff0c;使用的服务器最好是奇数台。 启动ZK 下载安装解压 不过多赘述 我的…

Spring:了解@Import注解的三种用法

一、前言 在 Spring 框架中&#xff0c;Import 注解用于导入配置类&#xff0c;使得你可以在一个配置类中引入另一个或多个配置类&#xff0c;从而实现配置的模块化。这对于组织大型应用程序的配置非常有用&#xff0c;因为它允许你将配置分散到多个类中&#xff0c;然后再将它…