【Python】科研代码学习:十 evaluate (metrics,Evaluator)

【Python】科研代码学习:十 evaluate

  • Evaluate
    • 评估类型
    • 简单使用教程
    • 如何寻找想要的 `metric`
    • 使用 `Evaluator`
    • 与 `transformers.trainer` 配合使用
    • 疑问与下节预告

Evaluate

  • 【HF官网-Doc-Evaluate:API】
    看名字就可以知道,EvaluateHF 提供的便捷的库,方便输入 (模型,数据集,指标) 三元组就能快速得到该指标来。
    甚至可以和 transformers / keras / scikit-learn 等库配合使用
  • 【安装】
    (1)安装合适的虚拟环境,比如 anaconda 环境,要求 python>3.7
    (2)pip install evaluate 直接安装即可
    (3)想要看 github repo 的话:
git clone https://github.com/huggingface/evaluate.git
cd evaluate
pip install -e .

评估类型

  • 这是一个前置知识
    Metric(指标)【HF官网-metric列表】:使用一个指标计算模型的表现,一般需要使用一些真值标签。如列表中所示,常见的有 ROUGE / BLEU / Perplexity / BERT Score / Accuracy / F1 / Precision / Recall / MSE / MASE
    Comparison(对照):比较两个模型之间的表现
    Measurement(测量):模型与数据集都很重要,一般都是使用特定的数据集进行评估。
  • 对于 Metric (指标),又有三个分类:
    Generic metric(通用指标):比如 precision, accuracy 等,在各种数据集下都可以直接用
    Task-specific metric(特定任务指标):与任务相关,比如机器翻译常用的 BLEU / ROUGE,以及命名实体检测常用的 seqeval
    Data-specific metric(特定数据集指标):用某个特定的 benchmark 来评估模型的表现,比如 GLUE benchmark 等。
  • 画了张分类图
    在这里插入图片描述

简单使用教程

  • 第一步:加载一个指标
    也可以提供 module_type,特别是有重名的时候。
import evaluate
accuracy = evaluate.load("accuracy")

word_length = evaluate.load("word_length", module_type="measurement")
  • 第二步(可选):展示指标的额外信息
    主要就是 description 查看模组介绍,citation 获取 BibTexfeatures 获取输入格式
accuracy = evaluate.load("accuracy")
accuracy.description
accuracy.citation
accuracy.features
  • 第三步:计算
    使用 compute 方法计算。输入格式记得查看上面的 features 打印查看。
    我们有两种输入方式,一种是一口气式(All-in-one),一种是增量式(Incremental
  • 一口气式:
accuracy.compute(references=[0,1,0,1], predictions=[1,0,0,1])
  • 一口气式,使用 .add() 方法
for ref, pred in zip([0,1,0,1], [1,0,0,1]):
    accuracy.add(references=ref, predictions=pred)
accuracy.compute()
  • 增量式,使用 .add_batch
for refs, preds in zip([[0,1],[0,1]], [[1,0],[0,1]]):
    accuracy.add_batch(references=refs, predictions=preds)
accuracy.compute()
  • 增量式,在调用模型时很常用,因为一般数据是一条一条跑出来的,除非提前存好。
for model_inputs, gold_standards in evaluation_dataset:
    predictions = model(model_inputs)
    metric.add_batch(references=gold_standards, predictions=predictions)
metric.compute()
  • 第三步(可选):复合多种指标,使用 combine 方法
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])
clf_metrics.compute(predictions=[0, 1, 0], references=[0, 1, 1])
  • 第三步(可选):使用 Evaluator ,方便 (模型,数据集,指标) 三元组评估,在后文讲
  • 第四步(可选):使用官方提供的一个可视化,但看了下只支持 ComplexRadarradar_plot 类,就不赘述了。

如何寻找想要的 metric

  • 对于 Generic metric,可以直接在这个 【HF官网-metric列表】 里面查看,根据里面的 card 指引操作,
  • 对于 Task-specific metric,首先在 【HF官网-Task索引】 里面找到自己在做的任务,比如 QA
    在这里插入图片描述
  • 进入后,在右下角可以找到 Datasets for QA 以及 Metrics for QA,看到这里有 exact-match & F1
    在这里插入图片描述
  • 对于 Dataset-specific metric,我们在 HF 寻找相关的 dataset 后,在 Dataset Preview 或者 dataset card 中都可以看到具体的操作方式。

使用 Evaluator

  • 使用 Evaluator,可以方便评估 (模型,数据集,指标) 三元组
    目前官方提供支持的任务有:
"text-classification": will use the TextClassificationEvaluator.
"token-classification": will use the TokenClassificationEvaluator.
"question-answering": will use the QuestionAnsweringEvaluator.
"image-classification": will use the ImageClassificationEvaluator.
"text-generation": will use the TextGenerationEvaluator.
"text2text-generation": will use the Text2TextGenerationEvaluator.
"summarization": will use the SummarizationEvaluator.
"translation": will use the TranslationEvaluator.
"automatic-speech-recognition": will use the AutomaticSpeechRecognitionEvaluator.
"audio-classification": will use the AudioClassificationEvaluator.
  • 可以看到,使用不同的任务,就需要使用不同的 Evaluator 类。
    假设我们做 QA 任务,我们进到 QuestionAnsweringEvaluator 类中查询使用方法
  • 可以从源码或者API中看到,QuestionAnsweringEvaluator 类其实就是继承了 Evaluator 类,在初始化中设置了 task='question-answering'
    在这里插入图片描述
  • 并且重载了 compute 方法,这个肯定是唯一重要的方法,稍微介绍一下参数
    model_or_pipeline :也就是说,我们可以提供 PretrainedModel,也可以提供 Pipeline
    data:可以是 Dataset 类型,也可以提供字符串,表示数据集的名字(貌似这里是不支持提供本地路径的?)
    subset :如果数据集有子数据集的话,在这里提供它的名字
    split(str):怎么划分,比如可以提供 split="validation[:2]"
    metric (str or EvaluationModule):指标名
    tokenizer:如果我们提供的是 model 而不是 pipeline 的话,就需要提供 tokenizer
    strategy :如果设置成 "bootstrap" 的话,就可以设置 confidence_level 置信区间;否则默认为 "simple"
    device :显卡号。
    在这里插入图片描述
  • 例子:
from evaluate import evaluator
from datasets import load_dataset
task_evaluator = evaluator("question-answering")
data = load_dataset("squad", split="validation[:2]")
results = task_evaluator.compute(
    model_or_pipeline="sshleifer/tiny-distilbert-base-cased-distilled-squad",
    data=data,
    metric="squad",
)

transformers.trainer 配合使用

  • 前面的功能感觉还是有点鸡肋?在训练时和 trainer 配合使用,不就可以在训练的时候显示我需要的指标了嘛
  • 看下面的例子
    重点是,我们实例化了 metric = evaluate.load("accuracy"),表示要计算 accu
    然后,我们自定义了一个方法 compute_metrics,对于输入的预测 eval_pred,我们获取它的 logits, labels, predictions 值,然后用 metric.compute 方法,计算出指标。
    trainer 我们直接提供给它 compute_metrics 即可。
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import numpy as np
import evaluate

# Prepare and tokenize dataset
dataset = load_dataset("yelp_review_full")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(200))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(200))

# Setup evaluation 
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

# Load pretrained model and evaluate model after each epoch
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()

疑问与下节预告

  • 看了上面这么多代码,感觉只能单卡跑啊?感觉意义突然不大了起来。
    HF 支持使用 accelerate 来多卡运行。在 trainer 里只要设置单卡的 batch 即可。
    另外可以学习 GPU Inference
    常用的加速工具还有 deepspeedHF 也是封装的比较好了。
  • 虽然能中间显示指标了,我怎么看到呢?
    最常用的在训练时展示指标的工具就是 wandb 了。
  • 还有微调常用的内容,PEFToptimization(optimizer, scheduler)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/455142.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

中国湿地沼泽分类分布数据集

数据下载链接:百度云下载链接 引言 随着经济社会的快速发展和城市化进程的加速推进,农业发生功能性转变,从单一生产功能向生产、生活、生态多功能服务首都经济社会发展转变。湿地与农田、草地、森林三大生态系统整合形成完整的现代农业生态服…

Linux环境(Ubuntu)上搭建MQTT服务器(EMQX )

目录 概述 1 认识EMQX 1.1 EMQX 简介 1.2 EMQX 版本类型 2 Ubuntu搭建EMQX 平台 2.1 下载和安装 2.1.1 下载 2.1.2 安装 2.2 查看运行端口 3 运行Dashboard 管理控制台 3.1 查看Ubuntu上的防火墙 3.2 运行Dashboard 管理控制台 概述 本文主要介绍EMQX 的一些内容&a…

云计算 3月12号 (PEX)

什么是PXE? PXE,全名Pre-boot Execution Environment,预启动执行环境; 通过网络接口启动计算机,不依赖本地存储设备(如硬盘)或本地已安装的操作系统; 由Intel和Systemsoft公司于199…

游戏数据处理

游戏行业关键数据指标 ~ 总激活码发放量、总激活量、总登录账号数 激活率、激活登录率 激活率 激活量 / 安装量 激活率 激活量 / 激活码发放量 激活且登录率 激活且登录量 / 激活码激活量 激活且登录率应用场景 激活且登录率是非常常用的转化率指标之一,广泛…

今天我们来学习一下关于MySQL数据库

目录 前言: 1.MySQL定义: 1.1基础概念: 1.1.1数据库(Database): 1.1.2表(Table): 1.1.3记录(Record)与字段(Field): …

C语言strcmp函数讲解

strcmp函数介绍 在cplusplus官网上是这样介绍strcmp函数的 这里的意思是假如我们输入两个字符串一个是abcdef另一个也是abcdef他们两个字符的每个元素的ascii码值进行比较如果两个元素的ascii码值都相等就移动到下一个元素a与a进行比较b与b进行比较直到遇到\0为止&#xff0c…

数据结构:7、队列

一、队列的概念与结构 队列:只允许在一端进行插入数据操作,在另一端进行删除数据操作的特殊线性表,队列具有先进先出FIFO(First In First Out) 入队列:进行插入操作的一端称为队尾 出队列:进行删除操作的一端称为队头…

功能测试--APP性能测试

功能测试--APP性能测试 内存数据查看内存测试 CPU数据查看CPU测试 流量和电量的消耗流量测试流量优化方法电量测试电量测试场景(大) 获取启动时间启动测试--安卓 流畅度流畅度测试 稳定性稳定性测试 内存数据查看 内存泄露:内存的曲线持续增长(增的远比减…

码头船只出行和货柜管理系统的设计与实现

针对于码头船只货柜信息管理方面的不规范,容错率低,管理人员处理数据费工费时,采用新开发的码头船只货柜管理系统可以从根源上规范整个数据处理流程。 码头船只货柜管理系统能够实现货柜管理,路线管理,新闻管理&#…

【MMDetection3D实战(3)】: KITTI 数据集介绍

文章目录 1. 数据集介绍2 数据下载及准备2.1 下载并整理数据集2.2 传感器及坐标定义2.3 数据的标注3 MMDet3D 中的坐标系规范4 数据的处理及可视化4.1 数据处理4.2 点云读取和可视化4.2.1 点云的读取4.2.2 点云的可视化1. 数据集介绍 KITTI数据集是3D目标检测中比较基础和常用…

【LeetCode】升级打怪之路 Day 17:二叉树题型 —— 二叉树的序列化与反序列化

今日题目: 297. 二叉树的序列化与反序列化652. 寻找重复的子树 目录 LC 297. 二叉树的序列化与反序列化 【classic】 ⭐⭐⭐⭐⭐1)序列化逻辑2)反序列化逻辑 LC 652. 寻找重复的子树 【稍有难度】 今天主要学习了二叉树的序列化和反序列化相关…

数字逻辑-时序逻辑电路一

一、实验目的 (1)熟悉触发器的逻辑功能及特性。 (2)掌握集成D和JK触发器的应用。 (3)掌握时序逻辑电路的分析和设计方法。 二、实验仪器及材料 三、实验内容及步骤 1、用D触发器(74LS74&am…

使用Docker在windows上安装IBM MQ

第一步、安装wsl 详见我另一篇安装wsl文章。 第二步、安装centos 这里推荐两种方式,一种是从微软商城安装,一种是使用提前准备好的镜像安装,详见我另一篇windos下安装centos教程。 第三步、安装windows下的Docker desktop 详见我另一篇wind…

TQ15EG开发板教程:运行MPSOC+AD9361

目录 1,下载工程需要使用的文件 2,编译以及修改工程 3,获取生成BOOT.BIN所需要的3个文件 3.1生成bit文件 3.2生成elf文件 3.3生成fsbl文件 4,生成boot.bin文件 5,上板测试 6,切换FMC接口 7&#…

自适应窗口图片轮播HTML代码

自适应窗口图片轮播HTML代码,源码由HTMLCSSJS组成,记事本打开源码文件可以进行内容文字之类的修改,双击html文件可以本地运行效果,也可以上传到服务器里面,重定向这个界面 代码下载地址 自适应窗口图片轮播HTML代码

分享 | 计算机组成与设计学习资料+CPU设计源码+实验报告

1.引言 百度网盘资源链接: 链接:https://pan.baidu.com/s/1Ww6u_l1L6DMXofC2HxfETw?pwdyqd6 提取码:yqd6 2.学习资源预览 2.1 包含学习手册四本: - 计算机原理与设计:Verilog HDL版 - 计算机组成与设…

开源分子对接程序rDock使用方法(2)-高通量虚拟筛选HTVS

欢迎浏览我的CSND博客! Blockbuater_drug …点击进入 文章目录 前言一、rDock用于高通量虚拟筛选HTVSMulti-Step Protocol HTVS步骤及注意事项 二、rDock中Multi-Step Protocol用于HTVS的用法Step 1. Exhaustive dockingStep 2. sdreport summaryStep 3. 运行rbhtfi…

Linux之NFS网络文件系统详解

华子目录 简介NFS背景介绍注意 生产应用场景NFS工作原理示例图流程 NFS的使用安装配置文件主配置文件分析权限参数/etc/exports文件内容示例 实验1nfs账户映射实验2实验3 autofs自动挂载服务产生原因安装配置文件分析挂载参数 实验4实验5:本机自动挂载光驱 简介 NF…

专升本 C语言笔记-08 goto语句

goto语句 无条件跳转运算符(凡是执行到goto语句会直接跳转到 定义的标签) 缺点&#xff1a;滥用goto语句将会导致逻辑混乱&#xff0c;导致系统崩溃等问题! ! ! 代码演示 int i 0; //定义标签 jump(名字随便起哦) jump:printf("%d ",i); i; if(i < 10)goto j…

如何处理Android悬浮弹窗双击返回事件?

目录 1 前言 1.1 准备知识 1.2 问题概述 2 解决方案 3 代码部分 3.1 动态更新窗口焦点 3.2 窗口监听返回事件 3.3 判断焦点是否在窗口内部 3.4 窗口监听焦点移入/移出 1 前言 1.1 准备知识 1&#xff09;开发环境&#xff1a; 2D开发环境&#xff1a;所有界面或弹窗…