YOLOv8改进 | 图像去雾 | 利用图像去雾网络AOD-PONO-Net网络增改进图像物体检测

一、本文介绍 

本文给大家带来的改进机制是利用AODNet图像去雾网络结合PONO机制实现二次增强,我将该网络结合YOLOv8针对图像进行去雾检测(也适用于一些模糊场景,图片不清晰的检测)同时本文的内容不影响其它的模块改进可以作为工作量凑近大家的论文里,非常的适用,图像去雾检测为群友最近提出的需要的改进,在开始之前给大家推荐一下我的专栏,本专栏每周更新3-10篇最新前沿机制 | 包括二次创新全网无重复,以及融合改进(大家拿到之后添加另外一个改进机制在你的数据集上实现涨点即可撰写论文),还有各种前沿顶会改进机制 |,更有包含我所有附赠的文件(文件内集成我所有的改进机制全部注册完毕可以直接运行)和交流群和视频讲解提供给大家。  

👑欢迎大家订阅我的专栏一起学习YOLO👑 

专栏目录:YOLOv8改进有效系列目录 | 包含卷积、主干、检测头、注意力机制、Neck上百种创新机制

专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备    

目录

一、本文介绍 

二、原理介绍 

三、核心代码

四、添加教程

4.1 修改一

4.2 修改二 

4.3 修改三 

五、AODNet-PONO-Net的yaml文件和运行记录

5.1 AODNet-PONO-Net的yaml文件

5.2 训练代码 

5.3 训练过程截图 

五、本文总结


二、原理介绍 

官方论文地址: 官方论文地址点击即可跳转

官方代码地址: 官方代码地址点击即可跳转

摘要:这篇论文提出了一种名为全能去雾网络(AOD-Net)的图像去雾模型,该模型是基于重新制定的大气散射模型并利用卷积神经网络(CNN)构建的。与大多数先前的模型不同,AOD-Net不是分别估计传输矩阵和大气光,而是直接通过一个轻量级的CNN生成清晰图像。这种新颖的端到端设计使得将AOD-Net嵌入到其他深度模型中变得简单,例如,用于提升雾霾图像上高级任务性能的Faster R-CNN。在合成和自然雾霾图像数据集上的实验结果证明了我们在峰值信噪比(PSNR)、结构相似性指数(SSIM)和主观视觉质量方面超越了最先进技术的性能。此外,当将AOD-Net与Faster R-CNN结合并从头到尾进行联合训练时,我们见证了雾霾图像上对象检测性能的显著提升。


AOD-Net是一个端到端的可训练去雾模型,直接从有雾图像产生清晰图像,而不是依赖于任何单独和中间参数估计步骤。基于重新公式化的大气散射模型设计,与现有工作共享相同的物理基础,但以一种“更端到端”的方式将其所有参数估计在一个统一模型中完成

主要创新点

  1. 端到端去雾模型:首次提出一个端到端训练的去雾模型,直接从雾图像生成清晰图像,避免了传统方法中估计传输矩阵和大气光的独立步骤​​。
  2. 与高级视觉任务的结合:首次量化研究去雾质量如何影响后续高级视觉任务的性能,为比较去雾结果提供了一种新的客观标准。此外,AOD-Net可以无缝地与其他深度模型嵌入,形成一个在有雾图像上执行高级任务的流水线,通过端到端的联合调优进一步提升性能​​。

个人总结:AOD-Net能够一步到位地把雾气重的照片变清晰,而不像以前的方法那样需要分好几步小心翼翼地处理。简单来说,AOD-Net就是通过学习雾中的图片和清晰图片之间的差别,找到一种直接去除雾气的捷径,使得图片恢复清晰,同时也帮助计算机更好地理解图片内容。 


三、核心代码

核心代码的使用方式看章节四!

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
torch.autograd.set_detect_anomaly(True)
__all__ = ['AOD_pono_net']
class AODnet(nn.Module):
    def __init__(self):
        super(AODnet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=5, stride=1, padding=2)
        self.conv4 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=7, stride=1, padding=3)
        self.conv5 = nn.Conv2d(in_channels=12, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.b = 1

    def forward(self, x):
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(x1))
        cat1 = torch.cat((x1, x2), 1)
        x3 = F.relu(self.conv3(cat1))
        cat2 = torch.cat((x2, x3), 1)
        x4 = F.relu(self.conv4(cat2))
        cat3 = torch.cat((x1, x2, x3, x4), 1)
        k = F.relu(self.conv5(cat3))

        if k.size() != x.size():
            raise Exception("k, haze image are different size!")

        output = k * x - k + self.b
        return F.relu(output)

class AOD_pono_net(nn.Module):
    def __init__(self):
        super(AOD_pono_net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=5, stride=1, padding=2)
        self.conv4 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=7, stride=1, padding=3)
        self.conv5 = nn.Conv2d(in_channels=12, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.b = 1

        self.pono = PONO(affine=False)
        self.ms = MS()

    def forward(self, x):
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(x1))
        cat1 = torch.cat((x1, x2), 1)
        x1, mean1, std1 = self.pono(x1)
        x2, mean2, std2 = self.pono(x2)
        x3 = F.relu(self.conv3(cat1))
        cat2 = torch.cat((x2, x3), 1)
        x3 = self.ms(x3, mean1, std1)
        x4 = F.relu(self.conv4(cat2))
        x4 = self.ms(x4, mean2, std2)
        cat3 = torch.cat((x1, x2, x3, x4), 1)
        k = F.relu(self.conv5(cat3))

        if k.size() != x.size():
            raise Exception("k, haze image are different size!")

        output = k * x - k + self.b
        output = F.relu(output)
        return output

class PONO(nn.Module):
    def __init__(self, input_size=None, return_stats=False, affine=True, eps=1e-5):
        super(PONO, self).__init__()
        self.return_stats = return_stats
        self.input_size = input_size
        self.eps = eps
        self.affine = affine

        if affine:
            self.beta = nn.Parameter(torch.zeros(1, 1, *input_size))
            self.gamma = nn.Parameter(torch.ones(1, 1, *input_size))
        else:
            self.beta, self.gamma = None, None

    def forward(self, x):
        mean = x.mean(dim=1, keepdim=True)
        std = (x.var(dim=1, keepdim=True) + self.eps).sqrt()
        x = (x - mean) / std
        if self.affine:
            x = x * self.gamma + self.beta
        return x, mean, std

class MS(nn.Module):
    def __init__(self, beta=None, gamma=None):
        super(MS, self).__init__()
        self.gamma, self.beta = gamma, beta

    def forward(self, x, beta=None, gamma=None):
        beta = self.beta if beta is None else beta
        gamma = self.gamma if gamma is None else gamma
        if gamma is not None:
            y = x.mul(gamma)  # 使用非原地操作mul
        else:
            y = x  # 如果不乘gamma,保持y不变
        if beta is not None:
            y = y.add(beta)  # 使用非原地操作add

        return y



if __name__ == "__main__":
    # Generating Sample image
    image_size = (1, 3, 640, 640)
    image = torch.rand(*image_size)
    out = AOD_pono_net()
    out = out(image)
    print(out.size())


四、添加教程

本文的网络结构为无参数网络结构所以我们的注册方式很简单。

4.1 修改一

第一还是建立文件,我们找到如下ultralytics/nn/modules文件夹下建立一个目录名字呢就是'Addmodules'文件夹(用群内的文件的话已经有了无需新建)!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。


4.2 修改二 

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'(用群内的文件的话已经有了无需新建),然后在其内部导入我们的检测头如下图所示。


4.3 修改三 

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块(用群内的文件的话已经有了无需重新导入直接开始第四步即可)

从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!

到此就修改完成了,大家可以复制下面的yaml文件运行,无需修改parse_model方法。。


五、AODNet-PONO-Net的yaml文件和运行记录

5.1 AODNet-PONO-Net的yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, AOD_pono_net, []]  # 0-P1/2
  - [-1, 1, Conv, [64, 3, 2]]  # 1-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 2-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 4-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 6-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 8-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 10

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 7], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 5], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)


5.2 训练代码 

大家可以创建一个py文件将我给的代码复制粘贴进去,配置好自己的文件路径即可运行。

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO('ultralytics/cfg/models/v8/yolov8-C2f-FasterBlock.yaml')
    # model.load('yolov8n.pt') # loading pretrain weights
    model.train(data=r'替换数据集yaml文件地址',
                # 如果大家任务是其它的'ultralytics/cfg/default.yaml'找到这里修改task可以改成detect, segment, classify, pose
                cache=False,
                imgsz=640,
                epochs=150,
                single_cls=False,  # 是否是单类别检测
                batch=4,
                close_mosaic=10,
                workers=0,
                device='0',
                optimizer='SGD', # using SGD
                # resume='', # 如过想续训就设置last.pt的地址
                amp=False,  # 如果出现训练损失为Nan可以关闭amp
                project='runs/train',
                name='exp',
                )


5.3 训练过程截图 


五、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv8改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

 专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/452023.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

网康科技 NS-ASG 应用安全网关 SQL注入漏洞复现(CVE-2024-2330)

0x01 产品简介 网康科技的NS-ASG应用安全网关是一款软硬件一体化的产品,集成了SSL和IPSec,旨在保障业务访问的安全性,适配所有移动终端,提供多种链路均衡和选择技术,支持多种认证方式灵活组合,以及内置短信认证、LDAP令牌、USB KEY等多达13种认证方式。 0x02 漏洞概述 …

Redis底层数据结构之String

文章目录 1. 前提回顾2. RedisObject三大数据类型简介3. SDS字符串4. SDS字符串源码分析5. 总结 1. 前提回顾 前面我们说到redis的String数据结构在底层有多种编码方式。例如我们执行下面两条语句 set k1 v1 set age 17我们查看类型,发现这类型都是String类型 我们…

【力扣精选算法100道】——二进制求和

LCR 002. 二进制求和 - 力扣(LeetCode) 目录 🎈了解题意 🎈算法分析 🚩cur1>0 🚩cur2>0 🚩t 🎈实现代码 🎈了解题意 遵循二进制加法法则,如果俩…

工具篇--分布式定时任务springBoot 整合 elasticjob使用(3)

文章目录 前言一、Springboot 整合:1.1 引入jar:1.2 配置zookeeper 注册中心:1.3 定义job 业务类:1.4 job 注册到zookeeper:1.5 项目启动:1.5.1 zookeeper 注册中心实例:1.5.2 任务执行日志输出…

【数据挖掘】练习1:R入门

课后作业1:R入门 一:习题内容 1.要与R交互必须安装Rstudio,这种说法对不对? 不对。虽然RStudio是一个流行的R交互集成开发环境,但并不是与R交互的唯一方式。 与R交互可以采用以下几种方法: 使用R Conso…

AHU 汇编 实验六

一、实验名称:实验6 输入一个16进制数,把它转换为10进制数输出 实验目的: 培养汇编中设计子程序的能力 实验过程: 源代码: data segmentbuff1 db Please input a number(H):$buff2 db 30,?,30 dup(?),13,10buff3 …

社交革命的引领者:探索Facebook如何改变我们的生活方式

1.数字社交的兴起 随着互联网的普及,社交媒体成为我们日常生活的重要组成部分。Facebook作为其中的先驱,从最初的社交网络演变成了一个拥有数十亿用户的全球化平台。它不仅改变了我们与世界互动的方式,还深刻影响了我们的社交习惯、人际关系以…

数据结构:树和二叉树

树的概念 1.树是一种非线性的数据结构。它是由n个有限节点的集合。 2.树分为根节点和子树。根节点没有前驱节点。 3.树的子树是由一个个子树组成,它们可以看作一个个集合。每个集合下面又有集合。 因此,树是递归定义的。 树形结构中,子树…

搜索引擎SEO策略介绍

baidu搜索:如何联系八爪鱼SEO baidu搜索:如何联系八爪鱼SEO baidu搜索:如何联系八爪鱼SEO 第一、 关键词的选择策略: 1、门户类的网站关键词选择策略: 网站每个页面本身基本都包含有关键词:网站拥有上百…

嵌入式数据库SQlite3-进阶篇

嵌入式数据库sqlite3 - HQ 文章目录 嵌入式数据库sqlite3 - HQ[toc] 嵌入式数据库sqlite3【进阶篇】数据库准备order子句Where 子句与逻辑运算符语法实例 group by子句having子句举例 函数SQLite COUNT 函数SQLite MAX 函数SQLite MIN 函数SQLite AVG 函数SQLite SUM 函数SQLit…

Qt 使用RAW INPUT获取HID触摸屏,笔设备,鼠标的原始数据,最低受支持的客户端:Windows XP [仅限桌面应用]

在开发绘图应用程序时,经常会需要读取笔设备的数据,通过对笔数据的解析,来判断笔的坐标,粗细。如果仅仅只是读取鼠标的坐标,就需要人为在应用程序端去修改笔的粗细,并且使用体验不好,如果可以实…

【C++】STL(五) Stack Queue容器

5、 stack容器 5.1 简介 ① stack是一种先进后出的容器,它只有一个出口。 ② 栈中只有顶端的元素才可以被外界使用,因此栈不允许有遍历行为。 ③ 栈中进入数据称为:入栈 push ④ 栈中弹出数据称为:出栈 pop 5.2 常用接口 …

Fair Data Exchange:区块链实现的原子式公平数据交换

1. 引言 2024年斯坦福大学和a16z crypto research团队 论文 Atomic and Fair Data Exchange via Blockchain 中,概述了一种构建(包含过期EIP-4844 blobs的)fair data-markets的协议。该论文源自a16z crypto的暑期实习计划,与四名…

R语言tidycmprsk包分析竞争风险模型

竞争风险模型就是指在临床事件中出现和它竞争的结局事件,这是事件会导致原有结局的改变,因此叫做竞争风险模型。比如我们想观察患者肿瘤的复发情况,但是患者在观察期突然车祸死亡,或者因其他疾病死亡,这样我们就观察不…

KAFKA入门教程

目录 1.安装kafka 2.安装kafkamanager可视化工具 3.springboot整合kafka 1.pom导包 2.启动类和yml配置 3.代码演示 编写生产者: 消费者: 1.安装kafka 进入kafka官网下载对应版本kafka kafka官网地址:Apache Kafka kafka是使用Scal…

Kotlin 数据解析(Gson)

一、添加依赖 build.gradle.kts(:app) // gson数据解析implementation("com.google.code.gson:gson:2.8.6") 对象类: // 对象类 class Account {var uid:String "00001"var userName:String "Freeman"var password:String &quo…

Midjourney能让角色保持一致了

Midjourney发布新功能,网友直呼“不可思议”! 现在你可以让生成的图像几乎保持角色一致,belike: 所有超级英雄长一个模样盯着你。 甚至动漫风、写实风等跨风格生成也同样适用: 保持同一风格,感jio配上文字…

【python】自动化工具Selenium与playwright去除webdriver检测

对这个世界如果你有太多的抱怨 跌倒了就不敢继续往前走 为什么人要这么的脆弱 堕落 请你打开电视看看 多少人为生命在努力勇敢的走下去 我们是不是该知足 珍惜一切 就算没有拥有 🎵 周杰伦《稻香》 # -*- coding:utf-8 -*- import timefrom s…

【C语言】如何规避野指针

✨✨ 欢迎大家来到莉莉的博文✨✨ 🎈🎈养成好习惯,先赞后看哦~🎈🎈 目录 一、概念: 二、野指针成因: 2.1. 指针未初始化 2.2 指针越界访问 3. 指针指向的空间释放 三、如何规避野指针 3.…

Manacher 算法——Leetcode 5.最长回文子串

在了解之前,我们先要了解什么是回文串,什么是回文子串。 回文串和回文子串: 回文串是指一个字符串正序遍历和反向遍历结果相同的字符串。如 ABBA,正着读反着读结果是一样的。 有了回文串的概念,回文子串的概念也就显…