CGAN|生成手势图像|可控制生成

  •     🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:TensorFlow入门实战|第3周:天气识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息。

CGAN通过将条件信息(如类别标签或其他辅助信息)加入生成器和判别器的输入中,使得生成器能够根据这些条件信息生成特定类型的数据,而判别器则负责区分真实数据和生成数据是否符合这些条件。这种方式让生成器在生成数据时有了明确的方向,从而提高了生成数据的质量与相关性。

CGAN的特点包括有监督学习、联合隐层表征、可控性、使用卷积结构等,其具体内容为:

有监督学习:CGAN通过额外信息的使用,将原本无监督的GAN转变为一种有监督的学习模式,这使得网络的训练更加目标明确,生成结果更加符合预期。
联合隐层表征:在生成模型中,噪声输入和条件信息共同构成了联合隐层表征,这有助于生成更多样化且具有特定属性的数据。
可控性:CGAN的一个关键特点是提高了生成过程的可控性,即可以通过调整条件信息来指导模型生成特定类型的数据。
使用卷积结构:CGAN可以采用卷积神经网络作为其内部结构,这在图像相关的任务中尤其有效,因为它能够捕捉到局部特征,并提高模型对细节的处理能力。

一、前期工作

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
from torchsummary import summary
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 128
train_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])

train_dataset = datasets.ImageFolder(root="H:/G3/rps/rps", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True,
                                           num_workers=6)
def show_images(dl):
    for images, _ in dl:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images.detach(), nrow=16).permute(1, 2, 0))
        break

show_images(train_loader)

 

二、构建模型 

latent_dim = 100
n_classes = 3
embedding_dim = 100

def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim), 
            nn.Linear(embedding_dim, 16)            
        )
        self.latent = nn.Sequential(
            nn.Linear(latent_dim, 4*4*512),  
            nn.LeakyReLU(0.2, inplace=True)  
        )

        self.model = nn.Sequential( 
            nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),  
            nn.ReLU(True),            
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.ReLU(True),     
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
            nn.ReLU(True),       
            nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
            nn.ReLU(True),       
            nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
            nn.Tanh()  
        )

    def forward(self, inputs):
        noise_vector, label = inputs  
        label_output = self.label_conditioned_generator(label)     
        label_output = label_output.view(-1, 1, 4, 4)        
        latent_output = self.latent(noise_vector)     
        latent_output = latent_output.view(-1, 512, 4, 4) 
        concat = torch.cat((latent_output, label_output), dim=1)
        image = self.model(concat)
        return image

generator = Generator().to(device)
generator.apply(weights_init)
a = torch.ones(100)
b = torch.ones(1)
b = b.long()
a = a.to(device)
b = b.to(device)

import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_condition_disc = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),     
            nn.Linear(embedding_dim, 3*128*128)         
        )
        
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, bias=False),      
            nn.LeakyReLU(0.2, inplace=True),             
            nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),    
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),  
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False),  
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),  
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),                               
            nn.Dropout(0.4),                            
            nn.Linear(4608, 1),                         
            nn.Sigmoid()                                
        )

    def forward(self, inputs):
        img, label = inputs
        
        label_output = self.label_condition_disc(label)
        label_output = label_output.view(-1, 3, 128, 128)
        
        concat = torch.cat((img, label_output), dim=1)
        
        output = self.model(concat)
        return output

a = torch.ones(2,3,128,128)
b = torch.ones(2,1)
b = b.long()
a = a.to(device)
b = b.to(device)

c = discriminator((a,b))

 三、训练模型及可视化

 这一部分主要定义初始化权重,构建鉴别器和生成器。

# 定义损失函数
adversarial_loss = nn.BCELoss() 
 
def generator_loss(fake_output, label):
    gen_loss = adversarial_loss(fake_output, label)
    return gen_loss
 
def discriminator_loss(output, label):
    disc_loss = adversarial_loss(output, label)
    return disc_loss
learning_rate = 0.0002
 
G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

# 设置训练的总轮数
num_epochs = 100
# 初始化用于存储每轮训练中判别器和生成器损失的列表
D_loss_plot, G_loss_plot = [], []
 
# 循环进行训练
for epoch in range(1, num_epochs + 1):
    
    # 初始化每轮训练中判别器和生成器损失的临时列表
    D_loss_list, G_loss_list = [], []
    
    # 遍历训练数据加载器中的数据
    for index, (real_images, labels) in enumerate(train_loader):
        # 清空判别器的梯度缓存
        D_optimizer.zero_grad()
        # 将真实图像数据和标签转移到GPU(如果可用)
        real_images = real_images.to(device)
        labels      = labels.to(device)
        
        # 将标签的形状从一维向量转换为二维张量(用于后续计算)
        labels = labels.unsqueeze(1).long()
        # 创建真实目标和虚假目标的张量(用于判别器损失函数)
        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))
 
        # 计算判别器对真实图像的损失
        D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)
        
        # 从噪声向量中生成假图像(生成器的输入)
        noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
        noise_vector = noise_vector.to(device)
        generated_image = generator((noise_vector, labels))
        
        # 计算判别器对假图像的损失(注意detach()函数用于分离生成器梯度计算图)
        output = discriminator((generated_image.detach(), labels))
        D_fake_loss = discriminator_loss(output, fake_target)
 
        # 计算判别器总体损失(真实图像损失和假图像损失的平均值)
        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss)
 
        # 反向传播更新判别器的参数
        D_total_loss.backward()
        D_optimizer.step()
 
        # 清空生成器的梯度缓存
        G_optimizer.zero_grad()
        # 计算生成器的损失
        G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
        G_loss_list.append(G_loss)
 
        # 反向传播更新生成器的参数
        G_loss.backward()
        G_optimizer.step()
 
    # 打印当前轮次的判别器和生成器的平均损失
    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
            (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), 
            torch.mean(torch.FloatTensor(G_loss_list))))
    
    # 将当前轮次的判别器和生成器的平均损失保存到列表中
    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))
 
    if epoch%10 == 0:
        # 将生成的假图像保存为图片文件
        save_image(generated_image.data[:50], './sample_%d' % epoch + '.png', nrow=5, normalize=True)
        # 将当前轮次的生成器和判别器的权重保存到文件
        torch.save(generator.state_dict(), './generator_epoch_%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), './discriminator_epoch_%d.pth' % (epoch))

 
 
from numpy.random import randint, randn
from numpy import linspace
from matplotlib import pyplot as plt, gridspec
import numpy as np
 
# Assuming 'generator' and 'device' are defined earlier in your code
 
generator.load_state_dict(torch.load('./generator_epoch_100.pth'), strict=False)
generator.eval()
 
interpolated = randn(100)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)
 
label = 0
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()
 
predictions = generator((interpolated, labels))
predictions = predictions.permute(0, 2, 3, 1).detach().cpu()
 
import warnings
warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100
 
plt.figure(figsize=(8, 3))
 
pred = (predictions[0, :, :, :] + 1) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()

 代码中的操作将预测结果的值加1(这样所有的值都变为非负数),然后乘以127.5,最终得到的值就在0到255之间。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/638163.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

影视解说5.0版零基础视频课程

课程简介 现在还能做解说吗、不会写解说文案怎么解决、不会配音怎么解决、如何找到合适的素材资源、如何变现…这是很多想做解说的伙伴最关心的几大问题。比如文案,我们推荐一个网站,10分钟搞定一篇文案,配音可以真人配音也可以软件配音。5.…

代码随想录算法训练营第三天| 203.移除链表元素、 707.设计链表、 206.反转链表

203.移除链表元素 题目链接: 203.移除链表元素 文档讲解:代码随想录 状态:没做出来,做题的时候定义了一个cur指针跳过了目标val遍历了一遍链表,实际上并没有删除该删的节点。 错误代码: public ListNode re…

Leecode热题100---45:跳跃游戏②

题目: 给定一个长度为 n 的 0 索引整数数组 nums。初始位置为 nums[0]。 每个元素 nums[i] 表示从索引 i 向前跳转的最大长度。 返回到达 nums[n - 1] 的最小跳跃次数。 思路: 如果某一个作为 起跳点 的格子可以跳跃的距离是 3,那么表示后面…

127.数据异构方案

文章目录 前言一、数据异构的常用方法1. 完整克隆2. MQ方式3. binlog方式 二、MQ与Binlog方案实现MQ方式binlog方式注意点 三、总结 前言 何谓数据异构:把数据按需(数据结构、存取方式、存取形式)异地构建存储。比如我们将DB里面的数据持久化…

【源码分享】简单的404 HTML页面示例,该页面在加载时会等待2秒钟,然后自动重定向到首页

展示效果 源码 html <!DOCTYPE html> <html lang"zh"> <head><meta charset"UTF-8"><title>404 页面未找到</title><meta http-equiv"refresh" content"2;url/"> <!-- 设置2秒后跳转到首…

适合小白入门的AI扩图(创成式填充)工具

近期&#xff0c;发现许多人对AI扩图工具的需求比较大&#xff0c;为了满足大家的需求&#xff0c;本期天祺为大家整理了一些好用的AI扩图工具&#xff0c;各个设配的扩图工具都有介绍哦&#xff0c;电脑&#xff0c;手机端都能用&#xff0c;大家可以根据自己的喜好和需求进行…

1075: 求最小生成树(Prim算法)

解法&#xff1a; 总结起来&#xff0c;Prim算法的核心思想是从一个顶点开始&#xff0c;一步一步地选择与当前最小生成树相邻的且权值最小的边&#xff0c;直到覆盖所有的顶点&#xff0c;形成一个最小生成树。 #include<iostream> #include<vector> using names…

Kubernetes 应用滚动更新

Kubernetes 应用版本号 在 Kubernetes 里&#xff0c;版本更新使用的不是 API 对象&#xff0c;而是两个命令&#xff1a;kubectl apply 和 kubectl rollout&#xff0c;当然它们也要搭配部署应用所需要的 Deployment、DaemonSet 等 YAML 文件。 在 Kubernetes 里应用都是以 …

力扣HOT100 - 169. 多数元素

解题思路&#xff1a; 有点类似于Boyer-Moore 投票算法&#xff0c;但更加形象。 class Solution {public int majorityElement(int[] nums) {int winner nums[0];int cnt 1;for (int i 1; i < nums.length; i) {if (winner nums[i]){cnt;} else if (cn…

Redis每月运维

为防止redis自动aof缩放失败 每月手动执行一次重写命令 bgrewriteaof 方式一&#xff1a; redis-cli 连接到每个服务器 认证后执行bgrewriteaof 示例 方式二&#xff1a; 通过工具连接到redis 执行命令 方式三: 定时任务系统 在定时任务系统里每天自动执行gocron - 定时任务…

基于transformers框架实践Bert系列5-阅读理解(文本摘要)

本系列用于Bert模型实践实际场景&#xff0c;分别包括分类器、命名实体识别、选择题、文本摘要等等。&#xff08;关于Bert的结构和详细这里就不做讲解&#xff0c;但了解Bert的基本结构是做实践的基础&#xff0c;因此看本系列之前&#xff0c;最好了解一下transformers和Bert…

基于SpringBoot和Hutool工具包实现的验证码案例

目录 验证码案例 1. 需求 2. 准备工作 3. 约定前后端交互接口 需求分析 接口定义 4. Hutool 工具介绍 5. 实现验证码 后端代码 前端代码 6. 运行测试 验证码案例 随着安全性的要求越来越高&#xff0c;目前项目中很多都会使用验证码&#xff0c;只要涉及到登录&…

一个用Java编写的屏幕测距工具,包括游戏地图测量功能

该程序提供了一个简单便捷的方式&#xff0c;在屏幕上测量距离&#xff0c;包括游戏地图分析在内。它允许用户准确确定屏幕上两点之间的距离&#xff0c;帮助游戏过程中的战略规划、资源管理和决策制定。 特点&#xff1a; 简单易用的界面&#xff1a;直观的控制使测量距离变得…

Marin说PCB之POC电路layout设计仿真案例---03

今天天中午午休的时候&#xff0c;我刚要打开手机的准备刷抖音看无忧传媒的学生们的“学习资料”的时候&#xff0c;看到CSDN -APP上有提醒&#xff0c;一看原来是一位道友发的一个问题&#xff1a; 本来小编最近由于刚刚从国外回来&#xff0c;手上的项目都已经结束了&#xf…

MQTT到串口的转发(node.js)

本文针对以下应用场景&#xff1a;已有通过串口通信的设备或软件&#xff0c;想要实现跨网的远程控制。 node.js安装 从 Node.js — Run JavaScript Everywhere下载LTS版本安装包&#xff0c;运行安装程序。&#xff08;傻瓜安装&#xff0c;按提示点击即可&#xff09; 设置环…

忍の摸头之术游戏娱乐源码

本资源提供给大家学习及参考研究借鉴美工之用&#xff0c;请勿用于商业和非法用途&#xff0c;无任何技术支持&#xff01; 忍の摸头之术游戏娱乐源码&#xff0c;抖音上面非常火的摸头杀画面,看得我眼花缭乱,源码拿去玩吧&#xff1b; 目录说明 忍の摸头之术&#xff1a;域…

idea新建项目/模块找不到Spring Initializr

idea创建项目找不到spring intellij&#xff0c;如下图解决 可能是没有下载spring的相应插件&#xff0c;或者没有启用对应的插件 我这里就是没有启用插件&#xff0c;导致的创建项目时找不到按件。 全部启用后&#xff0c;重启idea即可。 重启后可以看到出现了“Spring Initi…

【Andoird开发】android获取蓝牙权限,beacon,android-beacon-library

iBeacon 最先是苹果的技术&#xff0c;使用android-beacon-library包可以在android上开发iBeacon 技术。 iBeacon的发明意义重大。它是一种基于蓝牙低功耗&#xff08;Bluetooth Low Energy, BLE&#xff09;技术的定位系统&#xff0c;通过向周围发送信号来标识其位置。这项技…

Docker 容器间通讯

1、虚拟ip/访问 同一网络 安装docker时&#xff0c;docker会默认创建一个内部的桥接网络docker0&#xff0c;每创建一个容器分配一个虚拟网卡&#xff0c;容器之间(包括宿主机)可以根据分配的ip互相访问(ps:其他主机(包括其他主机的容器)无法ping通docker容器ip无法访问&#…

22个C语言小白常见问题总结

一.语言使用错误 在打代码的过程中&#xff0c;经常需要在中文与英文中进行转换&#xff0c;因此常出现一些符号一不小心就用错&#xff0c;用成中文。例如&#xff1a;“&#xff1b;”中文中的分号占用了两个字节&#xff0c;而英文中“;”分号只占用一个字节。编译器只能识…