大模型生成RAG评估数据集并计算hit_rate 和 mrr

文章目录

    • 背景
    • 简介
    • 代码实现
    • 公开
    • 参考资料

背景

最近在做RAG评估的实验,需要一个RAG问答对的评估数据集。在网上没有找到好用的,于是便打算自己构建一个数据集。

简介

本文使用大模型自动生成RAG 问答数据集。使用BM25关键词作为检索器,然后在问答数据集上评估该检索器的效果。
输入是一篇文本,使用llamaindex加载该文本,使用prompt让大模型针对输入的文本生成提问。
步骤如下:

  1. llamaindex 加载数据;
  2. 利用 chatglm3-6B 构建CustomLLM;
  3. 使用prompt和chatglm,结合文本生成对应的问题,构建RAG问答数据集;
  4. 使用BM25Retriever,构建基于关键词的检索器;
  5. 评估BM25Retriever在数据集上的hite_ratemrr结果;

由于在构建问答对时,让大模型结合文本生成对应的问题。笔者在测试时,发现关键词检索比向量检索效果要好

代码实现

导入包

from typing import List, Any

from llama_index.core import SimpleDirectoryReader

from llama_index.core.node_parser import SentenceWindowNodeParser
from llama_index.legacy.llms import (
    CustomLLM, CompletionResponse, CompletionResponseGen, LLMMetadata)
from llama_index.legacy.schema import NodeWithScore, QueryBundle, Node
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.legacy.retrievers import BM25Retriever
from llama_index.core.evaluation import RetrieverEvaluator
from llama_index.core.evaluation import (
    generate_question_context_pairs,
    EmbeddingQAFinetuneDataset,
)

加载数据,使用llamaindex网站的paul_graham_essay.txt

# Load data
documents = SimpleDirectoryReader(
    input_files=["data/paul_graham_essay.txt"]
).load_data()

# create the sentence window node parser w/ default settings
node_parser = SentenceWindowNodeParser.from_defaults(
    window_size=3,
    window_metadata_key="window",
    original_text_metadata_key="original_text",
)

# Extract nodes from documents
nodes = node_parser.get_nodes_from_documents(documents)

# by default, the node ids are set to random uuids. To ensure same id's per run, we manually set them.
for idx, node in enumerate(nodes):
    node.id_ = f"node_{idx}"

大模型加载
chatglm3-6B 使用half,显存占用12G

from modelscope import snapshot_download
from modelscope import AutoTokenizer, AutoModel

model_name = "chatglm3-6b"
model_path = snapshot_download('ZhipuAI/chatglm3-6b')
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
model = model.eval()

本地自定义大模型

# set context window size
context_window = 2048
# set number of output tokens
num_output = 256


class ChatGML(CustomLLM):
    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""
        return LLMMetadata(
            context_window=context_window,
            num_output=num_output,
            model_name=model_name,
        )

    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        prompt_length = len(prompt)

        # only return newly generated tokens
        text, _ = model.chat(tokenizer, prompt, history=[])
        return CompletionResponse(text=text)

    def stream_complete(
            self, prompt: str, **kwargs: Any
    ) -> CompletionResponseGen:
        raise NotImplementedError()


llm_model = ChatGML()

生成RAG测试数据集

# Prompt to generate questions
qa_generate_prompt_tmpl = """\
Context information is below.

---------------------
{context_str}
---------------------

Given the context information and not prior knowledge.
generate only questions based on the below query.

You are a Professor. Your task is to setup \
{num_questions_per_chunk} questions for an upcoming \
quiz/examination. The questions should be diverse in nature \
across the document. The questions should not contain options, not start with Q1/ Q2. \
Restrict the questions to the context information provided.\
"""
# The questions should be solely based on the provided context information, and please pose them in Chinese.\
qa_dataset = generate_question_context_pairs(
    nodes,
    llm=llm_model,
    num_questions_per_chunk=2,
    qa_generate_prompt_tmpl=qa_generate_prompt_tmpl
)
qa_dataset.save_json("pg_eval_dataset.json")
# qa_dataset = EmbeddingQAFinetuneDataset.from_json("pg_eval_dataset.json")
import pandas as pd

def display_results(eval_results):
    """
        计算hit_rate和mrr的平均值
    """

    metric_dicts = []
    for eval_result in eval_results:
        metric_dict = eval_result.metric_vals_dict
        metric_dicts.append(metric_dict)

    full_df = pd.DataFrame(metric_dicts)

    hit_rate = full_df["hit_rate"].mean()
    mrr = full_df["mrr"].mean()

    metric_df = pd.DataFrame(
        {"hit_rate": [hit_rate], "mrr": [mrr]}
    )
    return metric_df
class JieRetriever(BM25Retriever, BaseRetriever):
    def _get_scored_nodes(self, query: str):
        tokenized_query = self._tokenizer(query)
        doc_scores = self.bm25.get_scores(tokenized_query)
        nodes = []
        for i, node in enumerate(self._nodes):
            node_new = Node.from_dict(node.to_dict())
            node_with_score = NodeWithScore(node=node_new, score=doc_scores[i])
            nodes.append(node_with_score)
        return nodes
    
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        if query_bundle.custom_embedding_strs or query_bundle.embedding:
            logger.warning("BM25Retriever does not support embeddings, skipping...")

        scored_nodes = self._get_scored_nodes(query_bundle.query_str)

        # Sort and get top_k nodes, score range => 0..1, closer to 1 means more relevant
        nodes = sorted(scored_nodes, key=lambda x: x.score or 0.0, reverse=True)
        return nodes[:self._similarity_top_k]
retriever = JieRetriever.from_defaults(
# retriever = BM25Retrieve r.from_defaults(
                            nodes=nodes,
                            similarity_top_k=10
                        )

现在llamaindex在使用BM25Retrieve会报错,故笔者创建了JieRetriever,具体请点击查看链接

from llama_index.core.base.base_retriever import BaseRetriever

retriever_evaluator = RetrieverEvaluator.from_metric_names(
            ["mrr", "hit_rate"], retriever=retriever
        )
eval_results = await retriever_evaluator.aevaluate_dataset(qa_dataset)
for idx, item in enumerate(eval_results):
    if idx == 15:
        break
    d = item.metric_vals_dict
    mrr, hit_rate = d['mrr'], d['hit_rate']
    if mrr != 1 or hit_rate != 1:
        print(mrr, hit_rate, item.expected_ids, item.retrieved_ids)

下图展示了hit_rate 和 mrr 的计算:
在这里插入图片描述

结合下述结果,分析一下 hit_rate 和 mrr:

0.5 1.0 ['node_2'] ['node_71', 'node_2', 'node_0', 'node_199', 'node_126', 'node_419', 'node_446', 'node_218', 'node_1', 'node_70']
  • ['node_2'] 是 label
  • ['node_71', 'node_2', 'node_0', 'node_199', 'node_126', 'node_419', 'node_446', 'node_218', 'node_1', 'node_70'] 是检索器召回的候选列表;
  • mrr : 0.5;'node_2' 在候选列表的第二个位置,故mrr为 二分之一。在第几位就是几分之一;
  • hit_rate:代表label是否在候选集中,在就是1,不在就是0;
def display_results(eval_results):
    """
    	计算平均 hit_rate 和 mrr
    """

    metric_dicts = []
    for eval_result in eval_results:
        metric_dict = eval_result.metric_vals_dict
        metric_dicts.append(metric_dict)

    full_df = pd.DataFrame(metric_dicts)

    hit_rate = full_df["hit_rate"].mean()
    mrr = full_df["mrr"].mean()

    metric_df = pd.DataFrame(
        {"hit_rate": [hit_rate], "mrr": [mrr]}
    )
    return metric_df
display_results(eval_results)

在这里插入图片描述

公开

生成的评估数据集和相应示例代码,已上传到modelscope平台;

https://www.modelscope.cn/datasets/jieshenai/paul_graham_essay_rag/files

在这里插入图片描述

参考资料

  • https://www.llamaindex.ai/blog/boosting-rag-picking-the-best-embedding-reranker-models-42d079022e83

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/518508.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AI图片智能选区抠像解决方案

高质量的图片处理往往依赖于繁琐的手动操作,耗费大量时间与精力。美摄科技推出了一款革命性的AI图片智能选区抠像解决方案,旨在帮助企业轻松实现图片的高效处理,提升内容创作效率与质量。 美摄科技的AI图片智能选区抠像解决方案,…

An Aspect-Based Engine

GPU Pro 译: By 王钰涵 2024 4.14 10.1 Introduction(简介) 引擎的定义在整个行业中有所不同。在最基本的层面上,该术语描述了一个代码库,它在多个项目中提供共同的功能。其目的是分享开发这些功能所需的资源成本…

知网参考文献引用格式转latex中BibTex-Python操作

处理思路 参考 处理步骤: (单条处理:) 1、选知网NoteExpress格式的2-7行复制信息 2、新建一个文本文件,命名为cite.txt,把知网所复制信息粘贴进来 (txt文件保存编码ANSI可行) 3、…

GD32F470_TTP224 4路 电容式 触摸开关 数字触摸传感器模块移植

2.8 TTP224触摸传感器 该模块是一个基于触摸检测IC(TTP223B)的电容式点动型触摸开关模块。常态下,模块输出低电平,模式为低功耗模式;当用手指触摸相应位置时,模块会输出高电平,模式切换为快速模式;当持续12秒没有触摸时…

C#智慧手麻系统源码 医院手术麻醉系统源码 支持三甲医院评级需求 可提供演示

C#智慧手麻系统源码 医院手术麻醉系统源码 支持三甲医院评级需求 可提供演示 手术麻醉管理系统是应用于医院手术室、麻醉科室的计算机软件系统。该系统针对整个围术期,对病人进行全程跟踪与信息管理,自动集成病人HIS、LIS、RIS、PACS信息,采…

吃豆豆 经典的区间DP 好题典题

这里很巧妙的注意一点是,你最后要把所有的豆子都吃掉,所以你只要看你多增加的尽量的少就好了 然后维护一段区间,表示的是吃掉这段区间里面的所有豆子的最小代价,然后发现最后一个是左端点或者右端点 你吃一段新的区间的同时会把…

c++的学习之路:11、string(3)

昨天写string的时候没有说全,这里就开始接着讲。 目录 一、resize 二、insert 三、erase 一、resize 昨天说这个的时候没有考虑到缩小范围时咋处理,然后发现报错了,接着我调试发现缩小就不能正常执行了,因为用的是strcap所以…

有关字符串算法

例题一 解法: 算法思路(两两⽐较): 我们可以先找出前两个的最⻓公共前缀,然后拿这个最⻓公共前缀依次与后⾯的字符串⽐较,这样就可以找出所有字符串的最⻓公共前缀。 例题二 解法(中⼼扩散&am…

UNIAPP(小程序)每十个文章中间一个广告

三十秒刷新一次广告 ad-intervals"30" <template><view style"margin: 30rpx;"><view class"" v-for"(item,index) in 100"><!-- 广告 --><view style"margin-bottom: 20rpx;" v-if"(inde…

win10电脑无线网卡优化

近期win10会频繁断网&#xff0c;无任何规律。目前整理搜索后使用以下两种方法优化网卡&#xff0c;更改配置后断网问题得到有效改善。 方法一&#xff1a;在【电源管理】中取消勾选【允许计算机关闭此设备以节约电源】 方法二&#xff1a;【Preferred enable】修改为prefer 5…

R语言数据操纵:常用函数

这篇文章主要介绍R语言中处理循环&#xff0c;排序&#xff0c;总结重要信息的常用函数。 处理循环的函数 lapply函数 这个函数就是俗称的一句话循环函数&#xff0c;不同于while循环或者for循环&#xff0c;这个函数可以实现一句话就是一个循环的效果。 具体格式为lapply(…

C语言数据结构专题--顺序表(1基础)

前言 我们在对C语言有一定的了解之后&#xff0c;我们就可以开始数据结构的学习了&#xff0c;数据结构多用指针、结构体、动态内存开辟等知识&#xff0c;若对这些知识还不太了解的朋友&#xff0c;就需要加深其理解了&#xff0c;那么废话不多说&#xff0c;我们正式开始本节…

36.基于CAS实现的java类

JUC, java.util.concurrent并发工具包下。 1.原子整数 AtomicInteger AtomicLong AtomicBoolean 底层用的CAS来实现。 AtomicInteger类的incrementAndGet方法&#xff0c;addAndGet方法 public static void main(String[] args) {AtomicInteger atomicInteger new Atom…

一文搞懂 ThreadLocal

简介 ThreadLocal存取的数据&#xff0c;总是与当前线程相关&#xff0c;也就是说&#xff0c;JVM 为每个运行的线程&#xff0c;绑定了私有的本地实例存取空间&#xff0c;从而为多线程环境常出现的并发访问问题提供了一种隔离机制。 ThreadLocal的作用是提供线程内的局部变…

未授权访问-api接口

特别注意api接口的一些命名规则 常见的是这种&#xff0c;具体要看开发人员怎么命名的 而确认api路径的最好办法还是去多出发几个功能点&#xff0c;看他的路径&#xff0c;比如下面触发多个功能点 对比得知两个路径都有pyr/user/这时候可能就会觉得这就是api路径&#xff0c;但…

Azure的VFP和虚拟IP地址

Azure 的Virtual filtering platform (VFP) 是Azure 网络地址转换,端口转换和端口分配的基础。 下面我们来深入介绍一下VFP的工作方式。 VFP的出站动作。 对于客户端地址作为虚拟IP的出站目的地址的时候,VFP 驱动会负责做以下两个动作。 源地址转换。端口地址转换。VFP 和 S…

一分钟了解mos管选型

在选择MOS管时&#xff0c;需要考虑多个关键参数以确保选用的MOS管能够满足特定应用的需求。下面是一些主要参数的介绍 额定电压&#xff08;Vds&#xff09; 也称为漏源电压&#xff0c;通常我们所说的耐压&#xff0c;是指MOS管能够承受的最大电压差。 在选择MOS管时&#xf…

数据湖概述:大数据演进阶段-数据湖

文章目录 一. 大数据发展过程1. 离线大数据平台2. Lambda架构&#xff1a;速度层批层3. Kappa架构&#xff1a;流批一体4. 大数据架构痛点总结 二. 数据湖助力于解决数据仓库痛点问题1. 数据湖特点2. 开源数据湖的架构 三. 数据湖和数据仓库理念的对比1. 数据湖和数据仓库对比2…

c++的STL(7) -- stack

stack容器概述 stack容器其实是实现了一种和栈相同结构的容器。 如图&#xff0c;栈这种结构有两端: 栈底和栈顶。 特殊之处在于&#xff0c;这种结构&#xff0c;我们对数据的操作(删除数据&#xff0c;修改数据&#xff0c;查询数据&#xff0c;添加数据)只能在一端进行(栈…