diff --git a/agent/build_bible.py b/agent/build_bible.py index d7a297c..212819d 100644 --- a/agent/build_bible.py +++ b/agent/build_bible.py @@ -35,7 +35,7 @@ DefaultAgentPrompt = f""" def create_agent_prompt(prompt, SchedulerList): """创建代理提示词的辅助函数""" if not SchedulerList or len(SchedulerList) == 0: return prompt - node_list = [f"{node.name}:{node.desc}" for node in SchedulerList] + node_list = [f"{node['name']}:{node['desc']}" for node in SchedulerList] return f""" {prompt} \n 下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回: diff --git a/agent/episode_create.py b/agent/episode_create.py index 21c406c..b01b847 100644 --- a/agent/episode_create.py +++ b/agent/episode_create.py @@ -35,7 +35,7 @@ DefaultAgentPrompt = f""" def create_agent_prompt(prompt, SchedulerList): """创建代理提示词的辅助函数""" if not SchedulerList or len(SchedulerList) == 0: return prompt - node_list = [f"{node.name}:{node.desc}" for node in SchedulerList] + node_list = [f"{node['name']}:{node['desc']}" for node in SchedulerList] return f""" {prompt} \n 下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回: diff --git a/agent/scheduler.py b/agent/scheduler.py index 5e87ddd..f064777 100644 --- a/agent/scheduler.py +++ b/agent/scheduler.py @@ -97,6 +97,7 @@ DefaultAgentPrompt = f""" "status": "当前阶段的状态",//取值范围在上述 status的描述中 不可写其他值 "agent":'',//分析后得出由哪个智能体继续任务,此处为智能体名称;如果需要继续与用户交互或仅需要回复用户则为空字符串 "message":'',//回复给用户的内容 + "retry_count":0,//重试次数 "node":'',//下一个节点名称 }} @@ -105,7 +106,7 @@ DefaultAgentPrompt = f""" def create_agent_prompt(prompt, SchedulerList): """创建代理提示词的辅助函数""" if not SchedulerList or len(SchedulerList) == 0: return prompt - node_list = [f"{node.name}:{node.desc}" for node in SchedulerList] + node_list = [f"{node['name']}:{node['desc']}" for node in SchedulerList] return f""" {prompt} \n 下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回: diff --git a/agent/script_analysis.py b/agent/script_analysis.py index 37cff34..a359f9a 100644 --- a/agent/script_analysis.py +++ b/agent/script_analysis.py @@ -35,7 +35,7 @@ DefaultAgentPrompt = f""" def create_agent_prompt(prompt, SchedulerList): """创建代理提示词的辅助函数""" if not SchedulerList or len(SchedulerList) == 0: return prompt - node_list = [f"{node.name}:{node.desc}" for node in SchedulerList] + node_list = [f"{node['name']}:{node['desc']}" for node in SchedulerList] return f""" {prompt} \n 下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回: diff --git a/agent/strategic_planning.py b/agent/strategic_planning.py index e145e87..a2e29b4 100644 --- a/agent/strategic_planning.py +++ b/agent/strategic_planning.py @@ -35,7 +35,7 @@ DefaultAgentPrompt = f""" def create_agent_prompt(prompt, SchedulerList): """创建代理提示词的辅助函数""" if not SchedulerList or len(SchedulerList) == 0: return prompt - node_list = [f"{node.name}:{node.desc}" for node in SchedulerList] + node_list = [f"{node['name']}:{node['desc']}" for node in SchedulerList] return f""" {prompt} \n 下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回: diff --git a/doc/节点结构图.png b/doc/节点结构图.png new file mode 100644 index 0000000..1727f0b Binary files /dev/null and b/doc/节点结构图.png differ diff --git a/graph/test_agent_graph_1.py b/graph/test_agent_graph_1.py index b29ba02..e5eef46 100644 --- a/graph/test_agent_graph_1.py +++ b/graph/test_agent_graph_1.py @@ -4,15 +4,23 @@ """ from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional + +from langgraph.graph.state import RunnableConfig from agent.scheduler import SchedulerAgent from agent.build_bible import BuildBibleAgent from agent.episode_create import EpisodeCreateAgent from agent.script_analysis import ScriptAnalysisAgent from agent.strategic_planning import StrategicPlanningAgent +from langchain_core.messages import AnyMessage,HumanMessage from langgraph.graph import StateGraph, START, END from utils.logger import get_logger import operator +import json +import config + +from tools.database.mongo import client # type: ignore +from langgraph.checkpoint.mongodb import MongoDBSaver logger = get_logger(__name__) @@ -24,7 +32,8 @@ def replace_value(old_val, new_val): # 状态类型定义 class InputState(TypedDict): """工作流输入状态""" - input_data: Annotated[Dict[str, Any], operator.add] + input_data: Annotated[list[AnyMessage], operator.add] + from_type: Annotated[str, replace_value] session_id: Annotated[str, replace_value] class OutputState(TypedDict): @@ -32,6 +41,7 @@ class OutputState(TypedDict): session_id: Annotated[str, replace_value] status: Annotated[str, replace_value] error: Annotated[str, replace_value] + agent_message: Annotated[str, replace_value] # 智能体回复 class NodeInfo(TypedDict): """工作流信息""" @@ -45,12 +55,17 @@ class NodeInfo(TypedDict): class ScriptwriterState(TypedDict, total=False): """智能编剧工作流整体状态""" # 输入数据 - input_data: Annotated[Dict[str, Any], operator.add] + input_data: Annotated[list[HumanMessage], operator.add] session_id: Annotated[str, replace_value] + from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent] # 节点间状态 - node_info: NodeInfo - + next_node: Annotated[str, replace_value] # 下一个节点 + workflow_step: Annotated[str, replace_value] # 阶段名称 [wait_for_input,script_analysis,strategic_planning,build_bible,episode_create_loop, finish] + workflow_status: Annotated[str, replace_value] # 当前阶段的状态 [waiting,running,failed,completed] + workflow_reason: Annotated[str, replace_value] # 失败原因 + workflow_retry_count: Annotated[int, replace_value] # 重试次数 + # 中间状态 agent_script_id: Annotated[str, replace_value] # 剧本ID 包括原文 agent_plan: Annotated[Dict[str, Any], replace_value] #剧本计划 @@ -58,6 +73,7 @@ class ScriptwriterState(TypedDict, total=False): episode_list: Annotated[List, replace_value] # 章节列表 完成状态、产出章节id # 输出数据 + agent_message: Annotated[str, replace_value] # 智能体回复 status: Annotated[str, replace_value] error: Annotated[str, replace_value] @@ -76,77 +92,105 @@ class ScriptwriterGraph: def __init__(self): """初始化工作流图""" self.graph = None + self.memory = MongoDBSaver(client, db_name=config.MONGO_CHECKPOINT_DB_NAME) self._build_graph() def node_router(self, state: ScriptwriterState) -> str: - next_node = state.get("node", '') - if next_node: - return next_node - else: - return END + """节点路由函数""" + print(f'node_router state {state}') + next_node = state.get("next_node", 'pause_node') + # 修复:当 next_node 为空字符串时,设置默认值 + if not next_node: + next_node = 'pause_node' # 设置为暂停节点 + print(f'node_router next_node {next_node}') + return next_node def _build_graph(self) -> None: """构建工作流图""" try: # 创建智能体 + print("创建智能体") # 调度智能体 - schedulerAgent = SchedulerAgent( + self.schedulerAgent = SchedulerAgent( tools=[], SchedulerList=[ { - "scheduler_node": "调度智能体节点", - "script_analysis_node": "原始剧本分析节点", - "strategic_planning_node": "确立改编目标节点", - "build_bible_node": "剧本圣经构建节点", - "episode_create_node": "单集创作节点", - "end_node": "结束节点,任务失败终止时使用,结束后整个工作流将停止" + "name": "scheduler_node", + "desc": "调度智能体节点", + }, + { + "name": "script_analysis_node", + "desc": "原始剧本分析节点", + }, + { + "name": "strategic_planning_node", + "desc": "确立改编目标节点", + }, + { + "name": "build_bible_node", + "desc": "剧本圣经构建节点", + }, + { + "name": "episode_create_node", + "desc": "单集创作节点", + }, + { + "name": "end_node", + "desc": "结束节点,任务失败终止时使用,结束后整个工作流将停止" } ] ) - scriptAnalysisAgent = ScriptAnalysisAgent( + self.scriptAnalysisAgent = ScriptAnalysisAgent( tools=[], SchedulerList=[ { - "scheduler_node": "调度智能体节点", + "name": "scheduler_node", + "desc": "调度智能体节点", } ] ) - strategicPlanningAgent = StrategicPlanningAgent( + self.strategicPlanningAgent = StrategicPlanningAgent( tools=[], SchedulerList=[ { - "scheduler_node": "调度智能体节点", + "name": "scheduler_node", + "desc": "调度智能体节点", } ] ) - buildBibleAgent = BuildBibleAgent( + self.buildBibleAgent = BuildBibleAgent( tools=[], SchedulerList=[ { - "scheduler_node": "调度智能体节点", + "name": "scheduler_node", + "desc": "调度智能体节点", } ] ) - episodeCreate = EpisodeCreateAgent( + self.episodeCreate = EpisodeCreateAgent( tools=[], SchedulerList=[ { - "scheduler_node": "调度智能体节点", + "name": "scheduler_node", + "desc": "调度智能体节点", } ] ) # 创建状态图 + logger.info("创建状态图") workflow = StateGraph(ScriptwriterState, input_schema=InputState, output_schema=OutputState) # 添加节点 + logger.info("添加节点") workflow.add_node("scheduler_node", self.scheduler_node) workflow.add_node("script_analysis_node", self.script_analysis_node) workflow.add_node("strategic_planning_node", self.strategic_planning_node) workflow.add_node("build_bible_node", self.build_bible_node) workflow.add_node("episode_create_node", self.episode_create_node) workflow.add_node("end_node", self.end_node) + workflow.add_node("pause_node", self.pause_node) # 添加边 workflow.set_entry_point("scheduler_node") @@ -167,13 +211,14 @@ class ScriptwriterGraph: "episode_create_node": "episode_create_node", # 用户确认和暂停逻辑在这里处理,不需要单独的边 "end_node": "end_node", + "pause_node": "pause_node", } ) workflow.add_edge("end_node", END) # 编译图 - self.graph = workflow.compile() + self.graph = workflow.compile(checkpointer=self.memory) logger.info("工作流图构建完成") except Exception as e: @@ -182,10 +227,44 @@ class ScriptwriterGraph: # --- 定义图中的节点 --- async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState: - """第一步:初步沟通,请求剧本""" - session_id = state.get("session_id", "") - - return {} + """调度节点""" + try: + session_id = state.get("session_id", "") + from_type = state.get("from_type", "") + input_data = state.get("input_data", []) + logger.info(f"调度节点 {session_id} 输入参数: {input_data} from_type:{from_type}") + reslut = await self.schedulerAgent.ainvoke(state) + ai_message_str = reslut['messages'][-1].content + ai_message = json.loads(ai_message_str) + logger.info(f"调度节点结果: {ai_message}") + step:str = ai_message.get('step', '') + status:str = ai_message.get('status', '') + next_agent:str = ai_message.get('agent', '') + return_message:str = ai_message.get('message', '') + retry_count:int = int(ai_message.get('retry_count', '0')) + next_node:str = ai_message.get('node', 'pause_node') + if next_node == 'scheduler_node': + # 返回自身 代表暂停 + print(f"调度节点 暂停等待") + return { + "agent_message": return_message, + } + else: + return { + "next_node":next_node, + "workflow_step":step, + "workflow_status":status, + # "workflow_reason":return_message, + "workflow_retry_count":retry_count, + "agent_message":return_message, + } + except Exception as e: + return { + "next_node":'end_node', + "agent_message": "执行失败", + "error": str(e) or '系统错误,工作流已终止', + 'status':'failed', + } async def script_analysis_node(self, state: ScriptwriterState)-> ScriptwriterState: """第二步:诊断分析与资产评估""" @@ -218,20 +297,32 @@ class ScriptwriterGraph: episode_list = [] return {"episode_list": episode_list} + async def pause_node(self, state: ScriptwriterState)-> ScriptwriterState: + """ 暂停节点 处理并完成所有数据状态 """ + print(f"langgraph 暂停等待") + return { + "session_id": state.get("session_id", ""), + "status": state.get('status', ''), + "error": state.get('error', ''), + "agent_message": state.get('agent_message', '') + } + async def end_node(self, state: ScriptwriterState)-> OutputState: """ 结束节点 处理并完成所有数据状态 """ print(f"langgraph 所有任务完成") return { "session_id": state.get("session_id", ""), - "status": "", - "error": "", + "status": state.get('status', ''), + "error": state.get('error', ''), + "agent_message": state.get('agent_message', ''), } - async def run(self, input_data: Dict[str, Any]) -> OutputState: + async def run(self, session_id: str, input_data: list[AnyMessage], thread_id: str|None = None) -> OutputState: """运行工作流 - Args: + session_id: 会话ID input_data: 输入数据 + thread_id: 线程ID Returns: 工作流执行结果 @@ -239,44 +330,102 @@ class ScriptwriterGraph: try: logger.info("开始运行智能编剧工作流") - # # 初始化状态 - # initial_state: InputState = { - # 'input_data': input_data, - # 'session_id': input_data.get('session_id', ''), - # 'max_iterations': input_data.get('max_iterations', 3), - # 'batch_info': input_data.get('batch_info', {}) - # } + # 配置包含线程 ID + config:RunnableConfig = {"configurable": {"thread_id": thread_id}} + # 初始化状态 + initial_state: InputState = { + 'input_data': input_data, + 'session_id': session_id, + 'from_type': 'user', + } - # # 运行工作流 - # if self.graph is None: - # raise RuntimeError("工作流图未正确初始化") + # 运行工作流 + if self.graph is None: + raise RuntimeError("工作流图未正确初始化") - # result = await self.graph.ainvoke(initial_state) + result = await self.graph.ainvoke( + initial_state, + config, + # stream_mode='values' + ) # logger.info(f"工作流执行结果: {result}") - # if not result: - # raise ValueError("工作流执行结果为空") - # # 保存到记忆 - # self.memory.save_workflow_result(result) + if not result: + raise ValueError("工作流执行结果为空") - # # 构造输出状态 - # output_result: OutputState = { - # 'script': result.get('script'), - # 'adjustment': result.get('adjustment'), - # 'error': result.get('error'), - # 'iteration_count': result.get('iteration_count', 0) - # } - output_result:OutputState = { - 'session_id': "", - 'status': 'completed', - 'error': '', + # 构造输出状态 + output_result: OutputState = { + 'session_id': result.get('session_id', ''), + 'status': result.get('status', ''), + 'error': result.get('error', ''), + 'agent_message': result.get('agent_message', ''), } logger.info("智能编剧工作流运行完成") return output_result except Exception as e: logger.error(f"运行工作流失败: {e}") + import traceback + traceback.print_exc() raise + + async def get_checkpoint_history(self, thread_id: str): + """获取检查点历史""" + config:RunnableConfig = {"configurable": {"thread_id": thread_id}} + + try: + history_generator = self.memory.list(config, limit=10) + + print("正在获取检查点历史...") + + # 使用列表推导式或 for 循环来收集所有检查点 + history = list(history_generator) + + print(f"找到 {len(history)} 个检查点:") + + for i, checkpoint_tuple in enumerate(history): + # checkpoint_tuple 包含 config, checkpoint, metadata 等属性 + # print(f" - ID: {checkpoint_tuple}") + checkpoint_data = checkpoint_tuple.checkpoint + metadata = checkpoint_tuple.metadata + print(f"检查点 {i+1}:") + print(f" - ID: {checkpoint_data.get('id', 'N/A')}") + print(f" - 状态: {checkpoint_data.get('channel_values', {})}") + print(f" - 元数据: {metadata}") + print("-" * 50) + + except Exception as e: + print(f"获取历史记录时出错: {e}") + + def resume_from_checkpoint(self, thread_id: str, checkpoint_id: str): + """从检查点恢复执行""" + config:RunnableConfig = {"configurable": {"thread_id": thread_id}} + + if checkpoint_id: + config["configurable"]["checkpoint_id"] = checkpoint_id + + try: + # 获取 CheckpointTuple 对象 + checkpoint_tuple = self.memory.get_tuple(config) + + if checkpoint_tuple: + # 直接通过属性访问,而不是解包 + checkpoint_data = checkpoint_tuple.checkpoint + metadata = checkpoint_tuple.metadata + print(f"从检查点恢复:") + print(f" - 检查点 ID: {checkpoint_data.get('id', 'N/A')}") + print(f" - 状态: {checkpoint_data.get('channel_values', {})}") + print(f" - 元数据: {metadata}") + return checkpoint_data.get('channel_values', {}) + else: + print(f"未找到线程 {thread_id} 的检查点") + return None + + except Exception as e: + print(f"恢复检查点时出错: {e}") + return None + + def get_graph_visualization(self) -> str: """获取工作流图的可视化表示 @@ -285,8 +434,31 @@ class ScriptwriterGraph: """ try: if self.graph: - return str(self.graph) + with open('graph_visualization.png', 'wb') as f: + f.write(self.graph.get_graph().draw_mermaid_png()) + print("图片已保存为 graph_visualization.png") return "工作流图未初始化" except Exception as e: logger.error(f"获取图可视化失败: {e}") - return f"获取图可视化失败: {e}" \ No newline at end of file + return f"获取图可视化失败: {e}" + + +if __name__ == "__main__": + import asyncio + + async def main(): + print("测试") + graph = ScriptwriterGraph() + print("创建完成") + # graph.get_graph_visualization() + # print("可视化完成") + # 运行工作流 + session_id = "68c2c2915e5746343301ef71" + result = await graph.run( + session_id, + [HumanMessage(content="你好编剧,我想写小说!")], + session_id + ) + print(f"最终结果: {result}") + + asyncio.run(main()) diff --git a/tools/database/mongodb_memory.py b/tools/database/mongodb_memory.py deleted file mode 100644 index 362f929..0000000 --- a/tools/database/mongodb_memory.py +++ /dev/null @@ -1,30 +0,0 @@ -"""工作流记忆管理模块 - -该模块负责管理智能编剧系统工作流的记忆存储和检索。 -""" - -import sys -import os -from typing import Dict, Any, List, Optional -from datetime import datetime -import json -from database import client # type: ignore -from langgraph.checkpoint.mongodb import MongoDBSaver - -# 添加项目根目录到路径 -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))) -from agentgraph.utils.logger import get_logger - -logger = get_logger(__name__) - -DB_NAME = "langgraph_memory_db" - -class WorkflowMemory: - """工作流记忆管理类 - - 负责管理工作流执行过程中的状态存储、检索和历史记录。 - """ - - def __init__(self): - """初始化工作流记忆管理器""" - self.memory = MongoDBSaver(client, db_name=DB_NAME)