基于prefix tuning + Bert的标题党分类器

文章目录

  • 背景
  • 一、Prefix-Tuning介绍
  • 二、分类
  • 三、效果
  • 四、参阅

背景

近期, CSDN博客推荐流的标题党博客又多了起来, 先前的基于TextCNN版本的分类模型在语义理解上能力有限, 于是, 便使用的更大的模型来优化, 最终准确率达到了93.7%, 还不错吧.

一、Prefix-Tuning介绍

传统的fine-tuning是在大规模预训练语言模型(如Bert、GPT2等)上完成的, 针对不同的下游任务, 需要保存不同的模型参数, 代价比较高,

解决这个问题的一种自然方法是轻量微调(lightweight fine-tunning),它冻结了大部分预训练参数,并用小的可训练模块来增强模型,比如在预先训练的语言模型层之间插入额外的特定任务层。适配器微调(Adapter-tunning)在自然语言理解和生成基准测试上具有很好的性能,通过微调,仅添加约2-4%的任务特定参数,就可以获得类似的性能。

受到prompt的启发,提出一种prefix-tuning, 只需要保存prefix部分的参数即可.

相关论文: Prefix-Tuning: Optimizing Continuous Prompts for Generation

请添加图片描述

论文中, 作者使用Prefix-tuning做生成任务,它根据不同的模型结构定义了不同的Prompt拼接方式.

请添加图片描述

对于自回归模型,加入前缀后的模型输入表示:

z = [PREFIX; x; y]

对于编解码器结构的模型,加入前缀后的模型输入表示:

z = [PREFIX; x; PREFIX; y]

原理部分不过多赘述, 对于Prompt不太熟悉的同学, 一定要看看王嘉宁老师写的综述: Prompt-Tuning——深度解读一种新的微调范式

与Prefix-Tuning类似的方法还有P-tuning V2,不同之处在于Prefix-Tuning是面向文本生成领域的,P-tuning V2面向自然语言理解.

二、与Bert对比

Bert

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
)

因为我这里是二分类, 所以最后一层多了个Linear层

Prefix Tuning Bert

PeftModelForSequenceClassification(
  (base_model): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(21128, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (intermediate): BertIntermediate(
              (dense): Linear(in_features=768, out_features=3072, bias=True)
              (intermediate_act_fn): GELUActivation()
            )
            (output): BertOutput(
              (dense): Linear(in_features=3072, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
        )
      )
      (pooler): BertPooler(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (activation): Tanh()
      )
    )
    (dropout): Dropout(p=0.1, inplace=False)
    (classifier): ModulesToSaveWrapper(
      (original_module): Linear(in_features=768, out_features=2, bias=True)
      (modules_to_save): ModuleDict(
        (default): Linear(in_features=768, out_features=2, bias=True)
      )
    )
  )
  (prompt_encoder): ModuleDict(
    (default): PrefixEncoder(
      (embedding): Embedding(20, 18432)
    )
  )
  (word_embeddings): Embedding(21128, 768, padding_idx=0)
)

比Bert多了prompt_encoder, 查阅peft源码可以看到PrefixEncoder的实现:

class PrefixEncoder(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.prefix_projection = config.prefix_projection
        token_dim = config.token_dim
        num_layers = config.num_layers
        encoder_hidden_size = config.encoder_hidden_size
        num_virtual_tokens = config.num_virtual_tokens
        if self.prefix_projection and not config.inference_mode:
            # Use a two-layer MLP to encode the prefix
            self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
            self.transform = torch.nn.Sequential(
                torch.nn.Linear(token_dim, encoder_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),
            )
        else:
            self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim)

    def forward(self, prefix: torch.Tensor):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.transform(prefix_tokens)
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values

其实就是多了个embedding层

二、分类

好在huggingface的peft库对几个经典的prompt tuning都有封装, 实现起来并不难:



import logging
from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    PeftType,
    PrefixTuningConfig,
    PromptEncoderConfig,
    PromptTuningConfig,
    LoraConfig,
)
from torch.optim import AdamW
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from common.path.model.bert import get_chinese_roberta_wwm_ext

import os
import time
import torch
import math
from torch.utils.data import DataLoader
from datasets import load_dataset

import evaluate
from tqdm import tqdm
from server.tag.common.data_helper import TagClassifyDataHelper

logger = logging.getLogger(__name__)



class BlogTitleAttractorBertConfig(object):
    max_length = 32
    batch_size = 128
    p_type = "prefix-tuning"
    model_name_or_path = get_chinese_roberta_wwm_ext()
    lr = 3e-4
    num_epochs = 50
    num_labels = 2
    device = "cuda"   # cuda/cpu
    evaluate_every = 100
    print_every = 10


def get_tokenizer(model_config):
    if any(k in model_config.model_name_or_path for k in ("gpt", "opt", "bloom")):
        padding_side = "left"
    else:
        padding_side = "right"
    tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, padding_side=padding_side)
    if getattr(tokenizer, "pad_token_id") is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    return tokenizer


def get_model(model_config, peft_config):
    model = AutoModelForSequenceClassification.from_pretrained(model_config.model_name_or_path, num_labels=model_config.num_labels)
    if model_config.p_type is not None:
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()
    return model


def get_peft_config(model_config):
    p_type = model_config.p_type
    if p_type == "prefix-tuning":
        peft_type = PeftType.PREFIX_TUNING
        peft_config = PrefixTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=20)
    elif p_type == "prompt-tuning":
        peft_type = PeftType.PROMPT_TUNING
        peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=20)
    elif p_type == "p-tuning":
        peft_type = PeftType.P_TUNING
        peft_config = PromptEncoderConfig(task_type="SEQ_CLS", num_virtual_tokens=20, encoder_hidden_size=128)
    elif p_type == "lora":
        peft_type = PeftType.LORA
        peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)
    else:
        print(f"p_type:{p_type} is not supported.")
        return None, None
    logger.info(f"训练: {p_type} bert 模型")
    return peft_type, peft_config


def get_lr_scheduler(model_config, model, data_size):
    optimizer = AdamW(params=model.parameters(), lr=model_config.lr)

    # Instantiate scheduler
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=0.06 * (data_size * model_config.num_epochs),
        num_training_steps=(data_size * model_config.num_epochs),
    )
    return optimizer, lr_scheduler



class BlogTtileAttractorBertTrain(object):
    def __init__(self, config, options) -> None:
        self.config = config
        self.options = options
        self.model_path = "./data/models/blog_title_attractor/prefix-tuning"
        self.data_file = './data/datasets/blogs/title_attractor/train.csv'
        self.tag_path = "./data/models/blog_title_attractor/prefix-tuning/tag.txt"
        self.model_config = BlogTitleAttractorBertConfig()
        self.tag_dict, self.id2tag_dict = TagClassifyDataHelper.load_tag(self.tag_path)
        self.model_config.num_labels = len(self.tag_dict)

        _, peft_config = get_peft_config(self.model_config)
        self.tokenizer = get_tokenizer(self.model_config)
        self.model = get_model(self.model_config, peft_config)
        self.model.to(self.model_config.device)

    def replace_none(self, example):
        example["title"] = example["title"] if example["title"] is not None else ""
        example["label"] = self.tag_dict[example["label"].lower()]
        return example

    def load_data(self):
        # 加载数据集
        dataset = load_dataset("csv", data_files=self.data_file)
        dataset = dataset.filter(lambda x: x["title"] is not None)
        dataset = dataset["train"].train_test_split(0.2, seed=123)
        dataset = dataset.map(self.replace_none)
        return dataset
    
    def process_function(self, examples):
        tokenized_examples = self.tokenizer(examples["title"], truncation=True, max_length=self.model_config.max_length)
        return tokenized_examples
    
    def collate_fn(self, examples):
        return self.tokenizer.pad(examples, padding="longest", return_tensors="pt")

    def __call__(self):
        datasets = self.load_data()
        tokenized_datasets = datasets.map(self.process_function, batched=True, remove_columns=['title'])
        tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
        train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=self.collate_fn, batch_size=self.model_config.batch_size)
        eval_dataloader = DataLoader(
            tokenized_datasets["test"], shuffle=False, collate_fn=self.collate_fn, batch_size=self.model_config.batch_size
        )
        optimizer, lr_scheduler = get_lr_scheduler(self.model_config, self.model, len(train_dataloader))
        metric = evaluate.load("accuracy")
        start_time = time.time()
        best_loss = math.inf
        for epoch in range(self.model_config.num_epochs):
            self.model.train()
            for step, batch in enumerate(train_dataloader):
                batch.to(self.model_config.device)
                outputs = self.model(**batch)
                loss = outputs.loss
                loss.backward()
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                if step > 0 and step % self.model_config.print_every == 0:
                    now_time = time.time()
                    predictions = outputs.logits.argmax(dim=-1)
                    eval_metric = metric.compute(predictions=predictions, references=batch["labels"])
                    if loss < best_loss:
                        best_loss = loss
                        self.model.save_pretrained(self.model_path)

                        print("***TRAIN, steps:{%d}, loss:{%.3f}, accuracy:{%.3f}, cost_time:{%.3f}h" % (
                            step, loss, eval_metric["accuracy"], (now_time - start_time) / 3600))
                    else:
                        print("TRAIN, steps:{%d}, loss:{%.3f}, accuracy:{%.3f}, cost_time:{%.3f}h" % (
                            step, loss, eval_metric["accuracy"], (now_time - start_time) / 3600))
                if step > 0 and step % self.model_config.evaluate_every == 0:
                    self.model.eval()
                    total_loss = 0.
                    for step, batch in enumerate(tqdm(eval_dataloader)):
                        batch.to(self.model_config.device)
                        with torch.no_grad():
                            outputs = self.model(**batch)
                            loss = outputs.loss
                            total_loss += loss
                        predictions = outputs.logits.argmax(dim=-1)
                        predictions, references = predictions, batch["labels"]
                        metric.add_batch(
                            predictions=predictions,
                            references=references,
                        )

                    eval_metric = metric.compute()
                    print("epoch:%d, VAL, loss:%.3f, accuracy:%.3f" % (epoch, total_loss, eval_metric["accuracy"]))

三、效果

Prefix Tuning Bert

在标题党上的准确率为: 0.9368932038834952
在非标题党上的准确率为: 0.937015503875969

Bert

在标题党上的准确率为: 0.7006472491909385
在非标题党上的准确率为: 0.9728682170542635

直接微调的效果与Prefix Tuning 相比差距有点大, 理论上不应该有这么大差距才对, 应该和数据集有关系

四、参阅

  • Prompt-Tuning——深度解读一种新的微调范式
  • Continuous Optimization:从Prefix-tuning到更强大的P-Tuning V2
  • [代码学习]Huggingface的peft库学习-part 1- prefix tuning

温馨提示: 标题党文章在推荐流会被降权哦

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/29664.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

搭建TiDB负载均衡环境-LVS+KeepAlived实践

作者&#xff1a; 我是咖啡哥 原文来源&#xff1a; https://tidb.net/blog/f614b200 昨天&#xff0c;发了一篇使用HAproxyKP搭建TiDB负载均衡环境的文章&#xff0c;今天我们再用LVSKP来做个实验。 环境信息 TiDB版本&#xff1a;V7.1.0 haproxy版本&#xff1a;2.6.2 …

Linux环境下的工具(yum,gdb,vim)

一&#xff0c;yum yum其实是linux环境下的一种应用商店&#xff0c;主要用centos等版本。它也有三板斧&#xff1a;yum list,yum remove,yum install。当然不是说他只有这三个命令&#xff0c;还有yum search等等。在这直说以上三个。 yum list其实是查看你所能安装的软件包…

13.常用类|Java学习笔记

文章目录 包装类包装类型和String类型的相互转换Integer类和Character类的常用方法Integer创建机制&面试题 String类创建String对象的两种方式和区别字符串的特性String类的常用方法 StringBuffer类String和StringBuffer相互转换StringBuffer常用方法 StringBuilder类Strin…

安卓手机使用Termux搭建Hexo个人博客网站【内网穿透公网访问】

文章目录 1. 安装 Hexo2. 安装cpolar内网穿透3. 公网远程访问4. 固定公网地址 转载自cpolar极点云的文章&#xff1a;安卓手机使用Termux搭建Hexo个人博客网站【内网穿透公网访问】 Hexo 是一个用 Nodejs 编写的快速、简洁且高效的博客框架。Hexo 使用 Markdown 解析文章&#…

R 语言学习笔记

1. 基础语法 赋值 a 10; b <- 10;# 表示流向&#xff0c;数据流向变量&#xff0c;也可以写成10 -> b创建不规则向量 不用纠结什么是向量&#xff0c;就当作一个容器&#xff0c;数据类型要相同 a c("我","爱","沛")创建一定规则的向…

意向共享锁和意向排他锁

InnoDB表级锁 在绝大部分情况下都应该使用行锁&#xff0c;因为事务和行锁往往是选择InnoDB的理由&#xff0c;但个别情况下也使用表级锁&#xff1a; 1&#xff09;事务需要更新大部分或全部数据&#xff0c;表又比较大&#xff0c;如果使用默认的行锁&#xff0c;不仅这个事…

前端web入门-CSS-day06

(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;请留下您的足迹&#xff09; 目录 一、标准流 二、Flex 布局 组成 主轴对齐方式 侧轴对齐方式 修改主轴方向 弹性伸缩比 弹性盒子换行…

uniapp中使用mixins(混入)使用

mixins 选项接收一个混入对象的数组。这些混入对象可以像正常的实例对象一样包含实例选项&#xff0c;这些选项将会被合并到最终的选项中&#xff0c;使用的是和 Vue.extend() 一样的选项合并逻辑。也就是说&#xff0c;如果你的混入包含一个 created 钩子&#xff0c;而创建组…

设计用户模块的schema

schema 在计算机科学中&#xff0c;schema通常指的是 数据结构的定义和约束。 关系型数据库 在关系型数据库中&#xff0c;schema指的是数据库中所有表格的定义和表格之间的关系约束&#xff0c;包括每个表格的列名、数据类型、主键、外键等等。 如果要对一个关系型数据库进行…

Leetcode-6425. 找到最长的半重复子字符串

题目描述 给你一个下标从 0 开始的字符串 s &#xff0c;这个字符串只包含 0 到 9 的数字字符。 如果一个字符串 t 中至多有一对相邻字符是相等的&#xff0c;那么称这个字符串是 半重复的 。 请你返回 s 中最长 半重复 子字符串的长度。 一个 子字符串 是一个字符串中一段…

力扣日记2481

1. 题目 LeetCode 2481. 分割圆的最少切割次数 1.1 题意 可以使用直接或半径切分&#xff0c;管他叫一次切分&#xff0c;求切分圆为n等份的最少次数。 1.2 分析 可以想到&#xff0c;对圆做n等分&#xff0c;然后每个半径看出一次切分&#xff0c;这是最多次数&#xff0c;…

Python3 列表与元组 | 菜鸟教程(六)

目录 一、Python3 列表 &#xff08;一&#xff09;简介相关 1、序列是 Python 中最基本的数据结构。 2、序列中的每个值都有对应的位置值&#xff0c;称之为索引&#xff0c;第一个索引是 0&#xff0c;第二个索引是 1&#xff0c;依此类推。 3、Python 有 6 个序列的内置…

算法刷题-字符串-替换空格

题目&#xff1a;剑指Offer 05.替换空格 力扣题目链接 请实现一个函数&#xff0c;把字符串 s 中的每个空格替换成"%20"。 示例 1&#xff1a; 输入&#xff1a;s “We are happy.” 输出&#xff1a;“We%20are%20happy.” 思路 如果想把这道题目做到极致&…

webpack提升开发体验SourceMap

一、开发场景介绍 开发中我们不可避免的会写一些bug出来&#xff0c;这时候要调试&#xff0c;快速定位到bug到底出现在哪尤为关键。 例如我故意在sum函数中写一个错误代码如下&#xff1a; 这时我们用前面章节已经写好的开发模式的webpack.dev.js运行&#xff0c;控制台会出…

【总结笔记】Spring

1 Spring容器加载配置文件进行初始化。 Spring容器加载配置文件进行初始化主要有两种形式&#xff1a; 加载配置文件进行初始化&#xff1a; ClassPathXmlApplicationContext ctx new ClassPathXmlApplicationContext(“ApplicationContext.xml”); 加载配置类进行初始化&…

业务流程自动化:ThinkAutomation Professional Crack

ThinkAutomation 助力您的业务流程自动化。自动执行本地和基于云的业务流程&#xff0c;以降低成本并节省时间。 自动化传入的通信渠道&#xff0c;监控数据库&#xff0c;对传入的Webhook&#xff0c;Web表单和聊天机器人做出反应。处理文档、附件、本地文件和其他邮件源。 …

基于Spark的气象数据分析

研究背景与方案 1.1.研究背景 在大数据时代背景下&#xff0c;各行业数据的规模大幅度增加&#xff0c;数据类别日益复杂&#xff0c;给数据分析工作带来极大挑战。气象行业和人们的生活息息相关&#xff0c;随着信息时代的发展&#xff0c;大数据技术的出现为气象数据的发展…

模板匹配笔记

模板匹配是一种最基本、最原始的模式识别的方法。通过对比某一特定物体的图案位于图像的什么地方&#xff0c;进而识别出物体。它是图像处理中最基本、最常用的匹配方法。它的局限性主要是它只能进行平行移动&#xff0c;若原图像中的匹配目标发生旋转或大小变化&#xff0c;该…

前端vue入门(纯代码)09

【09.vue中组件的自定义事件】 自定义组件链接 在vue中用的click【点击】、keyup【按键】……等事件&#xff0c;这些属于内置事件&#xff0c;也就是js自带的事件。 问题一&#xff1a;什么是组件自定义事件呢&#xff1f; 【内置事件】:是给html元素用的&#xff0c;比如s…

014、数据库管理之配置管理

配置管理 TiDB配置系统配置集群配置配置的存储位置区分TiDB的系统参数和集群参数 系统参数系统参数的作用域系统参数的修改 集群参数集群参数的修改配置参数的查看 实验一&#xff1a; 在不同作用域下对数据库的系统参数进行修改session级别global级别 实验二&#xff1a; 修改…