使用pytorch构建GAN模型的评估

本文为此系列的第六篇对GAN的评估,上一篇为Controllable GAN。文中使用训练好的分类模型的部分网络提取特征将真实分布与生成分布进行对比来评估模型的好坏,若有不懂的无监督知识点可以看本系列第一篇。

原理

  1. 一般来说,我们评估模型的好坏可以通过对测试集的错误率来体现:比如图像分类我们可以统计几张分错几张分对来量化错误率、目标检测我们可以通过比对每个框得到mAP从而量化错误率…但是我们怎么通过生成的图像来评估GAN的好坏呢?
    在这里插入图片描述
    我们总不能说,生成的某一个像素要更绿色一点比较好,或者某个像素要更黄色一点比较好吧?
    先进行概括一下,全文主要围绕着生成质量(保真度fidelity)、多样性(diversity)进行讲解。
    在这里插入图片描述
  2. 图像对比有两种方法,pixel distance、feature distance。
    第一种像素对比,直接做相减运算。这样做的缺点是尽管两张图片可能非常相似,但是每个像素的像素值会有一些细微的差异,即使我们肉眼看不出来,最终的差值也会非常大,太过于关注细节。
    在这里插入图片描述
    第二种则是特征对比,通俗的说是成片的像素区域进行对比是否相似,这样的对比更符合我们人眼观察标准。
    在这里插入图片描述
    那么,接下来的问题就是如何进行特征提取。
  3. 特征提取的方法
    我们训练好的分类器是一个很好的特征提取器,比如我们训练了一个识别猫狗的分类器,那它必然是学习到了猫狗的特征才会对他们进行分类。
    在这里插入图片描述
    直接将分类部分的最后一层分类层去掉,其余的都是对我们有价值的。我们一般选择的是连接最后一个全连接层的池化层作为输出特征的层,我们成为特征层,输出的特征我们称为embedding。
    选择这个位置并不固定,只是选择的位置越后面,每个单元的感受野越大,所包含的信息就越多,更符合我们的要求。很前面的层获取到的特征可能只是一横或者一竖或者一个弧度等。
  • 我们使用Inception v3作为我们的特征提取器,Inception使用超1400万张图片、2万多类别的ImageNet数据库作为训练集。提取详细流程如图:
    在这里插入图片描述
    对总的概括可以概括为一下流程:
    在这里插入图片描述
    最终我们就是对真实数据提取的特征于生成数据提取的特征进行对比。
  1. Frechet Inception Distance(FID)
    我们使用FID来量化真假特征的差异。
    通俗来说Frechet Distance是用来衡量两条曲线之间的的最小距离,比如人狗同时走所需的最短牵引绳的长度。
    在这里插入图片描述
    严格来说,Frechet Distance是衡量两个分布之间的差异。
    在这里插入图片描述
    ①我们可以使用以下公式来表示两个单维正态分布的Frechet Distance:
    在这里插入图片描述
    分别从真实数据和生成数据里面提取大量的特征,分别作为真实特征分布于生成特征分布,计算出各自的均值和标准差即可计算出真假之间的差值。
    ②两个多变量正态分布的Frechet Distance
    我们可以为每个维度提供一个单变量的正态分布,假设是两个变量的(便于举例),如图:
    在这里插入图片描述

协方差矩阵:
比如(x1,x2)代表第一变量的正态分布的随机变量与第二正态分布的随机变量之间的协方差。非对角线元素代表不同变量之间的协方差,即不同变量之间的相关性。若两个变量变化趋势一致则协方差为正值,反之负值,若没有线性关系则为0。上图就代表两个变量之间相互不影响相互独立,下图代表两变量之间负相关;
比如(x1,x1)代表第一变量的正态分布的方差。对角线元素代表每个变量分布的方差,即每个变量本身的变化程度。
在这里插入图片描述
由此可以计算我们的多变量正态分布之间的Frechet Distance,可以将单维正态分布之间的Frechet Distance公式展开进行对比发现他们之间其实是相似的:
在这里插入图片描述
Tr运算为矩阵的对角线元素之和,例如上面那个负相关的协方差矩阵的Tr运算结果为2+2=4。
将多变量正态分布之间的Frechet Distance应用于真假特征的分布就是FID了:
在这里插入图片描述
FID越小,就代表着真假分布就越接近,那么GAN就越好。

代码

import torch
import numpy as np
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import CelebA
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for our testing purposes, please do not change!

class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=3, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 8),
            self.make_gen_block(hidden_dim * 8, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise):
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

z_dim = 64
image_size = 299
device = 'cuda'

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = CelebA(".", download=True, transform=transform)

gen = Generator(z_dim).to(device)
gen.load_state_dict(torch.load(f"pretrained_celeba.pth", map_location=torch.device(device))["gen"])
gen = gen.eval()

from torchvision.models import inception_v3
inception_model = inception_v3(pretrained=False)
inception_model.load_state_dict(torch.load("inception_v3_google-1a9a5a14.pth"))
inception_model.to(device)
inception_model = inception_model.eval() # Evaluation mode

inception_model.fc = torch.nn.Identity()

from torch.distributions import MultivariateNormal
import seaborn as sns # This is for visualization
mean = torch.Tensor([0, 0]) # Center the mean at the origin
covariance = torch.Tensor( # This matrix shows independence - there are only non-zero values on the diagonal
    [[1, 0],
     [0, 1]]
)
independent_dist = MultivariateNormal(mean, covariance)
samples = independent_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()

mean = torch.Tensor([0, 0])
covariance = torch.Tensor(
    [[2, -1],
     [-1, 2]]
)
covariant_dist = MultivariateNormal(mean, covariance)
samples = covariant_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()

import scipy
def matrix_sqrt(x):
    y = x.cpu().detach().numpy()
    y = scipy.linalg.sqrtm(y)
    return torch.Tensor(y.real, device=x.device)
    
def frechet_distance(mu_x, mu_y, sigma_x, sigma_y):
    return (mu_x - mu_y).dot(mu_x - mu_y) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2*torch.trace(matrix_sqrt(sigma_x @ sigma_y))

def preprocess(img):
    img = torch.nn.functional.interpolate(img, size=(299, 299), mode='bilinear', align_corners=False)
    return img

import numpy as np
def get_covariance(features):
    return torch.Tensor(np.cov(features.detach().numpy(), rowvar=False))

fake_features_list = []
real_features_list = []

n_samples = 512 # The total number of samples
batch_size = 4 # Samples per iteration

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True)

cur_samples = 0
with torch.no_grad(): # You don't need to calculate gradients here, so you do this to save memory
    try:
        for real_example, _ in tqdm(dataloader, total=n_samples // batch_size): # Go by batch
            real_samples = real_example
            real_features = inception_model(real_samples.to(device)).detach().to('cpu') # Move features to CPU
            real_features_list.append(real_features)

            fake_samples = get_noise(len(real_example), z_dim).to(device)
            fake_samples = preprocess(gen(fake_samples))
            fake_features = inception_model(fake_samples.to(device)).detach().to('cpu')
            fake_features_list.append(fake_features)
            cur_samples += len(real_samples)
            if cur_samples >= n_samples:
                break
    except:
        print("Error in loop")

fake_features_all = torch.cat(fake_features_list)
real_features_all = torch.cat(real_features_list)

mu_fake = fake_features_all.mean(0)
mu_real = real_features_all.mean(0)
sigma_fake = get_covariance(fake_features_all)
sigma_real = get_covariance(real_features_all)

indices = [2, 4, 5]
fake_dist = MultivariateNormal(mu_fake[indices], sigma_fake[indices][:, indices])
fake_samples = fake_dist.sample((5000,))
real_dist = MultivariateNormal(mu_real[indices], sigma_real[indices][:, indices])
real_samples = real_dist.sample((5000,))

import pandas as pd
df_fake = pd.DataFrame(fake_samples.numpy(), columns=indices)
df_real = pd.DataFrame(real_samples.numpy(), columns=indices)
df_fake["is_real"] = "no"
df_real["is_real"] = "yes"
df = pd.concat([df_fake, df_real])
sns.pairplot(df, plot_kws={'alpha': 0.1}, hue='is_real')
plt.show()

with torch.no_grad():
    print(frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item())

代码中使用的生成器模型可以从上一篇当中下载,inception_v3_google-1a9a5a14.pth模型可以从这里下载。

代码解析

  • 去掉分类层
inception_model.fc = torch.nn.Identity()

将最后一层的全连接层替换为恒等函数,它将输入的数据不做任何操作、原封不动地输出。
通常Inception模型的全连接层用于图像分类任务,它将提取的特征映射到类别预测上。然而我们不需要进行图像分类,而是想要利用Inception模型的前面部分来提取图像的特征。
这样就将Inception模型从原始的分类任务模型转变为一个特征提取器,从而不再执行图像分类任务,而是将图像转换为特征向量。

  • 可视化多变量正态分布
from torch.distributions import MultivariateNormal
import seaborn as sns # This is for visualization
mean = torch.Tensor([0, 0]) # Center the mean at the origin
covariance = torch.Tensor( # This matrix shows independence - there are only non-zero values on the diagonal
    [[1, 0],
     [0, 1]]
)
independent_dist = MultivariateNormal(mean, covariance)
samples = independent_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()

mean = torch.Tensor([0, 0])
covariance = torch.Tensor(
    [[2, -1],
     [-1, 2]]
)
covariant_dist = MultivariateNormal(mean, covariance)
samples = covariant_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()

首先定义均值和协方差矩阵(原理中举的两个例子),然后使用MultivariateNormal构建一个多变量正态分布对象covariant_dist。然后从这个分布中抽取了10000个样本,每个样本是一个shape为(samples, 2)的二维向量。最后将生成的样本可视化为二维核密度估计图(Kernel Density Estimate,KDE)。
在这里插入图片描述
在这里插入图片描述

  • 计算矩阵的平方根
def matrix_sqrt(x):
    y = x.cpu().detach().numpy()
    y = scipy.linalg.sqrtm(y)
    return torch.Tensor(y.real, device=x.device)

首先将输入矩阵转移到CPU上并将其转换为NumPy数组。这是因为scipy.linalg.sqrtm函数只能接受NumPy数组作为输入,不能接受PyTorch张量,且在CPU上计算更高效。
然后使用scipy.linalg.sqrtm函数计算平方根且返回一个复数矩阵,所以需要取其实部(real)部分,然后再转换为PyTorch张量。同时,函数还会确保新的张量与输入矩阵在相同的设备(device)上。

  • 计算FID
def frechet_distance(mu_x, mu_y, sigma_x, sigma_y):
    return (mu_x - mu_y).dot(mu_x - mu_y) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2*torch.trace(matrix_sqrt(sigma_x @ sigma_y))

给定两个分布的均值和协方差矩阵,利用原理中的公式进行计算。

  • 对生成图像进行处理
def preprocess(img):
    img = torch.nn.functional.interpolate(img, size=(299, 299), mode='bilinear', align_corners=False)
    return img

将输入的图像进行插值操作,插值方法使用双线性插值,参数align_corners=False指示在进行插值操作时不对齐图像的角点,这在图像处理中常用于避免不必要的插值偏差。
在这里插入图片描述

  • 计算协方差矩阵
def get_covariance(features):
    return torch.Tensor(np.cov(features.detach().numpy(), rowvar=False))

使用NumPy的np.cov()函数计算特征向量集合的协方差矩阵,rowvar=False参数表示传递的数据中每一列代表一个特征向量的观测值,而不是每一行代表一个观测样本。

  • 提取特征
for real_example, _ in tqdm(dataloader, total=n_samples // batch_size): # Go by batch
    real_samples = real_example
    real_features = inception_model(real_samples.to(device)).detach().to('cpu') # Move features to CPU
    real_features_list.append(real_features)

    fake_samples = get_noise(len(real_example), z_dim).to(device)
    fake_samples = preprocess(gen(fake_samples))
    fake_features = inception_model(fake_samples.to(device)).detach().to('cpu')
    fake_features_list.append(fake_features)
    cur_samples += len(real_samples)
    if cur_samples >= n_samples:
        break

使用预训练的Inception模型提取真实图像和生成图像的特征,并将这些特征存储在列表中,以备后续计算Fréchet Distance。
在这里需要对生成的图像进行preprocess()处理为299的宽高是因为真实数据的宽高为299,而生成数据的宽高为64。
我们可以将生成数据和preprocess处理后的数据显示出来看效果:

import matplotlib.pyplot as plt

# 选择其中一个样本进行显示
sample_index = 0

# 显示生成图像
fake_image = fake[sample_index].permute(1, 2, 0)  # 将张量形状转换为图像的形状(C, H, W)->(H, W, C)
plt.imshow(fake_image)
plt.axis('off')
plt.show()

# 显示经过处理的图像
fake_image = fake_samples[sample_index].permute(1, 2, 0)  # 将张量形状转换为图像的形状(C, H, W)->(H, W, C)
plt.imshow(fake_image)
plt.axis('off')
plt.show()

在这里插入图片描述
在这里插入图片描述
可以看到插值操作后平滑很多。

  • 可视化真实数据分布与生成数据分布,并计算FID
indices = [2, 4, 5]
import pandas as pd
df_fake = pd.DataFrame(fake_samples.numpy(), columns=indices)
df_real = pd.DataFrame(real_samples.numpy(), columns=indices)
df_fake["is_real"] = "no"
df_real["is_real"] = "yes"
df = pd.concat([df_fake, df_real])
sns.pairplot(df, plot_kws={'alpha': 0.1}, hue='is_real')
plt.show()

with torch.no_grad():
    print(frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item())

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/553014.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

DataGridView添加行号隔行变色

运行效果 颜色对应关系 类实现代码 using System; using System.Collections.Generic; using System.ComponentModel; using System.Drawing; using System.Linq; using System.Text; using System.Threading.Tasks; using System.Windows.Forms;namespace WindowsFormsApp1 {…

二刷大数据(二)- Spark

目录 SparkHadoop区别核心组件运行架构Master&WorkerApplication (Driver)Executor RDD概念yarn下工作原理算子依赖血缘关系阶段划分广播变量 shuffle流程SparkSQLDataSet、DataFrame、RDD相互转换 SparkStreaming Spark Spark是一种基于内存的快速、通用、可扩展的大数据…

C# Solidworks二次开发:比较两个solidworks文档属性相关API详解

大家好,今天要讲的文章是关于如何比较两个solidworks文档。 下面是API的介绍: (1)第一个为Close,这个API的含义为在比较solidworks文档以后执行必要的清理。下面是官方的具体解释: 其没有输入参数&#x…

MySQL Workbench下载安装、 MySQL Workbench使用

官方下载链接;MySQL :: Download MySQL Workbench 下载好懒人安装,也可自己选择目录 下面是使用: 连接数据库: 填写数据库连接信息: 基本操作部分: 数据导入导出: 导出/备份 导入: 生产er图…

【机器学习】科学库使用第5篇:Matplotlib,学习目标【附代码文档】

机器学习(科学计算库)完整教程(附代码资料)主要内容讲述:机器学习(常用科学计算库的使用)基础定位、目标,机器学习概述定位,目标,学习目标,学习目标,1 人工智能应用场景,2 人工智能小…

react中关于类式组件和函数组件对props、state、ref的使用

文章中有很多蓝色字体为扩展链接&#xff0c;可以补充查看。 常用命令使用规则 组件编写方式: 1.函数式 function MyButton() { //直接return 标签体return (<>……</>); }2.类 class MyButton extends React.Component { //在render方法中&#xff0c;return…

UE5 C++ 射线检测

一.声明四个变量 FVector StartLocation;FVector ForwardVector;FVector EndLocation;FHitResult HitResult;二.起点从摄像机&#xff0c;重点为摄像机前9999m。射线检测 使用LineTraceSingleByChannel 射线直线通道检测&#xff0c;所以 void AMyCharacter::Tick(float Delt…

GPT国内能用吗

2022年11月&#xff0c;Open AI发布ChatGPT&#xff0c;ChatGPT展现了大型语模型在自然语言处理方面的惊人进步&#xff0c;其生成文本的流畅度和连贯性令人印象深刻&#xff0c;为AI应用打开了新的可能性。 ChatGPT的出现推动了AI技术在各个领域的应用&#xff0c;例如&#x…

Python学习教程(Python学习路线+Python学习视频):Python数据结构

数据结构引言&#xff1a; 数据结构是组织数据的方式&#xff0c;以便能够更好的存储和获取数据。数据结构定义数据之间的关系和对这些数据的操作方式。数据结构屏蔽了数据存储和操作的细节&#xff0c;让程序员能更好的处理业务逻辑&#xff0c;同时拥有快速的数据存储和获取方…

.net9 AOT编绎生成标准DLL,输出API函数教程-中国首创

1&#xff0c;安装VS2022预览版&#xff08;Visual Studio Preview&#xff09; https://visualstudio.microsoft.com/zh-hans/vs/preview/#download-preview 2&#xff0c;选择安装组件&#xff1a;使用C的桌面开发 和 .NET桌面开发 ------------------------------------- …

java八股文知识点讲解(个人认为讲的比较好的)

1、解决哈希冲突——链地址法&#xff1a;【第7章查找】19哈希表的查找_链地址法解决哈希冲突_哔哩哔哩_bilibili 2、解决哈希冲突——开放地址法 &#xff1a; 【第7章查找】18哈希表的查找_开放定址法解决哈希冲突_哔哩哔哩_bilibili 3、小根堆大根堆的创建&#xff1a;选择…

【每日刷题】Day17

【每日刷题】Day17 &#x1f955;个人主页&#xff1a;开敲&#x1f349; &#x1f525;所属专栏&#xff1a;每日刷题&#x1f34d; &#x1f33c;文章目录&#x1f33c; 1. 19. 删除链表的倒数第 N 个结点 - 力扣&#xff08;LeetCode&#xff09; 2. 162. 寻找峰值 - 力扣…

1 回归:锂电池温度预测top2 代码部分(一) Tabnet

2024 iFLYTEK A.I.开发者大赛-讯飞开放平台 TabNet&#xff1a; 模型也是我在这个比赛一个意外收获&#xff0c;这个模型在比赛之中可用。但是需要GPU资源&#xff0c;否则运行真的是太慢了。后面针对这个模型我会写出如何使用的方法策略。 比赛结束后有与其他两位选手聊天&am…

《ElementPlus 与 ElementUI 差异集合》el-popconfirm 气泡确认框之插槽写法有差异

ElementUI 直接在 el-button 上配置属性 slot&#xff1b; <el-popconfirm title"确定删除吗&#xff1f;请谨慎操作&#xff01;" confirm"delete"><el-button slot"reference" size"small" type"danger">删…

Word学习笔记之奇偶页的页眉与页码设置

1. 常用格式 在毕业论文中&#xff0c;往往有一下要求&#xff1a; 奇数页右下角显示、偶数页左下角显示奇数页眉为每章标题、偶数页眉为论文标题 2. 问题解决 2.1 前期准备 首先&#xff0c;不论时要求 1、还是要求 2&#xff0c;这里我们都要做一下设置&#xff1a; 鼠…

Adobe Firefly是否将重新定义AI视频编辑领域?|TodayAI

Adobe最近发布了一段令人瞩目的视频&#xff0c;详细展示了其最新推出的Adobe Firefly视频模型。这一模型集成了尖端的生成式人工智能技术&#xff0c;带来了一系列颠覆性的视频编辑功能&#xff0c;引发了业界的广泛关注和讨论。 视频中的旁白充满热情地宣布&#xff1a;“Ad…

【超级简单】vscode进入服务器的docker容器

前提 1、已经运行docker容器 2、已经用vscode链接服务器 在vscode中安装的插件 Dev Containers docker 在容器中安装的依赖 yum install openssh-server yum install openssh-clientsvscode进入服务器的docker容器 找到自己的容器&#xff0c;右键点击&#xff0c;找到…

Jmeter BeanShell调用Java方法加密

1、添加BeanShell前置处理器 由于请求接口时&#xff0c;会传加密参数。加密过程会在请求之前完成&#xff0c;所以需要使用前置处理器中beanshell preprocessor 2、编写BeanShell脚本 ①定义一个beashell变量&#xff1a;phoneNum&#xff0c;在Beanshell中可以直接调用Jmete…

idea运行报错:启动命令过长

JAVA项目&#xff0c;运行的时候报错 Command line is too long. Shorten the command line via JAR manifest or via a classpath file and rerun老问题了&#xff0c;记录一下 解决办法&#xff1a; 1、Edit Configurations 2、点击Modify options设置&#xff0c;勾选S…

janus架构学习

基础介绍 Janus 是由Meetecho设计和开发的开源、通用的基于SFU架构的WebRTC流媒体服务器&#xff0c;它支持在Linux的服务器或MacOS上的机器进行编译和安装。Janus 是使用C语言进行编写的&#xff0c;它的性能十分优秀。 架构 janus为sfu架构 模块结构图 模块说明 core模…