Skip to content

Adaptive RAG

基于 LangGraph 官方教程:结合查询分析与自纠正 RAG 的智能检索策略


概述

Adaptive RAG 是一种将 (1) 查询分析 (Query Analysis)(2) 主动/自纠正 RAG (Active / Self-corrective RAG) 相结合的 RAG 策略。

根据 论文 的描述,查询分析可以路由到以下三种模式:

  • No Retrieval:无需检索,直接使用 LLM 知识
  • Single-shot RAG:单次 RAG 检索
  • Iterative RAG:迭代式 RAG 检索

本教程使用 LangGraph 构建了一个改进版本,路由策略包括:

  • Web Search:针对时效性问题(如近期事件)
  • Self-corrective RAG:针对索引相关的问题

Adaptive RAG Architecture图:Adaptive RAG 系统架构


环境准备

安装依赖

bash
pip install -U langchain_community tiktoken langchain-openai langchain-cohere \
    langchainhub chromadb langchain langgraph tavily-python

设置 API Keys

python
import getpass
import os

def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")

_set_env("OPENAI_API_KEY")
_set_env("TAVILY_API_KEY")

一、创建向量索引

使用 OpenAI Embeddings 和 Chroma 向量数据库,对博客文章进行索引:

python
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings

# 设置 Embeddings
embd = OpenAIEmbeddings()

# 待索引的文档 URL
urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

# 加载文档
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

# 分割文档
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=500, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)

# 添加到向量存储
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma",
    embedding=embd,
)
retriever = vectorstore.as_retriever()

二、LLM 组件

2.1 查询路由器 (Router)

创建 RouteQuery 数据模型,让 LLM 决定将查询路由到向量存储还是 Web 搜索:

python
from typing import Literal
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

# 数据模型
class RouteQuery(BaseModel):
    """Route a user query to the most relevant datasource."""
    datasource: Literal["vectorstore", "web_search"] = Field(
        ...,
        description="Given a user question choose to route it to web search or a vectorstore.",
    )

# 带结构化输出的 LLM
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
structured_llm_router = llm.with_structured_output(RouteQuery)

# 提示词
system = """You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks.
Use the vectorstore for questions on these topics. Otherwise, use web-search."""

route_prompt = ChatPromptTemplate.from_messages([
    ("system", system),
    ("human", "{question}"),
])

question_router = route_prompt | structured_llm_router

# 测试
print(question_router.invoke(
    {"question": "Who will the Bears draft first in the NFL draft?"}
))  # -> web_search

print(question_router.invoke(
    {"question": "What are the types of agent memory?"}
))  # -> vectorstore

2.2 检索评分器 (Retrieval Grader)

评估检索到的文档是否与查询相关:

python
class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""
    binary_score: str = Field(
        description="Documents are relevant to the question, 'yes' or 'no'"
    )

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)

system = """You are a grader assessing relevance of a retrieved document to a user question.
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""

grade_prompt = ChatPromptTemplate.from_messages([
    ("system", system),
    ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
])

retrieval_grader = grade_prompt | structured_llm_grader

2.3 生成器 (Generator)

使用 RAG 提示词生成答案:

python
from langchain import hub
from langchain_core.output_parsers import StrOutputParser

# 从 Hub 拉取 RAG 提示词
prompt = hub.pull("rlm/rag-prompt")

llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# RAG Chain
rag_chain = prompt | llm | StrOutputParser()

2.4 幻觉评分器 (Hallucination Grader)

验证生成的答案是否基于检索到的事实:

python
class GradeHallucinations(BaseModel):
    """Binary score for hallucination present in generation answer."""
    binary_score: str = Field(
        description="Answer is grounded in the facts, 'yes' or 'no'"
    )

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeHallucinations)

system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts.
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""

hallucination_prompt = ChatPromptTemplate.from_messages([
    ("system", system),
    ("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
])

hallucination_grader = hallucination_prompt | structured_llm_grader

2.5 答案评分器 (Answer Grader)

评估答案是否真正回答了问题:

python
class GradeAnswer(BaseModel):
    """Binary score to assess answer addresses question."""
    binary_score: str = Field(
        description="Answer addresses the question, 'yes' or 'no'"
    )

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeAnswer)

system = """You are a grader assessing whether an answer addresses / resolves a question.
Give a binary score 'yes' or 'no'. 'Yes' means that the answer resolves the question."""

answer_prompt = ChatPromptTemplate.from_messages([
    ("system", system),
    ("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
])

answer_grader = answer_prompt | structured_llm_grader

2.6 查询重写器 (Question Rewriter)

优化查询以提高检索效果:

python
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

system = """You a question re-writer that converts an input question to a better version that is optimized
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""

re_write_prompt = ChatPromptTemplate.from_messages([
    ("system", system),
    ("human", "Here is the initial question: \n\n {question} \n Formulate an improved question."),
])

question_rewriter = re_write_prompt | llm | StrOutputParser()

三、Web 搜索工具

使用 Tavily Search 获取网络信息:

python
from langchain_community.tools.tavily_search import TavilySearchResults

web_search_tool = TavilySearchResults(k=3)

四、构建 LangGraph 工作流

4.1 定义状态

python
from typing import List
from typing_extensions import TypedDict

class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
    """
    question: str
    generation: str
    documents: List[str]

4.2 定义节点函数

python
from pprint import pprint
from langchain.schema import Document

def retrieve(state):
    """检索文档"""
    print("---RETRIEVE---")
    question = state["question"]
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}


def generate(state):
    """生成答案"""
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]
    docs_txt = format_docs(documents)
    generation = rag_chain.invoke({"context": docs_txt, "question": question})
    return {"documents": documents, "question": question, "generation": generation}


def grade_documents(state):
    """评估文档相关性"""
    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
    question = state["question"]
    documents = state["documents"]

    filtered_docs = []
    for d in documents:
        score = retrieval_grader.invoke(
            {"question": question, "document": d.page_content}
        )
        grade = score.binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            continue
    return {"documents": filtered_docs, "question": question}


def transform_query(state):
    """重写查询"""
    print("---TRANSFORM QUERY---")
    question = state["question"]
    documents = state["documents"]
    better_question = question_rewriter.invoke({"question": question})
    return {"documents": documents, "question": better_question}


def web_search(state):
    """Web 搜索"""
    print("---WEB SEARCH---")
    question = state["question"]
    docs = web_search_tool.invoke({"query": question})
    web_results = "\n".join([d["content"] for d in docs])
    web_results = Document(page_content=web_results)
    return {"documents": web_results, "question": question}

4.3 定义路由函数(边)

python
def route_question(state):
    """路由问题到 Web 搜索或 RAG"""
    print("---ROUTE QUESTION---")
    question = state["question"]
    source = question_router.invoke({"question": question})
    if source.datasource == "web_search":
        print("---ROUTE QUESTION TO WEB SEARCH---")
        return "web_search"
    elif source.datasource == "vectorstore":
        print("---ROUTE QUESTION TO RAG---")
        return "vectorstore"


def decide_to_generate(state):
    """决定是生成答案还是重写查询"""
    print("---ASSESS GRADED DOCUMENTS---")
    filtered_documents = state["documents"]

    if not filtered_documents:
        print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---")
        return "transform_query"
    else:
        print("---DECISION: GENERATE---")
        return "generate"


def grade_generation_v_documents_and_question(state):
    """评估生成内容:是否有幻觉、是否回答了问题"""
    print("---CHECK HALLUCINATIONS---")
    question = state["question"]
    documents = state["documents"]
    generation = state["generation"]

    score = hallucination_grader.invoke(
        {"documents": documents, "generation": generation}
    )
    grade = score.binary_score

    if grade == "yes":
        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
        print("---GRADE GENERATION vs QUESTION---")
        score = answer_grader.invoke({"question": question, "generation": generation})
        grade = score.binary_score
        if grade == "yes":
            print("---DECISION: GENERATION ADDRESSES QUESTION---")
            return "useful"
        else:
            print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
            return "not useful"
    else:
        pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
        return "not supported"

4.4 编译图

python
from langgraph.graph import END, StateGraph, START

workflow = StateGraph(GraphState)

# 添加节点
workflow.add_node("web_search", web_search)
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)

# 构建图的边
workflow.add_conditional_edges(
    START,
    route_question,
    {
        "web_search": "web_search",
        "vectorstore": "retrieve",
    },
)
workflow.add_edge("web_search", "generate")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    },
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
    "generate",
    grade_generation_v_documents_and_question,
    {
        "not supported": "generate",
        "useful": END,
        "not useful": "transform_query",
    },
)

# 编译
app = workflow.compile()

五、运行示例

示例 1:Web 搜索路由

python
inputs = {
    "question": "What player at the Bears expected to draft first in the 2024 NFL draft?"
}
for output in app.stream(inputs):
    for key, value in output.items():
        pprint(f"Node '{key}':")
    pprint("\n---\n")

pprint(value["generation"])

执行流程

---ROUTE QUESTION---
---ROUTE QUESTION TO WEB SEARCH---
---WEB SEARCH---
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---

示例 2:向量存储路由

python
inputs = {"question": "What are the types of agent memory?"}
for output in app.stream(inputs):
    for key, value in output.items():
        pprint(f"Node '{key}':")
    pprint("\n---\n")

pprint(value["generation"])

执行流程

---ROUTE QUESTION---
---ROUTE QUESTION TO RAG---
---RETRIEVE---
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---DECISION: GENERATE---
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---

六、架构总结

Adaptive RAG 的核心流程:

用户查询


┌─────────────────────────────────────┐
│         查询路由器 (Router)           │
│   分析查询 → 决定数据源               │
└─────────────────────────────────────┘

    ├─→ Web 搜索 ────────────────────────────────────┐
    │                                                │
    └─→ 向量存储检索                                  │
            │                                        │
            ▼                                        │
    ┌─────────────────────────────────────┐          │
    │       文档评分 (Grade Documents)      │          │
    │   评估检索文档的相关性                 │          │
    └─────────────────────────────────────┘          │
            │                                        │
            ├─→ 不相关 → 查询重写 → 重新检索          │
            │                                        │
            └─→ 相关 ─────────────────────────────────┤


                                    ┌─────────────────────────────────────┐
                                    │         生成答案 (Generate)           │
                                    └─────────────────────────────────────┘


                                    ┌─────────────────────────────────────┐
                                    │       幻觉检测 (Hallucination)        │
                                    │   答案是否基于检索内容?              │
                                    └─────────────────────────────────────┘

                                                     ├─→ 有幻觉 → 重新生成


                                    ┌─────────────────────────────────────┐
                                    │       答案评估 (Answer Grader)        │
                                    │   答案是否回答了问题?               │
                                    └─────────────────────────────────────┘

                                                     ├─→ 未回答 → 查询重写

                                                     └─→ 已回答 → 返回结果

七、关键设计要点

组件作用输出
Query Router分析查询,路由到合适的数据源vectorstore / web_search
Retrieval Grader评估检索文档的相关性yes / no
Generator基于上下文生成答案文本答案
Hallucination Grader检测答案是否有幻觉yes / no
Answer Grader评估答案是否回答了问题yes / no
Question Rewriter优化查询以提高检索效果优化后的查询

思考题

  1. 如何扩展路由器以支持更多数据源(如知识图谱、SQL 数据库)?
  2. 在什么情况下应该增加最大重试次数限制?
  3. 如何在生产环境中监控和调试 Adaptive RAG 工作流?

参考资源

基于 MIT 许可证发布。内容版权归作者所有。