whisper 模型源码解读

在这里插入图片描述

whisper官方源码

whisper 模型官方代码:https://github.com/openai/whisper/blob/main/whisper/model.py ;注释如下

import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn

# 从其他模块导入必要的函数
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function

@dataclass
class ModelDimensions:
    """
    该类用于存储模型的各项参数
    """
    n_mels: int  # Mel谱图的频带数量
    n_audio_ctx: int  # 音频上下文窗口大小
    n_audio_state: int  # 音频状态维度
    n_audio_head: int  # 音频注意力头数量
    n_audio_layer: int  # 音频层数量
    n_vocab: int  # 词汇表大小
    n_text_ctx: int  # 文本上下文窗口大小
    n_text_state: int  # 文本状态维度
    n_text_head: int  # 文本注意力头数量
    n_text_layer: int  # 文本层数量

class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        """
        重写 forward 方法,确保输入张量的类型在归一化前后保持一致
        """
        return super().forward(x.float()).type(x.dtype)

class Linear(nn.Linear):
    def forward(self, x: Tensor) -> Tensor:
        """
        重写 forward 方法,确保权重和偏置与输入张量的类型一致
        """
        return F.linear(
            x,
            self.weight.to(x.dtype),
            None if self.bias is None else self.bias.to(x.dtype),
        )

class Conv1d(nn.Conv1d):
    def _conv_forward(
        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
    ) -> Tensor:
        """
        重写 _conv_forward 方法,确保卷积操作中的权重和偏置与输入张量的类型一致
        """
        return super()._conv_forward(
            x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
        )

def sinusoids(length, channels, max_timescale=10000):
    """
    生成用于位置嵌入的正弦曲线
    """
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)

class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        """
        初始化多头注意力层
        """
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        """
        多头注意力的前向传播
        """
        q = self.query(x)

        if kv_cache is None or xa is None or self.key not in kv_cache:
            # 如果没有缓存键和值,则正常计算
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            # 如果有缓存,则使用缓存的键和值
            k = kv_cache[self.key]
            v = kv_cache[self.value]

        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk

    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ):
        """
        计算 QKV 注意力
        """
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        qk = q @ k
        if mask is not None:
            qk = qk + mask[:n_ctx, :n_ctx]
        qk = qk.float()

        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()

class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        """
        初始化残差注意力块
        """
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)

        self.cross_attn = (
            MultiHeadAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        """
        残差注意力块的前向传播
        """
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x

class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        """
        初始化音频编码器
        """
        super().__init__()
        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state)

    def forward(self, x: Tensor):
        """
        前向传播,处理音频输入

        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
            音频的Mel谱图
        """
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)

        assert x.shape[1:] == self.positional_embedding.shape, "音频形状不正确"
        x = (x + self.positional_embedding).to(x.dtype)

        for block in self.blocks:
            x = block(x)

        x = self.ln_post(x)
        return x

class TextDecoder(nn.Module):
    def __init__(
        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        """
        初始化文本解码器
        """
        super().__init__()

        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [
                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
                for _ in range(n_layer)
            ]
        )
        self.ln = LayerNorm(n_state)

        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)

    def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
        """
        前向传播,处理文本输入并结合音频特征

        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
            文本的标

记序列
        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
            编码后的音频特征
        """
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        x = (
            self.token_embedding(x)
            + self.positional_embedding[offset : offset + x.shape[-1]]
        )
        x = x.to(xa.dtype)

        for block in self.blocks:
            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)

        x = self.ln(x)
        logits = (
            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()

        return logits

class Whisper(nn.Module):
    def __init__(self, dims: ModelDimensions):
        """
        初始化 Whisper 模型
        """
        super().__init__()
        self.dims = dims
        self.encoder = AudioEncoder(
            self.dims.n_mels,
            self.dims.n_audio_ctx,
            self.dims.n_audio_state,
            self.dims.n_audio_head,
            self.dims.n_audio_layer,
        )
        self.decoder = TextDecoder(
            self.dims.n_vocab,
            self.dims.n_text_ctx,
            self.dims.n_text_state,
            self.dims.n_text_head,
            self.dims.n_text_layer,
        )
        # 默认情况下,使用解码器层的后一半进行时间对齐;
        # 若要使用特定的注意力头,可以使用 `set_alignment_heads()` 方法。
        all_heads = torch.zeros(
            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
        )
        all_heads[self.dims.n_text_layer // 2 :] = True
        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)

    def set_alignment_heads(self, dump: bytes):
        """
        设置对齐的注意力头
        """
        array = np.frombuffer(
            gzip.decompress(base64.b85decode(dump)), dtype=bool
        ).copy()
        mask = torch.from_numpy(array).reshape(
            self.dims.n_text_layer, self.dims.n_text_head
        )
        self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)

    def embed_audio(self, mel: torch.Tensor):
        """
        编码音频特征
        """
        return self.encoder(mel)

    def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
        """
        获取预测的logits
        """
        return self.decoder(tokens, audio_features)

    def forward(
        self, mel: torch.Tensor, tokens: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        前向传播
        """
        return self.decoder(tokens, self.encoder(mel))

    @property
    def device(self):
        """
        获取模型所在的设备
        """
        return next(self.parameters()).device

    @property
    def is_multilingual(self):
        """
        判断模型是否支持多语言
        """
        return self.dims.n_vocab >= 51865

    @property
    def num_languages(self):
        """
        获取模型支持的语言数量
        """
        return self.dims.n_vocab - 51765 - int(self.is_multilingual)

    def install_kv_cache_hooks(self, cache: Optional[dict] = None):
        """
        为键和值的投影模块安装缓存钩子

        返回
        -------
        cache : Dict[nn.Module, torch.Tensor]
            映射键/值投影模块到其缓存的字典对象
        hooks : List[RemovableHandle]
            用于停止调用钩子的 PyTorch RemovableHandle 对象列表
        """
        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.dims.n_text_ctx:
                # 第一次标记或交叉注意时保存原始值
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks

    detect_language = detect_language_function  # 语言检测函数
    transcribe = transcribe_function  # 转录函数
    decode = decode_function  # 解码函数

语音识别自回归解码过程分析和举例说明

分析

语音识别自回归解码过程通常涉及以下步骤:

  1. 音频预处理:首先将输入的音频信号转换为Mel谱图。这一步骤在实际应用中通常由音频前端处理模块完成。

  2. 音频编码:将预处理后的Mel谱图输入到音频编码器中,生成音频特征表示。这些特征表示将作为后续文本解码器的输入。

  3. 文本解码:文本解码器通过自回归方式生成文本序列。具体来说,文本解码器在每个时间步上根据前一步生成的文本标记以及音频特征生成下一个文本标记。

  4. 语言检测和转录:在生成的文本序列基础上,可以进行语言检测,确认文本所使用的语言。此外,转录过程将生成的文本序列转换为最终的文本输出。

具体步骤

以下代码展示了上述过程的具体实现:

import torch

# 初始化模型参数
dims = ModelDimensions(
    n_mels=80,
    n_audio_ctx=1500,
    n_audio_state=512,
    n_audio_head=8,
    n_audio_layer=6,
    n_vocab=51865,
    n_text_ctx=448,
    n_text_state=512,
    n_text_head=8,
    n_text_layer=6,
)

# 创建模型实例
model = Whisper(dims)

# 假设我们有一个Mel谱图输入
mel_spectrogram = torch.randn(1, 80, 1500)  # (batch_size, n_mels, n_audio_ctx)

# 编码音频特征
audio_features = model.embed_audio(mel_spectrogram)

# 假设我们有一个初始的文本标记序列
initial_tokens = torch.tensor([[1, 2, 3]])  # (batch_size, seq_len)

# 自回归解码过程
for _ in range(10):  # 假设生成长度为10的序列
    logits = model.logits(initial_tokens, audio_features)
    next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
    initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

# 最终生成的文本标记序列
final_tokens = initial_tokens

# 打印生成的文本标记序列
print("Generated tokens:", final_tokens)

举例说明

假设我们有一段音频,其Mel谱图表示如下:

mel_spectrogram = torch.randn(1, 80, 1500)

我们希望通过自回归解码生成对应的文本表示。首先,我们将Mel谱图输入到音频编码器中,得到音频特征表示:

audio_features = model.embed_audio(mel_spectrogram)

然后,我们使用一个初始的文本标记序列(例如,序列开始标记)开始自回归解码过程:

initial_tokens = torch.tensor([[1]])  # 序列开始标记

在每个时间步,我们根据当前的文本标记序列和音频特征生成下一个文本标记:

logits = model.logits(initial_tokens, audio_features)
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

这个过程重复若干次(例如10次)直到生成完整的文本序列:

for _ in range(10):
    logits = model.logits(initial_tokens, audio_features)
    next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
    initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

最终得到的文本标记序列为:

final_tokens = initial_tokens
print("Generated tokens:", final_tokens)

以上示例展示了从音频输入到文本输出的完整自回归解码过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/735669.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【动态规划】1130. 叶值的最小代价生成树

1130. 叶值的最小代价生成树 难度&#xff1a;中等 力扣地址&#xff1a;https://leetcode.cn/problems/minimum-cost-tree-from-leaf-values/description/ 题目内容 给你一个正整数数组 arr&#xff0c;考虑所有满足以下条件的二叉树&#xff1a; 每个节点都有 0 个或是 2 个…

基于Java的学生成绩管理系统

你好呀&#xff0c;我是计算机学姐码农小野&#xff01;如果有相关需求&#xff0c;可以私信联系我。 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;Java技术&#xff0c;B/S结构 工具&#xff1a;MyEclipse&#xff0c;MySQL 系统展示 首页 个人中…

软件测试——稳定性测试:adb Monkey

Monkey 1. Monkey1.1 Monkey 是什么1.2 Monkey 测试场景1.3 Monkey 特点1.4 Monkey 在哪里1.5 测试准备事项1.6 Monkey 参数列表 2. 基本命令3. 常用参数4. 事件类型5. 调试参数6. 日志管理7. 日志错误定位8. Monkey测试可以发现的问题 1. Monkey 1.1 Monkey 是什么 Monkey是一…

计算机网络期末

1、IP 地址为:192.168.0.254,它的子网掩码应该为( ) A.255.255.255.0 B.255.255.254.0 C.255.255.252.0 D.255.255.0.0 2、最容易产生网络可靠性瓶颈问题的拓扑构型是&#xff08; &#xff09;。 A 总线型 B 星型 C 环型 D 网状型 3、HTTP 就是电子邮件阅读协议&#xff0…

使用Vue+Antv-X6实现一个输送线可视化编辑器(支持拖拽、自定义连线、自定义节点等)

最近公司有这样的业务&#xff0c;要实现一个类似流程图的编辑器&#xff0c;可以拖拉拽之类的&#xff0c;网上寻找了一番&#xff0c;最终决定使用Antv-X6这个图形引擎&#xff0c;非常强大&#xff0c;文档多看几遍也就能上手使用了。感觉还不错就写个使用心得期望能帮助到同…

访问网站时IP被屏蔽是什么原因?

在互联网使用中&#xff0c;有时我们可能会遇到访问某个网站时IP地址被屏蔽的情况。IP地址被网站屏蔽是一个相对常见的现象&#xff0c;而导致这种情况的原因多种多样&#xff0c;包括恶意行为、违规访问等。本文将解释IP地址被网站屏蔽的常见原因&#xff0c;同时&#xff0c;…

如何理解广角镜头和长焦镜头的区别。

为什么广角镜头的视野会比长焦镜头的视野大呢&#xff1f; 我之前用等光程解释了景深&#xff0c;也解释了为什么焦距越远&#xff0c;成像越大&#xff0c;但是从来没有提到过视野范围这个概念。实际上在我之前建立的数学模型中&#xff0c;物曲面S是无限大的&#xff0c;像曲…

Python,PyCharm,Anaconda安装及使用教程

一、Python下载及安装 Python官网Welcome to Python.org 当前最新版是3.11.0版本&#xff1a;python-3.11.0-amd64.exe 下一步&#xff0c;下一步进行安装即可 选择&#xff1a;“Customize installation”,出现下图&#xff1a; 点击“Next”下一步&#xff0c;出现如下图…

HTTP网络协议

1.HTTP &#xff08;1&#xff09;概念&#xff1a; Hyper Text Transfer Protocol&#xff0c;超文本传输协议规定了浏览器和服务器之间数据传输的规则。 &#xff08;2&#xff09;特点 基于TCP协议:面向连接&#xff0c;安全基于请求-响应模型的:一次请求对应一次响应HTTP协…

Redis源码学习:quicklist的设计与实现

为什么需要quicklist 假设你已经知道了ziplist的缺陷&#xff1a; 虽然节省空间&#xff0c;但是申请内存必须是连续的&#xff0c;如果内存占用比较多&#xff0c;申请效率低要存储大量数据&#xff0c;超过了ziplist的最佳上限后&#xff0c;性能有影响 借鉴分片思想&…

说说 SSL 的错误认识和不足之处

最近明月在学习折腾 LNMP 期间无意中建了一个 Typecho 的博客小站&#xff0c;近一周的折腾下来&#xff0c;收获真的不少&#xff0c;致使兴趣也越来越浓了&#xff0c;在升级 LNMP 的时候捎带手的给这个 Typecho 博客也启用了 SSL。并且开启了 memcached 和 OPcache 优化加速…

Spring Cache常见问题解决

目录 一 报错:Null key returned for cache operation 二 报错&#xff1a;类型转换异常 三 取出的数据为null 一 报错:Null key returned for cache operation 这里报错有两种情况&#xff1a; 第一&#xff0c;如果你在新增的方法上使用Cacheable注解&#xff0c;那么肯定是…

终极解决方案,传统极速方案,下载软件的双雄对决!

在数字资源日益丰富的今天&#xff0c;下载管理器成为了我们日常生活中不可或缺的工具。市场上两款备受欢迎的下载管理软件——Internet Download Manager&#xff08;IDM&#xff09;和迅雷11&#xff0c;它们以各自的特色和优势&#xff0c;满足了不同用户群体的需求。 软件…

5.3 Python len()函数:获取字符串长度或字节数

Python len()函数详解&#xff1a;获取字符串长度或字节数 Python 中&#xff0c;要想知道一个字符串有多少个字符&#xff08;获得字符串长度&#xff09;&#xff0c;或者一个字符串占用多少个字节&#xff0c;可以使用 len 函数。 len 函数的基本语法格式为&#xff1a; …

python-今年第几天

[题目描述] 定义一个结构体变量&#xff08;包括年、月、日&#xff09;。 计算该日在本年中是第几天&#xff0c;注意闰年问题。输入格式&#xff1a; 年 月 日。输出格式&#xff1a; 当年第几天。样例输入 2000 12 31样例输出 366 数据范围 对于100%的数据&#xff0c;保…

解决MNIST数据集下载慢,或者Http连接失败问题

下载MNIST数据集时遇到速度慢的问题 解决&#xff1a;手动从MNIST数据集的官方网站直接使用下载好的数据文件&#xff0c;放到指定目录下&#xff0c;再进行调取即可。 手动下载地址&#xff1a;MNIST官网 http://yann.lecun.com/exdb/mnist/ 【仍需要连接外网】 这里我提供…

ATA-4052C高压功率放大器在新能源汽车安全测试中的应用

新能源汽车的崛起已经改变了汽车行业的格局&#xff0c;为环境友好型交通方式提供了更多的选择。为了确保这些新型汽车的安全性和可靠性&#xff0c;进行全面的安全测试是至关重要的。高压功率放大器在新能源汽车的安全测试中发挥着重要的作用&#xff0c;本文将介绍其应用以及…

已解决:Vector析构异常Opencv Assert _CrtIsValidHeapPointer

已解决&#xff1a;Vector析构异常Opencv Assert _CrtIsValidHeapPointer 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页&#xff0c;我是博主英杰&#xff0c;211科班出身&#xff0c;就职于医疗科技公司&#xff0c;热衷分享知识&#xff0c;武汉…

PathDecider 详细解读

目录 PathDecider的主要功能 PathDecider代码分析 Process() MakeObjectDecision() MakeStaticObstacleDecision() MakeStaticObstacleDecision()的流程图​编辑 MakeStaticObstacleDecision()的代码解析 GenerateObjectStopDecision() PathDecider里用到的其他函数 …

C语言入门4-函数和程序结构

函数举例 读取字符串&#xff0c;如果字符串中含有ould则输出该字符串&#xff0c;否则不输出。 #include <stdio.h>// 函数声明 int getLine(char s[], int lim); int strindex(char s[], char t[]);int main() {char t[] "ould"; // 要查找的目标子字符串…