用你的手机/电脑运行文生图方案

f8b0fc16a231ae0cd7d8fcf298dd8fde.gif

随着ChatGPT和Stable Diffusion的发布,最近一两年,生成式AI已经火爆全球,已然成为移动互联网后一个重要的“风口”。就图片/视频生成领域来说,Stable Diffusion模型发挥着极其重要的作用。由于Stable Diffusion模型参数量是10亿参数的大模型,通常业界都是运行部署在显卡上。

但是随着量化、剪枝等模型压缩技术的进步,以及手机等终端设备的算力、带宽、内存持续增大。使得大模型在终端设备部署也成为的可能。大模型在终端部署可以有效保护用户隐私,而且终端设备日常广泛使用、用户可以随时随地生成想要的内容。

7ec7e6256a7667e08014fbdff9b54cea.png

MNN-Diffusion使用

本文是深度学习推理引擎MNN团队,做的Stable Diffusion端侧部署应用,代码开源,用户可以自行DIY各种好玩的Stable Diffusion应用。

MNN开源地址:

https://github.com/alibaba/MNN/tree/master

欢迎大家试用,使用教程如下:

https://mnn-docs.readthedocs.io/en/latest/transformers/diffusion.html



下面是在个人手机/电脑上生成的图片:

bebe439c12b0e97687df89db324f0bad.png

技术要点

业界加速Stable Diffusion部署通常有两个方向,一是算法层面的优化,包括优化网络结构、减少计算量或者降低推理迭代步数;二是工程部署优化,通过量化/算子高效实现等方式提高硬件计算效率、提高访存效率。MNN作为推理引擎,主要聚焦在工程部署优化上,下面分享下MNN Diffusion GPU在性能/内存方面做了优化工作。

  Self-Attention优化

Transformer结构中Self-Attention是一个基础结构,也是性能耗时的关键。如下结构是一个典型的Attention结构:

30df654d025272d5fc22fc5fd4eb287e.png

一个共有节点,分别经过三个Linear层,得到Query/Key/Value,Query/Key经过形状变换进行BatchMatMul操作,再进行Scale,取Softmax操作;该结果和Value经过形状变换做BatchMatMul;之后把结果进行形状变换,得到最终的输出。可以看到上述总共有19个算子,包括12个形状变化算子,7个计算型算子。



大量的形状变化会带来很多的访存耗时,对于GPU高算力的硬件来说,访存耗时往往容易成为热点。因此,将上述结构,融合成2个算子,第一个是将三个Linear层权重融合在一起,只做一个Linear,这样形成更大的矩阵乘尺寸,更容易打满GPU算力,带来性能收益;第二个算子是将Attention算子融合成一个算子Fused-MultiHead-Attention,融合之后在该新算子内部仅需5个Kernel就可以实现整个Attention功能。消除了大量额外的形状变换算子,降低了访存压力,同时可以更容易基于Attention算子特性做进一步优化工作。

500e2baca3a6db86eb4d5d2fd9550b6f.png

  GroupNorm/SplitGeLU融合

在Stable Diffusion中,有一个通用的结构ResnetBlock,其中包含了BroadCast Binary + GroupNorm + SiLU结构,在onnx模型图结构中包含了如下13个算子:

333736b66abbeff889e57a0fad37192b.png

可以看到GroupNorm采用InstanceNorm+形变算子实现,gamma/beta被单独拆解为mul/add算子,细碎的算子会增加全局内存的访存次数、以及Kernel launch的压力。因此将上述通用结构合并成一个GroupNorm算子,该算子把前面的BroadCast Binary和后续的SiLU激活函数,融合在一起。高效的只需一个Kernel就可以实现上述计算需求。



同样的图融合原理,在Transformer激活函数中,Stable Diffusion Feed-Forward模块中采用GEGLU结构,对应onnx图结构如下。将该8个onnx图算子,融合为通用的SplitGeLU算子。

c5251af434249b000dafca534c294940.png

  conv-winograd算法实现

在Stable Diffusion中有大量3x3卷积,在深度学习中,Winograd算法已经大量应用在加速3x3卷积实现。

Winograd F(m, r)算法,其中m代表一个计算tile的大小,r对应filter的尺寸,d=m+r-1 代表对应input tile大小。

4f6d41a9cea75a5c83330159ce22669b.png

下表是3x3 Winograd不同tile对应计算量的节省比例和中间内存占用的增大比例。

m

r

d

计算量前后比例

input中间内存

weight中间内存

2

3

4

9 : 4 = 2.25x

4x

1.78x

4

3

6

4 : 1 = 4x

2.25x

4x

6

3

8

81 : 16 = 5.06x

1.78x

7.11x

目前,我们使用的是F(2, 3) Winograd,控制内存增大量,同时带来一倍的性能提升效果。

  高性能Gemm/BatchGemm

上述分析可以看出,Attention/卷积3x3,核心计算量在BatchGemm上,Linear层实际上就是Gemm运算。实际上,Stable Diffusion中,核心的计算量或者说耗时的热点,归根溯源,都集中在Gemm/BatchGemm上。如何高效实现矩阵乘法 成为最核心的关键。

矩阵乘在各个维度上的分块策略,可以有效提升数据的复用度和数据cache命中率;合理的分块可以为矩阵乘法带来大幅度的性能提升。

a6c30dcdb2a2f826e6023ee302c4c2ab.png

上图展示了,矩阵乘在各个维度上面的分块变量,包括在并发M/N维度,单次数据访存向量化位宽、每个线程存取矩阵的尺寸、每个工作组存取矩阵的尺寸,以及如果使用local memory缓存的话每个线程/工作组的缓存量。

这些参量都决定了数据访存的效率、并发量的大小、计算访存比的大小。不同的设备有不同的寄存器资源、共享内存资源、访存带宽、计算核心数,这些参量都决定着矩阵乘法的性能效率。



对于特定的矩阵乘的尺寸M/N/K,针对特定设备采取Auto-Tuning的获取最佳的运行参数(OPWM/OPWN/OPTM/OPTN/VEC_M/VEC_N等),Tuning候选集数量是M的N次方(N是参数的个数、M是每个参数候选集个数)。如果暴力循环每个参数候选集,由于候选集数量巨大、并且大尺寸矩阵乘本身单次运行耗时较大,必然会导致要花费大量时间去Tuning完所有候选集。因此,根据经验和实际试跑,选出部分高频参数候选集进行Tuning,在控制好Tuning时间的同时,也可以带来极大的性能收益。

  Gemm Strassen探索

由于矩阵乘法是Stable Diffusion耗时的核心,因此进行了矩阵乘快速算法的研究探索。Strassen算法是利用矩阵拆解,通过引入矩阵加减法,来减少矩阵乘法次数的方式。最简单的方法,将M/N/K维度各对拆1/2的方法,朴素的矩阵拆解如下:

0b94ad48464fc4675377f555b66e0b0e.png

Strassen算法,通过15次子矩阵加减法,来减少一次子矩阵乘法。矩阵拆解如下:

46a06996401968f0e405de7c014f6a98.png

当N足够大时,矩阵加减法耗时会远低于矩阵乘法耗时,带来12.5%的计算量降低。当N较小时,受限于15次 子矩阵加减的 耗时,以及拆解子矩阵乘法算力打不满等损耗原因,将引起负优化。具体某个形状的矩阵乘法适不适合使用Strassen算法?



对于矩阵A形状为[M, K], 矩阵B形状为[N, K],输出矩阵C形状为[M, N]。15次子矩阵加减,数据访存量为:(3*M*K + 3*N*K + 3.5*M*N) * sizeof(DataType) Bytes。1次子矩阵乘法,数据计算量为:1/8 * M*N*K * 2 = 1/4 * M*N*K FLOPS。我们默认矩阵加减是带宽瓶颈,矩阵乘法是算力瓶颈。假设设备的内存带宽为X GB/s,算力是Y GFLOPS。

子矩阵加减耗时:(6*M*K + 6*N*K + 3.5*M*N)*sizeof(DataType) / X (ns)

子矩阵乘节省耗时:(1/4 * M * N * K) / Y (ns)



当节省的耗时大于损耗耗时,即可有性能收益。根据上述公式,计算访存比越低的设备,Strassen算法越容易有收益。对于手机设备来说,1024x1024x1024的子矩阵,通常可以获得约10%的性能收益。

  内存占用优化

在Attention优化中,Q/K做BatchMatMul得到中间数据QK时,张量维度为[Batch, HeadNum, SeqLen, SeqLen]。对于Stable Diffusion来说,会遇到Batch=2,HeadNum=16,SeqLen=4096。对于float16的数据类型,单个张量的存储就需要1GB的内存大小,这对于内存资源紧缺的端侧设备是不可接受的。

876e1c1eb1ae66fb68659378e93eced9.png

因此,将Attention操作进行分块处理,类似Paged Attention的思路,将整个Attention分成SeqNum次执行,这样每次仅需原先1/SeqNum中间内存大小,可以非常有效的控制内存的大小。

性能测评

MNN Stable Diffusion应用,生成512x512图片,在骁龙8Gen3上使用GPU float16精度达到2s/iter (20次迭代,手机上40s可以生成完一幅图),在Apple Mac M3上GPU float32精度达到1.1s/iter (20次迭代,Mac上22s可以生成完一幅图)。MNN CPU/GPU性能均较大幅度快于如下Stable Diffusion开源框架,例如:

  • stable-diffusion.cpp

    https://github.com/leejet/stable-diffusion.cpp/issues/15

  • Android OnnxRuntime Stable Diffusion应用

    https://github.com/ZTMIDGO/Android-Stable-diffusion-ONNX

9e9496859a04ea0ac02b2af87f69bd83.png

后续研究

后续在性能优化和内存优化上面仍然有空间可以挖掘。

性能优化方面:

  • Conv Winograd采用更大的分块,获取更高的计算量降低收益。

  • 矩阵乘尝试Image存储内存访问模式,提高访存效率。

  • Attention进一步采用Flash Attention等思路优化。

内存占用优化方面:

  • 采用低比特权重(int8/int4量化)。

  • 在线转换动态内存可复用,Conv Winograd权重尝试采用在线转换。

  • Attention 采用Flash Attention优化节省中间内存使用。

8c56100ad51dd1b6565cc7b815d90b01.png

参考资料

  • https://blog.csdn.net/xian0710830114/article/details/129194419

  • https://github.com/NVIDIA/TensorRT/tree/release/8.6/demo/Diffusion

  • https://arxiv.org/abs/0707.2347

  • https://courses.cs.cornell.edu/cs6810/2023fa/Matrix.pdf

  • https://github.com/CNugteren/CLBlast/tree/master

  • https://arxiv.org/pdf/1703.06503

  • https://github.com/leejet/stable-diffusion.cpp/

  • https://github.com/ZTMIDGO/Android-Stable-diffusion-ONNX

e2e86c44f7745eb2a3c1103c68e973db.png

团队介绍

我们是大淘宝技术Meta Team,负责面向消费场景的3D/XR基础技术建设和创新应用探索,通过技术和应用创新找到以手机及XR 新设备为载体的消费购物3D/XR新体验。团队在端智能、商品三维重建、3D引擎、XR引擎等方面有深厚的技术积累。团队在OSDI、MLSys、CVPR、ICCV、NeurIPS、TPAMI等顶级学术会议和期刊上发表多篇论文。

¤ 拓展阅读 ¤

3DXR技术 | 终端技术 | 音视频技术

服务端技术 | 技术质量 | 数据算法

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/895985.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

读者写者问题与读写锁

读者写者问题 读者写者 vs 生产消费 重点是有什么区别 读者写者问题如何理解 重点理解读者和写者如何完成同步 下面是一段伪代码:公共部分 uint32_t reader_count 0; lock_t count_lock; lock_t writer_lock; Reader // 加锁 lock(count_lock); if(reader_c…

Java | Leetcode Java题解之第492题构造矩形

题目: 题解: class Solution {public int[] constructRectangle(int area) {int w (int) Math.sqrt(area);while (area % w ! 0) {--w;}return new int[]{area / w, w};} }

5G物联网主机引领企业数字化转型

在当今这个信息化高度发展的时代,企业的竞争力很大程度上取决于其能否快速适应市场变化并高效地进行内部管理。郑州龙兴物联科技有限公司凭借其先进的5G物联网技术,推出了为企业量身定制的5G物联网主机,该设备充分利用其多协议、多接口的特点…

ESP32-C3 入门笔记04:gpio_key 按键 (ESP-IDF + VSCode)

1.GPIO简介 ESP32-C3是QFN32封装,GPIO引脚一共有22个,从GPIO0到GPIO21。 理论上,所有的IO都可以复用为任何外设功能,但有些引脚用作连接芯片内部FLASH或者外部FLASH功能时,官方不建议用作其它用途。 通过开发板的原…

【Vue】Vue3.0 (十二)、watchEffect 和watch的区别及使用

上篇文章: 【Vue】Vue3.0 (十二)、watch对ref定义的基本类型、对象类型;reactive定义的对象类型的监视使用 🏡作者主页:点击! 🤖Vue专栏:点击! ⏰️创作时间&…

数据仓库基础概念

数据仓库 概念 数据仓库(Data Warehouse, DW)是一个面向主题的、集成的、相对稳定的、反映历史变化的数据集合。它是为满足企业决策分析需求而设计的。 面向主题:数据仓库围绕特定的主题组织数据,例如“销售”或“人力资源”&am…

线上交友小程序源码系统 一元盲盒小程序在线开好友 带完整的安装代码包以及搭建部署教程

系统概述 线上交友小程序源码系统是基于先进的技术架构开发的一套完整的解决方案,旨在为用户提供一个便捷、有趣的线上交友平台。该系统通过一元盲盒的形式,让用户在未知中寻找惊喜,增加了交友的趣味性和神秘感。 该系统采用了先进的编程技…

UE5蓝图中忽略触发区域进行碰撞

Event Hit :只会在碰撞到实体的时候产生碰撞。如果是触发区域则会忽略。 Destroy Actor:销毁自身。

openrtp 音视频时间戳问题

解决音视频发送的rtp问题 openrtp增加了音频aac的发送,地址 OpenRTP Gitee开源地址 同时使用两个rtp ,来发送音频和视频 使用以下音频rtp,是可以发送和接收的,音频端口在视频端口上2 v0 o- 0 0 IN IP4 127.0.0.1 sMy Stream cI…

sentinel dashboard分布式改造落地设计实现解释(二)-分布式discovery组件

discovery discovery负责维护app/机器资料库,transport健康检测, transport上下线处理。discovery关键是分布式存储,后续研究一下raft,其复制,状态机,快照技术,但个人觉得,discover…

【网络安全】护网蓝队之应急响应

蓝队技术栈 Linux入侵排查 系统排查 一、查看历史命令 在Linux系统中,检查历史命令记录是安全审计的重要步骤之一,它可以帮助您了解系统上用户(包括潜在的黑客)的活动。以下是对您描述的重新表述和补充: 检查历史命…

webpack自定义插件 ChangeScriptSrcPlugin

插件文件 class ChangeScriptSrcPlugin {apply(compiler) {const pluginName "ChangeScriptSrcPlugin";compiler.hooks.compilation.tap(pluginName, (compilation, callback) > {compilation.hooks.htmlWebpackPluginAlterAssetTags.tapAsync(pluginName,(html…

LabVIEW提高开发效率技巧----节省内存

在LabVIEW开发过程中,内存管理是保障程序稳定性和性能的关键。本文将详细介绍如何通过队列处理来节省内存,尤其是如何通过解耦释放不再需要的数据,防止内存泄漏。通过多个实际例子,从不同角度探讨队列处理在大数据量或长时间运行的…

使用Airtest自动化某云音乐爬取歌曲名称

简介 本文将介绍如何使用Airtest自动化工具来模拟用户操作,从某云音乐中爬取与特定关键词相关的歌曲名称。我们将以搜索“文字”相关的歌曲为例,并将结果保存到本地文件。 准备工作 安装Airtest并配置好Android设备或模拟器。确保你的设备上已安装某云…

C0027.在Clion中解决CPU和内存过高的问题

解决办法 最新版的 clion 在 advance setting里,可以勾选 Use the Resharper C language engine (CLion Nova)。 有显著的性能提升。

深入探索JavaCV:功能强大的Java计算机视觉库

🧑 博主简介:历代文学网(PC端可以访问:https://literature.sinhy.com/#/literature?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编程,…

积木报表静态资源不生效,界面乱码 解决方法

目录 前言1. 问题所示2. 原理分析3. 解决方法前言 从实战中分析问题,解决问题,以下笔记学习为主 关于JimuReport的网站:文档中心 1. 问题所示 引入积木报表之后,界面静态文件不生效,最终截图如下: 大致浏览器终端报错如下: 基本信息如下: Uncaught SyntaxError: U…

项目管理的坎坷之路与 MBTI 的启示录

项目管理这一路走来,经历了无数的坎坷、不顺和阻碍。幸运的是,遇见 MBTI 之后,我仿佛看到了新的希望,终于我也看到了花团锦簇,也看到了灯彩佳话。那一夜,我也曾梦见百万雄兵。 什么是 MBTI ? M…

AI大模型学习路线路径,巨详细!

大模型技术已经成为推动人工智能发展的关键力量。无论你是初学者还是有经验的开发者,想要掌握大模型应用,都需要遵循一定的学习路线。 从核心技术解析到模型微调与私有化部署,逐步深入大模型应用的世界。 这份学习路线图详细的介绍了那年每…

gitee建立/取消关联仓库

目录 一、常用指令总结 二、建立关联具体操作 三、取消关联具体操作 一、常用指令总结 首先要选中要关联的文件,右击,选择Git Bash Here。 git remote -v //查看自己的文件有几个关联的仓库git init //初始化文件夹为git可远程建立链接的文件夹…