[跑代码]BK-SDM: A Lightweight, Fast, and Cheap Version of Stable Diffusion

Installation(下载代码-装环境)

conda create -n bk-sdm python=3.8
conda activate bk-sdm
git clone https://github.com/Nota-NetsPresso/BK-SDM.git
cd BK-SDM
pip install -r requirements.txt
Note on the torch versions we've used
  • torch 1.13.1 for MS-COCO evaluation & DreamBooth finetuning on a single 24GB RTX3090
     

  • torch 2.0.1 for KD pretraining on a single 80GB A10
    火炬2.0.1在单个80GB A100上进行KD预训练

    • 如果A100上总批大小为256的预训练导致gpu内存不足,请检查torch版本并考虑升级到torch>2.0.0。
      我的版本也是torch2.0.1 单个A100(80G)理论上吃的下256batch

小的例子

PNDM采样器 50步去噪声

等效代码(仅修改SD-v1.4的U-Net,同时保留其文本编码器和图像解码器):

Distillation Pretraining

Our code was based on train_text_to_image.py of Diffusers 0.15.0.dev0. To access the latest version, use this link.
BK-SDM的diffusers版本0.15
我的diffusers版本比较高0.24.0

检测是否能够训练(先下载数据集get_laion_data.sh再运行代码kd_train_toy.sh)

1 一个玩具数据集(11K的img-txt对)下载到。

bash scripts/get_laion_data.sh preprocessed_11k

/data/laion_aes/preprocessed_11k (1.7GB in tar.gz;1.8GB数据文件夹)。
get_laion_data.sh

需要修改,实际就是下载这三个数据集,我自行下载

# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_11k.tar.gz
# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_212k.tar.gz
# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_2256k.tar.gz

我修改后下载文件名 https://... .../preprocessed_11k.tar.gz直接粘贴到网址里面也可以下载
wget $S3_URL -0 $FILe_PATH
$S3_URL 就是这个网址
$FILe_PATH 就是下载路径./data/laion_aes/preprocessed_11k

DATA_TYPE=$"preprocessed_11k"  # {preprocessed_11k, preprocessed_212k, preprocessed_2256k}
FILE_NAME="${DATA_TYPE}.tar.gz"
 

DATA_DIR="./data/laion_aes/"
FILE_UNZIP_DIR="${DATA_DIR}${DATA_TYPE}"
FILE_PATH="${DATA_DIR}${FILE_NAME}"

if [ "$DATA_TYPE" = "preprocessed_11k" ] || [ "$DATA_TYPE" = "preprocessed_212k" ]; then
    echo "-> preprocessed_11k or 212k"
    S3_URL="https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/${FILE_NAME}"
elif [ "$DATA_TYPE" = "preprocessed_2256k" ]; then
    S3_URL="https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.25plus/${FILE_NAME}"
else
    echo "Something wrong in data folder name"
    exit
fi

wget $S3_URL -O $FILE_PATH
tar -xvzf $FILE_PATH -C $DATA_DIR
echo "downloaded to ${FILE_UNZIP_DIR}"

2 一个小脚本可以用来验证代码的可执行性,并找到与你的GPU匹配的批处理大小。
批量大小为8 (=4×2),训练BK-SDM-Base 20次迭代大约需要5分钟和22GB的GPU内存。

bash scripts/kd_train_toy.sh
MODEL_NAME="CompVis/stable-diffusion-v1-4"
TRAIN_DATA_DIR="./data/laion_aes/preprocessed_11k" # please adjust it if needed
UNET_CONFIG_PATH="./src/unet_config"

UNET_NAME="bk_small" # option: ["bk_base", "bk_small", "bk_tiny"]
OUTPUT_DIR="./results/toy_"$UNET_NAME # please adjust it if needed

BATCH_SIZE=2
GRAD_ACCUMULATION=4

StartTime=$(date +%s)

CUDA_VISIBLE_DEVICES=1 accelerate launch src/kd_train_text_to_image.py \
  --pretrained_model_name_or_path $MODEL_NAME \
  --train_data_dir $TRAIN_DATA_DIR\
  --use_ema \
  --resolution 512 --center_crop --random_flip \
  --train_batch_size $BATCH_SIZE \
  --gradient_checkpointing \
  --mixed_precision="fp16" \
  --learning_rate 5e-05 \
  --max_grad_norm 1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --report_to="all" \
  --max_train_steps=20 \
  --seed 1234 \
  --gradient_accumulation_steps $GRAD_ACCUMULATION \
  --checkpointing_steps 5 \
  --valid_steps 5 \
  --lambda_sd 1.0 --lambda_kd_output 1.0 --lambda_kd_feat 1.0 \
  --use_copy_weight_from_teacher \
  --unet_config_path $UNET_CONFIG_PATH --unet_config_name $UNET_NAME \
  --output_dir $OUTPUT_DIR


EndTime=$(date +%s)
echo "** KD training takes $(($EndTime - $StartTime)) seconds."

单GPU训练BK-SDM{Base, Small, Tiny}-0.22M数据训练
 

bash scripts/get_laion_data.sh preprocessed_212k
bash scripts/kd_train.sh

1 下载数据集preprocessed_212k
2 训练kd_train.sh
(256batch 训练BD-SM-Base 50K轮次需要300hours/53G单卡)
(64batch 训练BD-SM-Base 50K轮次需要60hours/28G单卡) 不理解?
 

单GPU训练BK-SDM{Base, Small, Tiny}-2.3M数据训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/205348.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

简单好用!日常写给 ChatGPT 的几个提示词技巧

ChatGPT 很强,但是有时候又显得很蠢,下面是使用 GPT4 的一个实例: 技巧一:三重冒号 """ 引用内容使用三重冒号 """,让 ChatGPT 清晰引用的内容: 技巧二:角色设定…

C++中的map和set的使用

C中的map详解 关联式容器键值对树形结构的关联式容器set的使用1. set的模板参数列表2. set的构造3. set的迭代器4. set的容量5. set修改操作6. set的使用举例 map1. map的简介2. map的模板参数说明3. map的构造4. map的迭代器5. map的容量与元素访问6. map的元素修改 multimap和…

centos8 下载

下载网址 Download 直接下载地址 https://mirrors.cqu.edu.cn/CentOS/8-stream/isos/x86_64/CentOS-Stream-8-20231127.0-x86_64-dvd1.iso 这个版本安装的时候方便

增强静态数据的安全性

静态数据是数字数据的三种状态之一,它是指任何静止并包含在永久存储设备(如硬盘驱动器和磁带)或信息库(如异地备份、数据库、档案等)中的数字信息。 静态数据是指被动存储在数据库、文件服务器、端点、可移动存储设备…

一篇文章带你掌握MongoDB

文章目录 1. 前言2. MongoDB简介3. MongoDB与关系型数据库的对比4. MongoDB的安装5. Compass的使用6. MongoDB的常用语句7. 总结 1. 前言 本文旨在帮助大家快速了解MongoDB,快速了解和掌握MongoDB的干货内容. 2. MongoDB简介 MongoDB是一种NoSQL数据库,采用了文档…

基于SpringBoot的在线视频教育平台的设计与实现

摘 要 随着科学技术的飞速发展,各行各业都在努力与现代先进技术接轨,通过科技手段提高自身的优势;对于在线视频教育平台当然也不能排除在外,随着网络技术的不断成熟,带动了在线视频教育平台,它彻底改变了过…

C++ -- 每日选择题 -- Day2

第一题 1. 下面代码中sizeof(A)结果为() #pragma pack(2) class A {int i;union U{char str[13];int i;}u;void func() {};typedef char* cp;enum{red,green,blue}color; }; A:20 B:21 C:22 D:24 答案及解析…

Python基础学习之包与模块详解

文章目录 前言什么是 Python 的包与模块包的身份证如何创建包创建包的小练习 包的导入 - import模块的导入 - from…import导入子包及子包函数的调用导入主包及主包的函数调用导入的包与子包模块之间过长如何优化 强大的第三方包什么是第三方包如何安装第三方包 总结关于Python…

[pyqt5]PyQt5之如何设置QWidget窗口背景图片问题

目录 PyQt5设置QWidget窗口背景图片 QWidget 添加背景图片问题QSS 背景图样式区别PyQt设置窗口背景图像,以及图像自适应窗口大小变化 总结 PyQt5设置QWidget窗口背景图片 QWidget 添加背景图片问题 QWidget 创建的窗口有时并不能直接用 setStyleSheet 设置窗口部分…

2023.11.27 使用anoconda搭建tensorflow环境

2023.11.27 使用anoconda搭建tensorflow环境 提供一个简便安装tensorflow的方法 1. 首先安装anoconda,安装过程略,注意安装的时候勾选安装anoconda prompt 2. 进入anoconda prompt 3. 建立python版本 conda create -n tensorflow1 python3.84. 激活t…

Java开发规范(简洁明了)

本篇规范基于阿里巴巴开发手册,总结了一些主要的开发规范,希望对大家有帮助。 目录 1. 命名规范: 2. 缩进和空格: 3. 花括号: 4. 注释: 5. 空行: 6. 导入语句: 7. 异常处理&a…

源 “MySQL 8.0 Community Server“ 的 GPG 密钥已安装,但是不适用于此软件包。请检查源的公钥 URL 是否配置正确。

源 “MySQL 8.0 Community Server“ 的 GPG 密钥已安装,但是不适用于此软件包。请检查源的公钥 URL 是否配置正确。yum install mysql-server --nogpgcheck

电子商务网站的技术 SEO:完整指南

技术 SEO 涉及对您的网站进行更改,这些更改由搜索引擎蜘蛛读取并由搜索结果索引。它简化了搜索引擎读取您网站的方式。如果不包括技术 SEO 服务,你就不可能有一个成功的电子商务 SEO 计划。主要目标是增强网站的框架。 在线商店的技术搜索引擎优化建议 …

熟悉SVN基本操作-(SVN相关介绍使用以及冲突解决)

一、SVN相关介绍 1、SVN是什么? 代码版本管理工具它能记住你每次的修改查看所有的修改记录恢复到任何历史版本恢复已经删除的文件 2、SVN跟Git比,有什么优势 使用简单,上手快目录级权限控制,企业安全必备子目录checkout,减少…

Redis应用的16个场景

常见的16种应用场景: 缓存、数据共享分布式、分布式锁、全局 ID、计数器、限流、位统计、购物车、用户消息时间线 timeline、消息队列、抽奖、点赞、签到、打卡、商品标签、商品筛选、用户关注、推荐模型、排行榜. 1、缓存 String类型 例如:热点数据缓存&#x…

【驱动】SPI驱动分析(四)-关键API解析

关键API 设备树 设备树解析 我们以Firefly 的SPI demo 分析下dts中对spi的描述&#xff1a; /* Firefly SPI demo */ &spi1 {spi_demo: spi-demo00{status "okay";compatible "firefly,rk3399-spi";reg <0x00>;spi-max-frequency <48…

C#工程中Form_xx.cs不能在设计器中查看

环境&#xff1a;VS2022 直接上图&#xff1a; 原因&#xff1a; 写了个类在Form_xx.cs中从For继承的部分类之前&#xff0c;移动到之后&#xff0c;保证窗体类是代码中的首个类即可&#xff0c;如图&#xff1a;

python pytorch实现RNN,LSTM,GRU,文本情感分类

python pytorch实现RNN,LSTM&#xff0c;GRU&#xff0c;文本情感分类 数据集格式&#xff1a; 有需要的可以联系我 实现步骤就是&#xff1a; 1.先对句子进行分词并构建词表 2.生成word2id 3.构建模型 4.训练模型 5.测试模型 代码如下&#xff1a; import pandas as pd im…

【强迫症患者必备】SpringBoot项目中Mybatis使用mybatis-redis开启三级缓存必须创建redis.properties优化方案

springboot项目中mybatis使用mybatis-redis开启三级缓存需要创建redis.properties优化方案 前言下载mybatis-redis源码分析RedisCache 代码RedisConfigurationBuilder的parseConfiguration方法 优化改造1.创建JedisConfig类2.复制RedisCache代码创建自定义的MyRedisCache3.指定…

INA219电流感应芯片_程序代码

详细跳转借鉴链接INA219例程此处进行总结 简单介绍一下 INA219&#xff1a; 1、 输入脚电压可以从 0V~26V,INA219 采用 3.3V/5V 供电. 2、 能够检测电流&#xff0c;电压和功率&#xff0c;INA219 内置基准器和乘法器使之能够直接以 A 为单位 读出电流值。 3、 16 位可编程地…