Stable Diffusion (version x.x) 文生图模型实践指南

前言:本篇博客记录使用Stable Diffusion模型进行推断时借鉴的相关资料和操作流程。

相关博客:
超详细!DALL · E 文生图模型实践指南
DALL·E 2 文生图模型实践指南

目录

  • 1. 环境搭建和预训练模型准备
    • 环境搭建
    • 预训练模型下载
  • 2. 代码


1. 环境搭建和预训练模型准备

环境搭建

pip install diffusers transformers accelerate scipy safetensors

预训练模型下载

关于 huggingface 网站总是崩溃的情况,找到一个解决办法,就是可以通过脚本来下载

第一步:安装 huggingface_hub,使用命令 pip install huggingface_hub
第二步:下载具体模型,使用命令 python model_download.py --repo_id model_id,其中,model_id 为要下载的模型,比如SD v2.1 版本的model_id可以是 stabilityai/stable-diffusion-2-1;SD v1.5 版本的model_id可以是 runwayml/stable-diffusion-v1-5. model_id 的查找方式是在huggingface 网站直接搜索需要的模型(如下图),得到的「模型来源/版本」的组合即为所需。

在这里插入图片描述

model_download.py文件来自这个链接。

# usage     : python model_download.py --repo_id repo_id
# example   : python model_download.py --repo_id facebook/opt-350m
import argparse
import time
import requests
import json
import os
from huggingface_hub import snapshot_download
import platform
from tqdm import tqdm
from urllib.request import urlretrieve


def _log(_repo_id, _type, _msg):
    date1 = time.strftime('%Y-%m-%d %H:%M:%S')
    print(date1 + " " + _repo_id + " " + _type + " :" + _msg)


def _download_model(_repo_id, _repo_type):
    if _repo_type == "model":
        _local_dir = 'dataroot/models/' + _repo_id
    else:
        _local_dir = 'dataroot/datasets/' + _repo_id
    try:
        if _check_Completed(_repo_id, _local_dir):
            return True, "check_Completed ok"
    except Exception as e:
        return False, "check_Complete exception," + str(e)
    _cache_dir = 'caches/' + _repo_id

    _local_dir_use_symlinks = True
    if platform.system().lower() == 'windows':
        _local_dir_use_symlinks = False
    try:
        if _repo_type == "model":
            snapshot_download(repo_id=_repo_id, cache_dir=_cache_dir, local_dir=_local_dir, local_dir_use_symlinks=_local_dir_use_symlinks,
                              resume_download=True, max_workers=4)
        else:
            snapshot_download(repo_id=_repo_id, cache_dir=_cache_dir, local_dir=_local_dir, local_dir_use_symlinks=_local_dir_use_symlinks,
                              resume_download=True, max_workers=4, repo_type="dataset")
    except Exception as e:
        error_msg = str(e)
        if ("401 Client Error" in error_msg):
            return True, error_msg
        else:
            return False, error_msg
    _removeHintFile(_local_dir)
    return True, ""


def _writeHintFile(_local_dir):
    file_path = _local_dir + '/~incomplete.txt'
    if not os.path.exists(file_path):
        if not os.path.exists(_local_dir):
            os.makedirs(_local_dir)
        open(file_path, 'w').close()


def _removeHintFile(_local_dir):
    file_path = _local_dir + '/~incomplete.txt'
    if os.path.exists(file_path):
        os.remove(file_path)


def _check_Completed(_repo_id, _local_dir):
    _writeHintFile(_local_dir)
    url = 'https://huggingface.co/api/models/' + _repo_id
    response = requests.get(url)
    if response.status_code == 200:
        data = json.loads(response.text)
    else:
        return False
    for sibling in data["siblings"]:
        if not os.path.exists(_local_dir + "/" + sibling["rfilename"]):
            return False
    _removeHintFile(_local_dir)
    return True


def download_model_retry(_repo_id, _repo_type):
    i = 0
    flag = False
    msg = ""
    while True:
        flag, msg = _download_model(_repo_id, _repo_type)
        if flag:
            _log(_repo_id, "success", msg)
            break
        else:
            _log(_repo_id, "fail", msg)
            if i > 1440:
                msg = "retry over one day"
                _log(_repo_id, "fail", msg)
                break
            timeout = 60
            time.sleep(timeout)
            i = i + 1
            _log(_repo_id, "retry", str(i))
    return flag, msg


def _fetchFileList(files):
    _files = []
    for file in files:
        if file['type'] == 'dir':
            filesUrl = 'https://e.aliendao.cn/' + file['path'] + '?json=true'
            response = requests.get(filesUrl)
            if response.status_code == 200:
                data = json.loads(response.text)
                for file1 in data['data']['files']:
                    if file1['type'] == 'dir':
                        filesUrl = 'https://e.aliendao.cn/' + \
                            file1['path'] + '?json=true'
                        response = requests.get(filesUrl)
                        if response.status_code == 200:
                            data = json.loads(response.text)
                            for file2 in data['data']['files']:
                                _files.append(file2)
                    else:
                        _files.append(file1)
        else:
            if file['name'] != '.gitattributes':
                _files.append(file)
    return _files


def _download_file_resumable(url, save_path, i, j, chunk_size=1024*1024):
    headers = {}
    r = requests.get(url, headers=headers, stream=True, timeout=(20, 60))
    if r.status_code == 403:
        _log(url, "download", '下载资源发生了错误,请使用正确的token')
        return False
    bar_format = '{desc}{percentage:3.0f}%|{bar}|{n_fmt}M/{total_fmt}M [{elapsed}<{remaining}, {rate_fmt}]'
    _desc = str(i) + ' of ' + str(j) + '(' + save_path.split('/')[-1] + ')'
    total_length = int(r.headers.get('content-length'))
    if os.path.exists(save_path):
        temp_size = os.path.getsize(save_path)
    else:
        temp_size = 0
    retries = 0
    if temp_size >= total_length:
        return True
    # 小文件显示
    if total_length < chunk_size:
        with open(save_path, 'wb') as f:
            for chunk in r.iter_content(chunk_size=chunk_size):
                if chunk:
                    f.write(chunk)
        with tqdm(total=1, desc=_desc, unit='MB', bar_format=bar_format) as pbar:
            pbar.update(1)
    else:
        headers['Range'] = f'bytes={temp_size}-{total_length}'
        r = requests.get(url, headers=headers, stream=True,
                         verify=False, timeout=(20, 60))
        data_size = round(total_length / 1024 / 1024)
        with open(save_path, 'ab') as fd:
            fd.seek(temp_size)
            initial = temp_size//chunk_size
            for chunk in tqdm(iterable=r.iter_content(chunk_size=chunk_size), initial=initial, total=data_size, desc=_desc, unit='MB', bar_format=bar_format):
                if chunk:
                    temp_size += len(chunk)
                    fd.write(chunk)
                    fd.flush()
    return True


def _download_model_from_mirror(_repo_id, _repo_type, _token, _e):
    if _repo_type == "model":
        filesUrl = 'https://e.aliendao.cn/models/' + _repo_id + '?json=true'
    else:
        filesUrl = 'https://e.aliendao.cn/datasets/' + _repo_id + '?json=true'
    response = requests.get(filesUrl)
    if response.status_code != 200:
        _log(_repo_id, "mirror", str(response.status_code))
        return False
    data = json.loads(response.text)
    files = data['data']['files']
    for file in files:
        if file['name'] == '~incomplete.txt':
            _log(_repo_id, "mirror", 'downloading')
            return False
    files = _fetchFileList(files)
    i = 1
    for file in files:
        url = 'http://61.133.217.142:20800/download' + file['path']
        if _e:
            url = 'http://61.133.217.139:20800/download' + \
                file['path'] + "?token=" + _token
        file_name = 'dataroot/' + file['path']
        if not os.path.exists(os.path.dirname(file_name)):
            os.makedirs(os.path.dirname(file_name))
        i = i + 1
        if not _download_file_resumable(url, file_name, i, len(files)):
            return False
    return True


def download_model_from_mirror(_repo_id, _repo_type, _token, _e):
    if _download_model_from_mirror(_repo_id, _repo_type, _token, _e):
        return
    else:
        #return download_model_retry(_repo_id, _repo_type)
        _log(_repo_id, "download", '下载资源发生了错误,请使用正确的token')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--repo_id', default=None, type=str, required=True)
    parser.add_argument('--repo_type', default="model",
                        type=str, required=False)  # models,dataset
    # --mirror为从aliendao.cn镜像下载,如果aliendao.cn没有镜像,则会转到hf
    # 默认为True
    parser.add_argument('--mirror', action='store_true',
                        default=True, required=False)
    parser.add_argument('--token', default="", type=str, required=False)
    # --e为企业付费版
    parser.add_argument('--e', action='store_true',
                        default=False, required=False)
    args = parser.parse_args()
    if args.mirror:
        download_model_from_mirror(
            args.repo_id, args.repo_type, args.token, args.e)
    else:
        download_model_retry(args.repo_id, args.repo_type)

2. 代码

Stable Diffusion 完整推断流程如下(from https://huggingface.co/stabilityai/stable-diffusion-2-1):

import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

model_id = "/dataroot/models/stabilityai/stable-diffusion-2-1"  # 预训练模型的下载路径

# Use the DPMSolverMultistepScheduler (DPM-Solver++) scheduler here instead
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
    
image.save("astronaut_rides_horse.png")

参考文献

  1. https://aliendao.cn/model_download.py
  2. https://github.com/Stability-AI/stablediffusion

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/149389.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Rust】快速教程——从hola,mundo到所有权

前言 学习rust的前提如下&#xff1a; &#xff08;1&#xff09;先把Rust环境装好 &#xff08;2&#xff09;把VScode中关于Rust的插件装好 \;\\\;\\\; 目录 前言先写一个程序看看Rust的基础mut可变变量let重定义覆盖变量基本数据类型复合类型&#xff08;&#xff09;和 [ …

[Linux] 网络文件共享服务

一、存储类型 存储类型可分为三类&#xff1a;DAS&#xff08;直连式存储&#xff09;,NAS&#xff08;网络附加存储&#xff09;,SAN&#xff08;存储区域网络&#xff09;。 1.1 DAS 定义&#xff1a; DAS是指直连存储&#xff0c;即直连存储&#xff0c;可以理解为本地文…

【软考篇】中级软件设计师 第三部分(二)

中级软件设计师 第三部分&#xff08;二&#xff09; 二十四. 概念设计阶段24.1 E-R模式24.2 E-R图 二十五. 网络和多媒体25.1 计算机网络分类25.2 OSI/RM参考模型25.3 网络互联硬件25.4 TCP/IP分层模型 二十六. IP地址26.1 子网划分26.2 特殊IP26.3 IPv626.4 冲突与和广播域26…

电池故障估计:Realistic fault detection of li-ion battery via dynamical deep learning

昇科能源、清华大学欧阳明高院士团队等的最新研究成果《动态深度学习实现锂离子电池异常检测》&#xff0c;用已经处理的整车充电段数据&#xff0c;分析车辆当前或近期是否存在故障。 思想步骤&#xff1a; 用正常电池的充电片段数据构造训练集&#xff0c;用如下的方式构造…

2023亚太杯数学建模思路 - 复盘:光照强度计算的优化模型

文章目录 0 赛题思路1 问题要求2 假设约定3 符号约定4 建立模型5 模型求解6 实现代码 建模资料 0 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 1 问题要求 现在已知一个教室长为15米&#xff0c;宽为12米&…

文件包含学习笔记总结

文件包含概述 ​ 程序开发人员通常会把可重复使用函数或语句写到单个文件中&#xff0c;形成“封装”。在使用某个功能的时候&#xff0c;直接调用此文件&#xff0c;无需再次编写&#xff0c;提高代码重用性&#xff0c;减少代码量。这种调用文件的过程通常称为包含。 ​ 程…

修改ubuntu终端目录背景颜色

Ubuntu终端上有部分目录是黄绿色底色&#xff0c;看着很不舒服。如下图所示&#xff1a; 这是由于修改用户权限导致的问题。 通过下面指令可以看到 echo $LS_COLORS | grep "ow" ​ 可以看到ow的默认参数是34:42ow:OTHER_WRITABLE&#xff0c;即其他用户可写权限 …

【C++11】线程库

文章目录 thread 线程库mutex 锁atomic 原子性操作condition_variable 条件变量实现两个线程交替打印1-100 thread 线程库 在C11之前&#xff0c;涉及到多线程问题&#xff0c;都是和平台相关的&#xff0c;比如Windows和Linux下各有自己的接口&#xff0c;这使得代码的可移植…

Spring Framework 简介与起源

Spring是用于企业Java应用程序开发的最流行的应用程序开发框架。全球数百万开发人员使用Spring Framework创建高性能、易于测试和可重用的代码。 Spring Framework是一个开源的Java平台。它最初由Rod Johnson编写&#xff0c;并于2003年6月在Apache 2.0许可下首次发布。 Spri…

7 款最好的 Android 手机数据恢复软件榜单(持续更新列表)

数据丢失会干扰您的个人生活和业务&#xff0c;如果手动完成&#xff0c;可能很难恢复丢失的数据。 Android数据恢复软件是克服此问题的完美解决方案。 这些工具可以帮助您快速轻松地从Android设备恢复丢失的数据。 它可以帮助您恢复照片、视频、笔记、联系人等。 7 款最好的An…

双十一买高画质投影仪,当贝F6还是极米H6?

如果你想购买一台4K画质的投影仪,那么在各大平台搜索“双十一最值得买的4K投影仪”时,一定会注意到当贝F6和极米H6这两个型号投影仪。个人认为当贝F6和极米H6都分别是当贝和极米两个品牌非常具有性价比的4K投影仪。那么到底哪一台更适合你。 首先放一张参数对比图,方便参数控研…

有效数字(表示数值的字符串),剑指offer,力扣

目录 题目地址&#xff1a; 我们直接看题解吧&#xff1a; 难度分析&#xff1a; 解题方法&#xff1a; 审题目事例提示&#xff1a; 解题思路&#xff1a; 代码实现&#xff1a; 题目地址&#xff1a; LCR 138. 有效数字 - 力扣&#xff08;LeetCode&#xff09; 难度&#xf…

Win11系统安装或执行程序时提示:文件系统错误(-1073740771)解决方案

有用户反映&#xff0c;exe文件无法执行或者无法安装&#xff0c;报错如图所示&#xff1a; 解决方法&#xff1a; 方法一&#xff1a; 1.打开控制面板&#xff0c;可以采用”搜索“→”控制面板“的方式 2.控制面板选择“用户账户”&#xff0c;再选择“更改用户账户控制设…

Java-绘图

文章目录 Java绘图Java绘图类绘图颜色与画笔属性设置颜色设置画笔 绘制文本显示图片图像处理1、放大与缩小2、图像翻转3、图像旋转4、图像倾斜 End Java绘图 Java绘图是指在Java程序中创建和显示图形的过程。Java提供了许多类和方法来支持绘图。 Java绘图类 Java中主要的绘图类…

DevExpress WinForms HeatMap组件,一个高度可自定义热图控件!

通过DevExpress WinForms可以为Windows Forms桌面平台提供的高度可定制的热图UI组件&#xff0c;体验DevExpress的不同之处。 DevExpress WinForms有180组件和UI库&#xff0c;能为Windows Forms平台创建具有影响力的业务解决方案。同时能完美构建流畅、美观且易于使用的应用程…

云课五分钟-04一段代码学习-大模型分析C++

前篇&#xff1a; 云课五分钟-03第一个开源游戏复现-贪吃蛇 经过01-03&#xff0c;基本了解云课最大的优势之一就是快速复现&#xff08;部署&#xff09;。 视频&#xff1a; 云课五分钟-04一段代码学习-大模型分析C AIGC大模型时代&#xff0c;学习编程语言的方式&#xf…

知虾数据分析软件:了解知虾数据分析软件提升Shopee店铺运营效果

在如今电商竞争激烈的市场中&#xff0c;了解市场趋势和产品数据是成功经营一家Shopee店铺的重要因素之一。而知虾——Shopee生意参谋作为一款功能强大的数据分析软件&#xff0c;可以帮助店主深入了解行业概况、产品潜力以及市场趋势&#xff0c;从而制定最优的运营策略。本文…

NI USRP RIO软件无线电

NI USRP RIO软件无线电 NI USRP RIO是SDR游戏规则的改变者&#xff0c;它为无线通信设计人员提供了经济实惠的SDR和前所不高的性能&#xff0c;可帮助开发下一代5G无线通信系统。“USRP RIO”是一个术语&#xff0c;用于描述包含FPGA的USRP软件定义无线电设备&#xff0c;例如…

PC端微信@所有人逻辑漏洞

&#xff08;一&#xff09;过程 这个漏洞是PC端微信&#xff0c;可以越权让非管理员艾特所有人&#xff0c;具体步骤如下 第一步&#xff1a;找一个自己的群&#xff08;要有艾特所有人的权限&#xff09;“123”是我随便输入的内容&#xff0c;可以更改&#xff0c;然后按c…

技巧篇:Mac 环境PyCharm 配置 python Anaconda

Mac 中 PyCharm 配置 python Anaconda环境 在 python 开发中我们最常用的IDE就是PyCharm&#xff0c;有关PyCharm的优点这里就不在赘述。在项目开发中我们经常用到许多第三方库&#xff0c;用的最多的命令就是pip install 第三方库名 进行安装。现在你可以使用一个工具来帮你解…