240712_昇思学习打卡-Day24-LSTM+CRF序列标注(3)

240712_昇思学习打卡-Day24-LSTM+CRF序列标注(3)

今天做LSTM+CRF序列标注第三部分,同样,仅作简单记录及注释,最近确实太忙了。

Viterbi算法

在完成前向训练部分后,需要实现解码部分。这里我们选择适合求解序列最优路径的Viterbi算法。与计算Normalizer类似,使用动态规划求解所有可能的预测序列得分。不同的是在解码时同时需要将第𝑖个Token对应的score取值最大的标签保存,供后续使用Viterbi算法求解最优预测序列使用。

取得最大概率得分ScoreScore,以及每个Token对应的标签历史HistoryHistory后,根据Viterbi算法可以得到公式:

请添加图片描述

从第0个至第𝑖个Token对应概率最大的序列,只需要考虑从第0个至第𝑖−1个Token对应概率最大的序列,以及从第𝑖个至第𝑖−1个概率最大的标签即可。因此我们逆序求解每一个概率最大的标签,构成最佳的预测序列。

由于静态图语法限制,我们将Viterbi算法求解最佳预测序列的部分作为后处理函数,不纳入后续CRF层的实现。

# 定义维特比解码算法,用于找出具有最大概率的标签序列
def viterbi_decode(emissions, mask, trans, start_trans, end_trans):
    # emissions: (seq_length, batch_size, num_tags) 发射概率矩阵
    # mask: (seq_length, batch_size) 序列掩码,用于标记有效序列长度
    # trans: 转移概率矩阵
    # start_trans: 初始状态转移概率向量
    # end_trans: 终止状态转移概率向量

    seq_length = mask.shape[0]  # 获取序列长度

    # 初始化分数矩阵,等于初始状态转移概率加上第一个发射概率
    score = start_trans + emissions[0]
    history = ()  # 初始化历史路径记录

    # 遍历序列中的每个时间步
    for i in range(1, seq_length):
        # 扩展维度以便广播运算
        broadcast_score = score.expand_dims(2)
        broadcast_emission = emissions[i].expand_dims(1)
        
        # 计算所有可能的转移分数
        next_score = broadcast_score + trans + broadcast_emission

        # 找出当前Token对应的最大分数标签,并保存
        indices = next_score.argmax(axis=1)
        history += (indices,)  # 保存历史路径信息

        # 取出最大分数
        next_score = next_score.max(axis=1)
        
        # 更新分数矩阵,只更新mask为True的部分
        score = mnp.where(mask[i].expand_dims(1), next_score, score)

    # 加上终止状态转移概率
    score += end_trans

    # 返回最终的分数矩阵和历史路径信息
    return score, history


# 根据解码过程中的得分和历史路径信息,重构最优标签序列
def post_decode(score, history, seq_length):
    # score: 最终得分矩阵
    # history: 历史路径信息
    # seq_length: 每个样本的实际序列长度

    batch_size = seq_length.shape[0]  # 获取批次大小
    seq_ends = seq_length - 1  # 计算每个样本的最后一个Token位置
    
    # 初始化最佳标签序列列表
    best_tags_list = []

    # 对批次中的每个样本进行解码
    for idx in range(batch_size):
        # 找出使最后一个Token对应的预测概率最大的标签
        best_last_tag = score[idx].argmax(axis=0)
        best_tags = [int(best_last_tag.asnumpy())]  # 添加最佳标签到序列

        # 从历史路径信息中反向追踪,找到每个Token的最佳标签
        for hist in reversed(history[:seq_ends[idx]]):
            best_last_tag = hist[idx][best_tags[-1]]
            best_tags.append(int(best_last_tag.asnumpy()))

        # 将逆序的标签序列反转,得到正序的最优标签序列
        best_tags.reverse()
        best_tags_list.append(best_tags)  # 添加到结果列表

    # 返回最优标签序列列表
    return best_tags_list

CRF层

完成上述前向训练和解码部分的代码后,将其组装完整的CRF层。考虑到输入序列可能存在Padding的情况,CRF的输入需要考虑输入序列的真实长度,因此除发射矩阵和标签外,加入seq_length参数传入序列Padding前的长度,并实现生成mask矩阵的sequence_mask方法。

综合上述代码,使用nn.Cell进行封装,最后实现完整的CRF层如下:

# 导入MindSpore相关模块
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common.initializer import initializer, Uniform

# 定义序列掩码生成函数
def sequence_mask(seq_length, max_length, batch_first=False):
    """
    根据序列的实际长度和最大长度生成mask矩阵。
    
    参数:
    seq_length: 实际序列长度张量。
    max_length: 序列的最大长度。
    batch_first: 是否将批次放在第一维度。
    
    返回:
    mask矩阵,形状为(batch_size, max_length),其中True表示有效位置,False表示填充位置。
    """
    # 生成从0到max_length的范围向量
    range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)
    # 创建mask矩阵,shape为(seq_length.shape + (1,))
    result = range_vector < seq_length.view(seq_length.shape + (1,))
    # 转换数据类型并根据batch_first参数调整维度顺序
    if batch_first:
        return result.astype(ms.int64)
    return result.astype(ms.int64).swapaxes(0, 1)


# 定义条件随机场(CRF)模型类
class CRF(nn.Cell):
    def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:
        """
        初始化CRF模型。
        
        参数:
        num_tags: 标签数量。
        batch_first: 是否将批次放在第一维度。
        reduction: 损失函数的缩减方式。
        """
        # 检查标签数量是否有效
        if num_tags <= 0:
            raise ValueError(f'无效的标签数量: {num_tags}')
        super().__init__()
        # 检查reduction参数是否有效
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'无效的缩减方式: {reduction}')
        self.num_tags = num_tags  # 标签数量
        self.batch_first = batch_first  # 批次是否在第一维度
        self.reduction = reduction  # 损失函数缩减方式
        # 初始化起始和结束状态转移权重
        self.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')
        self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')
        # 初始化状态间转移权重
        self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')

    def construct(self, emissions, tags=None, seq_length=None):
        """
        CRF模型的前向传播方法。
        
        参数:
        emissions: 发射概率张量。
        tags: 真实标签张量。
        seq_length: 序列长度张量。
        
        返回:
        如果tags为None,则返回解码结果;否则返回损失值。
        """
        if tags is None:
            return self._decode(emissions, seq_length)
        return self._forward(emissions, tags, seq_length)

    def _forward(self, emissions, tags=None, seq_length=None):
        """
        计算损失值。
        
        参数:
        emissions: 发射概率张量。
        tags: 真实标签张量。
        seq_length: 序列长度张量。
        
        返回:
        损失值。
        """
        # 根据batch_first参数调整emissions和tags的维度顺序
        if self.batch_first:
            batch_size, max_length = tags.shape
            emissions = emissions.swapaxes(0, 1)
            tags = tags.swapaxes(0, 1)
        else:
            max_length, batch_size = tags.shape
        
        # 如果seq_length未给出,则假设所有序列都是最大长度
        if seq_length is None:
            seq_length = mnp.full((batch_size,), max_length, ms.int64)
        
        # 生成mask矩阵
        mask = sequence_mask(seq_length, max_length)
        
        # 计算分子部分(真实路径的得分)
        numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)
        # 计算分母部分(所有可能路径的得分总和)
        denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)
        # 计算对数似然比
        llh = denominator - numerator
        
        # 根据reduction参数选择损失值的缩减方式
        if self.reduction == 'none':
            return llh
        elif self.reduction == 'sum':
            return llh.sum()
        elif self.reduction == 'mean':
            return llh.mean()
        return llh.sum() / mask.astype(emissions.dtype).sum()

    def _decode(self, emissions, seq_length=None):
        """
        解码方法,用于预测最优标签序列。
        
        参数:
        emissions: 发射概率张量。
        seq_length: 序列长度张量。
        
        返回:
        最优标签序列。
        """
        # 根据batch_first参数调整emissions的维度顺序
        if self.batch_first:
            batch_size, max_length = emissions.shape[:2]
            emissions = emissions.swapaxes(0, 1)
        else:
            batch_size, max_length = emissions.shape[:2]
        
        # 如果seq_length未给出,则假设所有序列都是最大长度
        if seq_length is None:
            seq_length = mnp.full((batch_size,), max_length, ms.int64)
        
        # 生成mask矩阵
        mask = sequence_mask(seq_length, max_length)
        
        # 使用维特比算法解码最优路径
        return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)

打卡图片:

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/796913.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

为Linux设置GRUB密码

正文共&#xff1a;999 字 11 图&#xff0c;预估阅读时间&#xff1a;1 分钟 我们前面介绍了如何恢复root密码&#xff08;CentOS 7.9遗忘了root密码怎么办&#xff1f;&#xff09;&#xff0c;虽然简单好用&#xff0c;但是可能会被不法分子利用&#xff0c;造成root密码以及…

Android ListView

ListView ListView是以列表的形式展示具体内容的控件&#xff0c;ListView能够根据数据的长度自适应显示&#xff0c;如手机通讯录、短消息列表等都可以使用ListView实现。如图1所示是两个ListView&#xff0c;上半部分是数组形式的ListView&#xff0c;下半部分是简单列表Lis…

【STM32F407ZET6】图文

STM32F407ZET6 是一款由意法半导体&#xff08;STMicroelectronics&#xff09;推出的ARM Cortex-M4 基于微控制器&#xff08;MCU&#xff09;。这款MCU是STM32系列中的高性能型号&#xff0c;专为需要快速数字信号处理&#xff08;DSP&#xff09;、实时控制和丰富外设功能的…

【信息系统项目管理师】高项常见知识点与公式

绩效域、合同、配置、变更、招投标、安全、立项论文考到的话大致业是按下面相关知识点开写 八大绩效域及其要点 团干部策划开公交 合同管理 合同的签订->合同的履行管理->合同的变更管理->合同的档案管理->合同的违约\索赔管理 配置管理 制定配置管理计划配置识…

当您消费权益受损时怎么办?您回答我帮您!

当您消费权益受损时怎么办&#xff1f;您回答我帮您&#xff01; 亲爱的消费者&#xff1a; 您好&#xff01;为了更好地了解消费者在购买商品或接受服务过程中遇到的问题&#xff0c;李秘书讲写作特此开展此次问卷调查。您的回答将对您我非常有帮助&#xff0c;我将根据您的回…

关于思维和智能体模型的思考(2)

在关于思维和智能体模型的思考&#xff08;1&#xff09;一文中&#xff0c;我们提出了思维和Agent 模型&#xff0c;提出了使用确定连接的智能体构建的思维模型。本文我们继续讨论思维与智能体&#xff0c;重点探讨另一种智能体-自主智能体&#xff0c;并且提出了自主智能体的…

面向企业中高层、业务决策人员的数据分析培训

✅作者简介&#xff1a;《数据运营&#xff1a;数据分析模型撬动新零售实战》作者、《数据实践之美》作者、数据科技公司创始人、多次参加国家级大数据行业标准研讨及制定、高端企培合作讲师。 是全社会都关注的复杂难题&#xff0c;数据应用的能力影响着你职场的高度。 是的&a…

【目录】全博文、专栏大纲

首先要和大家说一下&#xff0c;博主的文章并不是想到哪里写到哪里&#xff0c;而是以整个大后端为主题&#xff0c;成体系的在写专栏&#xff0c;从和后端紧相关的计算机核心课程开始、到JAVA SE、JAVA EE、到数据库、MQ等各类中间件、再到业务场景、性能优化。当然也会涉及一…

小众好玩的赛车游戏:环道巨星 CIRCUIT SUPERSTARS中文安装包

《环道巨星》&#xff08;Circuit Superstars&#xff09;是一款由赛车迷亲手为其他赛车迷打造的俯视角赛车游戏。荟集史上各类赛车运动&#xff0c;旨在提供刺激好玩的驾驶体验&#xff1b;而游戏自带的高技术难度将促使玩家长时间磨砺技巧&#xff0c;以达成完美的一圈。 游戏…

Cypress UI自动化之安装环境

注&#xff1a;macOS系统 一、git环境 略 二、node环境 1、安装nvm 前提&#xff1a;有装过Homebrew&#xff0c;参考adb使用方法文档 1、安装nvm&#xff1a;首先要保证之前没有安装过node&#xff0c;如果之前安装过&#xff0c;先 brew uninstall node brew install n…

paddlepaddle2.6,paddleorc2.8,cuda12,cudnn,nccl,python10环境

1.安装英伟达显卡驱动 首先需要到NAVIDIA官网去查自己的电脑是不是支持GPU运算。 网址是&#xff1a;CUDA GPUs | NVIDIA Developer。打开后的界面大致如下&#xff0c;只要里边有对应的型号就可以用GPU运算&#xff0c;并且每一款设备都列出来相关的计算能力&#xff08;Compu…

基于Java的飞机大战游戏的设计与实现论文

点击下载源码 基于Java的飞机大战游戏的设计与实现 摘 要 现如今&#xff0c;随着智能手机的兴起与普及&#xff0c;加上4G&#xff08;the 4th Generation mobile communication &#xff0c;第四代移动通信技术&#xff09;网络的深入&#xff0c;越来越多的IT行业开始向手机…

计算机组成原理:408考研|王道|学习笔记II

系列目录 计算机组成原理 学习笔记I 计算机组成原理 学习笔记II 目录 系列目录第四章 指令系统4.1 指令系统4.1.1 指令格式4.1.2 扩展操作码指令格式 4.2 指令的寻址方式4.2_1 指令寻址4.2_2 数据寻址 4.3 程序的机器级代码表示4.3.1 高级语言与机器级代码之间的对应4.3.2 常用…

C++从入门到起飞之——缺省参数/函数重载/引用全方位剖析!

目录 1.缺省参数 2. 函数重载 3.引⽤ 3.1 引⽤的概念和定义 3.2 引⽤的特性 3.3 引⽤的使⽤ 3.4 const引⽤ 3.5 指针和引⽤的关系 4.完结散花 个人主页&#xff1a;秋风起&#xff0c;再归来~ C从入门到起飞 个人格言&#xff1a;悟已往之不谏…

RocketMQ~架构了解

简介 RocketMQ 具有高性能、高可靠、高实时、分布式 的特点。它是一个采用 Java 语言开发的分布式的消息系统&#xff0c;由阿里巴巴团队开发&#xff0c;在 2016 年底贡献给 Apache&#xff0c;成为了 Apache 的一个顶级项目。 在阿里内部&#xff0c;RocketMQ 很好地服务了集…

Base64文件流查看下载PDF方法-CSDN

问题描述 数票通等接口返回的PDF类型发票是以Base64文件流的方式返回的&#xff0c;无法直接查看预览PDF发票&#xff0c; 处理方法 使用第三方在线工具&#xff1a;https://www.jyshare.com/front-end/61/ 在Html代码框中粘贴如下代码 <embed type"application/pd…

计网(1.1~1.4)

1.1计算机网络在信息时代的作用 21世纪的重要特征数字化、网络化和信息化 有三类网络&#xff1a;电信网络、有线电视网络和计算机网络 互联网两个重要基本特点&#xff0c;即连通性和共享 1.2因特网概述 &#xff08;1&#xff09;网络、互联网和互连网 网络:由若干结点和连接…

Docker 部署 ShardingSphere-Proxy 数据库中间件

文章目录 Github官网文档ShardingSphere-Proxymysql-connector-java 驱动下载conf 配置global.yamldatabase-sharding.yamldatabase-readwrite-splitting.yamldockerdocker-compose.yml Apache ShardingSphere 是一款分布式的数据库生态系统&#xff0c; 可以将任意数据库转换为…

初学编程不知道怎么选?推荐学习的三种热门编程语言

在当今的社会需求下&#xff0c;市场上最常见、最受欢迎、最广泛应用的编程语言主要有三种&#xff1a;C语言、Java语言和Python语言。 既然要做出选择&#xff0c;我们就需要明白这三种编程语言各自有何特点和区别。 一、特点 C语言 高效与灵活&#xff1a;C语言生成的机器…

防火墙组网与安全策略实验

实验要求&#xff1a; 实现&#xff1a; 防火墙接口配置&#xff1a; 所有接口均配置为三层接口 由于G1/0/3口下为vlan环境&#xff0c;所以防火墙需要配置子接口 &#xff1a; 交换机划分vlan分开生产区和办公区、配置trunk干道 &#xff1a; 安全策略&#xff1a; 生产区访…