【GPT-SOVITS-03】SOVITS 模块-生成模型解析

说明:该系列文章从本人知乎账号迁入,主要原因是知乎图片附件过于模糊。

知乎专栏地址:
语音生成专栏

系列文章地址:
【GPT-SOVITS-01】源码梳理
【GPT-SOVITS-02】GPT模块解析
【GPT-SOVITS-03】SOVITS 模块-生成模型解析
【GPT-SOVITS-04】SOVITS 模块-鉴别模型解析
【GPT-SOVITS-05】SOVITS 模块-残差量化解析
【GPT-SOVITS-06】特征工程-HuBert原理

1.概述

SOVIT 模块的主要功能是生成最终的音频文件。

GPT-SOVITS的核心与SOVITS差别不大,仍然是分了两个部分:

  • 基于 VAE + FLOW 的生成器,源代码为 SynthesizerTrn
  • 基于多尺度分类器的鉴别器,源代码为 SynthesizerTrn

针对鉴别器相较于SOVITS5做了一些简化,主要的差异是在在生成模型处引入了残差量化层。

在训练时进入先验编码器的是经过残差量化层的 quatized 数据。

在推理时,用的是AR模块推理出的 code,然后用code直接生成 quatized 数据,再进入先验编码器。

训练所涉及特征包括:
在这里插入图片描述

2.训练流程

在这里插入图片描述

  • 如概述所注,在训练时SSL特征经过残差量化层中会产生量化编码 code 和数据 quatized。
  • 这个 code 也会作为 AR,即GPT模块训练的特征
  • 在推理时,这个code 就由 GPT 模块生成
  • 损失函数如下:
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
with autocast(enabled=False):
    loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
    loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl

    loss_fm = feature_loss(fmap_r, fmap_g)
    loss_gen, losses_gen = generator_loss(y_d_hat_g)
    loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl

3.推理流程

在这里插入图片描述
推理时直接通过先验编码器,通过FLOW的逆,进入解码器后输出推理音频

4.调试代码参考

import os,sys
import json
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from torch.utils.data import DataLoader

from vof.vits.data_utils import (
    TextAudioSpeakerLoader,
    TextAudioSpeakerCollate,
    DistributedBucketSampler,
)
from vof.vits.models import SynthesizerTrn
from vof.script.utils import HParams

now_dir   = os.getcwd()
root_dir  = os.path.dirname(now_dir)
prj_name  = 'project01'               # 项目名称
prj_dir   = root_dir + '/res/' + prj_name + '/'

with open(root_dir + '/res/configs/s2.json') as f:
    data = f.read()
    data = json.loads(data)

# 新增其他参数
s2_dir = prj_dir + 'logs'  # gpt 训练用目录
os.makedirs("%s/logs_s2" % (s2_dir), exist_ok=True)

data["train"]["batch_size"]             = 3
data["train"]["epochs"]                 = 15
data["train"]["text_low_lr_rate"]       = 0.4
data["train"]["pretrained_s2G"]         = root_dir + '/res/pretrained_models/s2G488k.pth'
data["train"]["pretrained_s2D"]         = root_dir + '/res/pretrained_models/s2D488k.pth'
data["train"]["if_save_latest"]         = True
data["train"]["if_save_every_weights"]  = True
data["train"]["save_every_epoch"]       = 5
data["train"]["gpu_numbers"]            = 0
data["data"]["exp_dir"]                 = data["s2_ckpt_dir"] = s2_dir
data["save_weight_dir"]                 = root_dir + '/res/weight/sovits'
data["name"]                            = prj_name
data['exp_dir']                         = s2_dir

hps = HParams(**data)
print(hps)
"""
self.path2 = "%s/2-name2text-0.txt" % exp_dir
self.path4 = "%s/4-cnhubert" % exp_dir
self.path5 = "%s/5-wav32k" % exp_dir
"""
train_dataset = TextAudioSpeakerLoader(hps.data)
"""
ssl  hubert 特征 [1,768,195]
spec [1025,195]
wav  [1,124800]
text [14,]
"""
train_sampler = DistributedBucketSampler(
    train_dataset,
    hps.train.batch_size,
    [
        32,
        300,
        400,
        500,
        600,
        700,
        800,
        900,
        1000,
        1100,
        1200,
        1300,
        1400,
        1500,
        1600,
        1700,
        1800,
        1900,
    ],
    num_replicas=1,
    rank=0,
    shuffle=True,
)
collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    collate_fn=collate_fn,
    batch_sampler=train_sampler
)

def _model_forward(ssl, y, y_lengths, text, text_lengths):

    net_g = SynthesizerTrn(
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        n_speakers=hps.data.n_speakers,
        **hps.model,
    )
    net_g.forward(ssl, y, y_lengths, text, text_lengths)

for data in train_loader:

    ssl_padded   = data[0]
    ssl_lengths  = data[1]
    spec_padded  = data[2]
    spec_lengths = data[3]
    wav_padded   = data[4]
    wav_lengths  = data[5]
    text_padded  = data[6]
    text_lengths = data[7]

    _model_forward(ssl_padded, spec_padded, spec_lengths, text_padded, text_lengths)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/465017.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【PyTorch】进阶学习:一文详细介绍 torch.load() 的应用场景、实战代码示例

【PyTorch】进阶学习:一文详细介绍 torch.load() 的应用场景、实战代码示例 🌈 个人主页:高斯小哥 🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程…

栈和队列(Java实现)

栈和队列(Java实现) 栈 栈(Stack):栈是先进后出(FILO, First In Last Out)的数据结构。Java中实现栈有以下两种方式: stack类LinkedList实现(继承了Deque接口) (1&am…

Python基础算法解析:支持向量机(SVM)

支持向量机(Support Vector Machine,SVM)是一种用于分类和回归分析的机器学习算法,它通过在特征空间中找到一个最优的超平面来进行分类。本文将详细介绍支持向量机的原理、实现步骤以及如何使用Python进行编程实践。 什么是支持向…

【Java刷题篇】串联所有单词的子串

这里写目录标题 📃1.题目📜2.分析题目📜3.算法原理🧠4.思路叙述✍1.进窗口✍2.判断有效个数✍3.维护窗口✍4.出窗口 💥5.完整代码 📃1.题目 力扣链接: 串联所有单词的子串 📜2.分析题目 阅…

2.vscode 配置python开发环境

vscode用着习惯了,也不想再装别的ide 1.安装vscode 这一步默认已完成 2.安装插件 搜索插件安装 3.选择调试器 Ctrl Shift P(或F1),在打开的输入框中输入 Python: Select Interpreter 搜索,选择 Python 解析器 选择自己安…

vulhub中GitLab 远程命令执行漏洞复现(CVE-2021-22205)

GitLab是一款Ruby开发的Git项目管理平台。在11.9以后的GitLab中,因为使用了图片处理工具ExifTool而受到漏洞CVE-2021-22204的影响,攻击者可以通过一个未授权的接口上传一张恶意构造的图片,进而在GitLab服务器上执行任意命令。 环境启动后&am…

深度学习1650ti在win10安装pytorch复盘

深度学习1650ti在win10安装pytorch复盘 前言1. 安装anaconda2. 检查更新显卡驱动3. 根据pytorch选择CUDA版本4. 安装CUDA5. 安装cuDNN6. conda安装pytorch结语 前言 建议有条件的,可以在安装过程中,开启梯子。例如cuDNN安装时登录 or 注册,会…

安卓国产百度网盘与国外云盘软件onedrive对比

我更愿意使用国外软件公司的产品,而不是使用国内百度等制作的流氓软件。使用这些国产软件让我不放心,他们占用我的设备大量空间,在我的设备上推送运行各种无用的垃圾功能。瞒着我,做一些我不知道的事情。 百度网盘安装包大小&…

鸿蒙Next 支持数据双向绑定的组件:Checkbox--Search--TextInput

Checkbox $$语法,$$绑定的变量发生变化时,会触发UI的刷新 Entry Component struct MvvmCase { State isMarry:boolean falseStatesearchText:string build() {Grid(){GridItem(){Column(){Text("checkbox 的双向绑定")Checkbox().select($$…

【PyTorch】基础学习:一文详细介绍 torch.save() 的用法和应用

【PyTorch】基础学习:一文详细介绍 torch.save() 的用法和应用 🌈 个人主页:高斯小哥 🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程&#x1f44…

ioDraw:与 GitHub、gitee、gitlab、OneDrive 无缝对接,绘图文件永不丢失!

🌟 绘图神器 ioDraw 重磅更新,文件保存再无忧!🎉 无需注册,即刻畅绘!✨ ioDraw 让你告别繁琐注册,尽情挥洒灵感! 新增文件在线实时保存功能,支持将绘图文件保存到 GitHu…

【HarmonyOS】ArkUI - 向左/向右滑动删除

核心知识点:List容器 -> ListItem -> swipeAction 先看效果图: 代码实现: // 任务类 class Task {static id: number 1// 任务名称name: string 任务${Task.id}// 任务状态finished: boolean false }// 统一的卡片样式 Styles func…

机电公司管理小程序|基于微信小程序的机电公司管理小程序设计与实现(源码+数据库+文档)

机电公司管理小程序目录 目录 基于微信小程序的机电公司管理小程序设计与实现 一、前言 二、系统设计 三、系统功能设计 1、机电设备管理 2、机电零件管理 3、公告管理 4、公告类型管理 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八…

【LabVIEW FPGA入门】定时

在本节学习使用循环计时器来设置FPGA循环速率,等待来添加事件之间的延迟,以及Tick Count来对FPGA代码进行基准测试。 1.定时快捷VI函数 在FPGA VI中放置的每个VI或函数都需要一定的时间来执行。您可以允许操作以数据流确定的速率发生,而无需额…

FFmpeg分析视频信息输出到指定格式(csv/flat/ini/json/xml)文件中

1.查看ffprobe帮助 输出格式参数说明: 本例将演示输出csv,flat,ini,json,xml格式 输出所使用的参数如下: 1.输出csv格式: ffprobe -i 4K.mp4 -select_streams v -show_frames -of csv -o 4K.csv 输出: 2.输出flat格式: ffprobe -i 4K.mp4 -select_streams v -show_frames …

深度学习pytorch——Tensor维度变换(持续更新)

view()打平函数 需要注意的是打平之后的tensor是需要有物理意义的,根据需要进行打平,并且打平后总体的大小是不发生改变的。 并且一定要谨记打平会导致维度的丢失,造成数据污染,如果想要恢复到原来的数据形式,是需要…

在github下载的神经网络项目,如何运行?

github网页上可获取的信息 在github上面,有一个requirements.txt文件,该文件说明了项目要求的python解释器的模块。 - 此外,还有一个README.md文件,用来说明项目的运行环境以及其他的信息。例如python解释器的版本是3.7、PyTorc…

理财第一课:炒股词典

文章目录 基础代码规则委比委差量比换手率市盈率市净率 散户亏钱的原因庄家分析炒股战法波浪理论其它 钱者,人生之大事,死生存亡之地,不可不察也。耕田之利,十倍;珠玉之赢,百倍;闹革命&#xff…

STM32使用TIM2+DMA产生PWM波形异常分析

1、问题描述 使用 STM32F4 的 TIM2 结合 DMA,产生的 PWM 波形不符合预期,但是相同的配置使用在 IM3 上,得到的 PWM 波形就是符合预期的。其代码和配置都是从 F1 移植过来的,在 F1 上使用 TIM2 是没有问题的,对于 F4 的…

蓝桥杯并查集|路径压缩|合并优化|按秩合并|合根植物(C++)

并查集 并查集是大量的树(单个节点也算是树)经过合并生成一系列家族森林的过程。 可以合并可以查询的集合的一种算法 可以查询哪个元素属于哪个集合 每个集合也就是每棵树都是由根节点确定,也可以理解为每个家族的族长就是根节点。 元素集合…