"""智能编剧系统工作流图定义 该模块定义了智能编剧系统的完整工作流程图,包括各个节点和边的连接关系。 """ from re import T from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional from langgraph.graph.state import RunnableConfig from langgraph.prebuilt.chat_agent_executor import AgentState from agent.scheduler import SchedulerAgent from agent.build_bible import BuildBibleAgent from agent.episode_create import EpisodeCreateAgent from agent.script_analysis import ScriptAnalysisAgent from agent.strategic_planning import StrategicPlanningAgent from langgraph.checkpoint.memory import InMemorySaver from langchain_core.messages import AnyMessage,HumanMessage from langgraph.graph import StateGraph, START, END from utils.logger import get_logger import operator import json import config from tools.database.mongo import client # type: ignore from langgraph.checkpoint.mongodb import MongoDBSaver # 工具方法 from tools.agent.queryDB import QueryOriginalScript,QueryDiagnosisAndAssessment,QueryAdaptationIdeas,QueryScriptBible,QueryEpisodeCount from tools.agent.updateDB import UpdateAdaptationIdeasTool,UpdateScriptBibleTool,UpdateDiagnosisAndAssessmentTool,UpdateOneEpisodeTool logger = get_logger(__name__) def clear_messages(messages): """清除指定会话的所有消息""" # 剔除历史状态消息 exclude_strings = [ "---任务状态消息(开始)---", "---原始剧本(开始)---", "---诊断与资产评估报告(开始)---", "---改编思路(开始)---", ] messages = [ message for message in messages if all(s not in message.content for s in exclude_strings) ] # HumanMessage 超过 10 条,删除最早的 1 条 if len([message for message in messages if message.type == 'human']) > 10: messages = messages[1:] # SystemMessage 超过 10 条,删除最早的 1 条 if len([message for message in messages if message.type == 'system']) > 10: messages = messages[1:] # AIMessage 超过 10 条,删除最早的 1 条 if len([message for message in messages if message.type == 'ai']) > 10: messages = messages[1:] return messages def messages_handler(old_messages: list[AnyMessage], new_messages: list[AnyMessage]): """消息合并方法""" clear_messages(old_messages) new_messages = [message for message in new_messages if message.type != 'ai' or message.content] return old_messages + new_messages def replace_value(old_val, new_val): """值覆盖方法""" return new_val # 状态类型定义 class InputState(TypedDict): """工作流输入状态""" messages: Annotated[list[AnyMessage], messages_handler] from_type: Annotated[str, replace_value] session_id: Annotated[str, replace_value] class OutputState(TypedDict): """工作流输出状态""" session_id: Annotated[str, replace_value] status: Annotated[str, replace_value] error: Annotated[str, replace_value] agent_message: Annotated[str, replace_value] # 智能体回复 class TaskItem(TypedDict): """任务列表中的每个任务""" agent: Annotated[str, replace_value] # 执行这个任务的智能体名称 step: Annotated[str, replace_value] # 阶段名称 [wait_for_input,script_analysis,strategic_planning,build_bible,episode_create_loop, finish] status: Annotated[str, replace_value] # 当前阶段的状态 [waiting,running,failed,completed] reason: Annotated[str, replace_value] # 失败原因 retry_count: Annotated[int, replace_value] # 重试次数 pause: Annotated[bool, replace_value] # 是否暂停 episode_create_num: Annotated[List[int], replace_value] # 指定创作集数 仅在episode_create_loop阶段会返回,内容是数组,数组中每一项是指定创作的剧集编号(从1开始); class WorkflowInfo(TypedDict): """总工作流信息""" is_original_script: Annotated[bool, replace_value] # 是否已提交原始剧本 is_script_analysis: Annotated[bool, replace_value] # 是否已生成 诊断与资产评估报告 is_strategic_planning: Annotated[bool, replace_value] # 是否已生成 改编思路 is_build_bible: Annotated[bool, replace_value] # 是否已生成 剧本圣经 is_episode_create_loop: Annotated[bool, replace_value] # 是否已生成 剧集生成循环 is_all_episode_created: Annotated[bool, replace_value] # 是否已生成 全部剧集 class ScriptwriterState(TypedDict, total=False): """智能编剧工作流整体状态""" # 输入数据 messages: Annotated[list[AnyMessage], messages_handler] session_id: Annotated[str, replace_value] from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent] # 中间状态 workflow_info: Annotated[WorkflowInfo, replace_value] # 顺序执行的任务列表 task_list: Annotated[List[TaskItem], replace_value] # 顺序执行的任务列表 task_index: Annotated[int, replace_value] # 当前执行中的任务索引 # 输出数据 agent_message: Annotated[str, replace_value] # 智能体回复 status: Annotated[str, replace_value] error: Annotated[str, replace_value] AgentNodeMap = { "scheduler": "scheduler_node", "script_analysis": "script_analysis_node", "strategic_planning": "strategic_planning_node", "build_bible": "build_bible_node", "episode_create": "episode_create_node", } class ScriptwriterGraph: """智能编剧工作流图类 管理智能编剧系统的完整工作流程,包括: - 剧本接收 - 诊断分析 - 策略制定 - 剧本圣经构建 - 剧本创作 - 迭代调整 """ def __init__(self): """初始化工作流图""" self.graph = None self.memory = MongoDBSaver(client, db_name=config.MONGO_CHECKPOINT_DB_NAME) self._build_graph() def node_router(self, state: ScriptwriterState) -> str: """节点路由函数""" # return 'end_node' # print(f'node_router state {state}') task_list = state.get("task_list", []) task_index = state.get("task_index", 0) now_task = task_list[task_index] print(f'node_router now_task {now_task}') if not now_task or now_task.get('pause'): next_node = 'pause_node' # 设置为暂停节点 else: next_node = AgentNodeMap.get(now_task.get('agent'), 'pause_node') print(f'node_router next_node {next_node}') return next_node def post_scheduler_hook(self, state: ScriptwriterState)-> ScriptwriterState: """模型调用后的钩子函数""" ai_message_str = state.get("messages", [])[-1].content logger.info(f"!!!!!!!!!调度节点输出!!!!!!!!!!!") logger.info(f"调度节点结果: {ai_message_str}") # logger.info(f"调度节点结果 end") # ai_message = json.loads(ai_message_str) return state def _build_graph(self) -> None: """构建工作流图""" try: # 创建智能体 print("创建智能体") from tools.llm.deepseek_langchain import DeepseekChatModel llm = DeepseekChatModel(api_key='sk-571923fdcb0e493d8def7e2d78c02cb8') # 调度智能体 self.schedulerAgent = SchedulerAgent( # llm=llm, tools=[ QueryOriginalScript, QueryDiagnosisAndAssessment, QueryAdaptationIdeas, QueryScriptBible, QueryEpisodeCount, ], ) self.scriptAnalysisAgent = ScriptAnalysisAgent( tools=[], # SchedulerList=[ # { # "name": "scheduler_node", # "desc": "调度智能体节点", # } # ] ) self.strategicPlanningAgent = StrategicPlanningAgent( tools=[], # SchedulerList=[ # { # "name": "scheduler_node", # "desc": "调度智能体节点", # } # ] ) self.buildBibleAgent = BuildBibleAgent( tools=[], # SchedulerList=[ # { # "name": "scheduler_node", # "desc": "调度智能体节点", # } # ] ) self.episodeCreateAgent = EpisodeCreateAgent( tools=[], # SchedulerList=[ # { # "name": "scheduler_node", # "desc": "调度智能体节点", # } # ] ) # 创建状态图 logger.info("创建状态图") workflow = StateGraph(ScriptwriterState, input_schema=InputState, output_schema=OutputState) # 添加节点 logger.info("添加节点") workflow.add_node("scheduler_node", self.scheduler_node) workflow.add_node("script_analysis_node", self.script_analysis_node) workflow.add_node("strategic_planning_node", self.strategic_planning_node) workflow.add_node("build_bible_node", self.build_bible_node) workflow.add_node("episode_create_node", self.episode_create_node) workflow.add_node("end_node", self.end_node) workflow.add_node("pause_node", self.pause_node) # 添加边 workflow.set_entry_point("scheduler_node") # 所有功能节点执行完成后,都返回给调度节点 workflow.add_edge("script_analysis_node", "scheduler_node") workflow.add_edge("strategic_planning_node", "scheduler_node") workflow.add_edge("build_bible_node", "scheduler_node") workflow.add_edge("episode_create_node", "scheduler_node") # 添加条件边:由调度节点决定下一个路由 workflow.add_conditional_edges( "scheduler_node", self.node_router, { "script_analysis_node": "script_analysis_node", "strategic_planning_node": "strategic_planning_node", "build_bible_node": "build_bible_node", "episode_create_node": "episode_create_node", # 用户确认和暂停逻辑在这里处理,不需要单独的边 "end_node": "end_node", "pause_node": "pause_node", } ) workflow.add_edge("end_node", END) # 编译图 checkpoint = InMemorySaver() self.graph = workflow.compile(checkpointer=checkpoint) # 不缓存记忆 # self.graph = workflow.compile(checkpointer=self.memory) # 使用mongodb缓存记忆 logger.info("工作流图构建完成") except Exception as e: import traceback traceback.print_exc() logger.error(f"构建工作流图失败: {e}") raise # --- 定义图中的节点 --- async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState: """调度节点""" try: session_id = state.get("session_id", "") workflow_info = state.get("workflow_info", {}) task_list = state.get("task_list", []) task_index = int(state.get("task_index", 0)) from_type = state.get("from_type", "") messages = state.get("messages", []) # 清除历史状态消息 messages = clear_messages(messages) # 添加参数进提示词 # 工作流信息: # {workflow_info} messages.append(HumanMessage(content=f""" ---任务状态消息(开始)--- # 工作流参数: {{ 'session_id':'{session_id}', 'task_index':'{task_index}', 'from_type':'{from_type}', }} # 任务列表: {task_list} ---任务状态消息(结束)--- """)) system_message_count = 0 human_message_count = 0 ai_message_count = 0 for message in messages: if message.type == 'system': system_message_count += 1 elif message.type == 'human': human_message_count += 1 elif message.type == 'ai': ai_message_count += 1 logger.info(f"调度节点 {session_id} \n 输入消息条数: {len(messages)} \n from_type:{from_type} \n system_message_count:{system_message_count} \n human_message_count:{human_message_count} \n ai_message_count:{ai_message_count}") # 调用智能体 reslut = await self.schedulerAgent.ainvoke({ "messages": messages, "session_id": session_id, }) ai_message_str = reslut['messages'][-1].content logger.info(f"调度节点结果: {ai_message_str}") logger.info(f"调度节点结果 end") ai_message = json.loads(ai_message_str) # logger.info(f"调度节点结果: {ai_message}") return_message:str = ai_message.get('message', '') task_list:list = ai_message.get('task_list', []) task_index:int = int(ai_message.get('task_index', '0')) return { "agent_message": return_message, "task_list": task_list, "task_index": task_index, } except Exception as e: import traceback traceback.print_exc() return { "agent_message": "执行失败", "error": str(e) or '系统错误,工作流已终止', 'status':'failed', } async def script_analysis_node(self, state: ScriptwriterState)-> ScriptwriterState: """第二步:诊断分析与资产评估""" try: print("\n------------ 正在进行诊断分析 ------------") session_id = state.get("session_id", "") from_type = state.get("from_type", "") messages = state.get("messages", []) # 清除历史状态消息 messages = clear_messages(messages) # 添加参数进提示词 from tools.agent.queryDB import QueryOriginalScriptContent original_script_content = QueryOriginalScriptContent(session_id) messages.append(HumanMessage(content=f""" ---原始剧本(开始)--- {original_script_content['content']} ---原始剧本(结束)--- """)) reslut = await self.scriptAnalysisAgent.ainvoke({"messages": messages}) ai_message_str = reslut['messages'][-1].content ai_message = json.loads(ai_message_str) return_message:str = ai_message.get('message', '') task_list:list = state.get('task_list', []) task_index:int = int(state.get('task_index', '0')) task = task_list[task_index] if task: type:str = ai_message.get('type', '沟通') if type == '沟通': task['status'] = 'waiting' task['parse'] = True elif type == '输出': task['status'] = 'completed' task['parse'] = False diagnosis_and_assessment:str = ai_message.get('diagnosis_and_assessment', '') from tools.agent.updateDB import UpdateDiagnosisAndAssessment UpdateDiagnosisAndAssessment(session_id, diagnosis_and_assessment) # print(f"报告已生成: TEST") print("\n------------ 诊断分析结束 ------------") return { "messages": messages, "task_list": task_list, "task_index": task_index, "agent_message": return_message, } except Exception as e: import traceback traceback.print_exc() return { "agent_message": "诊断分析失败", "error": str(e) or '系统错误,工作流已终止', 'status':'failed', } async def strategic_planning_node(self, state: ScriptwriterState)-> ScriptwriterState: """第三步:确立改编目标与战略蓝图""" try: print("\n------------ 正在生成 改编思路 ------------") session_id = state.get("session_id", "") from_type = state.get("from_type", "") messages = state.get("messages", []) # 清除历史状态消息 messages = clear_messages(messages) # 添加参数进提示词 from tools.agent.queryDB import QueryOriginalScriptContent, QueryDiagnosisAndAssessmentContent original_script_content = QueryOriginalScriptContent(session_id) messages.append(HumanMessage(content=f""" ---原始剧本(开始)--- {original_script_content['content']} ---原始剧本(结束)--- """)) diagnosis_and_assessment_content = QueryDiagnosisAndAssessmentContent(session_id) messages.append(HumanMessage(content=f""" ---诊断与资产评估报告(开始)--- {diagnosis_and_assessment_content['content']} ---诊断与资产评估报告(结束)--- """)) reslut = await self.strategicPlanningAgent.ainvoke({"messages": messages}) ai_message_str = reslut['messages'][-1].content ai_message = json.loads(ai_message_str) return_message:str = ai_message.get('message', '') task_list:list = state.get('task_list', []) task_index:int = int(state.get('task_index', '0')) task = task_list[task_index] if task: type:str = ai_message.get('type', '沟通') if type == '沟通': task['status'] = 'waiting' task['parse'] = True elif type == '输出': task['status'] = 'completed' task['parse'] = False adaptation_ideas:str = ai_message.get('adaptation_ideas', '') total_episode_num:int = int(ai_message.get('total_episode_num', 0)) from tools.agent.updateDB import UpdateAdaptationIdeas,SetTotalEpisodeNum UpdateAdaptationIdeas(session_id, adaptation_ideas) SetTotalEpisodeNum(session_id, total_episode_num) # print(f"报告已生成: TEST") print("\n------------ 生成 改编思路 结束 ------------") return { "messages": messages, "task_list": task_list, "task_index": task_index, "agent_message": return_message, } except Exception as e: import traceback traceback.print_exc() return { "agent_message": "生成 改编思路 失败", "error": str(e) or '系统错误,工作流已终止', 'status':'failed', } async def build_bible_node(self, state: ScriptwriterState)-> ScriptwriterState: """第四步:制定剧本圣经""" try: print("\n------------ 正在生成 剧本圣经 ------------") session_id = state.get("session_id", "") from_type = state.get("from_type", "") messages = state.get("messages", []) # 清除历史状态消息 messages = clear_messages(messages) # 添加参数进提示词 from tools.agent.queryDB import QueryOriginalScriptContent, QueryAdaptationIdeasContent original_script_content = QueryOriginalScriptContent(session_id) messages.append(HumanMessage(content=f""" ---原始剧本(开始)--- {original_script_content['content']} ---原始剧本(结束)--- """)) adaptation_ideas_content = QueryAdaptationIdeasContent(session_id) messages.append(HumanMessage(content=f""" ---改编思路(开始)--- {adaptation_ideas_content['content']} ---改编思路(结束)--- """)) reslut = await self.buildBibleAgent.ainvoke({"messages": messages}) ai_message_str = reslut['messages'][-1].content ai_message = json.loads(ai_message_str) return_message:str = ai_message.get('message', '') task_list:list = state.get('task_list', []) task_index:int = int(state.get('task_index', '0')) task = task_list[task_index] if task: type:str = ai_message.get('type', '沟通') if type == '沟通': task['status'] = 'waiting' task['parse'] = True elif type == '输出': task['status'] = 'completed' task['parse'] = False script_bible:dict = ai_message.get('script_bible', {}) core_outline:str = script_bible.get('core_outline', '') character_profile:str = script_bible.get('character_profile', '') core_event_timeline:str = script_bible.get('core_event_timeline', '') character_list:str = script_bible.get('character_list', '') from tools.agent.updateDB import UpdateScriptBible UpdateScriptBible(session_id, core_outline, character_profile, core_event_timeline, character_list) # print(f"报告已生成: TEST") print("\n------------ 生成 剧本圣经 结束 ------------") return { "messages": messages, "task_list": task_list, "task_index": task_index, "agent_message": return_message, } except Exception as e: import traceback traceback.print_exc() return { "agent_message": "生成 剧本圣经 失败", "error": str(e) or '系统错误,工作流已终止', 'status':'failed', } async def episode_create_node(self, state: ScriptwriterState)-> ScriptwriterState: """第五步:循环创作剧本内容""" try: session_id = state.get("session_id", "") from_type = state.get("from_type", "") task_list:list = state.get('task_list', []) task_index:int = int(state.get('task_index', '0')) task:dict = task_list[task_index] or {} episode_create_num:list = task.get('episode_create_num', []) print(f"\n------------ 正在生成 剧集内容 {episode_create_num} ------------") messages = state.get("messages", []) # 清除历史状态消息 messages = clear_messages(messages) # 添加参数进提示词 from tools.agent.queryDB import QueryOriginalScriptContent, QueryAdaptationIdeasContent original_script_content = QueryOriginalScriptContent(session_id) messages.append(HumanMessage(content=f""" ---原始剧本(开始)--- {original_script_content['content']} ---原始剧本(结束)--- """)) adaptation_ideas_content = QueryAdaptationIdeasContent(session_id) messages.append(HumanMessage(content=f""" ---改编思路(开始)--- {adaptation_ideas_content['content']} ---改编思路(结束)--- """)) # 添加参数进提示词 messages.append(HumanMessage(content=f""" ---任务状态消息(开始)--- # 工作流参数: {{ 'episode_create_num':'{episode_create_num}', }} ---任务状态消息(结束)--- """)) reslut = await self.buildBibleAgent.ainvoke({"messages": messages}) ai_message_str = reslut['messages'][-1].content ai_message = json.loads(ai_message_str) return_message:str = ai_message.get('message', '') type:str = ai_message.get('type', '沟通') if type == '沟通': task['status'] = 'waiting' task['parse'] = True elif type == '输出': task['status'] = 'completed' task['parse'] = False episodes:list = ai_message.get('episodes', []) from tools.agent.updateDB import UpdateOneEpisode for episode in episodes: UpdateOneEpisode(session_id, episode.number, episode.content) # print(f"报告已生成: TEST") print("\n------------ 生成 剧本圣经 结束 ------------") return { "messages": messages, "task_list": task_list, "task_index": task_index, "agent_message": return_message, } except Exception as e: import traceback traceback.print_exc() return { "agent_message": "生成 剧本圣经 失败", "error": str(e) or '系统错误,工作流已终止', 'status':'failed', } async def pause_node(self, state: ScriptwriterState)-> ScriptwriterState: """ 暂停节点 处理并完成所有数据状态 """ print(f"langgraph 暂停等待") return { "session_id": state.get("session_id", ""), "status": state.get('status', ''), "error": state.get('error', ''), "agent_message": state.get('agent_message', '') } async def end_node(self, state: ScriptwriterState)-> OutputState: """ 结束节点 处理并完成所有数据状态 """ print(f"langgraph 所有任务完成") return { "session_id": state.get("session_id", ""), "status": state.get('status', ''), "error": state.get('error', ''), "agent_message": state.get('agent_message', ''), } async def run(self, session_id: str, messages: list[AnyMessage], thread_id: str|None = None) -> OutputState: """运行工作流 Args: session_id: 会话ID messages: 输入数据 thread_id: 线程ID Returns: 工作流执行结果 """ try: logger.info("开始运行智能编剧工作流") output_result: OutputState = { 'session_id': session_id, 'status': '', 'error': '', 'agent_message': '', } # 配置包含线程 ID config:RunnableConfig = {"configurable": {"thread_id": thread_id}} # 初始化状态 initial_state: InputState = { 'messages': messages, 'session_id': session_id, 'from_type': 'user', } # 运行工作流 if self.graph is None: raise RuntimeError("工作流图未正确初始化") # 运行工作流 直接返回最终结果 result = await self.graph.ainvoke( initial_state, config, # stream_mode='values' ) # logger.info(f"工作流执行结果: {result}") if not result: raise ValueError("工作流执行结果为空") # 构造输出状态 output_result = { 'session_id': result.get('session_id', ''), 'status': result.get('status', ''), 'error': result.get('error', ''), 'agent_message': result.get('agent_message', ''), } # 流式处理 # st = self.graph.stream( # initial_state, # config, # stream_mode='values' # ) # from utils.stream import debug_print_stream # debug_print_stream(st) logger.info("智能编剧工作流运行完成") return output_result except Exception as e: logger.error(f"运行工作流失败: {e}") import traceback traceback.print_exc() raise async def get_checkpoint_history(self, thread_id: str): """获取检查点历史""" config:RunnableConfig = {"configurable": {"thread_id": thread_id}} try: history_generator = self.memory.list(config, limit=10) print("正在获取检查点历史...") # 使用列表推导式或 for 循环来收集所有检查点 history = list(history_generator) print(f"找到 {len(history)} 个检查点:") for i, checkpoint_tuple in enumerate(history): # checkpoint_tuple 包含 config, checkpoint, metadata 等属性 # print(f" - ID: {checkpoint_tuple}") checkpoint_data = checkpoint_tuple.checkpoint metadata = checkpoint_tuple.metadata print(f"检查点 {i+1}:") print(f" - ID: {checkpoint_data.get('id', 'N/A')}") print(f" - 状态: {checkpoint_data.get('channel_values', {})}") print(f" - 元数据: {metadata}") print("-" * 50) except Exception as e: print(f"获取历史记录时出错: {e}") def resume_from_checkpoint(self, thread_id: str, checkpoint_id: str): """从检查点恢复执行""" config:RunnableConfig = {"configurable": {"thread_id": thread_id}} if checkpoint_id: config["configurable"]["checkpoint_id"] = checkpoint_id try: # 获取 CheckpointTuple 对象 checkpoint_tuple = self.memory.get_tuple(config) if checkpoint_tuple: # 直接通过属性访问,而不是解包 checkpoint_data = checkpoint_tuple.checkpoint metadata = checkpoint_tuple.metadata print(f"从检查点恢复:") print(f" - 检查点 ID: {checkpoint_data.get('id', 'N/A')}") print(f" - 状态: {checkpoint_data.get('channel_values', {})}") print(f" - 元数据: {metadata}") return checkpoint_data.get('channel_values', {}) else: print(f"未找到线程 {thread_id} 的检查点") return None except Exception as e: print(f"恢复检查点时出错: {e}") return None def get_graph_visualization(self) -> str: """获取工作流图的可视化表示 Returns: 图的文本表示 """ try: if self.graph: with open('graph_visualization.png', 'wb') as f: f.write(self.graph.get_graph().draw_mermaid_png()) print("图片已保存为 graph_visualization.png") return "工作流图未初始化" except Exception as e: logger.error(f"获取图可视化失败: {e}") return f"获取图可视化失败: {e}" if __name__ == "__main__": import asyncio async def main(): print("测试") graph = ScriptwriterGraph() print("创建完成") # graph.get_graph_visualization() # print("可视化完成") # 运行工作流 session_id = "68c2c2915e5746343301ef71" result = await graph.run( session_id, [HumanMessage(content="老师,我写好剧本了,您看看!帮我分析分析把!")], session_id ) print(f"最终结果: {result}") asyncio.run(main())