python基于VGG19实现图像风格迁移

目录

1、原理

2、代码实现


1、原理

图像风格迁移是一种将一张图片的内容与另一张图片的风格进行合成的技术。

风格(style)是指图像中不同空间尺度的纹理、颜色和视觉图案,内容(content)是指图像的高级宏观结构。

实现风格迁移背后的关键概念与所有深度学习算法的核心思想是一样的:定义一个损失函数来指定想要实现的目标,然后将这个损失最小化。你知道想要实现的目标是什么,就是保存原始图像的内容,同时采用参考图像的风格。

在Python中,我们可以使用基于深度学习的模型来实现这一技术。​神经风格迁移可以用任何预训练卷积神经网络来实现。我们这里将使用    Gatys等人所使用的 VGG19网络。

2、代码实现

​以下是一个基于VGG19模型的简单图像风格迁移的实现过程:

(1)创建一个网络,它能够同时计算风格参考图像、目标图像和生成图像的 VGG19层激活。

(2)使用这三张图像上计算的层激活来定义之前所述的损失函数,为了实现风格迁移,需要将这个损失函数最小化。

(3)设置梯度下降过程来将这个损失函数最小化

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
​
import torch
import torch.optim as optim
from torchvision import transforms, models
​
vgg = models.vgg19(pretrained=True).features
​
for param in vgg.parameters():
    param.requires_grad_(False)
​
​
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
​
vgg.to(device)
​
​
def load_image(img_path, max_size=400):
​
    image = Image.open(img_path)
    
    if max(image.size) > max_size:
        size = max_size
    else:
        size = max(image.size)
        
    image_transform = transforms.Compose([
                        transforms.Resize(size),
                        transforms.ToTensor(),
                        transforms.Normalize((0.485, 0.456, 0.406), 
                                             (0.229, 0.224, 0.225))])
​
    image = image_transform(image).unsqueeze(0)
    
    return image
​
​
content = load_image('dogs_and_cats.jpg').to(device)
style = load_image('picasso.jpg').to(device)
​
​
assert style.size() == content.size(), "输入的风格图片和内容图片大小需要一致"
​
​
plt.ion()
def imshow(tensor,title=None):
    
    image = tensor.cpu().clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.1)
​
plt.figure()
imshow(style, title='Style Image')
​
plt.figure()
imshow(content, title='Content Image')
​
​
​
​
def get_features(image, model, layers=None):
   
    if layers is None:
        layers = {'0': 'conv1_1',
                  '5': 'conv2_1', 
                  '10': 'conv3_1', 
                  '19': 'conv4_1',
                  '21': 'conv4_2',  
                  '28': 'conv5_1'}
        
    features = {}
    x = image
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
            
    return features
​
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)
​
def gram_matrix(tensor):
    
    _, d, h, w = tensor.size() 
    tensor = tensor.view(d, h * w)
    gram = torch.mm(tensor, tensor.t())
    return gram
​
style_grams={}
for layer in style_features:
  style_grams[layer] = gram_matrix(style_features[layer])
​
import torch.nn.functional as F
​
def ContentLoss(target_features,content_features):
  content_loss = F.mse_loss(target_features['conv4_2'],content_features['conv4_2'])
  return content_loss
​
def StyleLoss(target_features,style_grams,style_weights):
    style_loss = 0
    for layer in style_weights:
        target_feature = target_features[layer]
        target_gram = gram_matrix(target_feature)
        _, d, h, w = target_feature.shape
        style_gram = style_grams[layer]
        layer_style_loss = style_weights[layer] * F.mse_loss(target_gram,style_gram)
        style_loss += layer_style_loss / (d * h * w)
​
    return style_loss
​
​
style_weights = {'conv1_1': 1.,
                 'conv2_1': 0.75,
                 'conv3_1': 0.2,
                 'conv4_1': 0.2,
                 'conv5_1': 0.2}
​
alpha = 1  # alpha
beta = 1e6  # beta
​
​
show_every = 100
steps = 2000 
​
target = content.clone().requires_grad_(True).to(device)
optimizer = optim.Adam([target], lr=0.003)
​
​
for ii in range(1, steps+1):
    
    target_features = get_features(target, vgg)
    
    content_loss = ContentLoss(target_features,content_features)
    
    style_loss = StyleLoss(target_features,style_grams,style_weights)
        
    total_loss = alpha * content_loss + beta * style_loss
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    #print(ii)
    
    if  ii % show_every == 0:
        print('Total loss: ', total_loss.item())
        plt.figure()      
        imshow(target)
​
plt.figure()
imshow(target,"Target Image")
plt.ioff()
plt.show()
​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/111239.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络重点概念整理-第四章 网络层【期末复习|考研复习】

计算机网络复习系列文章传送门: 第一章 计算机网络概述 第二章 物理层 第三章 数据链路层 第四章 网络层 第五章 传输层 第六章 应用层 第七章 网络安全 计算机网络整理-简称&缩写 文章目录 前言四、网络层4.1 网络层功能4.1.1 电路交换、报文交换与分组交换4.1…

电商数据采集抓取封装数据、淘宝、天猫、京东等平台商品详情API接口参数详解

电商数据采集抓取数据、淘宝、天猫、京东等平台的电商数据抓取,网页爬虫、采集网站数据、网页数据采集软件、python爬虫、HTM网页提取、APP数据抓包、APP数据采集、一站式网站采集技术、BI数据的数据分析、数据标注等成为大数据发展中的热门技术关键词。那么电商数据…

【git】git拉取代码报错,fatal: refusing to merge unrelated histories问题解决

大家好,我是好学的小师弟。今天准备将之前写的代码,拉到新的工程文件夹(仓库)下面,用了pull命令,结果报错了,报错截图如下 $ git pull https://gitee.com/* #仓库地址 fatal: refusing to merge unrelated histor…

12、SpringCloud -- redis库存和redis预库存保持一致、优化后的压测效果

目录 redis库存和redis预库存保持一致问题的产生需求:代码:测试:优化后的压测效果之前的测试数据优化后的测试数据redis库存和redis预库存保持一致 redis库存是指初始化是从数据库中获取最新的秒杀商品列表数据存到redis中 redis的预库存是指每个秒杀商品每次成功秒杀之后…

Spring循环依赖处理

循环依赖是指两个或多个组件之间相互依赖,形成一个闭环,从而导致这些组件无法正确地被初始化或加载。这种情况可能会在软件开发中引起问题,因为循环依赖会导致初始化顺序混乱,组件之间的关系变得复杂,甚至可能引发死锁…

汉诺塔问题

作者本文的目标是利用递归求解汉诺塔的具体步骤 目录 汉诺塔是什么游戏思路1.最简单的情况——一个圆盘依次类推,增加盘子个数2.两个圆盘3.三个圆盘解释递归过程 完整代码 汉诺塔是什么 汉诺塔(Tower of Hanoi),又称河内塔&#x…

Kubernetes - Ingress HTTP 升级 HTTPS 配置解决方案(新版本v1.21+)

之前我们讲解过 Kubernetes - Ingress HTTP 搭建解决方案,并分别提供了旧版本和新版本。如果连 HTTP 都没搞明白的可以先去过一下这两篇 Kubernetes - Ingress HTTP 负载搭建部署解决方案_放羊的牧码的博客-CSDN博客Kubernetes - Ingress HTTP 负载搭建部署解决方案…

学习笔记---更进一步的双向链表专题~~

目录 1. 双向链表的结构🦊 2. 实现双向链表🐝 2.1 要实现的目标🎯 2.2 创建初始化🦋 2.2.1 List.h 2.2.2 List.c 2.2.3 test.c 2.2.4 代码测试运行 2.3 尾插打印头插🪼 思路分析 2.3.1 List.h 2.3.2 List.…

企业电子招标采购系统源码Spring Boot + Mybatis + Redis + Layui + 前后端分离 构建企业电子招采平台之立项流程图

项目说明 随着公司的快速发展,企业人员和经营规模不断壮大,公司对内部招采管理的提升提出了更高的要求。在企业里建立一个公平、公开、公正的采购环境,最大限度控制采购成本至关重要。符合国家电子招投标法律法规及相关规范,以及审…

Ceph入门到精通-bluestore IO流程及导入导出

bluestore 直接管理裸设备,实现在用户态下使用linux aio直接对裸设备进行I/O操作 写IO流程: 一个I/O在bluestore里经历了多个线程和队列才最终完成,对于非WAL的写,比如对齐写、写到新的blob里等,I/O先写到块设备上&am…

0003net程序设计-net旅游景点推荐系统

文章目录 摘 要目录系统设计开发环境 摘 要 随着信息技术和网络技术的飞速发展,人类已进入全新信息化时代,传统管理技术已无法高效,便捷地管理信息。为了迎合时代需求,优化管理效率,各种各样的管理系统应运而生&#…

多模态 多引擎 超融合 新生态!2023亚信科技AntDB数据库8.0产品发布

9月20日,以“多模态 多引擎 超融合 新生态”为主题的亚信科技AntDB数据库8.0产品发布会成功举办,从技术和生态两个角度全方位展示了AntDB数据库第8次大型能力升级和生态建设成果。浙江移动、用友、麒麟软件、华录高诚、金云智联等行业伙伴及业界专家共同…

Goland连接服务器/虚拟机远程编译开发

创建SSH连接 SSH用于与远程服务器建立连接 Settings -> Tools -> SSH Configurations 添加新的ssh连接,Host为ip地址,Username为用户名,认证方式这里选择密码验证 全部填完后可以点击Test Connection测试连接是否成功 创建Deployment…

nginx 转发数据流文件

1.问题描述 后端服务,从数据库中查询日志,并生成表格文件返回静态文件。当数据量几兆时,返回正常,但是超过几十兆,几百兆,就会超过网关的连接超时时间30秒。 时序图 这里面主要花费时间的地方在&#xff…

SPSS单样本t检验

前言: 本专栏参考教材为《SPSS22.0从入门到精通》,由于软件版本原因,部分内容有所改变,为适应软件版本的变化,特此创作此专栏便于大家学习。本专栏使用软件为:SPSS25.0 本专栏所有的数据文件请点击此链接下…

在excel中如何打出上标、下标

例如,想把A2的2变为下标。 在单元中输入内容: 选中2: 右键单击,然后点击“设置单元格格式”: 在特殊效果的下面勾选“下标”,然后点击下面的“确定”按钮: 就将2变为下标了:…

线扫相机DALSA--采集卡Base模式设置

采集卡默认加载“1 X Full Camera Link”固件,Base模式首先要将固件更新为“2 X Base Camera Link”。 右键SCI图标,选择“打开文件所在的位置”,找到并打开SciDalsaConfig的Demo,如上图所示: 左键单击“获取相机”&a…

【错误解决方案】ModuleNotFoundError: No module named ‘xgboost‘

1. 错误提示 在尝试导入名为xgboost的模块时出现了ModuleNotFoundError。 错误提示:ModuleNotFoundError: No module named xgboost 这个错误通常意味着Python环境中没有安装你试图导入的模块。 2. 解决方案 安装xgboost模块即可解决上述问题。 可以通过Python…

对于SOCKET套接字问题的若干认识

1. 首先大家应该知道Socket 编程吧 Socket套接字 分为 应用层套接字 数据链路层套接字(也就是原始socket) 1.流套接字(SOCK_STREAM) 流套接字用于提供面向连接、可靠的数据传输服务。该服务将保证数据能够实现无差错、无重复送,并按顺序接…

智能运维第一步:HDD磁盘故障预测

当今数字化时代,信息技术扮演着企业和组织运营的关键角色。然而,随着IT环境不断复杂化和数据量激增,传统的运维管理方法已经无法满足日益增长的需求。为应对这一挑战,智能运维(Artificial intelligence for IT operati…