240713_昇思学习打卡-Day25-LSTM+CRF序列标注(4)

240713_昇思学习打卡-Day25-LSTM+CRF序列标注(4)

最后一天咯,做第四部分。

BiLSTM+CRF模型

在实现CRF后,我们设计一个双向LSTM+CRF的模型来进行命名实体识别任务的训练。模型结构如下:

nn.Embedding -> nn.LSTM -> nn.Dense -> CRF

其中LSTM提取序列特征,经过Dense层变换获得发射概率矩阵,最后送入CRF层。具体实现如下:

# 定义双向LSTM结合CRF的序列标注模型
class BiLSTM_CRF(nn.Cell):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, padding_idx=0):
        """
        初始化BiLSTM_CRF模型。

        参数:
        vocab_size: 词汇表大小。
        embedding_dim: 词嵌入维度。
        hidden_dim: LSTM隐藏层维度。
        num_tags: 标签种类数量。
        padding_idx: 填充索引,默认为0。
        """
        super().__init__()
        # 初始化词嵌入层
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        # 初始化双向LSTM层
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)
        # 初始化从LSTM输出到标签的全连接层
        self.hidden2tag = nn.Dense(hidden_dim, num_tags, 'he_uniform')
        # 初始化条件随机场层
        self.crf = CRF(num_tags, batch_first=True)

    def construct(self, inputs, seq_length, tags=None):
        """
        模型的前向传播方法。

        参数:
        inputs: 输入序列,形状为(batch_size, seq_length)。
        seq_length: 序列长度,形状为(batch_size,)。
        tags: 真实标签,形状为(batch_size, seq_length),可选。

        返回:
        crf_outs: CRF层的输出,如果输入了真实标签则为损失值,否则为解码后的标签序列。
        """
        # 通过词嵌入层获取词向量表示
        embeds = self.embedding(inputs)
        # 通过双向LSTM层获取序列特征
        outputs, _ = self.lstm(embeds, seq_length=seq_length)
        # 通过全连接层转换LSTM输出到标签空间
        feats = self.hidden2tag(outputs)

        # 通过CRF层计算损失或解码
        crf_outs = self.crf(feats, tags, seq_length)
        return crf_outs

完成模型设计后,我们生成两句例子和对应的标签,并构造词表和标签表。

# 设置词嵌入维度和隐藏层维度
embedding_dim = 16
hidden_dim = 32

# 定义训练数据集,每条数据包含一个分词后的句子和相应的实体标签
training_data = [
    (
        "清 华 大 学 坐 落 于 首 都 北 京".split(),  # 分词后的句子
        "B I I I O O O O O B I".split()  # 相应的实体标签
    ),
    (
        "重 庆 是 一 个 魔 幻 城 市".split(),  # 分词后的句子
        "B I O O O O O O O".split()  # 相应的实体标签
    )
]

# 初始化词典,用于映射词到索引
word_to_idx = {}
# 添加特殊填充词到词典
word_to_idx['<pad>'] = 0
# 遍历训练数据,构建词到索引的映射
for sentence, tags in training_data:
    for word in sentence:
        # 如果词不在词典中,则添加到词典
        if word not in word_to_idx:
            word_to_idx[word] = len(word_to_idx)

# 初始化标签到索引的映射
tag_to_idx = {"B": 0, "I": 1, "O": 2}

len(word_to_idx)

接下来实例化模型,选择优化器并将模型和优化器送入Wrapper。

由于CRF层已经进行了NLLLoss的计算,因此不需要再设置Loss。

# 实例化BiLSTM_CRF模型,传入词汇表大小、词嵌入维度、隐藏层维度以及标签种类数量
model = BiLSTM_CRF(len(word_to_idx), embedding_dim, hidden_dim, len(tag_to_idx))

# 初始化随机梯度下降优化器,设置学习率为0.01,权重衰减为1e-4
optimizer = nn.SGD(model.trainable_params(), learning_rate=0.01, weight_decay=1e-4)

# 使用MindSpore的value_and_grad函数创建一个函数,它会同时计算模型的损失值和梯度
# 第二个参数设置为None表示不保留反向图,第三个参数是优化器的参数列表
grad_fn = ms.value_and_grad(model, None, optimizer.parameters)

def train_step(data, seq_length, label):
    """
    训练步骤函数,执行一次前向传播和反向传播更新模型参数。

    参数:
    data: 输入数据,形状为(batch_size, seq_length)。
    seq_length: 序列长度,形状为(batch_size,)。
    label: 真实标签,形状为(batch_size, seq_length)。

    返回:
    loss: 当前批次的损失值。
    """
    # 使用grad_fn计算损失值和梯度
    loss, grads = grad_fn(data, seq_length, label)
    # 使用优化器更新模型参数
    optimizer(grads)
    # 返回损失值
    return loss

将生成的数据打包成Batch,按照序列最大长度,对长度不足的序列进行填充,分别返回输入序列、输出标签和序列长度构成的Tensor。

def prepare_sequence(seqs, word_to_idx, tag_to_idx):
    """
    准备序列数据,包括填充和转换成张量。

    参数:
    seqs: 一个包含句子和对应标签的元组列表。
    word_to_idx: 词到索引的映射字典。
    tag_to_idx: 标签到索引的映射字典。

    返回:
    seq_outputs: 填充后的序列数据张量。
    label_outputs: 填充后的标签数据张量。
    seq_length: 序列的真实长度列表。
    """
    seq_outputs, label_outputs, seq_length = [], [], []
    # 获取最长序列长度
    max_len = max([len(i[0]) for i in seqs])

    for seq, tag in seqs:
        # 记录序列的真实长度
        seq_length.append(len(seq))
        # 将序列中的词转换为索引
        idxs = [word_to_idx[w] for w in seq]
        # 将标签转换为索引
        labels = [tag_to_idx[t] for t in tag]
        # 对序列进行填充
        idxs.extend([word_to_idx['<pad>'] for i in range(max_len - len(seq))])
        # 对标签进行填充,用'O'的索引填充
        labels.extend([tag_to_idx['O'] for i in range(max_len - len(seq))])
        # 添加填充后的序列和标签到列表
        seq_outputs.append(idxs)
        label_outputs.append(labels)

    # 将列表转换为MindSpore张量
    return ms.Tensor(seq_outputs, ms.int64), \
           ms.Tensor(label_outputs, ms.int64), \
           ms.Tensor(seq_length, ms.int64)

# 调用prepare_sequence函数处理训练数据,并获取处理后的数据、标签和序列长度
data, label, seq_length = prepare_sequence(training_data, word_to_idx, tag_to_idx)

# 打印处理后的数据、标签和序列长度的形状,以确认数据转换是否正确
print(data.shape, label.shape, seq_length.shape)

对模型进行预编译后,训练500个step。

训练流程可视化依赖tqdm库,可使用pip install tqdm命令安装。

from tqdm import tqdm

# 定义训练步骤的总数,用于进度条的设置
steps = 500

# 使用tqdm创建一个进度条,总进度为steps
with tqdm(total=steps) as t:
    for i in range(steps):
        # 执行单步训练,这里假设train_step是一个已定义的训练函数
        # 参数data为训练数据,seq_length为序列长度,label为标签
        loss = train_step(data, seq_length, label)
        
        # 更新进度条的附带信息,显示当前的损失值
        t.set_postfix(loss=loss)
        
        # 更新进度条,表示完成了一步训练
        t.update(1)

最后我们来观察训练500个step后的模型效果,首先使用模型预测可能的路径得分以及候选序列。

# 调用模型进行预测或评估,返回得分和历史记录
score, history = model(data, seq_length)

# 输出得分,用于查看模型的表现或决策
score

使用后处理函数进行预测得分的后处理。

predict = post_decode(score, history, seq_length)
predict

最后将预测的index序列转换为标签序列,打印输出结果,查看效果。

# 通过索引和标签的映射关系,构建标签到索引的反向映射
idx_to_tag = {idx: tag for tag, idx in tag_to_idx.items()}

def sequence_to_tag(sequences, idx_to_tag):
    """
    将序列中的索引转换为对应的标签。
    
    参数:
    sequences: 一个包含标签索引的序列列表。
    idx_to_tag: 一个字典,用于将索引映射到对应的标签。
    
    返回:
    一个列表,其中每个元素是输入序列中索引转换为标签后的结果。
    """
    # 初始化一个空列表,用于存储转换后的标签序列
    outputs = []
    # 遍历输入的序列列表
    for seq in sequences:
        # 对每个序列,将索引转换为标签,并添加到输出列表中
        outputs.append([idx_to_tag[i] for i in seq])
    # 返回转换后的标签序列列表
    return outputs

sequence_to_tag(predict, idx_to_tag)

打卡照片:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/796878.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

java各种锁介绍

在 Java 中&#xff0c;锁是用来控制多个线程对共享资源进行访问的机制。主要有以下几种类型的锁&#xff1a; 1.互斥锁&#xff08;Mutex Lock)&#xff1a;最简单的锁&#xff0c;一次只允许一个线程访问共享资源。如果一个线程获得了锁&#xff0c;其他线程必须等待锁被释放…

深度解读李彦宏的“不要卷模型,要卷应用”

深度解读李彦宏的“不要卷模型&#xff0c;要卷应用” —— AI技术的应用之道 引言 在2024世界人工智能大会的舞台上&#xff0c;李彦宏的“不要卷模型&#xff0c;要卷应用”言论犹如一石激起千层浪&#xff0c;引发了业界对AI技术发展路径的深思。本文将深入探讨这一观点&a…

JAVA设计模式>>结构型>>适配器模式

本文介绍23种设计模式中结构型模式的适配器模式 目录 1. 适配器模式 1.1 基本介绍 1.2 工作原理 1.3 适配器模式的注意事项和细节 1.4 类适配器模式 1.4.1 类适配器模式介绍 1.4.2 应用实例 1.4.3 注意事项和细节 1.5 对象适配器模式 1.5.1 基本介绍 1.5.2 …

解答|服务器只能开22端口可以申请IP地址SSL证书吗?

IP地址SSL证书&#xff0c;是一种专门颁发给公网IP地址的SSL证书&#xff0c;而不是常见的基于域名的SSL证书。SSL证书主要用于保障数据在客户端&#xff08;如用户的浏览器&#xff09;和服务器之间传输时的加密性和安全性&#xff0c;以防止数据被截取或篡改。 服务器只能开…

Python + OpenCV 简单车辆统计

目录 1 源码 2 运行结果 Python OpenCV 简单车辆统计 IDE : PyChram 1 源码 函数 car_count() 简单车辆统计 # 这是一个示例 Python 脚本。# 按 ShiftF10 执行或将其替换为您的代码。 # 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。 import cv2 impor…

windows远程桌面到 Linux系统(Ubuntu:22.04)—— 安装xrdp软件

1、在Linux系统上安装xrdp软件 sudo apt update sudo apt install xrdp2、安装完成后&#xff0c;需要开启xrdp服务 sudo systemctl start xrdp sudo systemctl enable xrdp打印返回 Synchronizing state of xrdp.service with SysV service script with /lib/systemd/system…

计算机网络通信

1、最原始的hub结构 2、局域网的交换机&#xff1a;mac和交换机端口路由表-数据链路层 mac地址 3、不同局域网之间进行通信&#xff0c;主要是路由器-网络层-ip 源ip到目标ip的不变化&#xff0c;但是mac地址在一直变化

【qt】TCP客户端如何断开连接?

disconnectFromHost() 来关闭套接字,断开连接. 当我们关闭窗口时,也需要断开连接. 需要重写关闭事件 如果当前的套接字状态是连接上的,我们就可以来断开连接. 运行结果:

C++ //练习 15.6 将Quote和Bulk_quote的对象传给15.2.1节(第529页)练习中的print_total函数,检查该函数是否正确。

C Primer&#xff08;第5版&#xff09; 练习 15.6 练习 15.6 将Quote和Bulk_quote的对象传给15.2.1节&#xff08;第529页&#xff09;练习中的print_total函数&#xff0c;检查该函数是否正确。 环境&#xff1a;Linux Ubuntu&#xff08;云服务器&#xff09; ## 工具&am…

FastAPI 学习之路(三十七)元数据和文档 URL

实现前的效果 那么如何实现呢&#xff0c;第一种方式如下&#xff1a; from routers.items import item_router from routers.users import user_router""" 自定义FastApi应用中的元数据配置Title&#xff1a;在 OpenAPI 和自动 API 文档用户界面中作为 API 的…

百日筑基第二十天-一头扎进消息队列3-RabbitMQ

百日筑基第二十天-一头扎进消息队列3-RabbitMQ 如上图所示&#xff0c;RabbitMQ 由 Producer、Broker、Consumer 三个大模块组成。生产者将数据发送到 Broker&#xff0c;Broker 接收到数据后&#xff0c;将数据存储到对应的 Queue 里面&#xff0c;消费者从不同的 Queue 消费数…

一个极简的 Vue 示例

https://andi.cn/page/621516.html

HSP_15章 Python_模板设计模式和oop进阶总结

P136 模板设计模式 1. 设计模式简介 设计模式是在大量的实践中总结和理论化之后优选的代码结构、编程风格、以及解决问题的思考方式 设计模式就像是经典的棋谱&#xff0c;不同的棋局&#xff0c;我们用不同的棋谱&#xff0c;免去我们自己再思考和摸索 2. 模板设计模式 基本…

linux查看目录下的文件夹命令,find 查找某个目录,但是不包括这个目录本身?

linux查看目录下的文件夹命令&#xff0c;find 查找某个目录&#xff0c;但是不包括这个目录本身&#xff1f; Linux中查看目录下的文件夹的命令是使用ls命令。ls命令用于列出指定目录中的文件和文件夹。通过不同的选项可以实现显示详细信息、按照不同的排序方式以及使用不同的…

Python爬虫之路(2):爬天气情况

hello hello~ &#xff0c;这里是绝命Coding——老白~&#x1f496;&#x1f496; &#xff0c;欢迎大家点赞&#x1f973;&#x1f973;关注&#x1f4a5;&#x1f4a5;收藏&#x1f339;&#x1f339;&#x1f339; &#x1f4a5;个人主页&#xff1a;绝命Coding-CSDN博客 &a…

卷积神经网络可视化的探索

文章目录 训练LeNet模型下载FashionMNIST数据训练保存模型 卷积神经网络可视化加载模型一个测试图像不同层对图像处理的可视化第一个卷积层的处理第二个卷积层的处理 卷积神经网络是利用图像空间结构的一种深度学习网络架构&#xff0c;图像在经过卷积层、激活层、池化层、全连…

Android liveData 监听异常,fragment可见时才收到回调记录

背景&#xff1a;在app的fragment不可见的情况下使用&#xff0c;发现注册了&#xff0c;但是没有回调导致数据一直未更新&#xff0c;只有在fragment可见的时候才收到回调 // 观察通用信息mLightNaviTopViewModel.getUpdateCommonInfo().observe(this, new Observer<Common…

13--memcache与redis

前言&#xff1a;数据库读取速度较慢一直是无法解决的问题&#xff0c;大型网站应对的方式主要是使用缓存服务器来缓解这种情况&#xff0c;减少数据库访问次数&#xff0c;以提高动态Web等应用的速度、提高可扩展性。 1、简介 Memcached/redis是高性能的分布式内存缓存服务器…

JVM:字节码文件

文章目录 一、Java虚拟机的组成二、字节码文件的组成1、基本信息2、常量池3、字段4、方法5、属性 三、常用的字节码工具1、javap -v 命令2、jclasslib插件3、阿里arthas 一、Java虚拟机的组成 二、字节码文件的组成 1、基本信息 魔数、字节码文件对应的Java版本号访问标识&am…

走进linux

1、为什么要使用linux 稳定性和可靠性&#xff1a; Linux内核以其稳定性而闻名&#xff0c;能够持续运行数月甚至数年而不需要重新启动。这对于服务器来说至关重要&#xff0c;因为它们需要保持长时间的稳定运行&#xff0c;以提供持续的服务 安全性&#xff1a; Linux系统…