动手学深度学习(Pytorch版)代码实践 -计算机视觉-49风格迁移

49风格迁移

在这里插入图片描述
在这里插入图片描述

读入内容图像:

import torch
import torchvision
from torch import nn
import matplotlib.pylab as plt
import liliPytorch as lp
from d2l import torch as d2l


# 读取内容图像
content_img = d2l.Image.open('../limuPytorch/images/rainier.jpg')
plt.imshow(content_img)
plt.show()

在这里插入图片描述

读取风格图像:

# 读取风格图像
style_img = d2l.Image.open('../limuPytorch/images/autumn-oak.jpg')
plt.imshow(style_img)
plt.show()

在这里插入图片描述

import torch
import torchvision
from torch import nn
import matplotlib.pylab as plt
import liliPytorch as lp
from d2l import torch as d2l


# 读取内容图像
content_img = d2l.Image.open('../limuPytorch/images/rainier.jpg')
# plt.imshow(content_img)
# plt.show()

# 读取风格图像
style_img = d2l.Image.open('../limuPytorch/images/autumn-oak.jpg')
# plt.imshow(style_img)
# plt.show()


# 预处理和后处理
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])

# 函数preprocess对输入图像在RGB三个通道分别做标准化,
# 并将结果变换成卷积神经网络接受的输入格式
def preprocess(img, image_shape):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
    return transforms(img).unsqueeze(0) # 增加一个通道

# 后处理函数postprocess则将输出图像中的像素值还原回标准化之前的值。 
# 由于图像打印函数要求每个像素的浮点数值在0~1之间,我们对小于0和大于1的值分别取0和1。
def postprocess(img):
    # img[0] 表示移除批次维度,从批次中提取出第一个图像
    img = img[0].to(rgb_std.device) # 移除批次维度,并将图像张量移动到与 rgb_std 相同的设备
    img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1) # 反转标准化过程
    return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))
    # ToPILImage() 期望的输入是 [C, H, W] 形式,因此需要再次将张量的通道维度移动到第一个位置。

# 抽取图像特征
# 使用基于ImageNet数据集预训练的VGG-19模型
# VGG19包含了19个隐藏层(16个卷积层和3个全连接层)
pretrained_net = torchvision.models.vgg19(pretrained=True)

"""
 一般来说,越靠近输入层,越容易抽取图像的细节信息;反之,则越容易抽取图像的全局信息。 
 为了避免合成图像过多保留内容图像的细节,我们选择VGG较靠近输出的层,即内容层,来输出图像的内容特征。 
 我们还从VGG中选择不同层的输出来匹配局部和全局的风格,这些图层也称为风格层。
"""
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
# net 模型包含了 VGG-19 从第 0 层到第 28 层的所有层
net = nn.Sequential(*[pretrained_net.features[i] for i in
                      range(max(content_layers + style_layers) + 1)])

# 由于我们还需要中间层的输出,
# 因此这里我们逐层计算,并保留内容层和风格层的输出
def extract_features(X, content_layers, style_layers):
    contents = []
    styles = []
    for i in range(len(net)):
        X = net[i](X)
        if i in style_layers:
            styles.append(X)
        if i in content_layers:
            contents.append(X)
    return contents, styles

# 对内容图像抽取内容特征
def get_contents(image_shape, device):
    content_X = preprocess(content_img, image_shape).to(device)
    contents_Y, _ = extract_features(content_X, content_layers, style_layers)
    return content_X, contents_Y

# 对风格图像抽取风格特征
def get_styles(image_shape, device):
    style_X = preprocess(style_img, image_shape).to(device)
    _, styles_Y = extract_features(style_X, content_layers, style_layers)
    return style_X, styles_Y

# 定义损失函数
# 由内容损失、风格损失和全变分损失3部分组成

# 内容损失
# 内容损失通过平方误差函数衡量合成图像与内容图像在内容特征上的差异
# 平方误差函数的两个输入均为extract_features函数计算所得到的内容层的输出。
def content_loss(Y_hat, Y):
    # 我们从动态计算梯度的树中分离目标:
    # 这是一个规定的值,而不是一个变量。
    return torch.square(Y_hat - Y.detach()).mean()

# 风格损失
def gram(X): # 基于风格图像的格拉姆矩阵
    num_channels, n = X.shape[1], X.numel() // X.shape[1]
    X = X.reshape((num_channels, n))
    return torch.matmul(X, X.T) / (num_channels * n)

def style_loss(Y_hat, gram_Y):
    return torch.square(gram(Y_hat) - gram_Y.detach()).mean()

# 全变分损失
# 合成图像里面有大量高频噪点,即有特别亮或者特别暗的颗粒像素。 
# 一种常见的去噪方法是全变分去噪total variation denoising
def tv_loss(Y_hat):
    return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
                  torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

"""
风格转移的损失函数是内容损失、风格损失和总变化损失的加权和。
通过调节这些权重超参数,我们可以权衡合成图像在保留内容、迁移风格以及去噪三方面的相对重要性。
"""
content_weight, style_weight, tv_weight = 1, 1e3, 10

def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
    # 分别计算内容损失、风格损失和全变分损失
    contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
        contents_Y_hat, contents_Y)]
    styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
        styles_Y_hat, styles_Y_gram)]
    tv_l = tv_loss(X) * tv_weight
    # 对所有损失求和
    l = sum(10 * styles_l + contents_l + [tv_l])
    return contents_l, styles_l, tv_l, l


# 初始化合成图像
class SynthesizedImage(nn.Module):
    def __init__(self, img_shape, **kwargs):
        super(SynthesizedImage, self).__init__(**kwargs)
        self.weight = nn.Parameter(torch.rand(*img_shape))

    def forward(self):
        return self.weight
    
# 函数创建了合成图像的模型实例,并将其初始化为图像X
def get_inits(X, device, lr, styles_Y):
    gen_img = SynthesizedImage(X.shape).to(device)
    gen_img.weight.data.copy_(X.data)
    trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
    styles_Y_gram = [gram(Y) for Y in styles_Y]
    return gen_img(), styles_Y_gram, trainer

# 训练模型
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
    X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)  # 初始化合成图像和优化器
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
    animator = lp.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs],
                            legend=['content', 'style', 'TV'],
                            ncols=2, figsize=(7, 2.5))
    for epoch in range(num_epochs):
        trainer.zero_grad()  # 梯度清零
        contents_Y_hat, styles_Y_hat = extract_features(
            X, content_layers, style_layers)  # 提取特征
        contents_l, styles_l, tv_l, l = compute_loss(
            X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)  # 计算损失
        l.backward()  # 反向传播计算梯度
        trainer.step()  # 更新模型参数
        scheduler.step()  # 更新学习率
        if (epoch + 1) % 10 == 0:
            animator.axes[1].imshow(postprocess(X))
            animator.add(epoch + 1, [float(sum(contents_l)),
                                     float(sum(styles_l)), float(tv_l)])
    return X

device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)
plt.show()

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/761428.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

J019_选择排序

一、排序算法 排序过程和排序原理如下图所示&#xff1a; 二、代码实现 package com.itheima.sort;import java.util.Arrays;public class SelectSort {public static void main(String[] args) {int[] arr {5, 4, 3, 1, 2};//选择排序for (int i 0; i < arr.length - 1…

django西餐厅管理系统-计算机毕业设计源码10873

摘要 在现代餐饮行业中&#xff0c;高效的管理系统对于西餐厅的成功运营至关重要。为了满足西餐厅日益增长的管理需求&#xff0c;设计并实现了一款基于 Python 的西餐厅管理系统。 Python作为一种简洁而易读的编程语言&#xff0c;具有广泛的应用领域&#xff0c;包括Web开发。…

MySQL5.7安装初始化错误解决方案

问题背景 今天在给公司配数据库环境时,第一次报initializing database 数据库初始化错误? 起初没管以为是安装软件原因,然后就出现以下错误:如下图 点开log,我们观察日志会发现 无法识别的参数 ‘mysqlx_port=0.0’,???,官方的安装程序还能出这问题?

docker的安装与基本使用

一.docker的安装卸载 1.先安装所需软件包 yum install -y yum-utils2.设置stable镜像仓库 yum-config-manager --add-repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo 3.安装DOCKER CE yum -y install docker-ce docker-ce-cli containerd.io 4.验…

【SpringCloud】Eureka源码解析 上

Eureka是一个服务发现与注册组件&#xff0c;它包含服务端和客户端&#xff0c;服务端管理服务的注册信息&#xff0c;客户端简化服务实例与服务端的交互。我们结合源码来分析下eureka组件的实现原理&#xff0c;内容分为上下两章&#xff0c;第一章分析eureka的服务注册&#…

【每日一练】python if选择判断结构应用

应用类 1. 计算面积 编写一个Python程序&#xff0c;计算矩形的面积。要求用户输入矩形的宽和高&#xff0c;然后计算并打印面积。 width float(input("请输入矩形的宽&#xff1a;")) height float(input("请输入矩形的高&#xff1a;")) if width &…

《数字图像处理与机器视觉》案例三 (基于数字图像处理的物料堆积角快速测量)

一、前言 物料堆积角是反映物料特性的重要参数&#xff0c;传统的测量方法将物料自然堆积&#xff0c;测量物料形成的圆锥表面与水平面的夹角即可&#xff0c;该方法检测效率低。随着数字成像设备的推广和应用&#xff0c;应用数字图像处理可以更准确更迅速地进行堆积角测量。 …

Visual Studio 设置回车代码补全

工具 -> 选项 -> 文本编辑器 -> C/C -> 高级 -> 主动提交成员列表 设置为TRUE

萨科微slkor金航标kinghelm的品牌海外布局

萨科微slkor&#xff08;www.slkormicro.com&#xff09;金航标kinghelm宋仕强在介绍品牌的海外布局时说&#xff0c; 萨科微目前在土耳其、印度、新加坡均有代理商&#xff0c;海外代理商还在不断的发展和筛选中&#xff01;欢迎公司有资质&#xff0c;在当地有一定客户关系的…

Axure 中继器 实现选取表格行、列交互

Axure 中继器 实现选取表格行、列交互 引言 在办公软件或富文本编辑器中插入表格的时候我们经常可以通过在表格上移动鼠标&#xff0c;可以选取想要插入的表格行、列数。 本文分享如何通过 Axure 实现这个交互。 放入中继器 这个交互的用到了中继器&#xff0c;所以首先在…

WPF/C#:BusinessLayerValidation

BusinessLayerValidation介绍 BusinessLayerValidation&#xff0c;即业务层验证&#xff0c;是指在软件应用程序的业务逻辑层&#xff08;Business Layer&#xff09;中执行的验证过程。业务逻辑层是应用程序架构中的一个关键部分&#xff0c;负责处理与业务规则和逻辑相关的…

【C++】继承(详解)

前言&#xff1a;今天我们正式的步入C进阶内容的学习了&#xff0c;当然了既然是进阶意味着学习难度的不断提升&#xff0c;各位一起努力呐。 &#x1f496; 博主CSDN主页:卫卫卫的个人主页 &#x1f49e; &#x1f449; 专栏分类:高质量&#xff23;学习 &#x1f448; &#…

2065. 最大化一张图中的路径价值 Hard

给你一张 无向 图&#xff0c;图中有 n 个节点&#xff0c;节点编号从 0 到 n - 1 &#xff08;都包括&#xff09;。同时给你一个下标从 0 开始的整数数组 values &#xff0c;其中 values[i] 是第 i 个节点的 价值 。同时给你一个下标从 0 开始的二维整数数组 edges &#xf…

电子技术基础(模电部分)笔记

终于整理出来了&#xff0c;可以安心复习大物线代了&#xff01;&#xff01; 数电部分预计7.10出

007-GeoGebra基础篇-构建等边三角形

今天继续来一篇尺规作图&#xff0c;可以跟着操作一波&#xff0c;刚开始我写的比较细一点&#xff0c;每步都有截图&#xff0c;后续内容逐渐复杂后我就只放置算式咯。 目录 一、先看看一下最终效果二、本次涉及的内容三、开始尺规画图1. 绘制定点A和B2. 绘制线段AB3. 以点A为…

【日记】度过了一个堕落的周末……(184 字)

正文 昨天睡了一天觉&#xff0c;今天看了一天《三体》电视剧。真是堕落到没边了呢&#xff08;笑。本来想写代码完成年度计划&#xff0c;或者多写几篇文章&#xff0c;但实在不想写&#xff0c;也不想动笔。 感觉这个周末什么都没做呢&#xff0c;休息倒是休息好了。 今天 30…

基于x86/ARM+FPGA+AI工业相机的智能工艺缺陷检测,可以检测点状,线状,面状的缺陷

应用场景 缺陷检测 在产品的制造生产环节中发挥着极其重要作用。智能工业缺陷检测能够替代传统的人工检测&#xff0c;降低人为判断漏失&#xff0c;使得产品质量大幅提升的同时降低了工厂的人力成本。智能工艺缺陷检测技术可以检测点状&#xff0c;线状&#xff0c;面状的缺陷…

UnityUGUI之四 Mask

会将上级物体遮盖 注&#xff1a; 尽量不使用Mask&#xff0c;因为其会过度消耗运行资源&#xff0c;可以使用Rect 2DMask&#xff0c;但容易造成bug&#xff0c;因此最好实现遮罩效果的方式为自己写一个mask物体

用易查分下发《致家长一封信》,支持在线手写签名,一键导出PDF!

暑假来临之际&#xff0c;学校通常需要下发致家长信&#xff0c;以正式、书面的形式向家长传达重要的通知或建议。传统的发放方式如家长签字后学生将回执单上交&#xff0c;容易存在丢失、遗忘的问题。 那么如何更高效、便捷、安全地将致家长一封信送达给每位家长呢&#xff1f…

项目方案:社会视频资源整合接入汇聚系统解决方案(八)---视频监控汇聚应用案例

目录 一、概述 1.1 应用背景 1.2 总体目标 1.3 设计原则 1.4 设计依据 1.5 术语解释 二、需求分析 2.1 政策分析 2.2 业务分析 2.3 系统需求 三、系统总体设计 3.1设计思路 3.2总体架构 3.3联网技术要求 四、视频整合及汇聚接入 4.1设计概述 4.2社会视频资源分…