agent-writer/graph/test_agent_graph_1.py

465 lines
18 KiB
Python

"""智能编剧系统工作流图定义
该模块定义了智能编剧系统的完整工作流程图,包括各个节点和边的连接关系。
"""
from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional
from langgraph.graph.state import RunnableConfig
from agent.scheduler import SchedulerAgent
from agent.build_bible import BuildBibleAgent
from agent.episode_create import EpisodeCreateAgent
from agent.script_analysis import ScriptAnalysisAgent
from agent.strategic_planning import StrategicPlanningAgent
from langchain_core.messages import AnyMessage,HumanMessage
from langgraph.graph import StateGraph, START, END
from utils.logger import get_logger
import operator
import json
import config
from tools.database.mongo import client # type: ignore
from langgraph.checkpoint.mongodb import MongoDBSaver
logger = get_logger(__name__)
# 定义一个简单的替换函数
def replace_value(old_val, new_val):
"""一个简单的合并函数,用于替换旧值"""
return new_val
# 状态类型定义
class InputState(TypedDict):
"""工作流输入状态"""
input_data: Annotated[list[AnyMessage], operator.add]
from_type: Annotated[str, replace_value]
session_id: Annotated[str, replace_value]
class OutputState(TypedDict):
"""工作流输出状态"""
session_id: Annotated[str, replace_value]
status: Annotated[str, replace_value]
error: Annotated[str, replace_value]
agent_message: Annotated[str, replace_value] # 智能体回复
class NodeInfo(TypedDict):
"""工作流信息"""
step: Annotated[str, replace_value] # 阶段名称 [wait_for_input,script_analysis,strategic_planning,build_bible,episode_create_loop, finish]
status: Annotated[str, replace_value] # 当前阶段的状态 [waiting,running,failed,completed]
reason: Annotated[str, replace_value] # 失败原因
retry_count: Annotated[int, replace_value] # 重试次数
from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent]
class ScriptwriterState(TypedDict, total=False):
"""智能编剧工作流整体状态"""
# 输入数据
input_data: Annotated[list[HumanMessage], operator.add]
session_id: Annotated[str, replace_value]
from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent]
# 节点间状态
next_node: Annotated[str, replace_value] # 下一个节点
workflow_step: Annotated[str, replace_value] # 阶段名称 [wait_for_input,script_analysis,strategic_planning,build_bible,episode_create_loop, finish]
workflow_status: Annotated[str, replace_value] # 当前阶段的状态 [waiting,running,failed,completed]
workflow_reason: Annotated[str, replace_value] # 失败原因
workflow_retry_count: Annotated[int, replace_value] # 重试次数
# 中间状态
agent_script_id: Annotated[str, replace_value] # 剧本ID 包括原文
agent_plan: Annotated[Dict[str, Any], replace_value] #剧本计划
script_bible: Annotated[Dict[str, Any], replace_value] #剧本圣经
episode_list: Annotated[List, replace_value] # 章节列表 完成状态、产出章节id
# 输出数据
agent_message: Annotated[str, replace_value] # 智能体回复
status: Annotated[str, replace_value]
error: Annotated[str, replace_value]
class ScriptwriterGraph:
"""智能编剧工作流图类
管理智能编剧系统的完整工作流程,包括:
- 剧本接收
- 诊断分析
- 策略制定
- 剧本圣经构建
- 剧本创作
- 迭代调整
"""
def __init__(self):
"""初始化工作流图"""
self.graph = None
self.memory = MongoDBSaver(client, db_name=config.MONGO_CHECKPOINT_DB_NAME)
self._build_graph()
def node_router(self, state: ScriptwriterState) -> str:
"""节点路由函数"""
print(f'node_router state {state}')
next_node = state.get("next_node", 'pause_node')
# 修复:当 next_node 为空字符串时,设置默认值
if not next_node:
next_node = 'pause_node' # 设置为暂停节点
print(f'node_router next_node {next_node}')
return next_node
def _build_graph(self) -> None:
"""构建工作流图"""
try:
# 创建智能体
print("创建智能体")
# 调度智能体
self.schedulerAgent = SchedulerAgent(
tools=[],
SchedulerList=[
{
"name": "scheduler_node",
"desc": "调度智能体节点",
},
{
"name": "script_analysis_node",
"desc": "原始剧本分析节点",
},
{
"name": "strategic_planning_node",
"desc": "确立改编目标节点",
},
{
"name": "build_bible_node",
"desc": "剧本圣经构建节点",
},
{
"name": "episode_create_node",
"desc": "单集创作节点",
},
{
"name": "end_node",
"desc": "结束节点,任务失败终止时使用,结束后整个工作流将停止"
}
]
)
self.scriptAnalysisAgent = ScriptAnalysisAgent(
tools=[],
SchedulerList=[
{
"name": "scheduler_node",
"desc": "调度智能体节点",
}
]
)
self.strategicPlanningAgent = StrategicPlanningAgent(
tools=[],
SchedulerList=[
{
"name": "scheduler_node",
"desc": "调度智能体节点",
}
]
)
self.buildBibleAgent = BuildBibleAgent(
tools=[],
SchedulerList=[
{
"name": "scheduler_node",
"desc": "调度智能体节点",
}
]
)
self.episodeCreate = EpisodeCreateAgent(
tools=[],
SchedulerList=[
{
"name": "scheduler_node",
"desc": "调度智能体节点",
}
]
)
# 创建状态图
logger.info("创建状态图")
workflow = StateGraph(ScriptwriterState, input_schema=InputState, output_schema=OutputState)
# 添加节点
logger.info("添加节点")
workflow.add_node("scheduler_node", self.scheduler_node)
workflow.add_node("script_analysis_node", self.script_analysis_node)
workflow.add_node("strategic_planning_node", self.strategic_planning_node)
workflow.add_node("build_bible_node", self.build_bible_node)
workflow.add_node("episode_create_node", self.episode_create_node)
workflow.add_node("end_node", self.end_node)
workflow.add_node("pause_node", self.pause_node)
# 添加边
workflow.set_entry_point("scheduler_node")
# 所有功能节点执行完成后,都返回给调度节点
workflow.add_edge("script_analysis_node", "scheduler_node")
workflow.add_edge("strategic_planning_node", "scheduler_node")
workflow.add_edge("build_bible_node", "scheduler_node")
workflow.add_edge("episode_create_node", "scheduler_node")
# 添加条件边:由调度节点决定下一个路由
workflow.add_conditional_edges(
"scheduler_node",
self.node_router,
{
"script_analysis_node": "script_analysis_node",
"strategic_planning_node": "strategic_planning_node",
"build_bible_node": "build_bible_node",
"episode_create_node": "episode_create_node",
# 用户确认和暂停逻辑在这里处理,不需要单独的边
"end_node": "end_node",
"pause_node": "pause_node",
}
)
workflow.add_edge("end_node", END)
# 编译图
self.graph = workflow.compile(checkpointer=self.memory)
logger.info("工作流图构建完成")
except Exception as e:
logger.error(f"构建工作流图失败: {e}")
raise
# --- 定义图中的节点 ---
async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState:
"""调度节点"""
try:
session_id = state.get("session_id", "")
from_type = state.get("from_type", "")
input_data = state.get("input_data", [])
logger.info(f"调度节点 {session_id} 输入参数: {input_data} from_type:{from_type}")
reslut = await self.schedulerAgent.ainvoke(state)
ai_message_str = reslut['messages'][-1].content
ai_message = json.loads(ai_message_str)
logger.info(f"调度节点结果: {ai_message}")
step:str = ai_message.get('step', '')
status:str = ai_message.get('status', '')
next_agent:str = ai_message.get('agent', '')
return_message:str = ai_message.get('message', '')
retry_count:int = int(ai_message.get('retry_count', '0'))
next_node:str = ai_message.get('node', 'pause_node')
if next_node == 'scheduler_node':
# 返回自身 代表暂停
print(f"调度节点 暂停等待")
return {
"agent_message": return_message,
}
else:
return {
"next_node":next_node,
"workflow_step":step,
"workflow_status":status,
# "workflow_reason":return_message,
"workflow_retry_count":retry_count,
"agent_message":return_message,
}
except Exception as e:
return {
"next_node":'end_node',
"agent_message": "执行失败",
"error": str(e) or '系统错误,工作流已终止',
'status':'failed',
}
async def script_analysis_node(self, state: ScriptwriterState)-> ScriptwriterState:
"""第二步:诊断分析与资产评估"""
print("\n--- 正在进行诊断分析 ---")
session_id = state.get("session_id", "")
print(f"报告已生成: TEST")
return {}
async def confirm_analysis_node(self, state: ScriptwriterState)-> ScriptwriterState:
"""用户确认分析报告节点"""
print("\n等待用户确认分析报告...")
return {}
async def strategic_planning_node(self, state: ScriptwriterState)-> ScriptwriterState:
"""第三步:确立改编目标与战略蓝图"""
print("\n--- 正在制定战略蓝图 ---")
print(f"战略蓝图已生成: TEST")
return {}
async def build_bible_node(self, state: ScriptwriterState)-> ScriptwriterState:
"""第四步:确立改编目标与战略蓝图"""
print("\n--- 正在制定战略蓝图 ---")
print(f"战略蓝图已生成: TEST")
return {}
async def episode_create_node(self, state: ScriptwriterState)-> ScriptwriterState:
"""第五步:动态创作与闭环校验(循环主体)"""
num_episodes = 3 # 假设每次创作3集
episode_list = []
return {"episode_list": episode_list}
async def pause_node(self, state: ScriptwriterState)-> ScriptwriterState:
""" 暂停节点 处理并完成所有数据状态 """
print(f"langgraph 暂停等待")
return {
"session_id": state.get("session_id", ""),
"status": state.get('status', ''),
"error": state.get('error', ''),
"agent_message": state.get('agent_message', '')
}
async def end_node(self, state: ScriptwriterState)-> OutputState:
""" 结束节点 处理并完成所有数据状态 """
print(f"langgraph 所有任务完成")
return {
"session_id": state.get("session_id", ""),
"status": state.get('status', ''),
"error": state.get('error', ''),
"agent_message": state.get('agent_message', ''),
}
async def run(self, session_id: str, input_data: list[AnyMessage], thread_id: str|None = None) -> OutputState:
"""运行工作流
Args:
session_id: 会话ID
input_data: 输入数据
thread_id: 线程ID
Returns:
工作流执行结果
"""
try:
logger.info("开始运行智能编剧工作流")
# 配置包含线程 ID
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
# 初始化状态
initial_state: InputState = {
'input_data': input_data,
'session_id': session_id,
'from_type': 'user',
}
# 运行工作流
if self.graph is None:
raise RuntimeError("工作流图未正确初始化")
result = await self.graph.ainvoke(
initial_state,
config,
# stream_mode='values'
)
# logger.info(f"工作流执行结果: {result}")
if not result:
raise ValueError("工作流执行结果为空")
# 构造输出状态
output_result: OutputState = {
'session_id': result.get('session_id', ''),
'status': result.get('status', ''),
'error': result.get('error', ''),
'agent_message': result.get('agent_message', ''),
}
logger.info("智能编剧工作流运行完成")
return output_result
except Exception as e:
logger.error(f"运行工作流失败: {e}")
import traceback
traceback.print_exc()
raise
async def get_checkpoint_history(self, thread_id: str):
"""获取检查点历史"""
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
try:
history_generator = self.memory.list(config, limit=10)
print("正在获取检查点历史...")
# 使用列表推导式或 for 循环来收集所有检查点
history = list(history_generator)
print(f"找到 {len(history)} 个检查点:")
for i, checkpoint_tuple in enumerate(history):
# checkpoint_tuple 包含 config, checkpoint, metadata 等属性
# print(f" - ID: {checkpoint_tuple}")
checkpoint_data = checkpoint_tuple.checkpoint
metadata = checkpoint_tuple.metadata
print(f"检查点 {i+1}:")
print(f" - ID: {checkpoint_data.get('id', 'N/A')}")
print(f" - 状态: {checkpoint_data.get('channel_values', {})}")
print(f" - 元数据: {metadata}")
print("-" * 50)
except Exception as e:
print(f"获取历史记录时出错: {e}")
def resume_from_checkpoint(self, thread_id: str, checkpoint_id: str):
"""从检查点恢复执行"""
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
if checkpoint_id:
config["configurable"]["checkpoint_id"] = checkpoint_id
try:
# 获取 CheckpointTuple 对象
checkpoint_tuple = self.memory.get_tuple(config)
if checkpoint_tuple:
# 直接通过属性访问,而不是解包
checkpoint_data = checkpoint_tuple.checkpoint
metadata = checkpoint_tuple.metadata
print(f"从检查点恢复:")
print(f" - 检查点 ID: {checkpoint_data.get('id', 'N/A')}")
print(f" - 状态: {checkpoint_data.get('channel_values', {})}")
print(f" - 元数据: {metadata}")
return checkpoint_data.get('channel_values', {})
else:
print(f"未找到线程 {thread_id} 的检查点")
return None
except Exception as e:
print(f"恢复检查点时出错: {e}")
return None
def get_graph_visualization(self) -> str:
"""获取工作流图的可视化表示
Returns:
图的文本表示
"""
try:
if self.graph:
with open('graph_visualization.png', 'wb') as f:
f.write(self.graph.get_graph().draw_mermaid_png())
print("图片已保存为 graph_visualization.png")
return "工作流图未初始化"
except Exception as e:
logger.error(f"获取图可视化失败: {e}")
return f"获取图可视化失败: {e}"
if __name__ == "__main__":
import asyncio
async def main():
print("测试")
graph = ScriptwriterGraph()
print("创建完成")
# graph.get_graph_visualization()
# print("可视化完成")
# 运行工作流
session_id = "68c2c2915e5746343301ef71"
result = await graph.run(
session_id,
[HumanMessage(content="你好编剧,我想写小说!")],
session_id
)
print(f"最终结果: {result}")
asyncio.run(main())