深度学习之超分辨率算法——FRCNN

– 对之前SRCNN算法的改进

    1. 输出层采用转置卷积层放大尺寸,这样可以直接将低分辨率图片输入模型中,解决了输入尺度问题。
    2. 改变特征维数,使用更小的卷积核和使用更多的映射层。卷积核更小,加入了更多的激活层。
    3. 共享其中的映射层,如果需要训练不同上采样倍率的模型,只需要修改最后的反卷积层大小,就可以训练出不同尺寸的图片。
  • 模型实现
  • 在这里插入图片描述
import math
from torch import nn


class FSRCNN(nn.Module):
    def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):
        super(FSRCNN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2),
            nn.PReLU(d)
        )
        # 添加入多个激活层和小卷积核
        self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]
        for _ in range(m):
            self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)])
        self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])
        self.mid_part = nn.Sequential(*self.mid_part)
        # 最后输出
        self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,
                                            output_padding=scale_factor-1)

        self._initialize_weights()

    def _initialize_weights(self):
        # 初始化
        for m in self.first_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)
        nn.init.zeros_(self.last_part.bias.data)

    def forward(self, x):
        x = self.first_part(x)
        x = self.mid_part(x)
        x = self.last_part(x)
        return x

以上代码中,如起初所说,将SRCNN中给的输出修改为转置卷积,并且在中间添加了多个11卷积核和多个线性激活层。且应用了权重初始化,解决协变量偏移问题。
备注:1
1卷积核虽然在通道的像素层面上,针对一个像素进行卷积,貌似没有什么作用,但是卷积神经网络的特性,我们在利用多个卷积核对特征图进行扫描时,单个卷积核扫描后的为sum©,那么就是尽管在像素层面上无用,但是在通道层面上进行了融合,并且进一步加深了层数,使网络层数增加,网络能力增强。

  • 上代码
  • train.py

训练脚本

import argparse
import os
import copy

import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

from models import FSRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # 训练文件
    parser.add_argument('--train-file', type=str,help="the dir of train data",default="./Train/91-image_x4.h5")
    # 测试集文件
    parser.add_argument('--eval-file', type=str,help="thr dir of test data ",default="./Test/Set5_x4.h5")
    # 输出的文件夹
    parser.add_argument('--outputs-dir',help="the output dir", type=str,default="./outputs")
    parser.add_argument('--weights-file', type=str)
    parser.add_argument('--scale', type=int, default=2)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--num-epochs', type=int, default=20)
    parser.add_argument('--num-workers', type=int, default=8)
    parser.add_argument('--seed', type=int, default=123)
    args = parser.parse_args()

    args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))

    if not os.path.exists(args.outputs_dir):
        os.makedirs(args.outputs_dir)

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(args.seed)

    model = FSRCNN(scale_factor=args.scale).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam([
        {'params': model.first_part.parameters()},
        {'params': model.mid_part.parameters()},
        {'params': model.last_part.parameters(), 'lr': args.lr * 0.1}
    ], lr=args.lr)

    train_dataset = TrainDataset(args.train_file)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    eval_dataset = EvalDataset(args.eval_file)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_psnr = 0.0

    for epoch in range(args.num_epochs):
        model.train()
        epoch_losses = AverageMeter()

        with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t:
            t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))

            for data in train_dataloader:
                inputs, labels = data

                inputs = inputs.to(device)
                labels = labels.to(device)

                preds = model(inputs)

                loss = criterion(preds, labels)

                epoch_losses.update(loss.item(), len(inputs))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
                t.update(len(inputs))

        torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))

        model.eval()
        epoch_psnr = AverageMeter()

        for data in eval_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = model(inputs).clamp(0.0, 1.0)

            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
    torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

test.py 测试脚本

import argparse

import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image

from models import FSRCNN
from utils import convert_ycbcr_to_rgb, preprocess, calc_psnr


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights-file', type=str, required=True)
    parser.add_argument('--image-file', type=str, required=True)
    parser.add_argument('--scale', type=int, default=3)
    args = parser.parse_args()

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = FSRCNN(scale_factor=args.scale).to(device)

    state_dict = model.state_dict()
    for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()

    image = pil_image.open(args.image_file).convert('RGB')

    image_width = (image.width // args.scale) * args.scale
    image_height = (image.height // args.scale) * args.scale

    hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
    lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)
    bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
    bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))

    lr, _ = preprocess(lr, device)
    hr, _ = preprocess(hr, device)
    _, ycbcr = preprocess(bicubic, device)

    with torch.no_grad():
        preds = model(lr).clamp(0.0, 1.0)

    psnr = calc_psnr(hr, preds)
    print('PSNR: {:.2f}'.format(psnr))

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

    output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    # 保存图片
    output.save(args.image_file.replace('.', '_fsrcnn_x{}.'.format(args.scale)))

datasets.py

数据集的读取

import h5py
import numpy as np
from torch.utils.data import Dataset


class TrainDataset(Dataset):
    def __init__(self, h5_file):
        super(TrainDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])


class EvalDataset(Dataset):
    def __init__(self, h5_file):
        super(EvalDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])

工具文件utils.py

  • 主要用来测试psnr指数,图片的格式转换(悄悄说一句,opencv有直接实现~~~)
import torch
import numpy as np


def calc_patch_size(func):
    def wrapper(args):
        if args.scale == 2:
            args.patch_size = 10
        elif args.scale == 3:
            args.patch_size = 7
        elif args.scale == 4:
            args.patch_size = 6
        else:
            raise Exception('Scale Error', args.scale)
        return func(args)
    return wrapper


def convert_rgb_to_y(img, dim_order='hwc'):
    if dim_order == 'hwc':
        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
    else:
        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.


def convert_rgb_to_ycbcr(img, dim_order='hwc'):
    if dim_order == 'hwc':
        y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
        cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.
        cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.
    else:
        y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.
        cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.
        cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.
    return np.array([y, cb, cr]).transpose([1, 2, 0])


def convert_ycbcr_to_rgb(img, dim_order='hwc'):
    if dim_order == 'hwc':
        r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921
        g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576
        b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836
    else:
        r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921
        g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576
        b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836
    return np.array([r, g, b]).transpose([1, 2, 0])


def preprocess(img, device):
    img = np.array(img).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(img)
    x = ycbcr[..., 0]
    x /= 255.
    x = torch.from_numpy(x).to(device)
    x = x.unsqueeze(0).unsqueeze(0)
    return x, ycbcr


def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


先跑他个几十轮~
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/939543.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

深度学习从入门到精通——图像分割实战DeeplabV3

DeeplabV3算法 参数配置关于数据集的配置训练集参数 数据预处理模块DataSet构建模块测试一下数据集去正则化模型加载模块DeepLABV3 参数配置 关于数据集的配置 parser argparse.ArgumentParser()# Datset Optionsparser.add_argument("--data_root", typestr, defa…

大数据操作实验一

1.Postgresql 1.1 数据库的对象创建 1.1.1 创建数据库(Database) 鼠标右键database进行创建 1.1.2 创建图(Schema) 鼠标右键schema,然后创建schema图纸 1.1.3 创建表(Table) 鼠标右键Table,创建表 1.2数据库实列化 1.2.1 实列化静态数据 提…

IDEA2024如何创建一个普通的Java Web项目工程(JSP)

本章教程,主要介绍如何在IDEA2024 专业版本中,创建一个普通的Java Web项目。 一、新建项目 二、配置项目 依次点击File——Project Structure——Modules 修改路径中的web为webapp,然后点击Create Artifact默认保存。 至此,一个基础的Java web就创建完成了。

Linux下mysql 8.0安装教程

本文介绍了如何在Linux下安装MySQL8.0,供大家参考,具体内容如下 准备工作: mysql8.0 rpm文件 测试工具(比如 idea的database工具) 安装步骤: 1. 下载mysql的repo源,下载地址: 进入Linux系统,输入指令: 1 wgethttps://dev.mysql.com/get/mysql80-community-rele…

libaom 源码分析:熵编码模块介绍

AV1 熵编码原理介绍 关于AV1 熵编码原理介绍可以参考:AV1 编码标准熵编码技术概述libaom 熵编码相关源码介绍 函数流程图 核心函数介绍 av1_pack_bitstream 函数:该函数负责将编码后的数据打包成符合 AV1 标准的比特流格式;包括写入序列头 OBU 的函数 av1_write_obu_header…

[数据结构#1] 并查集 | FindRoot | Union | 优化 | 应用

目录 1. 并查集原理 问题背景 名称与编号映射 数据结构设计 2. 并查集基本操作 (1) 初始化 (2) 查询根节点 (FindRoot) (3) 合并集合 (Union) (4) 集合操作总结 并查集优化 (1) 路径压缩 (2) 按秩合并 3. 并查集的应用 (1) 统计省份数量 (2) 判断等式方程是否成…

Centos创建共享文件夹拉取文件

1.打开VMware程序,鼠标右检你的虚拟机,打开设置 2.点击选项——共享文件夹——总是启用 点击添加,设置你想要共享的文件夹在pc上的路径(我这里已经添加过了就不加了) 注意不要中文,建议用share&#xff0c…

Element@2.15.14-tree checkStrictly 状态实现父项联动子项,实现节点自定义编辑、新增、删除功能

背景:现在有一个新需求,需要借助树结构来实现词库的分类管理,树的节点是不同的分类,不同的分类可以有自己的词库,所以父子节点是互不影响的;同样为了选择的方便性,提出了新需求,选择…

java版电子招投标采购|投标|评标|竞标|邀标|评审招投标系统源码

招投标管理系统是一款适用于招标代理、政府采购、企业采购和工程交易等领域的企业级应用平台。该平台以项目为主线,从项目立项到项目归档,实现了全流程的高效沟通和协作。通过该平台,用户可以实时共享项目数据信息,实现规范化管理…

【Verilog HDL 入门教程】 —— 学长带你学Verilog(基础篇)

文章目录 一、Verilog HDL 概述1、Verilog HDL 是什么2、Verilog HDL产生的背景3、Verilog HDL 和 VHDL的区别 二、Verilog HDL 基础知识1、Verilog HDL 语言要素1.1、命名规则1.2、注释符1.3、关键字1.4、数值1.4.1、整数及其表示1.4.2、实数及其表示1.4.3、字符串及其表示 2、…

龙迅#LT7911E适用于EDP/DP/TPYE-C转MIPIDSI应用,支持图像处理功能,内置I2C,主应用副屏显示,投屏领域!

1. 描述 LT7911E 是一款高性能 eDP 转 MIPI D-PHY 转换器,旨在将 eDP 源连接到 MIPI 显示面板。 LT7911E 集成了一个符合 eDP1.4 标准的接收器,支持 1.62Gbps 至 5.67Gbps 的输入数据,以 270Mbps 的递增步长,以及一个 2 端口 D…

《算法SM9》题目

判断题 SM9密码算法系统参数由KGC选择。 A.正确 B.错误 正确答案A 多项选择题 SM9密码算法KGC是负责( )的可信机构。 A.选择系统参数 B.生成主密钥 C.生成用户标识 D.生成用户私钥 正确答案ABD 判断题 SM9密钥封装机制封装的秘密密钥是根据…

C语言——实现求出最大值

问题描述&#xff1a;利用C语言自定义函数求出一维数组里边最大的数字 //利用函数找最大数#include<stdio.h>int search(int s[9]) //查找函数 {int i , max s[0] , max_xia 0;for(i0;i<9;i){if(s[i] > max){max_xia i;max s[max_xia];}}return max; } in…

【尚硅谷 - SSM+SpringBoot+SpringSecurity框架整合项目 】项目打包并且本地部署

前后端分离开发&#xff1a;把一个项目拆成两部分进行开发&#xff0c;所以在打包的时候&#xff0c;需要使用不同的打包方式。 后端 – SpringBoot – jar包 前端 – Vue: 因为使用了vue-admin-template框架&#xff1a;所以先使用框架进行打包使用Nginx部署&#xff0c;通…

【SH】Ubuntu Server 24服务器搭建MySQL数据库研发笔记

文章目录 搭建服务器在线安装1. 更新软件包列表2. 安装MySQL3. 检查MySQL状态4. 修改密码5. 新增用户6. 设置局域网访问 离线安装下载安装包 常用命令参考文档在线安装日志 搭建服务器 作者羊大侠搭建的是 Ubuntu Server 24.04 LTS 服务器环境 搭建参考文档&#xff1a;【SH】…

容器化技术全面解析:Docker 与 Containerd 的深入解读

目录 Docker 简介 1. 什么是 Docker&#xff1f; 2. Docker 的核心组件 3. Docker 的主要功能 4. Docker 的优点 5. Docker 的使用场景 Containerd 简介 1. 什么是 Containerd&#xff1f; 2. Containerd 的核心特性 3. Containerd 的架构 4. Containerd 与 Docker 的…

华为数通最新题库 H12-821 HCIP稳定过人中

以下是成绩单和考试人员 HCIP H12-831 HCIP H12-725 安全中级

Webots控制器编程

本文主要内容是如何编写Webots控制器&#xff0c;使用语言为Python。 文章目录 1. 新增控制器2. Hello World Example3. 读取传感器4. 使用执行器5. 理解step和robot.step函数6. 同时使用传感器和执行器7. 控制器参数 1. 新增控制器 对机器人Robot新增控制器的方式&#x…

[SAP ABAP] 将内表数据转换为HTML格式

从sflight数据库表中检索航班信息&#xff0c;并将这些信息转换成HTML格式&#xff0c;然后下载或显示在前端 开发步骤 ① 自定义一个数据类型 ty_sflight 来存储航班信息 ② 声明内表和工作区变量&#xff0c;用于存储表头、字段、HTML内容和航班详细信息以及创建字段目录lt…

《算法SM4》题目

单项选择题 我国商用密码算法SM4迭代结构是&#xff08;&#xff09;。 A.平衡Fesitel网络结构 B.非平衡Fesitel网络结构 C.SP结构 D.MD结构 正确答案B 多项选择题 SM4分组密码算法轮函数中的T置换&#xff0c;包括的运算有&#xff08;&#xff09;。 A.非线性变换 …