Pytorch训练RCAN QAT超分模型

Pytorch训练RCAN QAT超分模型

  • 版本信息
  • 测试步骤
    • 准备数据集
    • 创建容器
    • 生成文件列表
      • 创建文件列表的代码
      • 执行脚本,生成文件列表
    • 训练RCAN模型
      • 准备工作
      • 修改开源代码
      • 编写训练代码
      • 执行训练脚本
    • 可视化

本文以RCAN超分模型为例,演示了QAT的训练过程,步骤如下:

  • 先训练FP32模型
  • 再加载FP32训练的权值,进行QAT训练
  • 连续5次loss没有下降则停止训练
  • 为了加快演示,当psnr大于33.0时就停止训练
  • 采用tensorboard观察Loss曲线

版本信息

属性
训练环境 搭建步骤
GPU型号 NVIDIA GeForce RTX 3080 12GB
数据集下载链接 http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip
http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip
http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip
开源模型结构 https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/rcan.py
https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/option.py
https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/common.py
https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/template.py

测试步骤

准备数据集

wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip

创建容器

按https://editor.csdn.net/md/?articleId=136176989的步骤构建镜像

docker stop rcan_dev
docker rm rcan_dev
nvidia-docker run -ti -e NVIDIA_VISIBLE_DEVICES=all --privileged \
				--net=host -p 6006:6006 -v $PWD:/home -w /home  \
				-v /mnt/disk/RCAN/:/RCAN --name rcan_dev  cuda_dev_image:v1.0 /bin/bash
conda activate ai_dev

生成文件列表

创建文件列表的代码

# generate_datalist.py

import os
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm

train_HR_path = './DIV2K_train_HR'
train_LR_path = './DIV2K_train_LR_bicubic/X2'
valid_HR_path = './DIV2K_valid_HR'
valid_LR_path = './DIV2K_valid_LR_bicubic/X2'

train_file = 'datalist_div2k_train.txt'
valid_file = 'datalist_div2k_valid.txt'

def get_images(input_path, format='png'):
    names = [os.path.splitext(fname)[0]
            for fname in os.listdir(input_path)
            if fname.endswith(format)]
    names.sort()
    return names

def get_folders(input_path):
    names = [directory 
            for directory in os.listdir(input_path)
            if os.path.isdir(os.path.join(input_path, directory))]
    names.sort()
    return names

the_train_file = open(train_file, 'w')
image_names = get_images(train_HR_path)
for image_name in image_names:
    the_train_file.write('DIV2K_train_LR_bicubic/X2/' + image_name + 
            'x2.png' + ' ' + 'DIV2K_train_HR/' + image_name + '.png' + '\n')
the_train_file.close()

the_valid_file = open(valid_file, 'w')
image_names = get_images(valid_HR_path)
for image_name in image_names: 
    the_valid_file.write('DIV2K_valid_LR_bicubic/X2/' + image_name + 
            'x2.png' + ' ' + 'DIV2K_valid_HR/' + image_name + '.png' + '\n')
the_valid_file.close()

执行脚本,生成文件列表

cd /RCAN/
unzip DIV2K_train_HR.zip
unzip DIV2K_valid_HR.zip
unzip DIV2K_train_LR_bicubic_X2.zip
unzip DIV2K_valid_LR_bicubic_X2.zip
python generate_datalist.py

训练RCAN模型

准备工作

# 安装依赖
pip install tensorboard -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install scikit-image -i https://pypi.tuna.tsinghua.edu.cn/simple

# 设置环境变量
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

# 下载开源模型源码
cd /RCAN/
mkdir model
curl -L -o model/rcan.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/rcan.py
curl -L -o model/option.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/option.py
curl -L -o model/common.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/common.py
curl -L -o template.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/template.py

修改开源代码

  • model/rcan.py

image-20240220142852491

image-20240220144639916

  • model/common.py

    image-20240220143210588

编写训练代码

# train.py

import os
import torch
import torch.nn as nn
import torch.optim as optim
import json
import copy
import time
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.quantization.quantize_fx import prepare_qat_fx,convert_fx
from torch.ao.quantization import qconfig
from torch.ao.quantization.fake_quantize import *
from torch.ao.quantization.observer import *
from torch.utils import tensorboard
from torch.autograd import Variable
from torch.utils.data import Dataset
from skimage.color import rgb2hsv, hsv2rgb
import imageio
import random
import numpy as np

def _apply(func, x):

    if isinstance(x, (list, tuple)):
        return [_apply(func, x_i) for x_i in x]
    elif isinstance(x, dict):
        y = {
   }
        for key, value in x.items():
            y[key] = _apply(func, value)
        return y
    else:
        return func(x)

def get_patch(*args, patch_size=96, scale=2, input_large=False):
    ih, iw = args[0].shape[:2]

    if not input_large:
        p = scale
        tp = p * patch_size
        ip = tp // scale
    else:
        tp = patch_size
        ip = patch_size

    ix = random.randrange(0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/411530.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Qt QWidget 简约美观的加载动画 第二季

&#x1f603; 第二季来啦 &#x1f603; 简约的加载动画,用于网络查询等耗时操作时给用户的提示. 这是最终效果: 一共只有三个文件,可以直接编译运行 //main.cpp #include "LoadingAnimWidget.h" #include <QApplication> #include <QVBoxLayout> #i…

LeetCode704. 二分查找(C++)

LeetCode704. 二分查找 题目链接代码 题目链接 https://leetcode.cn/problems/binary-search/description/ 代码 class Solution { public:int search(vector<int>& nums, int target) {int left 0;int right nums.size() - 1;while(left < right){int midd…

主机字节序与网络字节序

大端序和小端序 大端序&#xff08;Big Endian&#xff09;和小端序&#xff08;Little Endian&#xff09;是两种计算机存储数据的方式。 大端序指的是将数据的高位字节存储在内存的低地址处&#xff0c;而将低位字节存储在内存的高地址处。这类似于我们阅读多位数时从左往右…

1.0 RK3399项目开发实录-Ubuntu环境搭建(物联技术666)

1.下载Ubuntu所需的版本&#xff1a;Index of /releases 2.安装vmplayer:Download VMware Workstation Player | VMware 3.安装Ubuntu时&#xff0c;磁盘空间尽量大些&#xff0c;开发板系统包都比较大&#xff0c;避免存不下&#xff0c;建议空间100G。 关闭Ubuntu自动更新…

卡玛网● 46. 携带研究材料 ● 01背包问题,你该了解这些! 滚动数组 力扣● 416. 分割等和子集

开始背包问题&#xff0c;掌握0-1背包和完全背包即可&#xff0c;注&#xff1a;0-1背包是完全背包的基础。 0-1背包问题&#xff1a;有n件物品和一个最多能背重量为w 的背包。第i件物品的重量是weight[i]&#xff0c;得到的价值是value[i] 。每件物品只能用一次&#xff0c;求…

【C进阶】顺序表详解

文章目录 &#x1f4dd;线性表的概念&#x1f320; 顺序表&#x1f309;顺序表的概念 &#x1f320;声明--接口&#x1f309;启动&#x1f320;初始化&#x1f309;扩容&#x1f320;尾插&#x1f309; 打印&#x1f320;销毁&#x1f309; 尾删&#x1f320;头插&#x1f309;…

内存函数(C语言进阶)

目录 前言 1、memcpy 2、memmove 3、memcmp 4、memset 结语 前言 本篇介绍了C语言中的内存函数&#xff0c;内存函数&#xff0c;顾名思义就是处理内存的函数。 1、memcpy memcpy&#xff1a;内存拷贝函数。 相对于strcpy只能拷贝字符串来讲&#xff0c;memcpy能拷…

Mysql学习之事务日志undolog深入剖析

Undo log redo log 是事务持久性的保证&#xff0c;undo log是事务原子性的保证。在事务中更新数据的前置操作其实是要先写入一个undo log。 如何理解undo 日志&#xff1f; 事务需要保证原子性&#xff0c;也就是事务中的操作要么全部完成&#xff0c;要么什么也不做。但有时…

kitti数据显示

画出track_id publish_utils.py中 def publish_3dbox(box3d_pub, corners_3d_velos, types, track_ids):marker_array MarkerArray()for i, corners_3d_velo in enumerate(corners_3d_velos):marker Marker()marker.header.frame_id FRAME_IDmarker.header.stamp rospy.T…

kubernetes的网络flannel与caclio

flannel网络 跨主机通信的一个解决方案是Flannel&#xff0c;由CoreOS推出&#xff0c;支持3种实现&#xff1a;UDP、VXLAN、host-gw udp模式&#xff1a;使用设备flannel.0进行封包解包&#xff0c;不是内核原生支持&#xff0c;上下文切换较大&#xff0c;性能非常差 vxlan模…

golang学习3,golang 项目中配置gin的web框架

1.go 初始化 mod文件 go mod init gin-ranking 2.gin的crm框架 go get -u github.com/gin-gonic/gin 3.go.mod爆红解决

五种多目标优化算法(MOFA、NSWOA、MOJS、MOAHA、MOPSO)性能对比(提供MATLAB代码)

一、5种多目标优化算法简介 多目标优化算法是用于解决具有多个目标函数的优化问题的一类算法。其求解流程通常包括以下几个步骤&#xff1a; 1. 定义问题&#xff1a;首先需要明确问题的目标函数和约束条件。多目标优化问题通常涉及多个目标函数&#xff0c;这些目标函数可能…

LeetCode刷题---从中序与后序遍历序列构造二叉树

解题思路: 首先还是定义哈希表将中序遍历的数插入进去&#xff0c;方便后序查阅 创建递归方法buildTreeNew(中序遍历数组&#xff0c;后序遍历数组&#xff0c;左或右子树在中序遍历数组中的起始和终止节点索引&#xff0c;左或右子树在后序遍历数组中的起始和终止节点索引) 后…

ThreadLocal从使用到实现原理与源码详解

ThreadLocal概述 ThreadLocal是多线程中对于解决线程安全的一个操作类&#xff0c;它会为每个线程都分配一个独立的线程副本从而解决了变量并发访问冲突的问题。ThreadLocal 同时实现了线程内的资源共享。 案例&#xff1a;使用JDBC操作数据库时&#xff0c;会将每一个线程的…

数据库系统概论(超详解!!!) 第一节 绪论

1.四个基本概念 1.数据&#xff08;Data&#xff09; 数据&#xff08;Data&#xff09;是数据库中存储的基本对象 数据的定义&#xff1a;描述事物的符号记录 数据的种类&#xff1a;数字、文字、图形、图像、音频、视频、学生的档案记录等 数据的含义称为数据的语义&…

第三百六十七回

文章目录 1. 概念介绍2. 方法与细节2.1 获取方法2.2 使用细节 3. 示例代码4. 内容总结 我们在上一章回中介绍了"如何获取当前系统语言"相关的内容&#xff0c;本章回中将介绍如何获取时间戳.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1. 概念介绍 我们在本章…

HCIA-Datacom实验指导手册:5.1 实验一:FTP SFTP TFTP 基础配置实验

HCIA-Datacom实验指导手册&#xff1a;5.1 实验一&#xff1a;FTP 基础配置实验 一、实验介绍&#xff1a;二、实验拓扑&#xff1a;三、实验目的&#xff1a;四、配置步骤&#xff1a;步骤 1 设备基础配置步骤 2 在 Router 上配置 FTP 和SFTP服务器功能及参数步骤 3 配置本地 …

MySQL数据库基础(十五):PyMySQL使用介绍

文章目录 PyMySQL使用介绍 一、为什么要学习PyMySQL 二、安装PyMySQL模块 三、PyMySQL的使用 1、导入 pymysql 包 2、创建连接对象 3、获取游标对象 4、pymysql完成数据的查询操作 5、pymysql完成对数据的增删改 PyMySQL使用介绍 提前安装MySQL数据库&#xff08;可以…

js里面有引用传递吗?

一&#xff1a;什么是引用传递 引用传递是相对于值传递的。那什么是值传递呢&#xff1f;值传递就是在传递过程中再复制一份&#xff0c;然后再赋值给变量&#xff0c;例如&#xff1a; let a 2; let b a;在这个代码中&#xff0c;let b a; 就是一个值传递&#xff0c;首先…

js中浏览器渲染原理

JavaScript&#xff08;JS&#xff09;是一种广泛使用的编程语言&#xff0c;特别是在Web开发中。在浏览器中&#xff0c;JS被用于实现动态网页效果、交互性和用户体验的提升。然而&#xff0c;要理解JS在浏览器中的工作原理&#xff0c;我们首先需要了解浏览器的渲染过程。 浏…