超分辨率重建

意义

客观世界的场景含有丰富多彩的信息,但是由于受到硬件设备的成像条件和成像方式的限制,难以获得原始场景中的所有信息。而且,硬件设备分辨率的限制会不可避免地使图像丢失某些高频细节信息。在当今信息迅猛发展的时代,在卫星遥感、医学影像、多媒体视频等领域中对图像质量的要求越来越高,人们不断寻求更高质量和更高分辨率的图像,来满足日益增长的需求。

空间分辨率的大小是衡量图像质量的一个重要指标,也是将图像应用到实际生活中重要的参数之一。分辨率越高的图像含有的细节信息越多,图像清晰度越高,在实际应用中对各种目标的识别和判断也更加准确。

但是通过提高硬件性能从而提高图像的分辨率的成本高昂。因此,为了满足对图像分辨率的需求,又不增加硬件成本的前提下,依靠软件方法的图像超分辨率重建应运而生。

超分辨率图像重建是指从一系列有噪声、模糊及欠采样的低分辨率图像序列中恢复出一幅高分辨率图像的过程。可以针对现有成像系统普遍存在分辨率低的缺陷,运用某些算法,提高所获得低分辨率图像的质量。因此,超分辨率重建算法的研究具有广阔的发展空间。

方法的具体细节

评价指标
峰值信噪比

峰值信噪比(Peak Signal-to-Noise Ratio), 是信号的最大功率和信号噪声功率之比,来测量被压缩的重构图像的质量,通常以分贝来表示。PSNR指标值越高,说明图像质量越好。

SSIM计算公式如下:

PSNR=10\ast lg\frac{MAX_I^2}{MSE}

MSE表示两个图像之间对应像素之间差值平方的均值。

MAX^2_I表示图像中像素的最大值。对于8bit图像,一般取255。

MSE=\frac{1}{M\ast N} \displaystyle \sum_{i=1}^{N} \sum_{j=1}^{M}(f_{ij}-f'_{ij})^2

f_{ij} 表示图像X在 ij 处的像素值

f'_{ij} 表示图像Y在 ij 处的像素值

结构相似性评价

结构相似性评价(Structural Similarity Index), 是衡量两幅图像相似度的指标,取值范围为0到1。SSIM指标值越大,说明图像失真程度越小,图像质量越好。

SSIM计算公式如下:

L(X,Y)=\frac{2\mu X\mu Y +C_1}{\mu ^2_X + \mu ^2_Y + C_1}

C(X,Y)=\frac{2\sigma X\sigma Y +C_2}{\sigma ^2_X + \sigma ^2_Y + C_2}

S(X,Y)=\frac{\sigma _{XY} + C_3}{\sigma _X \sigma _Y + C_3}

SSIM(X,Y)=L(X,Y) \ast C(X,Y) \ast S(X,Y)

 这两种方式,一般情况下能较为准确地评价重建效果。但是毕竟人眼的感受是复杂丰富的,所以有时也会出现一定的偏差。

EDSR

img

SRResNet在SR的工作中引入了残差块,取得了更深层的网络,而EDSR是对SRResNet的一种提升,其最有意义的模型性能提升是去除掉了SRResNet多余的模块(BN层)

image-20211229150541634

EDSR把批规范化处理(batch normalization, BN)操作给去掉了。

论文中说,原始的ResNet最一开始是被提出来解决高层的计算机视觉问题,比如分类和检测,直接把ResNet的结构应用到像超分辨率这样的低层计算机视觉问题,显然不是最优的。由于批规范化层消耗了与它前面的卷积层相同大小的内存,在去掉这一步操作后,相同的计算资源下,EDSR就可以堆叠更多的网络层或者使每层提取更多的特征,从而得到更好的性能表现。EDSR用L1损失函数来优化网络模型。

1.解压数据集

因为训练时间可能不是很长,所以这里用了BSD100,可以自行更换为DIV2K或者coco

#  !unzip -o /home/aistudio/data/data121380/DIV2K_train_HR.zip -d train
# !unzip -o  /home/aistudio/data/data121283/Set5.zip -d test
 2.定义dataset
import os
from paddle.io import Dataset
from paddle.vision import transforms
from PIL import Image
import random
import paddle
import PIL
import numbers
import numpy as np
from PIL import Image
from paddle.vision.transforms import BaseTransform
from paddle.vision.transforms import functional as F
import matplotlib.pyplot as plt


class SRDataset(Dataset):

    def __init__(self, data_path, crop_size, scaling_factor):
        """
        :参数 data_path: 图片文件夹路径
        :参数 crop_size: 高分辨率图像裁剪尺寸  (实际训练时不会用原图进行放大,而是截取原图的一个子块进行放大)
        :参数 scaling_factor: 放大比例
        """

        self.data_path=data_path
        self.crop_size = int(crop_size)
        self.scaling_factor = int(scaling_factor)
        self.images_path=[]

        # 如果是训练,则所有图像必须保持固定的分辨率以此保证能够整除放大比例
        # 如果是测试,则不需要对图像的长宽作限定

        # 读取图像路径
        for name in os.listdir(self.data_path):
            self.images_path.append(os.path.join(self.data_path,name))

        # 数据处理方式
        self.pre_trans=transforms.Compose([
                                # transforms.CenterCrop(self.crop_size),
                                transforms.RandomCrop(self.crop_size),
                                transforms.RandomHorizontalFlip(0.5),
                                transforms.RandomVerticalFlip(0.5),
                                # transforms.ColorJitter(brightness=0.3, contrast=0.3, hue=0.3),
                                ])

        self.input_transform = transforms.Compose([
                                transforms.Resize(self.crop_size//self.scaling_factor),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5],std=[0.5]),
                                ])

        self.target_transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5],std=[0.5]),
                                ])


    def __getitem__(self, i):
        # 读取图像
        img = Image.open(self.images_path[i], mode='r')
        img = img.convert('RGB')
        img=self.pre_trans(img)

        lr_img = self.input_transform(img)
        hr_img = self.target_transform(img.copy())
        
        return lr_img, hr_img


    def __len__(self):
        return len(self.images_path)

测试dataset

# 单元测试

train_path='train/DIV2K_train_HR'
test_path='test'
ds=SRDataset(train_path,96,2)
l,h=ds[1]

# print(type(l))
print(l.shape)
print(h.shape)

l=np.array(l)
h=np.array(h)
print(type(l))
l=l.transpose(2,1,0)
h=h.transpose(2,1,0)
print(l.shape)
print(h.shape)

plt.subplot(1, 2, 1)
plt.imshow(((l+1)/2))
plt.title('l')
plt.subplot(1, 2, 2)
plt.imshow(((h+1)/2))
plt.title('h')
plt.show()

定义网络结构

较rsresnet少了归一化层,以及更深的残差块

from paddle.nn import Layer
from paddle import nn
import math


n_feat = 256
kernel_size = 3

# 残差块 尺寸不变
class _Res_Block(nn.Layer):
    def __init__(self):
        super(_Res_Block, self).__init__()
        self.res_conv = nn.Conv2D(n_feat, n_feat, kernel_size, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        y = self.relu(self.res_conv(x))
        y = self.res_conv(y)
        y *= 0.1
        # 残差加入
        y = paddle.add(y, x)
        return y


class EDSR(nn.Layer):
    def __init__(self):
        super(EDSR, self).__init__()

        in_ch = 3
        num_blocks = 32

        self.conv1 = nn.Conv2D(in_ch, n_feat, kernel_size, padding=1)
        # 扩大
        self.conv_up = nn.Conv2D(n_feat, n_feat * 4, kernel_size, padding=1)
        self.conv_out = nn.Conv2D(n_feat, in_ch, kernel_size, padding=1)

        self.body = self.make_layer(_Res_Block, num_blocks)
        # 上采样
        self.upsample = nn.Sequential(self.conv_up, nn.PixelShuffle(2))

    # 32个残差块
    def make_layer(self, block, layers):
        res_block = []
        for _ in range(layers):
            res_block.append(block())
        return nn.Sequential(*res_block)

    def forward(self, x):

        out = self.conv1(x)
        out = self.body(out)
        out = self.upsample(out)
        out = self.conv_out(out)

        return out

看paddle能不能用gpu

import paddle
print(paddle.device.get_device())


paddle.device.set_device('gpu:0')

训练,一般4个小时就可以达到一个不错的效果,set5中psnr可以达到27左右,当然这时间还是太少了

import os
from math import log10
from paddle.io import DataLoader
import paddle.fluid as fluid
import warnings
from paddle.static import InputSpec

if __name__ == '__main__':
    warnings.filterwarnings("ignore", category=Warning)  # 过滤报警信息

    train_path='train/DIV2K_train_HR'
    test_path='test'

    crop_size = 96      # 高分辨率图像裁剪尺寸
    scaling_factor = 2  # 放大比例

    # 学习参数
    checkpoint = './work/edsr_paddle'   # 预训练模型路径,如果不存在则为None
    batch_size = 30    # 批大小
    start_epoch = 0     # 轮数起始位置
    epochs = 10000        # 迭代轮数
    workers = 4         # 工作线程数
    lr = 1e-4           # 学习率

    # 先前的psnr
    pre_psnr=32.35

    try:
        model = paddle.jit.load(checkpoint)
        print('加载先前模型成功')
    except:
        print('未加载原有模型训练')
        model = EDSR()

    # 初始化优化器
    scheduler = paddle.optimizer.lr.StepDecay(learning_rate=lr, step_size=1, gamma=0.99, verbose=True)
    optimizer = paddle.optimizer.Adam(learning_rate=scheduler,
                                    parameters=model.parameters())

    criterion = nn.MSELoss()

    train_dataset = SRDataset(train_path, crop_size, scaling_factor)
    test_dataset = SRDataset(test_path, crop_size, scaling_factor)

    train_loader = DataLoader(train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        )

    test_loader = DataLoader(test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=workers,
        )

    for epoch in range(start_epoch, epochs+1):

        model.train()  # 训练模式:允许使用批样本归一化
        train_loss=0
        n_iter_train = len(train_loader)
        train_psnr=0
        # 按批处理
        for i, (lr_imgs, hr_imgs) in enumerate(train_loader):
            lr_imgs = lr_imgs
            hr_imgs = hr_imgs

            sr_imgs = model(lr_imgs)
            loss = criterion(sr_imgs, hr_imgs)  
            optimizer.clear_grad()
            loss.backward()
            optimizer.step()
            train_loss+=loss.item()
            psnr = 10 * log10(1 / loss.item())
            train_psnr+=psnr

        epoch_loss_train=train_loss / n_iter_train
        train_psnr=train_psnr/n_iter_train

        print(f"Epoch {epoch}. Training loss: {epoch_loss_train} Train psnr {train_psnr}DB")


        model.eval()  # 测试模式
        test_loss=0
        all_psnr = 0
        n_iter_test = len(test_loader)

        with paddle.no_grad():
            for i, (lr_imgs, hr_imgs) in enumerate(test_loader):
                lr_imgs = lr_imgs
                hr_imgs = hr_imgs

                sr_imgs = model(lr_imgs)
                loss = criterion(sr_imgs, hr_imgs)

                psnr = 10 * log10(1 / loss.item())
                all_psnr+=psnr
                test_loss+=loss.item()
        
        epoch_loss_test=test_loss/n_iter_test
        epoch_psnr=all_psnr / n_iter_test

        print(f"Epoch {epoch}. Testing loss: {epoch_loss_test} Test psnr{epoch_psnr} dB")

        if epoch_psnr>pre_psnr:
            paddle.jit.save(model, checkpoint,input_spec=[InputSpec(shape=[1,3,48,48], dtype='float32')])
            pre_psnr=epoch_psnr
            print('模型更新成功')

        scheduler.step()

测试,需要自己上传一张低分辨率的图片

import paddle
from paddle.vision import transforms
import PIL.Image as Image
import numpy as np


imgO=Image.open('img_003_SRF_2_LR.png',mode="r")  #选择自己图片的路径
img=transforms.ToTensor()(imgO).unsqueeze(0)

#导入模型
net=paddle.jit.load("./work/edsr_paddle")

source = net(img)[0, :, :, :]
source = source.cpu().detach().numpy()  # 转为numpy
source = source.transpose((1, 2, 0))  # 切换形状
source = np.clip(source, 0, 1)  # 修正图片
img = Image.fromarray(np.uint8(source * 255))

plt.figure(figsize=(9,9))
plt.subplot(1, 2, 1)
plt.imshow(imgO)
plt.title('input')
plt.subplot(1, 2, 2)
plt.imshow(img)
plt.title('output')
plt.show()

img.save('./sr.png')

EDSR_X2效果

双线性插值放大效果

 EDSR_X2放大效果

 双线性插值放大效果

EDSR_X2放大效果

原文: EDSR图像超分重构

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/186800.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据结构与算法编程题21

判别两棵树是否相等。 #define _CRT_SECURE_NO_WARNINGS#include <iostream> using namespace std;typedef char ElemType; #define ERROR 0 #define OK 1typedef struct BiNode {ElemType data;BiNode* lchild, * rchild; }BiNode, * BiTree;bool Create_tree(BiTree&a…

python之pyqt专栏3-QT Designer

从前面两篇文章python之pyqt专栏1-环境搭建与python之pyqt专栏2-项目文件解析&#xff0c;我们对QT Designer有基础的认识。 QT Designer用来创建UI界面&#xff0c;保存的文件是"xxx.ui"文件&#xff0c;"xxx.ui"可以被pyuic转换为"xxx.py",而&…

html table样式的设计 表格边框修饰

<!DOCTYPE html> <html> <head> <meta http-equiv"Content-Type" content"text/html; charsetutf-8" /> <title>今日小说排行榜</title> <style> table {border-collapse: collapse;border: 4px double red; /*…

VC++彻底理解链接器:四,重定位

重定位 程序的运行过程就是CPU不断的从内存中取出指令然后执行执行的过程&#xff0c;对于函数调用来说比如我们在C/C语言中调用简单的加法函数add&#xff0c;其对应的汇编指令可能是这样的: call 0x4004fd 其中0x4004fd即为函数add在内存中的地址&#xff0c;当CPU执行这条…

2023大模型安全解决方案白皮书

今天分享的是大模型系列深度研究报告&#xff1a;《2023大模型安全解决方案白皮书》。 &#xff08;报告出品方&#xff1a;百度安全&#xff09; 报告共计&#xff1a;60页 前言 在当今迅速发展的数字化时代&#xff0c;人工智能技术正引领着科技创新的浪潮而其中的大模型…

Linux(7):Vim 程序编辑器

vi 基本上 vi 共分为三种模式&#xff0c;分别是【一般指令模式】、【编辑模式】与【指令列命令模式】。 这三种模式的作用分别是&#xff1a; 一般指令模式(command mode) 以 vi 打开一个文件就直接进入一般指令模式了(这是默认的模式&#xff0c;也简称为一般模式)。在这个模…

使用 HTML、CSS 和 JavaScript 创建图像滑块

使用 HTML、CSS 和 JavaScript 创建轮播图 在本文中&#xff0c;我们将讨论如何使用 HTML、CSS 和 JavaScript 构建轮播图。我们将演示两种不同的创建滑块的方法&#xff0c;一种是基于opacity的滑块&#xff0c;另一种是基于transform的。 创建 HTML 我们首先从 HTML 代码开…

WPF实战项目十七(客户端):数据等待加载弹框动画

1、在Common文件夹下新建文件夹Events&#xff0c;新建扩展类UpdateLoadingEvent public class UpdateModel {public bool IsOpen { get; set; }}internal class UpdateLoadingEvent : PubSubEvent<UpdateModel>{} 2、新建一个静态扩展类DialogExtensions来编写注册和推…

JSP EL 通过 三元运算符 控制界面 标签 标签属性内容

然后 我们来说说 EL配合三元运算符的妙用 我们先这样写 <% page contentType"text/html; charsetUTF-8" pageEncoding"UTF-8" %> <%request.setCharacterEncoding("UTF-8");%> <!DOCTYPE html> <html> <head>&l…

分布式篇---第六篇

系列文章目录 文章目录 系列文章目录前言一、说说什么是漏桶算法二、说说什么是令牌桶算法三、数据库如何处理海量数据?前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码…

css三角,鼠标样式,溢出文字

目录 css三角 鼠标样式 例子&#xff1a;页码模块 溢出文字表示方式 margin负值运用 css三角强化 css三角 css三角中&#xff1a;line-height&#xff1a;0和font-size&#xff1a;0是防止兼容性的问题 jd {position: relative;width: 120px;height: 249px;background-…

【matlab程序】matlab利用工具包nctool读取grib2、nc、opendaf、hdf5、hdf4等格式数据

【matlab程序】matlab利用工具包nctool读取grib2、nc、opendaf、hdf5、hdf4等格式数据 引用&#xff1a; B. Schlining, R. Signell, A. Crosby, nctoolbox (2009), Github repository, https://github.com/nctoolbox/nctoolbox Brief summary: nctoolbox is a Matlab toolbox…

「Verilog学习笔记」数据串转并电路

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点&#xff0c;刷题网站用的是牛客网 关于什么是Valid/Ready握手机制&#xff1a; 深入 AXI4 总线&#xff08;一&#xff09;握手机制 - 知乎 时序图含有的信息较多&#xff0c;观察时序图需要注意&#xff1a…

YOLOv8改进 | 2023 | MPDIoU、InnerMPDIoU助力细节涨点

论文地址&#xff1a;官方论文地址点击即可跳转 代码地址&#xff1a;官方并没有开源的该损失的函数的代码&#xff0c;我根据文章内容进行了复现代码块在第三章 一、本文介绍 本文为读者详细介绍了YOLOv8模型的最新改进&#xff0c;带来的改进机制是最新的损失函数MPDIoU和融…

Django必备知识点(图文详解)

目录 day02 django必备知识点 1.回顾 2.今日概要 3.路由系统 3.1 传统的路由 3.2 正则表达式路由 3.3 路由分发 小结 3.4 name 3.5 namespace 3.4 最后的 / 如何解决&#xff1f; 3.5 当前匹配对象 小结 4.视图 4.1 文件or文件夹 4.2 相对和绝对导入urls​编辑…

脏页刷新机制总结

1、Buffer Cache和Page Cache 一句话解释&#xff1a;Page Cache用于缓存文件的页数据&#xff0c;Buffer Cache用于缓存块设备&#xff08;磁盘&#xff09;的块数据。但由于磁盘都是由文件系统管理的&#xff0c;所以会导致数据会被缓存两次&#xff0c;因此现在Linux已经不再…

Python Web开发基础知识篇

一&#xff0c;基础知识篇 本片文章会简单地说一些python开发web中所必须的一些基础知识。主要包括HTML和css基础、JavaScript基础、网络编程基础、MySQL数据库基础、Web框架基础等知识。 1,Web简介 Web&#xff0c;全称为World Wide Web&#xff0c;也就是WWW&#xff0c;万…

mysql索引分为哪几类,聚簇索引和非聚簇索引的区别,MySQL索引失效的情况有哪几种情况,MySQL索引优化的手段,MySQL回表

文章目录 索引分为哪几类&#xff1f;聚簇索引和非聚簇索引的区别什么是[聚簇索引](https://so.csdn.net/so/search?q聚簇索引&spm1001.2101.3001.7020)&#xff1f;&#xff08;重点&#xff09;非聚簇索引 聚簇索引和非聚簇索引的区别主要有以下几个&#xff1a;什么叫回…

Leetcode103 二叉树的锯齿形层序遍历

二叉树的锯齿形层序遍历 题解1 层序遍历双向队列 给你二叉树的根节点 root &#xff0c;返回其节点值的 锯齿形层序遍历 。&#xff08;即先从左往右&#xff0c;再从右往左进行下一层遍历&#xff0c;以此类推&#xff0c;层与层之间交替进行&#xff09;。 提示&#xff1a…

激光塑料透光率检测仪进行材料质量监控

焊接质量检测是对焊接成果的检测&#xff0c;目的是保证焊接结构的完整性、可靠性、安全性和使用性。焊接质量检测通常包括外观检验、内部检查、无损检测以及试件制作与送检等步骤。通过这些检测方法&#xff0c;可以全面评估焊接质量&#xff0c;确保其符合设计要求和规范标准…