深度学习之超分辨率算法——SRGAN

  • 更新版本

  • 实现了生成对抗网络在超分辨率上的使用

  • 更新了损失函数,增加先验函数
    在这里插入图片描述

  • SRresnet实现

import torch
import torchvision
from torch import nn


class ConvBlock(nn.Module):

	def __init__(self, kernel_size=3, stride=1, n_inchannels=64):
		super(ConvBlock, self).__init__()

		self.sequential = nn.Sequential(
			nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),
					  stride=(stride, stride), bias=False, padding=(1, 1)),
			nn.BatchNorm2d(n_inchannels),
			nn.PReLU(),
			nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),
					  stride=(stride, stride), bias=False, padding=(1, 1)),
			nn.BatchNorm2d(n_inchannels),
			nn.PReLU(),
		)

	def forward(self, x):
		redisious = x
		out = self.sequential(x)
		return redisious + out


class Head_Conv(nn.Module):

	def __init__(self):
		super(Head_Conv, self).__init__()
		self.sequential = nn.Sequential(
			nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),
			nn.PReLU(),
		)

	def forward(self, x):
		return self.sequential(x)


class PixelShuffle(nn.Module):

	def __init__(self, n_channels=64, upscale_factor=2):
		super(PixelShuffle, self).__init__()
		self.sequential = nn.Sequential(
			nn.Conv2d(in_channels=n_channels, out_channels=n_channels * (upscale_factor ** 2), kernel_size=(3, 3),
					  stride=(1, 1), padding=(3 // 2, 3 // 2)),
			nn.BatchNorm2d(n_channels * (upscale_factor ** 2)),
			nn.PixelShuffle(upscale_factor=upscale_factor)
		)

	def forward(self, x):
		return self.sequential(x)


class Hidden_block(nn.Module):

	def __init__(self):
		super(Hidden_block, self).__init__()
		self.sequential = nn.Sequential(
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),
			nn.BatchNorm2d(64),
		)

	def forward(self, x):
		return self.sequential(x)


class TailConv(nn.Module):

	def __init__(self):
		super(TailConv, self).__init__()
		self.sequential = nn.Sequential(
			nn.Conv2d(in_channels=64, out_channels=3, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),
			nn.Tanh(),
		)

	def forward(self, x):
		return self.sequential(x)


class SRResNet(nn.Module):

	def __init__(self, n_blocks=16):
		super(SRResNet, self).__init__()
		self.head = Head_Conv()
		self.resnet = list()
		for _ in range(n_blocks):
			self.resnet.append(ConvBlock(kernel_size=3, stride=1, n_inchannels=64))

		self.resnet = nn.Sequential(*self.resnet)
		self.hidden = Hidden_block()
		self.pixelShuufe = []
		for _ in range(2):
			self.pixelShuufe.append(
				PixelShuffle(n_channels=64, upscale_factor=2)
			)
		self.pixelShuufe = nn.Sequential(*self.pixelShuufe)
		self.tail_conv = TailConv()

	def forward(self, x):
		head_out = self.head(x)
		resnet_out = self.resnet(head_out)
		out = head_out + resnet_out
		result = self.pixelShuufe(out)
		out = self.tail_conv(result)
		return out

class Generator(nn.Module):

	def __init__(self):
		super(Generator, self).__init__()
		self.model = SRResNet()

	def forward(self, x):
		'''
		:param x:lr_img
		:return: 
		'''
		return self.model(x)


class Discriminator(nn.Module):

	def __init__(self):
		super(Discriminator, self).__init__()
		self.hidden = nn.Sequential(
			nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
			nn.BatchNorm2d(64),
			nn.LeakyReLU(),

			nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
			nn.BatchNorm2d(128),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
			nn.BatchNorm2d(128),
			nn.LeakyReLU(),

			nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
			nn.BatchNorm2d(256),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
			nn.BatchNorm2d(256),
			nn.LeakyReLU(),

			nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
			nn.BatchNorm2d(512),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
			nn.BatchNorm2d(512),
			nn.LeakyReLU(),
			nn.AdaptiveAvgPool2d((6, 6))
		)
		self.out_layer = nn.Sequential(
			nn.Linear(512 * 6 * 6, 1024),
			nn.LeakyReLU(negative_slope=0.2, inplace=True),
			nn.Linear(1024, 1),
			nn.Sigmoid()
		)

	def forward(self, x):
		result = self.hidden(x)
		# print(result.shape)
		result = result.reshape(result.shape[0], -1)
		out = self.out_layer(result)
		return out

SRGAN模型的生成器与判别器的实现


class Generator(nn.Module):

	def __init__(self):
		super(Generator, self).__init__()
		self.model = SRResNet()

	def forward(self, x):
		'''
		:param x:lr_img
		:return: 
		'''
		return self.model(x)


class Discriminator(nn.Module):

	def __init__(self):
		super(Discriminator, self).__init__()
		self.hidden = nn.Sequential(
			nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
			nn.BatchNorm2d(64),
			nn.LeakyReLU(),

			nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
			nn.BatchNorm2d(128),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
			nn.BatchNorm2d(128),
			nn.LeakyReLU(),

			nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
			nn.BatchNorm2d(256),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
			nn.BatchNorm2d(256),
			nn.LeakyReLU(),

			nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
			nn.BatchNorm2d(512),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
			nn.BatchNorm2d(512),
			nn.LeakyReLU(),
			nn.AdaptiveAvgPool2d((6, 6))
		)
		self.out_layer = nn.Sequential(
			nn.Linear(512 * 6 * 6, 1024),
			nn.LeakyReLU(negative_slope=0.2, inplace=True),
			nn.Linear(1024, 1),
			nn.Sigmoid()
		)

	def forward(self, x):
		result = self.hidden(x)
		# print(result.shape)
		result = result.reshape(result.shape[0], -1)
		out = self.out_layer(result)
		return out


```
- 针对VGG19 的层数截取
```python
class TruncatedVGG19(nn.Module):
	"""
	truncated VGG19网络,用于计算VGG特征空间的MSE损失
	"""

	def __init__(self, i, j):
		"""
		:参数 i: 第 i 个池化层
		:参数 j: 第 j 个卷积层
		"""
		super(TruncatedVGG19, self).__init__()

		# 加载预训练的VGG模型
		vgg19 = torchvision.models.vgg19(pretrained=True)
		print(vgg19)
		maxpool_counter = 0
		conv_count = 0
		truncate_at = 0
		# 迭代搜索
		for layer in vgg19.features.children():
			truncate_at += 1

			# 统计
			if isinstance(layer, nn.Conv2d):
				conv_count += 1
			if isinstance(layer, nn.MaxPool2d):
				maxpool_counter += 1
				conv_counter = 0

			# 截断位置在第(i-1)个池化层之后(第 i 个池化层之前)的第 j 个卷积层
			if maxpool_counter == i - 1 and conv_count == j:
				break

		# 检查是否满足条件
		assert maxpool_counter == i - 1 and conv_count == j, "当前 i=%d 、 j=%d 不满足 VGG19 模型结构" % (
			i, j)

		# 截取网络
		self.truncated_vgg19 = nn.Sequential(*list(vgg19.features.children())[:truncate_at + 1])

	def forward(self, input):
		output = self.truncated_vgg19(input)  # (N, channels, _w,h)

		return output
```


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/941010.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Pytorch | 利用PI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

Pytorch | 利用PI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击 CIFAR数据集PI-FGSM介绍背景和动机算法原理算法流程 PI-FGSM代码实现PI-FGSM算法实现攻击效果 代码汇总pifgsm.pytrain.pyadvtest.py 之前已经针对CIFAR10训练了多种分类器: Pytorch | 从零构建AlexN…

IMX6ULL开发板如何关掉自带的QT的GUI界面和poky的界面的方法

重要说明:其实最后发现根本没必要去关掉自带的QT的GUI界面,直接把屏幕先刷黑就可以看到测试效果了,把屏蔽先刷黑的代码见博文: https://blog.csdn.net/wenhao_ir/article/details/144594705 不过,既然花了时间摸索如何…

【网络安全】逆向工程 练习示例

1. 逆向工程简介 逆向工程 (RE) 是将某物分解以了解其功能的过程。在网络安全中,逆向工程用于分析应用程序(二进制文件)的运行方式。这可用于确定应用程序是否是恶意的或是否存在任何安全漏洞。 例如,网络安全分析师对攻击者分发…

Docker Compose 安装 Harbor

我使用的系统是rocky Linux 9 1. 准备环境 确保你的系统已经安装了以下工具: DockerDocker ComposeOpenSSL(用于生成证书)#如果不需要通过https连接的可以不设置 1.1 安装 Docker 如果尚未安装 Docker,可以参考以下命令安装&…

深入浅出:多功能 Copilot 智能助手如何借助 LLM 实现精准意图识别

阅读原文 1. Copilot中的意图识别 如果要搭建一个 Copilot 智能助手,比如支持 知识问答、数据分析、智能托管、AIGC 等众多场景或能力,那么最核心的就是基于LLM进行意图识别分发能力,意图识别的准确率直接决定了 Copilot 智能助手的能力上限…

ZED-OpenCV项目运行记录

项目地址:GitCode - 全球开发者的开源社区,开源代码托管平台 使用 ZED 立体相机与 OpenCV 进行图像处理和深度感知 • 使用 ZED 相机和 OpenCV 库捕获图像、深度图和点云。 • 提供保存并排图像、深度图和点云的功能。 • 允许在不同格式之间切换保存的深度图和点云…

Linux 常见用例汇总

注:本文为 Linux 常见用例文章合辑。 部分内容已过时,未更新整理。 检查 Linux 上的 glibc 版本 译者:joeren | 2014-11-27 21:33 问:检查 Linux 系统上的 GNU C 库(glibc)的版本? GNU C 库&…

PHP阶段一

PHP 一门编程语言 运行在服务器端 专门用户开发网站的 脚本后缀名.php 与HTML语言进行混编,脚本后缀依然是.php 解释型语言,不要编译直接运行 PHP运行需要环境: Windows phpstudy Linux 单独安装 Web 原理简述 1、打开浏览器 2、输入u…

REMOTE_LISTENER引发的血案

作者:Digital Observer(施嘉伟) Oracle ACE Pro: Database PostgreSQL ACE Partner 11年数据库行业经验,现主要从事数据库服务工作 拥有Oracle OCM、DB2 10.1 Fundamentals、MySQL 8.0 OCP、WebLogic 12c OCA、KCP、PCTP、PCSD、P…

Redis篇--常见问题篇6--缓存一致性1(Mysql和Redis缓存一致,更新数据库删除缓存策略)

1、概述 在使用Redis作为MySQL的缓存层时,缓存一致性问题是指Redis中的缓存数据与MySQL数据库中的实际数据不一致的情况。这可能会导致读取到过期或错误的数据,从而影响系统的正确性和用户体验。 为了减轻数据库的压力,通常读操作都是先读缓…

Phono3py hdf5文件数据读取与处理

Phono3py是一个主要用python写的声子-声子相互作用相关性质的模拟包,可以基于有限位移算法实现三阶力常数和晶格热导率的计算过程,同时输出包括声速,格林奈森常数,声子寿命和累积晶格热导率等参量。 相关介绍和安装请参考往期推荐…

机器学习(四)-回归模型评估指标

文章目录 1. 哪个模型更好?2. 线性回归评估指标3. python 实现线性模型评估指标 1. 哪个模型更好? 我们之前已经对房价预测的问题构建了线性模型,并对测试集进行了预测。 如图所示,横坐标是地区人口,纵坐标是房价&am…

Oracle 适配 OpenGauss 数据库差异语法汇总

背景 国产化进程中,需要将某项目的数据库从 Oracle 转为 OpenGauss ,项目初期也是规划了适配不同数据库的,MyBatis 配置加载路径设计的是根据数据库类型加载指定文件夹的 xml 文件。 后面由于固定了数据库类型为 Oracle 后,只写…

Unity引擎学习总结------动画控件

左侧窗格可以在参数视图和图层视图之间切换。参数视图允许您创建、查看和编辑动画控制器参数。这些是您定义的变量,用作状态机的输入。要添加参数,请单击加号图标并从弹出菜单中选择参数类型。要删除参数,请在列表中选择该参数并按删除键&…

记录:virt-manager配置Ubuntu arm虚拟机

virt-manager(Virtual Machine Manager)是一个图形用户界面应用程序,通过libvirt管理虚拟机(即作为libvirt的图形前端) 因为要在Linux arm环境做测试,记录下virt-manager配置arm虚拟机的过程 先在VMWare中…

VSCode 搭建Python编程环境 2024新版图文安装教程(Python环境搭建+VSCode安装+运行测试+背景图设置)

名人说:一点浩然气,千里快哉风。—— 苏轼《水调歌头》 创作者:Code_流苏(CSDN) 目录 一、Python环境安装二、VScode下载及安装三、VSCode配置Python环境四、运行测试五、背景图设置 很高兴你打开了这篇博客,更多详细的安装教程&…

VBA编程:自定义函数 - 字符串转Hex数据

目录 一、自定义函数二、语法将字符串转换为hex数据MID函数:返回一个字符串中指定位置和长度的子串LEN函数:返回一个字符串的长度(字符数)Asc函数三、定义变量和数据类型变量声明的基本语法常见的数据类型四、For循环基本语法五、&运算符一、自定义函数 定义:用户定义…

jvm字节码中方法的结构

“-Xss”这一名称并没有一个特定的“为什么”来解释其命名,它更多是JVM(Java虚拟机)配置参数中的一个约定俗成的标识。在JVM中,有多个配置参数用于调整和优化Java应用程序的性能,这些参数通常以一个短横线“-”开头&am…

网络架构与IP技术:4K/IP演播室制作的关键支撑

随着科技的不断发展,广播电视行业也在不断迭代更新,其中4K/IP演播室技术的应用成了一个引人注目的焦点。4K超高清技术和IP网络技术的结合,不仅提升了节目制作的画质和效果,还为节目制作带来了更高的效率和灵活性。那么4K超高清技术…

Mac上Stable Diffusion的环境搭建(还算比较简单)

https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon AI兴起的速度是真的快,感觉不了解点相关的东西都要与时代脱节了,吓得我赶紧找个AIGC看看能不能实现我艺术家的人梦想(绷不住了) 我…