LangGraph Multiple Schemas 详细解读
📚 概述
本文档详细解读 LangGraph 中的 Multiple Schemas(多状态模式) 概念。这是一种高级状态管理技术,允许我们在图的不同部分使用不同的状态结构,从而实现更灵活、更安全、更清晰的状态管理。
🎯 核心概念
什么是 Multiple Schemas?
在之前的学习中,我们通常使用单一状态模式:整个图的所有节点共享同一个状态结构(Schema)。但在实际开发中,这种方式有时会带来问题:
单一状态的问题:
- 状态污染:内部节点可能产生中间数据,但这些数据不应该暴露给用户
- 输入输出不灵活:用户可能只需要提供部分输入,或只需要部分输出
- 节点耦合:所有节点都依赖同一个庞大的状态结构,增加复杂度
Multiple Schemas 的解决方案:
- Private State(私有状态):节点间可以传递私有数据,不暴露给外部
- Input Schema(输入模式):限定用户输入的字段
- Output Schema(输出模式):限定返回给用户的字段
- Internal State(内部状态):图内部使用的完整状态
🎭 实战案例 1:私有状态(Private State)
需求场景
假设我们有一个计算流程:
node_1
接收用户输入foo
,计算出中间结果baz
node_2
使用baz
进行进一步计算,得到最终结果foo
- 关键:
baz
是中间计算值,不应该暴露给用户
系统架构图
用户输入: {foo: 1}
↓
[node_1]
输入: OverallState {foo: 1}
计算: baz = foo + 1 = 2
输出: PrivateState {baz: 2}
↓
[node_2]
输入: PrivateState {baz: 2}
计算: foo = baz + 1 = 3
输出: OverallState {foo: 3}
↓
用户输出: {foo: 3} ← baz 被隐藏了!
🔧 代码实现详解:私有状态
1. 定义状态
from typing_extensions import TypedDict
# 全局状态(对外可见)
class OverallState(TypedDict):
foo: int
# 私有状态(内部使用)
class PrivateState(TypedDict):
baz: int
设计思路:
OverallState
:图的输入和输出状态,用户可见PrivateState
:内部节点之间传递的私有数据- 两者完全独立,互不干扰
2. 定义节点
def node_1(state: OverallState) -> PrivateState:
print("---Node 1---")
return {"baz": state['foo'] + 1}
def node_2(state: PrivateState) -> OverallState:
print("---Node 2---")
return {"foo": state['baz'] + 1}
关键点解析:
node_1 的类型签名
def node_1(state: OverallState) -> PrivateState:
# ^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^
# 输入类型 输出类型
- 输入:
OverallState
- 从用户输入或图状态读取foo
- 输出:
PrivateState
- 返回私有字段baz
- 重要:虽然返回类型是
PrivateState
,但返回的数据会被合并到图的总体状态中
node_2 的类型签名
def node_2(state: PrivateState) -> OverallState:
# ^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^
# 从私有状态读取 写回全局状态
- 输入:
PrivateState
- 读取node_1
产生的私有字段baz
- 输出:
OverallState
- 更新公开字段foo
Python 知识点:类型注解(Type Hints)
Python 的类型注解用于:
- 静态类型检查:帮助 IDE 和工具检测类型错误
- 文档作用:清晰表达函数的输入输出
- LangGraph 使用:告诉 LangGraph 节点使用哪个状态结构
# 类型注解语法
def function(param: InputType) -> OutputType:
return value
# LangGraph 中的特殊用途
# LangGraph 会根据类型注解自动进行状态过滤和映射
def node(state: PrivateState) -> OverallState:
# LangGraph 自动:
# 1. 从总状态中提取 PrivateState 字段传给 node
# 2. 将 node 返回的 OverallState 字段合并回总状态
pass
3. 构建图
from langgraph.graph import StateGraph, START, END
# 使用 OverallState 作为图的主状态
builder = StateGraph(OverallState)
# 添加节点
builder.add_node("node_1", node_1)
builder.add_node("node_2", node_2)
# 添加边
builder.add_edge(START, "node_1")
builder.add_edge("node_1", "node_2")
builder.add_edge("node_2", END)
# 编译
graph = builder.compile()
LangGraph 知识点:状态图初始化
StateGraph(OverallState)
# ^^^^^^^^^^^^^
# 图的主状态类型
- 尽管节点可能使用不同的状态类型(如
PrivateState
),但图本身有一个主状态类型 - 这个主状态类型定义了图的输入和输出接口
- 节点的输入输出类型可以是主状态的子集或超集
4. 执行与观察
result = graph.invoke({"foo": 1})
print(result)
# 输出: {'foo': 3}
执行流程详解:
步骤 | 当前状态 | 节点 | 输入 | 输出 | 状态更新后 |
---|---|---|---|---|---|
1 | {foo: 1} | node_1 | {foo: 1} | {baz: 2} | {foo: 1, baz: 2} |
2 | {foo: 1, baz: 2} | node_2 | {baz: 2} | {foo: 3} | {foo: 3, baz: 2} |
3 | 最终输出过滤 | - | - | - | {foo: 3} ← baz 被过滤掉 |
关键观察:
baz
在图的内部状态中存在(步骤 1-2)- 但最终输出时,只返回
OverallState
中定义的字段(foo
) baz
成功隐藏,实现了私有状态的效果
🎭 实战案例 2:输入/输出模式(Input/Output Schema)
需求场景
构建一个问答系统:
- 用户输入:只需要提供
question
- 内部处理:生成
answer
和notes
(内部思考过程) - 用户输出:只返回
answer
,隐藏notes
系统架构图
用户输入: {question: "hi"} ← 符合 InputState
↓
[thinking_node]
生成: answer="bye", notes="... his name is Lance"
↓
[answer_node]
生成: answer="bye Lance"
↓
用户输出: {answer: "bye Lance"} ← 符合 OutputState,notes 被隐藏
🔧 代码实现详解:输入/输出模式
1. 定义三种状态
from typing_extensions import TypedDict
# 输入状态:用户必须提供的字段
class InputState(TypedDict):
question: str
# 输出状态:返回给用户的字段
class OutputState(TypedDict):
answer: str
# 内部状态:图内部使用的完整状态
class OverallState(TypedDict):
question: str
answer: str
notes: str
设计模式:状态分层
InputState (最小输入)
↓ 包含于
OverallState (完整内部状态)
↓ 过滤为
OutputState (最小输出)
Python 知识点:TypedDict 继承关系
虽然这里没有使用继承,但 TypedDict 支持继承:
# 方式 1:独立定义(如上面代码)
class InputState(TypedDict):
question: str
# 方式 2:继承扩展
class OverallState(InputState): # 继承 InputState
answer: str
notes: str
# 两种方式效果类似,但继承方式确保字段一致性
2. 定义节点(版本 1:不使用输入/输出模式)
先看没有 Input/Output Schema 的版本:
class OverallState(TypedDict):
question: str
answer: str
notes: str
def thinking_node(state: OverallState):
return {"answer": "bye", "notes": "... his name is Lance"}
def answer_node(state: OverallState):
return {"answer": "bye Lance"}
# 构建图
graph = StateGraph(OverallState)
graph.add_node("answer_node", answer_node)
graph.add_node("thinking_node", thinking_node)
graph.add_edge(START, "thinking_node")
graph.add_edge("thinking_node", "answer_node")
graph.add_edge(answer_node, END)
graph = graph.compile()
# 执行
result = graph.invoke({"question": "hi"})
print(result)
# 输出: {'question': 'hi', 'answer': 'bye Lance', 'notes': '... his name is Lance'}
问题: 所有内部字段(包括 notes
)都暴露给用户了!
3. 定义节点(版本 2:使用输入/输出模式)✅
class InputState(TypedDict):
question: str
class OutputState(TypedDict):
answer: str
class OverallState(TypedDict):
question: str
answer: str
notes: str
# 节点 1:只需要 InputState
def thinking_node(state: InputState):
return {"answer": "bye", "notes": "... his name is Lance"}
# 节点 2:使用 OverallState,但声明输出为 OutputState
def answer_node(state: OverallState) -> OutputState:
return {"answer": "bye Lance"}
# 构建图:指定输入和输出模式
graph = StateGraph(
OverallState,
input_schema=InputState, # 限定输入
output_schema=OutputState # 限定输出
)
graph.add_node("answer_node", answer_node)
graph.add_node("thinking_node", thinking_node)
graph.add_edge(START, "thinking_node")
graph.add_edge("thinking_node", "answer_node")
graph.add_edge(answer_node, END)
graph = graph.compile()
# 执行
result = graph.invoke({"question": "hi"})
print(result)
# 输出: {'answer': 'bye Lance'} ← notes 被过滤掉了!
关键改进:
StateGraph
构造时指定input_schema
和output_schema
thinking_node
的输入类型注解为InputState
answer_node
的输出类型注解为OutputState
4. 类型注解的作用
thinking_node 的类型注解
def thinking_node(state: InputState):
# ^^^^^^^^^^
# 只能访问 question 字段
return {"answer": "bye", "notes": "... his name is Lance"}
作用:
- 输入过滤:LangGraph 只传递
InputState
中定义的字段(question
) - 安全性:节点无法意外访问其他字段
- 清晰性:明确表达节点的输入依赖
answer_node 的类型注解
def answer_node(state: OverallState) -> OutputState:
# ^^^^^^^^^^^^^ ^^^^^^^^^^^
# 可以访问所有字段 输出会被过滤
return {"answer": "bye Lance"}
作用:
- 输入:可以访问完整的
OverallState
(包括notes
) - 输出:返回值会被过滤,只保留
OutputState
定义的字段
注意:即使 answer_node
返回了 notes
,也会被过滤掉:
def answer_node(state: OverallState) -> OutputState:
return {
"answer": "bye Lance",
"notes": "some internal note" # 这个会被过滤掉
}
🎓 核心知识点总结
LangGraph 特有概念
1. 私有状态(Private State)
定义: 节点间传递但不对外暴露的状态
实现方式:
# 定义私有状态
class PrivateState(TypedDict):
internal_data: str
# 节点输出私有状态
def node1(state: PublicState) -> PrivateState:
return {"internal_data": "secret"}
# 节点消费私有状态
def node2(state: PrivateState) -> PublicState:
# 使用 state["internal_data"]
return {"result": "processed"}
关键点:
- 私有字段不会出现在图的最终输出中
- 只在图的内部状态中存在
- 通过类型注解控制可见性
2. 输入/输出模式(Input/Output Schema)
定义: 限定图的输入和输出接口
实现方式:
graph = StateGraph(
InternalState, # 图内部的完整状态
input_schema=InputState, # 用户输入的状态结构
output_schema=OutputState # 返回给用户的状态结构
)
三种状态的关系:
状态类型 | 用途 | 字段数量 | 作用域 |
---|---|---|---|
InputState | 用户输入 | 最少(必需字段) | 图入口 |
InternalState | 内部处理 | 最多(所有字段) | 图内部 |
OutputState | 用户输出 | 适中(有用字段) | 图出口 |
典型模式:
InputState ⊆ InternalState ⊇ OutputState
3. 状态过滤机制
LangGraph 自动进行状态过滤:
输入过滤:
# 用户调用
graph.invoke({"question": "hi", "extra": "ignored"})
# ^^^^^^^^^^^^^^^^
# 不在 InputState 中,被忽略
# 如果 InputState 只定义了 question,
# 则只有 question 会被传入图
输出过滤:
# 图的内部状态
internal_state = {
"question": "hi",
"answer": "bye",
"notes": "internal"
}
# 返回给用户(只保留 OutputState 定义的字段)
output = {"answer": "bye"} # question 和 notes 被过滤
4. 节点类型注解的完整语法
def node(state: InputType) -> OutputType:
# ^^^^^^^^^^^^^ ^^^^^^^^^^
# 从图状态读取 写回图状态
# (过滤输入) (过滤输出)
pass
类型注解组合:
输入类型 | 输出类型 | 含义 |
---|---|---|
InputState | OutputState | 读取最小输入,写入特定输出 |
OverallState | PrivateState | 读取所有字段,写入私有字段 |
PrivateState | OverallState | 读取私有字段,写入公开字段 |
OverallState | OverallState | 读写完整状态(默认) |
Python 特有知识点
1. TypedDict 详解
基础用法:
from typing_extensions import TypedDict
class Person(TypedDict):
name: str
age: int
# 使用
person: Person = {"name": "Alice", "age": 30}
特性:
- 类型检查:IDE 会提示类型错误
- 字典本质:仍然是普通字典,只是添加了类型信息
- 运行时:不会进行严格验证(与 Pydantic 不同)
可选字段:
from typing import NotRequired
class OptionalState(TypedDict):
required: str
optional: NotRequired[int] # 可选字段
# 合法
state1: OptionalState = {"required": "value"}
state2: OptionalState = {"required": "value", "optional": 42}
2. 类型注解(Type Hints)
函数注解:
def greet(name: str) -> str:
# ^^^ ^^^
# 参数类型 返回类型
return f"Hello, {name}"
变量注解:
count: int = 0
names: list[str] = ["Alice", "Bob"]
data: dict[str, int] = {"a": 1, "b": 2}
泛型注解:
from typing import List, Dict, Optional, Union
# 旧式(Python 3.8)
names: List[str] = ["Alice"]
mapping: Dict[str, int] = {"a": 1}
optional_value: Optional[int] = None # int 或 None
union_value: Union[int, str] = 42 # int 或 str
# 新式(Python 3.9+)
names: list[str] = ["Alice"]
mapping: dict[str, int] = {"a": 1}
optional_value: int | None = None
union_value: int | str = 42
3. Annotated 类型
在 LangGraph 中常见的 Annotated
用法:
from typing import Annotated
import operator
class State(TypedDict):
# 基础类型
name: str
# 带 reducer 的类型
messages: Annotated[list, operator.add]
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# 类型: list, 合并策略: operator.add
Annotated 的结构:
Annotated[Type, metadata1, metadata2, ...]
# ^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^
# 实际类型 元数据(附加信息)
LangGraph 中的用途:
- 第一个参数:实际的数据类型
- 第二个参数:Reducer 函数(如何合并多个更新)
# 例子:列表追加
messages: Annotated[list, operator.add]
# node1 返回 {"messages": ["a"]}
# node2 返回 {"messages": ["b"]}
# 最终状态: {"messages": ["a", "b"]} ← operator.add 自动合并
💡 最佳实践
1. 何时使用私有状态?
✅ 适用场景:
- 内部计算的中间结果(如临时变量、缓存)
- 调试信息(如执行日志、性能指标)
- 敏感信息(不希望暴露给最终用户)
- 节点间的协调数据(如计数器、标志位)
❌ 不适用场景:
- 用户需要的核心数据
- 需要持久化的信息
- 需要在多个公开节点间共享的数据
示例:
# ✅ 好的私有状态使用
class PrivateState(TypedDict):
processing_time: float # 性能统计
api_call_count: int # 调试信息
intermediate_result: str # 中间计算
# ❌ 不应该私有的数据
class PrivateState(TypedDict):
user_id: str # 用户需要知道的
final_result: str # 核心输出数据
2. 何时使用输入/输出模式?
✅ 适用场景:
- API 服务(需要清晰的输入输出接口)
- 多租户系统(不同用户看到不同字段)
- 安全敏感应用(隐藏内部实现细节)
- 大型项目(强制接口规范)
❌ 不适用场景:
- 简单的内部工具
- 原型开发阶段
- 所有字段都需要输入输出的情况
示例:
# ✅ 好的输入/输出设计
class InputState(TypedDict):
user_query: str # 用户只需输入查询
class OutputState(TypedDict):
answer: str # 用户只需要答案
confidence: float # 和置信度
class InternalState(TypedDict):
user_query: str
answer: str
confidence: float
api_calls: int # 内部统计
cached: bool # 内部标志
raw_response: dict # 原始数据
3. 状态设计原则
原则 1:最小化原则
每个状态只包含必需的字段:
# ✅ 好的设计
class InputState(TypedDict):
question: str # 只要必需字段
# ❌ 过度设计
class InputState(TypedDict):
question: str
answer: str # 输入不需要 answer
notes: str # 输入不需要 notes
metadata: dict # 输入不需要 metadata
原则 2:单一职责
每个状态服务于特定目的:
# ✅ 好的分离
class InputState(TypedDict): # 职责:接收用户输入
question: str
class ProcessingState(TypedDict): # 职责:内部处理
question: str
intermediate: str
class OutputState(TypedDict): # 职责:返回结果
answer: str
# ❌ 混淆职责
class MixedState(TypedDict): # 混合了所有职责
question: str # 输入
intermediate: str # 处理
answer: str # 输出
原则 3:清晰的类型注解
始终明确指定节点的输入输出类型:
# ✅ 清晰的类型注解
def process(state: InputState) -> OutputState:
# 一眼就知道输入输出
pass
# ❌ 不清晰的注解
def process(state): # 无法知道使用哪个状态
pass
🚀 进阶技巧
1. 部分状态更新
节点可以只更新状态的部分字段:
class State(TypedDict):
a: str
b: str
c: str
def node(state: State) -> State:
# 只更新 a,不影响 b 和 c
return {"a": "new_value"}
# 执行前: {a: "old", b: "b", c: "c"}
# 执行后: {a: "new_value", b: "b", c: "c"}
2. 状态验证
使用 Pydantic 进行运行时验证:
from pydantic import BaseModel, validator
class ValidatedInput(BaseModel):
question: str
@validator('question')
def question_not_empty(cls, v):
if not v.strip():
raise ValueError('Question cannot be empty')
return v
# 在节点中使用
def node(state):
# 验证输入
validated = ValidatedInput(**state)
# 使用 validated.question
3. 动态状态选择
根据条件选择不同的状态结构:
def conditional_node(state: OverallState):
if state.get("use_detailed"):
# 返回详细状态
return DetailedState(...)
else:
# 返回简单状态
return SimpleState(...)
4. 状态继承
复用状态定义:
class BaseState(TypedDict):
timestamp: float
user_id: str
class ProcessingState(BaseState): # 继承 BaseState
data: str
status: str
# ProcessingState 自动包含 timestamp 和 user_id
📊 状态模式对比
特性 | 单一状态 | 私有状态 | 输入/输出模式 |
---|---|---|---|
复杂度 | 低 | 中 | 高 |
灵活性 | 低 | 中 | 高 |
封装性 | 差 | 好 | 最好 |
适用场景 | 简单图 | 内部计算 | API、服务 |
学习曲线 | 平缓 | 适中 | 陡峭 |
选择建议:
# 简单原型 → 单一状态
graph = StateGraph(SimpleState)
# 需要隐藏中间数据 → 私有状态
def node1(state: PublicState) -> PrivateState: ...
def node2(state: PrivateState) -> PublicState: ...
# 对外服务 → 输入/输出模式
graph = StateGraph(
InternalState,
input_schema=InputState,
output_schema=OutputState
)
🎯 实际应用案例
案例 1:智能客服系统
# 输入:用户只需提供问题
class InputState(TypedDict):
user_question: str
# 输出:返回答案和置信度
class OutputState(TypedDict):
answer: str
confidence: float
# 内部:包含所有处理信息
class InternalState(TypedDict):
user_question: str
answer: str
confidence: float
retrieved_docs: list # 检索的文档(私有)
api_response: dict # API 原始响应(私有)
processing_time: float # 处理时间(私有)
def retrieve_docs(state: InputState) -> InternalState:
# 检索相关文档
docs = vector_db.search(state["user_question"])
return {"retrieved_docs": docs}
def generate_answer(state: InternalState) -> OutputState:
# 生成答案
response = llm.invoke({
"question": state["user_question"],
"context": state["retrieved_docs"]
})
return {
"answer": response.answer,
"confidence": response.confidence
}
graph = StateGraph(
InternalState,
input_schema=InputState,
output_schema=OutputState
)
案例 2:数据处理管道
# 私有状态:中间处理步骤
class PrivateState(TypedDict):
cleaned_data: str
validation_result: bool
# 公开状态:输入和输出
class PublicState(TypedDict):
raw_data: str
processed_data: str
def clean_data(state: PublicState) -> PrivateState:
cleaned = state["raw_data"].strip().lower()
is_valid = len(cleaned) > 0
return {
"cleaned_data": cleaned,
"validation_result": is_valid
}
def process_data(state: PrivateState) -> PublicState:
if state["validation_result"]:
processed = state["cleaned_data"].upper()
return {"processed_data": processed}
else:
return {"processed_data": "ERROR"}
# cleaned_data 和 validation_result 不会出现在最终输出
案例 3:多阶段认证系统
class LoginInput(TypedDict):
username: str
password: str
class SessionState(TypedDict): # 私有
user_id: int
permissions: list[str]
session_token: str
class LoginOutput(TypedDict):
success: bool
session_id: str # 不暴露完整 token
class InternalState(TypedDict):
username: str
password: str
user_id: int
permissions: list[str]
session_token: str
success: bool
session_id: str
def authenticate(state: LoginInput) -> SessionState:
# 验证用户并创建会话
user = db.authenticate(state["username"], state["password"])
return {
"user_id": user.id,
"permissions": user.permissions,
"session_token": generate_token(user.id)
}
def create_response(state: InternalState) -> LoginOutput:
# 创建安全的响应(隐藏敏感信息)
return {
"success": True,
"session_id": hash(state["session_token"])[:16]
}
# session_token 和 user_id 不会暴露给客户端
🔍 常见问题
Q1: 私有状态的数据会被完全删除吗?
答: 不会。私有状态的数据在图的内部状态中仍然存在,只是在最终输出时被过滤掉。
# 内部状态(执行过程中)
internal = {
"foo": 3,
"baz": 2 # ← 仍然存在
}
# 输出状态(返回给用户)
output = {
"foo": 3 # ← baz 被过滤
}
访问内部状态:
# 使用 stream 查看完整状态
for state in graph.stream({"foo": 1}):
print(state) # 可以看到 baz
# 使用 invoke 只能看到过滤后的状态
result = graph.invoke({"foo": 1})
print(result) # 看不到 baz
Q2: 节点的返回值会覆盖还是合并状态?
答: 默认是合并(merge),不是覆盖(replace)。
# 初始状态
state = {"a": 1, "b": 2, "c": 3}
# 节点返回
def node(state):
return {"a": 99} # 只更新 a
# 执行后的状态
state = {"a": 99, "b": 2, "c": 3} # b 和 c 保持不变
除非使用 Reducer:
# 使用自定义 reducer
class State(TypedDict):
data: Annotated[list, my_custom_reducer]
# 可以实现覆盖、追加、合并等任何逻辑
Q3: 可以在不同节点间切换状态类型吗?
答: 可以!这正是 Multiple Schemas 的核心功能。
def node1(state: StateA) -> StateB:
# 输入 StateA,输出 StateB
pass
def node2(state: StateB) -> StateC:
# 输入 StateB,输出 StateC
pass
def node3(state: StateC) -> StateA:
# 输入 StateC,输出 StateA(回到起点)
pass
限制: 所有状态类型的字段必须能够合并到图的主状态类型中。
Q4: InputState 可以包含 OverallState 没有的字段吗?
答: 不建议这样做,因为这些字段无法被存储到图状态中。
# ❌ 不推荐
class InputState(TypedDict):
question: str
extra_field: str # OverallState 没有这个字段
class OverallState(TypedDict):
question: str
answer: str
# 缺少 extra_field
# 当用户输入 {"question": "hi", "extra_field": "value"}
# extra_field 会被忽略或导致错误
正确做法: InputState 应该是 OverallState 的子集。
📖 扩展阅读
- LangGraph Private State 官方文档
- LangGraph Input/Output Schema 官方文档
- Python TypedDict 文档
- Python Type Hints PEP 484
🎁 总结
Multiple Schemas(多状态模式) 是 LangGraph 中的高级特性,提供了三种强大的能力:
私有状态(Private State)
- 节点间传递私有数据
- 隐藏内部实现细节
- 保持输出简洁
输入模式(Input Schema)
- 限定用户输入
- 提高安全性
- 强制接口规范
输出模式(Output Schema)
- 过滤输出字段
- 保护敏感信息
- 提供清晰的 API
核心价值:
- 更好的封装:隐藏内部细节
- 更清晰的接口:明确的输入输出
- 更高的安全性:防止数据泄露
- 更强的可维护性:代码结构清晰
掌握 Multiple Schemas,你就能构建出更加专业、安全、易用的 LangGraph 应用!