【低照度图像增强系列(7)】RDDNet算法详解与代码实现(同济大学|ICME)

前言

☀️ 在低照度场景下进行目标检测任务,常存在图像RGB特征信息少提取特征困难目标识别和定位精度低等问题,给检测带来一定的难度。

     🌻使用图像增强模块对原始图像进行画质提升,恢复各类图像信息,再使用目标检测网络对增强图像进行特定目标检测,有效提高检测的精确度。

      ⭐本专栏会介绍传统方法、Retinex、EnlightenGAN、SCI、Zero-DCE、IceNet、RRDNet、URetinex-Net等低照度图像增强算法。

👑完整代码已打包上传至资源→低照度图像增强代码汇总

目录

前言

🚀一、RDDNet介绍 

☀️1.1 RDDNet简介   

研究背景 

算法框架 

损失函数

🚀二、RDDNet核心代码

 ☀️2.1 网络模型—RRDNet.py

 ☀️2.2 损失函数—loss_functions.py

(1)重构损失——reconstruction_loss

(2)光照损失——illumination_smooth_loss

(3)反射损失——reflectance_smooth_loss

(4)噪声损失——noise_loss

  ☀️2.3 Retinex操作—pipline.py

🚀三、RDDNet代码复现

☀️3.1 环境配置

☀️3.2 运行过程

☀️3.3 运行效果

 

🚀一、RDDNet介绍 

学习资料:

  • 论文题目:《ZERO-SHOT RESTORATION OF UNDEREXPOSED IMAGES VIA ROBUST RETINEX DECOMPOS》(通过鲁棒性 Retinex 分解对曝光不足的图像进行零样本恢复)
  • 论文讲解:ICME| RRDNet《ZERO-SHOT RESTORATION OF UNDEREXPOSED IMAGES VIA ROBUST RETINEX DECOMPOS》论文超详细解读(翻译+精读)
  • 原文地址:Zero-Shot Restoration of Underexposed Images via Robust Retinex Decomposition | IEEE Conference Publication | IEEE Xplore
  • 源码地址:代码export.arxiv.org/pdf/2109.05838v2.pdf

☀️1.1 RDDNet简介   

RRDNet同济大学在2020年提出来的一种新的三分支全卷积神经网络,认为图像由三部分构成:光照分量反射分量噪声分量。在没有pair对的情况下实现低光图像增强,通过对loss进行迭代来有效估计出噪声和恢复光照。 

研究背景 

  • 曝光不足的图像由于能见度差和黑暗中的潜在噪声,通常会出现严重的质量下降。
  • 现有的图像增强方法忽略了噪声,因此使用带噪声分量的Retinex模型作为基础。
  • 基于学习(数据驱动)的方法限制了模型的泛化能力,因此提出zero-shot的学习模式。

算法框架 

  1. 通过三分支网络把输入图像分解为反射图、光照图和噪声图三个分量。
  2. 通过Gamma变换调整光照图,再计算得到无噪声的反射图。
  3. 结合光照图和反射图,重构得到最终结果。 

损失函数

1. Retinex重构损失,取最大通道值作为初始光照图,用来约束光照图。在光照图的基础上约束反射图和噪声。

2. 纹理增强损失,通过平滑光照图可以帮助增强反射图的纹理。具体损失公式是带有权重的总变分损失,权重的设计规则是,梯度大的地方权重小,即权重与梯度成负相关即可,这里是将梯度经过高斯滤波放在分母。

3. 光照指导的噪声损失,根据噪声随着光照的变大而变大的假设,可以使用光照图来做权重指导,其次考虑两点:

(1)假定噪声范围限定

(2)通过平滑反射图来得到噪声,本身并没有直接得到噪声的损失,只是通过对反射图做总变分约束来去噪


🚀二、RDDNet核心代码

 代码框架如图所示:

(图片来源:【代码笔记】RRDNet 网络-CSDN博客 谢谢大佬!@chaiky) 

 ☀️2.1 网络模型—RRDNet.py

import torch
import torch.nn as nn

class RRDNet(nn.Module):
    def __init__(self):
        super(RRDNet, self).__init__()

 #----------- 1.illumination(光照估计)---------------------------#
        self.illumination_net = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(64, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 1, 3, 1, 1),

        )

 #----------- 2.reflectance(反射率估计)---------------------------#
        self.reflectance_net = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(64, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 3, 3, 1, 1)
        )

 #----------- 3.noise(噪声估计)---------------------------#
        self.noise_net = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(64, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 3, 3, 1, 1)
        )

    def forward(self, input):
        illumination = torch.sigmoid(self.illumination_net(input))
        reflectance = torch.sigmoid(self.reflectance_net(input))
        noise = torch.tanh(self.noise_net(input))

        return illumination, reflectance, noise

  我们可以对照上图左边的结构来理解代码。

  • illumination_net:  主要是负责对输入图像进行处理以获取光照信息,包括一系列卷积层和ReLU激活函数,最终输出一个通道数为1的图像,表示光照强度

  • reflectance_net:  主要是负责提取输入图像的反射率信息,同样包括一系列卷积层和ReLU激活函数,最终输出一个通道数为3的图像,表示反射率在RGB通道上的分布。

  • noise_net:  主要是则用于估计输入图像的噪声信息,同样由一系列卷积层和ReLU激活函数组成,最终输出一个通道数为3的图像,表示噪声在RGB通道上的分布。

 最后,illumination_netreflectance_net的输出经过sigmoid函数处理,而noise_net的输出则经过tanh函数处理。


 ☀️2.2 损失函数—loss_functions.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import conf

 #----------- 1.reconstruction_loss:计算重构损失---------------------------#
def reconstruction_loss(image, illumination, reflectance, noise):
    reconstructed_image = illumination*reflectance+noise
    return torch.norm(image-reconstructed_image, 1)

 #----------- 2.gradient: 计算输入图像在水平和垂直方向上的梯度--------------------#
def gradient(img):
    height = img.size(2)
    width = img.size(3)
    gradient_h = (img[:,:,2:,:]-img[:,:,:height-2,:]).abs()
    gradient_w = (img[:, :, :, 2:] - img[:, :, :, :width-2]).abs()
    gradient_h = F.pad(gradient_h, [0, 0, 1, 1], 'replicate')
    gradient_w = F.pad(gradient_w, [1, 1, 0, 0], 'replicate')
    gradient2_h = (img[:,:,4:,:]-img[:,:,:height-4,:]).abs()
    gradient2_w = (img[:, :, :, 4:] - img[:, :, :, :width-4]).abs()
    gradient2_h = F.pad(gradient2_h, [0, 0, 2, 2], 'replicate')
    gradient2_w = F.pad(gradient2_w, [2, 2, 0, 0], 'replicate')
    return gradient_h*gradient2_h, gradient_w*gradient2_w

 #----------- 3.normalize01: 将输入图像进行归一化到0到1的范围内---------------------#
def normalize01(img):
    minv = img.min()
    maxv = img.max()
    return (img-minv)/(maxv-minv)

 #----------- 4.gaussianblur3: 3通道的高斯模糊---------------------------#
def gaussianblur3(input):
    slice1 = F.conv2d(input[:,0,:,:].unsqueeze(1), weight=conf.gaussian_kernel, padding=conf.g_padding)
    slice2 = F.conv2d(input[:,1,:,:].unsqueeze(1), weight=conf.gaussian_kernel, padding=conf.g_padding)
    slice3 = F.conv2d(input[:,2,:,:].unsqueeze(1), weight=conf.gaussian_kernel, padding=conf.g_padding)
    x = torch.cat([slice1,slice2, slice3], dim=1)
    return x

 #----------- 5.illumination_smooth_loss: 计算光照平滑损失---------------------------#
def illumination_smooth_loss(image, illumination):
    gray_tensor = 0.299*image[0,0,:,:] + 0.587*image[0,1,:,:] + 0.114*image[0,2,:,:]
    max_rgb, _ = torch.max(image, 1)
    max_rgb = max_rgb.unsqueeze(1)
    gradient_gray_h, gradient_gray_w = gradient(gray_tensor.unsqueeze(0).unsqueeze(0))
    gradient_illu_h, gradient_illu_w = gradient(illumination)
    weight_h = 1/(F.conv2d(gradient_gray_h, weight=conf.gaussian_kernel, padding=conf.g_padding)+0.0001)
    weight_w = 1/(F.conv2d(gradient_gray_w, weight=conf.gaussian_kernel, padding=conf.g_padding)+0.0001)
    weight_h.detach()
    weight_w.detach()
    loss_h = weight_h * gradient_illu_h
    loss_w = weight_w * gradient_illu_w
    max_rgb.detach()
    return loss_h.sum() + loss_w.sum() + torch.norm(illumination-max_rgb, 1)

 #----------- 6.reflectance_smooth_loss:计算反射率平滑损失---------------------------#
def reflectance_smooth_loss(image, illumination, reflectance):
    gray_tensor = 0.299*image[0,0,:,:] + 0.587*image[0,1,:,:] + 0.114*image[0,2,:,:]
    gradient_gray_h, gradient_gray_w = gradient(gray_tensor.unsqueeze(0).unsqueeze(0))
    gradient_reflect_h, gradient_reflect_w = gradient(reflectance)
    weight = 1/(illumination*gradient_gray_h*gradient_gray_w+0.0001)
    weight = normalize01(weight)
    weight.detach()
    loss_h = weight * gradient_reflect_h
    loss_w = weight * gradient_reflect_w
    refrence_reflect = image/illumination
    refrence_reflect.detach()
    return loss_h.sum() + loss_w.sum() + conf.reffac*torch.norm(refrence_reflect - reflectance, 1)

 #----------- 7.noise_loss: 计算噪声损失---------------------------#
def noise_loss(image, illumination, reflectance, noise):
    weight_illu = illumination
    weight_illu.detach()
    loss = weight_illu*noise
    return torch.norm(loss, 2)
(1)重构损失——reconstruction_loss

图像的分解组件必须满足Robust Retinex的公式,将RGB三个通道中最大强度值S的初始值,在此基础上约束反射图和噪声。

(2)光照损失——illumination_smooth_loss

通过平滑的光照图可以增强暗区域的纹理细节,公式中x和y是水平和垂直方向,Wx和Wy是确保图像平滑的权重参数。

权重与梯度呈反比,梯度大的地方权重小,梯度小的地方权重大,因此将高斯滤波G放在分母,这里公式中的I是输入图像转换成的灰度图,Wy的计算方式和Wx的相同。

(3)反射损失——reflectance_smooth_loss

通过平滑反射图来得到噪声,本身并没有直接得到噪声的损失,只是通过对反射图做总变分约束来去噪。

(4)噪声损失——noise_loss

为了增加图像的清晰度增加了图像的对比度,与此同时,图像的噪声也被放大,出于以下两点限制噪声:

  1. 噪声的范围需要被限制。
  2. 噪声可以平滑的反射图限制。


  ☀️2.3 Retinex操作—pipline.py

import os
import numpy as np
import cv2
import torch
import torch.optim as optim
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import torch.nn.init as init

from model.RRDNet import RRDNet
from loss.loss_functions import reconstruction_loss, illumination_smooth_loss, reflectance_smooth_loss, noise_loss, normalize01
import conf

 #----------- retinex图像增强---------------------------#
def pipline_retinex(net, img):
    img_tensor = transforms.ToTensor()(img)  # [c, h, w] #将输入图像转换为张量,并调整形状
    img_tensor = img_tensor.to(conf.device)
    img_tensor = img_tensor.unsqueeze(0)     # [1, c, h, w]

    optimizer = optim.Adam(net.parameters(), lr=conf.lr)

    # iterations:迭代优化过程
    for i in range(conf.iterations+1):
        # forward:通过网络前向传播得到光照、反射率和噪声图像。
        illumination, reflectance, noise = net(img_tensor)  # [1, c, h, w]
        # loss computing:计算总损失,并进行反向传播优化网络参数。
        loss_recons = reconstruction_loss(img_tensor, illumination, reflectance, noise)  # 重构损失
        loss_illu = illumination_smooth_loss(img_tensor, illumination) # 光照损失
        loss_reflect = reflectance_smooth_loss(img_tensor, illumination, reflectance) #反射损失
        loss_noise = noise_loss(img_tensor, illumination, reflectance, noise) # 噪声损失

        loss = loss_recons + conf.illu_factor*loss_illu + conf.reflect_factor*loss_reflect + conf.noise_factor*loss_noise

        # backward
        net.zero_grad()
        loss.backward()
        optimizer.step()

        # log:每隔 100 次迭代打印日志,显示重建损失、光照损失、反射率损失和噪声损失的数值。
        if i%100 == 0:
            print("iter:", i, '  reconstruction loss:', float(loss_recons.data), '  illumination loss:', float(loss_illu.data), '  reflectance loss:', float(loss_reflect.data), '  noise loss:', float(loss_noise.data))


    # adjustment:对增强后的图像进行调整
    adjust_illu = torch.pow(illumination, conf.gamma)
    res_image = adjust_illu*((img_tensor-noise)/illumination)# 对增强后的图像进行调整
    res_image = torch.clamp(res_image, min=0, max=1)# 对调整后的图像进行限幅操作,确保像素值在 0 到 1 之间。

    if conf.device != 'cpu':
        res_image = res_image.cpu()
        illumination = illumination.cpu()
        adjust_illu = adjust_illu.cpu()
        reflectance = reflectance.cpu()
        noise = noise.cpu()
    
    # 将处理后的张量转换为 PIL 图像
    res_img = transforms.ToPILImage()(res_image.squeeze(0))
    illum_img = transforms.ToPILImage()(illumination.squeeze(0))
    adjust_illu_img = transforms.ToPILImage()(adjust_illu.squeeze(0))
    reflect_img = transforms.ToPILImage()(reflectance.squeeze(0))
    noise_img = transforms.ToPILImage()(normalize01(noise.squeeze(0)))

    return res_img, illum_img, adjust_illu_img, reflect_img, noise_img


if __name__ == '__main__':

    # Init Model
    net = RRDNet()
    net = net.to(conf.device)

    # Test
    img = Image.open(conf.test_image_path)

    res_img, illum_img, adjust_illu_img, reflect_img, noise_img = pipline_retinex(net, img)
    res_img.save('./test/result.jpg')
    illum_img.save('./test/illumination.jpg')
    adjust_illu_img.save('./test/adjust_illumination.jpg')
    reflect_img.save('./test/reflectance.jpg')
    noise_img.save('./test/noise_map.jpg')

这段代码基本都注释了,就不再详细讲解了~


🚀三、RDDNet代码复现

☀️3.1 环境配置

  • Python 3
  • PyTorch >= 0.4.1
  • PIL >= 6.1.0
  • Opencv-python>=3.4

☀️3.2 运行过程

这个也是运行比较简单,配好环境就行 。不再过多叙述~


☀️3.3 运行效果

没错,你怎么知道我去看邓紫棋演唱会啦~ 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/637196.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue项目中如何使用iconfont

导读:vue项目中引入iconfont的方式 iconfont 的三种使用方法 unicode 不常用Font class 像字体一样使用,默认黑色图标,无法修改颜色Symbol 支持多色图标,更灵活,推荐 一、unicode 略 二、Font class 方式一&#…

完美解决原生小程序点击地图markers上的点获取不到对应的坐标信息

需求:地图上有多个markes点,点击每一个获取对应的数据,再根据当前的坐标信息去调用导航。 出现的问题:每次点击的时候获取不到对应的坐标信息,获取到的信息显然不是想要的 原因: 因为你的id不是number类型&…

阿里云手机adb远程连接出现adb问题unauthorized解决

执行adb shell出现下面错误 adb.exe: device unauthorized. This adb servers $ADB_VENDOR_KEYS is not set Try adb kill-server if that seems wrong. Otherwise check for a confirmation dialog on your device.解决:导入和绑定adb的密钥 重启云手机

[Redis]常见数据和内部编码

相关命令 type (key) type 命令实际返回的就是当前键的数据结构类型,它们分别是:string(字符串)、list(列 表)、hash(哈希)、set(集合)、zset(有…

[36#]私有化部署地图套装(全球版)

私有化部署地图套装(全球版),是由全球高清卫星影像与100%全球水陆覆盖高程数据组成的全球三维地图套装。 私有化部署地图套装(全球版) 我们在《难以置信,谁还会用离线地球》一文中,为大家分享…

7 Series FPGAs Integrated Block for PCI Express IP核 Advanced模式配置详解(三)

1 TL Settings Transaction Layer (TL)设置只在Advanced模式下有效。 Endpoint: Unlock and PME_Turn_Off Messages: 与端点的电源管理相关,允许发送解锁和电源管理事件关闭消息。 Root Port: Error Messages: Error Correctable(错误可纠正&#xff09…

IO游戏设计思路

1、TCP ,UDP ,KCP ,QUIC TCP 协议最常用的协议 UDP协议非常规的协议,因为需要在线广播,貌似运营商会有一些影响 KCP 基于UDP的协议,GitHub - l42111996/java-Kcp: 基于java的netty实现的可靠udp网络库(kcp算法),包含fec实现&am…

增强版 Kimi:AI 驱动的智能创作平台,实现一站式内容生成(图片、PPT、PDF)!

前言 基于扣子 Coze 零代码平台,我们从零到一轻松实现了专属 Bot 机器人的搭建。 AI 大模型(LLM)、智能体(Agent)、知识库、向量数据库、知识图谱,RAG,AGI 的不同形态愈发显现,如何…

GEO数据挖掘-PCA、差异分析

From 生物技能树 GEO数据挖掘第二节 文章目录 探针注释自主注释流程(了解)PCA图、top1000基因热图探针注释查看示例代码 top 1000 sd 热图离散基因热图,top1000表达基因,只是看一下,不用放文章里 差异分析火山图差异基因热图转换id富集分析-K…

安装mpi4py与dlio_profiler_py的总结

安装mpi4py mpi4py是一个Python库,它提供了与MPI(Message Passing Interface)兼容的接口,使得Python程序能够利用MPI实现并行计算。mpi4py 的核心是基于MPI标准的C/C实现,它能够在高性能计算环境下进行高效的并行处理…

网页版收银系统比安装板收银系统的四大优势

在当今竞争激烈的零售市场中,高效的收银系统对于连锁实体店的管理至关重要。随着科技的不断发展,网页版收银系统成为越来越多零售企业的首选。网页版收银系统以其灵活性、可定制性和便利性,成为现代零售业的利器。本文将探讨网页版收银系统相…

pycharm 关闭项目卡死

PyCharm2023.3.4 关闭一直卡在 closing projects 解决办法: 打开PyCharm, 选择 Help -> Find Action -> 输入 Registry -> 禁用ide.await.scope.completion

MYSQL 集群

1.集群目的:负载均衡 解决高并发 高可用HA 服务可用性 远程灾备 数据有效性 类型:M M-S M-S-S M-M M-M-S-S 原理:在主库把数据更改(DDL DML DCL)记录到二进制日志中。 备库I/O线程将主库上的日志复制到自己的中继日志中。 备库SQL线程读取中继日志…

51cto已购买的视频怎么下载到电脑上?

在数字学习的浪潮中,51CTO已成为众多专业人士和爱好者的知识宝库。但购买了视频课程后,如何将其下载到电脑上以便离线学习呢?这不仅是技术问题,更是时间管理和学习效率的关键。本文将为您揭示简单而高效的步骤,无论您使…

前端面试项目细节重难点(已工作|做分享)

面试官提问:需求场景:页面上有一个单选框,有是否两个选项:当用户选择是,出现一个输入框,用户可以输入内容,给后端的保存接口传入参数radio和content这两个字段,值分别是用户选项和输…

西门子WINCC8.0VBS脚本学习讲解

WinCC VBS脚本置位/复位/取反 二进制变量 "TAG1_BOOL1" 进行置位复位取反操作 步骤:按钮-->对象属性-->事件-->单击鼠标VBS动作填入代码如下: 对二进制变量进行复位 对二进制变量进行置位 对二进制变量进行取反 VBS脚本数学运算/读写批处理 …

百度智能云参与信通院多项边缘计算标准编制,「大模型时代下云边端协同 AI 发展研讨会」成功召开

1 中国信通院联合业界制定、发布多项标准化成果,推动产业发展 大模型开启了 AI 原生时代,云边端协同 AI 构建了「集中式大规模训练」、「边缘分布式协同推理」新范式,有效降低推理时延和成本,提升数据安全和隐私性,也…

安卓App封装全攻略:利用小猪APP分发提升应用发布效率

在快速迭代的移动应用市场,高效且安全地分发安卓应用程序是开发者面临的一大挑战。安卓App封装技术,作为这一挑战的解决方案之一,不仅能够提升应用的安全性,还能简化分发流程。本文将深入探讨安卓App封装的核心概念,以…

小型发电机不发电原因和解决方法

小型发电机不发电可能由多种原因造成,以下是一些常见原因及其解决方法: 1.电池电量不足:小型发电机通常需要电池来启动。如果电池电量不足,可能导致发电机无法启动。此时,您可以使用充电设备对电池进行充电&#xff0…