代码解读:使用Stable Diffusion完成相似图像生成任务

Diffusion models代码解读:入门与实战

前言:作为内容生产重要的一部分,生成相似图像是一项有意义的工作,例如很多内容创作分享平台单纯依赖用户贡献的图片已经不够了,最省力的方法就是利用已有的图片生成相似的图片作为补充。这篇博客详细解读基于Stable Diffusion生成相似图片的原理和代码。

目录

原理详解

代码实战

环境安装

代码运行

参数解读

快速体验

效果展示

代码解读

模型加载

图像加噪

分类器引导


原理详解

首先读者需要熟悉Stable Diffusion的原理,这部分的可以参考我之前的博客:

Diffusion models代码实战:从零搭建自己的扩散模型_diffusion model 自己-CSDN博客

在Stable Diffusion 中,latent域注入了text condition;在相似图像生成的任务中,把这个text condition替换成Image conditon:

推理的时候,会先在图片中加入噪声,这个添加噪声的程度会用noise_level参数控制。然后把这个加噪过的图片embedding输入到模型中。

原理就这么简单……

代码实战

环境安装

pip install git+https://github.com/lllcho/image_variation.git
pip install modelscope

代码运行

from modelscope.pipelines import pipeline
from modelscope.outputs import OutputKeys
from PIL import Image
from image_variation import modelscope_warpper

model = 'damo/cv_image_variation_sd'
pipe = pipeline('image_variation_task', model=model, device='gpu',auto_collate=False,model_revision='v1.1.0')
out=pipe('https://vision-poster.oss-cn-shanghai.aliyuncs.com/lllcho.lc/data/test_data/sunset-landscape-sky-colorful-preview.jpg')
imgs=out[OutputKeys.OUTPUT_IMGS]
imgs[0].save(f'result.jpg')

参数解读

pipeline调用时的可调参数:

num_inference_steps: int, 默认为20
guidance_scale:float, 默认5.0
num_images_per_prompt:默认为1,每次调用返回几张图,可根据显存大小调整
seed:默认为None,int类型,取值范围[0, 2^32-1]
height::默认值512
width:默认值512
noise_level: int,默认值为0, 取值范围[0,999],表示像输入图像中加入噪声,值越大噪声越多,生成结果与输入图像的相似度越低

快速体验

https://modelscope.cn/studios/iic/image_variation/summary

效果展示

输入的图像:

输出的图像

代码解读

代码地址:https://modelscope.cn/studios/iic/image_variation/summary

模型加载

    scheduler = UniPCMultistepScheduler(beta_start=0.00085,beta_end=0.012,beta_schedule='scaled_linear')
    vae = AutoencoderKL.from_pretrained(ckpt_dir, subfolder='vae')
    vae.eval()
    unet = UNet2DConditionModel.from_pretrained(ckpt_dir, subfolder='unet')
    cond_model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('ViT-H-14',
        pretrained=osp.join(ckpt_dir,'CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin')     

图像加噪

对输入的图片embedding加噪后输入到模型中:

    def noise_image_embeddings(self,image_embeds,noise_level,generator=None):
        noise = randn_tensor(image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype)
        noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
        
        meanstd=torch.from_numpy(np.load(self.norm_file)[None]).to(device=image_embeds.device, dtype=image_embeds.dtype)
        mean,std=torch.chunk(meanstd,2,dim=1)
        #scale
        image_embeds-=mean
        image_embeds/=std
        
        image_embeds=self.scheduler.add_noise(image_embeds,noise,noise_level)
        
        #unscale
        image_embeds*=std
        image_embeds+=mean
    
        return image_embeds

分类器引导

和text condition 一样,这里也有“negative prompt”的分类器引导。

图像的“negative prompt”是用masked image替代的,maked image用了一个图片值全0的图片表示:

        mask=torch.ones(2,1,height//8,width//8,device=self.device,dtype=self.dtype)
        masked_img=torch.zeros(2,3,height,width,device=self.device,dtype=self.dtype)
        masked_image_latents = self.vae.encode(masked_img).latent_dist.sample()
        masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
        
        mask=mask.repeat(num_images_per_prompt,1,1,1)
        masked_image_latents=masked_image_latents.repeat(num_images_per_prompt,1,1,1)

完整的推理代码如下:

   @torch.no_grad()
    def __call__(self,
                 image,
                 num_inference_steps=20,
                 guidance_scale=5.0,
                 num_images_per_prompt=1,
                 seed=None,
                 height=512,
                 width=512,
                 noise_level: int=0,
                 ):
        if seed is None:
            seed=random.randint(0,2**32-1)
        set_seed(seed)
        self.scheduler.set_timesteps(num_inference_steps, device=self.device)
        timesteps = self.scheduler.timesteps
        
        mask=torch.ones(2,1,height//8,width//8,device=self.device,dtype=self.dtype)
        masked_img=torch.zeros(2,3,height,width,device=self.device,dtype=self.dtype)
        masked_image_latents = self.vae.encode(masked_img).latent_dist.sample()
        masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
        
        mask=mask.repeat(num_images_per_prompt,1,1,1)
        masked_image_latents=masked_image_latents.repeat(num_images_per_prompt,1,1,1)
         
        latents = randn_tensor((num_images_per_prompt,4,height//8,width//8),device=self.device, generator=None,dtype=self.dtype)
        latents = latents * self.scheduler.init_noise_sigma

        cond_image=image
        clip_img=self.preprocess(read_img(cond_image).convert('RGB')).unsqueeze(0)
        cond_embedding=self.cond_model.encode_image(clip_img.to(self.device,self.dtype)).to(self.dtype)
        cond_embedding=cond_embedding.repeat(num_images_per_prompt,1,1)
        if noise_level>0:
            cond_embedding=self.noise_image_embeddings(cond_embedding,noise_level)
        cond_embedding=torch.cat([cond_embedding*0,cond_embedding])
        
        for i, t in enumerate(timesteps):
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
            latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
            
            noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=cond_embedding,cross_attention_kwargs={}).sample
            
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
            latents = self.scheduler.step(noise_pred, t, latents, **{}).prev_sample
        latents = 1 / self.vae.config.scaling_factor * latents
        image = self.vae.decode(latents.to(self.dtype)).sample
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        image = (image * 255).round().astype("uint8")
        imgs=[Image.fromarray(img) for img in image]
        return imgs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/535249.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C中自定义类型——结构体

一.前言 在C语言中,不仅有int、char、short、long等内置类型,C语言还有一种特殊的类型——自定义类型。该类型可以由使用者自己定义,可以解决一些复杂的个体。 二.结构体 2.1结构体的声明 我们在利用结构体的时候一般是用于描述一些有多种…

代码随想录算法训练营第三十一天| 455.分发饼干、376.摆动序列、53.最大子序和

系列文章目录 目录 系列文章目录455.分发饼干贪心算法大饼干喂胃口大的(先遍历胃口)胃口大的先吃大饼干(先遍历饼干)小饼干先喂胃口小的(先遍历胃口)胃口小的先吃小饼干(先遍历饼干) 376. 摆动序…

Claude使用教程

claude 3 opus面世后,网上盛传吊打了GPT-4。网上这几天也已经有了许多应用,但竟然还有很多小伙伴不知道国内怎么用gpt,也不知道怎么去用这个据说已经吊打了gpt-4的claude3。 今天我们想要进行的一项尝试就是—— 用claude3和gpt4&#xff0c…

2024-04-11最新dubbo+zookeeper下载安装,DEMO展示

dubbozookeeper下载安装 下载zookeeper: 下载地址 解压,并进入bin目录,启动 如果闪退可以编辑脚本,在指定位置加上暂停脚本 报错内容说没有conf/zoo.cfg,就复制zoo_sample.cfg重命名为zoo.cfg 再次启动脚本&#x…

HarmonyOS开发实例:【手势截屏】

介绍 本篇Codelab基于手势处理和截屏能力,介绍了手势截屏的实现过程。样例主要包括以下功能: 根据下滑手势调用全屏截图功能。全屏截图,同时右下角有弹窗提示截图成功。根据双击手势调用区域截图功能。区域截图,通过调整选择框大…

Excel 记录单 快速录入数据

一. 调出记录单 ⏹记录单功能默认是隐藏的,通过如下如图所示的方式,将记录单功能显示出来。 二. 录入数据 ⏹先在表格中录入一行数据,给记录单一个参考 ⏹将光标至于表格右上角,然后点击记录单按钮,调出记录单 然后点…

百元不入耳运动耳机哪个品牌好?五款业内顶尖品牌推荐

在追求舒适与健康的运动中,不入耳式(开放式耳机)运动耳机逐渐成为了许多运动爱好者的首选,它们不仅避免了长时间佩戴耳机带来的不适,还能在享受音乐的同时保持对环境的警觉,确保运动安全,市场上…

Python中同时调用多个列表

如果你有多个列表,想要同时迭代它们,可以使用zip()函数。zip()函数可以将多个可迭代对象合并成一个元组的迭代器,然后你可以在循环中使用它。 问题背景 当需要在Python脚本中避免重复相同任务时,可以使用for循环来遍历列表。但是…

Volatility-内存取证案例1-writeup--xx大赛

题目提示:flag{中文} 按部就班 (1)获取内存镜像版本信息 volatility -f 文件名 imageinfo 通过上述可知,镜像版本为Win7SP1X64。 (2)获取进程信息: volatility -f 镜像名 --profile第一步获取…

面壁智能完成新一轮数亿元融资,继续面向AGI的高效大模型征程

近日,面壁智能完成新一轮数亿元融资,由春华创投、华为哈勃领投,北京市人工智能产业投资基金等跟投,知乎作为战略股东持续跟投支持。本轮融资完成后,面壁智能将进一步推进优秀人才引入,加固大模型发展的底层…

6.12物联网RK3399项目开发实录-驱动开发之UART 串口的使用(wulianjishu666)

嵌入式实战开发例程【珍贵收藏,开发必备】: 链接:https://pan.baidu.com/s/1tkDBNH9R3iAaHOG1Zj9q1Q?pwdt41u UART 使用 简介 AIO-3399J 支持 SPI 桥接/扩展 4 个增强功能串口(UART)的功能,分别为 UA…

LeetCode 面试题 02.07.链表相交(判断两个结点是否相同)

给你两个单链表的头节点 headA 和 headB ,请你找出并返回两个单链表相交的起始节点。如果两个链表没有交点,返回 null 。 图示两个链表在节点 c1 开始相交: 题目数据 保证 整个链式结构中不存在环。 注意,函数返回结果后&#x…

Qt中的网络通信

C没有封装专门的网络套接字的类,因此C只能调用C对应的API,而在Linux和Windows环境下的API都是不一样的 Qt作为一个C框架提供了相关封装好的套接字通信类 在Qt中需要用到两个类,两个类都属于network且都是属于IO操作,只不过这两个类…

ArcGIS Desktop使用入门(三)图层右键工具——缩放至图层、缩放至可见

系列文章目录 ArcGIS Desktop使用入门(一)软件初认识 ArcGIS Desktop使用入门(二)常用工具条——标准工具 ArcGIS Desktop使用入门(二)常用工具条——编辑器 ArcGIS Desktop使用入门(二&#x…

Java 怎么捕捉 Windows 中前台窗口的改变?

在Java中捕捉Windows中前台窗口的改变通常需要使用JNI(Java Native Interface)来调用Windows API。Windows API提供了一系列函数来获取有关窗口和进程的信息,通过使用这些函数,我们可以实现在Java程序中监视和捕捉Windows前台窗口…

抖音爬虫——点赞量

该爬虫模拟了一个get请求来得到返回json里面的点赞量信息 下面介绍如何使用: 首先,我们找一个浏览器打开抖音搜索具体的关键词 接着我们点击键盘的F12建 就会出现如下的界面,接着我们点击网络(可能再一些浏览器是叫network&…

JavaSE:this关键字(代码和内存图讲解)

this的含义 this代表当前对象,谁调用this所在的方法,this就代表谁 这句话非常重要 demo 以这段代码为例,setNum方法内部的this,setStr方法内部的this,还有构造方法ThisKeyword(int num, String str)内部的两个this…

软件库V1.2版本开源-首页UI优化

iAppV3源码,首页的分类更换成了标签布局,各位可以参考学习,界面名称已经中文标注! 老版本和现在的版本还是有较大的区别的,建议更新一下! 新版本改动界面如下: 1、首页.iyu:分类按…

基于javassm实现的幼儿教育管理系统

开发语言:Java 框架:ssm 技术:JSP JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 5.7(一定要5.7版本) 数据库工具:Navicat11 开发软件:eclipse/myeclip…

晶核职业选择:六大角色技能揭秘,成为战斗高手!

在晶核的世界中,每一位玩家都扮演着不同角色,组成多样的团队,共同踏上探索未知的征程。而每个角色都有其独特的技能和特点,下面将为你详细介绍每个角色的技能搭配和操作技巧,让你在战斗中游刃有余,一展自己…