G4 - 可控手势生成 CGAN

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目录

  • 代码
  • 总结与心得


代码

关于CGAN的原理上节已经讲过,这次主要是编写代码加载上节训练后的模型来进行指定条件的生成

图像的生成其实只需要使用Generator模型,判别器模型是在训练过程中才用的。

# 库引入
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 超参数
latent_dim = 100
n_classes = 3
embedding_dim = 100

# 工具函数
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)


# 模型
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),
            nn.Linear(embedding_dim, 16)
        )

        self.latent = nn.Sequential(
            nn.Linear(latent_dim, 4*4*512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.model = nn.Sequential(
            nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, inputs):
        noise_vector, label = inputs
        label_output = self.label_conditioned_generator(label)
        label_output = label_output.view(-1, 1, 4, 4)
        latent_output = self.latent(noise_vector)
        latent_output = latent_output.view(-1, 512, 4, 4)

        concat = torch.cat((latent_output, label_output), dim=1)
        image = self.model(concat)
        return image

generator = Generator().to(device)
generator.apply(weights_init)
print(generator)
Generator(
  (label_conditioned_generator): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=16, bias=True)
  )
  (latent): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (model): Sequential(
    (0): ConvTranspose2d(513, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
from numpy.random import randint, randn
from numpy import linspace
from matplotlib import pyplot, gridspec

# 加载训练好的权重
generator.load_state_dict(torch.load('generator_epoch_300.pth'), strict=False)
# 关闭梯度积累
generator.eval()

# 生成随机变量
interpolated = randn(100)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)

# 生成条件变量
label = 0 # 生成第0个分类的图像
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()

# 执行生成
predictions = generator((interpolated, labels))
predictions = predictions.permute(0, 2, 3, 1).detach().cpu()

# 屏蔽警告
import warnings
warnings.filterwarnings('ignore')


# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
# 防止负号无法显示
plt.rcParams['axes.unicode_minus']= False
# 设置图的分辨率
plt.rcParams['figure.dpi'] = 100

# 绘图
plt.figure(figsize=(8, 3))
pred = (predictions[0, :, :, :] + 1) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()

生成分类0
我们将分类修改为1重新生成一次

生成分类1

总结与心得

在本次实验的过程中,我了解了CGAN模型在训练完成后,后续如何使用的步骤:

  1. 保存训练好的生成器的权重
  2. 使用生成器加载
  3. 生成随机分布变量用于生成图像
  4. 生成指定的标签,并转换成控制向量
  5. 执行生成操作

另外关于警告和matplotlib设置中文字体的方式也是经常会用到的技巧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/666322.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

艾体宝方案 | redis赋能游戏开发,游戏玩家纵享丝滑

掉线,加载缓慢,反馈无跟进,这些令游戏玩家炸毛的问题,同时也是游戏开发者关注的问题。开发者将目光投向了Redis,一个实时数据平台,告别卡顿延迟! 一、玩家不掉线,游戏更丝滑 在大型…

Redis学习笔记【基础篇】

SQL vs NOSQL SQL(Structured Query Language)和NoSQL(Not Only SQL)是两种不同的数据库处理方式,它们在多个维度上有所差异,主要区别包括: 数据结构: SQL(关系型数据库)…

【OceanBase诊断调优】—— obdiag 工具助力OceanBase数据库诊断调优(DBA 从入门到实践第八期)

1. 前言 昨天给大家分享了【DBA从入门到实践】第八期:OceanBase数据库诊断调优、认证体系和用户实践 中obdiag的部分,今天将其中的内容以博客的形式给大家展开一下,方便大家阅读。 2. 正文 在介绍敏捷诊断工具之前,先说说OceanBa…

移植其他命令行Vivado IDE的工具

移植其他命令行Vivado IDE的工具 介绍 本章介绍如何迁移各种AMD命令行工具以在AMD中使用 Vivado™集成设计环境(IDE)。 迁移ISE Partgen命令行工具 ISE™Design Suite Partgen工具可获得: •系统上安装的所有设备的信息 •详细的包装信息 您可…

Leetcode刷题笔记7

69. x 的平方根 69. x 的平方根 - 力扣(LeetCode) 假设求17的平方根 解法一:暴力解法 从1开始依次尝试 比如1的平方是1,2的平方是4...直到5的平方,25>17,所以一定是4点几的平方,所以等于4…

YAML快速编写示例

一、案例 1.1 自主式创建service关联上方的pod 资源名称my-nginx-kkk命名空间my-kkk容器镜像nginx:1.21容器端口80标签njzb:my-kkk 1.1.1 创建一个demo文件夹 1.1.2 创建并获取模版文件 1.1.3 查看服务并编写yaml文件 1.1.4 编写yaml文件并部署,查看服务是否运行成…

Nginx过滤指定日志不输出

参考Nginx过滤指定日志不输出 | LogDicthttps://www.logdict.com/archives/nginxguo-lu-xin-tiao-ri-zhi

File(文件)

File对象表示一个路径,可以是文件的路径,也可以是文件夹的路径。 这个路径可以存在,也允许不存在。 创建File对象的方法 public class test {public static void main(String [] args) {//根据字符串创建文件String str"C:\\Users\\PC…

FFmpeg 中 Filters 使用文档介绍

描述 这份文档描述了由libavfilter库提供的过滤器Filters、源sources和接收器sinks。 滤镜介绍 FFmpeg通过libavfilter库启用过滤功能。在libavfilter中,一个过滤器可以有多个输入和多个输出。为了说明可能的类型,我们考虑以下过滤器图: 这个过滤器图将输入流分成两个流,然…

Source Insight 变量高亮快捷键F8 失效

SourceInsight4.0,使用的时候,高亮快捷键F8突然不能用了 查半天发现,是用了“有道翻译”的原因,热键冲突,如下,把下面的热键换一个就好了

【第七节】C++的STL基本使用

目录 前言 一、STL简介 1.1 STL基本概念 1.2 STL六大组件 1.3 STL优点 二、STL三大组件 2.1 容器 2.2 算法 2.3 迭代器 三、STL常见的容器 3.1 string容器 3.1.1 string容器基本概念 3.1.2 string容器的常用操作 3.1.2.1 string 构造函数 3.1.2.2 string 基本赋…

【AVL Design Explorer DOE】

AVL Design Explorer DOE 1、关于DOE的个人理解2、DOE参考资料-知乎2.1 DOE发展及基本类型2.2 DOE应用场景2.3 Mintab 中的 DOE工具3、AVL Design Explorer DOE示例 1、关于DOE的个人理解 仿真和试验一样,就像盲人摸象,在不知道大象的全景之前&#xff…

SpringCloud学习笔记(一)

SpringCloud、SpringCloud Alibaba 前置知识: 核心新组件: 所用版本: 学习方法: 1.看理论:官网 2.看源码:github 一、微服务理论知识 二、关于SpringCloud各种组件的停更/升级/替换 主业务逻辑是&#x…

Mysql基础教程(11):DISTINCT

MySQL DISTINCT 用法和实例 当使用 SELECT 查询数据时,我们可能会得到一些重复的行。比如学生表中有很多重复的年龄。如果想得到一个唯一的、没有重复记录的结果集,就需要用到 DISTINCT 关键字。 MySQL DISTINCT用法 在 SELECT 语句中使用 DISTINCT 关…

jsmug:一个针对JSON Smuggling技术的测试PoC环境

关于jsmug jsmug是一个代码简单但功能强大的JSON Smuggling技术环境PoC,该工具可以帮助广大研究人员深入学习和理解JSON Smuggling技术,并辅助提升Web应用程序的安全性。 背景内容 JSON Smuggling技术可以利用目标JSON文档中一些“不重要”的字节数据实…

判断自守数-第13届蓝桥杯选拔赛Python真题精选

[导读]:超平老师的Scratch蓝桥杯真题解读系列在推出之后,受到了广大老师和家长的好评,非常感谢各位的认可和厚爱。作为回馈,超平老师计划推出《Python蓝桥杯真题解析100讲》,这是解读系列的第75讲。 判断自守数&#…

发电机组故障的原因、解决方案及解决措施

发电机组故障的原因、解决方案及解决措施可以总结如下: 一、故障原因 供电中断 原因:电网故障、线路短路或电力负荷过重等。 燃油问题 原因:燃油供应系统问题,如燃油管路堵塞、燃油质量不佳等。 轴承过热 原因:轴承过…

Android 11 Audio音频系统配置文件解析

在AudioPolicyService的启动过程中,会去创建AudioPolicyManager对象,进而去解析配置文件 //frameworks/av/services/audiopolicy/managerdefault/AudioPolicyManager.cpp AudioPolicyManager::AudioPolicyManager(AudioPolicyClientInterface *clientIn…

19.1 简易抽奖

准备一个数组&#xff0c;里面添加10个奖品数据&#xff0c;让奖品数据快速的在盒子中随机显示&#xff0c;通过按钮控制盒子里面的内容停止。 效果图&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8">&…

01-1.2.3 算法的空间复杂度

什么是空间复杂度&#xff1f; 代码在运行之前需要先装入内存&#xff0c;程序代码需要占一定的位置&#xff08;在这边假设是100B&#xff09; 定义的变量和参数i&#xff0c;n都需要占用内存空间 //算法一——逐步递增型 void loveYou(int n) { //n为问题规模int i 1; /…