agent-writer/graph/test_graph_3.py
2025-09-11 18:34:03 +08:00

171 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# test_graph_persistent.py
from operator import add
from typing import TypedDict, Annotated
from langchain_core.messages import AnyMessage, HumanMessage
from langgraph import graph
from langgraph.config import get_stream_writer
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from IPython.display import Image
import uuid
import config
# 导入数据库连接和自定义检查点存储器
from tools.database.mongo import mainDB, client
# from mongodb_checkpointer import MongoDBCheckpointSaver
from langgraph.checkpoint.mongodb import MongoDBSaver
import asyncio
# collection = db.langgraph_checkpoints
memory: MongoDBSaver = MongoDBSaver(client, db_name=config.MONGO_CHECKPOINT_DB_NAME)
class State(TypedDict):
messages: Annotated[list[AnyMessage], add]
type: str
class InputState(TypedDict):
user_input: str
class OutputState(TypedDict):
graph_output: str
class OverallState(TypedDict):
foo: str
user_input: str
graph_output: str
class PrivateState(TypedDict):
bar: str
async def node_1(state: InputState) -> OverallState:
print(f"Node 1 处理: {state['user_input']}")
return {"foo": state["user_input"] + ">学院"}
async def node_2(state: OverallState) -> PrivateState:
print(f"Node 2 处理: {state['foo']}")
return {"bar": state["foo"] + ">非常"}
async def node_3(state: PrivateState) -> OverallState:
print(f"Node 3 处理: {state['bar']}")
return {"graph_output": state["bar"] + ">靠谱"}
# 创建 MongoDB 检查点存储器
# checkpointer = MongoDBCheckpointSaver(db.langgraph_checkpoints)
# 构建图,并添加检查点存储器
builder = StateGraph(OverallState, input_schema=InputState, output_schema=OutputState)
builder.add_node('node_1', node_1)
builder.add_node('node_2', node_2)
builder.add_node('node_3', node_3)
builder.add_edge(START, 'node_1')
builder.add_edge('node_1', 'node_2')
builder.add_edge('node_2', 'node_3')
builder.add_edge('node_3', END)
# 编译图并添加检查点存储器
graph = builder.compile(checkpointer=memory)
async def run_with_persistence(user_input: str, thread_id: str = None):
"""运行带持久化的图"""
if thread_id is None:
thread_id = str(uuid.uuid4())
print(f"使用线程 ID: {thread_id}")
# 配置包含线程 ID
config = {"configurable": {"thread_id": thread_id}}
# 执行图
input_state = {"user_input": user_input}
output_state = await graph.ainvoke(input_state, config)
print(f"输出: {output_state}")
return output_state, thread_id
async def get_checkpoint_history(thread_id: str):
"""获取检查点历史"""
config = {"configurable": {"thread_id": thread_id}}
try:
history_generator = memory.list(config, limit=10)
print("正在获取检查点历史...")
# 使用列表推导式或 for 循环来收集所有检查点
history = list(history_generator)
print(f"找到 {len(history)} 个检查点:")
for i, checkpoint_tuple in enumerate(history):
# checkpoint_tuple 包含 config, checkpoint, metadata 等属性
# print(f" - ID: {checkpoint_tuple}")
checkpoint_data = checkpoint_tuple.checkpoint
metadata = checkpoint_tuple.metadata
print(f"检查点 {i+1}:")
print(f" - ID: {checkpoint_data.get('id', 'N/A')}")
print(f" - 状态: {checkpoint_data.get('channel_values', {})}")
print(f" - 元数据: {metadata}")
print("-" * 50)
except Exception as e:
print(f"获取历史记录时出错: {e}")
def resume_from_checkpoint(thread_id: str, checkpoint_id: str = None):
"""从检查点恢复执行"""
config = {"configurable": {"thread_id": thread_id}}
if checkpoint_id:
config["configurable"]["checkpoint_id"] = checkpoint_id
try:
# 获取 CheckpointTuple 对象
checkpoint_tuple = memory.get_tuple(config)
if checkpoint_tuple:
# 直接通过属性访问,而不是解包
checkpoint_data = checkpoint_tuple.checkpoint
metadata = checkpoint_tuple.metadata
print(f"从检查点恢复:")
print(f" - 检查点 ID: {checkpoint_data.get('id', 'N/A')}")
print(f" - 状态: {checkpoint_data.get('channel_values', {})}")
print(f" - 元数据: {metadata}")
return checkpoint_data.get('channel_values', {})
else:
print(f"未找到线程 {thread_id} 的检查点")
return None
except Exception as e:
print(f"恢复检查点时出错: {e}")
return None
if __name__ == "__main__":
print("=== 测试持久化 LangGraph ===")
# 第一次运行
print("\n1. 第一次运行:")
# 由于在异步函数外使用await需要使用asyncio运行
output1, thread_id = asyncio.run(run_with_persistence("你好"))
# 查看检查点历史
print("\n2. 查看检查点历史:")
asyncio.run(get_checkpoint_history(thread_id))
# 从同一线程继续运行
print("\n3. 从同一线程继续运行:")
output2, _ = asyncio.run(run_with_persistence("再见", thread_id))
# 查看更新后的历史
print("\n4. 查看更新后的历史:")
asyncio.run(get_checkpoint_history(thread_id))
# 恢复检查点状态
print("\n5. 恢复最新检查点状态:")
restored_state = resume_from_checkpoint(thread_id)
# 可视化图结构(可选)
# try:
# with open('persistent_graph_visualization.png', 'wb') as f:
# f.write(graph.get_graph().draw_mermaid_png())
# print("\n图片已保存为 persistent_graph_visualization.png")
# except Exception as e:
# print(f"保存可视化图片失败: {e}")