LangGraph自适应RAG
- 介绍
- 索引
- LLMs
- web 搜索工具
- graph
- graph state
- graph flow
- build graph
- 执行
介绍
自适应 RAG 是一种 RAG 策略,它将 (1) 查询分析 (2) 主动/自校正 RAG 结合起来。
在文章中,他们报告了查询分析到路由获取:
- No Retrieval
- Single-shot RAG
- Iterative RAG
让我们使用 LangGraph 在此基础上进行构建。
在我们的实现中,我们将在以下之间进行路由:
- 网络搜索:与最近事件相关的问题
- 自校正 RAG:针对与我们的索引相关的问题
索引
from typing import List
import requests
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Vearch
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (
BaseModel
)
from langchain_text_splitters import RecursiveCharacterTextSplitter
from common.constant import VEARCH_ROUTE_URL, BGE_M3_EMB_URL
class Bgem3Embeddings(BaseModel, Embeddings):
def embed_documents(self, texts: List[str]) -> List[List[float]]:
print(texts)
return []
def embed_query(self, text: str) -> List[float]:
if not text:
return []
return cop_embeddings(text)
"""
bg3m3转向量
"""
def cop_embeddings(input: str) -> list:
if not input.strip():
return []
headers = {
"Content-Type": "application/json"
}
params = {
"sentences": [input],
"type": "dense"
}
response = requests.post(BGE_M3_EMB_URL, headers=headers, json=params)
if response.status_code == 200:
cop_embeddings_result = response.json()
if not cop_embeddings_result or 'embeddings' not in cop_embeddings_result or not cop_embeddings_result[
'embeddings']:
return []
original_vector = cop_embeddings_result['embeddings'][0]
original_size = len(original_vector)
# 将1024的向量兼容为1536,适配openai向量接口
adaptor_vector = [0.0] * 1536
for i in range(min(original_size, 1536)):
adaptor_vector[i] = original_vector[i]
return adaptor_vector
else:
print(f"cop_embeddings error: {response.text}")
return []
# Docs to index
urls = [
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]
# 加载文档
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
# 文档分块
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=500, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
# 数据存储到向量库
embeddings_model = Bgem3Embeddings()
# embeddings_model, VEARCH_ROUTE_URL,"lanchain_autogpt","lanchain_autogpt_db", 3,
vectorstore = Vearch.from_documents(
documents=doc_splits,
embedding=embeddings_model,
path_or_url=VEARCH_ROUTE_URL,
table_name="lanchain_autogpt",
db_name="lanchain_autogpt_db",
flag=3
)
retriever = vectorstore.as_retriever()
LLMs
### Router
from typing import Literal
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from common.common import PROXY_URL, API_KEY
from index1 import retriever
# Data model
class RouteQuery(BaseModel):
"""将用户查询路由到最相关的数据源。"""
datasource: Literal["vectorstore", "web_search"] = Field(
...,
description="给定一个用户问题,选择将其发送到web_search或vectorstore。",
)
# LLM with function call
llm = ChatOpenAI(model_name="gpt-4o", api_key=API_KEY, base_url=PROXY_URL, temperature=0)
structured_llm_router = llm.with_structured_output(RouteQuery)
# Prompt
system = """你是将用户问题传送到vectorstore或web_search的专家。
vectorstore包含与agents、prompt engineering和adversarial attacks相关的文档。
使用向量库回答有关这些主题的问题。否则,请使用web_search。"""
route_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "{question}"),
]
)
question_router = route_prompt | structured_llm_router
# 案例1:数据源选择
# print(
# question_router.invoke(
# {"question": "谁将成为NFL选秀的第一人?"}
# )
# )
# print(question_router.invoke({"question": "Agent memory有哪些类型?"}))
# Data model
class GradeDocuments(BaseModel):
"""Binary score for relevance check on retrieved documents."""
binary_score: str = Field(
description="文档与问题相关, 'yes' or 'no'"
)
structured_llm_grader = llm.with_structured_output(GradeDocuments)
# Prompt
system = """你是一个评估检索到的文档和用户问题的相关性的分级员。 \n
如果文档包含与用户问题相关的关键字或语义,则将其评为相关。 \n
它不需要是一个严格的测试。目标是过滤掉错误的检索。 \n
给出二进制分数 'yes' or 'no' 表示文档是否与问题相关。"""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "检索到的文档: \n\n {document} \n\n 用户问题: {question}"),
]
)
retrieval_grader = grade_prompt | structured_llm_grader
question = "agent memory"
docs = retriever.get_relevant_documents(question)
doc_txt = docs[1].page_content
# 案例2:检索评估
# print(retrieval_grader.invoke({"question": question, "document": doc_txt}))
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
# Prompt
prompt = hub.pull("rlm/rag-prompt")
# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# Chain
rag_chain = prompt | llm | StrOutputParser()
# Run
# generation = rag_chain.invoke({"context": docs, "question": question})
# 案例3:prompt = hub.pull("rlm/rag-prompt") 生成
# print(generation)
# Data model
class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""
binary_score: str = Field(
description="答案以事实为基础, 'yes' 或 'no'"
)
structured_llm_grader = llm.with_structured_output(GradeHallucinations)
# Prompt
system = """你是一名评估LLM生成是否以一组检索到的事实为基础/支持的分级员。 \n
给一个二进制分数 'yes' 或 'no'. 'Yes' 意味着答案以一系列事实为基础。"""
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "设置事实: \n\n {documents} \n\n 大模型生成: {generation}"),
]
)
hallucination_grader = hallucination_prompt | structured_llm_grader
# hgr = hallucination_grader.invoke({"documents": docs, "generation": generation})
# 案例4:幻觉评估
# print(hgr)
### Answer Grader
# Data model
class GradeAnswer(BaseModel):
"""Binary score to assess answer addresses question."""
binary_score: str = Field(
description="答案是否解决了问题, 'yes' 或 'no'"
)
structured_llm_grader = llm.with_structured_output(GradeAnswer)
# Prompt你是一名评估答案是否能解决问题的评分员
system = """你是一名评估答案是否能解决问题的评分员 \n
给一个二进制分数 'yes' 或 'no'。 'Yes' 意味着答案解决了问题。"""
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "用户问题: \n\n {question} \n\n 大模型生成: {generation}"),
]
)
answer_grader = answer_prompt | structured_llm_grader
# agr = answer_grader.invoke({"question": question, "generation": generation})
# 案例5:答复评估
# print(agr)
# Prompt
system = """你是一个问题重写器,可以将输入问题转换为更好的版本,该版本针对向量库检索进行了优化。
查看输入并尝试推理潜在的语义意图/含义。"""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
(
"human",
"下面是原始问题: \n\n {question} \n 提出一个改进的问题。",
),
]
)
question_rewriter = re_write_prompt | llm | StrOutputParser()
# qrr = question_rewriter.invoke({"question": question})
# 案例6:问题重写
# print(qrr)
web 搜索工具
import os
from langchain_community.tools.tavily_search import TavilySearchResults
from common.common import TAVILY_API_KEY
# 提前通过 https://app.tavily.com/home 申请
os.environ["TAVILY_API_KEY"] = TAVILY_API_KEY
tavily_tool = TavilySearchResults(k=3)
graph
将流程捕获为图表。
graph state
from typing import List
from typing_extensions import TypedDict
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
question: question
generation: LLM generation
documents: list of documents
"""
question: str
generation: str
documents: List[str]
graph flow
from langchain.schema import Document
from index1 import retriever
from llm2 import rag_chain, retrieval_grader, question_rewriter, question_router, hallucination_grader, answer_grader
from webstool3 import web_search_tool
# 检索向量库中的doc
def retrieve(state):
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE---")
question = state["question"]
# Retrieval
documents = retriever.invoke(question)
return {"documents": documents, "question": question}
# 大模型生成
def generate(state):
"""
Generate answer
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation, that contains LLM generation
"""
print("---GENERATE---")
question = state["question"]
documents = state["documents"]
# RAG generation
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
# 文档评估
def grade_documents(state):
"""
Determines whether the retrieved documents are relevant to the question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates documents key with only filtered relevant documents
"""
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["question"]
documents = state["documents"]
# Score each doc
filtered_docs = []
for d in documents:
score = retrieval_grader.invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
continue
return {"documents": filtered_docs, "question": question}
# 输入重写
def transform_query(state):
"""
Transform the query to produce a better question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates question key with a re-phrased question
"""
print("---TRANSFORM QUERY---")
question = state["question"]
documents = state["documents"]
# Re-write question
better_question = question_rewriter.invoke({"question": question})
return {"documents": documents, "question": better_question}
# web搜索
def web_search(state):
"""
Web search based on the re-phrased question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates documents key with appended web results
"""
print("---WEB SEARCH---")
question = state["question"]
# Web search
docs = web_search_tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
return {"documents": web_results, "question": question}
### Edges ###
def route_question(state):
"""
Route question to web search or RAG.
Args:
state (dict): The current graph state
Returns:
str: Next node to call
"""
print("---ROUTE QUESTION---")
question = state["question"]
source = question_router.invoke({"question": question})
if source.datasource == "web_search":
print("---ROUTE QUESTION TO WEB SEARCH---")
return "web_search"
elif source.datasource == "vectorstore":
print("---ROUTE QUESTION TO RAG---")
return "vectorstore"
# 生成答案 还是 重新生成问题
def decide_to_generate(state):
"""
Determines whether to generate an answer, or re-generate a question.
Args:
state (dict): The current graph state
Returns:
str: Binary decision for next node to call
"""
print("---ASSESS GRADED DOCUMENTS---")
state["question"]
filtered_documents = state["documents"]
if not filtered_documents:
# All documents have been filtered check_relevance
# We will re-generate a new query
print(
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
)
return "transform_query"
else:
# We have relevant documents, so generate answer
print("---DECISION: GENERATE---")
return "generate"
# 确定生成是否基于文档并回答问题。
def grade_generation_v_documents_and_question(state):
"""
Determines whether the generation is grounded in the document and answers question.
Args:
state (dict): The current graph state
Returns:
str: Decision for next node to call
"""
print("---CHECK HALLUCINATIONS---")
question = state["question"]
documents = state["documents"]
generation = state["generation"]
score = hallucination_grader.invoke(
{"documents": documents, "generation": generation}
)
grade = score.binary_score
# Check hallucination
if grade == "yes":
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
# Check question-answering
print("---GRADE GENERATION vs QUESTION---")
score = answer_grader.invoke({"question": question, "generation": generation})
grade = score.binary_score
if grade == "yes":
print("---DECISION: GENERATION ADDRESSES QUESTION---")
return "useful"
else:
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
return "not useful"
else:
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
return "not supported"
build graph
from langgraph.graph import END, StateGraph
from pprint import pprint
from common.common import show_img
from gflow5 import web_search, retrieve, grade_documents, generate, transform_query, route_question, decide_to_generate, \
grade_generation_v_documents_and_question
from gstate4 import GraphState
# 定义工作流
workflow = StateGraph(GraphState)
# Define the nodes
workflow.add_node("web_search", web_search) # web search
workflow.add_node("retrieve", retrieve) # retrieve
workflow.add_node("grade_documents", grade_documents) # grade documents
workflow.add_node("generate", generate) # generatae
workflow.add_node("transform_query", transform_query) # transform_query
# Build graph
workflow.set_conditional_entry_point(
route_question,
{
"web_search": "web_search",
"vectorstore": "retrieve",
},
)
workflow.add_edge("web_search", "generate")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
"generate",
grade_generation_v_documents_and_question,
{
"not supported": "generate",
"useful": END,
"not useful": "transform_query",
},
)
# Compile
app = workflow.compile()
执行
# Run
inputs = {
"question": "熊队的哪位球员有望在2024年的NFL选秀中获得第一名?"
}
for output in app.stream(inputs):
for key, value in output.items():
# Node
pprint(f"Node '{key}':")
# Optional: print full state at each node
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
pprint("\n---\n")