深入浅出 diffusion(3):pytorch 实现 diffusion 中的 U-Net

导入python包

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

 silu激活函数

class SiLU(nn.Module):  
    # SiLU激活函数
    @staticmethod
    def forward(x):
        return x * torch.sigmoid(x)

归一化设置

def get_norm(norm, num_channels, num_groups):
    if norm == "in":
        return nn.InstanceNorm2d(num_channels, affine=True)
    elif norm == "bn":
        return nn.BatchNorm2d(num_channels)
    elif norm == "gn":
        return nn.GroupNorm(num_groups, num_channels)
    elif norm is None:
        return nn.Identity()
    else:
        raise ValueError("unknown normalization type")

 计算时间步长的位置嵌入,一半为sin,一半为cos

class PositionalEmbedding(nn.Module):
    def __init__(self, dim, scale=1.0):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.scale = scale

    def forward(self, x):
        device      = x.device
        half_dim    = self.dim // 2
        emb = math.log(10000) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        # x * self.scale和emb外积
        emb = torch.outer(x * self.scale, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

 上下采样层设置

class Downsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)
    
    def forward(self, x, time_emb, y):
        if x.shape[2] % 2 == 1:
            raise ValueError("downsampling tensor height should be even")
        if x.shape[3] % 2 == 1:
            raise ValueError("downsampling tensor width should be even")

        return self.downsample(x)

class Upsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
        )
        
    def forward(self, x, time_emb, y):
        return self.upsample(x)

 使用Self-Attention注意力机制,做一个全局的Self-Attention

class AttentionBlock(nn.Module):
    def __init__(self, in_channels, norm="gn", num_groups=32):
        super().__init__()
        
        self.in_channels = in_channels
        self.norm = get_norm(norm, in_channels, num_groups)
        self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
        self.to_out = nn.Conv2d(in_channels, in_channels, 1)

    def forward(self, x):
        b, c, h, w  = x.shape
        q, k, v     = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)

        q = q.permute(0, 2, 3, 1).view(b, h * w, c)
        k = k.view(b, c, h * w)
        v = v.permute(0, 2, 3, 1).view(b, h * w, c)

        dot_products = torch.bmm(q, k) * (c ** (-0.5))
        assert dot_products.shape == (b, h * w, h * w)

        attention   = torch.softmax(dot_products, dim=-1)
        out         = torch.bmm(attention, v)
        assert out.shape == (b, h * w, c)
        out         = out.view(b, h, w, c).permute(0, 3, 1, 2)

        return self.to_out(out) + x

 用于特征提取的残差结构

class ResidualBlock(nn.Module):
    def __init__(
        self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=F.relu,
        norm="gn", num_groups=32, use_attention=False,
    ):
        super().__init__()

        self.activation = activation

        self.norm_1 = get_norm(norm, in_channels, num_groups)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        self.norm_2 = get_norm(norm, out_channels, num_groups)
        self.conv_2 = nn.Sequential(
            nn.Dropout(p=dropout), 
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
        )

        self.time_bias  = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
        self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None

        self.residual_connection    = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.attention              = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)
    
    def forward(self, x, time_emb=None, y=None):
        out = self.activation(self.norm_1(x))
        # 第一个卷积
        out = self.conv_1(out)
        
        # 对时间time_emb做一个全连接,施加在通道上
        if self.time_bias is not None:
            if time_emb is None:
                raise ValueError("time conditioning was specified but time_emb is not passed")
            out += self.time_bias(self.activation(time_emb))[:, :, None, None]

        # 对种类y_emb做一个全连接,施加在通道上
        if self.class_bias is not None:
            if y is None:
                raise ValueError("class conditioning was specified but y is not passed")

            out += self.class_bias(y)[:, :, None, None]

        out = self.activation(self.norm_2(out))
        # 第二个卷积+残差边
        out = self.conv_2(out) + self.residual_connection(x)
        # 最后做个Attention
        out = self.attention(out)
        return out

 U-Net模型设计

class UNet(nn.Module):
    def __init__(
        self, img_channels, base_channels=128, channel_mults=(1, 2, 2, 2),
        num_res_blocks=2, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=F.silu,
        dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,
    ):
        super().__init__()
        # 使用到的激活函数,一般为SILU
        self.activation = activation
        # 是否对输入进行padding
        self.initial_pad = initial_pad
        # 需要去区分的类别数
        self.num_classes = num_classes
        
        # 对时间轴输入的全连接层
        self.time_mlp = nn.Sequential(
            PositionalEmbedding(base_channels, time_emb_scale),
            nn.Linear(base_channels, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        ) if time_emb_dim is not None else None
    
        # 对输入图片的第一个卷积
        self.init_conv  = nn.Conv2d(img_channels, base_channels, 3, padding=1)

        # self.downs用于存储下采样用到的层,首先利用ResidualBlock提取特征
        # 然后利用Downsample降低特征图的高宽
        self.downs      = nn.ModuleList()
        self.ups        = nn.ModuleList()
        
        # channels指的是每一个模块处理后的通道数
        # now_channels是一个中间变量,代表中间的通道数
        channels        = [base_channels]
        now_channels    = base_channels
        for i, mult in enumerate(channel_mults):
            out_channels = base_channels * mult
            for _ in range(num_res_blocks):
                self.downs.append(
                    ResidualBlock(
                        now_channels, out_channels, dropout,
                        time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                        norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
                    )
                )
                now_channels = out_channels
                channels.append(now_channels)
            
            if i != len(channel_mults) - 1:
                self.downs.append(Downsample(now_channels))
                channels.append(now_channels)

        # 可以看作是特征整合,中间的一个特征提取模块
        self.mid = nn.ModuleList(
            [
                ResidualBlock(
                    now_channels, now_channels, dropout,
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                    norm=norm, num_groups=num_groups, use_attention=True,
                ),
                ResidualBlock(
                    now_channels, now_channels, dropout,
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, 
                    norm=norm, num_groups=num_groups, use_attention=False,
                ),
            ]
        )

        # 进行上采样,进行特征融合
        for i, mult in reversed(list(enumerate(channel_mults))):
            out_channels = base_channels * mult

            for _ in range(num_res_blocks + 1):
                self.ups.append(ResidualBlock(
                    channels.pop() + now_channels, out_channels, dropout, 
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, 
                    norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
                ))
                now_channels = out_channels
            
            if i != 0:
                self.ups.append(Upsample(now_channels))
        
        assert len(channels) == 0
        
        self.out_norm = get_norm(norm, base_channels, num_groups)
        self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)
    
    def forward(self, x, time=None, y=None):
        # 是否对输入进行padding
        ip = self.initial_pad
        if ip != 0:
            x = F.pad(x, (ip,) * 4)

        # 对时间轴输入的全连接层
        if self.time_mlp is not None:
            if time is None:
                raise ValueError("time conditioning was specified but tim is not passed")
            time_emb = self.time_mlp(time)
        else:
            time_emb = None
        
        if self.num_classes is not None and y is None:
            raise ValueError("class conditioning was specified but y is not passed")
        
        # 对输入图片的第一个卷积
        x = self.init_conv(x)

        # skips用于存放下采样的中间层
        skips = [x]
        for layer in self.downs:
            x = layer(x, time_emb, y)
            skips.append(x)
        
        # 特征整合与提取
        for layer in self.mid:
            x = layer(x, time_emb, y)
        
        # 上采样并进行特征融合
        for layer in self.ups:
            if isinstance(layer, ResidualBlock):
                x = torch.cat([x, skips.pop()], dim=1)
            x = layer(x, time_emb, y)

        # 上采样并进行特征融合
        x = self.activation(self.out_norm(x))
        x = self.out_conv(x)
        
        if self.initial_pad != 0:
            return x[:, :, ip:-ip, ip:-ip]
        else:
            return x

参考链接:GitCode - 开发者的代码家园icon-default.png?t=N7T8https://gitcode.com/bubbliiiing/ddpm-pytorch/tree/master?utm_source=csdn_github_accelerator&isLogin=1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/350999.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java集合-Map接口(key-value)

Map接口的特点:①KV键值对方式存储②Key键唯一,Value允许重复③无序。 Map有四个实现类:1.HashMap类2.LinkedHashMap类3.TreeMap类4.Hashtable类 1.HashMap类: 存储结构:哈希表 数组Node[ ] 链表(红黑…

暴力破解

暴力破解工具使用汇总 1.查看密码加密方式 在线网站:https://cmd5.com/ http://www.158566.com/ https://encode.chahuo.com/kali:hash-identifier2.hydra 用于各种服务的账号密码爆破:FTP/Mysql/SSH/RDP...常用参数 -l name 指定破解登录…

BUUCTF--xor1

这题考察的是亦或。查壳: 无壳。看下IDA的流程: 我们看到将用户输入做一个异或操作,然后和一个变量做比较。如果相同则输出Success。这里的知识点就是两次异或会输出原文。因此我们只需要把global再做一次异或就能解出flag。在IDA中按住shift…

Excel 2019 for Mac/Win:商务数据分析与处理的终极工具

在当今快节奏的商业环境中,数据分析已经成为一项至关重要的技能。从市场趋势预测到财务报告,再到项目管理,数据无处不在。而作为数据分析的基石,Microsoft Excel 2019 for Mac/Win正是一个强大的工具,帮助用户高效地处…

鸿蒙HarmonyOS获取GPS精确位置信息

参考官方文档 #1.初始化时获取经纬度信息 aboutToAppear() {this.getLocation() } async getLocation () {try {const result await geoLocationManager.getCurrentLocation()AlertDialog.show({message: JSON.stringify(result)})}catch (error) {AlertDialog.show({message…

RabbitMQ 笔记二

1.Spring 整合RabbitMQ 生产者消费者 创建生产者工程添加依赖配置整合编写代码发送消息 创建消费者工程添加依赖配置整合编写消息监听器 2.创建工程RabbitMQ Producers spring-rabbitmq-producers <?xml version"1.0" encoding"UTF-8"?> <pr…

[TCP协议]基于TCP协议的字典服务器

目录 1.TCP协议简介: 2.TCP协议在Java中封装的类以及方法 3.字典服务器 3.1服务器代码: 3.2客户端代码: 1.TCP协议简介: TCP协议是一种有连接,面向字节流,全双工,可靠的网络通信协议.它相对于UDP协议来说有以下几点好处: 1.它是可靠传输,相比于UDP协议,传输的数据更加可靠…

两个近期的计算机领域国际学术会议(软件工程、计算机安全):欢迎投稿

近期&#xff0c;受邀担任两个国际学术会议的Special session共同主席及程序委员会成员&#xff08;TPC member&#xff09;&#xff0c;欢迎广大学界同行踊跃投稿&#xff0c;分享最新研究成果。期待这个夏天能够在夏威夷檀香山或者加利福尼亚圣荷西与各位学者深入交流。 SERA…

深度学习-搭建Colab环境

Google Colab(Colaboratory) 是一个免费的云端环境&#xff0c;旨在帮助开发者和研究人员轻松进行机器学习和数据科学工作。它提供了许多优势&#xff0c;使得编写、执行和共享代码变得更加简单和高效。Colab 在云端提供了预配置的环境&#xff0c;可以直接开始编写代码&#x…

React中使用LazyBuilder实现页面懒加载方法二

前言&#xff1a; 在一个表格中&#xff0c;需要展示100条数据&#xff0c;当每条数据里面需要承载的内容很多&#xff0c;需要渲染的元素也很多的时候&#xff0c;容易造成页面加载的速度很慢&#xff0c;不能给用户提供很好的体验时&#xff0c;懒加载是优化页面加载速度的方…

【网络协议测试】畸形数据包——圣诞树攻击(DOS攻击)

简介 TCP所有标志位被设置为1的数据包被称为圣诞树数据包&#xff08;XMas Tree packet&#xff09;&#xff0c;之所以叫这个名是因为这些标志位就像圣诞树上灯一样全部被点亮。 标志位介绍 TCP报文格式&#xff1a; 控制标志&#xff08;Control Bits&#xff09;共6个bi…

【虚拟机数据恢复】异常断电导致虚拟机无法启动的数据恢复案例

虚拟机数据恢复环境&#xff1a; 某品牌R710服务器MD3200存储&#xff0c;上层是ESXI虚拟机和虚拟机文件&#xff0c;虚拟机中存放有SQL Server数据库。 虚拟机故障&#xff1a; 机房非正常断电导致虚拟机无法启动。服务器管理员检查后发现虚拟机配置文件丢失&#xff0c;所幸…

JPDA框架和JDWP协议

前言 在逆向开发中,一般都需要对目标App进行代码注入。主流的代码注入工具是Frida,这个工具能稳定高效实现java代码hook和native代码hook,不过缺点是需要使用Root设备,而且用js开发,入门门槛较高。最近发现一种非Root环境下对Debug App进行代码注入的方案,原理是利用Jav…

Unity MonoBehaviour 生成dll

dllllllllllllll&#x1f953; &#x1f959;vs创建类库项目&#x1f9c0;添加UnityEngine、UnityEditor引用&#x1f355;添加MonoBehaviour类&#x1f9aa;设置dll生成路径&#x1f37f;生成dll&#x1f354;使用dll中的Mono类 &#x1f959;vs创建类库项目 &#x1f9c0;添加…

qiankun子应用静态资源404问题有效解决(涉及 css文件引用图片、svg图片无法转换成 base64等问题)

在&#x1f449;&#x1f3fb; qiankun微前端部署&#x1f448;&#x1f3fb;这个部署方式的前提下&#xff0c;遇到的问题并解决问题的过程 最开始的问题现象 通过http请求本地的静态json文件404css中部分引入的图片无法显示 最开始的解决方式 在&#x1f449;&#x1f3…

YOLO系列(YOLO1-YOLO5)技术规格、应用场景、特点及性能对比分析

文章目录 前言一、YOLOv1-YOLOv5技术规格对比&#xff1a;二、主要应用场景和特点&#xff1a;三、性能对比分析&#xff1a;四、市场应用前景及对不同用户群体的潜在影响&#xff1a;总结 前言 YOLO&#xff08;You Only Look Once&#xff09;系列模型作为一种实时目标检测算…

OpenAI 降低价格并修复拒绝工作的“懒惰”GPT-4,另外ChatGPT 新增了两个小功能

OpenAI降低了GPT-3.5 Turbo模型的API访问价格&#xff0c;输入和输出价格分别降低了50%和25%。这对于使用API进行文本密集型应用程序的用户来说是一个好消息。 OpenAI官网&#xff1a;OpenAI AIGC专区&#xff1a;aigc 教程专区&#xff1a;AI绘画&#xff0c;AI视频&#x…

npm,cnpm install报:Error: certificate has expired at TLSSocket.onConnectSecure

问题描述 最近发现前端项目 CI/CD 时失败&#xff0c;报下面的错误。npm淘宝镜像源证书过期导致的。 [npminstall:get] retry GET https://registry.npm.taobao.org/vue-router after 400ms, retry left 1, error: ResponseError: certificate has expired, GET https://reg…

【Unity小技巧】一个脚本实现控制3D远程/近战敌人AI

最终效果 文章目录 最终效果烘培导航地图配置敌人导航数据简单配置敌人动画敌人AI脚本完结 想了解导航的其他内容可以看我这篇文章&#xff1a;【Unity游戏开发教程】零基础带你从小白到超神29——导航系统 烘培导航地图 选中地面&#xff0c;设置为静态导航 点击烘培&#xf…

《动手学深度学习(PyTorch版)》笔记4.4

注&#xff1a;书中对代码的讲解并不详细&#xff0c;本文对很多细节做了详细注释。另外&#xff0c;书上的源代码是在Jupyter Notebook上运行的&#xff0c;较为分散&#xff0c;本文将代码集中起来&#xff0c;并加以完善&#xff0c;全部用vscode在python 3.9.18下测试通过。…