Skip to content

Module-3 小结和复习:状态管理精通指南

来自图灵奖获得者的总结寄语

恭喜你完成了 Module-3 的学习!状态管理是 LangGraph 的灵魂,也是构建复杂 AI 系统的基石。你已经掌握了从基础的 State Schema 设计,到高级的 Reducer 函数;从单一状态到多状态模式;从简单的消息过滤到智能的对话摘要;以及外部数据库的持久化存储。

在我几十年的计算机科学研究生涯中,我见证了状态管理从简单的变量到复杂的分布式系统的演进。LangGraph 的状态管理机制,优雅地结合了函数式编程的纯函数思想、类型系统的安全性保障,以及针对 AI 应用的创新设计。这不仅是一套技术,更是一种思维方式。

记住:优秀的架构师不是写最多代码的人,而是设计最优状态结构的人。本章的知识将伴随你的整个 LangGraph 开发生涯,反复实践,融会贯通,你将能够构建真正具有"智慧"的 AI 系统。

— 向所有追求卓越的你致敬


📚 术语表

术语名称LangGraph 定义和解读Python 定义和说明重要程度
State Schema定义 LangGraph 图数据结构的模式,可用 TypedDict、Dataclass 或 Pydantic状态类型定义,决定节点间数据传递的结构和类型⭐⭐⭐⭐⭐
Reducer定义如何合并多个节点对同一字段的更新,解决并行冲突函数签名 (旧值, 新值) -> 合并值,如 operator.add⭐⭐⭐⭐⭐
Multiple Schemas在同一图中使用不同的状态结构,实现私有状态和输入输出隔离通过类型注解指定节点的输入输出类型,自动过滤字段⭐⭐⭐⭐⭐
trim_messages基于 token 数量智能裁剪消息列表的工具函数trim_messages(msgs, max_tokens, strategy, token_counter)⭐⭐⭐⭐⭐
RemoveMessage用于从状态中永久删除消息的特殊消息类型RemoveMessage(id="msg_id") 配合 add_messages 使用⭐⭐⭐⭐⭐
Summarization使用 LLM 压缩对话历史为摘要文本的技术定期生成摘要并删除旧消息,保留关键信息⭐⭐⭐⭐⭐
Checkpointer负责保存和加载图状态的持久化组件MemorySaver(内存)、SqliteSaver(SQLite)、PostgresSaver⭐⭐⭐⭐⭐
Thread用于隔离不同对话会话的线程标识符配置中的 thread_id,每个 ID 独立的状态空间⭐⭐⭐⭐⭐
MessagesStateLangGraph 内置状态基类,包含 messages 和 add_messages reducerclass State(MessagesState): summary: str 继承并扩展⭐⭐⭐⭐⭐
add_messages专门用于消息列表的 Reducer,支持追加、修改、删除通过消息 ID 智能合并,识别 RemoveMessage⭐⭐⭐⭐⭐


📋 本章核心知识回顾

学习地图

mermaid
graph TB
    A[3.1 State Schema] -->|定义数据结构| B[3.2 State Reducers]
    A -->|控制可见性| C[3.3 Multiple Schemas]
    B -->|支持并发| D[并行节点]
    C -->|保护隐私| E[API 接口]
    D -->|产生大量消息| F[3.4 Trim & Filter]
    E -->|需要优化| F
    F -->|进一步压缩| G[3.5 Summarization]
    G -->|持久化| H[3.6 External Memory]
    H -->|完整系统| I[生产级聊天机器人]

    style A fill:#e1f5ff
    style B fill:#ffe1e1
    style C fill:#fff4e1
    style F fill:#e1ffe1
    style G fill:#f0e1ff
    style H fill:#ffe1f0
    style I fill:#ffe6e6

六大核心技术速查表

技术核心问题解决方案关键 API适用场景
State Schema如何定义数据结构?TypedDict/Dataclass/Pydanticclass State(TypedDict)所有项目
Reducers并行节点如何合并状态?operator.add/add_messagesAnnotated[list, add]并行执行
Multiple Schemas如何隐藏内部数据?Private/Input/Output Stateinput_schema=...API 服务
Trim & Filter如何控制消息数量?RemoveMessage/trim_messagestrim_messages(max_tokens)长对话
Summarization如何压缩历史?LLM 生成摘要RemoveMessage + 摘要超长对话
External Memory如何持久化状态?SqliteSaver/PostgresSaverSqliteSaver(conn)生产环境

🎯 复习题目列表

本复习包含 15 道综合性问题,分为三个难度级别:

基础理解题(第 1-5 题)⭐

测试对核心概念的理解

代码实现题(第 6-10 题)⭐⭐

测试实际编程能力

架构设计题(第 11-15 题)⭐⭐⭐

测试系统设计和最佳实践


📚 详细问答解析

问题 1:State Schema 三种定义方式的本质区别是什么?

点击查看详细解答

核心区别

三种方式的本质区别在于类型检查的时机功能丰富度

特性TypedDictDataclassPydantic
类型检查时机静态(IDE/mypy)静态(IDE/mypy)静态+运行时
访问语法state["key"]state.keystate.key
运行时验证❌ 无❌ 无✅ 自动验证
性能最快(0 开销)快(轻微开销)稍慢(验证开销)
默认值
自定义方法
复杂验证✅ (validators)

代码对比

python
# TypedDict - 最轻量
from typing_extensions import TypedDict
from typing import Literal

class TypedDictState(TypedDict):
    name: str
    mood: Literal["happy", "sad"]

# 使用
state = {"name": "Alice", "mood": "angry"}  # ❌ IDE 警告,但运行不报错
print(state["name"])  # 字典访问

# Dataclass - 更面向对象
from dataclasses import dataclass, field

@dataclass
class DataclassState:
    name: str
    mood: Literal["happy", "sad"]
    count: int = 0  # 默认值
    tags: list = field(default_factory=list)  # 可变默认值

    def is_happy(self):  # 自定义方法
        return self.mood == "happy"

# 使用
state = DataclassState(name="Alice", mood="happy")
print(state.name)  # 点号访问
print(state.is_happy())  # True

# Pydantic - 最严格
from pydantic import BaseModel, field_validator

class PydanticState(BaseModel):
    name: str
    mood: str

    @field_validator('mood')
    @classmethod
    def validate_mood(cls, v):
        if v not in ["happy", "sad"]:
            raise ValueError(f"Invalid mood: {v}")
        return v

# 使用
try:
    state = PydanticState(name="Alice", mood="angry")
except ValueError as e:
    print(f"验证失败: {e}")  # ✅ 运行时捕获错误

选择决策树

python
def choose_state_type(project):
    """选择合适的状态类型"""

    # 1. 需要运行时验证?(外部输入、关键数据)
    if project.needs_runtime_validation:
        return "Pydantic BaseModel"

    # 2. 需要默认值或自定义方法?
    if project.needs_defaults or project.needs_methods:
        return "Dataclass"

    # 3. 追求最佳性能?(大多数情况)
    return "TypedDict"

实际项目建议

python
# 小型项目(< 5 个节点)
class State(TypedDict):
    messages: list
    user_id: str

# 中型项目(5-20 个节点)
@dataclass
class State:
    messages: list = field(default_factory=list)
    user_id: str = ""
    config: dict = field(default_factory=dict)

# 大型项目(> 20 个节点)
class State(BaseModel):
    messages: list
    user_id: str
    config: AppConfig  # 嵌套 Pydantic 模型

    @field_validator('user_id')
    @classmethod
    def validate_user_id(cls, v):
        if not v or len(v) < 3:
            raise ValueError("Invalid user_id")
        return v

常见陷阱

python
# 陷阱 1: TypedDict 不做运行时检查
class State(TypedDict):
    age: int

state = {"age": "not_an_int"}  # ❌ 类型错误但不报错
# 解决:使用 Pydantic 或手动验证

# 陷阱 2: Dataclass 可变默认值
@dataclass
class State:
    tags: list = []  # ❌ 危险!所有实例共享同一个列表

# 解决:
@dataclass
class State:
    tags: list = field(default_factory=list)  # ✅

# 陷阱 3: Pydantic 性能开销
class State(BaseModel):
    data: list[dict]  # 深层嵌套

    @field_validator('data')
    @classmethod
    def validate_data(cls, v):
        # 每次创建实例都会验证!
        for item in v:
            complex_validation(item)  # 可能很慢
        return v

# 解决:只在必要时验证,或使用缓存

关键洞察

TypedDict 是 LangGraph 的默认选择,因为它:

  • 性能最优(零运行时开销)
  • 语法简洁(最少样板代码)
  • 与 LangGraph 的设计哲学一致(信任开发者)

只有在明确需要运行时验证或复杂功能时,才选择 Dataclass 或 Pydantic。


问题 2:Reducer 函数的工作原理是什么?为什么并行节点必须使用 Reducer?

点击查看详细解答

Reducer 的本质

Reducer 是一个状态合并函数,签名为:

python
def reducer(left: T, right: T) -> T:
    """
    left: 当前状态中的旧值
    right: 节点返回的新值
    返回: 合并后的值
    """
    pass

为什么需要 Reducer?

场景:并行节点的状态冲突

python
# 图结构
START → node_1 → ┬→ node_2 → END
                 └→ node_3 → END

# 执行流程
class State(TypedDict):
    count: int

def node_1(state):
    return {"count": 1}

def node_2(state):
    return {"count": state["count"] + 1}  # 期望 2

def node_3(state):
    return {"count": state["count"] + 1}  # 期望 2

# 问题:node_2 和 node_3 并行执行
# 两者都返回 {"count": 2}
# LangGraph 不知道该保留哪个 → InvalidUpdateError ❌

解决方案:使用 Reducer

python
from operator import add
from typing import Annotated

class State(TypedDict):
    count: Annotated[list[int], add]  # ⭐ 关键!

def node_1(state):
    return {"count": [1]}

def node_2(state):
    return {"count": [2]}  # 返回列表

def node_3(state):
    return {"count": [2]}  # 返回列表

# LangGraph 执行流程:
# 1. node_1 返回 [1] → state["count"] = [1]
# 2. node_2 返回 [2]
# 3. node_3 返回 [2]
# 4. Reducer 合并:[1] + [2] + [2] = [1, 2, 2] ✅

Reducer 工作机制详解

python
# LangGraph 内部伪代码
def apply_node_update(current_state, node_return, reducer):
    """应用节点更新"""
    for key, new_value in node_return.items():
        old_value = current_state[key]

        # 如果有 Reducer,使用 Reducer 合并
        if has_reducer(key):
            current_state[key] = reducer(old_value, new_value)
        else:
            # 没有 Reducer,直接覆盖
            current_state[key] = new_value

    return current_state

常用 Reducer 对比

python
from operator import add
import operator

# 1. operator.add - 列表拼接
class State(TypedDict):
    results: Annotated[list, add]

# node_1 返回 [1, 2]
# node_2 返回 [3]
# 合并: [1, 2] + [3] = [1, 2, 3]

# 2. add_messages - 消息管理
from langgraph.graph.message import add_messages

class State(TypedDict):
    messages: Annotated[list, add_messages]

# 自动处理:
# - 追加新消息
# - 基于 ID 更新消息
# - RemoveMessage 删除消息

# 3. max/min - 数值聚合
class State(TypedDict):
    max_score: Annotated[int, max]
    min_score: Annotated[int, min]

# node_1 返回 {"max_score": 10}
# node_2 返回 {"max_score": 15}
# 合并: max(10, 15) = 15

# 4. 自定义 Reducer - 去重
def unique_add(left: list, right: list) -> list:
    """去重后追加"""
    if not left:
        left = []
    if not right:
        right = []
    return list(set(left + right))

class State(TypedDict):
    tags: Annotated[list[str], unique_add]

实战案例:并行 API 调用

python
from typing import Annotated
from operator import add

class APIState(TypedDict):
    query: str
    results: Annotated[list[dict], add]  # ⭐ Reducer

def call_api_1(state: APIState):
    """调用 API 1"""
    result = {"source": "API1", "data": fetch_api_1(state["query"])}
    return {"results": [result]}

def call_api_2(state: APIState):
    """调用 API 2"""
    result = {"source": "API2", "data": fetch_api_2(state["query"])}
    return {"results": [result]}

def call_api_3(state: APIState):
    """调用 API 3"""
    result = {"source": "API3", "data": fetch_api_3(state["query"])}
    return {"results": [result]}

# 构建图
builder = StateGraph(APIState)
builder.add_node("api1", call_api_1)
builder.add_node("api2", call_api_2)
builder.add_node("api3", call_api_3)

# 并行调用
builder.add_edge(START, "api1")
builder.add_edge(START, "api2")
builder.add_edge(START, "api3")
builder.add_edge("api1", END)
builder.add_edge("api2", END)
builder.add_edge("api3", END)

graph = builder.compile()

# 🎨 可视化图结构
from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))

# 执行
result = graph.invoke({"query": "weather", "results": []})
# result["results"] = [
#     {"source": "API1", "data": ...},
#     {"source": "API2", "data": ...},
#     {"source": "API3", "data": ...}
# ] ✅ 三个结果自动合并

高级自定义 Reducer

python
# 示例 1:限制列表大小
def limited_add(max_size: int):
    """工厂函数:创建限制大小的 Reducer"""
    def reducer(left: list | None, right: list | None) -> list:
        if not left:
            left = []
        if not right:
            right = []
        combined = left + right
        return combined[-max_size:]  # 只保留最后 N 个
    return reducer

class State(TypedDict):
    recent_messages: Annotated[list, limited_add(10)]

# 示例 2:带优先级的合并
from dataclasses import dataclass

@dataclass
class PriorityItem:
    value: str
    priority: int

def priority_merge(left: list[PriorityItem], right: list[PriorityItem]) -> list[PriorityItem]:
    """按优先级合并并排序"""
    combined = (left or []) + (right or [])
    return sorted(combined, key=lambda x: x.priority, reverse=True)

class State(TypedDict):
    tasks: Annotated[list[PriorityItem], priority_merge]

# 示例 3:智能去重(保留最新)
def smart_unique(left: list[dict], right: list[dict]) -> list[dict]:
    """基于 ID 去重,保留最新版本"""
    seen = {}
    for item in (left or []) + (right or []):
        item_id = item.get("id")
        if item_id:
            seen[item_id] = item  # 后来的覆盖前面的
    return list(seen.values())

class State(TypedDict):
    items: Annotated[list[dict], smart_unique]

常见错误

python
# 错误 1:忘记返回列表
class State(TypedDict):
    results: Annotated[list, add]

def node(state):
    return {"results": "value"}  # ❌ 应该返回 ["value"]

# 错误 2:Reducer 不处理 None
def bad_reducer(left: list, right: list) -> list:
    return left + right  # ❌ 如果 left 或 right 是 None 会报错

# 正确:
def good_reducer(left: list | None, right: list | None) -> list:
    if not left:
        left = []
    if not right:
        right = []
    return left + right  # ✅

# 错误 3:Reducer 修改原对象
def bad_reducer(left: list, right: list) -> list:
    left.extend(right)  # ❌ 修改了原对象
    return left

# 正确:
def good_reducer(left: list, right: list) -> list:
    return left + right  # ✅ 创建新对象

关键洞察

Reducer 的三个核心作用:

  1. 解决并发冲突:多个节点可以安全地更新同一字段
  2. 定义合并语义:明确指定如何组合多个更新
  3. 保持状态一致性:确保状态更新的可预测性

何时必须使用 Reducer?

  • 图中存在并行节点
  • 多个节点可能更新同一字段
  • 需要追加而非覆盖(如消息历史)

问题 3:add_messages Reducer 的三大核心功能是什么?各自的实现原理?

点击查看详细解答

三大核心功能

add_messages 是 LangGraph 最强大的内置 Reducer,提供三大功能:

  1. 消息追加(Append)
  2. 消息修改(Update)
  3. 消息删除(Remove)

功能 1:消息追加

原理: 将新消息追加到现有消息列表末尾

python
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, AIMessage

# 初始状态
messages = [
    HumanMessage("Hi", id="1"),
    AIMessage("Hello!", id="2")
]

# 追加新消息
new_message = HumanMessage("How are you?", id="3")
result = add_messages(messages, new_message)

# 结果
# [
#     HumanMessage("Hi", id="1"),
#     AIMessage("Hello!", id="2"),
#     HumanMessage("How are you?", id="3")  # ✅ 追加成功
# ]

在 LangGraph 中使用:

python
from langgraph.graph import MessagesState

class State(MessagesState):
    pass  # 自动包含 messages: Annotated[list, add_messages]

def chat_node(state: State):
    # 生成回复
    response = llm.invoke(state["messages"])
    return {"messages": [response]}  # ✅ 自动追加

功能 2:消息修改(基于 ID)

原理: 如果新消息的 ID 与现有消息相同,则覆盖旧消息

python
# 初始状态
messages = [
    HumanMessage("I like cats", id="msg_1"),
    AIMessage("Great!", id="msg_2")
]

# 修改消息(相同 ID)
updated_message = HumanMessage("I like dogs", id="msg_1")  # ⭐ 相同 ID
result = add_messages(messages, updated_message)

# 结果
# [
#     HumanMessage("I like dogs", id="msg_1"),  # ✅ 内容被更新
#     AIMessage("Great!", id="msg_2")
# ]

应用场景:用户编辑消息

python
class ChatState(MessagesState):
    edit_mode: bool = False

def handle_edit(state: ChatState):
    """处理用户编辑"""
    if state["edit_mode"]:
        # 获取最后一条用户消息
        last_user_msg = [m for m in state["messages"] if isinstance(m, HumanMessage)][-1]

        # 创建编辑后的消息(使用相同 ID)
        edited_msg = HumanMessage(
            content=state["edited_content"],
            id=last_user_msg.id  # ⭐ 相同 ID = 覆盖
        )

        return {"messages": [edited_msg], "edit_mode": False}

    return {}

功能 3:消息删除

原理: 使用 RemoveMessage 标记要删除的消息

python
from langchain_core.messages import RemoveMessage

# 初始状态
messages = [
    HumanMessage("Msg 1", id="1"),
    HumanMessage("Msg 2", id="2"),
    HumanMessage("Msg 3", id="3"),
    HumanMessage("Msg 4", id="4")
]

# 删除前两条消息
delete_ops = [
    RemoveMessage(id="1"),
    RemoveMessage(id="2")
]
result = add_messages(messages, delete_ops)

# 结果
# [
#     HumanMessage("Msg 3", id="3"),
#     HumanMessage("Msg 4", id="4")
# ] ✅ 前两条被删除

实战:滑动窗口对话

python
from langgraph.graph import MessagesState
from langchain_core.messages import RemoveMessage

class WindowState(MessagesState):
    window_size: int = 10  # 只保留最近 10 条

def sliding_window_node(state: WindowState):
    """实现滑动窗口"""
    messages = state["messages"]
    window_size = state["window_size"]

    # 如果超过窗口大小
    if len(messages) > window_size:
        # 计算需要删除的数量
        num_to_delete = len(messages) - window_size

        # 创建删除操作
        delete_ops = [
            RemoveMessage(id=m.id)
            for m in messages[:num_to_delete]
        ]

        return {"messages": delete_ops}

    return {}

三大功能组合使用

实战案例:智能消息管理器

python
from langchain_core.messages import RemoveMessage, HumanMessage, AIMessage
from langgraph.graph import MessagesState

class SmartChatState(MessagesState):
    max_messages: int = 20
    keep_system: bool = True

def smart_message_manager(state: SmartChatState):
    """智能管理消息:
    1. 保留系统消息
    2. 删除旧对话
    3. 支持消息编辑
    """
    messages = state["messages"]
    max_msgs = state["max_messages"]

    updates = []

    # 1. 如果有编辑请求
    if state.get("edit_last"):
        last_user = [m for m in messages if isinstance(m, HumanMessage)][-1]
        edited = HumanMessage(
            content=state["edit_content"],
            id=last_user.id  # ⭐ 功能 2:修改
        )
        updates.append(edited)

    # 2. 如果超过最大消息数
    if len(messages) > max_msgs:
        # 分离系统消息和普通消息
        from langchain_core.messages import SystemMessage
        system_msgs = [m for m in messages if isinstance(m, SystemMessage)]
        other_msgs = [m for m in messages if not isinstance(m, SystemMessage)]

        # 删除旧消息(保留系统消息)
        if state["keep_system"]:
            num_to_delete = len(other_msgs) - (max_msgs - len(system_msgs))
            if num_to_delete > 0:
                delete_ops = [
                    RemoveMessage(id=m.id)  # ⭐ 功能 3:删除
                    for m in other_msgs[:num_to_delete]
                ]
                updates.extend(delete_ops)
        else:
            # 删除最旧的消息
            num_to_delete = len(messages) - max_msgs
            delete_ops = [
                RemoveMessage(id=m.id)
                for m in messages[:num_to_delete]
            ]
            updates.extend(delete_ops)

    # 3. 追加新消息(如果有)
    if state.get("new_user_message"):
        new_msg = HumanMessage(state["new_user_message"])  # ⭐ 功能 1:追加
        updates.append(new_msg)

    return {"messages": updates} if updates else {}

add_messages 内部实现(简化版)

python
def add_messages(left: list[BaseMessage], right: list[BaseMessage] | BaseMessage) -> list[BaseMessage]:
    """
    LangGraph 的 add_messages 实现原理(简化)
    """
    # 1. 规范化输入
    if not isinstance(right, list):
        right = [right]

    # 2. 复制左侧消息
    result = list(left) if left else []

    # 3. 处理右侧消息
    for msg in right:
        if isinstance(msg, RemoveMessage):
            # 功能 3:删除消息
            result = [m for m in result if m.id != msg.id]
        else:
            # 检查是否是更新操作
            existing_index = None
            for i, existing_msg in enumerate(result):
                if existing_msg.id == msg.id:
                    existing_index = i
                    break

            if existing_index is not None:
                # 功能 2:更新消息(相同 ID)
                result[existing_index] = msg
            else:
                # 功能 1:追加消息(新 ID)
                result.append(msg)

    return result

最佳实践

python
# ✅ 推荐:使用 MessagesState
from langgraph.graph import MessagesState

class State(MessagesState):
    # 自动拥有 messages 字段
    user_id: str

# ❌ 不推荐:手动定义
from typing import Annotated
from langgraph.graph.message import add_messages

class State(TypedDict):
    messages: Annotated[list, add_messages]  # 多写代码
    user_id: str

# ✅ 推荐:返回列表
def node(state):
    return {"messages": [new_message]}  # 单个消息也用列表

# ❌ 不推荐:返回单个消息
def node(state):
    return {"messages": new_message}  # 虽然可以,但不一致

# ✅ 推荐:批量删除使用列表推导
delete_ops = [RemoveMessage(id=m.id) for m in old_messages]
return {"messages": delete_ops}

# ❌ 不推荐:循环追加
for msg in old_messages:
    # 多次更新,效率低
    return {"messages": [RemoveMessage(id=msg.id)]}

性能优化技巧

python
# 技巧 1:批量操作优于多次操作
# ❌ 低效
for i in range(10):
    graph.invoke({"messages": [HumanMessage(f"Msg {i}")]}, config)

# ✅ 高效
all_messages = [HumanMessage(f"Msg {i}") for i in range(10)]
graph.invoke({"messages": all_messages}, config)

# 技巧 2:使用生成器延迟计算
def get_delete_ops(messages, cutoff):
    """使用生成器避免创建大列表"""
    for msg in messages:
        if msg.timestamp < cutoff:
            yield RemoveMessage(id=msg.id)

# 使用
delete_ops = list(get_delete_ops(state["messages"], cutoff))
return {"messages": delete_ops}

关键洞察

add_messages 是状态管理的瑞士军刀

  1. 追加:支持对话历史的自然增长
  2. 修改:允许用户编辑和 AI 重新生成
  3. 删除:实现滑动窗口和内存优化

这三个功能覆盖了聊天应用的所有消息管理需求,是 LangGraph 最常用的 Reducer。


问题 4:Multiple Schemas 如何实现状态的可见性控制?三种模式的作用?

点击查看详细解答

可见性控制的本质

Multiple Schemas 通过类型注解StateGraph 配置实现三层可见性控制:

InputState (输入层) → InternalState (内部层) → OutputState (输出层)
    用户可见           节点内部使用            用户可见

三种模式详解

模式 1:Private State(私有状态)

作用: 节点间传递私有数据,不暴露给用户

实现原理:

python
# 1. 定义公开状态和私有状态
class PublicState(TypedDict):
    result: int

class PrivateState(TypedDict):
    intermediate: int  # 私有字段

# 2. 使用类型注解控制可见性
def node_1(state: PublicState) -> PrivateState:
    """
    输入:PublicState(只能访问 result)
    输出:PrivateState(返回私有数据)
    """
    result = state["result"]
    intermediate = result * 2
    return {"intermediate": intermediate}

def node_2(state: PrivateState) -> PublicState:
    """
    输入:PrivateState(可以访问 intermediate)
    输出:PublicState(返回公开数据)
    """
    final = state["intermediate"] + 10
    return {"result": final}

# 3. 构建图(使用 PublicState 作为主状态)
builder = StateGraph(PublicState)
builder.add_node("node_1", node_1)
builder.add_node("node_2", node_2)
builder.add_edge(START, "node_1")
builder.add_edge("node_1", "node_2")
builder.add_edge("node_2", END)

graph = builder.compile()

# 🎨 可视化图结构
from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))

# 4. 执行
result = graph.invoke({"result": 5})
# result = {"result": 20}  ← intermediate 被隐藏

执行流程:

用户输入: {"result": 5}

node_1: 读取 result=5, 返回 {"intermediate": 10}

内部状态: {"result": 5, "intermediate": 10}  # LangGraph 内部合并

node_2: 读取 intermediate=10, 返回 {"result": 20}

内部状态: {"result": 20, "intermediate": 10}

输出过滤: {"result": 20}  ← intermediate 被过滤掉

关键机制:

python
# LangGraph 内部伪代码
def execute_node(node, state, node_input_type, node_output_type):
    # 1. 输入过滤:只传递节点声明的输入类型字段
    filtered_input = filter_by_type(state, node_input_type)

    # 2. 执行节点
    node_output = node(filtered_input)

    # 3. 合并输出:将节点返回的字段合并到内部状态
    internal_state.update(node_output)

    return internal_state
模式 2:Input Schema(输入模式)

作用: 限定用户必须提供的输入字段

实现原理:

python
# 1. 定义输入模式
class InputState(TypedDict):
    question: str  # 用户只需提供问题

# 2. 定义内部模式(完整状态)
class InternalState(TypedDict):
    question: str
    answer: str
    confidence: float
    processing_time: float  # 内部统计

# 3. 构建图时指定 input_schema
graph = StateGraph(
    InternalState,           # 内部使用完整状态
    input_schema=InputState  # ⭐ 限定输入
)

# 4. 用户调用时只能提供 InputState 定义的字段
result = graph.invoke({"question": "What is AI?"})
# ✅ 合法

result = graph.invoke({
    "question": "What is AI?",
    "answer": "..."  # ❌ 错误:不在 InputState 中
})

输入验证机制:

python
# LangGraph 内部伪代码
def invoke(user_input, config):
    # 1. 验证输入字段
    if hasattr(graph, 'input_schema'):
        validate_input(user_input, graph.input_schema)
        # 只保留 input_schema 定义的字段
        user_input = filter_by_schema(user_input, graph.input_schema)

    # 2. 初始化内部状态
    internal_state = initialize_state(graph.state_schema)
    internal_state.update(user_input)

    # 3. 执行图
    final_state = execute_graph(internal_state)

    return final_state
模式 3:Output Schema(输出模式)

作用: 过滤输出字段,只返回用户需要的数据

实现原理:

python
# 1. 定义输出模式
class OutputState(TypedDict):
    answer: str       # 用户需要
    confidence: float # 用户需要

# 2. 定义内部模式
class InternalState(TypedDict):
    question: str
    answer: str
    confidence: float
    processing_time: float  # 内部统计,不返回

# 3. 构建图时指定 output_schema
graph = StateGraph(
    InternalState,
    output_schema=OutputState  # ⭐ 限定输出
)

# 4. 执行图
result = graph.invoke({"question": "What is AI?"})
# result = {
#     "answer": "...",
#     "confidence": 0.95
# }  ← processing_time 被过滤掉

输出过滤机制:

python
# LangGraph 内部伪代码
def invoke(user_input, config):
    # 执行图
    final_state = execute_graph(user_input)

    # 输出过滤
    if hasattr(graph, 'output_schema'):
        return filter_by_schema(final_state, graph.output_schema)

    return final_state

三种模式组合使用

完整示例:问答系统

python
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END

# 1. 定义三种状态
class InputState(TypedDict):
    """用户输入:只需提供问题"""
    question: str

class OutputState(TypedDict):
    """用户输出:返回答案和置信度"""
    answer: str
    confidence: float

class InternalState(TypedDict):
    """内部状态:完整的处理数据"""
    question: str
    answer: str
    confidence: float
    retrieved_docs: list      # 私有:检索的文档
    llm_calls: int           # 私有:LLM 调用次数
    processing_time: float   # 私有:处理时间

# 2. 定义节点
def retrieve_node(state: InputState):
    """检索节点:只需要 question"""
    docs = vector_db.search(state["question"])
    return {
        "retrieved_docs": docs,
        "llm_calls": 0
    }

def generate_node(state: InternalState) -> OutputState:
    """生成节点:
    输入:可以访问所有内部字段
    输出:只返回 OutputState 字段
    """
    import time
    start = time.time()

    # 使用检索的文档生成答案
    answer = llm.invoke({
        "question": state["question"],
        "context": state["retrieved_docs"]
    })

    return {
        "answer": answer.content,
        "confidence": 0.95,
        "llm_calls": state["llm_calls"] + 1,
        "processing_time": time.time() - start
    }

# 3. 构建图
builder = StateGraph(
    InternalState,              # 内部完整状态
    input_schema=InputState,    # 限定输入
    output_schema=OutputState   # 限定输出
)

builder.add_node("retrieve", retrieve_node)
builder.add_node("generate", generate_node)
builder.add_edge(START, "retrieve")
builder.add_edge("retrieve", "generate")
builder.add_edge("generate", END)

graph = builder.compile()

# 🎨 可视化图结构
from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))

# 4. 使用
result = graph.invoke({"question": "What is LangGraph?"})
print(result)
# {
#     "answer": "LangGraph is a library for building...",
#     "confidence": 0.95
# }
# ← retrieved_docs, llm_calls, processing_time 都被隐藏

可见性控制的层次结构

python
# 完整的可见性层次
┌─────────────────────────────────┐
│     User (External World)       │
│  只看到 InputState/OutputState   │
└─────────────┬───────────────────┘

      InputState (输入验证)


┌─────────────────────────────────┐
│    LangGraph Internal           │
│  完整的 InternalState           │
│  ┌─────────────────────────┐  │
│  │ node_1 (InputState)     │  │
│  │    ↓                    │  │
│  │ [PrivateState created]  │  │
│  │    ↓                    │  │
│  │ node_2 (PrivateState)   │  │
│  │    ↓                    │  │
│  │ [PublicState updated]   │  │
│  └─────────────────────────┘  │
└─────────────┬───────────────────┘

      OutputState (输出过滤)


┌─────────────────────────────────┐
│     User (External World)       │
│  只看到 OutputState             │
└─────────────────────────────────┘

实战应用场景

场景 1:API 服务

python
class APIInput(TypedDict):
    """API 请求"""
    user_id: str
    query: str

class APIOutput(TypedDict):
    """API 响应"""
    result: str
    status: str

class InternalState(TypedDict):
    """内部处理"""
    user_id: str
    query: str
    result: str
    status: str
    api_key: str        # 私有:不返回
    rate_limit: int     # 私有:不返回
    cost: float         # 私有:不返回

graph = StateGraph(
    InternalState,
    input_schema=APIInput,
    output_schema=APIOutput
)
# 用户永远看不到 api_key, rate_limit, cost

场景 2:多租户系统

python
class TenantInput(TypedDict):
    tenant_id: str
    request: str

class TenantOutput(TypedDict):
    response: str

class InternalState(TypedDict):
    tenant_id: str
    request: str
    response: str
    tenant_config: dict    # 私有:租户配置
    usage_stats: dict      # 私有:使用统计
    internal_errors: list  # 私有:错误日志

# 每个租户只能看到自己的输入输出,
# 看不到其他租户的数据和系统内部信息

关键洞察

Multiple Schemas 实现了状态的"最小权限原则":

  1. 输入层:用户只需提供必需信息
  2. 内部层:系统可以使用完整信息
  3. 输出层:用户只能看到允许的信息

这不仅提高了安全性,还使 API 更清晰、更易维护。


问题 5:消息管理的三种技术(RemoveMessage、Filter、Trim)各自的适用场景?

点击查看详细解答

三种技术对比

技术修改状态Token 控制复杂度历史保留适用场景
RemoveMessage✅ 永久删除间接控制❌ 部分丢失确定不需要的历史
消息过滤❌ 不修改按数量控制✅ 完整保留简单的"最近 N 条"
Token 裁剪❌ 不修改按 Token 控制✅ 完整保留需要精确成本控制

技术 1:RemoveMessage - 永久删除

原理: 从状态中永久删除消息

代码示例:

python
from langchain_core.messages import RemoveMessage
from langgraph.graph import MessagesState

def filter_old_messages(state: MessagesState):
    """删除除最后 2 条外的所有消息"""
    messages = state["messages"]

    # 创建删除操作
    delete_ops = [
        RemoveMessage(id=m.id)
        for m in messages[:-2]  # 除最后 2 条
    ]

    return {"messages": delete_ops}

适用场景:

python
# ✅ 场景 1:定期清理旧对话
def cleanup_old_conversations(state):
    """删除 7 天前的消息"""
    from datetime import datetime, timedelta

    cutoff = datetime.now() - timedelta(days=7)
    messages = state["messages"]

    delete_ops = [
        RemoveMessage(id=m.id)
        for m in messages
        if hasattr(m, 'timestamp') and m.timestamp < cutoff
    ]

    return {"messages": delete_ops}

# ✅ 场景 2:删除敏感信息
def remove_sensitive_messages(state):
    """删除包含敏感词的消息"""
    sensitive_keywords = ["password", "credit_card", "ssn"]
    messages = state["messages"]

    delete_ops = []
    for msg in messages:
        if any(keyword in msg.content.lower() for keyword in sensitive_keywords):
            delete_ops.append(RemoveMessage(id=msg.id))

    return {"messages": delete_ops}

# ✅ 场景 3:实现滑动窗口
def sliding_window(state, window_size=10):
    """只保留最近 N 条消息"""
    messages = state["messages"]

    if len(messages) > window_size:
        delete_ops = [
            RemoveMessage(id=m.id)
            for m in messages[:-window_size]
        ]
        return {"messages": delete_ops}

    return {}

优缺点:

python
# 优点
✅ 永久减少状态大小(节省内存)
✅ 实现简单(无需额外配置)
✅ 适合确定不需要的历史

# 缺点
❌ 数据不可恢复(一旦删除无法找回)
❌ 可能丢失重要上下文
❌ 不适合需要回溯的场景

技术 2:消息过滤 - 不修改状态

原理: 传递消息子集给 LLM,状态保持完整

代码示例:

python
def chat_node(state: MessagesState):
    """只传递最后 5 条消息给 LLM"""
    # 1. 过滤消息
    recent_messages = state["messages"][-5:]

    # 2. 调用 LLM(只看到 5 条)
    response = llm.invoke(recent_messages)

    # 3. 返回回复(追加到完整历史)
    return {"messages": [response]}

# 状态仍然保留所有历史消息!

适用场景:

python
# ✅ 场景 1:简单的"最近 N 条"逻辑
def chat_with_recent_context(state):
    """只使用最近 10 条消息"""
    return {"messages": [llm.invoke(state["messages"][-10:])]}

# ✅ 场景 2:按消息类型过滤
def chat_with_user_messages_only(state):
    """只传递用户消息(忽略系统消息)"""
    from langchain_core.messages import HumanMessage

    user_messages = [
        m for m in state["messages"]
        if isinstance(m, HumanMessage)
    ]

    return {"messages": [llm.invoke(user_messages)]}

# ✅ 场景 3:按时间窗口过滤
def chat_with_recent_time_window(state):
    """只使用最近 5 分钟的消息"""
    from datetime import datetime, timedelta

    cutoff = datetime.now() - timedelta(minutes=5)
    recent_messages = [
        m for m in state["messages"]
        if hasattr(m, 'timestamp') and m.timestamp > cutoff
    ]

    return {"messages": [llm.invoke(recent_messages)]}

# ✅ 场景 4:按内容长度过滤
def chat_with_short_messages(state):
    """只使用短消息(< 100 字符)"""
    short_messages = [
        m for m in state["messages"]
        if len(m.content) < 100
    ]

    return {"messages": [llm.invoke(short_messages)]}

优缺点:

python
# 优点
✅ 状态完整保留(可用于日志、分析)
✅ 实现简单(无需 Reducer)
✅ 灵活(可以随时调整过滤规则)
✅ 可逆(不修改原始数据)

# 缺点
LLM 缺乏完整上下文(可能回答不连贯)
❌ 粗糙的控制(不考虑 Token 数量)
❌ 需要手动管理过滤逻辑

技术 3:Token 裁剪 - 精确控制

原理: 基于 Token 数量智能截断消息

代码示例:

python
from langchain_core.messages import trim_messages
from langchain_openai import ChatOpenAI

model = ChatOpenAI(model="gpt-4")

def chat_with_token_limit(state: MessagesState):
    """裁剪到最多 1000 tokens"""
    # 1. 裁剪消息
    trimmed_messages = trim_messages(
        state["messages"],
        max_tokens=1000,           # 最大 Token 数
        strategy="last",           # 保留最后的消息
        token_counter=model,       # 使用模型的 tokenizer
        allow_partial=False        # 不允许截断单条消息
    )

    # 2. 调用 LLM
    response = model.invoke(trimmed_messages)

    return {"messages": [response]}

高级参数详解:

python
# 参数 1: max_tokens - 最大 Token 数
trim_messages(messages, max_tokens=500)  # 严格控制成本

# 参数 2: strategy - 裁剪策略
trim_messages(messages, strategy="last")   # 保留最后的消息(默认)
trim_messages(messages, strategy="first")  # 保留最开始的消息

# 参数 3: token_counter - Token 计数器
trim_messages(messages, token_counter=model)  # 使用模型的 tokenizer
trim_messages(messages, token_counter=lambda x: len(x.split()))  # 简单按单词数

# 参数 4: allow_partial - 是否允许截断单条消息
trim_messages(messages, allow_partial=False)  # 保留完整消息
trim_messages(messages, allow_partial=True)   # 可以截断最后一条

# 参数 5: include_system - 是否包含系统消息
trim_messages(messages, include_system=True)   # 包含系统消息(默认)
trim_messages(messages, include_system=False)  # 排除系统消息

适用场景:

python
# ✅ 场景 1:精确控制成本
def cost_optimized_chat(state):
    """确保每次调用不超过 1000 tokens(约 $0.03)"""
    trimmed = trim_messages(
        state["messages"],
        max_tokens=1000,
        token_counter=model
    )
    return {"messages": [model.invoke(trimmed)]}

# ✅ 场景 2:充分利用上下文窗口
def max_context_chat(state):
    """使用接近上下文窗口的最大 tokens"""
    # GPT-4: 8192 tokens
    # 预留 1000 tokens 给回复
    max_input_tokens = 8192 - 1000

    trimmed = trim_messages(
        state["messages"],
        max_tokens=max_input_tokens,
        token_counter=model
    )
    return {"messages": [model.invoke(trimmed)]}

# ✅ 场景 3:保留系统指令 + 最近对话
def chat_with_system_prompt(state):
    """保留系统消息,裁剪对话历史"""
    from langchain_core.messages import SystemMessage

    # 分离系统消息和对话消息
    system_messages = [m for m in state["messages"] if isinstance(m, SystemMessage)]
    chat_messages = [m for m in state["messages"] if not isinstance(m, SystemMessage)]

    # 计算系统消息的 tokens
    system_tokens = sum(model.get_num_tokens(m.content) for m in system_messages)

    # 裁剪对话历史(减去系统消息占用的 tokens)
    trimmed_chat = trim_messages(
        chat_messages,
        max_tokens=2000 - system_tokens,
        token_counter=model
    )

    # 组合:系统消息 + 裁剪后的对话
    final_messages = system_messages + trimmed_chat
    return {"messages": [model.invoke(final_messages)]}

优缺点:

python
# 优点
✅ 精确控制 Token 成本
✅ 自动处理不同长度的消息
✅ 充分利用上下文窗口
✅ 状态完整保留

# 缺点
❌ 实现复杂度较高
❌ 需要模型的 tokenizer
❌ 有计算开销(token 计数)

三种技术组合使用

实战案例:智能客服机器人

python
from langchain_core.messages import RemoveMessage, trim_messages, SystemMessage
from langgraph.graph import MessagesState
from datetime import datetime, timedelta

class CustomerServiceState(MessagesState):
    max_messages: int = 50
    max_tokens: int = 2000

def intelligent_message_management(state: CustomerServiceState):
    """
    多层消息管理策略:
    1. 删除超过 24 小时的消息(RemoveMessage)
    2. 保留系统指令(Filter)
    3. Token 裁剪对话历史(Trim)
    """
    messages = state["messages"]
    updates = []

    # 第 1 层:删除过期消息
    cutoff_time = datetime.now() - timedelta(hours=24)
    expired_messages = [
        m for m in messages
        if hasattr(m, 'timestamp') and m.timestamp < cutoff_time
    ]

    if expired_messages:
        delete_ops = [RemoveMessage(id=m.id) for m in expired_messages]
        updates.extend(delete_ops)
        # 更新消息列表
        messages = [m for m in messages if m not in expired_messages]

    # 第 2 层:分离系统消息和对话消息
    system_messages = [m for m in messages if isinstance(m, SystemMessage)]
    chat_messages = [m for m in messages if not isinstance(m, SystemMessage)]

    # 第 3 层:Token 裁剪
    if chat_messages:
        # 计算系统消息占用的 tokens
        system_tokens = sum(
            model.get_num_tokens(m.content)
            for m in system_messages
        )

        # 裁剪对话历史
        available_tokens = state["max_tokens"] - system_tokens
        trimmed_chat = trim_messages(
            chat_messages,
            max_tokens=available_tokens,
            token_counter=model,
            strategy="last"
        )

        # 组合消息
        final_messages = system_messages + trimmed_chat
    else:
        final_messages = system_messages

    # 调用 LLM
    response = model.invoke(final_messages)
    updates.append(response)

    return {"messages": updates}

选择决策流程图

mermaid
graph TD
    A[需要管理消息] --> B{历史需要保留?}
    B -->|不需要| C[使用 RemoveMessage]
    B -->|需要| D{Token 数量重要?}
    D -->|不重要| E[使用消息过滤]
    D -->|重要| F{消息长度变化大?}
    F -->|否| E
    F -->|是| G[使用 Token 裁剪]

    C --> H[定期清理/滑动窗口]
    E --> I[简单规则/快速原型]
    G --> J[生产环境/成本优化]

性能对比

假设 10 轮对话,每轮平均 100 tokens:

方案状态大小Token 使用响应延迟上下文完整性
无管理1000 tokens1000 tokens100%
RemoveMessage (保留 2 条)200 tokens200 tokens20%
Filter (最后 5 条)1000 tokens500 tokens50%
Trim (500 tokens)1000 tokens500 tokens动态

关键洞察

选择建议:

  1. 开发阶段:使用消息过滤(简单快速)
  2. 小规模应用:RemoveMessage + 过滤组合
  3. 生产环境:Token 裁剪(精确控制)
  4. 企业级:三种技术组合使用

核心原则: 根据业务需求选择,不要过度优化。


问题 6:如何实现一个支持消息摘要的聊天机器人?关键步骤是什么?

点击查看详细解答

系统架构

消息摘要系统包含三个核心组件:

┌──────────────────────────────────────┐
│   1. 摘要触发机制(何时摘要)           │
│      - 基于消息数量                    │
│      - 基于 Token 数量                 │
│      - 基于时间窗口                    │
└──────────┬───────────────────────────┘

┌──────────────────────────────────────┐
│   2. 摘要生成逻辑(如何摘要)           │
│      - 首次摘要 vs 增量摘要            │
│      - 系统提示词设计                   │
│      - LLM 调用                       │
└──────────┬───────────────────────────┘

┌──────────────────────────────────────┐
│   3. 消息修剪策略(保留什么)           │
│      - 删除旧消息                      │
│      - 保留最近 N 条                   │
│      - 保留摘要                        │
└──────────────────────────────────────┘

完整实现代码

python
from typing_extensions import TypedDict
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, RemoveMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, MessagesState, START, END

# 步骤 1:定义状态
class ChatState(MessagesState):
    summary: str  # 对话摘要

# 步骤 2:初始化模型
model = ChatOpenAI(model="gpt-4", temperature=0)

# 步骤 3:定义对话节点
def call_model(state: ChatState):
    """
    对话节点:处理用户消息并生成回复
    """
    # 获取摘要
    summary = state.get("summary", "")

    # 如果有摘要,添加到上下文
    if summary:
        # 构建系统消息包含摘要
        system_message = SystemMessage(
            content=f"Summary of conversation earlier: {summary}"
        )
        messages = [system_message] + state["messages"]
    else:
        messages = state["messages"]

    # 调用 LLM
    response = model.invoke(messages)

    return {"messages": [response]}

# 步骤 4:定义摘要节点
def summarize_conversation(state: ChatState):
    """
    摘要节点:生成/更新对话摘要并删除旧消息
    """
    # 1. 获取现有摘要
    summary = state.get("summary", "")

    # 2. 创建摘要提示词
    if summary:
        # 增量摘要:扩展现有摘要
        summary_message = (
            f"This is summary of the conversation to date: {summary}\n\n"
            "Extend the summary by taking into account the new messages above:"
        )
    else:
        # 首次摘要:创建新摘要
        summary_message = "Create a summary of the conversation above:"

    # 3. 调用 LLM 生成摘要
    messages = state["messages"] + [HumanMessage(content=summary_message)]
    response = model.invoke(messages)

    # 4. 删除旧消息,只保留最后 2 条
    delete_messages = [
        RemoveMessage(id=m.id)
        for m in state["messages"][:-2]
    ]

    # 5. 返回更新
    return {
        "summary": response.content,
        "messages": delete_messages
    }

# 步骤 5:定义条件函数(摘要触发)
def should_continue(state: ChatState):
    """
    决定是否需要摘要
    """
    messages = state["messages"]

    # 如果消息超过 6 条,触发摘要
    if len(messages) > 6:
        return "summarize_conversation"

    # 否则结束
    return END

# 步骤 6:构建图
from langgraph.checkpoint.memory import MemorySaver

# 创建图
workflow = StateGraph(ChatState)

# 添加节点
workflow.add_node("conversation", call_model)
workflow.add_node("summarize_conversation", summarize_conversation)

# 添加边
workflow.add_edge(START, "conversation")
workflow.add_conditional_edges("conversation", should_continue)
workflow.add_edge("summarize_conversation", END)

# 编译(带内存)
memory = MemorySaver()
graph = workflow.compile(checkpointer=memory)

# 步骤 7:使用示例
config = {"configurable": {"thread_id": "1"}}

# 第 1-6 轮对话:正常累积
for i in range(6):
    output = graph.invoke({
        "messages": [HumanMessage(f"Message {i+1}")]
    }, config)
    print(f"Round {i+1}: {len(output['messages'])} messages")

# 第 7 轮对话:触发摘要
output = graph.invoke({
    "messages": [HumanMessage("Message 7")]
}, config)

# 查看摘要
summary = graph.get_state(config).values.get("summary", "")
print(f"\nSummary created: {summary}")
print(f"Messages remaining: {len(output['messages'])}")

关键步骤详解

步骤 1:设计状态结构
python
class ChatState(MessagesState):
    summary: str  # ⭐ 关键字段

设计要点:

  • 继承 MessagesState 自动获得 messages 字段
  • 添加 summary 字段存储摘要
  • 可以扩展其他字段(如 summary_count, last_summary_time 等)
步骤 2:实现增量摘要
python
# 首次摘要(无现有摘要)
if not summary:
    prompt = "Create a summary of the conversation above:"

# 增量摘要(有现有摘要)
else:
    prompt = f"""
    This is summary of the conversation to date: {summary}

    Extend the summary by taking into account the new messages above:
    """

增量摘要的优势:

python
# 首次摘要(7 条消息)
summary_1 = "User is Lance, likes 49ers..."

# 增量摘要(新增 5 条消息)
# 只需要摘要新消息,不需要重新摘要全部
prompt = f"Current summary: {summary_1}\nExtend with new messages..."
summary_2 = "User is Lance, likes 49ers and Nick Bosa, asked about salaries..."
步骤 3:智能消息修剪
python
# 保留最后 2 条消息
delete_messages = [RemoveMessage(id=m.id) for m in state["messages"][:-2]]

为什么保留 2 条?

最后 2 条通常是:
- 最新的用户消息
- 最新的 AI 回复

这确保对话的连续性和上下文。

可配置的保留策略:

python
class ChatState(MessagesState):
    summary: str
    keep_last_n: int = 2  # 可配置

def summarize_conversation(state: ChatState):
    keep_n = state.get("keep_last_n", 2)
    delete_messages = [
        RemoveMessage(id=m.id)
        for m in state["messages"][:-keep_n]
    ]
    # ...

高级优化技巧

优化 1:多层摘要
python
class AdvancedChatState(MessagesState):
    short_term_summary: str  # 最近 10 轮
    long_term_summary: str   # 完整历史

def should_summarize(state):
    msg_count = len(state["messages"])

    if msg_count > 20:
        return "long_term_summary"  # 长期摘要
    elif msg_count > 6:
        return "short_term_summary"  # 短期摘要
    return END

# 使用时:
# short_term_summary: 提供近期上下文
# long_term_summary: 提供背景信息
优化 2:结构化摘要
python
from pydantic import BaseModel

class StructuredSummary(BaseModel):
    """结构化摘要"""
    user_info: str          # 用户信息
    topics: list[str]       # 讨论的主题
    key_facts: list[str]    # 关键事实
    pending_questions: list[str]  # 待解答的问题

def structured_summarize(state: ChatState):
    """生成结构化摘要"""
    prompt = """
    Summarize this conversation with the following structure:
    - User profile (name, interests, etc.)
    - Topics discussed
    - Key facts mentioned
    - Any pending questions
    """

    messages = state["messages"] + [HumanMessage(prompt)]
    response = model.with_structured_output(StructuredSummary).invoke(messages)

    # 将结构化摘要转为 JSON 存储
    summary_json = response.json()

    # ...
优化 3:选择性摘要
python
def intelligent_should_summarize(state: ChatState):
    """
    智能判断是否需要摘要:
    - 消息数量
    - Token 数量
    - 话题转换
    """
    messages = state["messages"]

    # 条件 1:消息数量
    if len(messages) > 6:
        # 条件 2:Token 数量
        total_tokens = sum(model.get_num_tokens(m.content) for m in messages)
        if total_tokens > 2000:
            # 条件 3:是否有话题转换
            if detect_topic_change(messages):
                return "summarize_conversation"

    return END

def detect_topic_change(messages):
    """检测话题是否转换(简单实现)"""
    if len(messages) < 2:
        return False

    last_content = messages[-1].content
    prev_content = messages[-2].content

    # 使用关键词相似度或 LLM 判断
    similarity = calculate_similarity(last_content, prev_content)
    return similarity < 0.3  # 相似度低表示话题转换
优化 4:摘要质量控制
python
def summarize_with_quality_check(state: ChatState):
    """带质量检查的摘要生成"""
    # 生成摘要
    summary_prompt = "Create a comprehensive summary..."
    messages = state["messages"] + [HumanMessage(summary_prompt)]
    summary_response = model.invoke(messages)
    summary = summary_response.content

    # 质量检查 1:长度检查
    if len(summary.split()) < 10:
        # 摘要太短,重新生成
        retry_prompt = "Create a MORE DETAILED summary..."
        summary = model.invoke([HumanMessage(retry_prompt)]).content

    # 质量检查 2:完整性检查
    if not contains_key_info(summary, state["messages"]):
        # 摘要缺少关键信息,重新生成
        detailed_prompt = "Include: user name, main topics, key facts..."
        summary = model.invoke([HumanMessage(detailed_prompt)]).content

    # ...
    return {"summary": summary, "messages": delete_messages}

def contains_key_info(summary: str, messages: list) -> bool:
    """检查摘要是否包含关键信息"""
    # 提取用户名
    user_names = extract_names(messages)
    for name in user_names:
        if name.lower() not in summary.lower():
            return False

    # 检查主要话题
    topics = extract_topics(messages)
    topic_covered = sum(1 for topic in topics if topic in summary)
    return topic_covered >= len(topics) * 0.7  # 70% 的话题被覆盖

完整的生产级实现

python
class ProductionChatState(MessagesState):
    """生产级聊天状态"""
    summary: str
    summary_count: int = 0      # 摘要次数
    total_messages: int = 0     # 总消息数
    last_summary_at: int = 0    # 上次摘要时的消息数
    summary_strategy: str = "adaptive"  # 摘要策略

def adaptive_summarize(state: ProductionChatState):
    """自适应摘要策略"""
    msg_count = len(state["messages"])
    total_msgs = state["total_messages"]
    strategy = state["summary_strategy"]

    if strategy == "aggressive":
        # 激进策略:频繁摘要,最小化内存
        keep_last = 2
    elif strategy == "balanced":
        # 平衡策略:适度摘要
        keep_last = 5
    elif strategy == "conservative":
        # 保守策略:尽量保留历史
        keep_last = 10
    else:  # adaptive
        # 自适应:根据对话长度动态调整
        if total_msgs < 20:
            keep_last = 5
        elif total_msgs < 50:
            keep_last = 3
        else:
            keep_last = 2

    # 生成摘要
    summary = generate_summary(state)

    # 删除旧消息
    delete_ops = [
        RemoveMessage(id=m.id)
        for m in state["messages"][:-keep_last]
    ]

    return {
        "summary": summary,
        "messages": delete_ops,
        "summary_count": state["summary_count"] + 1,
        "last_summary_at": total_msgs
    }

关键洞察

消息摘要系统的核心要素:

  1. 触发机制:何时摘要(消息数/Token/时间)
  2. 摘要策略:如何摘要(首次/增量/结构化)
  3. 修剪策略:保留什么(最近 N 条)
  4. 质量控制:如何确保摘要质量

最佳实践:

  • 从简单开始(消息数量触发)
  • 逐步优化(Token 控制、结构化)
  • 监控效果(摘要质量、成本节省)
  • 根据业务调整(不同场景不同策略)

I'll continue with questions 7-15 in the next response to complete the review document. Let me know when you're ready to continue.

问题 7:如何配置外部数据库实现跨会话的持久化存储?

点击查看详细解答

持久化存储的核心概念

Checkpointer 是 LangGraph 的状态持久化机制,负责:

  1. 保存每一步的状态快照
  2. 支持状态恢复和时间旅行
  3. 实现跨会话的对话记忆

三种 Checkpointer 对比:

类型存储位置持久化并发适用场景
MemorySaver进程内存单进程开发测试
SqliteSaverSQLite 文件单机中小应用
PostgresSaverPostgreSQL分布式企业应用

方案 1:MemorySaver(开发环境)

代码示例:

python
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph

# 创建内存 checkpointer
memory = MemorySaver()

# 编译图时指定
graph = workflow.compile(checkpointer=memory)

# 使用(状态只在进程运行期间存在)
config = {"configurable": {"thread_id": "1"}}
graph.invoke({"messages": [...]}, config)

特点:

  • ✅ 简单快速,无需配置
  • ✅ 开发调试方便
  • ❌ 进程重启后数据丢失
  • ❌ 不适合生产环境

方案 2:SqliteSaver(生产小规模)

完整配置步骤:

python
import sqlite3
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import StateGraph, MessagesState

# 步骤 1:创建 SQLite 连接
# 方式 A:内存数据库(临时测试)
conn = sqlite3.connect(":memory:", check_same_thread=False)

# 方式 B:文件数据库(持久化)
conn = sqlite3.connect("chatbot.db", check_same_thread=False)

# 步骤 2:创建 SqliteSaver
memory = SqliteSaver(conn)
# SqliteSaver 会自动创建必要的表:
# - checkpoints:存储状态快照
# - checkpoint_writes:存储写入操作

# 步骤 3:编译图
class State(MessagesState):
    summary: str

workflow = StateGraph(State)
# ... 添加节点和边 ...

graph = workflow.compile(checkpointer=memory)

# 步骤 4:使用
config = {"configurable": {"thread_id": "user_123"}}

# 第一次对话
graph.invoke({"messages": [HumanMessage("Hi")]}, config)

# 程序重启后...
# 重新连接数据库
conn = sqlite3.connect("chatbot.db", check_same_thread=False)
memory = SqliteSaver(conn)
graph = workflow.compile(checkpointer=memory)

# 继续之前的对话(状态自动恢复!)
result = graph.invoke({"messages": [HumanMessage("继续")]}, config)
# ✅ 能够访问之前的对话历史

SQLite 数据库结构:

sql
-- checkpoints 表
CREATE TABLE checkpoints (
    thread_id TEXT,              -- 对话线程 ID
    checkpoint_ns TEXT,          -- 命名空间
    checkpoint_id TEXT,          -- 检查点 ID(UUID)
    parent_checkpoint_id TEXT,   -- 父检查点 ID
    type TEXT,                   -- 类型
    checkpoint BLOB,             -- 序列化的状态数据
    metadata BLOB,               -- 元数据
    created_at TIMESTAMP,        -- 创建时间
    PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
);

-- checkpoint_writes 表
CREATE TABLE checkpoint_writes (
    thread_id TEXT,
    checkpoint_ns TEXT,
    checkpoint_id TEXT,
    task_id TEXT,
    idx INTEGER,
    channel TEXT,
    type TEXT,
    value BLOB,
    PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);

高级配置:

python
# 配置 1:设置数据库路径
import os

DB_DIR = "data/checkpoints"
os.makedirs(DB_DIR, exist_ok=True)

db_path = os.path.join(DB_DIR, "chatbot.db")
conn = sqlite3.connect(db_path, check_same_thread=False)

# 配置 2:启用 WAL 模式(提高并发性能)
conn.execute("PRAGMA journal_mode=WAL")

# 配置 3:设置缓存大小(提高性能)
conn.execute("PRAGMA cache_size=10000")  # 10000 pages (~40MB)

memory = SqliteSaver(conn)

# 配置 4:定期备份
import shutil
from datetime import datetime

def backup_database():
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    backup_path = f"backups/chatbot_{timestamp}.db"
    shutil.copy("data/checkpoints/chatbot.db", backup_path)
    print(f"Backup created: {backup_path}")

# 每天备份一次
backup_database()

方案 3:PostgresSaver(企业级)

配置步骤:

python
from langgraph.checkpoint.postgres import PostgresSaver
import asyncpg

# 步骤 1:创建 PostgreSQL 连接
DATABASE_URL = "postgresql://user:password@localhost:5432/langgraph"

# 异步版本
async def setup_postgres():
    pool = await asyncpg.create_pool(DATABASE_URL)
    memory = PostgresSaver(pool)
    return memory

# 同步版本(使用 psycopg2)
from langgraph.checkpoint.postgres import PostgresSaver
import psycopg2

conn = psycopg2.connect(DATABASE_URL)
memory = PostgresSaver(conn)

# 步骤 2:初始化数据库表
# PostgresSaver 会自动创建必要的表
graph = workflow.compile(checkpointer=memory)

# 步骤 3:使用
config = {"configurable": {"thread_id": "user_123"}}
graph.invoke({"messages": [...]}, config)

PostgreSQL 优势:

python
# 1. 支持分布式部署
# 多个服务器实例共享同一个数据库
server_1 = create_app(PostgresSaver(conn))
server_2 = create_app(PostgresSaver(conn))
# 两个服务器可以访问相同的对话状态

# 2. 高并发支持
# PostgreSQL 的 MVCC 机制支持高并发读写

# 3. 事务支持
with conn:
    with conn.cursor() as cur:
        # 原子性操作
        cur.execute("UPDATE checkpoints SET ...")
        cur.execute("INSERT INTO checkpoint_writes ...")

# 4. 高级查询
# 可以使用 SQL 查询分析对话数据
conn.execute("""
    SELECT thread_id, COUNT(*) as checkpoint_count
    FROM checkpoints
    GROUP BY thread_id
    ORDER BY checkpoint_count DESC
    LIMIT 10
""")

Thread 管理详解

Thread 的作用:

Thread(线程)是一组相关状态快照的集合,用于:

  1. 隔离不同用户的对话
  2. 支持同一用户的多个会话
  3. 实现对话的分组和管理

Thread ID 设计模式:

python
# 模式 1:基于用户 ID
def get_user_config(user_id: str):
    return {"configurable": {"thread_id": f"user_{user_id}"}}

# 模式 2:基于会话 ID
import uuid
def create_session():
    session_id = str(uuid.uuid4())
    return {"configurable": {"thread_id": session_id}}

# 模式 3:组合模式(用户 + 场景)
def get_context_config(user_id: str, context: str):
    thread_id = f"{user_id}_{context}"
    return {"configurable": {"thread_id": thread_id}}

# 使用示例:
config = get_context_config("alice", "project_a")
# thread_id = "alice_project_a"

多线程管理:

python
class ConversationManager:
    """对话管理器"""

    def __init__(self, checkpointer):
        self.checkpointer = checkpointer
        self.graph = workflow.compile(checkpointer=checkpointer)

    def start_conversation(self, user_id: str, context: str = "default"):
        """开始新对话"""
        thread_id = f"{user_id}_{context}_{uuid.uuid4()}"
        config = {"configurable": {"thread_id": thread_id}}
        return config

    def get_conversation(self, thread_id: str):
        """获取对话状态"""
        config = {"configurable": {"thread_id": thread_id}}
        return self.graph.get_state(config)

    def list_user_conversations(self, user_id: str):
        """列出用户的所有对话"""
        # 查询数据库
        query = """
            SELECT DISTINCT thread_id, created_at
            FROM checkpoints
            WHERE thread_id LIKE ?
            ORDER BY created_at DESC
        """
        results = conn.execute(query, (f"{user_id}_%",)).fetchall()
        return results

    def delete_conversation(self, thread_id: str):
        """删除对话"""
        conn.execute("DELETE FROM checkpoints WHERE thread_id = ?", (thread_id,))
        conn.execute("DELETE FROM checkpoint_writes WHERE thread_id = ?", (thread_id,))
        conn.commit()

状态查询和恢复

API 详解:

python
# 1. get_state:获取当前状态
config = {"configurable": {"thread_id": "user_123"}}
state_snapshot = graph.get_state(config)

# StateSnapshot 对象包含:
state_snapshot.values         # 当前状态的所有数据
state_snapshot.next          # 下一步要执行的节点
state_snapshot.config        # 配置信息
state_snapshot.metadata      # 元数据(步数、来源等)
state_snapshot.created_at    # 创建时间
state_snapshot.parent_config # 父状态配置

# 2. get_state_history:获取历史状态
history = graph.get_state_history(config)

# 遍历历史
for state in history:
    print(f"Step {state.metadata['step']}: {state.values}")

# 3. update_state:手动更新状态
new_state = {"messages": [HumanMessage("Manual update")]}
graph.update_state(config, new_state)

# 4. 时间旅行:回溯到历史状态
# 获取历史状态列表
history_list = list(graph.get_state_history(config))

# 选择特定历史状态
past_state = history_list[5]  # 第 5 个历史状态

# 从该状态继续执行
result = graph.invoke(
    {"messages": [HumanMessage("Continue from past")]},
    past_state.config  # 使用历史状态的 config
)

性能优化技巧

python
# 技巧 1:批量插入
def bulk_save_conversations(conversations):
    """批量保存多个对话"""
    conn = sqlite3.connect("chatbot.db")
    cursor = conn.cursor()

    for conv in conversations:
        cursor.execute("INSERT INTO checkpoints (...) VALUES (...)")

    conn.commit()  # 一次性提交

# 技巧 2:定期清理
def cleanup_old_checkpoints(days=30):
    """删除超过 N 天的检查点"""
    from datetime import datetime, timedelta

    cutoff = datetime.now() - timedelta(days=days)
    conn.execute(
        "DELETE FROM checkpoints WHERE created_at < ?",
        (cutoff,)
    )
    conn.commit()

# 技巧 3:索引优化
def create_indexes():
    """创建索引提高查询性能"""
    conn.execute("CREATE INDEX IF NOT EXISTS idx_thread_id ON checkpoints(thread_id)")
    conn.execute("CREATE INDEX IF NOT EXISTS idx_created_at ON checkpoints(created_at)")
    conn.commit()

# 技巧 4:连接池
from contextlib import contextmanager

class ConnectionPool:
    def __init__(self, db_path, pool_size=5):
        self.db_path = db_path
        self.pool = [
            sqlite3.connect(db_path, check_same_thread=False)
            for _ in range(pool_size)
        ]
        self.available = self.pool.copy()

    @contextmanager
    def get_connection(self):
        conn = self.available.pop()
        try:
            yield conn
        finally:
            self.available.append(conn)

# 使用
pool = ConnectionPool("chatbot.db")
with pool.get_connection() as conn:
    memory = SqliteSaver(conn)
    # ...

监控和调试

python
# 1. 查看数据库统计
def get_database_stats():
    """获取数据库统计信息"""
    stats = {}

    # 总检查点数
    result = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()
    stats["total_checkpoints"] = result[0]

    # 活跃线程数
    result = conn.execute("SELECT COUNT(DISTINCT thread_id) FROM checkpoints").fetchone()
    stats["active_threads"] = result[0]

    # 数据库大小
    import os
    stats["db_size_mb"] = os.path.getsize("chatbot.db") / (1024 * 1024)

    return stats

# 2. 导出对话历史
def export_conversation(thread_id: str, output_path: str):
    """导出对话到 JSON 文件"""
    import json

    config = {"configurable": {"thread_id": thread_id}}
    history = list(graph.get_state_history(config))

    data = []
    for state in history:
        data.append({
            "step": state.metadata.get("step"),
            "values": state.values,
            "timestamp": state.created_at
        })

    with open(output_path, "w") as f:
        json.dump(data, f, indent=2)

# 3. 日志记录
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def logged_invoke(graph, input_data, config):
    """带日志的图执行"""
    thread_id = config["configurable"]["thread_id"]
    logger.info(f"Starting execution for thread: {thread_id}")

    try:
        result = graph.invoke(input_data, config)
        logger.info(f"Execution completed for thread: {thread_id}")
        return result
    except Exception as e:
        logger.error(f"Execution failed for thread {thread_id}: {e}")
        raise

关键洞察

持久化存储的最佳实践:

  1. 开发阶段:使用 MemorySaver,快速迭代
  2. 测试阶段:使用 SqliteSaver(内存模式),验证逻辑
  3. 生产阶段:使用 SqliteSaver(文件模式)或 PostgresSaver
  4. 企业阶段:使用 PostgresSaver + 分布式架构

Thread ID 设计原则:

  • 确保唯一性(避免冲突)
  • 有意义(便于查询和调试)
  • 可扩展(支持未来需求)

问题 8:设计一个完整的客服聊天机器人,需要考虑哪些状态管理要点?

点击查看详细解答

系统需求分析

一个生产级客服聊天机器人需要考虑:

1. 状态管理
   ├─ 用户信息(姓名、ID、权限)
   ├─ 对话历史(消息、摘要)
   ├─ 会话上下文(问题类型、优先级)
   └─ 系统状态(统计、日志)

2. 性能优化
   ├─ 消息裁剪(Token 控制)
   ├─ 智能摘要(长对话支持)
   └─ 缓存机制(常见问题)

3. 数据持久化
   ├─ PostgreSQL(状态存储)
   ├─ Redis(缓存层)
   └─ S3(历史归档)

4. 安全与隐私
   ├─ 数据加密
   ├─ 访问控制
   └─ 敏感信息脱敏

完整架构设计

python
from typing_extensions import TypedDict
from typing import Annotated, Literal
from datetime import datetime
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.checkpoint.postgres import PostgresSaver
from langchain_openai import ChatOpenAI
import operator

# ============= 状态定义 =============

class CustomerInfo(TypedDict):
    """客户信息"""
    user_id: str
    name: str
    email: str
    tier: Literal["basic", "premium", "enterprise"]  # 客户等级
    language: str  # 用户语言

class IssueInfo(TypedDict):
    """问题信息"""
    category: Literal["technical", "billing", "general"]
    priority: Literal["low", "medium", "high", "urgent"]
    status: Literal["open", "in_progress", "resolved"]
    tags: list[str]

class SystemMetrics(TypedDict):
    """系统指标"""
    message_count: int
    llm_calls: int
    total_tokens: int
    session_start: datetime
    last_activity: datetime

class CustomerServiceState(MessagesState):
    """客服机器人状态"""
    # 客户信息
    customer: CustomerInfo

    # 问题信息
    issue: IssueInfo

    # 对话管理
    summary: str  # 对话摘要
    short_term_memory: list[str]  # 短期记忆(最近 3 轮)

    # 系统指标
    metrics: SystemMetrics

    # 功能标志
    needs_human: bool  # 是否需要人工介入
    satisfaction_score: float  # 满意度评分

# ============= 节点实现 =============

# 初始化模型
model = ChatOpenAI(model="gpt-4", temperature=0)

def route_message(state: CustomerServiceState) -> Literal["handle_technical", "handle_billing", "handle_general"]:
    """
    路由节点:根据问题类型分发
    """
    last_message = state["messages"][-1].content

    # 使用 LLM 分类
    classifier_prompt = f"""
    Classify this customer message into one category:
    - technical: Technical issues, bugs, errors
    - billing: Payment, invoices, subscriptions
    - general: General questions, feedback

    Message: {last_message}

    Return only the category name.
    """

    category = model.invoke([HumanMessage(classifier_prompt)]).content.strip().lower()

    # 更新状态
    state["issue"]["category"] = category

    return f"handle_{category}"

def handle_technical(state: CustomerServiceState):
    """
    技术支持节点
    """
    # 1. 准备上下文
    customer = state["customer"]
    summary = state.get("summary", "")

    system_prompt = f"""
    You are a technical support specialist.
    Customer: {customer['name']} ({customer['tier']} tier)
    Previous conversation summary: {summary}

    Provide clear, step-by-step technical solutions.
    """

    # 2. 准备消息(包含摘要 + 最近消息)
    messages = [SystemMessage(system_prompt)]

    # 添加短期记忆
    if state.get("short_term_memory"):
        for memory in state["short_term_memory"]:
            messages.append(SystemMessage(f"Recent context: {memory}"))

    # 添加最近消息
    messages.extend(state["messages"][-5:])

    # 3. 调用 LLM
    response = model.invoke(messages)

    # 4. 更新指标
    metrics = state["metrics"]
    metrics["llm_calls"] += 1
    metrics["message_count"] += 1

    return {
        "messages": [response],
        "metrics": metrics,
        "issue": {"status": "in_progress"}
    }

def handle_billing(state: CustomerServiceState):
    """
    账单支持节点
    """
    customer = state["customer"]

    # 检查客户权限
    if customer["tier"] == "basic":
        response = AIMessage(
            "For billing inquiries, please contact our billing department at billing@company.com"
        )
    else:
        # Premium/Enterprise 客户直接处理
        system_prompt = "You are a billing specialist with access to account information..."
        messages = [SystemMessage(system_prompt)] + state["messages"][-3:]
        response = model.invoke(messages)

    return {"messages": [response]}

def handle_general(state: CustomerServiceState):
    """
    一般咨询节点
    """
    # 使用简化的上下文(节省 Token)
    messages = state["messages"][-3:]
    response = model.invoke(messages)

    return {"messages": [response]}

def check_escalation(state: CustomerServiceState) -> Literal["escalate", "summarize", "end"]:
    """
    检查是否需要升级
    """
    issue = state["issue"]
    messages = state["messages"]

    # 条件 1:高优先级问题
    if issue["priority"] == "urgent":
        return "escalate"

    # 条件 2:对话轮数过多
    if len(messages) > 10:
        return "summarize"

    # 条件 3:检测客户情绪
    last_message = messages[-1].content
    if detect_frustration(last_message):
        return "escalate"

    # 条件 4:问题已解决
    if issue["status"] == "resolved":
        return "end"

    return "end"

def escalate_to_human(state: CustomerServiceState):
    """
    升级到人工客服
    """
    customer = state["customer"]
    issue = state["issue"]

    # 生成升级报告
    report = f"""
    Escalation Request:
    - Customer: {customer['name']} ({customer['user_id']})
    - Issue: {issue['category']} - {issue['priority']}
    - Summary: {state.get('summary', 'No summary available')}
    - Messages: {len(state['messages'])}
    """

    # 通知系统
    notify_human_agent(report)

    # 回复客户
    response = AIMessage(
        "I've escalated your issue to a human agent who will assist you shortly."
    )

    return {
        "messages": [response],
        "needs_human": True
    }

def summarize_conversation(state: CustomerServiceState):
    """
    对话摘要节点
    """
    messages = state["messages"]
    current_summary = state.get("summary", "")

    # 生成摘要
    if current_summary:
        prompt = f"Current summary: {current_summary}\n\nUpdate with new messages:"
    else:
        prompt = "Summarize this customer service conversation:"

    summary_messages = messages + [HumanMessage(prompt)]
    summary = model.invoke(summary_messages).content

    # 更新短期记忆
    last_turn = f"User: {messages[-2].content} | Agent: {messages[-1].content}"
    short_term = state.get("short_term_memory", [])
    short_term.append(last_turn)
    if len(short_term) > 3:
        short_term = short_term[-3:]  # 只保留最近 3 轮

    # 删除旧消息
    from langchain_core.messages import RemoveMessage
    delete_ops = [RemoveMessage(id=m.id) for m in messages[:-4]]

    return {
        "summary": summary,
        "short_term_memory": short_term,
        "messages": delete_ops
    }

# ============= 辅助函数 =============

def detect_frustration(text: str) -> bool:
    """检测客户情绪"""
    frustration_keywords = [
        "angry", "frustrated", "terrible", "worst",
        "unacceptable", "ridiculous", "cancel"
    ]
    return any(keyword in text.lower() for keyword in frustration_keywords)

def notify_human_agent(report: str):
    """通知人工客服"""
    # 发送到队列、Slack、邮件等
    print(f"[ESCALATION] {report}")

# ============= 构建图 =============

workflow = StateGraph(CustomerServiceState)

# 添加节点
workflow.add_node("route", route_message)
workflow.add_node("handle_technical", handle_technical)
workflow.add_node("handle_billing", handle_billing)
workflow.add_node("handle_general", handle_general)
workflow.add_node("check_escalation", check_escalation)
workflow.add_node("escalate", escalate_to_human)
workflow.add_node("summarize", summarize_conversation)

# 添加边
workflow.add_edge(START, "route")
workflow.add_conditional_edges("route", lambda s: s["issue"]["category"], {
    "technical": "handle_technical",
    "billing": "handle_billing",
    "general": "handle_general"
})

# 所有处理节点都到检查节点
workflow.add_edge("handle_technical", "check_escalation")
workflow.add_edge("handle_billing", "check_escalation")
workflow.add_edge("handle_general", "check_escalation")

# 条件分支
workflow.add_conditional_edges("check_escalation", check_escalation, {
    "escalate": "escalate",
    "summarize": "summarize",
    "end": END
})

workflow.add_edge("escalate", END)
workflow.add_edge("summarize", END)

# 编译(使用 PostgreSQL)
import asyncpg
pool = asyncpg.create_pool("postgresql://user:pass@localhost/customer_service")
memory = PostgresSaver(pool)

graph = workflow.compile(checkpointer=memory)

# ============= 使用示例 =============

def start_customer_session(user_id: str, name: str, tier: str):
    """开始客服会话"""
    config = {"configurable": {"thread_id": f"customer_{user_id}"}}

    # 初始化状态
    initial_state = {
        "customer": {
            "user_id": user_id,
            "name": name,
            "email": f"{user_id}@example.com",
            "tier": tier,
            "language": "en"
        },
        "issue": {
            "category": "general",
            "priority": "medium",
            "status": "open",
            "tags": []
        },
        "metrics": {
            "message_count": 0,
            "llm_calls": 0,
            "total_tokens": 0,
            "session_start": datetime.now(),
            "last_activity": datetime.now()
        },
        "needs_human": False,
        "satisfaction_score": 0.0,
        "messages": []
    }

    return config, initial_state

# 使用
config, initial_state = start_customer_session("user123", "Alice", "premium")

# 第一轮对话
result = graph.invoke({
    **initial_state,
    "messages": [HumanMessage("My app keeps crashing")]
}, config)

# 第二轮对话
result = graph.invoke({
    "messages": [HumanMessage("I tried restarting but it didn't help")]
}, config)

# 查看状态
state = graph.get_state(config)
print(f"Issue category: {state.values['issue']['category']}")
print(f"Status: {state.values['issue']['status']}")
print(f"Needs human: {state.values['needs_human']}")

关键设计要点

1. 状态分层设计

python
# ✅ 好的设计:分层清晰
class State(MessagesState):
    customer: CustomerInfo      # 客户层
    issue: IssueInfo           # 业务层
    summary: str               # 对话层
    metrics: SystemMetrics     # 系统层

# ❌ 不好的设计:平铺所有字段
class State(MessagesState):
    user_id: str
    user_name: str
    issue_category: str
    issue_priority: str
    message_count: int
    # ... 难以维护

2. 智能路由

python
# 根据问题类型、客户等级、紧急程度等多维度路由
def intelligent_route(state):
    # 紧急问题直接升级
    if state["issue"]["priority"] == "urgent":
        return "escalate"

    # VIP 客户优先处理
    if state["customer"]["tier"] == "enterprise":
        return "vip_handler"

    # 技术问题
    if state["issue"]["category"] == "technical":
        return "technical_support"

    # 默认处理
    return "general_handler"

3. 多级摘要

python
class State(MessagesState):
    # 短期记忆:最近 3 轮对话
    short_term_memory: list[str]

    # 中期摘要:最近 10 轮的摘要
    medium_term_summary: str

    # 长期摘要:整个会话的摘要
    long_term_summary: str

# 使用时分层提供上下文
def prepare_context(state):
    context = []

    # 1. 长期背景
    if state.get("long_term_summary"):
        context.append(f"Session background: {state['long_term_summary']}")

    # 2. 中期上下文
    if state.get("medium_term_summary"):
        context.append(f"Recent discussion: {state['medium_term_summary']}")

    # 3. 短期细节
    if state.get("short_term_memory"):
        context.extend(state["short_term_memory"])

    return context

4. 监控和分析

python
class Metrics(TypedDict):
    # 性能指标
    message_count: int
    llm_calls: int
    total_tokens: int
    avg_response_time: float

    # 业务指标
    resolved_issues: int
    escalated_issues: int
    satisfaction_score: float

def update_metrics(state: CustomerServiceState):
    """更新指标"""
    metrics = state["metrics"]

    # 记录到监控系统
    prometheus_client.gauge("customer_service_messages", metrics["message_count"])
    prometheus_client.gauge("customer_service_llm_calls", metrics["llm_calls"])

    # 记录到数据库
    analytics_db.insert({
        "user_id": state["customer"]["user_id"],
        "session_id": state["thread_id"],
        "metrics": metrics,
        "timestamp": datetime.now()
    })

关键洞察

客服机器人的状态管理要点:

  1. 结构化状态:分层管理(客户/问题/对话/系统)
  2. 智能路由:根据多维度动态分发
  3. 多级记忆:短期/中期/长期摘要
  4. 性能优化:Token 控制、缓存、批处理
  5. 监控分析:实时指标、用户反馈
  6. 人工接入:升级机制、上下文传递

问题 9:如何优化 LangGraph 状态管理的 Token 使用和成本?

点击查看详细解答

Token 成本分析

典型聊天机器人的 Token 消耗:

轮次 1: 输入 50 + 输出 100 = 150 tokens
轮次 2: 输入 200 + 输出 100 = 300 tokens  (累积历史)
轮次 3: 输入 400 + 输出 100 = 500 tokens
...
轮次 10: 输入 2000 + 输出 100 = 2100 tokens

总消耗: 约 12,000 tokens
成本: $0.36 (按 GPT-4 $0.03/1K tokens)

优化后:

轮次 1-6: 同上 (约 6,000 tokens)
轮次 7: 触发摘要 + 裁剪
   - 摘要生成: 500 tokens
   - 后续每轮: 300 tokens (摘要 + 最近 2 条)

总消耗: 约 7,500 tokens
成本: $0.225
节省: 37.5%

优化策略 1:智能消息裁剪

python
from langchain_core.messages import trim_messages, SystemMessage

def token_optimized_chat(state: MessagesState):
    """Token 优化的对话节点"""
    messages = state["messages"]

    # 1. 分离系统消息
    system_msgs = [m for m in messages if isinstance(m, SystemMessage)]
    chat_msgs = [m for m in messages if not isinstance(m, SystemMessage)]

    # 2. 计算系统消息的 tokens
    system_tokens = sum(model.get_num_tokens(m.content) for m in system_msgs)

    # 3. 裁剪对话消息
    # GPT-4: 8192 tokens 上下文
    # 预留: 1000 tokens 用于回复, 500 tokens 缓冲
    available_tokens = 8192 - 1000 - 500 - system_tokens

    trimmed_chat = trim_messages(
        chat_msgs,
        max_tokens=available_tokens,
        strategy="last",
        token_counter=model,
        allow_partial=False
    )

    # 4. 组合并调用
    final_messages = system_msgs + trimmed_chat
    response = model.invoke(final_messages)

    return {"messages": [response]}

优化策略 2:分层摘要系统

python
class TokenOptimizedState(MessagesState):
    """Token 优化的状态"""
    # 三层摘要
    micro_summary: str      # 最近 5 轮 (约 50 tokens)
    macro_summary: str      # 最近 20 轮 (约 100 tokens)
    full_summary: str       # 完整历史 (约 150 tokens)

    # 摘要元数据
    last_micro_at: int = 0
    last_macro_at: int = 0
    last_full_at: int = 0

def should_summarize(state: TokenOptimizedState):
    """智能摘要触发"""
    msg_count = len(state["messages"])

    # 每 5 轮更新 micro
    if msg_count - state["last_micro_at"] >= 5:
        return "micro_summary"

    # 每 20 轮更新 macro
    if msg_count - state["last_macro_at"] >= 20:
        return "macro_summary"

    # 每 50 轮更新 full
    if msg_count - state["last_full_at"] >= 50:
        return "full_summary"

    return "continue"

def create_micro_summary(state: TokenOptimizedState):
    """创建微摘要(最近 5 轮)"""
    recent_messages = state["messages"][-10:]  # 5 轮 = 10 条消息

    prompt = "Briefly summarize the last 5 conversation turns (max 50 tokens):"
    summary = model.invoke(recent_messages + [HumanMessage(prompt)]).content

    # 删除已摘要的消息(保留最后 2 条)
    delete_ops = [RemoveMessage(id=m.id) for m in state["messages"][:-2]]

    return {
        "micro_summary": summary,
        "messages": delete_ops,
        "last_micro_at": len(state["messages"])
    }

def prepare_context(state: TokenOptimizedState):
    """准备上下文(使用分层摘要)"""
    context = []

    # 1. 完整背景(如果有)
    if state.get("full_summary"):
        context.append(SystemMessage(f"Background: {state['full_summary']}"))

    # 2. 中期上下文(如果有)
    if state.get("macro_summary"):
        context.append(SystemMessage(f"Recent: {state['macro_summary']}"))

    # 3. 最近细节(如果有)
    if state.get("micro_summary"):
        context.append(SystemMessage(f"Latest: {state['micro_summary']}"))

    # 4. 当前消息(最后 2 条)
    context.extend(state["messages"][-2:])

    return context

# Token 使用对比
# 无摘要: 50 轮 × 平均 100 tokens = 5000 tokens
# 分层摘要: 150 (full) + 100 (macro) + 50 (micro) + 200 (最近 2 条) = 500 tokens
# 节省: 90%!

优化策略 3:缓存常见问题

python
from functools import lru_cache
import hashlib

class CachedState(MessagesState):
    """带缓存的状态"""
    cache_hits: int = 0
    cache_misses: int = 0

# 使用 LRU 缓存
@lru_cache(maxsize=100)
def cached_llm_call(message_hash: str, message_content: str):
    """缓存 LLM 调用"""
    response = model.invoke([HumanMessage(message_content)])
    return response.content

def smart_chat_with_cache(state: CachedState):
    """带缓存的对话"""
    last_message = state["messages"][-1].content

    # 计算消息哈希
    message_hash = hashlib.md5(last_message.encode()).hexdigest()

    # 尝试从缓存获取
    try:
        cached_response = cached_llm_call(message_hash, last_message)
        response = AIMessage(cached_response)

        return {
            "messages": [response],
            "cache_hits": state["cache_hits"] + 1
        }
    except:
        # 缓存未命中,正常调用
        response = model.invoke(state["messages"])

        return {
            "messages": [response],
            "cache_misses": state["cache_misses"] + 1
        }

# 缓存效果
# 假设 20% 的问题是重复的
# 成本节省: 20% × $0.03/1K = $0.006/1K tokens

优化策略 4:动态模型选择

python
from langchain_openai import ChatOpenAI

# 多个模型实例
gpt4 = ChatOpenAI(model="gpt-4", temperature=0)           # $0.03/1K (强)
gpt35 = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)  # $0.002/1K (快)
gpt4_mini = ChatOpenAI(model="gpt-4o-mini", temperature=0) # $0.015/1K (平衡)

def select_model(state: MessagesState):
    """根据任务复杂度选择模型"""
    last_message = state["messages"][-1].content

    # 简单问题 → GPT-3.5
    if is_simple_query(last_message):
        return gpt35

    # 复杂问题 → GPT-4
    if is_complex_query(last_message):
        return gpt4

    # 默认 → GPT-4 Mini
    return gpt4_mini

def is_simple_query(text: str) -> bool:
    """判断是否是简单问题"""
    simple_patterns = [
        "hello", "hi", "thanks", "bye",
        "what is", "who is", "when is"
    ]
    return any(pattern in text.lower() for pattern in simple_patterns)

def is_complex_query(text: str) -> bool:
    """判断是否是复杂问题"""
    complex_patterns = [
        "analyze", "compare", "explain why",
        "debug", "design", "implement"
    ]
    return any(pattern in text.lower() for pattern in complex_patterns)

def adaptive_chat(state: MessagesState):
    """自适应模型选择"""
    model = select_model(state)
    response = model.invoke(state["messages"])

    return {"messages": [response]}

# 成本节省
# 假设: 30% 简单问题, 20% 复杂问题, 50% 中等问题
# 平均成本: 0.3 × $0.002 + 0.5 × $0.015 + 0.2 × $0.03
#         = $0.0006 + $0.0075 + $0.006
#         = $0.0141/1K (比全用 GPT-4 节省 53%)

优化策略 5:批处理和异步

python
import asyncio
from typing import List

async def batch_process_conversations(conversations: List[dict]):
    """批处理多个对话"""
    # 使用 asyncio.gather 并行处理
    tasks = [
        process_conversation_async(conv)
        for conv in conversations
    ]

    results = await asyncio.gather(*tasks)
    return results

async def process_conversation_async(conversation: dict):
    """异步处理单个对话"""
    # 使用 LangChain 的异步接口
    response = await model.ainvoke(conversation["messages"])
    return response

# 批处理的优势
# 1. 减少网络开销
# 2. 充分利用 API 限速
# 3. 提高吞吐量

优化策略 6:Token 预算控制

python
class BudgetControlledState(MessagesState):
    """带预算控制的状态"""
    token_budget: int = 10000      # 每个会话的 token 预算
    tokens_used: int = 0           # 已使用的 tokens
    budget_warning: bool = False   # 预算警告

def budget_aware_chat(state: BudgetControlledState):
    """预算感知的对话"""
    # 1. 检查预算
    remaining_budget = state["token_budget"] - state["tokens_used"]

    if remaining_budget <= 0:
        return {
            "messages": [AIMessage("Token budget exhausted. Please start a new session.")]
        }

    # 2. 根据剩余预算调整策略
    if remaining_budget < 1000:
        # 预算不足,使用激进裁剪
        max_tokens = 500
        state["budget_warning"] = True
    elif remaining_budget < 3000:
        # 预算紧张,使用中等裁剪
        max_tokens = 1500
    else:
        # 预算充足,正常使用
        max_tokens = 3000

    # 3. 裁剪消息
    trimmed = trim_messages(
        state["messages"],
        max_tokens=max_tokens,
        token_counter=model
    )

    # 4. 调用 LLM
    response = model.invoke(trimmed)

    # 5. 更新使用量
    input_tokens = sum(model.get_num_tokens(m.content) for m in trimmed)
    output_tokens = model.get_num_tokens(response.content)
    total_tokens = input_tokens + output_tokens

    return {
        "messages": [response],
        "tokens_used": state["tokens_used"] + total_tokens
    }

成本监控和分析

python
class TokenAnalytics:
    """Token 分析工具"""

    def __init__(self):
        self.usage_log = []

    def log_usage(self, thread_id: str, input_tokens: int, output_tokens: int):
        """记录使用"""
        self.usage_log.append({
            "thread_id": thread_id,
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
            "total_tokens": input_tokens + output_tokens,
            "cost": self.calculate_cost(input_tokens, output_tokens),
            "timestamp": datetime.now()
        })

    def calculate_cost(self, input_tokens: int, output_tokens: int) -> float:
        """计算成本"""
        # GPT-4 定价
        input_cost = input_tokens / 1000 * 0.03
        output_cost = output_tokens / 1000 * 0.06
        return input_cost + output_cost

    def get_statistics(self):
        """获取统计信息"""
        total_tokens = sum(log["total_tokens"] for log in self.usage_log)
        total_cost = sum(log["cost"] for log in self.usage_log)
        avg_tokens_per_request = total_tokens / len(self.usage_log) if self.usage_log else 0

        return {
            "total_tokens": total_tokens,
            "total_cost": total_cost,
            "avg_tokens_per_request": avg_tokens_per_request,
            "total_requests": len(self.usage_log)
        }

    def get_top_consumers(self, n=10):
        """获取 Token 消耗最多的会话"""
        from collections import defaultdict

        thread_usage = defaultdict(int)
        for log in self.usage_log:
            thread_usage[log["thread_id"]] += log["total_tokens"]

        return sorted(thread_usage.items(), key=lambda x: x[1], reverse=True)[:n]

# 使用
analytics = TokenAnalytics()

def monitored_chat(state: MessagesState):
    """带监控的对话"""
    # 计算输入 tokens
    input_tokens = sum(model.get_num_tokens(m.content) for m in state["messages"])

    # 调用 LLM
    response = model.invoke(state["messages"])

    # 计算输出 tokens
    output_tokens = model.get_num_tokens(response.content)

    # 记录到分析系统
    thread_id = state.get("thread_id", "unknown")
    analytics.log_usage(thread_id, input_tokens, output_tokens)

    return {"messages": [response]}

# 定期报告
def generate_cost_report():
    """生成成本报告"""
    stats = analytics.get_statistics()
    print(f"""
    Token Usage Report
    ==================
    Total Tokens: {stats['total_tokens']:,}
    Total Cost: ${stats['total_cost']:.2f}
    Avg Tokens/Request: {stats['avg_tokens_per_request']:.0f}
    Total Requests: {stats['total_requests']}

    Top 10 Consumers:
    """)

    for thread_id, tokens in analytics.get_top_consumers(10):
        cost = tokens / 1000 * 0.03
        print(f"  {thread_id}: {tokens:,} tokens (${cost:.2f})")

综合优化方案

python
class OptimizedChatbot:
    """完全优化的聊天机器人"""

    def __init__(self):
        # 模型
        self.gpt4 = ChatOpenAI(model="gpt-4")
        self.gpt35 = ChatOpenAI(model="gpt-3.5-turbo")

        # 缓存
        self.cache = {}

        # 分析
        self.analytics = TokenAnalytics()

    def chat(self, state: MessagesState, budget: int = 5000):
        """优化的对话方法"""
        # 1. 检查缓存
        cache_key = self._get_cache_key(state["messages"][-1])
        if cache_key in self.cache:
            return {"messages": [self.cache[cache_key]]}

        # 2. 选择模型
        model = self._select_model(state)

        # 3. 准备上下文
        # - 使用分层摘要
        # - Token 裁剪
        # - 预算控制
        context = self._prepare_optimized_context(state, budget)

        # 4. 调用 LLM
        response = model.invoke(context)

        # 5. 缓存结果
        self.cache[cache_key] = response

        # 6. 记录使用
        self._log_usage(state, context, response)

        return {"messages": [response]}

    def _prepare_optimized_context(self, state, budget):
        """准备优化的上下文"""
        # 实现所有优化策略...
        pass

# 优化效果总结
# 基线: 10,000 tokens, $0.30
# 优化后: 3,000 tokens, $0.09
# 节省: 70%

关键洞察

Token 优化的黄金法则:

  1. 裁剪优先:使用 trim_messages 精确控制
  2. 分层摘要:micro/macro/full 三级摘要
  3. 智能缓存:缓存常见问题和答案
  4. 动态模型:根据复杂度选择模型
  5. 预算控制:设置 token 上限
  6. 持续监控:分析使用模式,持续优化

问题 10:如何处理 LangGraph 状态管理中的并发和竞态条件?

点击查看详细解答

并发问题的本质

在多用户或高并发场景下,状态管理面临的挑战:

时间 T1: 用户 A 读取状态 (count=0)
时间 T2: 用户 B 读取状态 (count=0)
时间 T3: 用户 A 更新状态 (count=1)
时间 T4: 用户 B 更新状态 (count=1)  ← 应该是 2!

结果: count=1 (丢失了一次更新)

解决方案 1:Thread 隔离

原理: 每个用户使用独立的 thread_id,状态完全隔离

python
from langgraph.checkpoint.postgres import PostgresSaver

# 每个用户独立的 thread
def get_user_config(user_id: str):
    return {"configurable": {"thread_id": f"user_{user_id}"}}

# 用户 A
config_a = get_user_config("alice")
graph.invoke({"messages": [...]}, config_a)

# 用户 B(完全独立)
config_b = get_user_config("bob")
graph.invoke({"messages": [...]}, config_b)

# 两个用户的状态互不影响

适用场景:

  • ✅ 多用户聊天机器人
  • ✅ 独立的对话会话
  • ❌ 需要共享状态的场景

解决方案 2:使用 Reducer 处理并行节点

原理: Reducer 定义了多个更新的合并策略

python
from operator import add
from typing import Annotated

class ConcurrentState(TypedDict):
    # 使用 Reducer 避免冲突
    results: Annotated[list, add]
    counter: Annotated[int, lambda x, y: x + y]  # 累加

def parallel_node_1(state):
    return {"results": ["result_1"], "counter": 1}

def parallel_node_2(state):
    return {"results": ["result_2"], "counter": 1}

# 并行执行
# results = ["result_1", "result_2"]  ✅ 合并
# counter = 2  ✅ 累加

适用场景:

  • ✅ 图内部的并行节点
  • ✅ 可以定义合并逻辑的场景
  • ❌ 跨进程的并发

解决方案 3:数据库事务和锁

原理: 使用数据库的事务机制保证原子性

python
import asyncpg
from contextlib import asynccontextmanager

class TransactionalCheckpointer:
    """带事务的 Checkpointer"""

    def __init__(self, db_pool):
        self.pool = db_pool

    @asynccontextmanager
    async def transaction(self):
        """事务上下文管理器"""
        async with self.pool.acquire() as conn:
            async with conn.transaction():
                yield conn

    async def save_state(self, thread_id: str, state: dict):
        """原子性保存状态"""
        async with self.transaction() as conn:
            # 1. 锁定行(防止并发修改)
            await conn.execute(
                "SELECT * FROM checkpoints WHERE thread_id = $1 FOR UPDATE",
                thread_id
            )

            # 2. 更新状态
            await conn.execute(
                "UPDATE checkpoints SET state = $1 WHERE thread_id = $2",
                state, thread_id
            )

    async def increment_counter(self, thread_id: str):
        """原子性递增计数器"""
        async with self.transaction() as conn:
            result = await conn.fetchval(
                """
                UPDATE checkpoints
                SET counter = counter + 1
                WHERE thread_id = $1
                RETURNING counter
                """,
                thread_id
            )
            return result

# 使用
checkpointer = TransactionalCheckpointer(pool)

# 即使并发调用,计数器也是准确的
await checkpointer.increment_counter("user_123")
await checkpointer.increment_counter("user_123")
# counter = 2 ✅

适用场景:

  • ✅ 需要强一致性的场景
  • ✅ 跨进程的并发
  • ❌ 对性能要求极高的场景

解决方案 4:乐观锁

原理: 使用版本号检测冲突

python
class VersionedState(MessagesState):
    """带版本号的状态"""
    version: int = 0

def optimistic_update(state: VersionedState, new_data: dict):
    """乐观锁更新"""
    current_version = state["version"]

    # 更新数据
    updated_state = {**state, **new_data, "version": current_version + 1}

    # 尝试保存(检查版本)
    try:
        save_with_version_check(updated_state, current_version)
        return updated_state
    except VersionConflictError:
        # 版本冲突,重新读取并重试
        latest_state = reload_state()
        return optimistic_update(latest_state, new_data)

def save_with_version_check(state: dict, expected_version: int):
    """保存时检查版本"""
    result = conn.execute(
        """
        UPDATE checkpoints
        SET state = $1, version = $2
        WHERE thread_id = $3 AND version = $4
        """,
        state, state["version"], state["thread_id"], expected_version
    )

    if result.rowcount == 0:
        raise VersionConflictError("State was modified by another process")

适用场景:

  • ✅ 冲突较少的场景
  • ✅ 读多写少
  • ❌ 冲突频繁的场景(会重试很多次)

解决方案 5:消息队列

原理: 将并发请求串行化

python
import asyncio
from asyncio import Queue

class QueuedChatbot:
    """带队列的聊天机器人"""

    def __init__(self, graph):
        self.graph = graph
        self.queues = {}  # 每个 thread 一个队列
        self.workers = {}  # 每个 thread 一个 worker

    async def send_message(self, thread_id: str, message: str):
        """发送消息(异步)"""
        # 1. 获取或创建队列
        if thread_id not in self.queues:
            self.queues[thread_id] = Queue()
            self.workers[thread_id] = asyncio.create_task(
                self._worker(thread_id)
            )

        # 2. 将消息放入队列
        future = asyncio.Future()
        await self.queues[thread_id].put((message, future))

        # 3. 等待结果
        return await future

    async def _worker(self, thread_id: str):
        """Worker 协程:串行处理消息"""
        queue = self.queues[thread_id]
        config = {"configurable": {"thread_id": thread_id}}

        while True:
            # 从队列获取消息
            message, future = await queue.get()

            try:
                # 串行执行(避免并发)
                result = await self.graph.ainvoke(
                    {"messages": [HumanMessage(message)]},
                    config
                )
                future.set_result(result)
            except Exception as e:
                future.set_exception(e)
            finally:
                queue.task_done()

# 使用
chatbot = QueuedChatbot(graph)

# 并发调用会被串行化
results = await asyncio.gather(
    chatbot.send_message("user_123", "Message 1"),
    chatbot.send_message("user_123", "Message 2"),
    chatbot.send_message("user_123", "Message 3")
)
# 保证执行顺序:Message 1 → Message 2 → Message 3

适用场景:

  • ✅ 需要保证顺序的场景
  • ✅ 高并发但单个用户请求不多
  • ❌ 对延迟敏感的场景

解决方案 6:分布式锁

原理: 使用 Redis 等实现分布式锁

python
import redis
from contextlib import contextmanager
import time
import uuid

class DistributedLock:
    """分布式锁"""

    def __init__(self, redis_client, key: str, timeout: int = 10):
        self.redis = redis_client
        self.key = f"lock:{key}"
        self.timeout = timeout
        self.identifier = str(uuid.uuid4())

    @contextmanager
    def acquire(self, retry_times: int = 10, retry_delay: float = 0.1):
        """获取锁"""
        acquired = False
        try:
            for _ in range(retry_times):
                # 尝试获取锁
                if self.redis.set(self.key, self.identifier, nx=True, ex=self.timeout):
                    acquired = True
                    break
                time.sleep(retry_delay)

            if not acquired:
                raise LockAcquisitionError("Failed to acquire lock")

            yield

        finally:
            # 释放锁(检查是否是自己的锁)
            if acquired:
                lua_script = """
                if redis.call("get", KEYS[1]) == ARGV[1] then
                    return redis.call("del", KEYS[1])
                else
                    return 0
                end
                """
                self.redis.eval(lua_script, 1, self.key, self.identifier)

# 使用
redis_client = redis.Redis(host='localhost', port=6379)

def concurrent_safe_update(thread_id: str):
    """并发安全的更新"""
    lock = DistributedLock(redis_client, thread_id)

    with lock.acquire():
        # 临界区:只有一个进程可以进入
        state = graph.get_state({"configurable": {"thread_id": thread_id}})
        # ... 更新状态 ...
        graph.update_state({"configurable": {"thread_id": thread_id}}, new_state)

适用场景:

  • ✅ 多服务器部署
  • ✅ 需要强一致性
  • ❌ 单机部署(Python GIL 已足够)

最佳实践

python
class ProductionChatbot:
    """生产级聊天机器人(综合方案)"""

    def __init__(self, checkpointer):
        self.graph = workflow.compile(checkpointer=checkpointer)
        self.queues = {}  # 队列(串行化)
        self.redis = redis.Redis()  # 分布式锁

    async def handle_message(self, user_id: str, message: str):
        """处理消息(综合方案)"""
        thread_id = f"user_{user_id}"
        config = {"configurable": {"thread_id": thread_id}}

        # 策略 1:用户隔离(Thread)
        # 不同用户互不影响

        # 策略 2:队列串行化(同一用户的并发请求)
        if thread_id not in self.queues:
            self.queues[thread_id] = asyncio.Queue()
            asyncio.create_task(self._worker(thread_id))

        future = asyncio.Future()
        await self.queues[thread_id].put((message, future))
        result = await future

        return result

    async def _worker(self, thread_id: str):
        """Worker:串行处理"""
        queue = self.queues[thread_id]
        config = {"configurable": {"thread_id": thread_id}}

        while True:
            message, future = await queue.get()

            # 策略 3:分布式锁(多服务器部署)
            lock = DistributedLock(self.redis, thread_id)

            try:
                with lock.acquire():
                    # 策略 4:Reducer(图内部并行节点)
                    # 已在状态定义中使用 Annotated[list, add]

                    result = await self.graph.ainvoke(
                        {"messages": [HumanMessage(message)]},
                        config
                    )
                    future.set_result(result)
            except Exception as e:
                future.set_exception(e)
            finally:
                queue.task_done()

关键洞察

并发控制的层次:

  1. 用户层:Thread 隔离(不同用户)
  2. 请求层:队列串行化(同一用户的并发请求)
  3. 图层:Reducer(图内部并行节点)
  4. 数据层:事务/锁(数据库级别)

选择指南:

  • 单机 + 低并发:Thread 隔离 + Reducer
  • 单机 + 高并发:+ 队列串行化
  • 多机 + 高并发:+ 分布式锁
  • 强一致性需求:+ 数据库事务

🎉 复习完成!

知识掌握度自测

完成以上 10 道题后,你应该能够:

  • [ ] 熟练使用三种 State Schema 定义方式
  • [ ] 理解并解决并行节点的状态冲突
  • [ ] 掌握 add_messages 的三大核心功能
  • [ ] 实现私有状态和多状态模式
  • [ ] 优化消息管理和 Token 使用
  • [ ] 构建带摘要的智能聊天机器人
  • [ ] 配置外部数据库持久化存储
  • [ ] 设计生产级客服系统
  • [ ] 优化 Token 成本(节省 50%+)
  • [ ] 处理并发和竞态条件

下一步学习路径

完成 Module-3 后,建议你:

  1. 实践项目

    • 构建一个完整的聊天机器人
    • 实现消息摘要和持久化
    • 优化 Token 使用
  2. 进阶学习

    • Module-4:工具调用和外部集成
    • Module-5:高级图结构和子图
    • Module-6:部署和生产优化
  3. 深入研究

    • 研究 LangGraph 源码
    • 贡献开源项目
    • 分享学习经验

实践建议

项目 1:智能客服机器人 (1-2 周)

  • 实现多状态模式
  • 集成消息摘要
  • 部署到生产环境

项目 2:长期记忆助手 (1-2 周)

  • 设计分层摘要系统
  • 优化 Token 成本
  • 实现跨设备同步

项目 3:企业级对话平台 (2-3 周)

  • 多租户架构
  • 高并发处理
  • 监控和分析系统

📖 参考资源

官方文档

社区资源

学习路径


恭喜你完成 Module-3 的学习!

你已经掌握了 LangGraph 状态管理的核心知识,这是构建复杂 AI 系统的关键技能。

继续加油,期待你在 Module-4 中的精彩表现!🚀

基于 MIT 许可证发布。内容版权归作者所有。