【目标检测】yolov8结构及代码分析

yolov8代码:https://github.com/ultralytics/ultralytics

yolov8的整体结构如下图(来自mmyolo):

yolov8的配置文件:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

可以看出,主要包含Conv,C2f,SPPF,Concat,Detect几个模块。

一、Conv

Conv模块包含卷积层、BN层和激活函数层。

代码如下:

class Conv(nn.Module):
    """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given arguments including activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor."""
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        """Perform transposed convolution of 2D data."""
        return self.act(self.conv(x))

二、C2f

C2f就是模型结构图中的CSPLayer_2Conv,包含多个DarkNetBottleNeck。代码如下:

class C2f(nn.Module):
    """Faster Implementation of CSP Bottleneck with 2 convolutions."""

    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
        """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
        expansion.
        """
        super().__init__()
        self.c = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))

    def forward(self, x):
        """Forward pass through C2f layer."""
        y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

    def forward_split(self, x):
        """Forward pass using split() instead of chunk()."""
        y = list(self.cv1(x).split((self.c, self.c), 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))


class Bottleneck(nn.Module):
    """Standard bottleneck."""

    def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
        """Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and
        expansion.
        """
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, k[0], 1)
        self.cv2 = Conv(c_, c2, k[1], 1, g=g)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        """'forward()' applies the YOLO FPN to input data."""
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

三、SPPF

yolov8的SPPF实现是通过多次池化实现不同大小的池化窗口运算,代码如下:

class SPPF(nn.Module):
    """Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""

    def __init__(self, c1, c2, k=5):
        """
        Initializes the SPPF layer with given input/output channels and kernel size.

        This module is equivalent to SPP(k=(5, 9, 13)).
        """
        super().__init__()
        c_ = c1 // 2  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_ * 4, c2, 1, 1)
        self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

    def forward(self, x):
        """Forward pass through Ghost Convolution block."""
        x = self.cv1(x)
        y1 = self.m(x)
        y2 = self.m(y1)
        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))

四、Concat

和torch.cat功能几乎一致:

class Concat(nn.Module):
    """Concatenate a list of tensors along dimension."""

    def __init__(self, dimension=1):
        """Concatenates a list of tensors along a specified dimension."""
        super().__init__()
        self.d = dimension

    def forward(self, x):
        """Forward pass for the YOLOv8 mask Proto module."""
        return torch.cat(x, self.d)

五、Detect

yolov8的检测头:

class Detect(nn.Module):
    """YOLOv8 Detect head for detection models."""
    dynamic = False  # force grid reconstruction
    export = False  # export mode
    shape = None
    anchors = torch.empty(0)  # init
    strides = torch.empty(0)  # init

    def __init__(self, nc=80, ch=()):
        """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
        super().__init__()
        self.nc = nc  # number of classes
        self.nl = len(ch)  # number of detection layers
        self.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
        self.no = nc + self.reg_max * 4  # number of outputs per anchor
        self.stride = torch.zeros(self.nl)  # strides computed during build
        c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channels
        self.cv2 = nn.ModuleList(
            nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
        self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
        self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

    def forward(self, x):
        """Concatenates and returns predicted bounding boxes and class probabilities."""
        shape = x[0].shape  # BCHW
        for i in range(self.nl):
            x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
        if self.training:
            return x
        elif self.dynamic or self.shape != shape:
            self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
            self.shape = shape

        x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
        if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'):  # avoid TF FlexSplitV ops
            box = x_cat[:, :self.reg_max * 4]
            cls = x_cat[:, self.reg_max * 4:]
        else:
            box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
        dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides

        if self.export and self.format in ('tflite', 'edgetpu'):
            # Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:
            # https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309
            # See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
            img_h = shape[2] * self.stride[0]
            img_w = shape[3] * self.stride[0]
            img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)
            dbox /= img_size

        y = torch.cat((dbox, cls.sigmoid()), 1)
        return y if self.export else (y, x)

    def bias_init(self):
        """Initialize Detect() biases, WARNING: requires stride availability."""
        m = self  # self.model[-1]  # Detect() module
        # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
        # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # nominal class frequency
        for a, b, s in zip(m.cv2, m.cv3, m.stride):  # from
            a[-1].bias.data[:] = 1.0  # box
            b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/280508.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于Python的电商手机数据可视化分析和推荐系统

1. 项目简介 本项目旨在通过Python技术栈对京东平台上的手机数据进行抓取、分析并构建一个简单的手机推荐系统。主要功能包括: 网络爬虫:从京东获取手机数据;数据分析:统计各厂商手机销售分布、市场占有率、价格区间和好评率&am…

WPF+Halcon 培训项目实战(12):WPF导出匹配模板

文章目录 前言相关链接项目专栏运行环境匹配图片WPF导出匹配模板如何了解Halcon和C#代码的对应关系逻辑分析:添加截取ROI功能基类矩形圆形 生成导出模板运行结果:可能的报错你的文件路径不存在你选择的区域的内容有效信息过少 前言 为了更好地去学习WPF…

大创项目推荐 深度学习二维码识别

文章目录 0 前言2 二维码基础概念2.1 二维码介绍2.2 QRCode2.3 QRCode 特点 3 机器视觉二维码识别技术3.1 二维码的识别流程3.2 二维码定位3.3 常用的扫描方法 4 深度学习二维码识别4.1 部分关键代码 5 测试结果6 最后 0 前言 🔥 优质竞赛项目系列,今天…

【PyQt】(自定义类)QIcon派生,更易用的纯色Icon

嫌Qt自带的icon太丑,自己写了一个,主要用于纯色图标的自由改色。 当然,图标素材得网上找。 Qt原生图标与现代图标对比: 没有对比就没有伤害 Qt图标 网络素材图标 自定义类XJQ_Icon: from PyQt5.QtGui import QIc…

java go c++ 开源全文搜索引擎

Apache Lucene Java 全文搜索框架 许可证:Apache-2.0 开发语言:Java 官网:https://lucene.apache.org/ Apache Lucene 是完全用 Java 编写的高性能、功能齐全的全文检索引擎架构,提供了完整的查询引擎和索引引擎、部分文本分析引…

超维空间S2无人机使用说明书——52、初级版——使用PID算法进行基于yolo的目标跟踪

引言:在实际工程项目中,为了提高系统的响应速度和稳定性,往往需要采用一定的控制算法进行目标跟踪。这里抛砖引玉,仅采用简单的PID算法进行目标的跟随控制,目标的识别依然采用yolo。对系统要求更高的,可以对…

使用SecoClient软件连接L2TP

secoclient软件是华为防火墙与友商设备进行微屁恩对接的一款软件,运行在windows下可以替代掉win系统自带的连接功能,因为win系统自带的连接功能总是不可用而且我照着网上查到的各种方法调试了很久都调不好,导致我一度怀疑是我的服务没搭建好,浪费了大把时间去研究其他搭建方案 …

Kubernetes技术与架构-集群管理

Kubernetes技术与架构提供支撑工具支持集群的规划、安装、创建以及管理。 数字证书 用户可以使用easyrsa、openssl、cfssl工具生成数字证书,在kubernetes集群的api server中部署数字证书用于访问鉴权 资源管理 如上所示,定义一个服务类service用于负…

Flask笔记

一:模板渲染 一般的话都序列化成字符串 二:项目拆分 2.1 项目拆分 app.py init.py views.py models.py 模型数据 2.2 蓝图 三:路由参数 3.1 String 重点 3.2 int 3.3 path 3.4 UUID 3.5 any 四:请求方式 五:Requ…

将网页变身移动应用:网址封装成App的完全指南

什么是网址封装? 网址封装是一个将你的网站或网页直接嵌入到一个原生应用容器中的过程。用户可以通过下载你的App来访问网站,而无需通过浏览器。这种方式不仅提升了用户体验,还可利用移动设备的功能,如推送通知和硬件集成。 小猪…

介绍一款PDF在线工具

PDF是我们日常工作中的一种常见格式,其处理也是我们工作的重要基础性环节,一款可靠的处理工具显得十分重要。 完全免费、易于使用、丰富的PDF处理工具,包括:合并、拆分、压缩、转换、旋转和解锁PDF文件,以及给PDF文件…

linux实用技巧:ubuntu18.04安装samba服务器实现局域网文件共享

Ubuntu安装配置Samba服务与Win10共享文件 Chapter1 Ubuntu18.04安装配置Samba服务与Win10共享文件一、什么是Samba二、安装Samba1、查看是否有安装samba2、安装samba 三、配置Samba服务1、创建共享目录(以samba_workspaces为例)2、为samba设置登录用户3、…

SSM房屋租赁系统----计算机毕业设计

项目介绍 房屋租赁系统,基于 Spring5.x 的实战项目,此项目非Maven项目。 前台系统主要功能包括房源列表展示、房源详细信息展示、根据房源特征进行搜索,包括:房型、小区名;以及房源的预订功能。 后台管理: 用户信息管…

1.2.0 IGP高级特性之FRR

理论部分参考文档:Segment Routing TI-LFA FRR保护技术 - 华为 一、快速重路由技术 FRR(Fast Reroute)快速重路由 实现备份链路的快速切换,也可以与BFD联动实现对故障的快速感知。 随着网络的不断发展,VoIP和在线视频等业务对实时性的要求越…

【AIGC-图片生成视频系列-3】AI视频随心而动:MotionCtrl的相机运动控制和物体运动控制

最近,「单张图片生成视频」相关工作很多,但运动控制的准确性依旧是个挑战,包括相机运动的控制以及物体运动控制。 然,MotionCtrl 横空出世。 一. 项目简介 MotionCtrl——一个相机运动控制、物体运动控制的视频工具&#xff0c…

[C#]opencvsharp进行图像拼接普通拼接stitch算法拼接

介绍: opencvsharp进行图像拼一般有2种方式:一种是传统方法将2个图片上下或者左右拼接,还有一个方法就是融合拼接,stitch拼接就是一种非常好的算法。opencv里面已经有stitch拼接算法因此我们很容易进行拼接。 效果: …

二分查找(非朴素)--在排序数组中查找元素的第一个和最后一个位置

个人主页:Lei宝啊 愿所有美好如期而遇 目录 本题链接 输入描述 输出描述 算法分析 1.算法一:暴力求解 2.算法二:朴素二分算法 3.算法三:二分查找左右端点 3.1查找左端点 3.1.1细节一:循环条件 3.1.2细节二…

【前端基础】——原型与原型链详解,看一篇即可【图文版】

前言 本文旨在通过图文的方式,一步步回顾原型链的整个流程是如何运作的,如果你刚好在电脑旁边,不妨跟着我的思路,一起走一遍敲一遍代码流程,你会发现原型链并没有你想的那么复杂。 new关键字 我们先看这一个代码&am…

华锐三维云展平台 | VR在线展览云平台提供定制化虚拟展厅制作工具

随着科技的飞速发展,互联网技术的不断革新,广州华锐互动顺应时代需求,开发了VR在线展览云平台,用户可以在平台上自主创建属于自己的3D展厅。VR在线展览云平台正改变着传统展览行业的模式,为参展者提供更高效、更便捷、…

vue3中怎么巧妙的去运用jsx?

文章目录 概要JSX / TSX?安装配置封装JsxRender.vue使用JsxRender.vue怎么巧妙的去使用它?Demo下载 概要 我们都知道vue3是支持用jsx/tsx,但是好像从来都没有人告诉我们应该怎么运用到项目当中,下面是我觉得不错的一种使用方式,分享给大家…