# test_graph_persistent.py from operator import add from typing import TypedDict, Annotated from langchain_core.messages import AnyMessage, HumanMessage from langgraph import graph from langgraph.config import get_stream_writer from langgraph.constants import END, START from langgraph.graph import StateGraph from IPython.display import Image import uuid import config # 导入数据库连接和自定义检查点存储器 from tools.database.mongo import mainDB, client # from mongodb_checkpointer import MongoDBCheckpointSaver from langgraph.checkpoint.mongodb import MongoDBSaver import asyncio # collection = db.langgraph_checkpoints memory: MongoDBSaver = MongoDBSaver(client, db_name=config.MONGO_CHECKPOINT_DB_NAME) class State(TypedDict): messages: Annotated[list[AnyMessage], add] type: str class InputState(TypedDict): user_input: str class OutputState(TypedDict): graph_output: str class OverallState(TypedDict): foo: str user_input: str graph_output: str class PrivateState(TypedDict): bar: str async def node_1(state: InputState) -> OverallState: print(f"Node 1 处理: {state['user_input']}") return {"foo": state["user_input"] + ">学院"} async def node_2(state: OverallState) -> PrivateState: print(f"Node 2 处理: {state['foo']}") return {"bar": state["foo"] + ">非常"} async def node_3(state: PrivateState) -> OverallState: print(f"Node 3 处理: {state['bar']}") return {"graph_output": state["bar"] + ">靠谱"} # 创建 MongoDB 检查点存储器 # checkpointer = MongoDBCheckpointSaver(db.langgraph_checkpoints) # 构建图,并添加检查点存储器 builder = StateGraph(OverallState, input_schema=InputState, output_schema=OutputState) builder.add_node('node_1', node_1) builder.add_node('node_2', node_2) builder.add_node('node_3', node_3) builder.add_edge(START, 'node_1') builder.add_edge('node_1', 'node_2') builder.add_edge('node_2', 'node_3') builder.add_edge('node_3', END) # 编译图并添加检查点存储器 graph = builder.compile(checkpointer=memory) async def run_with_persistence(user_input: str, thread_id: str = None): """运行带持久化的图""" if thread_id is None: thread_id = str(uuid.uuid4()) print(f"使用线程 ID: {thread_id}") # 配置包含线程 ID config = {"configurable": {"thread_id": thread_id}} # 执行图 input_state = {"user_input": user_input} output_state = await graph.ainvoke(input_state, config) print(f"输出: {output_state}") return output_state, thread_id async def get_checkpoint_history(thread_id: str): """获取检查点历史""" config = {"configurable": {"thread_id": thread_id}} try: history_generator = memory.list(config, limit=10) print("正在获取检查点历史...") # 使用列表推导式或 for 循环来收集所有检查点 history = list(history_generator) print(f"找到 {len(history)} 个检查点:") for i, checkpoint_tuple in enumerate(history): # checkpoint_tuple 包含 config, checkpoint, metadata 等属性 # print(f" - ID: {checkpoint_tuple}") checkpoint_data = checkpoint_tuple.checkpoint metadata = checkpoint_tuple.metadata print(f"检查点 {i+1}:") print(f" - ID: {checkpoint_data.get('id', 'N/A')}") print(f" - 状态: {checkpoint_data.get('channel_values', {})}") print(f" - 元数据: {metadata}") print("-" * 50) except Exception as e: print(f"获取历史记录时出错: {e}") def resume_from_checkpoint(thread_id: str, checkpoint_id: str = None): """从检查点恢复执行""" config = {"configurable": {"thread_id": thread_id}} if checkpoint_id: config["configurable"]["checkpoint_id"] = checkpoint_id try: # 获取 CheckpointTuple 对象 checkpoint_tuple = memory.get_tuple(config) if checkpoint_tuple: # 直接通过属性访问,而不是解包 checkpoint_data = checkpoint_tuple.checkpoint metadata = checkpoint_tuple.metadata print(f"从检查点恢复:") print(f" - 检查点 ID: {checkpoint_data.get('id', 'N/A')}") print(f" - 状态: {checkpoint_data.get('channel_values', {})}") print(f" - 元数据: {metadata}") return checkpoint_data.get('channel_values', {}) else: print(f"未找到线程 {thread_id} 的检查点") return None except Exception as e: print(f"恢复检查点时出错: {e}") return None if __name__ == "__main__": print("=== 测试持久化 LangGraph ===") # 第一次运行 print("\n1. 第一次运行:") # 由于在异步函数外使用await,需要使用asyncio运行 output1, thread_id = asyncio.run(run_with_persistence("你好")) # 查看检查点历史 print("\n2. 查看检查点历史:") asyncio.run(get_checkpoint_history(thread_id)) # 从同一线程继续运行 print("\n3. 从同一线程继续运行:") output2, _ = asyncio.run(run_with_persistence("再见", thread_id)) # 查看更新后的历史 print("\n4. 查看更新后的历史:") asyncio.run(get_checkpoint_history(thread_id)) # 恢复检查点状态 print("\n5. 恢复最新检查点状态:") restored_state = resume_from_checkpoint(thread_id) # 可视化图结构(可选) # try: # with open('persistent_graph_visualization.png', 'wb') as f: # f.write(graph.get_graph().draw_mermaid_png()) # print("\n图片已保存为 persistent_graph_visualization.png") # except Exception as e: # print(f"保存可视化图片失败: {e}")