DP-GAN-生成器代码

在train文件中,对生成器和判别器分别进行更新,根据loss的不同,分别计算对于的损失:

loss_G, losses_G_list = model(image, label, "losses_G", losses_computer)
loss_D, losses_D_list = model(image, label, "losses_D", losses_computer)

在model中:

from models.sync_batchnorm import DataParallelWithCallback
import models.generator as generators
import models.discriminator as discriminators
import os
import copy
import torch
import torch.nn as nn
from torch.nn import init
import models.losses as losses
class DP_GAN_model(nn.Module):
    def __init__(self, opt):
        super(DP_GAN_model, self).__init__()
        self.opt = opt
        #--- generator and discriminator ---
        self.netG = generators.DP_GAN_Generator(opt).cuda()
        if opt.phase == "train" or opt.phase == "eval":
            self.netD = discriminators.DP_GAN_Discriminator(opt)
        self.print_parameter_count()
        self.init_networks()
        #--- EMA of generator weights ---
        with torch.no_grad():
            self.netEMA = copy.deepcopy(self.netG) if not opt.no_EMA else None
        #--- load previous checkpoints if needed ---
        self.load_checkpoints()
        #--- perceptual loss ---#
        if opt.phase == "train":
            if opt.add_vgg_loss:
                self.VGG_loss = losses.VGGLoss(self.opt.gpu_ids)
        self.GAN_loss = losses.GANLoss()
        self.MSELoss = nn.MSELoss(reduction='mean')

    def align_loss(self, feats, feats_ref):
        loss_align = 0
        for f, fr in zip(feats, feats_ref):
            loss_align += self.MSELoss(f, fr)
        return loss_align

    def forward(self, image, label, mode, losses_computer):
        # Branching is applied to be compatible with DataParallel
        if mode == "losses_G":
            loss_G = 0
            fake = self.netG(label)
            output_D, scores, feats = self.netD(fake)
            _, _, feats_ref = self.netD(image)
            loss_G_adv = losses_computer.loss(output_D, label, for_real=True)
            loss_G += loss_G_adv
            loss_ms = self.GAN_loss(scores, True, for_discriminator=False)
            loss_G += loss_ms.item()
            loss_align = self.align_loss(feats, feats_ref)
            loss_G += loss_align
            if self.opt.add_vgg_loss:
                loss_G_vgg = self.opt.lambda_vgg * self.VGG_loss(fake, image)
                loss_G += loss_G_vgg
            else:
                loss_G_vgg = None
            return loss_G, [loss_G_adv, loss_G_vgg]

        if mode == "losses_D":
            loss_D = 0
            with torch.no_grad():
                fake = self.netG(label)
            output_D_fake, scores_fake, _ = self.netD(fake)
            loss_D_fake = losses_computer.loss(output_D_fake, label, for_real=False)
            loss_ms_fake = self.GAN_loss(scores_fake, False, for_discriminator=True)
            loss_D += loss_D_fake + loss_ms_fake.item()
            output_D_real, scores_real, _ = self.netD(image)
            loss_D_real = losses_computer.loss(output_D_real, label, for_real=True)
            loss_ms_real = self.GAN_loss(scores_real, True, for_discriminator=True)
            loss_D += loss_D_real + loss_ms_real.item()
            if not self.opt.no_labelmix:
                mixed_inp, mask = generate_labelmix(label, fake, image)
                output_D_mixed, _, _ = self.netD(mixed_inp)
                loss_D_lm = self.opt.lambda_labelmix * losses_computer.loss_labelmix(mask, output_D_mixed, output_D_fake,
                                                                                output_D_real)
                loss_D += loss_D_lm
            else:
                loss_D_lm = None
            return loss_D, [loss_D_fake, loss_D_real, loss_D_lm]

        if mode == "generate":
            with torch.no_grad():
                if self.opt.no_EMA:
                    fake = self.netG(label)
                else:
                    fake = self.netEMA(label)
            return fake

        if mode == "eval":
            with torch.no_grad():
                pred, _, _ = self.netD(image)
            return pred

    def load_checkpoints(self):
        if self.opt.phase == "test":
            which_iter = self.opt.ckpt_iter
            path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "models", str(which_iter) + "_")
            if self.opt.no_EMA:
                self.netG.load_state_dict(torch.load(path + "G.pth"))
            else:
                self.netEMA.load_state_dict(torch.load(path + "EMA.pth"))
        elif self.opt.phase == "eval":
            which_iter = self.opt.ckpt_iter
            path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "models", str(which_iter) + "_")
            self.netD.load_state_dict(torch.load(path + "D.pth"))
        elif self.opt.continue_train:
            which_iter = self.opt.which_iter
            path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "models", str(which_iter) + "_")
            self.netG.load_state_dict(torch.load(path + "G.pth"))
            self.netD.load_state_dict(torch.load(path + "D.pth"))
            if not self.opt.no_EMA:
                self.netEMA.load_state_dict(torch.load(path + "EMA.pth"))

    def print_parameter_count(self):
        if self.opt.phase == "train":
            networks = [self.netG, self.netD]
        else:
            networks = [self.netG]
        for network in networks:
            param_count = 0
            for name, module in network.named_modules():
                if (isinstance(module, nn.Conv2d)
                        or isinstance(module, nn.Linear)
                        or isinstance(module, nn.Embedding)):
                    param_count += sum([p.data.nelement() for p in module.parameters()])
            print('Created', network.__class__.__name__, "with %d parameters" % param_count)

    def init_networks(self):
        def init_weights(m, gain=0.02):
            classname = m.__class__.__name__
            if classname.find('BatchNorm2d') != -1:
                if hasattr(m, 'weight') and m.weight is not None:
                    init.normal_(m.weight.data, 1.0, gain)
                if hasattr(m, 'bias') and m.bias is not None:
                    init.constant_(m.bias.data, 0.0)
            elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                init.xavier_normal_(m.weight.data, gain=gain)
                if hasattr(m, 'bias') and m.bias is not None:
                    init.constant_(m.bias.data, 0.0)

        if self.opt.phase == "train":
            networks = [self.netG, self.netD]
        else:
            networks = [self.netG]
        for net in networks:
            net.apply(init_weights)


def put_on_multi_gpus(model, opt):
    if opt.gpu_ids != "-1":
        gpus = list(map(int, opt.gpu_ids.split(",")))
        model = DataParallelWithCallback(model, device_ids=gpus).cuda()
    else:
        model.module = model
    assert len(opt.gpu_ids.split(",")) == 0 or opt.batch_size % len(opt.gpu_ids.split(",")) == 0
    return model


def preprocess_input(opt, data):
    data['label'] = data['label'].long()
    if opt.gpu_ids != "-1":
        data['label'] = data['label'].cuda()
        data['image'] = data['image'].cuda()
    label_map = data['label']
    bs, _, h, w = label_map.size()
    nc = opt.semantic_nc
    if opt.gpu_ids != "-1":
        input_label = torch.cuda.FloatTensor(bs, nc, h, w).zero_()
    else:
        input_label = torch.FloatTensor(bs, nc, h, w).zero_()
    input_semantics = input_label.scatter_(1, label_map, 1.0)
    return data['image'], input_semantics


def generate_labelmix(label, fake_image, real_image):
    target_map = torch.argmax(label, dim = 1, keepdim = True)
    all_classes = torch.unique(target_map)
    for c in all_classes:
        target_map[target_map == c] = torch.randint(0,2,(1,)).cuda()
    target_map = target_map.float()
    mixed_image = target_map*real_image+(1-target_map)*fake_image
    return mixed_image, target_map

首先看生成器流程:
标签输入到生成器中得到fake image,fake image 和 real image 共同输入到判别器中得到中间变量输出,接着分别计算四个损失。我们需要明白生成器和辨别器模型的搭建,损失计算过程。
在这里插入图片描述
首先是生成器的组成:
在这里插入图片描述
在这里插入图片描述
输入标签大小是(b,c,h,w),首先z等于一个正态分布的随机数,大小为(b,64),接着view为(b,64,1,1),再扩张到(b,64,h,w)和(b,c,h,w)沿着通道维度拼接起来。将拼接的结果上采样到W和H大小。
在这里插入图片描述
其中在CityscapesDataset指定了:
在这里插入图片描述
则w=512//2^5=16,h=16/2=8.
在这里插入图片描述
令s等于input label,输入到pyrmid中,生成结果添加到列表中。

self.seg_pyrmid = nn.ModuleList([])
        if not self.opt.no_3dnoise:
            self.fc = nn.Conv2d(self.opt.semantic_nc + self.opt.z_dim, 16 * ch, 3, padding=1)
            self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(self.opt.semantic_nc + self.opt.z_dim, 32, 3, stride=1, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True)))
        else:
            self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * ch, 3, padding=1)
            self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(self.opt.semantic_nc, 32, 3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)))

        self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)))
        for i in range(len(self.channels)-2):
            self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)))         

而pyrmid是一个modulist,便利添加的每一个module,生成一个结果:
首先将标签图和噪声拼接起来经过一个3x3卷积,输出通道变为32,再经过一个1x1卷积,输出通道变为64.再经过经过5个步长为2的3x3卷积,下采样32倍。这样pyrmid列表中就有7个结果。
接着将已经采样的x输入到Fc中,输出通道是1024.这里需要清楚两个变量x,和pyrmid.
1:x是输入下采样到(H,W)大小的label+noise.
2:pyrmid是储存经过七次(五次下采样)卷积之后的label+noise。
接着将pyrmid最后一个值采样到x的大小。然后和pyrmid的第i个值拼接在一起。
在这里插入图片描述
对应于:
在这里插入图片描述
每拼接一次生成的值和经过Fc之后的label+noise共同作为输入:
在这里插入图片描述
输入到SPADE块中:
首先要判断SPAD的两个参数即输入通道是否相等。
在这里插入图片描述
在这里插入图片描述
如果相等就输入到SPADE模块,如果不等令变量等于输入值。
在这里插入图片描述
其中最后一个参数是类别值:在Cityscape数据集设定语义标签是34类。有一类是未知,加上噪声的64个通道。
在这里插入图片描述
SPADE:

class SPADE(nn.Module):
    def __init__(self, opt, norm_nc, label_nc):
        super().__init__()
        self.first_norm = get_norm_layer(opt, norm_nc)
        ks = opt.spade_ks
        nhidden = 128
        pw = ks // 2
        #self.mlp_shared = nn.Sequential(
        #    nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
        #    nn.ReLU()
        #)
        self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)

    def forward(self, x, segmap):
        normalized = self.first_norm(x)
        #segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
        #actv = self.mlp_shared(segmap)
        actv = segmap
        gamma = self.mlp_gamma(actv)
        beta = self.mlp_beta(actv)
        out = normalized * (1 + gamma) + beta
        return out

公式:
在这里插入图片描述
首先X经过一个norm层,即为分布式BN。
在这里插入图片描述
在这里插入图片描述
接着使用卷积学习β和γ。
在这里插入图片描述
在这里插入图片描述
卷积核大小都为3,padding为1。
接着经过bn之后的变量和γ相乘在和β相加,再和经过归一化之后的x相加。
在这里插入图片描述
接着:x和seg经过相同的norm操作。再进过一个LeakyReLU,再进行一个卷积层。中间有个midlayer过渡。
在这里插入图片描述
在这里插入图片描述
输出的结果经过一个跳连接得到最后输出。
在这里插入图片描述
经过SPADE之后的输出上采样两倍作为输入输入到下一个SPADE中。
最终输出一个通道为3的RGB图片。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/53490.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

给初学嵌入式的菜鸟一点建议.学习嵌入式linux

学习嵌入式,我认为两个重点,cpu和操作系统,目前市场是比较流行arm,所以推荐大家学习arm。操作系统很多,我个人对开始学习的人,特别不是计算机专业的,推荐学习ucos。那是开源的,同时很…

ALLEGRO之Place

本文主要讲述了ALLEGRO的Place菜单。 (1)Manually:手动放置,常用元器件放置方法; (2)Quickplace:快速放置; (3)Autoplace:自动放置&a…

Linux6.16 Docker consul的容器服务更新与发现

文章目录 计算机系统5G云计算第四章 LINUX Docker consul的容器服务更新与发现一、consul 概述1.什么是服务注册与发现2.什么是consul 二、consul 部署1.consul服务器2.registrator服务器3.consul-template4.consul 多节点 计算机系统 5G云计算 第四章 LINUX Docker consul的…

Linux虚拟机安装tomcat(图文详解)

目录 第一章、xshell工具和xftp的使用1.1)xshell下载与安装1.2)xshell连接1.3)xftp下载安装和连接 第二章、安装tomcat1.1)关闭防火墙,传输tomcat压缩包到Linux虚拟机12)启动tomcat 第一章、xshell工具和xf…

Git 版本管理使用-介绍-示例

文章目录 Git是一种版本控制工具,它可以帮助程序员组织和管理代码的变更历史Git的使用方式:常见命令安装Git软件第一次上传分支删除分支 Git是一种版本控制工具,它可以帮助程序员组织和管理代码的变更历史 以下是Git的基本概念和使用方式&am…

【Git系列】分支操作

🐳分支操作 🧊1. 什么是分支🧊2. 分支的好处🧊3. 分支操作🪟3.1 查看分支🪟3.2 创建分支🪟3.3 切换分支 🧊4. 分支冲突🪟4.1 环境准备🪟4.2 分支冲突演示 &am…

01 Excel常用高频快捷键汇总

目录 一、简介二、快捷键介绍2.1 常用基本快捷键1 复制:CtrlC2 粘贴:CtrlV3 剪切:CtrlX4 撤销:CtrlZ5 全选:CtrlA 2.2 常用高级快捷键1 单元格内强制换行:AltEnter2 批量输入相同的内容:CtrlEnt…

机器学习-Basic Concept

机器学习(Basic Concept) videopptblog Where does the error come from? 在前面我们讨论误差的时候,我们提到了Average Error On Testing Data是最重要的 A more complex model does not lead to better performance on test data Bias And Variance Bias(偏差) …

排序算法(冒泡排序、选择排序、插入排序、希尔排序、堆排序、快速排序、归并排序、计数排序)

🍕博客主页:️自信不孤单 🍬文章专栏:数据结构与算法 🍚代码仓库:破浪晓梦 🍭欢迎关注:欢迎大家点赞收藏关注 文章目录 🍓冒泡排序概念算法步骤动图演示代码 &#x1f34…

数学建模学习(7):Matlab绘图

一、二维图像绘制 1.绘制曲线图 最基础的二维图形绘制方法:plot -plot命令自动打开一个图形窗口Figure; 用直线连接相邻两数据点来绘制图形 -根据图形坐标大小自动缩扩坐标轴,将数据标尺及单位标注自动加到两个坐标轴上,可自定…

【Linux】sed修改文件指定内容

sed修改文件指定内容: 参考:(5条消息) Linux系列讲解 —— 【cat echo sed】操作读写文件内容_shell命令修改文件内容_星际工程师的博客-CSDN博客

理解构建LLM驱动的聊天机器人时的向量数据库检索的局限性 - (第1/3部分)

本博客是一系列文章中的第一篇,解释了为什么使用大型语言模型(LLM)部署专用领域聊天机器人的主流管道成本太高且效率低下。在第一篇文章中,我们将讨论为什么矢量数据库尽管最近流行起来,但在实际生产管道中部署时从根本…

【编译】gcc make cmake Makefile CMakeList.txt 区别

文章目录 一 关系二 gcc2.1 编译过程2.2 编译参数2.3 静态库和动态库1 后缀名2 联系与区别 2.4 GDB 调试器1 常用命令 三 make、makefile四 cmake、cmakelist4.1 语法特性4.2 重要命令4.2 重要变量4.3 编译流程4.4 两种构建方式 五 Vscode5.0 常用快捷键5.1 界面5.2 插件5.3 .v…

点播播放器如何自定义额外信息(统计信息传值)

Web播放器支持设置观众信息参数&#xff0c;设置后在播放器上报的观看日志中会附带观众信息&#xff0c;这样用户就可以通过管理后台的统计页面或服务端API来查看特定观众的视频观看情况了。 播放器设置观众信息参数的代码示例如下&#xff1a; <div id"player"…

加利福尼亚大学|3D-LLM:将3D世界于大规模语言模型结合

来自加利福尼亚大学的3D-LLM项目团队提到&#xff1a;大型语言模型 (LLM) 和视觉语言模型 (VLM) 已被证明在多项任务上表现出色&#xff0c;例如常识推理。尽管这些模型非常强大&#xff0c;但它们并不以 3D 物理世界为基础&#xff0c;而 3D 物理世界涉及更丰富的概念&#xf…

【100天精通python】Day20:文件及目录操作_os模块和os.psth模块,文件权限修改

目录 专栏导读 1 文件的目录操作 os模块的一些操作目录函数​编辑 os.path 模块的操作目录函数 2 相对路径和绝对路径 3 路径拼接 4 判断目录是否存在 5 创建目录、删除目录、遍历目录 专栏导读 专栏订阅地址&#xff1a;https://blog.csdn.net/qq_35831906/category_12…

Java中的代理模式

Java中的代理模式 1. 静态代理JDK动态代理CGLib动态代理 1. 静态代理 接口 public interface ICeo {void meeting(String name) throws InterruptedException; }目标类 public class Ceo implements ICeo{public void meeting(String name) throws InterruptedException {Th…

【信号去噪和正交采样】流水线过程的一部分,用于对L波段次级雷达中接收的信号进行降噪(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

计算机网络——应用层

文章目录 **1 网络应用模型****2 域名系统DNS****3 文件传输协议FTP****4 电子邮件****4.1 电子邮件系统的组成结构****4.2 电子邮件格式与MIME****4.3 SMTP和POP3** **5 万维网WWW****5.1 HTTP** 1 网络应用模型 客户/服务器模型 C/S 服务器服务于许多来自其他称为客户机的主…