agent-writer/graph/test_agent_graph_1.py

772 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""智能编剧系统工作流图定义
该模块定义了智能编剧系统的完整工作流程图,包括各个节点和边的连接关系。
"""
from re import T
from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional
from langgraph.graph.state import RunnableConfig
from langgraph.prebuilt.chat_agent_executor import AgentState
from agent.scheduler import SchedulerAgent
from agent.build_bible import BuildBibleAgent
from agent.episode_create import EpisodeCreateAgent
from agent.script_analysis import ScriptAnalysisAgent
from agent.strategic_planning import StrategicPlanningAgent
from langgraph.checkpoint.memory import InMemorySaver
from langchain_core.messages import AnyMessage,HumanMessage
from langgraph.graph import StateGraph, START, END
from utils.logger import get_logger
import operator
import json
import config
from tools.database.mongo import client # type: ignore
from langgraph.checkpoint.mongodb import MongoDBSaver
# 工具方法
from tools.agent.queryDB import QueryOriginalScript,QueryDiagnosisAndAssessment,QueryAdaptationIdeas,QueryScriptBible,QueryEpisodeCount
from tools.agent.updateDB import UpdateAdaptationIdeasTool,UpdateScriptBibleTool,UpdateDiagnosisAndAssessmentTool,UpdateOneEpisodeTool
logger = get_logger(__name__)
def clear_messages(messages):
"""清除指定会话的所有消息"""
# 剔除历史状态消息
exclude_strings = [
"---任务状态消息(开始)---",
"---原始剧本(开始)---",
"---诊断与资产评估报告(开始)---",
"---改编思路(开始)---",
]
messages = [
message for message in messages
if all(s not in message.content for s in exclude_strings)
]
# HumanMessage 超过 10 条,删除最早的 1 条
if len([message for message in messages if message.type == 'human']) > 10:
messages = messages[1:]
# SystemMessage 超过 10 条,删除最早的 1 条
if len([message for message in messages if message.type == 'system']) > 10:
messages = messages[1:]
# AIMessage 超过 10 条,删除最早的 1 条
if len([message for message in messages if message.type == 'ai']) > 10:
messages = messages[1:]
return messages
def messages_handler(old_messages: list[AnyMessage], new_messages: list[AnyMessage]):
"""消息合并方法"""
clear_messages(old_messages)
new_messages = [message for message in new_messages if message.type != 'ai' or message.content]
return old_messages + new_messages
def replace_value(old_val, new_val):
"""值覆盖方法"""
return new_val
# 状态类型定义
class InputState(TypedDict):
"""工作流输入状态"""
messages: Annotated[list[AnyMessage], messages_handler]
from_type: Annotated[str, replace_value]
session_id: Annotated[str, replace_value]
class OutputState(TypedDict):
"""工作流输出状态"""
session_id: Annotated[str, replace_value]
status: Annotated[str, replace_value]
error: Annotated[str, replace_value]
agent_message: Annotated[str, replace_value] # 智能体回复
class TaskItem(TypedDict):
"""任务列表中的每个任务"""
agent: Annotated[str, replace_value] # 执行这个任务的智能体名称
step: Annotated[str, replace_value] # 阶段名称 [wait_for_input,script_analysis,strategic_planning,build_bible,episode_create_loop, finish]
status: Annotated[str, replace_value] # 当前阶段的状态 [waiting,running,failed,completed]
reason: Annotated[str, replace_value] # 失败原因
retry_count: Annotated[int, replace_value] # 重试次数
pause: Annotated[bool, replace_value] # 是否暂停
episode_create_num: Annotated[List[int], replace_value] # 指定创作集数 仅在episode_create_loop阶段会返回内容是数组数组中每一项是指定创作的剧集编号从1开始;
class WorkflowInfo(TypedDict):
"""总工作流信息"""
is_original_script: Annotated[bool, replace_value] # 是否已提交原始剧本
is_script_analysis: Annotated[bool, replace_value] # 是否已生成 诊断与资产评估报告
is_strategic_planning: Annotated[bool, replace_value] # 是否已生成 改编思路
is_build_bible: Annotated[bool, replace_value] # 是否已生成 剧本圣经
is_episode_create_loop: Annotated[bool, replace_value] # 是否已生成 剧集生成循环
is_all_episode_created: Annotated[bool, replace_value] # 是否已生成 全部剧集
class ScriptwriterState(TypedDict, total=False):
"""智能编剧工作流整体状态"""
# 输入数据
messages: Annotated[list[AnyMessage], messages_handler]
session_id: Annotated[str, replace_value]
from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent]
# 中间状态
workflow_info: Annotated[WorkflowInfo, replace_value] # 顺序执行的任务列表
task_list: Annotated[List[TaskItem], replace_value] # 顺序执行的任务列表
task_index: Annotated[int, replace_value] # 当前执行中的任务索引
# 输出数据
agent_message: Annotated[str, replace_value] # 智能体回复
status: Annotated[str, replace_value]
error: Annotated[str, replace_value]
AgentNodeMap = {
"scheduler": "scheduler_node",
"script_analysis": "script_analysis_node",
"strategic_planning": "strategic_planning_node",
"build_bible": "build_bible_node",
"episode_create": "episode_create_node",
}
class ScriptwriterGraph:
"""智能编剧工作流图类
管理智能编剧系统的完整工作流程,包括:
- 剧本接收
- 诊断分析
- 策略制定
- 剧本圣经构建
- 剧本创作
- 迭代调整
"""
def __init__(self):
"""初始化工作流图"""
self.graph = None
self.memory = MongoDBSaver(client, db_name=config.MONGO_CHECKPOINT_DB_NAME)
self._build_graph()
def node_router(self, state: ScriptwriterState) -> str:
"""节点路由函数"""
# return 'end_node'
# print(f'node_router state {state}')
task_list = state.get("task_list", [])
task_index = state.get("task_index", 0)
now_task = task_list[task_index]
print(f'node_router now_task {now_task}')
if not now_task or now_task.get('pause'):
next_node = 'pause_node' # 设置为暂停节点
else:
next_node = AgentNodeMap.get(now_task.get('agent'), 'pause_node')
print(f'node_router next_node {next_node}')
return next_node
def post_scheduler_hook(self, state: ScriptwriterState)-> ScriptwriterState:
"""模型调用后的钩子函数"""
ai_message_str = state.get("messages", [])[-1].content
logger.info(f"!!!!!!!!!调度节点输出!!!!!!!!!!!")
logger.info(f"调度节点结果: {ai_message_str}")
# logger.info(f"调度节点结果 end")
# ai_message = json.loads(ai_message_str)
return state
def _build_graph(self) -> None:
"""构建工作流图"""
try:
# 创建智能体
print("创建智能体")
from tools.llm.deepseek_langchain import DeepseekChatModel
llm = DeepseekChatModel(api_key='sk-571923fdcb0e493d8def7e2d78c02cb8')
# 调度智能体
self.schedulerAgent = SchedulerAgent(
# llm=llm,
tools=[
QueryOriginalScript,
QueryDiagnosisAndAssessment,
QueryAdaptationIdeas,
QueryScriptBible,
QueryEpisodeCount,
],
)
self.scriptAnalysisAgent = ScriptAnalysisAgent(
tools=[],
# SchedulerList=[
# {
# "name": "scheduler_node",
# "desc": "调度智能体节点",
# }
# ]
)
self.strategicPlanningAgent = StrategicPlanningAgent(
tools=[],
# SchedulerList=[
# {
# "name": "scheduler_node",
# "desc": "调度智能体节点",
# }
# ]
)
self.buildBibleAgent = BuildBibleAgent(
tools=[],
# SchedulerList=[
# {
# "name": "scheduler_node",
# "desc": "调度智能体节点",
# }
# ]
)
self.episodeCreateAgent = EpisodeCreateAgent(
tools=[],
# SchedulerList=[
# {
# "name": "scheduler_node",
# "desc": "调度智能体节点",
# }
# ]
)
# 创建状态图
logger.info("创建状态图")
workflow = StateGraph(ScriptwriterState, input_schema=InputState, output_schema=OutputState)
# 添加节点
logger.info("添加节点")
workflow.add_node("scheduler_node", self.scheduler_node)
workflow.add_node("script_analysis_node", self.script_analysis_node)
workflow.add_node("strategic_planning_node", self.strategic_planning_node)
workflow.add_node("build_bible_node", self.build_bible_node)
workflow.add_node("episode_create_node", self.episode_create_node)
workflow.add_node("end_node", self.end_node)
workflow.add_node("pause_node", self.pause_node)
# 添加边
workflow.set_entry_point("scheduler_node")
# 所有功能节点执行完成后,都返回给调度节点
workflow.add_edge("script_analysis_node", "scheduler_node")
workflow.add_edge("strategic_planning_node", "scheduler_node")
workflow.add_edge("build_bible_node", "scheduler_node")
workflow.add_edge("episode_create_node", "scheduler_node")
# 添加条件边:由调度节点决定下一个路由
workflow.add_conditional_edges(
"scheduler_node",
self.node_router,
{
"script_analysis_node": "script_analysis_node",
"strategic_planning_node": "strategic_planning_node",
"build_bible_node": "build_bible_node",
"episode_create_node": "episode_create_node",
# 用户确认和暂停逻辑在这里处理,不需要单独的边
"end_node": "end_node",
"pause_node": "pause_node",
}
)
workflow.add_edge("end_node", END)
# 编译图
checkpoint = InMemorySaver()
self.graph = workflow.compile(checkpointer=checkpoint) # 不缓存记忆
# self.graph = workflow.compile(checkpointer=self.memory) # 使用mongodb缓存记忆
logger.info("工作流图构建完成")
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"构建工作流图失败: {e}")
raise
# --- 定义图中的节点 ---
async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState:
"""调度节点"""
try:
session_id = state.get("session_id", "")
workflow_info = state.get("workflow_info", {})
task_list = state.get("task_list", [])
task_index = int(state.get("task_index", 0))
from_type = state.get("from_type", "")
messages = state.get("messages", [])
# 清除历史状态消息
messages = clear_messages(messages)
# 添加参数进提示词
# 工作流信息:
# {workflow_info}
messages.append(HumanMessage(content=f"""
---任务状态消息(开始)---
# 工作流参数:
{{
'session_id':'{session_id}',
'task_index':'{task_index}',
'from_type':'{from_type}',
}}
# 任务列表:
{task_list}
---任务状态消息(结束)---
"""))
system_message_count = 0
human_message_count = 0
ai_message_count = 0
for message in messages:
if message.type == 'system':
system_message_count += 1
elif message.type == 'human':
human_message_count += 1
elif message.type == 'ai':
ai_message_count += 1
logger.info(f"调度节点 {session_id} \n 输入消息条数: {len(messages)} \n from_type:{from_type} \n system_message_count:{system_message_count} \n human_message_count:{human_message_count} \n ai_message_count:{ai_message_count}")
# 调用智能体
reslut = await self.schedulerAgent.ainvoke({
"messages": messages,
"session_id": session_id,
})
ai_message_str = reslut['messages'][-1].content
logger.info(f"调度节点结果: {ai_message_str}")
logger.info(f"调度节点结果 end")
ai_message = json.loads(ai_message_str)
# logger.info(f"调度节点结果: {ai_message}")
return_message:str = ai_message.get('message', '')
task_list:list = ai_message.get('task_list', [])
task_index:int = int(ai_message.get('task_index', '0'))
return {
"agent_message": return_message,
"task_list": task_list,
"task_index": task_index,
}
except Exception as e:
import traceback
traceback.print_exc()
return {
"agent_message": "执行失败",
"error": str(e) or '系统错误,工作流已终止',
'status':'failed',
}
async def script_analysis_node(self, state: ScriptwriterState)-> ScriptwriterState:
"""第二步:诊断分析与资产评估"""
try:
print("\n------------ 正在进行诊断分析 ------------")
session_id = state.get("session_id", "")
from_type = state.get("from_type", "")
messages = state.get("messages", [])
# 清除历史状态消息
messages = clear_messages(messages)
# 添加参数进提示词
from tools.agent.queryDB import QueryOriginalScriptContent
original_script_content = QueryOriginalScriptContent(session_id)
messages.append(HumanMessage(content=f"""
---原始剧本(开始)---
{original_script_content['content']}
---原始剧本(结束)---
"""))
reslut = await self.scriptAnalysisAgent.ainvoke({"messages": messages})
ai_message_str = reslut['messages'][-1].content
ai_message = json.loads(ai_message_str)
return_message:str = ai_message.get('message', '')
task_list:list = state.get('task_list', [])
task_index:int = int(state.get('task_index', '0'))
task = task_list[task_index]
if task:
type:str = ai_message.get('type', '沟通')
if type == '沟通':
task['status'] = 'waiting'
task['parse'] = True
elif type == '输出':
task['status'] = 'completed'
task['parse'] = False
diagnosis_and_assessment:str = ai_message.get('diagnosis_and_assessment', '')
from tools.agent.updateDB import UpdateDiagnosisAndAssessment
UpdateDiagnosisAndAssessment(session_id, diagnosis_and_assessment)
# print(f"报告已生成: TEST")
print("\n------------ 诊断分析结束 ------------")
return {
"messages": messages,
"task_list": task_list,
"task_index": task_index,
"agent_message": return_message,
}
except Exception as e:
import traceback
traceback.print_exc()
return {
"agent_message": "诊断分析失败",
"error": str(e) or '系统错误,工作流已终止',
'status':'failed',
}
async def strategic_planning_node(self, state: ScriptwriterState)-> ScriptwriterState:
"""第三步:确立改编目标与战略蓝图"""
try:
print("\n------------ 正在生成 改编思路 ------------")
session_id = state.get("session_id", "")
from_type = state.get("from_type", "")
messages = state.get("messages", [])
# 清除历史状态消息
messages = clear_messages(messages)
# 添加参数进提示词
from tools.agent.queryDB import QueryOriginalScriptContent, QueryDiagnosisAndAssessmentContent
original_script_content = QueryOriginalScriptContent(session_id)
messages.append(HumanMessage(content=f"""
---原始剧本(开始)---
{original_script_content['content']}
---原始剧本(结束)---
"""))
diagnosis_and_assessment_content = QueryDiagnosisAndAssessmentContent(session_id)
messages.append(HumanMessage(content=f"""
---诊断与资产评估报告(开始)---
{diagnosis_and_assessment_content['content']}
---诊断与资产评估报告(结束)---
"""))
reslut = await self.strategicPlanningAgent.ainvoke({"messages": messages})
ai_message_str = reslut['messages'][-1].content
ai_message = json.loads(ai_message_str)
return_message:str = ai_message.get('message', '')
task_list:list = state.get('task_list', [])
task_index:int = int(state.get('task_index', '0'))
task = task_list[task_index]
if task:
type:str = ai_message.get('type', '沟通')
if type == '沟通':
task['status'] = 'waiting'
task['parse'] = True
elif type == '输出':
task['status'] = 'completed'
task['parse'] = False
adaptation_ideas:str = ai_message.get('adaptation_ideas', '')
total_episode_num:int = int(ai_message.get('total_episode_num', 0))
from tools.agent.updateDB import UpdateAdaptationIdeas,SetTotalEpisodeNum
UpdateAdaptationIdeas(session_id, adaptation_ideas)
SetTotalEpisodeNum(session_id, total_episode_num)
# print(f"报告已生成: TEST")
print("\n------------ 生成 改编思路 结束 ------------")
return {
"messages": messages,
"task_list": task_list,
"task_index": task_index,
"agent_message": return_message,
}
except Exception as e:
import traceback
traceback.print_exc()
return {
"agent_message": "生成 改编思路 失败",
"error": str(e) or '系统错误,工作流已终止',
'status':'failed',
}
async def build_bible_node(self, state: ScriptwriterState)-> ScriptwriterState:
"""第四步:制定剧本圣经"""
try:
print("\n------------ 正在生成 剧本圣经 ------------")
session_id = state.get("session_id", "")
from_type = state.get("from_type", "")
messages = state.get("messages", [])
# 清除历史状态消息
messages = clear_messages(messages)
# 添加参数进提示词
from tools.agent.queryDB import QueryOriginalScriptContent, QueryAdaptationIdeasContent
original_script_content = QueryOriginalScriptContent(session_id)
messages.append(HumanMessage(content=f"""
---原始剧本(开始)---
{original_script_content['content']}
---原始剧本(结束)---
"""))
adaptation_ideas_content = QueryAdaptationIdeasContent(session_id)
messages.append(HumanMessage(content=f"""
---改编思路(开始)---
{adaptation_ideas_content['content']}
---改编思路(结束)---
"""))
reslut = await self.buildBibleAgent.ainvoke({"messages": messages})
ai_message_str = reslut['messages'][-1].content
ai_message = json.loads(ai_message_str)
return_message:str = ai_message.get('message', '')
task_list:list = state.get('task_list', [])
task_index:int = int(state.get('task_index', '0'))
task = task_list[task_index]
if task:
type:str = ai_message.get('type', '沟通')
if type == '沟通':
task['status'] = 'waiting'
task['parse'] = True
elif type == '输出':
task['status'] = 'completed'
task['parse'] = False
script_bible:dict = ai_message.get('script_bible', {})
core_outline:str = script_bible.get('core_outline', '')
character_profile:str = script_bible.get('character_profile', '')
core_event_timeline:str = script_bible.get('core_event_timeline', '')
character_list:str = script_bible.get('character_list', '')
from tools.agent.updateDB import UpdateScriptBible
UpdateScriptBible(session_id, core_outline, character_profile, core_event_timeline, character_list)
# print(f"报告已生成: TEST")
print("\n------------ 生成 剧本圣经 结束 ------------")
return {
"messages": messages,
"task_list": task_list,
"task_index": task_index,
"agent_message": return_message,
}
except Exception as e:
import traceback
traceback.print_exc()
return {
"agent_message": "生成 剧本圣经 失败",
"error": str(e) or '系统错误,工作流已终止',
'status':'failed',
}
async def episode_create_node(self, state: ScriptwriterState)-> ScriptwriterState:
"""第五步:循环创作剧本内容"""
try:
session_id = state.get("session_id", "")
from_type = state.get("from_type", "")
task_list:list = state.get('task_list', [])
task_index:int = int(state.get('task_index', '0'))
task:dict = task_list[task_index] or {}
episode_create_num:list = task.get('episode_create_num', [])
print(f"\n------------ 正在生成 剧集内容 {episode_create_num} ------------")
messages = state.get("messages", [])
# 清除历史状态消息
messages = clear_messages(messages)
# 添加参数进提示词
from tools.agent.queryDB import QueryOriginalScriptContent, QueryAdaptationIdeasContent
original_script_content = QueryOriginalScriptContent(session_id)
messages.append(HumanMessage(content=f"""
---原始剧本(开始)---
{original_script_content['content']}
---原始剧本(结束)---
"""))
adaptation_ideas_content = QueryAdaptationIdeasContent(session_id)
messages.append(HumanMessage(content=f"""
---改编思路(开始)---
{adaptation_ideas_content['content']}
---改编思路(结束)---
"""))
# 添加参数进提示词
messages.append(HumanMessage(content=f"""
---任务状态消息(开始)---
# 工作流参数:
{{
'episode_create_num':'{episode_create_num}',
}}
---任务状态消息(结束)---
"""))
reslut = await self.buildBibleAgent.ainvoke({"messages": messages})
ai_message_str = reslut['messages'][-1].content
ai_message = json.loads(ai_message_str)
return_message:str = ai_message.get('message', '')
type:str = ai_message.get('type', '沟通')
if type == '沟通':
task['status'] = 'waiting'
task['parse'] = True
elif type == '输出':
task['status'] = 'completed'
task['parse'] = False
episodes:list = ai_message.get('episodes', [])
from tools.agent.updateDB import UpdateOneEpisode
for episode in episodes:
UpdateOneEpisode(session_id, episode.number, episode.content)
# print(f"报告已生成: TEST")
print("\n------------ 生成 剧本圣经 结束 ------------")
return {
"messages": messages,
"task_list": task_list,
"task_index": task_index,
"agent_message": return_message,
}
except Exception as e:
import traceback
traceback.print_exc()
return {
"agent_message": "生成 剧本圣经 失败",
"error": str(e) or '系统错误,工作流已终止',
'status':'failed',
}
async def pause_node(self, state: ScriptwriterState)-> ScriptwriterState:
""" 暂停节点 处理并完成所有数据状态 """
print(f"langgraph 暂停等待")
return {
"session_id": state.get("session_id", ""),
"status": state.get('status', ''),
"error": state.get('error', ''),
"agent_message": state.get('agent_message', '')
}
async def end_node(self, state: ScriptwriterState)-> OutputState:
""" 结束节点 处理并完成所有数据状态 """
print(f"langgraph 所有任务完成")
return {
"session_id": state.get("session_id", ""),
"status": state.get('status', ''),
"error": state.get('error', ''),
"agent_message": state.get('agent_message', ''),
}
async def run(self, session_id: str, messages: list[AnyMessage], thread_id: str|None = None) -> OutputState:
"""运行工作流
Args:
session_id: 会话ID
messages: 输入数据
thread_id: 线程ID
Returns:
工作流执行结果
"""
try:
logger.info("开始运行智能编剧工作流")
output_result: OutputState = {
'session_id': session_id,
'status': '',
'error': '',
'agent_message': '',
}
# 配置包含线程 ID
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
# 初始化状态
initial_state: InputState = {
'messages': messages,
'session_id': session_id,
'from_type': 'user',
}
# 运行工作流
if self.graph is None:
raise RuntimeError("工作流图未正确初始化")
# 运行工作流 直接返回最终结果
result = await self.graph.ainvoke(
initial_state,
config,
# stream_mode='values'
)
# logger.info(f"工作流执行结果: {result}")
if not result:
raise ValueError("工作流执行结果为空")
# 构造输出状态
output_result = {
'session_id': result.get('session_id', ''),
'status': result.get('status', ''),
'error': result.get('error', ''),
'agent_message': result.get('agent_message', ''),
}
# 流式处理
# st = self.graph.stream(
# initial_state,
# config,
# stream_mode='values'
# )
# from utils.stream import debug_print_stream
# debug_print_stream(st)
logger.info("智能编剧工作流运行完成")
return output_result
except Exception as e:
logger.error(f"运行工作流失败: {e}")
import traceback
traceback.print_exc()
raise
async def get_checkpoint_history(self, thread_id: str):
"""获取检查点历史"""
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
try:
history_generator = self.memory.list(config, limit=10)
print("正在获取检查点历史...")
# 使用列表推导式或 for 循环来收集所有检查点
history = list(history_generator)
print(f"找到 {len(history)} 个检查点:")
for i, checkpoint_tuple in enumerate(history):
# checkpoint_tuple 包含 config, checkpoint, metadata 等属性
# print(f" - ID: {checkpoint_tuple}")
checkpoint_data = checkpoint_tuple.checkpoint
metadata = checkpoint_tuple.metadata
print(f"检查点 {i+1}:")
print(f" - ID: {checkpoint_data.get('id', 'N/A')}")
print(f" - 状态: {checkpoint_data.get('channel_values', {})}")
print(f" - 元数据: {metadata}")
print("-" * 50)
except Exception as e:
print(f"获取历史记录时出错: {e}")
def resume_from_checkpoint(self, thread_id: str, checkpoint_id: str):
"""从检查点恢复执行"""
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
if checkpoint_id:
config["configurable"]["checkpoint_id"] = checkpoint_id
try:
# 获取 CheckpointTuple 对象
checkpoint_tuple = self.memory.get_tuple(config)
if checkpoint_tuple:
# 直接通过属性访问,而不是解包
checkpoint_data = checkpoint_tuple.checkpoint
metadata = checkpoint_tuple.metadata
print(f"从检查点恢复:")
print(f" - 检查点 ID: {checkpoint_data.get('id', 'N/A')}")
print(f" - 状态: {checkpoint_data.get('channel_values', {})}")
print(f" - 元数据: {metadata}")
return checkpoint_data.get('channel_values', {})
else:
print(f"未找到线程 {thread_id} 的检查点")
return None
except Exception as e:
print(f"恢复检查点时出错: {e}")
return None
def get_graph_visualization(self) -> str:
"""获取工作流图的可视化表示
Returns:
图的文本表示
"""
try:
if self.graph:
with open('graph_visualization.png', 'wb') as f:
f.write(self.graph.get_graph().draw_mermaid_png())
print("图片已保存为 graph_visualization.png")
return "工作流图未初始化"
except Exception as e:
logger.error(f"获取图可视化失败: {e}")
return f"获取图可视化失败: {e}"
if __name__ == "__main__":
import asyncio
async def main():
print("测试")
graph = ScriptwriterGraph()
print("创建完成")
# graph.get_graph_visualization()
# print("可视化完成")
# 运行工作流
session_id = "68c2c2915e5746343301ef71"
result = await graph.run(
session_id,
[HumanMessage(content="老师,我写好剧本了,您看看!帮我分析分析把!")],
session_id
)
print(f"最终结果: {result}")
asyncio.run(main())