Self-RAG
Self-RAG
概述
Self-RAG(Self-Reflective Retrieval-Augmented Generation)是一种增强型的RAG(检索增强生成)策略,结合了自我反思和自我评分机制,以提高检索文档和生成内容的质量。通过对检索到的文档和生成的回答进行多层次的评估,Self-RAG旨在减少错误信息(如幻觉)并提升回答的相关性和准确性。
-
是否从检索器中检索:
- 输入:x(问题)或 x(问题),y(生成内容)
- 决策:决定何时使用R(检索)来获取D个文档块
- 输出:
yes
、no
、continue
-
检索到的段落D是否与问题x相关:
- 输入:对于D中的每个d(文档块),x(问题)
- 决策:d是否提供了解决x的有用信息
- 输出:
relevant
、irrelevant
-
LLM从D中的每个段落生成的内容是否与段落相关(防止幻觉等):
- 输入:x(问题)、d(段落)、y(生成内容)对于D中的每个d
- 决策:y(生成内容)中的所有需要验证的陈述是否由d支持
- 输出:
fully supported
、partially supported
、no support
-
LLM从D中的每个段落生成的内容是否对x(问题)有用:
- 输入:x(问题)、y(生成内容)对于D中的每个d
- 决策:y(生成内容)是否对x(问题)有用
- 输出:
{5, 4, 3, 2, 1}
(评分)
我们将使用LangGraph从头实现这些理念。作为第一步,我们将跳过知识精炼阶段,并在发现任何不相关的文档时,通过网络搜索补充检索结果。同时,我们将使用查询重写来优化网络搜索的查询。
系统架构图
系统的图形化表示如下所示:
设置环境(Setup)
首先,下载所需的包并设置必要的API密钥。
1. 安装必要的包
在Jupyter Notebook或终端中运行以下命令安装所需的包:
pip install -U langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph tavily-python
2. 设置API密钥
接下来,设置OpenAI和Tavily的API密钥。以下代码将提示您输入API密钥并将其存储在环境变量中:
import getpass
import os
def _set_env(key: str):
if key not in os.environ:
os.environ[key] = getpass.getpass(f"{key}:")
_set_env("OPENAI_API_KEY")
_set_env("TAVILY_API_KEY")
3. 设置LangSmith用于LangGraph开发
LangSmith是一个用于调试、测试和监控LangGraph项目的工具。通过注册LangSmith,您可以使用跟踪数据来优化LangGraph应用程序的性能。详细的注册和使用方法请参考LangSmith入门指南。
创建索引(Create Index)
1. 构建索引
我们首先需要构建一个文档索引,以便后续的检索和生成过程。
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
# 设置嵌入模型
embd = OpenAIEmbeddings()
# 要索引的文档URL
urls = [
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]
# 使用WebBaseLoader加载文档
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
# 使用RecursiveCharacterTextSplitter拆分文档
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
# 将拆分后的文档添加到向量存储中
vectorstore = Chroma.from_documents(
documents=doc_splits,
collection_name="rag-chroma",
embedding=embd,
)
retriever = vectorstore.as_retriever()
解释:
- WebBaseLoader:从指定的URL递归加载网页内容。
- RecursiveCharacterTextSplitter:将长文档拆分成较小的块,以便LLM更高效地处理。
- Chroma:使用向量存储(vectorstore)管理文档的嵌入向量,并提供高效的相似度检索。
- retriever:将向量存储作为检索器,供LLM调用以获取相关文档。
LLMs 配置
使用Pydantic与LangChain
此部分使用Pydantic v2的BaseModel
,需要langchain-core >= 0.3
。使用langchain-core < 0.3
将导致因混合使用Pydantic v1和v2而出错。
1. 检索评分器(Retrieval Grader)
检索评分器用于评估检索到的文档是否与用户问题相关。
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
# 数据模型
class GradeDocuments(BaseModel):
"""评估检索文档相关性的二元评分。"""
binary_score: str = Field(
description="文档是否与问题相关,'yes' 或 'no'"
)
# 配置 LLM 并设置结构化输出
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)
# 提示语
system = """你是一个评分员,负责评估检索到的文档是否与用户的问题相关。
这不需要是严格的测试,目标是过滤掉错误的检索结果。
如果文档包含与用户问题相关的关键词或语义含义,请将其评分为相关。
请给出二元评分“yes”或“no”,以指示文档是否与问题相关。"""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
# 两个\n的可读性会更高
("human", "检索到的文档:\n\n {document} \n\n 用户问题:{question}"),
]
)
retrieval_grader = grade_prompt | structured_llm_grader
question = "agent memory"
docs = retriever.invoke(question)
doc_txt = docs[1].page_content
print(retrieval_grader.invoke({"question": question, "document": doc_txt}))
# 输出示例
binary_score='no'
解释:
- GradeDocuments:定义了评分器的输出结构,包括
binary_score
字段,值为"yes"
或"no"
。 - grade_prompt:定义了用于评估文档相关性的提示模板。
- retrieval_grader:结合了提示模板和LLM的评分链。
- 示例调用:评估特定文档是否与用户问题相关。
2. 生成回答节点(Generate)
生成回答节点基于检索到的文档生成最终回答。
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
# 获取提示模板
prompt = hub.pull("rlm/rag-prompt")
# 初始化LLM
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
# 后处理函数:格式化文档
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# 构建RAG链
rag_chain = prompt | llm | StrOutputParser()
# 运行RAG链
generation = rag_chain.invoke({"context": docs, "question": question})
print(generation)
# 输出示例
# The design of generative agents combines LLM with memory, planning, and reflection mechanisms to enable agents to behave conditioned on past experience. Memory stream is a long-term memory module that records a comprehensive list of agents' experience in natural language. LLM functions as the agent's brain in an autonomous agent system.
解释:
- hub.pull(“rlm/rag-prompt”):从LangChain Hub拉取预定义的RAG提示模板。
- rag_chain:结合提示模板和LLM,创建一个RAG链。
- generation:基于上下文和用户问题生成的回答。
3. 幻觉评分器(Hallucination Grader)
幻觉评分器用于评估生成的回答是否基于检索到的事实,防止模型生成虚假信息。
# 数据模型
class GradeHallucinations(BaseModel):
"""评估回答中是否存在幻觉的二元评分。"""
binary_score: str = Field(
description="回答是否基于事实支持,'yes' 或 'no'"
)
# 配置 LLM 并设置结构化输出
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeHallucinations)
# 提示语
system = """你是一个评分员,负责评估LLM生成的回答是否基于一组检索到的事实。
请给出二元评分“yes”或“no”。“yes”表示回答是基于这些事实支持的。"""
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "事实集:\n\n {documents} \n\n LLM生成的回答:{generation}"),
]
)
hallucination_grader = hallucination_prompt | structured_llm_grader
hallucination_grader.invoke({"documents": docs, "generation": generation})
解释:
- GradeHallucinations:定义了幻觉评分器的输出结构,包括
binary_score
字段,值为"yes"
或"no"
。 - hallucination_prompt:定义了用于评估幻觉的提示模板。
- hallucination_grader:结合了提示模板和LLM的幻觉评分链。
- 示例调用:评估生成的回答是否基于检索到的事实。
4. 答案评分器(Answer Grader)
答案评分器用于评估生成的回答是否有效地回答了用户的问题。
# 数据模型
class GradeAnswer(BaseModel):
"""评估回答是否解决问题的二元评分。"""
binary_score: str = Field(
description="回答是否解决了问题,'yes' 或 'no'"
)
# 配置 LLM 并设置结构化输出
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeAnswer)
# 提示语
system = """你是一个评分员,负责评估一个回答是否解决了用户的问题。
请给出二元评分“yes”或“no”。“yes”表示回答解决了问题。"""
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "用户问题:\n\n {question} \n\n LLM生成的回答:{generation}"),
]
)
answer_grader = answer_prompt | structured_llm_grader
answer_grader.invoke({"question": question, "generation": generation})
解释:
- GradeAnswer:定义了答案评分器的输出结构,包括
binary_score
字段,值为"yes"
或"no"
。 - answer_prompt:定义了用于评估答案的提示模板。
- answer_grader:结合了提示模板和LLM的答案评分链。
- 示例调用:评估生成的回答是否有效地解决了用户的问题。
5. 问题重写器(Question Re-writer)
问题重写器用于优化用户的问题,以提高检索效果。
# LLM 配置
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
# 提示语
system = """你是一个问题重写器,负责将输入的问题转换为更好的版本,以优化向量存储的检索效果。
请查看输入的问题,并尝试推理其潜在的语义意图或含义。"""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
(
"human",
"这是初始问题:\n\n {question} \n 请制定一个改进后的问题。",
),
]
)
question_rewriter = re_write_prompt | llm | StrOutputParser()
question_rewriter.invoke({"question": question})
# 示例输出
"什么是agent记忆在代理功能中的作用?"
解释:
- re_write_prompt:定义了用于重写问题的提示模板。
- question_rewriter:结合了提示模板和LLM的问题重写链。
- 示例调用:优化用户的原始问题,以提高检索效果。
6. 网络搜索工具(Web Search Tool)
网络搜索工具用于处理与近期事件相关的问题,通过网络搜索获取最新信息。
from langchain_community.tools.tavily_search import TavilySearchResults
# 初始化网络搜索工具
web_search_tool = TavilySearchResults(k=3)
解释:
- TavilySearchResults:定义了网络搜索工具,设置返回结果的数量为3。
- web_search_tool:网络搜索工具实例,供后续调用以获取相关信息。
构建图(Construct the Graph)
1. 定义图状态(Define Graph State)
首先,定义图的状态结构,包含问题、生成的回答、是否进行网络搜索以及相关文档列表。
from typing import List
from typing_extensions import TypedDict
class GraphState(TypedDict):
"""
表示图的状态。
属性:
question: 用户问题
generation: LLM生成的回答
web_search: 是否进行网络搜索
documents: 文档列表
"""
question: str
generation: str
web_search: str
documents: List[str]
解释:
- GraphState:定义了图的状态结构,包括用户问题(
question
)、生成的回答(generation
)、是否进行网络搜索(web_search
)和相关文档列表(documents
)。
2. 定义图流程(Define Graph Flow)
构建图的逻辑流程,包括检索、生成、评分和重写等节点。
from langchain.schema import Document
def retrieve(state):
"""
检索文档
Args:
state (dict): 当前图的状态
Returns:
dict: 更新状态,包含检索到的文档
"""
print("---RETRIEVE---")
question = state["question"]
# 调用检索器
documents = retriever.invoke(question)
return {"documents": documents, "question": question}
def generate(state):
"""
生成回答
Args:
state (dict): 当前图的状态
Returns:
dict: 更新状态,包含生成的回答
"""
print("---GENERATE---")
question = state["question"]
documents = state["documents"]
# RAG生成
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
def grade_documents(state):
"""
确定检索到的文档是否与问题相关
Args:
state (dict): 当前图的状态
Returns:
dict: 更新状态,包含筛选后的相关文档
"""
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["question"]
documents = state["documents"]
# 评分每个文档
filtered_docs = []
web_search = "No"
for d in documents:
score = retrieval_grader.invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
web_search = "Yes"
continue
return {"documents": filtered_docs, "question": question, "web_search": web_search}
def transform_query(state):
"""
转换查询,生成更好的问题
Args:
state (dict): 当前图的状态
Returns:
dict: 更新状态,包含重新表述的问题
"""
print("---TRANSFORM QUERY---")
question = state["question"]
documents = state["documents"]
# 重写问题
better_question = question_rewriter.invoke({"question": question})
return {"documents": documents, "question": better_question}
def web_search(state):
"""
基于重新表述的问题进行网络搜索
Args:
state (dict): 当前图的状态
Returns:
dict: 更新状态,包含网络搜索结果
"""
print("---WEB SEARCH---")
question = state["question"]
documents = state["documents"]
# 网络搜索
docs = web_search_tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
documents.append(web_results)
return {"documents": documents, "question": question}
解释:
- retrieve:根据用户问题调用检索器,获取相关文档。
- generate:基于检索到的文档生成回答。
- grade_documents:评估每个检索到的文档是否与用户问题相关,并筛选出相关文档。如果有任何文档不相关,则标记需要进行网络搜索。
- transform_query:优化用户的问题,以提高检索效果。
- web_search:针对不相关的文档,通过网络搜索获取补充信息,并将搜索结果添加到文档列表中。
3. 定义边(Edges)
定义节点之间的连接关系,决定流程的执行顺序。
def decide_to_generate(state):
"""
决定是否生成回答,或重新生成问题
Args:
state (dict): 当前图的状态
Returns:
str: 决策结果,决定下一步调用的节点
"""
print("---ASSESS GRADED DOCUMENTS---")
web_search = state["web_search"]
filtered_documents = state["documents"]
if web_search == "Yes":
# 有不相关的文档,需要进行网络搜索并重新生成问题
print(
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
)
return "transform_query"
else:
# 有相关文档,生成回答
print("---DECISION: GENERATE---")
return "generate"
def grade_generation_v_documents_and_question(state):
"""
确定生成的回答是否基于文档且回答了问题
Args:
state (dict): 当前图的状态
Returns:
str: 决策结果,决定下一步调用的节点
"""
print("---CHECK HALLUCINATIONS---")
question = state["question"]
documents = state["documents"]
generation = state["generation"]
score = hallucination_grader.invoke(
{"documents": documents, "generation": generation}
)
grade = score.binary_score
# 检查幻觉
if grade == "yes":
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
# 检查回答是否解决了问题
print("---GRADE GENERATION vs QUESTION---")
score = answer_grader.invoke({"question": question, "generation": generation})
grade = score.binary_score
if grade == "yes":
print("---DECISION: GENERATION ADDRESSES QUESTION---")
return "useful"
else:
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
return "not useful"
else:
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
return "not supported"
解释:
- decide_to_generate:根据文档评分结果,决定是否生成回答或重新转换查询。如果有不相关的文档,则需要进行网络搜索并优化问题。
- grade_generation_v_documents_and_question:评估生成的回答是否基于检索到的文档且有效回答了用户的问题。如果回答不符合要求,则重新生成或优化问题。
4. 编译图(Compile Graph)
使用StateGraph
将所有节点和边连接起来,并编译图。
from langgraph.graph import END, StateGraph, START
workflow = StateGraph(GraphState)
# 定义节点
workflow.add_node("retrieve", retrieve) # 检索节点
workflow.add_node("grade_documents", grade_documents) # 评估文档相关性节点
workflow.add_node("generate", generate) # 生成回答节点
workflow.add_node("transform_query", transform_query) # 转换查询节点
workflow.add_node("web_search_node", web_search) # 网络搜索节点
# 定义边
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "web_search_node")
workflow.add_edge("web_search_node", "generate")
workflow.add_edge("generate", END)
# 编译图
app = workflow.compile()
解释:
- workflow.add_node:将各个节点添加到图中。
- “retrieve”:检索节点,负责从向量存储中获取相关文档。
- “grade_documents”:评估文档相关性节点,筛选相关文档。
- “generate”:生成回答节点,基于相关文档生成最终回答。
- “transform_query”:转换查询节点,优化用户问题以提高检索效果。
- “web_search_node”:网络搜索节点,处理需要通过网络搜索获取的补充信息。
- workflow.add_edge:定义节点之间的直接连接。
- START -> “retrieve”:流程从检索节点开始。
- “retrieve” -> “grade_documents”:检索后评估文档相关性。
- “grade_documents” -> “transform_query” 或 “generate”:根据评估结果决定下一步。
- “transform_query” -> “web_search_node”:优化问题后进行网络搜索。
- “web_search_node” -> “generate”:获取补充信息后生成回答。
- “generate” -> END:生成回答后结束流程。
使用图(Use the Graph)
1. 导入必要模块
from pprint import pprint
2. 运行
定义输入并通过图进行处理。
# 示例调用1
inputs = {"question": "Explain how the different types of agent memory work?"}
for output in app.stream(inputs):
for key, value in output.items():
# 节点
pprint(f"Node '{key}':")
# 可选:打印每个节点的完整状态
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
pprint("\n---\n")
# 最终生成的回答
pprint(value["generation"])
# 示例调用2
inputs = {"question": "How does the AlphaCodium paper work?"}
for output in app.stream(inputs):
for key, value in output.items():
# 节点
pprint(f"Node '{key}':")
# 可选:打印每个节点的完整状态
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
pprint("\n---\n")
# 最终生成的回答
pprint(value["generation"])
解释:
-
示例调用1:
- 用户问题被路由到
retrieve
节点,从向量存储中检索相关文档。 - grade_documents节点评估每个文档的相关性,筛选出相关文档。
- decide_to_generate节点决定生成回答。
- generate节点基于相关文档生成最终回答。
- 最终回答展示。
- 用户问题被路由到
-
示例调用2:
- 用户问题被路由到
retrieve
节点,从向量存储中检索相关文档。 - grade_documents节点评估每个文档的相关性,发现部分文档不相关,标记需要进行网络搜索。
- decide_to_generate节点决定需要优化查询。
- transform_query节点优化用户问题。
- web_search_node节点通过网络搜索获取补充信息。
- generate节点基于补充信息生成最终回答。
- 最终回答展示。
- 用户问题被路由到
汇总
好的,以下是将上述使用 LangGraph 实现的 Self-RAG 相关代码汇总到一个完整的 Python 文件中的示例。该文件包含了环境设置、文档检索、评分器的定义、图的构建以及执行流程。请按照以下步骤操作:
-
确保安装所需的包:在执行脚本之前,请确保已经安装了所有必要的 Python 包。您可以使用以下命令来安装这些包:
pip install -U langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph pydantic
-
设置 OpenAI API 密钥:脚本会提示您输入
OPENAI_API_KEY
。您需要确保拥有有效的 OpenAI API 密钥。 -
执行脚本:将以下代码保存为
self_rag.py
,然后在终端中运行:python self_rag.py
以下是完整的 self_rag.py
文件内容:
# self_rag.py
import getpass
import os
from typing import List
from typing_extensions import TypedDict
from pprint import pprint
# LangChain 和 LangGraph 相关导入
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain import hub
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from langgraph.graph import END, StateGraph, START
# 环境设置函数
def _set_env(key: str):
if key not in os.environ:
os.environ[key] = getpass.getpass(f"{key}:")
_set_env("OPENAI_API_KEY")
# 设置 LangSmith(可选)
# 您可以根据需要配置 LangSmith,以便进行调试和监控
# 详细信息请参考 LangSmith 的官方文档
# 定义图的状态
class GraphState(TypedDict):
"""
表示图的状态。
属性:
question: 用户问题
generation: LLM生成的回答
documents: 文档列表
"""
question: str
generation: str
documents: List[str]
# 初始化检索器
def initialize_retriever():
urls = [
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]
# 加载文档
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
# 分割文档
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
# 添加到向量数据库
vectorstore = Chroma.from_documents(
documents=doc_splits,
collection_name="rag-chroma",
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()
return retriever
# 定义评分器的数据模型和函数
# 1. 检索评分器(Retrieval Grader)
class GradeDocuments(BaseModel):
"""评估检索文档相关性的二元评分。"""
binary_score: str = Field(
description="文档是否与问题相关,'yes' 或 'no'"
)
def initialize_retrieval_grader():
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)
system = """你是一个评分员,负责评估检索到的文档是否与用户的问题相关。
这不需要是严格的测试,目标是过滤掉错误的检索结果。
如果文档包含与用户问题相关的关键词或语义含义,请将其评分为相关。
请给出二元评分“yes”或“no”,以指示文档是否与问题相关。"""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "检索到的文档:\n\n {document} \n\n 用户问题:{question}"),
]
)
retrieval_grader = grade_prompt | structured_llm_grader
return retrieval_grader
# 2. 幻觉评分器(Hallucination Grader)
class GradeHallucinations(BaseModel):
"""评估回答中是否存在幻觉的二元评分。"""
binary_score: str = Field(
description="回答是否基于事实支持,'yes' 或 'no'"
)
def initialize_hallucination_grader():
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeHallucinations)
system = """你是一个评分员,负责评估LLM生成的回答是否基于一组检索到的事实。
请给出二元评分“yes”或“no”。“yes”表示回答是基于这些事实支持的。"""
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "事实集:\n\n {documents} \n\n LLM生成的回答:{generation}"),
]
)
hallucination_grader = hallucination_prompt | structured_llm_grader
return hallucination_grader
# 3. 回答评分器(Answer Grader)
class GradeAnswer(BaseModel):
"""评估回答是否解决问题的二元评分。"""
binary_score: str = Field(
description="回答是否解决了问题,'yes' 或 'no'"
)
def initialize_answer_grader():
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeAnswer)
system = """你是一个评分员,负责评估一个回答是否解决了用户的问题。
请给出二元评分“yes”或“no”。“yes”表示回答解决了问题。"""
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "用户问题:\n\n {question} \n\n LLM生成的回答:{generation}"),
]
)
answer_grader = answer_prompt | structured_llm_grader
return answer_grader
# 4. 问题重写器(Question Re-writer)
def initialize_question_rewriter():
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
system = """你是一个问题重写器,负责将输入的问题转换为更好的版本,以优化向量存储的检索效果。
请查看输入的问题,并尝试推理其潜在的语义意图或含义。"""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
(
"human",
"这是初始问题:\n\n {question} \n 请制定一个改进后的问题。",
),
]
)
question_rewriter = re_write_prompt | llm | StrOutputParser()
return question_rewriter
# 生成回答(Generate)
def initialize_rag_chain():
prompt = hub.pull("rlm/rag-prompt")
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = prompt | llm | StrOutputParser()
return rag_chain
# 定义节点函数
def retrieve(state, retriever):
"""
检索文档
参数:
state (dict): 当前图的状态
retriever: 检索器对象
返回:
dict: 更新后的状态,包含检索到的文档
"""
print("---检索---")
question = state["question"]
# 检索
documents = retriever.invoke(question)
return {"documents": documents, "question": question}
def generate(state, rag_chain):
"""
生成回答
参数:
state (dict): 当前图的状态
rag_chain: RAG生成链对象
返回:
dict: 更新后的状态,包含生成的回答
"""
print("---生成回答---")
question = state["question"]
documents = state["documents"]
# RAG生成
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
def grade_documents(state, retrieval_grader):
"""
评估检索到的文档是否相关
参数:
state (dict): 当前图的状态
retrieval_grader: 检索评分器对象
返回:
dict: 更新后的状态,包含过滤后的相关文档
"""
print("---检查文档与问题的相关性---")
question = state["question"]
documents = state["documents"]
# 评分每个文档
filtered_docs = []
for d in documents:
score = retrieval_grader.invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
print("---评分:文档相关---")
filtered_docs.append(d)
else:
print("---评分:文档不相关---")
continue
return {"documents": filtered_docs, "question": question}
def transform_query(state, question_rewriter):
"""
优化问题
参数:
state (dict): 当前图的状态
question_rewriter: 问题重写器对象
返回:
dict: 更新后的状态,包含优化后的问题
"""
print("---优化问题---")
question = state["question"]
documents = state["documents"]
# 重写问题
better_question = question_rewriter.invoke({"question": question})
print(f"优化后的问题:{better_question}")
return {"documents": documents, "question": better_question}
def decide_to_generate(state):
"""
决定是否生成回答或重新优化问题
参数:
state (dict): 当前图的状态
返回:
str: 下一个节点的名称
"""
print("---评估已评分的文档---")
filtered_documents = state["documents"]
if not filtered_documents:
# 所有文档均不相关,重新优化问题
print("---决策:所有文档与问题不相关,优化问题---")
return "transform_query"
else:
# 有相关文档,生成回答
print("---决策:生成回答---")
return "generate"
def grade_generation_v_documents_and_question(state, hallucination_grader, answer_grader):
"""
评估生成的回答是否基于文档并解决了问题
参数:
state (dict): 当前图的状态
hallucination_grader: 幻觉评分器对象
answer_grader: 回答评分器对象
返回:
str: 下一个节点的名称
"""
print("---检查幻觉---")
question = state["question"]
documents = state["documents"]
generation = state["generation"]
score = hallucination_grader.invoke(
{"documents": documents, "generation": generation}
)
grade = score.binary_score
# 检查是否存在幻觉
if grade == "yes":
print("---决策:生成的回答基于文档---")
# 检查回答是否解决了问题
print("---评分生成的回答是否解决问题---")
score = answer_grader.invoke({"question": question, "generation": generation})
grade = score.binary_score
if grade == "yes":
print("---决策:生成的回答解决了问题---")
return "useful"
else:
print("---决策:生成的回答未解决问题---")
return "not useful"
else:
print("---决策:生成的回答未基于文档,重试---")
return "not supported"
# 构建并编译图
def build_workflow(retrieve_fn, grade_documents_fn, generate_fn, transform_query_fn,
decide_to_generate_fn, grade_generation_fn,
retriever, rag_chain, retrieval_grader, hallucination_grader, answer_grader, question_rewriter):
workflow = StateGraph(GraphState)
# 定义节点
workflow.add_node("retrieve", lambda state: retrieve_fn(state, retriever))
workflow.add_node("grade_documents", lambda state: grade_documents_fn(state, retrieval_grader))
workflow.add_node("generate", lambda state: generate_fn(state, rag_chain))
workflow.add_node("transform_query", lambda state: transform_query_fn(state, question_rewriter))
# 构建边
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate_fn,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
"generate",
lambda state: grade_generation_fn(state, hallucination_grader, answer_grader),
{
"not supported": "generate",
"useful": END,
"not useful": "transform_query",
},
)
# 编译图
app = workflow.compile()
return app
# 运行图
def run_workflow(app, inputs):
for output in app.stream(inputs):
for key, value in output.items():
# 打印每个节点的状态
pprint(f"节点 '{key}':")
# 可选:打印每个节点的详细状态
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
pprint("\n---\n")
# 打印最终生成的回答
pprint(value.get("generation", "没有生成回答"))
def main():
# 初始化组件
retriever = initialize_retriever()
retrieval_grader = initialize_retrieval_grader()
hallucination_grader = initialize_hallucination_grader()
answer_grader = initialize_answer_grader()
question_rewriter = initialize_question_rewriter()
rag_chain = initialize_rag_chain()
# 构建工作流
app = build_workflow(
retrieve_fn=retrieve,
grade_documents_fn=grade_documents,
generate_fn=generate,
transform_query_fn=transform_query,
decide_to_generate_fn=decide_to_generate,
grade_generation_fn=grade_generation_v_documents_and_question,
retriever=retriever,
rag_chain=rag_chain,
retrieval_grader=retrieval_grader,
hallucination_grader=hallucination_grader,
answer_grader=answer_grader,
question_rewriter=question_rewriter
)
# 运行第一个示例
print("=== 示例 1 ===")
inputs1 = {"question": "Explain how the different types of agent memory work?"}
run_workflow(app, inputs1)
# 运行第二个示例
print("\n=== 示例 2 ===")
inputs2 = {"question": "Explain how chain of thought prompting works?"}
run_workflow(app, inputs2)
if __name__ == "__main__":
main()
代码说明
-
环境设置:
- 使用
getpass
获取OPENAI_API_KEY
并设置为环境变量。 - 请确保在运行脚本时输入有效的 OpenAI API 密钥。
- 使用
-
检索器初始化:
- 从指定的 URL 加载文档。
- 使用
RecursiveCharacterTextSplitter
将文档分割成较小的片段。 - 将文档片段存储到 Chroma 向量数据库中,以便高效检索。
-
评分器初始化:
- 检索评分器(Retrieval Grader):评估检索到的文档是否与问题相关。
- 幻觉评分器(Hallucination Grader):评估生成的回答是否基于检索到的文档,避免幻觉。
- 回答评分器(Answer Grader):评估生成的回答是否解决了用户的问题。
-
问题重写器初始化:
- 优化用户输入的问题,以便更好地从向量存储中检索相关文档。
-
生成链初始化(Generate):
- 使用 LangChain 的
hub.pull("rlm/rag-prompt")
获取预定义的 RAG 提示语。 - 配置 LLM 生成回答。
- 使用 LangChain 的
-
节点函数定义:
- retrieve:检索相关文档。
- generate:基于检索到的文档生成回答。
- grade_documents:评估检索到的文档是否相关。
- transform_query:优化用户的问题。
- decide_to_generate:决定是生成回答还是优化问题。
- grade_generation_v_documents_and_question:评估生成的回答是否基于文档并解决了问题。
-
图的构建与编译:
- 使用 LangGraph 的
StateGraph
定义工作流图。 - 添加节点和边,定义流程控制逻辑。
- 编译图为可执行的应用对象
app
。
- 使用 LangGraph 的
-
运行图:
- 提供输入问题,运行整个工作流,生成并打印最终的回答。
- 示例中包含两个问题进行演示。
注意事项
-
包版本:请确保安装的包版本与代码兼容,尤其是
langchain
和langgraph
。如果遇到兼容性问题,请参考相应包的官方文档进行调整。 -
Prompt 模板:代码中使用了
hub.pull("rlm/rag-prompt")
来获取 RAG 提示语。请确保该提示语存在于 LangChain 的 Hub 中,或者根据需要自定义提示语。 -
错误处理:为了简化代码示例,未添加详细的错误处理逻辑。在实际应用中,建议添加适当的异常处理,以提高代码的鲁棒性。
-
LangSmith:代码中提到了 LangSmith,用于调试和监控 LangGraph 项目。如果需要使用,请参考 LangSmith 官方文档 进行配置。
执行示例
运行脚本后,您将看到如下类似的输出:
=== 示例 1 ===
---检索---
---检查文档与问题的相关性---
---评分:文档不相关---
---评分:文档相关---
---评分:文档不相关---
---评分:文档相关---
---评估已评分的文档---
---决策:生成回答---
节点 'grade_documents':
---
---生成回答---
---检查幻觉---
---决策:生成的回答基于文档---
---评分生成的回答是否解决问题---
---决策:生成的回答解决了问题---
节点 'generate':
---
('短期记忆用于代理的上下文学习,使其能够快速学习。长期记忆使代理能够在较长时间内保留和回忆大量信息。代理还可以利用外部工具如API来访问超出其记忆存储的信息。')
=== 示例 2 ===
---检索---
---检查文档与问题的相关性---
---评分:文档相关---
---评分:文档不相关---
---评分:文档相关---
---评分:文档相关---
---评估已评分的文档---
---决策:生成回答---
节点 'grade_documents':
---
---生成回答---
---检查幻觉---
---决策:生成的回答基于文档---
---评分生成的回答是否解决问题---
---决策:生成的回答解决了问题---
节点 'generate':
---
('链式思维提示通过反复提示模型提出后续问题,以迭代构建思维过程。这种方法可以与查询相关实体和内容的搜索相结合,以将其添加回上下文中。它通过在每一步探索多种推理可能性,创建一个思维树结构,从而扩展了思维过程。')
结论
通过上述完整的 Python 脚本,您可以更好地理解和执行 Self-RAG 的实现流程。该脚本涵盖了从文档检索、评分、生成回答到评估和优化问题的整个过程,确保生成的回答具有高相关性和准确性。如果在执行过程中遇到任何问题,请确保所有依赖包已正确安装,并参考相关包的文档进行调试。
评估(Eval)
在本节中,我们将评估使用LangGraph实现的Self-RAG系统与基线方法(Context Stuffing)的性能对比。
1. 导入必要模块
import langsmith
from langsmith.schemas import Example, Run
from langsmith.evaluation import evaluate
2. 克隆公共数据集
克隆一个公共的LCEL问题数据集,用于评估:
client = langsmith.Client()
# 克隆数据集到您的租户
try:
public_dataset = (
"https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d"
)
client.clone_public_dataset(public_dataset)
except:
print("Please setup LangSmith")
解释:
- clone_public_dataset:将公共数据集克隆到您的LangSmith租户中,以便进行评估。
- public_dataset:指定要克隆的数据集URL。
3. 定义自定义评估器
创建两个评估器,用于检查生成的回答是否正确导入和执行。
def check_import(run: Run, example: Example) -> dict:
"""检查导入语句是否正确"""
imports = run.outputs.get("imports")
try:
exec(imports)
return {"key": "import_check", "score": 1}
except Exception:
return {"key": "import_check", "score": 0}
def check_execution(run: Run, example: Example) -> dict:
"""检查代码块是否能正确执行"""
imports = run.outputs.get("imports")
code = run.outputs.get("code")
try:
exec(imports + "\n" + code)
return {"key": "code_execution_check", "score": 1}
except Exception:
return {"key": "code_execution_check", "score": 0}
解释:
- check_import:尝试执行导入语句,如果成功,返回分数1;否则,返回分数0。
- check_execution:尝试执行导入语句和代码块,如果成功,返回分数1;否则,返回分数0。
4. 定义预测函数
定义两个预测函数,分别用于基线方法和Self-RAG方法。
def predict_base_case(example: dict):
"""基线方法:Context Stuffing"""
solution = code_gen_chain.invoke(
{"context": concatenated_content, "messages": [("user", example["question"])]}
)
return {"imports": solution.imports, "code": solution.code}
def predict_self_rag(example: dict):
"""Self-RAG方法"""
graph = app.invoke(
{
"question": example["question"],
"generation": "",
"web_search": "No",
"documents": []
}
)
solution = graph["generation"]
return {"imports": solution.imports, "code": solution.code}
解释:
- predict_base_case:使用基线方法(Context Stuffing)生成回答。
- predict_self_rag:使用Self-RAG方法生成回答。
5. 运行评估
使用LangSmith的evaluate
函数,分别评估基线方法和Self-RAG方法的性能。
# 评估器列表
code_evaluator = [check_import, check_execution]
# 数据集名称
dataset_name = "lcel-teacher-eval"
# 运行基线方法的评估
try:
experiment_results_ = evaluate(
predict_base_case,
data=dataset_name,
evaluators=code_evaluator,
experiment_prefix=f"test-without-langgraph-{llm.model}",
max_concurrency=2,
metadata={
"llm": llm.model,
},
)
except:
print("Please setup LangSmith")
# 运行Self-RAG方法的评估
try:
experiment_results = evaluate(
predict_self_rag,
data=dataset_name,
evaluators=code_evaluator,
experiment_prefix=f"test-with-langgraph-{llm.model}-{flag}",
max_concurrency=2,
metadata={
"llm": llm.model,
"feedback": flag,
},
)
except:
print("Please setup LangSmith")
解释:
- evaluate:运行评估,传入预测函数、数据集、评估器列表和其他配置参数。
- predict_base_case:基线方法的预测函数。
- predict_self_rag:Self-RAG方法的预测函数。
- code_evaluator:评估器列表,用于检查回答的导入和执行情况。
- experiment_prefix:定义实验的前缀,便于区分不同的实验结果。
- metadata:附加的元数据,用于记录LLM类型和反馈标志。
6. 结果
根据评估结果,Self-RAG方法的表现优于基线方法,特别是在添加重试机制后性能有所提升。然而,反思机制并未带来预期的改进,反而在某些情况下导致性能下降。此外,使用GPT-4模型的性能优于Claude3模型。
结果摘要:
- Self-RAG优于基线方法(Self-RAG outperforms base case):添加重试机制显著提高了性能。
- 反思机制未带来改进(Reflection did not help):在重试前进行反思反而导致性能下降,相比之下,直接将错误反馈给LLM更为有效。
- GPT-4优于Claude3(GPT-4 outperforms Claude3):GPT-4模型在执行工具调用时的错误率较低,表现优于Claude3模型。
您可以通过访问以下链接查看详细的评估结果:
评估结果链接
总结
通过本节的讲解,您已经学习了如何使用LangGraph实现一个Self-RAG系统。这个系统能够通过自我反思和评分机制,优化检索和生成过程,确保生成的回答既相关又准确。具体来说,您已经掌握了以下内容:
- 系统设置:安装必要的包,配置API密钥,并设置LangSmith进行开发和监控。
- 索引创建:使用
WebBaseLoader
和Chroma
创建检索工具,索引并检索相关文档。 - LLM配置:
- 使用OpenAI的GPT-3.5进行检索评分、幻觉评分和生成。
- 定义Pydantic模型来结构化存储生成的回答和评分结果。
- 构建检索评分器、幻觉评分器、答案评分器和生成链。
- 状态管理:定义图的状态结构,包括用户问题、生成的回答、是否进行网络搜索和相关文档列表。
- 图定义:
- 定义检索、生成、评分和重写的节点。
- 定义条件边路由,决定流程的执行顺序。
- 评估:
- 使用LangSmith的评估功能,比较Self-RAG方法与基线方法的性能。
- 通过自定义评估器检查回答的准确性和相关性。
下一步建议:
- 扩展功能:可以进一步扩展系统,如增加更多的单元测试,集成更多的工具或优化重试和反思机制。
- 优化路由逻辑:根据评估结果,优化路由器的决策逻辑,提高系统的鲁棒性和生成回答的质量。
- 多模型集成:结合不同的LLM模型,探索多模型协作的可能性,进一步提升回答的准确性和效率。
- 部署和监控:将系统部署到生产环境中,并使用LangSmith进行持续的监控和优化,确保系统稳定运行。