代码解读 | Hybrid Transformers for Music Source Separation[05]

一、背景

        0、Hybrid Transformer 论文解读

        1、代码复现|Demucs Music Source Separation_demucs架构原理-CSDN博客

        2、Hybrid Transformer 各个模块对应的代码具体在工程的哪个地方

        3、Hybrid Transformer 各个模块的底层到底是个啥(初步感受)?

        4、Hybrid Transformer 各个模块处理后,数据的维度大小是咋变换的?

        5、Hybrid Transformer 拆解STFT模块


        从模块上划分,Hybrid Transformer Demucs 共包含 (STFT模块、时域编码模块、频域编码模块、Cross-Domain Transformer Encoder模块、时域解码模块、频域解码模块、ISTFT模块)7个模块。

        本篇目标:拆解频域编码模块的底层

        时域编码和频域编码原理类似(后续不再拆解时域编码模块)。

二、频域编码模块


class HEncLayer(nn.Module):
    def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
                 freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
                 rewrite=True):
        """Encoder layer. This used both by the time and the frequency branch.

        Args:
            chin: number of input channels.
            chout: number of output channels.
            norm_groups: number of groups for group norm.
            empty: used to make a layer with just the first conv. this is used
                before merging the time and freq. branches.
            freq: this is acting on frequencies.
            dconv: insert DConv residual branches.
            norm: use GroupNorm.
            context: context size for the 1x1 conv.
            dconv_kw: list of kwargs for the DConv class.
            pad: pad the input. Padding is done so that the output size is
                always the input size / stride.
            rewrite: add 1x1 conv at the end of the layer.
        """
        super().__init__()
        norm_fn = lambda d: nn.Identity()  # noqa
        if norm:
            norm_fn = lambda d: nn.GroupNorm(norm_groups, d)  # noqa
        if pad:
            pad = kernel_size // 4
        else:
            pad = 0
        klass = nn.Conv1d
        self.freq = freq
        self.kernel_size = kernel_size
        self.stride = stride
        self.empty = empty
        self.norm = norm
        self.pad = pad
        if freq:
            kernel_size = [kernel_size, 1]
            stride = [stride, 1]
            pad = [pad, 0]
            klass = nn.Conv2d
        self.conv = klass(chin, chout, kernel_size, stride, pad)
        if self.empty:
            return
        self.norm1 = norm_fn(chout)
        self.rewrite = None
        if rewrite:
            self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
            self.norm2 = norm_fn(2 * chout)

        self.dconv = None
        if dconv:
            self.dconv = DConv(chout, **dconv_kw)

    def forward(self, x, inject=None):
        """
        `inject` is used to inject the result from the time branch into the frequency branch,
        when both have the same stride.
        """
        if not self.freq and x.dim() == 4:
            B, C, Fr, T = x.shape
            x = x.view(B, -1, T)

        if not self.freq:
            le = x.shape[-1]
            if not le % self.stride == 0:
                x = F.pad(x, (0, self.stride - (le % self.stride)))
        y = self.conv(x)
        if self.empty:
            return y
        if inject is not None:
            assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
            if inject.dim() == 3 and y.dim() == 4:
                inject = inject[:, :, None]
            y = y + inject
        y = F.gelu(self.norm1(y))
        if self.dconv:
            if self.freq:
                B, C, Fr, T = y.shape
                y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
            y = self.dconv(y)
            if self.freq:
                y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
        if self.rewrite:
            z = self.norm2(self.rewrite(y))
            z = F.glu(z, dim=1)
        else:
            z = y
        return z

        核心代码如上所示。

        使用print函数打印出各个关键节点的信息,可以得到频域编解码模块的全景图。

        编码层:Conv2d+Norm1+GELU,  Norm1:Identity()

        残差连接:(Conv1d+GroupNorm+GELU +Conv1d+GroupNorm+GLU+LayerScale())

        +(Conv2d+Norm2+GLU),Norm2:Identity() ,备注:Identity可以理解成直通

#上图均是自己读完代码绘制的。相信自己也可以。
#具体的,编码层1-4的Conv2d分别是:
Conv2d(4, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(48, 96, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(96, 192, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(192, 384, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
#残差连接1
DConv(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 6, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 96, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
    (1): Sequential(
      (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 6, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 96, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
  )
)
Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))

#残差连接2
DConv(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 12, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 192, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
    (1): Sequential(
      (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 12, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 192, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
  )
)
Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1))

#残差连接3
DConv(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 24, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 384, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
    (1): Sequential(
      (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 24, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 384, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
  )
)
Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1))

#残差连接4
DConv(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 48, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 768, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
    (1): Sequential(
      (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 48, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 768, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
  )
)
Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))

        关于,各个卷积模块输出数据的shape计算,可以读这篇文章。

        没有所谓天生的大佬,如果有那么我愿称他/她为圣人。我相信,能读到这儿的都会成为大佬~。Believe yourself,one day,you will be somebody.


         感谢阅读,最近开始写公众号(分享好用的AI工具),欢迎大家一起见证我的成长(桂圆学AI)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/705616.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Tomcat部署以及虚拟主机

概论 Tomcat 是 Java 语言开发的,Tomcat 服务器是一个免费的开放源代码的 Web 应用服务器,是 Apache 软件基金会的 Jakarta 项目中的一个核心项目,由 Apache、Sun 和其他一些公司及个人共同开发而成。 Tomcat的组成 Tomcat 由一系列的组件构…

黑苹果睡眠总是自动唤醒(RTC)

黑苹果睡眠总是自动唤醒【RTC】 1. 问题2. 解决方案2.1. 查看重启日志2.2. 配置Disable RTC wake scheduling补丁 3. 后续4. 参考 1. 问题 黑苹果EFI 更换后,总是在手动 睡眠后,间歇性重启,然后再次睡眠,然后再重启。原因归结为&…

界面控件DevExpress WinForms垂直属性网格组件 - 拥有更灵活的UI选择(一)

DevExpress WinForms垂直&属性网格组件旨在提供UI灵活性,它允许用户显示数据集中的单个行或在其90度倒置网格容器中显示多行数据集。另外,用户可以把它用作一个属性网格,就像在Visual Studio IDE中那样。 P.S:DevExpress Win…

【软件测试】遇到bug怎么分析,这篇文章值得一看

为什么定位问题如此重要? 可以明确一个问题是不是真的“bug” 很多时候,我们找到了问题的原因,结果发现这根本不是bug。原因明确,误报就会降低 多个系统交互,可以明确指出是哪个系统的缺陷,防止“踢皮球…

OneNet创建产品和设备

onenet平台网址 https://open.iot.10086.cn/console/device/manage/devs?pidn5Yw89el5t 产品创建二号设备创建在下文中具有详细讲解 选择设备管理后,点击蓝色的添加设备按钮来添加设备 点击添加设备后,进入如下界面。设备所属产品和设备名称如下图设置…

RK3568技术笔记 Ubuntu 安装VMware Tools

安装 VMware Tools 后可以直接使用复制粘贴功能拷贝 Ubuntu 系统和 windows 主机内的文件,非常方便。 开启虚拟机,必须要进入ubuntu系统后才能进行下面的步骤。 单击 VMware 软件中的标签“虚拟机”,在下拉的菜单中单击“安装VMware Tools &…

技术革新,智绘未来丨悦数图数据库 v5.0 重磅亮相 WAIC 2024

本次 WAIC(世界人工智能大会)2024 将于7 月 4 日- 7 日在上海世博展览馆**举行,本次 WAIC 2024 围绕“以共商促共享 以善治促善智”为主题,杭州悦数科技有限公司将携最新的悦数图数据库 v5.0 亮相 E805 展位。作为国内领先的图数据…

使用GPT/文心实现诗词作画

在教育领域中,古诗词一直是培养学生文化素养和审美能力的重要载体。选择合适的古诗词进行学习和欣赏,不仅能够增强他们的语言表达能力,还能促进他们对中国传统文化的理解和热爱。本文将结合AI技术,将古诗词转换为图画。 1、选择适…

WWDC 2024 回顾:Apple Intelligence 的发布与解析

一年一度的苹果全球开发者大会(WWDC)如期而至,2024 年的 WWDC 再次成为科技界的焦点。本次发布会中,苹果正式推出了他们在 AI 领域的全新战略——Apple Intelligence。这一全新概念旨在为用户打造“强大、易用、全面、个性化、注重…

setOptMode -holdTargetSlack与-holdSlackFixingThreshod

我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起来吧? 拾陆楼知识星球入口 -holdTargetSlack与-holdSlackFixingThreshod这两个option都是针对hold slack的,前者限制slack的目标,默认是0,也就是说工具尽可能会收敛时序…

查分易怎么生成二维码

现在,家长和学生对于成绩查询的需求不断增长。教给各位新手教师一个简单又高效的查询工具——查分易小程序。可以为繁杂的工作做减法,也让学生和家长随时查看自己的学习情况。 查分易因为安全、便捷、高效,成为了众多学校和老师的首选。能够快…

【云服务器介绍】选择指南 腾讯云 阿里云全配置对比 搭建web 个人开发 app 游戏服务器

​省流目录:适用于博客建站(2-4G)、个人开发/小型游戏[传奇/我的世界/饥荒](4-8G)、数据分析/大型游戏[幻兽帕鲁/雾锁王国]服务器(16-64G) 1.京东云-618专属活动 官方采购季专属活动地址&#x…

Python写UI自动化--playwright(元素定位)

本篇详细分享playwright如何进行打断点、元素定位、填写输入框、点击等操作 目录 一、PyCharm打断点进行调试 二、浏览器开发者模式检查元素 三、通过CSS或XPath进行定位 四、输入框输入文本操作 五、点击操作 总结 一、PyCharm打断点进行调试 如图所示,我们…

深度学习之激活函数

激活函数(Activation Function)是一种添加到人工神经网络中的函数,旨在帮助网络学习数据中的复杂模式。在神经元中,输入的input经过一系列加权求和后作用于另一个函数,这个函数就是这里的激活函数。 1. 为什么需要激活…

雷神电脑怎么找文件所在位置?四个方法让你轻松上手

在数字化时代,电脑文件的管理与存储显得尤为重要。对于使用雷神电脑的用户而言,了解如何快速定位文件所在位置,以及在文件丢失时采取有效的应对措施,是提升工作效率、保障数据安全的关键。本文将围绕这两个核心问题展开&#xff0…

怎么更快捷的修改图片大小?压缩图片jpg、png、gif的快捷方法

jpg作为最常用的一种图片格式,在遇到图片太大问题时,该如何操作能够快速在压缩图片jpg的大小呢?图片太大无法上传时目前常见的一个使用问题,只有将图片处理到合适的大小才可以正常在平台上传使用,一般情况下想要快速解…

Android帧绘制流程深度解析 (二)

书接上回:Android帧绘制流程深度解析 (一) 5、 dispatchVsync: 在请求Vsync以后,choreographer会等待Vsync的到来,在Vsync信号到来后,会触发dispatchVsync函数,从而调用onVsync方法…

自监督分类网络:创新的端到端学习方法

现代人工智能的快速发展中,分类任务的高效解决方案一直备受关注。今天,我们向大家介绍一种名为Self-Classifier的全新自监督端到端分类学习方法。由Elad Amrani、Leonid Karlinsky和Alex Bronstein团队开发,Self-Classifier通过优化同一样本的…

Web应用安全测试-认证功能缺陷

Web应用安全测试-认证功能缺陷 存在空口令 漏洞描述:认证登录环节允许空口令 测试方法: 找到网站登录页面,尝试输入用用户名,密码为空进行登录。 风险分析:攻击者可利用该漏洞登录网站后台,操作敏感数…