292 lines
11 KiB
Python
292 lines
11 KiB
Python
"""智能编剧系统工作流图定义
|
|
|
|
该模块定义了智能编剧系统的完整工作流程图,包括各个节点和边的连接关系。
|
|
"""
|
|
|
|
from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional
|
|
from agent.scheduler import SchedulerAgent
|
|
from agent.build_bible import BuildBibleAgent
|
|
from agent.episode_create import EpisodeCreateAgent
|
|
from agent.script_analysis import ScriptAnalysisAgent
|
|
from agent.strategic_planning import StrategicPlanningAgent
|
|
|
|
from langgraph.graph import StateGraph, START, END
|
|
from utils.logger import get_logger
|
|
import operator
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# 定义一个简单的替换函数
|
|
def replace_value(old_val, new_val):
|
|
"""一个简单的合并函数,用于替换旧值"""
|
|
return new_val
|
|
|
|
# 状态类型定义
|
|
class InputState(TypedDict):
|
|
"""工作流输入状态"""
|
|
input_data: Annotated[Dict[str, Any], operator.add]
|
|
session_id: Annotated[str, replace_value]
|
|
|
|
class OutputState(TypedDict):
|
|
"""工作流输出状态"""
|
|
session_id: Annotated[str, replace_value]
|
|
status: Annotated[str, replace_value]
|
|
error: Annotated[str, replace_value]
|
|
|
|
class NodeInfo(TypedDict):
|
|
"""工作流信息"""
|
|
step: Annotated[str, replace_value] # 阶段名称 [wait_for_input,script_analysis,strategic_planning,build_bible,episode_create_loop, finish]
|
|
status: Annotated[str, replace_value] # 当前阶段的状态 [waiting,running,failed,completed]
|
|
reason: Annotated[str, replace_value] # 失败原因
|
|
retry_count: Annotated[int, replace_value] # 重试次数
|
|
from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent]
|
|
|
|
|
|
class ScriptwriterState(TypedDict, total=False):
|
|
"""智能编剧工作流整体状态"""
|
|
# 输入数据
|
|
input_data: Annotated[Dict[str, Any], operator.add]
|
|
session_id: Annotated[str, replace_value]
|
|
|
|
# 节点间状态
|
|
node_info: NodeInfo
|
|
|
|
# 中间状态
|
|
agent_script_id: Annotated[str, replace_value] # 剧本ID 包括原文
|
|
agent_plan: Annotated[Dict[str, Any], replace_value] #剧本计划
|
|
script_bible: Annotated[Dict[str, Any], replace_value] #剧本圣经
|
|
episode_list: Annotated[List, replace_value] # 章节列表 完成状态、产出章节id
|
|
|
|
# 输出数据
|
|
status: Annotated[str, replace_value]
|
|
error: Annotated[str, replace_value]
|
|
|
|
class ScriptwriterGraph:
|
|
"""智能编剧工作流图类
|
|
|
|
管理智能编剧系统的完整工作流程,包括:
|
|
- 剧本接收
|
|
- 诊断分析
|
|
- 策略制定
|
|
- 剧本圣经构建
|
|
- 剧本创作
|
|
- 迭代调整
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""初始化工作流图"""
|
|
self.graph = None
|
|
self._build_graph()
|
|
|
|
|
|
def node_router(self, state: ScriptwriterState) -> str:
|
|
next_node = state.get("node", '')
|
|
if next_node:
|
|
return next_node
|
|
else:
|
|
return END
|
|
|
|
def _build_graph(self) -> None:
|
|
"""构建工作流图"""
|
|
try:
|
|
# 创建智能体
|
|
# 调度智能体
|
|
schedulerAgent = SchedulerAgent(
|
|
tools=[],
|
|
SchedulerList=[
|
|
{
|
|
"scheduler_node": "调度智能体节点",
|
|
"script_analysis_node": "原始剧本分析节点",
|
|
"strategic_planning_node": "确立改编目标节点",
|
|
"build_bible_node": "剧本圣经构建节点",
|
|
"episode_create_node": "单集创作节点",
|
|
"end_node": "结束节点,任务失败终止时使用,结束后整个工作流将停止"
|
|
}
|
|
]
|
|
)
|
|
scriptAnalysisAgent = ScriptAnalysisAgent(
|
|
tools=[],
|
|
SchedulerList=[
|
|
{
|
|
"scheduler_node": "调度智能体节点",
|
|
}
|
|
]
|
|
)
|
|
strategicPlanningAgent = StrategicPlanningAgent(
|
|
tools=[],
|
|
SchedulerList=[
|
|
{
|
|
"scheduler_node": "调度智能体节点",
|
|
}
|
|
]
|
|
)
|
|
buildBibleAgent = BuildBibleAgent(
|
|
tools=[],
|
|
SchedulerList=[
|
|
{
|
|
"scheduler_node": "调度智能体节点",
|
|
}
|
|
]
|
|
)
|
|
episodeCreate = EpisodeCreateAgent(
|
|
tools=[],
|
|
SchedulerList=[
|
|
{
|
|
"scheduler_node": "调度智能体节点",
|
|
}
|
|
]
|
|
)
|
|
|
|
# 创建状态图
|
|
workflow = StateGraph(ScriptwriterState, input_schema=InputState, output_schema=OutputState)
|
|
|
|
# 添加节点
|
|
workflow.add_node("scheduler_node", self.scheduler_node)
|
|
workflow.add_node("script_analysis_node", self.script_analysis_node)
|
|
workflow.add_node("strategic_planning_node", self.strategic_planning_node)
|
|
workflow.add_node("build_bible_node", self.build_bible_node)
|
|
workflow.add_node("episode_create_node", self.episode_create_node)
|
|
workflow.add_node("end_node", self.end_node)
|
|
|
|
# 添加边
|
|
workflow.set_entry_point("scheduler_node")
|
|
# 所有功能节点执行完成后,都返回给调度节点
|
|
workflow.add_edge("script_analysis_node", "scheduler_node")
|
|
workflow.add_edge("strategic_planning_node", "scheduler_node")
|
|
workflow.add_edge("build_bible_node", "scheduler_node")
|
|
workflow.add_edge("episode_create_node", "scheduler_node")
|
|
|
|
# 添加条件边:由调度节点决定下一个路由
|
|
workflow.add_conditional_edges(
|
|
"scheduler_node",
|
|
self.node_router,
|
|
{
|
|
"script_analysis_node": "script_analysis_node",
|
|
"strategic_planning_node": "strategic_planning_node",
|
|
"build_bible_node": "build_bible_node",
|
|
"episode_create_node": "episode_create_node",
|
|
# 用户确认和暂停逻辑在这里处理,不需要单独的边
|
|
"end_node": "end_node",
|
|
}
|
|
)
|
|
|
|
workflow.add_edge("end_node", END)
|
|
|
|
# 编译图
|
|
self.graph = workflow.compile()
|
|
logger.info("工作流图构建完成")
|
|
|
|
except Exception as e:
|
|
logger.error(f"构建工作流图失败: {e}")
|
|
raise
|
|
|
|
# --- 定义图中的节点 ---
|
|
async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
|
"""第一步:初步沟通,请求剧本"""
|
|
session_id = state.get("session_id", "")
|
|
|
|
return {}
|
|
|
|
async def script_analysis_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
|
"""第二步:诊断分析与资产评估"""
|
|
print("\n--- 正在进行诊断分析 ---")
|
|
session_id = state.get("session_id", "")
|
|
print(f"报告已生成: TEST")
|
|
return {}
|
|
|
|
async def confirm_analysis_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
|
"""用户确认分析报告节点"""
|
|
print("\n等待用户确认分析报告...")
|
|
return {}
|
|
|
|
async def strategic_planning_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
|
"""第三步:确立改编目标与战略蓝图"""
|
|
print("\n--- 正在制定战略蓝图 ---")
|
|
print(f"战略蓝图已生成: TEST")
|
|
return {}
|
|
|
|
async def build_bible_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
|
"""第四步:确立改编目标与战略蓝图"""
|
|
print("\n--- 正在制定战略蓝图 ---")
|
|
print(f"战略蓝图已生成: TEST")
|
|
return {}
|
|
|
|
|
|
async def episode_create_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
|
"""第五步:动态创作与闭环校验(循环主体)"""
|
|
num_episodes = 3 # 假设每次创作3集
|
|
episode_list = []
|
|
return {"episode_list": episode_list}
|
|
|
|
async def end_node(self, state: ScriptwriterState)-> OutputState:
|
|
""" 结束节点 处理并完成所有数据状态 """
|
|
print(f"langgraph 所有任务完成")
|
|
return {
|
|
"session_id": state.get("session_id", ""),
|
|
"status": "",
|
|
"error": "",
|
|
}
|
|
|
|
async def run(self, input_data: Dict[str, Any]) -> OutputState:
|
|
"""运行工作流
|
|
|
|
Args:
|
|
input_data: 输入数据
|
|
|
|
Returns:
|
|
工作流执行结果
|
|
"""
|
|
try:
|
|
logger.info("开始运行智能编剧工作流")
|
|
|
|
# # 初始化状态
|
|
# initial_state: InputState = {
|
|
# 'input_data': input_data,
|
|
# 'session_id': input_data.get('session_id', ''),
|
|
# 'max_iterations': input_data.get('max_iterations', 3),
|
|
# 'batch_info': input_data.get('batch_info', {})
|
|
# }
|
|
|
|
# # 运行工作流
|
|
# if self.graph is None:
|
|
# raise RuntimeError("工作流图未正确初始化")
|
|
|
|
# result = await self.graph.ainvoke(initial_state)
|
|
# logger.info(f"工作流执行结果: {result}")
|
|
# if not result:
|
|
# raise ValueError("工作流执行结果为空")
|
|
# # 保存到记忆
|
|
# self.memory.save_workflow_result(result)
|
|
|
|
# # 构造输出状态
|
|
# output_result: OutputState = {
|
|
# 'script': result.get('script'),
|
|
# 'adjustment': result.get('adjustment'),
|
|
# 'error': result.get('error'),
|
|
# 'iteration_count': result.get('iteration_count', 0)
|
|
# }
|
|
output_result:OutputState = {
|
|
'session_id': "",
|
|
'status': 'completed',
|
|
'error': '',
|
|
}
|
|
logger.info("智能编剧工作流运行完成")
|
|
return output_result
|
|
|
|
except Exception as e:
|
|
logger.error(f"运行工作流失败: {e}")
|
|
raise
|
|
|
|
def get_graph_visualization(self) -> str:
|
|
"""获取工作流图的可视化表示
|
|
|
|
Returns:
|
|
图的文本表示
|
|
"""
|
|
try:
|
|
if self.graph:
|
|
return str(self.graph)
|
|
return "工作流图未初始化"
|
|
except Exception as e:
|
|
logger.error(f"获取图可视化失败: {e}")
|
|
return f"获取图可视化失败: {e}" |