山东大学软件学院ai导论实验之生成对抗网络

目录

实验目的

实验代码

实验内容

实验结果


实验目的

基于Pytorch搭建一个生成对抗网络,使用MNIST数据集。

实验代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os

# 设置环境变量
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# 创建保存生成图像的文件夹
output_path = r"xxxxxxxxxxxxxxxxxx"
os.makedirs(output_path, exist_ok=True)


# 生成器网络
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.network(z)
        return img.view(img.size(0), 1, 28, 28)


# 判别器网络
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.network(img.view(img.size(0), -1))


def generate_and_save_images(generator, test_input, epoch, img_path):
    with torch.no_grad():
        generated_images = generator(test_input).cpu().numpy()

    fig, axes = plt.subplots(4, 4, figsize=(4, 4))
    for i, ax in enumerate(axes.flat):
        # 将图像从形状 (1, 28, 28) 转换为 (28, 28),去除通道维度
        ax.imshow(np.squeeze(generated_images[i]), cmap='gray')
        ax.axis('off')

    img_filename = os.path.join(img_path, f"generated_epoch_{epoch}.png")
    plt.tight_layout()
    plt.savefig(img_filename)
    plt.close()


# 设置设备(使用GPU或CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数
lr = 0.0001
batch_size = 128
latent_dim = 100
epochs = 2000

# 数据预处理和加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./MNIST_data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 测试数据:随机噪声作为输入
test_data = torch.randn(batch_size, latent_dim).to(device)

# 初始化生成器和判别器,并定义损失函数和优化器
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# 记录损失
D_losses = []
G_losses = []

# 训练过程
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(train_loader):
        real_imgs = imgs.to(device)
        batch_size = real_imgs.size(0)

        # 判别器训练
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)

        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # 计算损失
        real_loss = adversarial_loss(discriminator(real_imgs), real_labels)
        fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake_labels)
        d_loss = (real_loss + fake_loss) / 2

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # 生成器训练
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(fake_imgs), real_labels)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        # 记录损失
        D_losses.append(d_loss.item())
        G_losses.append(g_loss.item())

        # 打印每2000个步骤的迭代信息
        if (epoch * len(train_loader) + i) % 2000 == 0:
            print(f"Iter: {epoch * len(train_loader) + i}")
            print(f"D_loss: {d_loss.item():.4f}")
            print(f"G_loss: {g_loss.item():.4f}")
    # 每个epoch保存生成的图像
    generate_and_save_images(generator, test_data, epoch, output_path)

    # 保存生成器和判别器的模型
    torch.save(generator.state_dict(), "Generator_mnist.pth")
    torch.save(discriminator.state_dict(), "Discriminator_mnist.pth")

# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.plot(D_losses, label='Discriminator Loss')
plt.plot(G_losses, label='Generator Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.savefig('loss_curve.png')  # 保存图像
plt.show()  # 显示图像

实验内容

1. 数据集加载

与前几次实验一样,本实验仍然使用MNIST数据集作为输入数据集通过torchvision库进行加载并标准化处理,使得图像像素值在[-1, 1]范围内,以适应生成对抗网络的训练要求。

2. 生成器与判别器网络

生成器:生成器网络的任务是生成伪造的图像,以欺骗判别器。输入是一个随机噪声向量(latent vector),输出是一个28x28像素的图像。生成器使用多个全连接层,每个层后面都跟着一个LeakyReLU激活函数,最终输出通过Tanh激活函数确保生成的图像像素值在[-1, 1]范围内。

判别器:判别器网络的任务是区分输入的图像是“真实的”还是“伪造的”。它将图像输入后,通过多个全连接层,最后输出一个介于0和1之间的值,表示图像的真实性。

3. 训练过程

判别器训练:判别器的目标是最大化其准确性,即正确分类真实和伪造的图像。在每次训练中,先计算真实图像的损失,然后计算生成图像的损失,最后将两个损失加权平均得到判别器的总损失。

生成器训练:生成器的目标是最小化判别器对其生成图像的判断错误率。即通过调整其权重,使得生成的图像越来越像真实图像,以此欺骗判别器。生成器的损失函数是判别器对生成图像的输出,标签为“真实”(即1)。

模型优化:使用Adam优化器分别优化生成器和判别器的参数。学习率为0.0001。

  1. 改变隐藏层数

生成器的结构由原来的4个隐藏层缩减为2个隐藏层:

5.生成图像并保存

在每个epoch结束时,使用生成器生成一些图像,并将图像保存为PNG格式文件。每个epoch的图像被保存到指定的文件夹中,以便可视化生成图像的变化。

6. 绘制损失曲线

训练过程中记录并绘制判别器和生成器的损失曲线,以便观察模型的训练进展。

实验结果

迭代得到的训练结果为:

改变隐藏层数得到的部分结果为:

刚开始生成的初始图像为:

运行一段时间后,得到的图像为:

可以明显的看到,随着迭代不断增加,数字越来越清晰,数字识别成功

损失曲线为:

初始:

慢慢的趋于平稳:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/978434.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Java 面试 八股文】JVM 虚拟机篇

JVM 虚拟机篇 1. JVM组成1.1 JVM由那些部分组成,运行流程是什么?1.2 什么是程序计数器?1.3 你能给我详细的介绍Java堆吗?1.4 Java 虚拟机栈1.4.1 Java Virtual machine Stacks (java 虚拟机栈)1.4.2 栈和堆的区别1.4.3 垃圾回收是否涉及栈内…

钉钉MAKE AI生态大会思考

1. 核心特性 1.1 底层模型开放 除原有模型通义千问外,新接入猎户星空、智普、MinMax、月之暗面、百川智能、零一万物。 1.2 AI搜索 AI搜索贯通企业和个人散落在各地的知识(聊天记录、文档、会议、日程、知识库、项目等),通过大模型对知识逻辑化,直接生成搜索的答案,并…

flex布局自定义一行几栏,靠左对齐===grid布局

模板 <div class"content"><div class"item">1222</div><div class"item">1222</div><div class"item">1222</div><div class"item">1222</div><div class"…

当下弹幕互动游戏源码开发教程及功能逻辑分析

当下很多游戏开发者或者想学习游戏开发的人&#xff0c;想要了解如何制作弹幕互动游戏&#xff0c;比如直播平台上常见的那种&#xff0c;观众通过发送弹幕来影响游戏进程。需要涵盖教程的步骤和功能逻辑的分析。 首先&#xff0c;弹幕互动游戏源码开发教程部分应该分步骤&…

jdk21下载、安装(Windows、Linux、macOS)

Windows 系统 1. 下载安装 访问 Oracle 官方 JDK 下载页面 或 OpenJDK 下载页面&#xff0c;根据自己的系统选择合适的 Windows 版本进行下载&#xff08;通常选择 .msi 安装包&#xff09;。 2. 配置环境变量 右键点击 “此电脑”&#xff0c;选择 “属性”。 在左侧导航栏…

2.部署kafka:9092

官方文档&#xff1a;http://kafka.apache.org/documentation.html (虽然kafka中集成了zookeeper,但还是建议使用独立的zk集群) Kafka3台集群搭建环境&#xff1a; 操作系统: centos7 防火墙&#xff1a;全关 3台zookeeper集群内的机器&#xff0c;1台logstash 软件版本: …

VUE向外暴露文件,并通过本地接口调用获取,前端自己生成接口获取public目录里面的文件

VUE中&#xff0c;如果我们想对外暴露一个文件&#xff0c;可以在打包之后也能事实对其进行替换&#xff0c;我们只需要把相关文件放置在public目录下即可&#xff0c;可以放置JSON&#xff0c;Excel等文件 比如我在这里放置一个other文件 我们可以直接在VUE中使用axios去获取…

ubuntu-24.04.1-desktop 中安装 QT6.7

ubuntu-24.04.1-desktop 中安装 QT6.7 1 环境准备1.1 安装 GCC 和必要的开发包:1.2 Xshell 连接 Ubuntu2 安装 Qt 和 Qt Creator:2.1 下载在线安装器2.2 在虚拟机中为文件添加可执行权限2.3 配置镜像地址运行安装器2.4 错误:libxcb-xinerama.so.0: cannot open shared objec…

应用的负载均衡

概述 负载均衡&#xff08;Load Balancing&#xff09; 调度后方的多台机器&#xff0c;以统一的接口对外提供服务&#xff0c;承担此职责的技术组件被称为“负载均衡”。 负载均衡器将传入的请求分发到应用服务器和数据库等计算资源。负载均衡是计算机网络中一种用于优化资源利…

Linux: 已占用接口

Linux: 已占用接口 1. netstat&#xff08;适用于旧系统&#xff09;1.1 书中对该命令的介绍 2. ss&#xff08;适用于新系统&#xff0c;替代 netstat&#xff09;3. lsof&#xff08;查看详细进程信息&#xff09;4. fuser&#xff08;快速查找占用端口的进程&#xff09;5. …

组件注册方式、传递数据

组件注册 一个vue组件要先被注册&#xff0c;这样vue才能在渲染模版时找到其对应的实现。有两种注册方式&#xff1a;全局注册和局部注册。&#xff08;组件的引入方式&#xff09; 以下这种属于局部引用。 组件传递数据 注意&#xff1a;props传递数据&#xff0c;只能从父…

使用DeepSeek/chatgpt等AI工具辅助网络协议流量数据包分析

随着deepseek,chatgpt等大模型的能力越来越强大&#xff0c;本文将介绍一下deepseek等LLM在分数流量数据包这方面的能力。为需要借助LLM等大模型辅助分析流量数据包的同学提供参考&#xff0c;也了解一下目前是否有必要继续学习wireshark工具以及复杂的协议知识。 pcap格式 目…

蓝桥杯嵌入式客观题以及解释

第十一届省赛&#xff08;大学组&#xff09; 1.稳压二极管时利用PN节的反向击穿特性制作而成 2.STM32嵌套向量终端控制器NVIC具有可编程的优先等级 16 个 3.一个功能简单但是需要频繁调用的函数&#xff0c;比较适用内联函数 4.模拟/数字转换器的分辨率可以通过输出二进制…

《Mycat核心技术》第17章:实现MySQL的读写分离

作者&#xff1a;冰河 星球&#xff1a;http://m6z.cn/6aeFbs 博客&#xff1a;https://binghe.gitcode.host 文章汇总&#xff1a;https://binghe.gitcode.host/md/all/all.html 星球项目地址&#xff1a;https://binghe.gitcode.host/md/zsxq/introduce.html 沉淀&#xff0c…

虚拟机 | Ubuntu 安装流程以及界面太小问题解决

文章目录 前言一、Ubuntu初识二、使用步骤1.下载ubuntu镜像2.创建虚拟机1、使用典型&#xff08;节省空间&#xff09;2、稍后安装方便配置3、优选Linux版本符合4、浏览位置&#xff0c;选择空间大的磁盘 6、 配置信息&#xff0c;选择镜像7、 启动虚拟机&#xff0c;执行以下步…

2025系统架构师(一考就过):案例之三:架构风格总结

软件架构风格是描述某一特定应用领域中系统组织方式的惯用模式&#xff0c;按照软件架构风格&#xff0c;物联网系统属于&#xff08; &#xff09;软件架构风格。 A:层次型 B:事件系统 C:数据线 D:C2 答案&#xff1a;A 解析&#xff1a; 物联网分为多个层次&#xff0…

ubuntu离线安装Ollama并部署Llama3.1 70B INT4

文章目录 1.下载Ollama2. 下载安装Ollama的安装命令文件install.sh3.安装并验证Ollama4.下载所需要的大模型文件4.1 加载.GGUF文件&#xff08;推荐、更容易&#xff09;4.2 加载.Safetensors文件&#xff08;不建议使用&#xff09; 5.配置大模型文件 参考&#xff1a; 1、 如…

算法-数据结构(图)-DFS深度优先遍历

深度优先遍历&#xff08;DFS&#xff09;是一种用于遍历或搜索图的算法。以下是对它的详细介绍&#xff1a; 1. 定义 基本思想&#xff1a;从图中某个起始顶点出发&#xff0c;沿着一条路径尽可能深地访问图中的顶点&#xff0c;直到无法继续前进&#xff08;即到达一个没…

uni-app集成sqlite

Sqlite SQLite 是一种轻量级的关系型数据库管理系统&#xff08;RDBMS&#xff09;&#xff0c;广泛应用于各种应用程序中&#xff0c;特别是那些需要嵌入式数据库解决方案的场景。它不需要单独的服务器进程或系统配置&#xff0c;所有数据都存储在一个单一的普通磁盘文件中&am…

python文件的基本操作,文件读写

1.文件 1.1文件就是存储在某种长期存储设备上的一段数据 1.2文件操作 打开文件-->读写文件-->关闭文件 注意&#xff1a;可以只打开和关闭文件不进行任何操作 1.3文件对象的方法 1.open():创建一个file对象&#xff0c;默认以只读模式打开 2.read(n):n表示从文件中…