人工智能(pytorch)搭建模型23-pytorch搭建生成对抗网络(GAN):手写数字生成的项目应用

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型23-pytorch搭建生成对抗网络(GAN):手写数字生成的项目应用。生成对抗网络(GAN)是一种强大的生成模型,在手写数字生成方面具有广泛的应用前景。通过生成逼真的手写数字图像,GAN可以用于数据增强、图像修复、风格迁移等任务,提高模型的性能和泛化能力。生成对抗网络在手写数字生成领域具有广泛的应用前景。主要应用场景包括数据增强、图像修复、风格迁移和跨领域生成。数据增强可以通过生成逼真的手写数字图像,为训练数据集提供更多的样本,提高模型的泛化能力。

一、项目背景

随着深度学习技术的不断发展,生成模型在计算机视觉、自然语言处理等领域取得了显著的成果。生成对抗网络(GAN)作为一种新兴的生成模型,近年来备受关注。在手写数字生成方面,GAN可以生成逼真的手写数字图像,为数据增强、图像修复等任务提供有力支持。

二、生成对抗网络原理

生成对抗网络(GAN)由Goodfellow等人于2014年提出,它由两个神经网络——生成器(Generator)和判别器(Discriminator)——组成。生成器的目标是生成逼真的假样本,而判别器的目标是区分真实样本和生成器生成的假样本。在训练过程中,生成器和判别器相互竞争,不断调整参数,以达到纳什均衡。
GAN的目标是最小化以下价值函数:
min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]
其中, G G G表示生成器, D D D表示判别器, x x x表示真实样本, z z z表示生成器的输入噪声, p data p_{\text{data}} pdata表示真实数据分布, p z p_z pz表示噪声分布。
在这里插入图片描述

三、生成对抗网络应用场景

生成对抗网络(GAN)在手写数字生成领域的应用具有广泛的前景。以下是几个主要的应用场景:
1.数据增强:通过生成逼真的手写数字图像,GAN可以为训练数据集提供更多的样本,提高模型的泛化能力。
2. 图像修复:GAN可以用于修复损坏或缺失的手写数字图像,提高图像的质量和可读性。
3. 风格迁移:GAN可以将一种手写风格转换为另一种风格,为个性化手写数字生成提供可能。
4. 跨领域生成:GAN可以实现不同手写数字数据集之间的转换,为多任务学习提供支持。

四、生成对抗网络实现手写数字生成

下面我将利用pytorch深度学习框架构建生成对抗网络的生成器模型Generator、判别器模型Discriminator。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image

# 超参数设置
batch_size = 128
learning_rate = 0.0002
num_epochs = 80

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 下载并加载训练数据
train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

# 定义生成器模型
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 28*28),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x).view(x.size(0), 1, 28, 28)

# 定义判别器模型
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.model(x)

# 初始化模型
generator = Generator()
discriminator = Discriminator()

# 损失函数和优化器
criterion = nn.BCELoss()
optimizerG = optim.Adam(generator.parameters(), lr=learning_rate)
optimizerD = optim.Adam(discriminator.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        # 确保标签的大小与当前批次的数据大小一致
        real_labels = torch.ones(images.size(0), 1)
        fake_labels = torch.zeros(images.size(0), 1)

        # 训练判别器
        optimizerD.zero_grad()
        real_outputs = discriminator(images)
        d_loss_real = criterion(real_outputs, real_labels)
        z = torch.randn(images.size(0), 100)
        fake_images = generator(z)
        fake_outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(fake_outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizerD.step()

        # 训练生成器
        optimizerG.zero_grad()
        fake_images = generator(z)
        fake_outputs = discriminator(fake_images)
        g_loss = criterion(fake_outputs, real_labels)
        g_loss.backward()
        optimizerG.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')

    # 保存生成器生成的图片
    save_image(fake_images.data[:25], './fake_images/fake_images-{}.png'.format(epoch+1), nrow=5, normalize=True)

# 保存模型
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')

最后我们打开fake_images/文件夹,可以看到生成手写图片的过程:
在这里插入图片描述

五、总结

本项目利用生成对抗网络(GAN)实现了手写数字的生成。通过训练生成器和判别器,我们成功生成了逼真的手写数字图像。这些生成的图像可以应用于数据增强、图像修复、风格迁移等领域,为手写数字识别等相关任务提供有力支持。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/359486.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【RT-DETR有效改进】Bi-FPN高效的双向特征金字塔网络(附yaml文件+完整代码)

👑欢迎大家订阅本专栏,一起学习RT-DETR👑 一、本文介绍 本文给大家带来的改进机制是BiFPN双向特征金字塔网络,其是一种特征融合层的结构,也就是我们本文改进RT-DETR模型中的Neck部分,它的主要思想是通过多层级的特征金字塔和双向信息传递来提高精度。本文给大家带…

走迷宫-bfs

package Test;import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader;public class Main {static int N 110,hh 0,tt -1,n,m;static int[][] g new int[N][N]; //用来存储迷宫static int[][] d new int[N][N]; //用来存储d[i…

yarn 现代的包管理工具 介绍

一、前言 yarn 是一个现代的包管理工具,它是 npm(Node Package Manager)的一个替代品。yarn 由 Facebook 开发,并在 2016 年发布。它解决了当时 npm 的一些问题,尤其是在性能和安全性方面。 yarn 主要用于以下几个方面…

bat脚本:批量生成创建数据库的SQL语句

需求来源:使用 Navicat等数据库工具点击“转储SQL文件”会生成一个 xxx.sql 的文件,xxx是导出的数据库名。导出的数据库多了,就会一次性生成很多这样的SQL文件,所以需要写个脚本根据这些SQL脚本文件来批量生成创建数据库的SQL语句…

DX-11A DC0.075A 型信号继电器 柜内安装,板前接线

DX-11信号继电器; DX-11A信号继电器; DX-11B信号继电器; DX-11C信号继电器; DX-11Q信号继电器; DX-11A/Q信号继电器; DX-11B/Q信号继电器; DX-11C/Q信号继电器; 一. 用途 DX-11/0.…

React16源码: React中LegacyContext的源码实现

LegacyContext 老的 contextAPI 也就是我们使用 childContextTypes 这种声明方式来从父节点为它的子树提供 context 内容的这么一种方式遗留的contextAPI 在 react 17 被彻底移除了,就无法使用了那么为什么要彻底移除这个contextAPI的使用方式呢?因为它…

自建DNS劫持服务器,纯内网劫持PS5,屏蔽更新,自动hen

背景:目前PS5首次折腾必须要连外网,还要改DNS,除非使用ESP8266/32, 本文的方法是完全不改DNS,不使用ESP8266,不连接外网的情况下自动折腾 能实现什么: 1.折腾全程不连接外网 2.完全自建hen服务器&#xff…

时间序列表征之SAX(Symbolic Aggregate approXimation)实战python讲解

一、前言 sax理论篇:时间序列表征之SAX(Symbolic Aggregate approXimation)算法 二、sax实现 2.1 过程 标准化(将数据转换为高斯分布)paadiscretization 2.2 标准化 因为原文中采用的breakpoints为 前提假设为&#xf…

Redis五种数据类型及应用场景

1、数据类型 String(字符串,整数,浮点数):做简单的键值对缓存 List(列表):储存一些列表类型的数据结构 Hash(哈希):包含键值对的无序散列表,结构化的数据 Set(无序集合):交集,并集…

Java多线程--同步机制解决线程安全问题方式二:同步方法

文章目录 一、同步方法(1)同步方法--案例11、案例12、案例1之同步监视器 (2)同步方法--案例21、案例2之同步监视器的问题2、案例2的补充说明 二、代码及重要说明(1)代码(2)重要说明 …

云计算HCIE备考经验分享

大家好,我是来自深圳信息职业技术学院22级鲲鹏3-1班的刘同学,在2023年9月19日成功通过了华为云计算HCIE认证,并且取得了A的成绩。下面把我的考证经验分享给大家。 转专业进鲲鹏班考HCIE 大一上学期的时候,在上Linux课程的时候&…

代码随想录 Leetcode222.完全二叉树的节点个数

题目&#xff1a; 代码&#xff08;首刷自解 2024年1月30日&#xff09;&#xff1a; class Solution { public:int countNodes(TreeNode* root) {int res 0;if (root nullptr) return res;queue<TreeNode*> deque;TreeNode* cur root;deque.push(cur);int size 0;w…

注册亚马逊店铺用动态IP可以吗?

注册亚马逊店铺可以用动态IP&#xff0c;只要是独立且干净的网线就没问题&#xff0c;亚马逊规则要求一个IP地址只能出现一个亚马逊店铺&#xff0c;若使用不当会导致关联账户。 固定ip可以给我们的账户带来更多的安全&#xff0c;要知道关联问题是亚马逊上的一个大问题&#…

DBCO-PEG8-Amine,二苯并环辛炔 PEG8 氨基,具有良好反应活性

您好&#xff0c;欢迎来到新研之家 文章关键词&#xff1a;二苯并环辛炔-八聚乙二醇-氨基&#xff0c;二苯并环辛炔 PEG8 氨基&#xff0c;DBCO-PEG8-NH2&#xff0c;DBCO-PEG8-Amine 一、基本信息 产品简介&#xff1a;DBCO-PEG8-NH2 is a compound with good reactivity. …

ubuntu20.04安装sumo

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 有问题&#xff0c;请大家指出&#xff0c;争取使方法更完善。这只是ubuntu安装sumo的一种方法。一、注意事项1、首先明确你的ubuntu的用户名是什么 二、sumo安装1.…

Python爬虫实践指南:利用cpr库爬取技巧

引言 在信息时代&#xff0c;数据是无价之宝。为了获取网络上的丰富数据&#xff0c;网络爬虫成为了不可或缺的工具。在Python这个强大的编程语言中&#xff0c;cpr库崭露头角&#xff0c;为网络爬虫提供了便捷而高效的解决方案。本文将深入探讨如何利用cpr库实现数据爬取的各…

Ruff应用:打破传统,IoT技术赋能工业制造数字化转型之路

近年来&#xff0c;随着物联网、大数据、云计算、5G等数字技术的快速应用&#xff0c;工业制造领域正在经历着前所未有的变革。工业4.0时代&#xff0c;各种数字技术与工业制造的结合&#xff0c;不仅提高了工业生产效率、降低运营成本&#xff0c;更是极大地推动了传统工业数字…

智能小程序事件系统——SJS响应事件实现方案

背景信息 如有频繁用户交互&#xff0c;在小程序上表现是比较卡顿的。例如&#xff0c;页面有 2 个元素 A 和 B&#xff0c;用户在 A 上做 touchmove 手势&#xff0c;要求 B 也跟随移动&#xff0c;movable-view 就是一个典型的例子。一次 touchmove 事件的响应过程为&#x…

GPT-4 Vision调试任何应用,即使缺少文本日志 升级Streamlit七

GPT-4 Vision 系列: 翻译: GPT-4 with Vision 升级 Streamlit 应用程序的 7 种方式一翻译: GPT-4 with Vision 升级 Streamlit 应用程序的 7 种方式二翻译: GPT-4 Vision静态图表转换为动态数据可视化 升级Streamlit 三翻译: GPT-4 Vision从图像转换为完全可编辑的表格 升级St…

springboot 个人网盘系统 java web网盘文件分享系统 web在线云盘

springboot 个人网盘系统 java web网盘文件分享系统 web在线云盘 开发工具&#xff1a;Eclipse/idea Java开发环境&#xff1a;JDK8.0 Web服务器:Tomcate9.0。 数据库&#xff1a;MySQL数据库。 技术框架&#xff1a;Struts2SpringHibernate和JSP 有详细的源码&#xff0…