NLP(16)--生成式任务

前言

仅记录学习过程,有问题欢迎讨论

  输入输出均为不定长序列(seq2seq)

自回归语言模型:

  • x 为 str[start : end ]; y为 [start+1 : end +1] 同时训练多个字,逐字计算交叉熵

encode-decode结构:

  • Encoder将输入转化为向量或矩阵,其中包含了输入中的信息
  • Decoder将Encoder的输出转化为输出

attention机制

  • 输入和输出应该和重点句子强相关,给输入加权(所以维度应该和输入的size一致)
  • 在这里插入图片描述

Teacher forcing

  • 使用真实标签作为下一个输入(自回归语言模型就是使用的teacher forcing)

Transform结构

  • Query来自Decode ,KV来自Encode
  • 在这里插入图片描述

使用Mask Attation 来避免对output做计算时,获取了所有的信息。只使用当前的位置对应的output信息。(自回归模型,先mask,然后在softmax)
在这里插入图片描述

评价指标:

  • BLEU:按照输出的字符计算一系列的数学(惩罚机制,Ngrim)计算来评价相似性

采样:

  • Beam size:
    保留概率最大的n条路径

  • Temperature Sampling
    根据概率分布生成下一个词,通过参数T,T越大,结果越随机,分布更均匀

  • TOP-P/K
    采样先按概率从大到小排序,累加概率不超过P的范围中选
    采样从TOP-K中采样下一个词

代码

使用bert实现自回归训练模型,
添加mask attention 来实现

# coding:utf8

import torch
import torch.nn as nn
import numpy as np
import math
import random
import os
import re

from transformers import BertModel, BertTokenizer

"""
基于pytorch的LSTM语言模型
"""


class LanguageModel(nn.Module):
    def __init__(self, input_dim, vocab_size):
        super(LanguageModel, self).__init__()
        # self.embedding = nn.Embedding(len(vocab), input_dim)
        # self.layer = nn.LSTM(input_dim, input_dim, num_layers=1, batch_first=True)
        self.bert = BertModel.from_pretrained(r"D:\NLP\video\第六周\bert-base-chinese", return_dict=False)
        self.classify = nn.Linear(input_dim, vocab_size)
        # self.dropout = nn.Dropout(0.1)
        self.loss = nn.functional.cross_entropy

    # 当输入真实标签,返回loss值;无真实标签,返回预测值
    def forward(self, x, y=None):
        # x = self.embedding(x)  # output shape:(batch_size, sen_len, input_dim)
        # 使用mask来防止提前预知结果
        if y is not None:
            # 构建一个下三角的mask
            # bert的mask attention 为(batch_size, vocab_size, vocab_size) L*L
            mask = torch.tril(torch.ones(x.shape[0], x.shape[1], x.shape[1]))
            print(mask)
            x, _ = self.bert(x, attention_mask=mask)
            y_pred = self.classify(x)
            return self.loss(y_pred.view(-1, y_pred.shape[-1]), y.view(-1))
        else:
            x = self.bert(x)[0]
            y_pred = self.classify(x)
            return torch.softmax(y_pred, dim=-1)


# 加载字表
def build_vocab(vocab_path):
    vocab = {"<pad>": 0}
    with open(vocab_path, encoding="utf8") as f:
        for index, line in enumerate(f):
            char = line[:-1]  # 去掉结尾换行符
            vocab[char] = index + 1  # 留出0位给pad token
    return vocab


# 加载语料
def load_corpus(path):
    corpus = ""
    with open(path, encoding="utf8") as f:
        for line in f:
            corpus += line.strip()
    return corpus


# 随机生成一个样本
# 从文本中截取随机窗口,前n个字作为输入,最后一个字作为输出
def build_sample(tokenizer, window_size, corpus):
    start = random.randint(0, len(corpus) - 1 - window_size)
    end = start + window_size
    window = corpus[start:end]
    target = corpus[start + 1:end + 1]  # 输入输出错开一位
    # print(window, target)
    # 中文的文本转化为tokenizer的id
    input_ids_x = tokenizer.encode(window, add_special_tokens=False, padding='max_length', truncation=True,
                                   max_length=10)
    input_ids_y = tokenizer.encode(target, add_special_tokens=False, padding='max_length', truncation=True,
                                   max_length=10)
    return input_ids_x, input_ids_y


# 建立数据集
# sample_length 输入需要的样本数量。需要多少生成多少
# vocab 词表
# window_size 样本长度
# corpus 语料字符串
def build_dataset(sample_length, tokenizer, window_size, corpus):
    dataset_x = []
    dataset_y = []
    for i in range(sample_length):
        x, y = build_sample(tokenizer, window_size, corpus)
        dataset_x.append(x)
        dataset_y.append(y)
    return torch.LongTensor(dataset_x), torch.LongTensor(dataset_y)


# 建立模型
def build_model(vocab_size, char_dim):
    model = LanguageModel(char_dim, vocab_size)
    return model


# 文本生成测试代码
def generate_sentence(openings, model, tokenizer, window_size):
    # reverse_vocab = dict((y, x) for x, y in vocab.items())
    model.eval()
    with torch.no_grad():
        pred_char = ""
        # 生成文本超过30字终止
        while len(openings) <= 30:
            openings += pred_char
            x = tokenizer.encode(openings, add_special_tokens=False, padding='max_length', truncation=True,
                                 max_length=10)
            x = torch.LongTensor([x])
            if torch.cuda.is_available():
                x = x.cuda()
            # batch_size = 1 最后一个字符的概率
            y = model(x)[0][-1]
            index = sampling_strategy(y)
            # 转化为中文 只有一个字符
            pred_char = tokenizer.decode(index)
    return openings


# 采样方式
def sampling_strategy(prob_distribution):
    if random.random() > 0.1:
        strategy = "greedy"
    else:
        strategy = "sampling"
    if strategy == "greedy":
        return int(torch.argmax(prob_distribution))
    elif strategy == "sampling":
        prob_distribution = prob_distribution.cpu().numpy()
        return np.random.choice(list(range(len(prob_distribution))), p=prob_distribution)


# 计算文本ppl
def calc_perplexity(sentence, model, vocab, window_size):
    prob = 0
    model.eval()
    with torch.no_grad():
        for i in range(1, len(sentence)):
            start = max(0, i - window_size)
            window = sentence[start:i]
            x = [vocab.get(char, vocab["<UNK>"]) for char in window]
            x = torch.LongTensor([x])
            target = sentence[i]
            target_index = vocab.get(target, vocab["<UNK>"])
            if torch.cuda.is_available():
                x = x.cuda()
            pred_prob_distribute = model(x)[0][-1]
            target_prob = pred_prob_distribute[target_index]
            prob += math.log(target_prob, 10)
    return 2 ** (prob * (-1 / len(sentence)))


def train(corpus_path, save_weight=True):
    epoch_num = 15  # 训练轮数
    batch_size = 64  # 每次训练样本个数
    train_sample = 10000  # 每轮训练总共训练的样本总数
    char_dim = 768  # 每个字的维度
    window_size = 10  # 样本文本长度
    # vocab = build_vocab(r"vocab.txt")  # 建立字表
    tokenizer = BertTokenizer.from_pretrained(r"D:\NLP\video\第六周\bert-base-chinese")
    vocab_size = 21128
    corpus = load_corpus(corpus_path)  # 加载语料
    model = build_model(vocab_size, char_dim)  # 建立模型
    if torch.cuda.is_available():
        model = model.cuda()
    optim = torch.optim.Adam(model.parameters(), lr=0.001)  # 建立优化器
    print("文本词表模型加载完毕,开始训练")
    for epoch in range(epoch_num):
        model.train()
        watch_loss = []
        for batch in range(int(train_sample / batch_size)):
            x, y = build_dataset(batch_size, tokenizer, window_size, corpus)  # 构建一组训练样本
            if torch.cuda.is_available():
                x, y = x.cuda(), y.cuda()
            optim.zero_grad()  # 梯度归零
            loss = model(x, y)  # 计算loss
            loss.backward()  # 计算梯度
            optim.step()  # 更新权重
            watch_loss.append(loss.item())
        print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))
        print(generate_sentence("忽然一阵狂风吹过,他直接", model, tokenizer, window_size))
        print(generate_sentence("天青色等烟雨,而我在", model, tokenizer, window_size))
    if not save_weight:
        return
    else:
        base_name = os.path.basename(corpus_path).replace("txt", "pth")
        model_path = os.path.join("model", base_name)
        torch.save(model.state_dict(), model_path)
        return


if __name__ == "__main__":
    train("corpus.txt", False)

    # mask = torch.tril(torch.ones(4, 4)).unsqueeze(0).unsqueeze(0)
    # print(mask)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/637895.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

设计模式——概述

1.设计模式定义 ​ 设计模式是软件设计中常见问题的典型解决方案,可用于解决代码中反复出现的设计问题。设计模式的出现可以让我们站在前人的肩膀上&#xff0c;通过一些成熟的设计方案来指导新项目的开发和设计&#xff0c;以便于我们开发出具有更好的灵活性和可扩展性&#…

如何在window中快速建立多个文件夹?

时隔多年&#xff0c;再次开始撰写文章&#xff0c;但是这次却是以设计师的身份 1. 几个基础快捷键先记一下&#xff0c;要不更高级的玩儿不转&#xff08;1&#xff09;快速打开资源管理器&#xff08;2&#xff09;快速建立新文件夹&#xff08;3&#xff09;快速修改文件文件…

【openlayers系统学习】1.1渲染GeoJSON,添加link交互

一、渲染GeoJSON 在进入编辑之前&#xff0c;我们将看一下使用矢量源和图层进行基本要素渲染。Workshop在 data​ 目录中包含一个 countries.json​ GeoJSON文件。我们首先加载该数据并将其渲染在地图上。 首先&#xff0c;编辑 index.html​ 以便向地图添加深色背景&#xf…

树洞陪聊系统源码/陪聊/陪玩/树洞/陪陪/公众号开发/源码交付/树洞系统源码

独立版本源码交付&#xff0c;自研UI和前后端代码 平台自带店员&#xff0c;无需自主招募&#xff0c;搭建直接运营 支持三方登录&#xff0c;官方支付、虎皮椒、易支付/码支付 支持首单体验、盲盒订单、指定下单等多个模式 支持钱包预充值、店员收藏、订单评价等功能 支持…

网页加载时,大图片文件如何分片加载,有示例代码。

浏览网页时候&#xff0c;碰到大图片半天加载不出来&#xff0c;急死人&#xff0c;本问分享一种分片加载的方式&#xff0c;其实还有其他方式&#xff0c;比如先模糊后清晰等。 一、为什么要分片加载 大图片文件可以通过分片加载来提高加载性能和用户体验。分片加载的基本思…

智能禁区监控:计算机视觉在人员禁区闯入检测中的应用

基于视觉分析的人员禁区闯入行为检测算法主要依赖于计算机视觉技术和深度学习算法。这些技术结合高性能的摄像头和图像处理硬件&#xff0c;实现了对监控区域内人员行为的自动识别和分析。具体来说&#xff0c;这种检测算法利用摄像头捕捉的视频数据&#xff0c;通过深度学习模…

科技前沿:IDEA插件Translation v3.6 带来革命性更新,翻译和发音更智能!

博主猫头虎的技术世界 &#x1f31f; 欢迎来到猫头虎的博客 — 探索技术的无限可能&#xff01; 专栏链接&#xff1a; &#x1f517; 精选专栏&#xff1a; 《面试题大全》 — 面试准备的宝典&#xff01;《IDEA开发秘籍》 — 提升你的IDEA技能&#xff01;《100天精通鸿蒙》 …

leetcode124 二叉树中的最大路径和-dp

题目 二叉树中的 路径 被定义为一条节点序列&#xff0c;序列中每对相邻节点之间都存在一条边。同一个节点在一条路径序列中 至多出现一次 。该路径 至少包含一个 节点&#xff0c;且不一定经过根节点。 路径和 是路径中各节点值的总和。 给你一个二叉树的根节点 root &…

50.WEB渗透测试-信息收集-CDN识别绕过(3)

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 内容参考于&#xff1a; 易锦网校会员专享课 上一个内容&#xff1a;49.WEB渗透测试-信息收集-CDN识别绕过&#xff08;2&#xff09; 关于cdn的识别方法内容…

基于SpringBoot的社区医院管理系统

基于SpringBootVue的社区医院管理系统的设计与实现~ 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringBootMyBatis工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 首页 医生预约 管理员界面 医生界面 摘要 基于Spring Boot的社区医院管理系…

linux命令日常使用思考

linux命令日常使用思考 复制的相关问题scp和cp的区别root192.168.5.229-r的理解 更新版本的相关问题svn info 根目录和家目录的区别根目录家目录 复制的相关问题 scp和cp的区别 安全性&#xff1a;SCP 是基于 SSH 的加密传输协议&#xff0c;可以保证数据在传输过程中的安全性…

揭秘网红老阳的选品师项目:从选品到赚钱的全方位解析

在快节奏的互联网时代&#xff0c;网红隋总以其独特的洞察力和前瞻性&#xff0c;为我们揭示了人力RPO(招聘流程外包)项目背后的变革与机遇。这次&#xff0c;我们不再单纯地从市场或企业的角度来探讨这个项目&#xff0c;而是从更宏观的视角&#xff0c;看看它如何推动了人力资…

Python 调整PDF文件的页面大小

在处理PDF文件时&#xff0c;我们可能会遇到这样的情况&#xff1a;原始PDF文档不符合我们的阅读习惯&#xff0c;或者需要适配不同显示设备等。这时&#xff0c;我们就需要及时调整PDF文档中的页面尺寸&#xff0c;以满足不同应用场景的需求。 利用Python语言的高效性和灵活性…

关于redis设置的密码不生效问题

今天申请了阿里云使用3个月的服务器&#xff0c;于是想在服务器上部署一下自己的项目&#xff0c;但是吸取了上次的教训&#xff0c;再也不敢随便开放redis的端口号了&#xff0c;就算要开放redis的端口&#xff0c;也要设置密码&#xff0c;保证不会被挖矿病毒通过redis入侵服…

自用升级centos7.2的默认Python 2.7.5为python3.8

wget https://www.python.org/ftp/python/3.8.8/Python-3.8.8.tgztar zxvf Python-3.8.8.tgz 进入刚刚解压后的目录 ./configure --prefix/data/soft/python3按照上面截图所属&#xff0c;需要安装gcc 安装报错需要安装 sudo yum install zlib1g-dev make -j4 make install -…

VBA语言専攻每周通知20240524

通知20240524 各位学员∶本周MF系列VBA技术资料增加611-615讲&#xff0c;T3学员看到通知后请免费领取,领取时间5月24日晚上18:00-5月26日晚上18:00。本次增加内容&#xff1a; MF611:用InputBox录入日期 MF612:信息提示10秒后关自动关闭 MF613:只是信息提示10秒 MF614:显…

VUE2 tab切换导航 展示页面内容(父级子级独立)

VUE2 tab切换导航 展示页面内容 父级子级独立 图片示例代码 图片示例 代码 <template><div class"center"><!-- 一级导航 --><div class"menu"><div class"menu_list"><div v-for"item of List" :k…

新定义RD8T36P48使用USCI0的TWI功能点亮OLED

时间不多&#xff0c;因此先只给出工程&#xff0c;等有时间再添加详细说明 现象 这是从之前的一个51单片机的程序移植过来的&#xff0c;主要修改了IIC启动和停止&#xff0c;以及数据发送的代码&#xff0c;我现在还不是很满意的一点是发送过程中要等待上一个字节发送完才能…

CDH6.3.2集成Flink1.17

直接运行脚本即可&#xff0c;一键输出相关依赖包 运行步骤已给到文档 下载地址

如何灵活运用keil工具进行问题分析(1)— 解决日常程序卡死问题

前言 &#xff08;1&#xff09;如果有嵌入式企业需要招聘湖南区域日常实习生&#xff0c;任何区域的暑假Linux驱动实习岗位&#xff0c;可C站直接私聊&#xff0c;或者邮件&#xff1a;zhangyixu02gmail.com&#xff0c;此消息至2025年1月1日前均有效 &#xff08;2&#xff0…