772 lines
32 KiB
Python
772 lines
32 KiB
Python
"""智能编剧系统工作流图定义
|
||
|
||
该模块定义了智能编剧系统的完整工作流程图,包括各个节点和边的连接关系。
|
||
"""
|
||
|
||
from re import T
|
||
from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional
|
||
|
||
from langgraph.graph.state import RunnableConfig
|
||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||
from agent.scheduler import SchedulerAgent
|
||
from agent.build_bible import BuildBibleAgent
|
||
from agent.episode_create import EpisodeCreateAgent
|
||
from agent.script_analysis import ScriptAnalysisAgent
|
||
from agent.strategic_planning import StrategicPlanningAgent
|
||
|
||
from langgraph.checkpoint.memory import InMemorySaver
|
||
from langchain_core.messages import AnyMessage,HumanMessage
|
||
from langgraph.graph import StateGraph, START, END
|
||
from utils.logger import get_logger
|
||
import operator
|
||
import json
|
||
import config
|
||
|
||
from tools.database.mongo import client # type: ignore
|
||
from langgraph.checkpoint.mongodb import MongoDBSaver
|
||
# 工具方法
|
||
from tools.agent.queryDB import QueryOriginalScript,QueryDiagnosisAndAssessment,QueryAdaptationIdeas,QueryScriptBible,QueryEpisodeCount
|
||
from tools.agent.updateDB import UpdateAdaptationIdeasTool,UpdateScriptBibleTool,UpdateDiagnosisAndAssessmentTool,UpdateOneEpisodeTool
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
def clear_messages(messages):
|
||
"""清除指定会话的所有消息"""
|
||
# 剔除历史状态消息
|
||
exclude_strings = [
|
||
"---任务状态消息(开始)---",
|
||
"---原始剧本(开始)---",
|
||
"---诊断与资产评估报告(开始)---",
|
||
"---改编思路(开始)---",
|
||
]
|
||
messages = [
|
||
message for message in messages
|
||
if all(s not in message.content for s in exclude_strings)
|
||
]
|
||
# HumanMessage 超过 10 条,删除最早的 1 条
|
||
if len([message for message in messages if message.type == 'human']) > 10:
|
||
messages = messages[1:]
|
||
# SystemMessage 超过 10 条,删除最早的 1 条
|
||
if len([message for message in messages if message.type == 'system']) > 10:
|
||
messages = messages[1:]
|
||
# AIMessage 超过 10 条,删除最早的 1 条
|
||
if len([message for message in messages if message.type == 'ai']) > 10:
|
||
messages = messages[1:]
|
||
return messages
|
||
|
||
def messages_handler(old_messages: list[AnyMessage], new_messages: list[AnyMessage]):
|
||
"""消息合并方法"""
|
||
clear_messages(old_messages)
|
||
new_messages = [message for message in new_messages if message.type != 'ai' or message.content]
|
||
return old_messages + new_messages
|
||
|
||
def replace_value(old_val, new_val):
|
||
"""值覆盖方法"""
|
||
return new_val
|
||
|
||
# 状态类型定义
|
||
class InputState(TypedDict):
|
||
"""工作流输入状态"""
|
||
messages: Annotated[list[AnyMessage], messages_handler]
|
||
from_type: Annotated[str, replace_value]
|
||
session_id: Annotated[str, replace_value]
|
||
|
||
class OutputState(TypedDict):
|
||
"""工作流输出状态"""
|
||
session_id: Annotated[str, replace_value]
|
||
status: Annotated[str, replace_value]
|
||
error: Annotated[str, replace_value]
|
||
agent_message: Annotated[str, replace_value] # 智能体回复
|
||
|
||
class TaskItem(TypedDict):
|
||
"""任务列表中的每个任务"""
|
||
agent: Annotated[str, replace_value] # 执行这个任务的智能体名称
|
||
step: Annotated[str, replace_value] # 阶段名称 [wait_for_input,script_analysis,strategic_planning,build_bible,episode_create_loop, finish]
|
||
status: Annotated[str, replace_value] # 当前阶段的状态 [waiting,running,failed,completed]
|
||
reason: Annotated[str, replace_value] # 失败原因
|
||
retry_count: Annotated[int, replace_value] # 重试次数
|
||
pause: Annotated[bool, replace_value] # 是否暂停
|
||
episode_create_num: Annotated[List[int], replace_value] # 指定创作集数 仅在episode_create_loop阶段会返回,内容是数组,数组中每一项是指定创作的剧集编号(从1开始);
|
||
|
||
class WorkflowInfo(TypedDict):
|
||
"""总工作流信息"""
|
||
is_original_script: Annotated[bool, replace_value] # 是否已提交原始剧本
|
||
is_script_analysis: Annotated[bool, replace_value] # 是否已生成 诊断与资产评估报告
|
||
is_strategic_planning: Annotated[bool, replace_value] # 是否已生成 改编思路
|
||
is_build_bible: Annotated[bool, replace_value] # 是否已生成 剧本圣经
|
||
is_episode_create_loop: Annotated[bool, replace_value] # 是否已生成 剧集生成循环
|
||
is_all_episode_created: Annotated[bool, replace_value] # 是否已生成 全部剧集
|
||
|
||
|
||
class ScriptwriterState(TypedDict, total=False):
|
||
"""智能编剧工作流整体状态"""
|
||
# 输入数据
|
||
messages: Annotated[list[AnyMessage], messages_handler]
|
||
session_id: Annotated[str, replace_value]
|
||
from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent]
|
||
|
||
# 中间状态
|
||
workflow_info: Annotated[WorkflowInfo, replace_value] # 顺序执行的任务列表
|
||
task_list: Annotated[List[TaskItem], replace_value] # 顺序执行的任务列表
|
||
task_index: Annotated[int, replace_value] # 当前执行中的任务索引
|
||
|
||
# 输出数据
|
||
agent_message: Annotated[str, replace_value] # 智能体回复
|
||
status: Annotated[str, replace_value]
|
||
error: Annotated[str, replace_value]
|
||
|
||
AgentNodeMap = {
|
||
"scheduler": "scheduler_node",
|
||
"script_analysis": "script_analysis_node",
|
||
"strategic_planning": "strategic_planning_node",
|
||
"build_bible": "build_bible_node",
|
||
"episode_create": "episode_create_node",
|
||
}
|
||
|
||
class ScriptwriterGraph:
|
||
"""智能编剧工作流图类
|
||
|
||
管理智能编剧系统的完整工作流程,包括:
|
||
- 剧本接收
|
||
- 诊断分析
|
||
- 策略制定
|
||
- 剧本圣经构建
|
||
- 剧本创作
|
||
- 迭代调整
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""初始化工作流图"""
|
||
self.graph = None
|
||
self.memory = MongoDBSaver(client, db_name=config.MONGO_CHECKPOINT_DB_NAME)
|
||
self._build_graph()
|
||
|
||
|
||
def node_router(self, state: ScriptwriterState) -> str:
|
||
"""节点路由函数"""
|
||
# return 'end_node'
|
||
# print(f'node_router state {state}')
|
||
task_list = state.get("task_list", [])
|
||
task_index = state.get("task_index", 0)
|
||
now_task = task_list[task_index]
|
||
print(f'node_router now_task {now_task}')
|
||
if not now_task or now_task.get('pause'):
|
||
next_node = 'pause_node' # 设置为暂停节点
|
||
else:
|
||
next_node = AgentNodeMap.get(now_task.get('agent'), 'pause_node')
|
||
print(f'node_router next_node {next_node}')
|
||
return next_node
|
||
|
||
|
||
def post_scheduler_hook(self, state: ScriptwriterState)-> ScriptwriterState:
|
||
"""模型调用后的钩子函数"""
|
||
ai_message_str = state.get("messages", [])[-1].content
|
||
logger.info(f"!!!!!!!!!调度节点输出!!!!!!!!!!!")
|
||
logger.info(f"调度节点结果: {ai_message_str}")
|
||
# logger.info(f"调度节点结果 end")
|
||
# ai_message = json.loads(ai_message_str)
|
||
return state
|
||
|
||
def _build_graph(self) -> None:
|
||
"""构建工作流图"""
|
||
try:
|
||
# 创建智能体
|
||
print("创建智能体")
|
||
from tools.llm.deepseek_langchain import DeepseekChatModel
|
||
llm = DeepseekChatModel(api_key='sk-571923fdcb0e493d8def7e2d78c02cb8')
|
||
# 调度智能体
|
||
self.schedulerAgent = SchedulerAgent(
|
||
# llm=llm,
|
||
tools=[
|
||
QueryOriginalScript,
|
||
QueryDiagnosisAndAssessment,
|
||
QueryAdaptationIdeas,
|
||
QueryScriptBible,
|
||
QueryEpisodeCount,
|
||
],
|
||
)
|
||
self.scriptAnalysisAgent = ScriptAnalysisAgent(
|
||
tools=[],
|
||
# SchedulerList=[
|
||
# {
|
||
# "name": "scheduler_node",
|
||
# "desc": "调度智能体节点",
|
||
# }
|
||
# ]
|
||
)
|
||
self.strategicPlanningAgent = StrategicPlanningAgent(
|
||
tools=[],
|
||
# SchedulerList=[
|
||
# {
|
||
# "name": "scheduler_node",
|
||
# "desc": "调度智能体节点",
|
||
# }
|
||
# ]
|
||
)
|
||
self.buildBibleAgent = BuildBibleAgent(
|
||
tools=[],
|
||
# SchedulerList=[
|
||
# {
|
||
# "name": "scheduler_node",
|
||
# "desc": "调度智能体节点",
|
||
# }
|
||
# ]
|
||
)
|
||
self.episodeCreateAgent = EpisodeCreateAgent(
|
||
tools=[],
|
||
# SchedulerList=[
|
||
# {
|
||
# "name": "scheduler_node",
|
||
# "desc": "调度智能体节点",
|
||
# }
|
||
# ]
|
||
)
|
||
|
||
# 创建状态图
|
||
logger.info("创建状态图")
|
||
workflow = StateGraph(ScriptwriterState, input_schema=InputState, output_schema=OutputState)
|
||
|
||
# 添加节点
|
||
logger.info("添加节点")
|
||
workflow.add_node("scheduler_node", self.scheduler_node)
|
||
workflow.add_node("script_analysis_node", self.script_analysis_node)
|
||
workflow.add_node("strategic_planning_node", self.strategic_planning_node)
|
||
workflow.add_node("build_bible_node", self.build_bible_node)
|
||
workflow.add_node("episode_create_node", self.episode_create_node)
|
||
workflow.add_node("end_node", self.end_node)
|
||
workflow.add_node("pause_node", self.pause_node)
|
||
|
||
# 添加边
|
||
workflow.set_entry_point("scheduler_node")
|
||
# 所有功能节点执行完成后,都返回给调度节点
|
||
workflow.add_edge("script_analysis_node", "scheduler_node")
|
||
workflow.add_edge("strategic_planning_node", "scheduler_node")
|
||
workflow.add_edge("build_bible_node", "scheduler_node")
|
||
workflow.add_edge("episode_create_node", "scheduler_node")
|
||
|
||
# 添加条件边:由调度节点决定下一个路由
|
||
workflow.add_conditional_edges(
|
||
"scheduler_node",
|
||
self.node_router,
|
||
{
|
||
"script_analysis_node": "script_analysis_node",
|
||
"strategic_planning_node": "strategic_planning_node",
|
||
"build_bible_node": "build_bible_node",
|
||
"episode_create_node": "episode_create_node",
|
||
# 用户确认和暂停逻辑在这里处理,不需要单独的边
|
||
"end_node": "end_node",
|
||
"pause_node": "pause_node",
|
||
}
|
||
)
|
||
|
||
workflow.add_edge("end_node", END)
|
||
|
||
# 编译图
|
||
checkpoint = InMemorySaver()
|
||
self.graph = workflow.compile(checkpointer=checkpoint) # 不缓存记忆
|
||
# self.graph = workflow.compile(checkpointer=self.memory) # 使用mongodb缓存记忆
|
||
logger.info("工作流图构建完成")
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
logger.error(f"构建工作流图失败: {e}")
|
||
raise
|
||
|
||
|
||
# --- 定义图中的节点 ---
|
||
async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
||
"""调度节点"""
|
||
try:
|
||
session_id = state.get("session_id", "")
|
||
workflow_info = state.get("workflow_info", {})
|
||
task_list = state.get("task_list", [])
|
||
task_index = int(state.get("task_index", 0))
|
||
from_type = state.get("from_type", "")
|
||
messages = state.get("messages", [])
|
||
# 清除历史状态消息
|
||
messages = clear_messages(messages)
|
||
# 添加参数进提示词
|
||
# 工作流信息:
|
||
# {workflow_info}
|
||
messages.append(HumanMessage(content=f"""
|
||
---任务状态消息(开始)---
|
||
# 工作流参数:
|
||
{{
|
||
'session_id':'{session_id}',
|
||
'task_index':'{task_index}',
|
||
'from_type':'{from_type}',
|
||
}}
|
||
# 任务列表:
|
||
{task_list}
|
||
---任务状态消息(结束)---
|
||
"""))
|
||
system_message_count = 0
|
||
human_message_count = 0
|
||
ai_message_count = 0
|
||
for message in messages:
|
||
if message.type == 'system':
|
||
system_message_count += 1
|
||
elif message.type == 'human':
|
||
human_message_count += 1
|
||
elif message.type == 'ai':
|
||
ai_message_count += 1
|
||
logger.info(f"调度节点 {session_id} \n 输入消息条数: {len(messages)} \n from_type:{from_type} \n system_message_count:{system_message_count} \n human_message_count:{human_message_count} \n ai_message_count:{ai_message_count}")
|
||
# 调用智能体
|
||
reslut = await self.schedulerAgent.ainvoke({
|
||
"messages": messages,
|
||
"session_id": session_id,
|
||
})
|
||
ai_message_str = reslut['messages'][-1].content
|
||
logger.info(f"调度节点结果: {ai_message_str}")
|
||
logger.info(f"调度节点结果 end")
|
||
ai_message = json.loads(ai_message_str)
|
||
# logger.info(f"调度节点结果: {ai_message}")
|
||
return_message:str = ai_message.get('message', '')
|
||
task_list:list = ai_message.get('task_list', [])
|
||
task_index:int = int(ai_message.get('task_index', '0'))
|
||
|
||
return {
|
||
"agent_message": return_message,
|
||
"task_list": task_list,
|
||
"task_index": task_index,
|
||
}
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return {
|
||
"agent_message": "执行失败",
|
||
"error": str(e) or '系统错误,工作流已终止',
|
||
'status':'failed',
|
||
}
|
||
|
||
async def script_analysis_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
||
"""第二步:诊断分析与资产评估"""
|
||
try:
|
||
print("\n------------ 正在进行诊断分析 ------------")
|
||
session_id = state.get("session_id", "")
|
||
from_type = state.get("from_type", "")
|
||
messages = state.get("messages", [])
|
||
# 清除历史状态消息
|
||
messages = clear_messages(messages)
|
||
# 添加参数进提示词
|
||
from tools.agent.queryDB import QueryOriginalScriptContent
|
||
original_script_content = QueryOriginalScriptContent(session_id)
|
||
messages.append(HumanMessage(content=f"""
|
||
---原始剧本(开始)---
|
||
{original_script_content['content']}
|
||
---原始剧本(结束)---
|
||
"""))
|
||
reslut = await self.scriptAnalysisAgent.ainvoke({"messages": messages})
|
||
ai_message_str = reslut['messages'][-1].content
|
||
ai_message = json.loads(ai_message_str)
|
||
return_message:str = ai_message.get('message', '')
|
||
|
||
task_list:list = state.get('task_list', [])
|
||
task_index:int = int(state.get('task_index', '0'))
|
||
task = task_list[task_index]
|
||
if task:
|
||
type:str = ai_message.get('type', '沟通')
|
||
if type == '沟通':
|
||
task['status'] = 'waiting'
|
||
task['parse'] = True
|
||
elif type == '输出':
|
||
task['status'] = 'completed'
|
||
task['parse'] = False
|
||
diagnosis_and_assessment:str = ai_message.get('diagnosis_and_assessment', '')
|
||
from tools.agent.updateDB import UpdateDiagnosisAndAssessment
|
||
UpdateDiagnosisAndAssessment(session_id, diagnosis_and_assessment)
|
||
# print(f"报告已生成: TEST")
|
||
print("\n------------ 诊断分析结束 ------------")
|
||
|
||
return {
|
||
"messages": messages,
|
||
"task_list": task_list,
|
||
"task_index": task_index,
|
||
"agent_message": return_message,
|
||
}
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return {
|
||
"agent_message": "诊断分析失败",
|
||
"error": str(e) or '系统错误,工作流已终止',
|
||
'status':'failed',
|
||
}
|
||
|
||
async def strategic_planning_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
||
"""第三步:确立改编目标与战略蓝图"""
|
||
try:
|
||
print("\n------------ 正在生成 改编思路 ------------")
|
||
session_id = state.get("session_id", "")
|
||
from_type = state.get("from_type", "")
|
||
messages = state.get("messages", [])
|
||
# 清除历史状态消息
|
||
messages = clear_messages(messages)
|
||
# 添加参数进提示词
|
||
from tools.agent.queryDB import QueryOriginalScriptContent, QueryDiagnosisAndAssessmentContent
|
||
original_script_content = QueryOriginalScriptContent(session_id)
|
||
messages.append(HumanMessage(content=f"""
|
||
---原始剧本(开始)---
|
||
{original_script_content['content']}
|
||
---原始剧本(结束)---
|
||
"""))
|
||
diagnosis_and_assessment_content = QueryDiagnosisAndAssessmentContent(session_id)
|
||
messages.append(HumanMessage(content=f"""
|
||
---诊断与资产评估报告(开始)---
|
||
{diagnosis_and_assessment_content['content']}
|
||
---诊断与资产评估报告(结束)---
|
||
"""))
|
||
reslut = await self.strategicPlanningAgent.ainvoke({"messages": messages})
|
||
ai_message_str = reslut['messages'][-1].content
|
||
ai_message = json.loads(ai_message_str)
|
||
return_message:str = ai_message.get('message', '')
|
||
|
||
task_list:list = state.get('task_list', [])
|
||
task_index:int = int(state.get('task_index', '0'))
|
||
task = task_list[task_index]
|
||
if task:
|
||
type:str = ai_message.get('type', '沟通')
|
||
if type == '沟通':
|
||
task['status'] = 'waiting'
|
||
task['parse'] = True
|
||
elif type == '输出':
|
||
task['status'] = 'completed'
|
||
task['parse'] = False
|
||
adaptation_ideas:str = ai_message.get('adaptation_ideas', '')
|
||
total_episode_num:int = int(ai_message.get('total_episode_num', 0))
|
||
from tools.agent.updateDB import UpdateAdaptationIdeas,SetTotalEpisodeNum
|
||
UpdateAdaptationIdeas(session_id, adaptation_ideas)
|
||
SetTotalEpisodeNum(session_id, total_episode_num)
|
||
# print(f"报告已生成: TEST")
|
||
print("\n------------ 生成 改编思路 结束 ------------")
|
||
|
||
return {
|
||
"messages": messages,
|
||
"task_list": task_list,
|
||
"task_index": task_index,
|
||
"agent_message": return_message,
|
||
}
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return {
|
||
"agent_message": "生成 改编思路 失败",
|
||
"error": str(e) or '系统错误,工作流已终止',
|
||
'status':'failed',
|
||
}
|
||
|
||
async def build_bible_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
||
"""第四步:制定剧本圣经"""
|
||
try:
|
||
print("\n------------ 正在生成 剧本圣经 ------------")
|
||
session_id = state.get("session_id", "")
|
||
from_type = state.get("from_type", "")
|
||
messages = state.get("messages", [])
|
||
# 清除历史状态消息
|
||
messages = clear_messages(messages)
|
||
# 添加参数进提示词
|
||
from tools.agent.queryDB import QueryOriginalScriptContent, QueryAdaptationIdeasContent
|
||
original_script_content = QueryOriginalScriptContent(session_id)
|
||
messages.append(HumanMessage(content=f"""
|
||
---原始剧本(开始)---
|
||
{original_script_content['content']}
|
||
---原始剧本(结束)---
|
||
"""))
|
||
adaptation_ideas_content = QueryAdaptationIdeasContent(session_id)
|
||
messages.append(HumanMessage(content=f"""
|
||
---改编思路(开始)---
|
||
{adaptation_ideas_content['content']}
|
||
---改编思路(结束)---
|
||
"""))
|
||
reslut = await self.buildBibleAgent.ainvoke({"messages": messages})
|
||
ai_message_str = reslut['messages'][-1].content
|
||
ai_message = json.loads(ai_message_str)
|
||
return_message:str = ai_message.get('message', '')
|
||
|
||
task_list:list = state.get('task_list', [])
|
||
task_index:int = int(state.get('task_index', '0'))
|
||
task = task_list[task_index]
|
||
if task:
|
||
type:str = ai_message.get('type', '沟通')
|
||
if type == '沟通':
|
||
task['status'] = 'waiting'
|
||
task['parse'] = True
|
||
elif type == '输出':
|
||
task['status'] = 'completed'
|
||
task['parse'] = False
|
||
script_bible:dict = ai_message.get('script_bible', {})
|
||
core_outline:str = script_bible.get('core_outline', '')
|
||
character_profile:str = script_bible.get('character_profile', '')
|
||
core_event_timeline:str = script_bible.get('core_event_timeline', '')
|
||
character_list:str = script_bible.get('character_list', '')
|
||
from tools.agent.updateDB import UpdateScriptBible
|
||
UpdateScriptBible(session_id, core_outline, character_profile, core_event_timeline, character_list)
|
||
# print(f"报告已生成: TEST")
|
||
print("\n------------ 生成 剧本圣经 结束 ------------")
|
||
|
||
return {
|
||
"messages": messages,
|
||
"task_list": task_list,
|
||
"task_index": task_index,
|
||
"agent_message": return_message,
|
||
}
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return {
|
||
"agent_message": "生成 剧本圣经 失败",
|
||
"error": str(e) or '系统错误,工作流已终止',
|
||
'status':'failed',
|
||
}
|
||
|
||
|
||
async def episode_create_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
||
"""第五步:循环创作剧本内容"""
|
||
try:
|
||
session_id = state.get("session_id", "")
|
||
from_type = state.get("from_type", "")
|
||
task_list:list = state.get('task_list', [])
|
||
task_index:int = int(state.get('task_index', '0'))
|
||
task:dict = task_list[task_index] or {}
|
||
episode_create_num:list = task.get('episode_create_num', [])
|
||
print(f"\n------------ 正在生成 剧集内容 {episode_create_num} ------------")
|
||
messages = state.get("messages", [])
|
||
# 清除历史状态消息
|
||
messages = clear_messages(messages)
|
||
# 添加参数进提示词
|
||
from tools.agent.queryDB import QueryOriginalScriptContent, QueryAdaptationIdeasContent
|
||
original_script_content = QueryOriginalScriptContent(session_id)
|
||
messages.append(HumanMessage(content=f"""
|
||
---原始剧本(开始)---
|
||
{original_script_content['content']}
|
||
---原始剧本(结束)---
|
||
"""))
|
||
adaptation_ideas_content = QueryAdaptationIdeasContent(session_id)
|
||
messages.append(HumanMessage(content=f"""
|
||
---改编思路(开始)---
|
||
{adaptation_ideas_content['content']}
|
||
---改编思路(结束)---
|
||
"""))
|
||
# 添加参数进提示词
|
||
messages.append(HumanMessage(content=f"""
|
||
---任务状态消息(开始)---
|
||
# 工作流参数:
|
||
{{
|
||
'episode_create_num':'{episode_create_num}',
|
||
}}
|
||
---任务状态消息(结束)---
|
||
"""))
|
||
reslut = await self.buildBibleAgent.ainvoke({"messages": messages})
|
||
ai_message_str = reslut['messages'][-1].content
|
||
ai_message = json.loads(ai_message_str)
|
||
return_message:str = ai_message.get('message', '')
|
||
|
||
type:str = ai_message.get('type', '沟通')
|
||
if type == '沟通':
|
||
task['status'] = 'waiting'
|
||
task['parse'] = True
|
||
elif type == '输出':
|
||
task['status'] = 'completed'
|
||
task['parse'] = False
|
||
episodes:list = ai_message.get('episodes', [])
|
||
from tools.agent.updateDB import UpdateOneEpisode
|
||
for episode in episodes:
|
||
UpdateOneEpisode(session_id, episode.number, episode.content)
|
||
# print(f"报告已生成: TEST")
|
||
print("\n------------ 生成 剧本圣经 结束 ------------")
|
||
|
||
return {
|
||
"messages": messages,
|
||
"task_list": task_list,
|
||
"task_index": task_index,
|
||
"agent_message": return_message,
|
||
}
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return {
|
||
"agent_message": "生成 剧本圣经 失败",
|
||
"error": str(e) or '系统错误,工作流已终止',
|
||
'status':'failed',
|
||
}
|
||
|
||
async def pause_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
||
""" 暂停节点 处理并完成所有数据状态 """
|
||
print(f"langgraph 暂停等待")
|
||
return {
|
||
"session_id": state.get("session_id", ""),
|
||
"status": state.get('status', ''),
|
||
"error": state.get('error', ''),
|
||
"agent_message": state.get('agent_message', '')
|
||
}
|
||
|
||
async def end_node(self, state: ScriptwriterState)-> OutputState:
|
||
""" 结束节点 处理并完成所有数据状态 """
|
||
print(f"langgraph 所有任务完成")
|
||
return {
|
||
"session_id": state.get("session_id", ""),
|
||
"status": state.get('status', ''),
|
||
"error": state.get('error', ''),
|
||
"agent_message": state.get('agent_message', ''),
|
||
}
|
||
|
||
async def run(self, session_id: str, messages: list[AnyMessage], thread_id: str|None = None) -> OutputState:
|
||
"""运行工作流
|
||
Args:
|
||
session_id: 会话ID
|
||
messages: 输入数据
|
||
thread_id: 线程ID
|
||
|
||
Returns:
|
||
工作流执行结果
|
||
"""
|
||
try:
|
||
logger.info("开始运行智能编剧工作流")
|
||
output_result: OutputState = {
|
||
'session_id': session_id,
|
||
'status': '',
|
||
'error': '',
|
||
'agent_message': '',
|
||
}
|
||
# 配置包含线程 ID
|
||
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
|
||
# 初始化状态
|
||
initial_state: InputState = {
|
||
'messages': messages,
|
||
'session_id': session_id,
|
||
'from_type': 'user',
|
||
}
|
||
# 运行工作流
|
||
if self.graph is None:
|
||
raise RuntimeError("工作流图未正确初始化")
|
||
# 运行工作流 直接返回最终结果
|
||
result = await self.graph.ainvoke(
|
||
initial_state,
|
||
config,
|
||
# stream_mode='values'
|
||
)
|
||
# logger.info(f"工作流执行结果: {result}")
|
||
if not result:
|
||
raise ValueError("工作流执行结果为空")
|
||
# 构造输出状态
|
||
output_result = {
|
||
'session_id': result.get('session_id', ''),
|
||
'status': result.get('status', ''),
|
||
'error': result.get('error', ''),
|
||
'agent_message': result.get('agent_message', ''),
|
||
}
|
||
|
||
# 流式处理
|
||
# st = self.graph.stream(
|
||
# initial_state,
|
||
# config,
|
||
# stream_mode='values'
|
||
# )
|
||
# from utils.stream import debug_print_stream
|
||
# debug_print_stream(st)
|
||
|
||
logger.info("智能编剧工作流运行完成")
|
||
return output_result
|
||
|
||
except Exception as e:
|
||
logger.error(f"运行工作流失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
raise
|
||
|
||
|
||
async def get_checkpoint_history(self, thread_id: str):
|
||
"""获取检查点历史"""
|
||
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
|
||
|
||
try:
|
||
history_generator = self.memory.list(config, limit=10)
|
||
|
||
print("正在获取检查点历史...")
|
||
|
||
# 使用列表推导式或 for 循环来收集所有检查点
|
||
history = list(history_generator)
|
||
|
||
print(f"找到 {len(history)} 个检查点:")
|
||
|
||
for i, checkpoint_tuple in enumerate(history):
|
||
# checkpoint_tuple 包含 config, checkpoint, metadata 等属性
|
||
# print(f" - ID: {checkpoint_tuple}")
|
||
checkpoint_data = checkpoint_tuple.checkpoint
|
||
metadata = checkpoint_tuple.metadata
|
||
print(f"检查点 {i+1}:")
|
||
print(f" - ID: {checkpoint_data.get('id', 'N/A')}")
|
||
print(f" - 状态: {checkpoint_data.get('channel_values', {})}")
|
||
print(f" - 元数据: {metadata}")
|
||
print("-" * 50)
|
||
|
||
except Exception as e:
|
||
print(f"获取历史记录时出错: {e}")
|
||
|
||
def resume_from_checkpoint(self, thread_id: str, checkpoint_id: str):
|
||
"""从检查点恢复执行"""
|
||
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
|
||
|
||
if checkpoint_id:
|
||
config["configurable"]["checkpoint_id"] = checkpoint_id
|
||
|
||
try:
|
||
# 获取 CheckpointTuple 对象
|
||
checkpoint_tuple = self.memory.get_tuple(config)
|
||
|
||
if checkpoint_tuple:
|
||
# 直接通过属性访问,而不是解包
|
||
checkpoint_data = checkpoint_tuple.checkpoint
|
||
metadata = checkpoint_tuple.metadata
|
||
print(f"从检查点恢复:")
|
||
print(f" - 检查点 ID: {checkpoint_data.get('id', 'N/A')}")
|
||
print(f" - 状态: {checkpoint_data.get('channel_values', {})}")
|
||
print(f" - 元数据: {metadata}")
|
||
return checkpoint_data.get('channel_values', {})
|
||
else:
|
||
print(f"未找到线程 {thread_id} 的检查点")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"恢复检查点时出错: {e}")
|
||
return None
|
||
|
||
|
||
def get_graph_visualization(self) -> str:
|
||
"""获取工作流图的可视化表示
|
||
|
||
Returns:
|
||
图的文本表示
|
||
"""
|
||
try:
|
||
if self.graph:
|
||
with open('graph_visualization.png', 'wb') as f:
|
||
f.write(self.graph.get_graph().draw_mermaid_png())
|
||
print("图片已保存为 graph_visualization.png")
|
||
return "工作流图未初始化"
|
||
except Exception as e:
|
||
logger.error(f"获取图可视化失败: {e}")
|
||
return f"获取图可视化失败: {e}"
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import asyncio
|
||
|
||
async def main():
|
||
print("测试")
|
||
graph = ScriptwriterGraph()
|
||
print("创建完成")
|
||
# graph.get_graph_visualization()
|
||
# print("可视化完成")
|
||
# 运行工作流
|
||
session_id = "68c2c2915e5746343301ef71"
|
||
result = await graph.run(
|
||
session_id,
|
||
[HumanMessage(content="老师,我写好剧本了,您看看!帮我分析分析把!")],
|
||
session_id
|
||
)
|
||
print(f"最终结果: {result}")
|
||
|
||
asyncio.run(main())
|