171 lines
5.8 KiB
Python
171 lines
5.8 KiB
Python
# test_graph_persistent.py
|
||
from operator import add
|
||
from typing import TypedDict, Annotated
|
||
from langchain_core.messages import AnyMessage, HumanMessage
|
||
from langgraph import graph
|
||
from langgraph.config import get_stream_writer
|
||
from langgraph.constants import END, START
|
||
from langgraph.graph import StateGraph
|
||
from IPython.display import Image
|
||
import uuid
|
||
import config
|
||
# 导入数据库连接和自定义检查点存储器
|
||
from tools.database.mongo import mainDB, client
|
||
# from mongodb_checkpointer import MongoDBCheckpointSaver
|
||
from langgraph.checkpoint.mongodb import MongoDBSaver
|
||
import asyncio
|
||
# collection = db.langgraph_checkpoints
|
||
memory: MongoDBSaver = MongoDBSaver(client, db_name=config.MONGO_CHECKPOINT_DB_NAME)
|
||
|
||
class State(TypedDict):
|
||
messages: Annotated[list[AnyMessage], add]
|
||
type: str
|
||
|
||
class InputState(TypedDict):
|
||
user_input: str
|
||
|
||
class OutputState(TypedDict):
|
||
graph_output: str
|
||
|
||
class OverallState(TypedDict):
|
||
foo: str
|
||
user_input: str
|
||
graph_output: str
|
||
|
||
class PrivateState(TypedDict):
|
||
bar: str
|
||
|
||
async def node_1(state: InputState) -> OverallState:
|
||
print(f"Node 1 处理: {state['user_input']}")
|
||
return {"foo": state["user_input"] + ">学院"}
|
||
|
||
async def node_2(state: OverallState) -> PrivateState:
|
||
print(f"Node 2 处理: {state['foo']}")
|
||
return {"bar": state["foo"] + ">非常"}
|
||
|
||
async def node_3(state: PrivateState) -> OverallState:
|
||
print(f"Node 3 处理: {state['bar']}")
|
||
return {"graph_output": state["bar"] + ">靠谱"}
|
||
|
||
# 创建 MongoDB 检查点存储器
|
||
# checkpointer = MongoDBCheckpointSaver(db.langgraph_checkpoints)
|
||
|
||
# 构建图,并添加检查点存储器
|
||
builder = StateGraph(OverallState, input_schema=InputState, output_schema=OutputState)
|
||
|
||
builder.add_node('node_1', node_1)
|
||
builder.add_node('node_2', node_2)
|
||
builder.add_node('node_3', node_3)
|
||
|
||
builder.add_edge(START, 'node_1')
|
||
builder.add_edge('node_1', 'node_2')
|
||
builder.add_edge('node_2', 'node_3')
|
||
builder.add_edge('node_3', END)
|
||
|
||
# 编译图并添加检查点存储器
|
||
graph = builder.compile(checkpointer=memory)
|
||
|
||
async def run_with_persistence(user_input: str, thread_id: str = None):
|
||
"""运行带持久化的图"""
|
||
if thread_id is None:
|
||
thread_id = str(uuid.uuid4())
|
||
|
||
print(f"使用线程 ID: {thread_id}")
|
||
|
||
# 配置包含线程 ID
|
||
config = {"configurable": {"thread_id": thread_id}}
|
||
|
||
# 执行图
|
||
input_state = {"user_input": user_input}
|
||
output_state = await graph.ainvoke(input_state, config)
|
||
|
||
print(f"输出: {output_state}")
|
||
return output_state, thread_id
|
||
|
||
async def get_checkpoint_history(thread_id: str):
|
||
"""获取检查点历史"""
|
||
config = {"configurable": {"thread_id": thread_id}}
|
||
|
||
try:
|
||
history_generator = memory.list(config, limit=10)
|
||
|
||
print("正在获取检查点历史...")
|
||
|
||
# 使用列表推导式或 for 循环来收集所有检查点
|
||
history = list(history_generator)
|
||
|
||
print(f"找到 {len(history)} 个检查点:")
|
||
|
||
for i, checkpoint_tuple in enumerate(history):
|
||
# checkpoint_tuple 包含 config, checkpoint, metadata 等属性
|
||
# print(f" - ID: {checkpoint_tuple}")
|
||
checkpoint_data = checkpoint_tuple.checkpoint
|
||
metadata = checkpoint_tuple.metadata
|
||
print(f"检查点 {i+1}:")
|
||
print(f" - ID: {checkpoint_data.get('id', 'N/A')}")
|
||
print(f" - 状态: {checkpoint_data.get('channel_values', {})}")
|
||
print(f" - 元数据: {metadata}")
|
||
print("-" * 50)
|
||
|
||
except Exception as e:
|
||
print(f"获取历史记录时出错: {e}")
|
||
|
||
def resume_from_checkpoint(thread_id: str, checkpoint_id: str = None):
|
||
"""从检查点恢复执行"""
|
||
config = {"configurable": {"thread_id": thread_id}}
|
||
|
||
if checkpoint_id:
|
||
config["configurable"]["checkpoint_id"] = checkpoint_id
|
||
|
||
try:
|
||
# 获取 CheckpointTuple 对象
|
||
checkpoint_tuple = memory.get_tuple(config)
|
||
|
||
if checkpoint_tuple:
|
||
# 直接通过属性访问,而不是解包
|
||
checkpoint_data = checkpoint_tuple.checkpoint
|
||
metadata = checkpoint_tuple.metadata
|
||
print(f"从检查点恢复:")
|
||
print(f" - 检查点 ID: {checkpoint_data.get('id', 'N/A')}")
|
||
print(f" - 状态: {checkpoint_data.get('channel_values', {})}")
|
||
print(f" - 元数据: {metadata}")
|
||
return checkpoint_data.get('channel_values', {})
|
||
else:
|
||
print(f"未找到线程 {thread_id} 的检查点")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"恢复检查点时出错: {e}")
|
||
return None
|
||
|
||
if __name__ == "__main__":
|
||
print("=== 测试持久化 LangGraph ===")
|
||
|
||
# 第一次运行
|
||
print("\n1. 第一次运行:")
|
||
# 由于在异步函数外使用await,需要使用asyncio运行
|
||
output1, thread_id = asyncio.run(run_with_persistence("你好"))
|
||
|
||
# 查看检查点历史
|
||
print("\n2. 查看检查点历史:")
|
||
asyncio.run(get_checkpoint_history(thread_id))
|
||
|
||
# 从同一线程继续运行
|
||
print("\n3. 从同一线程继续运行:")
|
||
output2, _ = asyncio.run(run_with_persistence("再见", thread_id))
|
||
|
||
# 查看更新后的历史
|
||
print("\n4. 查看更新后的历史:")
|
||
asyncio.run(get_checkpoint_history(thread_id))
|
||
|
||
# 恢复检查点状态
|
||
print("\n5. 恢复最新检查点状态:")
|
||
restored_state = resume_from_checkpoint(thread_id)
|
||
|
||
# 可视化图结构(可选)
|
||
# try:
|
||
# with open('persistent_graph_visualization.png', 'wb') as f:
|
||
# f.write(graph.get_graph().draw_mermaid_png())
|
||
# print("\n图片已保存为 persistent_graph_visualization.png")
|
||
# except Exception as e:
|
||
# print(f"保存可视化图片失败: {e}") |