Python图像处理【22】基于卷积神经网络的图像去雾

基于卷积神经网络的图像去雾

    • 0. 前言
    • 1. 渐进特征融合网络
    • 2. 图像去雾
      • 2.1 网络构建
      • 2.2 模型测试
    • 小结
    • 系列链接

0. 前言

单图像去雾 (dehazing) 是一个具有挑战性的图像恢复问题。为了解决这个问题,大多数算法都采用经典的大气散射模型,该模型是一种基于单一散射和均匀大气介质假设的简化物理模型,但现实环境中的雾霾表述更加复杂。

1. 渐进特征融合网络

在本节中,我们将学习如何使用输入自适应端到端深度学习预训练去雾模型,即渐进特征融合网络 (Progressive Feature Fusion Network, PFFNet),并通过使用 Pytorch 来执行模糊图像的去雾操作。渐进特征融合所采用的 U-Net 架构编码器 - 解码器网络,可直接学习从模糊图像到清晰图像的高度非线性转换函数。深度神经网架构如下图所示:

PFFNet
从以上体系结构图可以看出:

  • 编码器由五个卷积层组成,每个卷积层之后都有非线性 ReLU 激活函数;第一层用于从原始模糊图像中相对较大的局部感受野上的提取特征,然后,依次执行四次下采样卷积操作,以获取图像金字塔
  • 特征转换模块由基于残差的模块组成,深层网络可以表示非常复杂的特征,也可以学习到许多不同尺度的特征,但同时,在使用反向传播进行训练时,经常会遇到消失的梯度问题,而残差网络就是为了解决这一问题而被提出的,可以用于训练更深的网络
  • 解码器由四个反卷积层和一个卷积层组成,与编码器相反,解码器的反卷积层顺序堆叠以恢复图像结构细节

2. 图像去雾

2.1 网络构建

(1) 首先下载预训练网络模型,并导入所需的库,模块和函数:

import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage, Normalize, Resize
#from torchviz import make_dot
import matplotlib.pylab as plt 

(2) 定义与深神经网络中不同层相对应的 ConvLayerUpsampleConvLayer 类,所有网络层都继承自 Pytorchnn.module 类;每个层都需要实现自己的 init() (用于初始化参数/成员变量/层)和 forward() 方法(定义前向传播过程中的计算):

class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

class UpsampleConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
      super(UpsampleConvLayer, self).__init__()
      reflection_padding = kernel_size // 2
      self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
      self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

(3) 接下来,我们用两个 ConvLayer 类实例定义类 ResidualBlock,在 ConvLayer 类实例之间使用 PReLU 激活函数,该类同样继承自 nn.module,并定义 forward() 方法用于前向传播:

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.relu = nn.PReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out) * 0.1
        out = torch.add(out, residual)
        return out 

(4) 定义继承自 nn.conv2d 类的 MeanShift 类,通过将 requires_grad 的参数设置为 False,冻结 MeanShift 层:

class MeanShift(nn.Conv2d):
    def __init__(self, rgb_range, rgb_mean, sign):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.bias.data = float(sign) * torch.Tensor(rgb_mean) * rgb_range

        # Freeze the MeanShift layer
        for params in self.parameters():
            params.requires_grad = False

(5) 最后,根据所定义的神经网络层定义深度神经网络类 Net,该类同样需要定义 init() 方法。网络使用了五个 ConvLayer,然后使用四个 UPSampleconvLayer,最后通过 ConvLayer 层后输出,网络使用 LeakyReLU 作为激活函数。
同样,需要定义向前传播方法 forward(),并在每个激活函数后使用双线性上采样:

class Net(nn.Module):
    def __init__(self, res_blocks=18):
        super(Net, self).__init__()

        rgb_mean = (0.5204, 0.5167, 0.5129)
        self.sub_mean = MeanShift(1., rgb_mean, -1)
        self.add_mean = MeanShift(1., rgb_mean, 1)

        self.conv_input = ConvLayer(3, 16, kernel_size=11, stride=1)
        self.conv2x = ConvLayer(16, 32, kernel_size=3, stride=2)
        self.conv4x = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.conv8x = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.conv16x = ConvLayer(128, 256, kernel_size=3, stride=2)

        self.dehaze = nn.Sequential()
        for i in range(1, res_blocks):
            self.dehaze.add_module('res%d' % i, ResidualBlock(256))

        self.convd16x = UpsampleConvLayer(256, 128, kernel_size=3, stride=2)
        self.convd8x = UpsampleConvLayer(128, 64, kernel_size=3, stride=2)
        self.convd4x = UpsampleConvLayer(64, 32, kernel_size=3, stride=2)
        self.convd2x = UpsampleConvLayer(32, 16, kernel_size=3, stride=2)

        self.conv_output = ConvLayer(16, 3, kernel_size=3, stride=1)
()
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.relu(self.conv_input(x))
        res2x = self.relu(self.conv2x(x))
        res4x = self.relu(self.conv4x(res2x))

        res8x = self.relu(self.conv8x(res4x))
        res16x = self.relu(self.conv16x(res8x))

        res_dehaze = res16x
        res16x = self.dehaze(res16x)
        res16x = torch.add(res_dehaze, res16x)

        res16x = self.relu(self.convd16x(res16x))
        res16x = F.upsample(res16x, res8x.size()[2:], mode='bilinear')
        res8x = torch.add(res16x, res8x)

        res8x = self.relu(self.convd8x(res8x))
        res8x = F.upsample(res8x, res4x.size()[2:], mode='bilinear')
        res4x = torch.add(res8x, res4x)

        res4x = self.relu(self.convd4x(res4x))
        res4x = F.upsample(res4x, res2x.size()[2:], mode='bilinear')
        res2x = torch.add(res4x, res2x)

        res2x = self.relu(self.convd2x(res2x))
        res2x = F.upsample(res2x, x.size()[2:], mode='bilinear')
        x = torch.add(res2x, x)

        x = self.conv_output(x)

        return x

(6) 定义预训练模型参数位置以及模型使用的残差块数量:

rb = 13
checkpoint = "I-HAZE_O-HAZE.pth"

(7) 实例化 Net() 类并使用 load_state_dict() 方法从检查点加载预训练权重。由于我们不需要训练模型,因此使用测试模式:

net = Net(rb)
net.load_state_dict(torch.load(checkpoint)['state_dict'])
net.eval()

2.2 模型测试

(1) 接下来,使用 open() 函数读取输入图像:

im_path = "pic.png"
im = Image.open(im_path)
h, w = im.size
print(h, w)

(2) 使用 torchvision.transforms 模块中的 ToTensor() 将图像转换为张量对象以输入网络,然后使用输入图像在模型上运行正向传递过程计算输出,最后将输出转换为图像:

imt = ToTensor()(im)
imt = Variable(imt).view(1, -1, w, h)
#im = im.cuda()
with torch.no_grad():
    imt = net(imt)
out = torch.clamp(imt, 0., 1.)
out = out.cpu()
out = out.data[0]
out = ToPILImage()(out)

def plot_image(image, title=None, sz=10):
    plt.imshow(image)
    plt.title(title, size=sz)
    plt.axis('off')
plt.figure(figsize=(20,10))
plt.subplot(121), plot_image(im, 'hazed input')
plt.subplot(122), plot_image(out, 'de-hazed output')
plt.tight_layout()
plt.show() 

去雾结果

小结

图像去雾已成为计算机视觉的重要研究方向,在雾、霾等恶劣天气下拍摄的的图像通常由于大气散射的作用,图像质量严重下降使颜色偏灰白色,对比度降低,物体特征难以辨认,还会影响图像的分析与处理。因此,需要使用图像去雾技术来增强或修复图像,以改善视觉效果并便于图像的后续处理。在本节中,我们学习了一种基于卷积神经网络的图像去雾模型,通过使用训练后的模型可以显著改善图像视觉效果。

系列链接

Python图像处理【1】图像与视频处理基础
Python图像处理【2】探索Python图像处理库
Python图像处理【3】Python图像处理库应用
Python图像处理【4】图像线性变换
Python图像处理【5】图像扭曲/逆扭曲
Python图像处理【6】通过哈希查找重复和类似的图像
Python图像处理【7】采样、卷积与离散傅里叶变换
Python图像处理【8】使用低通滤波器模糊图像
Python图像处理【9】使用高通滤波器执行边缘检测
Python图像处理【10】基于离散余弦变换的图像压缩
Python图像处理【11】利用反卷积执行图像去模糊
Python图像处理【12】基于小波变换执行图像去噪
Python图像处理【13】使用PIL执行图像降噪
Python图像处理【14】基于非线性滤波器的图像去噪
Python图像处理【15】基于非锐化掩码锐化图像
Python图像处理【16】OpenCV直方图均衡化
Python图像处理【17】指纹增强和细节提取
Python图像处理【18】边缘检测详解
Python图像处理【19】基于霍夫变换的目标检测
Python图像处理【20】图像金字塔
Python图像处理【21】基于卷积神经网络增强微光图像

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/447625.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Git的基本操作(安装Git,创建本地仓库,配置Git,添加、修改、回退、撤销修改、删除文件)

文章目录 一、Git安装二、创建本地仓库三、配置Git四、认识工作区、暂存区、本地库五、添加文件六、修改文件七、版本回退八、撤销修改1.对于⼯作区的代码,还没有add2.已经add,但没有commit3.已经add,并且已经commit 九、删除⽂件 一、Git安装…

使用 ReclaiMe Pro 恢复任意文件系统(Win/Linux/MacOS)

天津鸿萌科贸发展有限公司是 ReclaiMe Pro 数据恢复软件授权代理商。 ReclaiMe Pro 是一个通用工具包,几乎可以用于从所有文件系统(从 Windows 系列文件系统、Linux 和 MacOS)中恢复数据。此外,考虑到数据恢复工作的具体情况&…

【组合递归回溯】【StringBuilder】Leetcode 17. 电话号码的字母组合

【组合递归回溯】【StringBuilde】Leetcode 17. 电话号码的字母组合 StringBulider常用方法!!!!!!!!!!!!!!17…

【webpack】和【vite】中获取本地文件夹目录下的所有图片

1. webpack&#xff08;require.context&#xff09; const systemUrls ref<{ url: string; name: string }[]>([]);// 获取该目录下的所有svg文件const files require.context(public/icon, false, /\.svg$/);systemUrls.value files.keys().map((key) > {const f…

ROS 2基础概念#4:消息(Message)| ROS 2学习笔记

ROS 2消息简介 ROS程序使用三种不同的接口来进行沟通&#xff1a;消息&#xff08;message&#xff09;&#xff0c;服务&#xff08;service&#xff09;和动作&#xff08;action&#xff09;。ROS 2使用一种简化的描述语言&#xff1a;IDL&#xff08;interface definition…

【深入理解LRU Cache】:缓存算法的经典之作

目录 一、什么是LRU Cache&#xff1f; 二、LRU Cache的实现 1.JDK中类似LRUCahe的数据结构LinkedHashMap 2.自己实现双向链表 三、LRU Cache的OJ 一、什么是LRU Cache&#xff1f; LRU Cache&#xff08;Least Recently Used的缩写&#xff0c;即最近最少使用&#xff0…

游戏行业需要堡垒机吗?用哪款堡垒机好?

相信大家对于游戏都不陌生&#xff0c;上到老&#xff0c;下到小&#xff0c;越来越多的小伙伴开始玩游戏。随着游戏用户的增加&#xff0c;如何保障用户资料安全&#xff0c;如何确保游戏公司数据安全等是一个不容忽视的问题。因此不少人在问&#xff0c;游戏行业需要堡垒机吗…

数据结构----完全二叉树的时间复杂度讲解,堆排序

目录 一.建堆的时间复杂度 1.向上调整算法建堆 2.向下调整算法建堆 二.堆排序 1.概念 2.代码思路 3.代码实现 一.建堆的时间复杂度 1.向上调整算法建堆 我们就以极端情况考虑时间复杂度(满二叉树遍历所有层) 假设所有节点个数为N,树的高度为h N 2^02^12^2......2^(h-…

表的连接【MySQL】

文章目录 什么是连接测试表内连接外连接左外连接右外连接全外连接 自然连接交叉连接参考资料 什么是连接 数据库的连接是指在数据库系统中&#xff0c;两个或多个数据表之间建立的关联关系&#xff0c;使它们可以进行数据的交互和操作。连接通常基于某种共同的字段或条件&…

2.1_2 数据通信基础知识

文章目录 2.1_2 数据通信基础知识&#xff08;一&#xff09;典型的数据通信模型&#xff08;二&#xff09;数据通信相关术语&#xff08;三&#xff09;设计数据通信系统要考虑的3个问题&#xff08;1&#xff09;三种通信方式&#xff08;2&#xff09;串行传输 & 并行传…

开源的python 游戏开发库介绍

本文将为您详细讲解开源的 Python 游戏开发库&#xff0c;以及它们的特点、区别和应用场景。Python 社区提供了多种游戏开发库&#xff0c;这些库可以帮助您在 Python 应用程序中实现游戏逻辑、图形渲染、声音处理等功能。 1. Pygame 特点 - 基于 Python 的游戏开发库。…

C语言分析基础排序算法——交换排序

目录 交换排序 冒泡排序 快速排序 Hoare版本快速排序 挖坑法快速排序 前后指针法快速排序 快速排序优化 快速排序非递归版 交换排序 冒泡排序 见C语言基础知识指针部分博客C语言指针-CSDN博客 快速排序 Hoare版本快速排序 Hoare版本快速排序的过程类似于二叉树前序…

程序员常用小工具推荐

前言 工作或者学习时&#xff0c;常常有一些工具能帮到我们很多&#xff0c;本次简单列举和说明&#xff0c;如果有更多更好用的&#xff0c;欢迎讨论补充。 工具大全 网络分析工具 Wireshark,可以很清晰的解析和过滤网络包&#xff0c;也有助于分析网络的的传输原理。linux环…

基于FPGA的HyperRam接口设计与实现

一 HyperRAM 针对一些低功耗、低带宽应用&#xff08;物联网、消费产品、汽车和工业应用等&#xff09;&#xff0c;涉及到外部存储&#xff0c;HyperRAM提供了更简洁的内存解决方案。 HyperRAM具有以下特性&#xff1a; 1、超低功耗&#xff1a;200MHz工作频率下读写不到50mW…

新书速览|Vue.js 3.x+Element Plus从入门到精通(视频教学版)

详解Vue.jsElement Plus框架各组件的用法&#xff0c;实战网上商城系统和图书借阅系统开发 本书内容 《Vue.js 3.xElement Plus从入门到精通&#xff1a;视频教学版》通过对Vue.js&#xff08;简称Vue&#xff09;的示例和综合案例的介绍与演练&#xff0c;使读者快速掌握Vue.j…

计算机网络—eNSP搭建基础 IP网络

目录 1.下载eNSP 2.启动eNSP 3.建立拓扑 4.建立一条物理连接 5.进入终端系统配置界面 6.配置终端系统 7.启动终端系统设备 8.捕获接口报文 9.生成接口流量 10.观察捕获的报文 1.下载eNSP 网上有许多下载eNSP的方式&#xff0c;记得还要下其它三个Virtual Box、Winpa…

HSCCTF 3th 2024 Web方向 题解wp

WEB-CHECKIN【*没出】 直接给了源码 <?php highlight_file(__FILE__); error_reporting(0); $a$_POST[1]; $b"php://filter/$a/resource/dev/null"; if(file_get_contents($b)"2024"){echo file_get_contents(/flag); }else{echo $b; }咋这么像 WEB…

python文件组织:包(package)、模块(module)、文件(file)

包&#xff1a; 模块所在的包&#xff0c;创建一个包用于组织多个模块&#xff0c;包文件夹中必须创建一个名为’__init__.py’的文件&#xff0c;以将其识别为包&#xff0c;否则只能算作是一个普通的目录。在使用该包时&#xff0c;init自动执行。 包可以多层嵌套&#xff…

使用 ReclaiMe Pro 进行 RAIDZ 数据恢复

天津鸿萌科贸发展有限公司是 ReclaiMe Pro 数据恢复软件授权代理商。 ZFS 是一个开源文件系统&#xff0c;主要用于 FreeNAS 和 NAS4Free 存储系统。在开发 ZFS 时&#xff0c;主要目标是可靠性&#xff0c;这是通过写时复制、冗余元数据、日志等不同功能来实现的。ZFS 使用自…

Redis核心数据结构之跳跃表

跳跃表 概述 跳跃表(skiplist)是一种有序数据结构&#xff0c;它通过在每个节点中维持多个指向其他节点的指针&#xff0c;从而达到快速访问节点的目的。跳跃表支持平均O(logN)、最坏O(N)复杂度的节点查找&#xff0c;还可以通过顺序性操作来批量处理节点。在大部分情况下&am…