AI绘画训练一个扩散模型-上集

介绍

AI绘画,其中最常见方案基于扩散模型,Stable Diffusion 在此基础上,增加了 VAE 模块和 CLIP 模块,本文搞了一个测试Demo,分为上下两集,第一集是denoising_diffusion_pytorch ,第二集是diffusers。
对于专业的算法同学而言,我更推荐使用 diffusers 来训练。原因是 diffusers 工具包在实际的 AI 绘画项目中用得更多,并且也更易于我们修改代码逻辑,实现定制化功能。
https://arxiv.org/abs/2112.10752

基础模块

  • 创建UNet模型和高斯扩散模型(Gaussian Diffusion)。

UNet是一个编码器-解码器结构的全卷积神经网络。Gaussian Diffusion用于定义噪声过程和损失函数。

  • 将模型加载到GPU上(如果有GPU)。

  • 使用随机初始化的图片进行一次训练,计算损失并反向传播。

这一步的目的是对模型进行一次预热,更新权重。

  • 使用diffusion模型采样生成图片。

这里采样1000步,也就是将噪声逐步减少,每步用UNet预测下一步的图像,最终输出生成的图片。

  • 如果图片在GPU上,将其移回到CPU。

  • 可视化第一张生成图片。

plt.imshow显示图片。

这样通过DDPM框架,可以从随机噪声生成符合数据分布的新图片。每次训练会使模型逐步逼近真实数据分布,从而产生更高质量的图片。

# 创建UNet和扩散模型

from denoising_diffusion_pytorch import Unet, GaussianDiffusion
import torch

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8)
).cuda()

diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    timesteps = 1000   # number of steps
).cuda()

# 使用随机初始化的图片进行一次训练
training_images = torch.randn(8, 3, 128, 128)
loss = diffusion(training_images.cuda())
loss.backward()


# 采样1000步生成一张图片
sampled_images = diffusion.sample(batch_size = 4)
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import torchvision.transforms as transforms


# 如果张量在 GPU上,需要移动到 CPU上
if sampled_images.is_cuda:
    sampled_images = sampled_images.cpu()

# 检查我们生成的一张图
img = sampled_images[0].clone().detach().permute(1, 2, 0)

plt.imshow(img)

数据集

  • 导入所需的库:PIL、io、datasets等。

  • 使用datasets库中的load_dataset方法加载Oxford Flowers数据集。

  • 创建一个目录来保存图片。

  • 遍历数据集的训练、验证、测试split,逐个图像获取图片bytes数据,并保存为PNG格式图片。

  • 使用PIL库的Image对象将bytes数据加载并保存为图片文件。

  • 使用tqdm显示循环进度。

# 数据集下载
from PIL import Image
from io import BytesIO
from datasets import load_dataset
import os
from tqdm import tqdm

dataset = load_dataset("nelorth/oxford-flowers")

# 创建一个用于保存图片的文件夹
images_dir = "./oxford-datasets/raw-images"
os.makedirs(images_dir, exist_ok=True)

# 遍历所有图片并保存,针对oxford-flowers,整个过程要持续15分钟左右
for split in dataset.keys():
    for index, item in enumerate(tqdm(dataset[split])):
        image = item['image']
        image.save(os.path.join(images_dir, f"{split}_image_{index}.png"))

模型训练

  • 定义Unet模型架构和Gaussian Diffusion过程。

  • 加载数据,指定训练参数:

    • 训练总步数20000
    • batch size 16
    • 学习率2e-5
    • 梯度累积步数2
    • EMA指数衰减参数0.995
    • 使用混合精度训练
    • 每2000步保存一次模型
    • 创建Trainer进行模型训练。Trainer封装了训练循环逻辑。
  • 调用trainer.train()进行训练。

# 模型训练
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8)
).cuda()

diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    timesteps = 1000   # 加噪总步数
).cuda()

trainer = Trainer(
    diffusion,
    './oxford-datasets/raw-images',
    train_batch_size = 16,
    train_lr = 2e-5,
    train_num_steps = 20000,          # 总共训练20000步
    gradient_accumulate_every = 2,    # 梯度累积步数
    ema_decay = 0.995,                # 指数滑动平均decay参数
    amp = True,                       # 使用混合精度训练加速
    calculate_fid = False,            # 我们关闭FID评测指标计算(比较耗时)。FID用于评测生成质量。
    save_and_sample_every = 2000      # 每隔2000步保存一次模型
)

trainer.train()
# 你可以等待上面的模型训练完成后,查看生成结果

from glob import glob

result_images = glob(r"./results/*.png")
print(result_images)
# 可视化图像看看
from PIL import Image

img = Image.open("./results/sample-1.png")
img

测试

https://colab.research.google.com/github/NightWalker888/ai_painting_journey/blob/main/lesson12/train_diffusion_v2.ipynb#scrollTo=8BVjfBPI93Ar

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/268845.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

TomcatHttp协议

1 javaWEB 1.1 Web概述 Web在英文中的含义是网状物,网络。在计算机领域,它通常指的是后者,即网络。像我们前面接触的WWW,它是由3个单词组成的,即:World Wide Web,中文含义是万维网。而我们前…

读书笔记--构建数据湖仓阅读有感

企业为什么要开展数据治理?为什么在数据治理过程中提出数据湖仓构建?数据湖如果没有分析基础设施的建设,就会形成数据沼泽或臭水沟,因为没有人用,也不知道数据之间的关系。我们知道数据因业务运行而产生,后…

盒子 Box

UVa1587 思路&#xff1a; 1.输入每个面的长宽并将每个面较长的一边放在前面 2.判断是否存在三对面分别相等 3.判断是否存在三组四棱相等 #include <stdio.h> #include <stdlib.h> #define maxn 100int cmp(const void* e1, const void* e2) {return (int)(*(d…

Maya python清除命名空间

问题描述&#xff1a; Maya命名空间可能存在嵌套。 如上&#xff0c;直接删除 :female_actor02会出现异常。 因此需要先删除子命名空间&#xff0c;再删除父命名空间。 解决方法&#xff1a; def remove_namespace_node(namespace_name, ns_parent":"):""…

RabbitMQ入门指南(八):MQ可靠性

专栏导航 RabbitMQ入门指南 从零开始了解大数据 目录 专栏导航 前言 一、MQ数据持久化 1.交换机持久化 2.队列持久化 3.消息持久化 4.生产者确认机制 二、LazyQueue 1.LazyQueue模式介绍 2.管理控制台配置Lazy模式 3.代码配置Lazy模式 4.更新已有队列为lazy模式 总…

【计算机四级(网络工程师)笔记】操作系统运行机制

目录 一、中央处理器&#xff08;CPU&#xff09; 1.1CPU的状态 1.2指令分类 二、寄存器 2.1寄存器分类 2.2程序状态字&#xff08;PSW&#xff09; 三、系统调用 3.1系统调用与一般过程调用的区别 3.2系统调用的分类 四、中断与异常 4.1中断 4.2异常 &#x1f308;嗨&#xff…

RPC 实战与原理

文章目录 什么是 RPC&#xff1f;RPC 有什么作用&#xff1f;RPC 步骤为什么需要序列化&#xff1f;零拷贝什么是零拷贝&#xff1f;为什么需要零拷贝&#xff1f;如何实现零拷贝&#xff1f;Netty 的零拷贝有何不同&#xff1f; 动态代理实现HTTP/2 特性为什么需要服务发现&am…

DDD领域驱动设计系列-原理篇-战术设计

概述 上篇战略设计产出了领域及问题域领域模型&#xff1b;详见&#xff1a;DDD领域驱动设计系列-原理篇-战略设计-CSDN博客 战术设计篇聚焦如何落地&#xff0c;包含实际解决方案模型落地&#xff0c;架构分层&#xff08;Clean&#xff0c;CQRS&#xff09;&#xff0c;Rep…

04-C++ 类和对象-02

类和对象-02 1. this 指针 1.1 概念&#xff1a; 谁调用this所在的函数&#xff0c;this就存储谁的地址&#xff0c;即指向谁 。 1.2 特点&#xff1a; 在当前类的非静态成员函数中调用本类非静态成员时&#xff0c;默认有this关键字静态成员函数&#xff0c;没有this指针…

Neovim+ctag浏览、编辑源代码

Neovimctag浏览、编辑源代码 一 配置安装vim及 ctags vim应该可以不用装&#xff0c;直接装neovim&#xff0c;这里我是先装了vim再装的neovim Ctags必须装&#xff0c;后面用neovim telescope索引函数时才有效 vim复制系统粘贴板&#xff1a;vim输入模式下&#xff0c;按shi…

低功耗16位MCU:R7F100GLL3CFA、R7F100GLN2DLA、R7F100GLN3CFA、R7F100GLN2DFA是新一代RL78微控制器

产品介绍&#xff1a; RL78/G23低功耗MCU可在41μA/MHz CPU运行频率下工作&#xff0c;功耗低&#xff0c;停止4KB SRAM保持时为210nA。该MCU设有snooze模式排序器&#xff0c;可显著降低间歇工作时的功耗。RL78/G23组具有1.6V至5.5V宽工作电压范围&#xff0c;频率高达32MHz。…

bean生命周期源码(三)

书接上文 文章目录 一、Bean的销毁逻辑1. 简介2. Bean销毁逻辑的注册3. Bean的销毁过程 一、Bean的销毁逻辑 1. 简介 前面我们已经分析完了Spring创建Bean的整个过程的源码&#xff0c;在创建bean的核心方法中doCreateBean这一个核心方法中&#xff0c;在方法的最后面有这么…

SpringCloudAlibaba Seata在Openfeign跨节点环境出现全局事务Xid失效原因底层探究

原创/朱季谦 曾经在SpringCloudAlibaba的Seata分布式事务搭建过程中&#xff0c;跨节点通过openfeign调用不同服务时&#xff0c;发现全局事务XID在当前节点也就是TM处&#xff0c;是正常能通过RootContext.getXID()获取到分布式全局事务XID的&#xff0c;但在下游节点就出现获…

ros2+gazebo+urdf:ros2机器人使用gazebo的urdf文件中的<gazebo>部分官网资料

原文链接SDFormat extensions to URDF (the gazebo tag) — Documentation 注意了ros2的gazebo部分已经跟ros1的gazebo部分不一样了&#xff1a; Toggle navigation SpecificationAPIDocumentationDownload Back Edit Version: 1.6 Table of C…

我在代码随想录|写代码Day-Day之总结篇

我是用笔手写的我觉得这样可以对个人记忆会更好,而且理解更深解释也更清楚 下面是手写笔记 总结部分----- 第一章 二分 二分模版 图片可能反了下不过没有关系 图形打印模版题 第二章 链表 链表基本操作和疑问 链表代码操作和解析----5大操作 删除部分 对结点的操作 反了反了…

关于个人Git学习记录及相关

前言 可以看一下猴子都能懂的git入门&#xff0c;图文并茂不枯燥 猴子都能懂的git入门 学习东西还是建议尽可能的去看官方文档 权威且详细 官方文档 强烈建议看一下GitHub漫游指南及开源指北&#xff0c;可以对开源深入了解一下&#xff0c;打开新世界的大门&#xff01; …

若依SQL Server开发使用教程

1. sys_menu表中的将菜单ID修改为自动ID,解决不能增加菜单的问题&#xff0c;操作流程如下&#xff1a; 解决方案如下 菜单栏->工具->选项 点击设计器&#xff0c;去掉阻止保存要求更新创建表的更改选项&#xff0c;点确认既可以保存了 2 自动生成代码找不表的解决方案…

C语言--直接插入排序【排序算法|图文详解】

一.直接插入排序介绍&#x1f357; 直接插入排序又叫简单插入排序&#xff0c;是一种简单直观的排序算法&#xff0c;它通过构建有序序列&#xff0c;对于未排序的数据&#xff0c;在已排序序列中从后向前扫描&#xff0c;找到相应位置并插入。 算法描述&#xff1a; 假设要排序…

idea 找不到 Free MyBatis plugin

idea 找不到 free mybatis plugin 可以使用mybatisX替换&#xff1a; 插件安装成功后&#xff0c;重启idea。

【English】水果单词小小汇总~~

废物研究生&#xff0c;只要不搞科研干啥都是开心的&#xff0c;啊啊啊啊啊科研要命。作为一个水果怪&#xff08;每天不吃水果就要命的那种哈哈哈哈&#xff09;突然发现竟然就知道什么apple、banana、orange&#xff01;惭愧惭愧&#xff0c;正好兴致正浓&#xff0c;来整理一…