调度智能体 记忆存储
This commit is contained in:
parent
e5f66e0d24
commit
750af43ff3
@ -35,7 +35,7 @@ DefaultAgentPrompt = f"""
|
|||||||
def create_agent_prompt(prompt, SchedulerList):
|
def create_agent_prompt(prompt, SchedulerList):
|
||||||
"""创建代理提示词的辅助函数"""
|
"""创建代理提示词的辅助函数"""
|
||||||
if not SchedulerList or len(SchedulerList) == 0: return prompt
|
if not SchedulerList or len(SchedulerList) == 0: return prompt
|
||||||
node_list = [f"{node.name}:{node.desc}" for node in SchedulerList]
|
node_list = [f"{node['name']}:{node['desc']}" for node in SchedulerList]
|
||||||
return f"""
|
return f"""
|
||||||
{prompt} \n
|
{prompt} \n
|
||||||
下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回:
|
下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回:
|
||||||
|
|||||||
@ -35,7 +35,7 @@ DefaultAgentPrompt = f"""
|
|||||||
def create_agent_prompt(prompt, SchedulerList):
|
def create_agent_prompt(prompt, SchedulerList):
|
||||||
"""创建代理提示词的辅助函数"""
|
"""创建代理提示词的辅助函数"""
|
||||||
if not SchedulerList or len(SchedulerList) == 0: return prompt
|
if not SchedulerList or len(SchedulerList) == 0: return prompt
|
||||||
node_list = [f"{node.name}:{node.desc}" for node in SchedulerList]
|
node_list = [f"{node['name']}:{node['desc']}" for node in SchedulerList]
|
||||||
return f"""
|
return f"""
|
||||||
{prompt} \n
|
{prompt} \n
|
||||||
下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回:
|
下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回:
|
||||||
|
|||||||
@ -97,6 +97,7 @@ DefaultAgentPrompt = f"""
|
|||||||
"status": "当前阶段的状态",//取值范围在上述 status的描述中 不可写其他值
|
"status": "当前阶段的状态",//取值范围在上述 status的描述中 不可写其他值
|
||||||
"agent":'',//分析后得出由哪个智能体继续任务,此处为智能体名称;如果需要继续与用户交互或仅需要回复用户则为空字符串
|
"agent":'',//分析后得出由哪个智能体继续任务,此处为智能体名称;如果需要继续与用户交互或仅需要回复用户则为空字符串
|
||||||
"message":'',//回复给用户的内容
|
"message":'',//回复给用户的内容
|
||||||
|
"retry_count":0,//重试次数
|
||||||
"node":'',//下一个节点名称
|
"node":'',//下一个节点名称
|
||||||
}}
|
}}
|
||||||
|
|
||||||
@ -105,7 +106,7 @@ DefaultAgentPrompt = f"""
|
|||||||
def create_agent_prompt(prompt, SchedulerList):
|
def create_agent_prompt(prompt, SchedulerList):
|
||||||
"""创建代理提示词的辅助函数"""
|
"""创建代理提示词的辅助函数"""
|
||||||
if not SchedulerList or len(SchedulerList) == 0: return prompt
|
if not SchedulerList or len(SchedulerList) == 0: return prompt
|
||||||
node_list = [f"{node.name}:{node.desc}" for node in SchedulerList]
|
node_list = [f"{node['name']}:{node['desc']}" for node in SchedulerList]
|
||||||
return f"""
|
return f"""
|
||||||
{prompt} \n
|
{prompt} \n
|
||||||
下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回:
|
下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回:
|
||||||
|
|||||||
@ -35,7 +35,7 @@ DefaultAgentPrompt = f"""
|
|||||||
def create_agent_prompt(prompt, SchedulerList):
|
def create_agent_prompt(prompt, SchedulerList):
|
||||||
"""创建代理提示词的辅助函数"""
|
"""创建代理提示词的辅助函数"""
|
||||||
if not SchedulerList or len(SchedulerList) == 0: return prompt
|
if not SchedulerList or len(SchedulerList) == 0: return prompt
|
||||||
node_list = [f"{node.name}:{node.desc}" for node in SchedulerList]
|
node_list = [f"{node['name']}:{node['desc']}" for node in SchedulerList]
|
||||||
return f"""
|
return f"""
|
||||||
{prompt} \n
|
{prompt} \n
|
||||||
下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回:
|
下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回:
|
||||||
|
|||||||
@ -35,7 +35,7 @@ DefaultAgentPrompt = f"""
|
|||||||
def create_agent_prompt(prompt, SchedulerList):
|
def create_agent_prompt(prompt, SchedulerList):
|
||||||
"""创建代理提示词的辅助函数"""
|
"""创建代理提示词的辅助函数"""
|
||||||
if not SchedulerList or len(SchedulerList) == 0: return prompt
|
if not SchedulerList or len(SchedulerList) == 0: return prompt
|
||||||
node_list = [f"{node.name}:{node.desc}" for node in SchedulerList]
|
node_list = [f"{node['name']}:{node['desc']}" for node in SchedulerList]
|
||||||
return f"""
|
return f"""
|
||||||
{prompt} \n
|
{prompt} \n
|
||||||
下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回:
|
下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回:
|
||||||
|
|||||||
BIN
doc/节点结构图.png
Normal file
BIN
doc/节点结构图.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 30 KiB |
@ -4,15 +4,23 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional
|
from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional
|
||||||
|
|
||||||
|
from langgraph.graph.state import RunnableConfig
|
||||||
from agent.scheduler import SchedulerAgent
|
from agent.scheduler import SchedulerAgent
|
||||||
from agent.build_bible import BuildBibleAgent
|
from agent.build_bible import BuildBibleAgent
|
||||||
from agent.episode_create import EpisodeCreateAgent
|
from agent.episode_create import EpisodeCreateAgent
|
||||||
from agent.script_analysis import ScriptAnalysisAgent
|
from agent.script_analysis import ScriptAnalysisAgent
|
||||||
from agent.strategic_planning import StrategicPlanningAgent
|
from agent.strategic_planning import StrategicPlanningAgent
|
||||||
|
|
||||||
|
from langchain_core.messages import AnyMessage,HumanMessage
|
||||||
from langgraph.graph import StateGraph, START, END
|
from langgraph.graph import StateGraph, START, END
|
||||||
from utils.logger import get_logger
|
from utils.logger import get_logger
|
||||||
import operator
|
import operator
|
||||||
|
import json
|
||||||
|
import config
|
||||||
|
|
||||||
|
from tools.database.mongo import client # type: ignore
|
||||||
|
from langgraph.checkpoint.mongodb import MongoDBSaver
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -24,7 +32,8 @@ def replace_value(old_val, new_val):
|
|||||||
# 状态类型定义
|
# 状态类型定义
|
||||||
class InputState(TypedDict):
|
class InputState(TypedDict):
|
||||||
"""工作流输入状态"""
|
"""工作流输入状态"""
|
||||||
input_data: Annotated[Dict[str, Any], operator.add]
|
input_data: Annotated[list[AnyMessage], operator.add]
|
||||||
|
from_type: Annotated[str, replace_value]
|
||||||
session_id: Annotated[str, replace_value]
|
session_id: Annotated[str, replace_value]
|
||||||
|
|
||||||
class OutputState(TypedDict):
|
class OutputState(TypedDict):
|
||||||
@ -32,6 +41,7 @@ class OutputState(TypedDict):
|
|||||||
session_id: Annotated[str, replace_value]
|
session_id: Annotated[str, replace_value]
|
||||||
status: Annotated[str, replace_value]
|
status: Annotated[str, replace_value]
|
||||||
error: Annotated[str, replace_value]
|
error: Annotated[str, replace_value]
|
||||||
|
agent_message: Annotated[str, replace_value] # 智能体回复
|
||||||
|
|
||||||
class NodeInfo(TypedDict):
|
class NodeInfo(TypedDict):
|
||||||
"""工作流信息"""
|
"""工作流信息"""
|
||||||
@ -45,12 +55,17 @@ class NodeInfo(TypedDict):
|
|||||||
class ScriptwriterState(TypedDict, total=False):
|
class ScriptwriterState(TypedDict, total=False):
|
||||||
"""智能编剧工作流整体状态"""
|
"""智能编剧工作流整体状态"""
|
||||||
# 输入数据
|
# 输入数据
|
||||||
input_data: Annotated[Dict[str, Any], operator.add]
|
input_data: Annotated[list[HumanMessage], operator.add]
|
||||||
session_id: Annotated[str, replace_value]
|
session_id: Annotated[str, replace_value]
|
||||||
|
from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent]
|
||||||
|
|
||||||
# 节点间状态
|
# 节点间状态
|
||||||
node_info: NodeInfo
|
next_node: Annotated[str, replace_value] # 下一个节点
|
||||||
|
workflow_step: Annotated[str, replace_value] # 阶段名称 [wait_for_input,script_analysis,strategic_planning,build_bible,episode_create_loop, finish]
|
||||||
|
workflow_status: Annotated[str, replace_value] # 当前阶段的状态 [waiting,running,failed,completed]
|
||||||
|
workflow_reason: Annotated[str, replace_value] # 失败原因
|
||||||
|
workflow_retry_count: Annotated[int, replace_value] # 重试次数
|
||||||
|
|
||||||
# 中间状态
|
# 中间状态
|
||||||
agent_script_id: Annotated[str, replace_value] # 剧本ID 包括原文
|
agent_script_id: Annotated[str, replace_value] # 剧本ID 包括原文
|
||||||
agent_plan: Annotated[Dict[str, Any], replace_value] #剧本计划
|
agent_plan: Annotated[Dict[str, Any], replace_value] #剧本计划
|
||||||
@ -58,6 +73,7 @@ class ScriptwriterState(TypedDict, total=False):
|
|||||||
episode_list: Annotated[List, replace_value] # 章节列表 完成状态、产出章节id
|
episode_list: Annotated[List, replace_value] # 章节列表 完成状态、产出章节id
|
||||||
|
|
||||||
# 输出数据
|
# 输出数据
|
||||||
|
agent_message: Annotated[str, replace_value] # 智能体回复
|
||||||
status: Annotated[str, replace_value]
|
status: Annotated[str, replace_value]
|
||||||
error: Annotated[str, replace_value]
|
error: Annotated[str, replace_value]
|
||||||
|
|
||||||
@ -76,77 +92,105 @@ class ScriptwriterGraph:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化工作流图"""
|
"""初始化工作流图"""
|
||||||
self.graph = None
|
self.graph = None
|
||||||
|
self.memory = MongoDBSaver(client, db_name=config.MONGO_CHECKPOINT_DB_NAME)
|
||||||
self._build_graph()
|
self._build_graph()
|
||||||
|
|
||||||
|
|
||||||
def node_router(self, state: ScriptwriterState) -> str:
|
def node_router(self, state: ScriptwriterState) -> str:
|
||||||
next_node = state.get("node", '')
|
"""节点路由函数"""
|
||||||
if next_node:
|
print(f'node_router state {state}')
|
||||||
return next_node
|
next_node = state.get("next_node", 'pause_node')
|
||||||
else:
|
# 修复:当 next_node 为空字符串时,设置默认值
|
||||||
return END
|
if not next_node:
|
||||||
|
next_node = 'pause_node' # 设置为暂停节点
|
||||||
|
print(f'node_router next_node {next_node}')
|
||||||
|
return next_node
|
||||||
|
|
||||||
def _build_graph(self) -> None:
|
def _build_graph(self) -> None:
|
||||||
"""构建工作流图"""
|
"""构建工作流图"""
|
||||||
try:
|
try:
|
||||||
# 创建智能体
|
# 创建智能体
|
||||||
|
print("创建智能体")
|
||||||
# 调度智能体
|
# 调度智能体
|
||||||
schedulerAgent = SchedulerAgent(
|
self.schedulerAgent = SchedulerAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
SchedulerList=[
|
SchedulerList=[
|
||||||
{
|
{
|
||||||
"scheduler_node": "调度智能体节点",
|
"name": "scheduler_node",
|
||||||
"script_analysis_node": "原始剧本分析节点",
|
"desc": "调度智能体节点",
|
||||||
"strategic_planning_node": "确立改编目标节点",
|
},
|
||||||
"build_bible_node": "剧本圣经构建节点",
|
{
|
||||||
"episode_create_node": "单集创作节点",
|
"name": "script_analysis_node",
|
||||||
"end_node": "结束节点,任务失败终止时使用,结束后整个工作流将停止"
|
"desc": "原始剧本分析节点",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "strategic_planning_node",
|
||||||
|
"desc": "确立改编目标节点",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "build_bible_node",
|
||||||
|
"desc": "剧本圣经构建节点",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "episode_create_node",
|
||||||
|
"desc": "单集创作节点",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "end_node",
|
||||||
|
"desc": "结束节点,任务失败终止时使用,结束后整个工作流将停止"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
scriptAnalysisAgent = ScriptAnalysisAgent(
|
self.scriptAnalysisAgent = ScriptAnalysisAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
SchedulerList=[
|
SchedulerList=[
|
||||||
{
|
{
|
||||||
"scheduler_node": "调度智能体节点",
|
"name": "scheduler_node",
|
||||||
|
"desc": "调度智能体节点",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
strategicPlanningAgent = StrategicPlanningAgent(
|
self.strategicPlanningAgent = StrategicPlanningAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
SchedulerList=[
|
SchedulerList=[
|
||||||
{
|
{
|
||||||
"scheduler_node": "调度智能体节点",
|
"name": "scheduler_node",
|
||||||
|
"desc": "调度智能体节点",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
buildBibleAgent = BuildBibleAgent(
|
self.buildBibleAgent = BuildBibleAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
SchedulerList=[
|
SchedulerList=[
|
||||||
{
|
{
|
||||||
"scheduler_node": "调度智能体节点",
|
"name": "scheduler_node",
|
||||||
|
"desc": "调度智能体节点",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
episodeCreate = EpisodeCreateAgent(
|
self.episodeCreate = EpisodeCreateAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
SchedulerList=[
|
SchedulerList=[
|
||||||
{
|
{
|
||||||
"scheduler_node": "调度智能体节点",
|
"name": "scheduler_node",
|
||||||
|
"desc": "调度智能体节点",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建状态图
|
# 创建状态图
|
||||||
|
logger.info("创建状态图")
|
||||||
workflow = StateGraph(ScriptwriterState, input_schema=InputState, output_schema=OutputState)
|
workflow = StateGraph(ScriptwriterState, input_schema=InputState, output_schema=OutputState)
|
||||||
|
|
||||||
# 添加节点
|
# 添加节点
|
||||||
|
logger.info("添加节点")
|
||||||
workflow.add_node("scheduler_node", self.scheduler_node)
|
workflow.add_node("scheduler_node", self.scheduler_node)
|
||||||
workflow.add_node("script_analysis_node", self.script_analysis_node)
|
workflow.add_node("script_analysis_node", self.script_analysis_node)
|
||||||
workflow.add_node("strategic_planning_node", self.strategic_planning_node)
|
workflow.add_node("strategic_planning_node", self.strategic_planning_node)
|
||||||
workflow.add_node("build_bible_node", self.build_bible_node)
|
workflow.add_node("build_bible_node", self.build_bible_node)
|
||||||
workflow.add_node("episode_create_node", self.episode_create_node)
|
workflow.add_node("episode_create_node", self.episode_create_node)
|
||||||
workflow.add_node("end_node", self.end_node)
|
workflow.add_node("end_node", self.end_node)
|
||||||
|
workflow.add_node("pause_node", self.pause_node)
|
||||||
|
|
||||||
# 添加边
|
# 添加边
|
||||||
workflow.set_entry_point("scheduler_node")
|
workflow.set_entry_point("scheduler_node")
|
||||||
@ -167,13 +211,14 @@ class ScriptwriterGraph:
|
|||||||
"episode_create_node": "episode_create_node",
|
"episode_create_node": "episode_create_node",
|
||||||
# 用户确认和暂停逻辑在这里处理,不需要单独的边
|
# 用户确认和暂停逻辑在这里处理,不需要单独的边
|
||||||
"end_node": "end_node",
|
"end_node": "end_node",
|
||||||
|
"pause_node": "pause_node",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow.add_edge("end_node", END)
|
workflow.add_edge("end_node", END)
|
||||||
|
|
||||||
# 编译图
|
# 编译图
|
||||||
self.graph = workflow.compile()
|
self.graph = workflow.compile(checkpointer=self.memory)
|
||||||
logger.info("工作流图构建完成")
|
logger.info("工作流图构建完成")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -182,10 +227,44 @@ class ScriptwriterGraph:
|
|||||||
|
|
||||||
# --- 定义图中的节点 ---
|
# --- 定义图中的节点 ---
|
||||||
async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
||||||
"""第一步:初步沟通,请求剧本"""
|
"""调度节点"""
|
||||||
session_id = state.get("session_id", "")
|
try:
|
||||||
|
session_id = state.get("session_id", "")
|
||||||
return {}
|
from_type = state.get("from_type", "")
|
||||||
|
input_data = state.get("input_data", [])
|
||||||
|
logger.info(f"调度节点 {session_id} 输入参数: {input_data} from_type:{from_type}")
|
||||||
|
reslut = await self.schedulerAgent.ainvoke(state)
|
||||||
|
ai_message_str = reslut['messages'][-1].content
|
||||||
|
ai_message = json.loads(ai_message_str)
|
||||||
|
logger.info(f"调度节点结果: {ai_message}")
|
||||||
|
step:str = ai_message.get('step', '')
|
||||||
|
status:str = ai_message.get('status', '')
|
||||||
|
next_agent:str = ai_message.get('agent', '')
|
||||||
|
return_message:str = ai_message.get('message', '')
|
||||||
|
retry_count:int = int(ai_message.get('retry_count', '0'))
|
||||||
|
next_node:str = ai_message.get('node', 'pause_node')
|
||||||
|
if next_node == 'scheduler_node':
|
||||||
|
# 返回自身 代表暂停
|
||||||
|
print(f"调度节点 暂停等待")
|
||||||
|
return {
|
||||||
|
"agent_message": return_message,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"next_node":next_node,
|
||||||
|
"workflow_step":step,
|
||||||
|
"workflow_status":status,
|
||||||
|
# "workflow_reason":return_message,
|
||||||
|
"workflow_retry_count":retry_count,
|
||||||
|
"agent_message":return_message,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"next_node":'end_node',
|
||||||
|
"agent_message": "执行失败",
|
||||||
|
"error": str(e) or '系统错误,工作流已终止',
|
||||||
|
'status':'failed',
|
||||||
|
}
|
||||||
|
|
||||||
async def script_analysis_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
async def script_analysis_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
||||||
"""第二步:诊断分析与资产评估"""
|
"""第二步:诊断分析与资产评估"""
|
||||||
@ -218,20 +297,32 @@ class ScriptwriterGraph:
|
|||||||
episode_list = []
|
episode_list = []
|
||||||
return {"episode_list": episode_list}
|
return {"episode_list": episode_list}
|
||||||
|
|
||||||
|
async def pause_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
||||||
|
""" 暂停节点 处理并完成所有数据状态 """
|
||||||
|
print(f"langgraph 暂停等待")
|
||||||
|
return {
|
||||||
|
"session_id": state.get("session_id", ""),
|
||||||
|
"status": state.get('status', ''),
|
||||||
|
"error": state.get('error', ''),
|
||||||
|
"agent_message": state.get('agent_message', '')
|
||||||
|
}
|
||||||
|
|
||||||
async def end_node(self, state: ScriptwriterState)-> OutputState:
|
async def end_node(self, state: ScriptwriterState)-> OutputState:
|
||||||
""" 结束节点 处理并完成所有数据状态 """
|
""" 结束节点 处理并完成所有数据状态 """
|
||||||
print(f"langgraph 所有任务完成")
|
print(f"langgraph 所有任务完成")
|
||||||
return {
|
return {
|
||||||
"session_id": state.get("session_id", ""),
|
"session_id": state.get("session_id", ""),
|
||||||
"status": "",
|
"status": state.get('status', ''),
|
||||||
"error": "",
|
"error": state.get('error', ''),
|
||||||
|
"agent_message": state.get('agent_message', ''),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def run(self, input_data: Dict[str, Any]) -> OutputState:
|
async def run(self, session_id: str, input_data: list[AnyMessage], thread_id: str|None = None) -> OutputState:
|
||||||
"""运行工作流
|
"""运行工作流
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
input_data: 输入数据
|
input_data: 输入数据
|
||||||
|
thread_id: 线程ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
工作流执行结果
|
工作流执行结果
|
||||||
@ -239,44 +330,102 @@ class ScriptwriterGraph:
|
|||||||
try:
|
try:
|
||||||
logger.info("开始运行智能编剧工作流")
|
logger.info("开始运行智能编剧工作流")
|
||||||
|
|
||||||
# # 初始化状态
|
# 配置包含线程 ID
|
||||||
# initial_state: InputState = {
|
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
|
||||||
# 'input_data': input_data,
|
# 初始化状态
|
||||||
# 'session_id': input_data.get('session_id', ''),
|
initial_state: InputState = {
|
||||||
# 'max_iterations': input_data.get('max_iterations', 3),
|
'input_data': input_data,
|
||||||
# 'batch_info': input_data.get('batch_info', {})
|
'session_id': session_id,
|
||||||
# }
|
'from_type': 'user',
|
||||||
|
}
|
||||||
|
|
||||||
# # 运行工作流
|
# 运行工作流
|
||||||
# if self.graph is None:
|
if self.graph is None:
|
||||||
# raise RuntimeError("工作流图未正确初始化")
|
raise RuntimeError("工作流图未正确初始化")
|
||||||
|
|
||||||
# result = await self.graph.ainvoke(initial_state)
|
result = await self.graph.ainvoke(
|
||||||
|
initial_state,
|
||||||
|
config,
|
||||||
|
# stream_mode='values'
|
||||||
|
)
|
||||||
# logger.info(f"工作流执行结果: {result}")
|
# logger.info(f"工作流执行结果: {result}")
|
||||||
# if not result:
|
if not result:
|
||||||
# raise ValueError("工作流执行结果为空")
|
raise ValueError("工作流执行结果为空")
|
||||||
# # 保存到记忆
|
|
||||||
# self.memory.save_workflow_result(result)
|
|
||||||
|
|
||||||
# # 构造输出状态
|
# 构造输出状态
|
||||||
# output_result: OutputState = {
|
output_result: OutputState = {
|
||||||
# 'script': result.get('script'),
|
'session_id': result.get('session_id', ''),
|
||||||
# 'adjustment': result.get('adjustment'),
|
'status': result.get('status', ''),
|
||||||
# 'error': result.get('error'),
|
'error': result.get('error', ''),
|
||||||
# 'iteration_count': result.get('iteration_count', 0)
|
'agent_message': result.get('agent_message', ''),
|
||||||
# }
|
|
||||||
output_result:OutputState = {
|
|
||||||
'session_id': "",
|
|
||||||
'status': 'completed',
|
|
||||||
'error': '',
|
|
||||||
}
|
}
|
||||||
logger.info("智能编剧工作流运行完成")
|
logger.info("智能编剧工作流运行完成")
|
||||||
return output_result
|
return output_result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"运行工作流失败: {e}")
|
logger.error(f"运行工作流失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def get_checkpoint_history(self, thread_id: str):
|
||||||
|
"""获取检查点历史"""
|
||||||
|
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
|
||||||
|
|
||||||
|
try:
|
||||||
|
history_generator = self.memory.list(config, limit=10)
|
||||||
|
|
||||||
|
print("正在获取检查点历史...")
|
||||||
|
|
||||||
|
# 使用列表推导式或 for 循环来收集所有检查点
|
||||||
|
history = list(history_generator)
|
||||||
|
|
||||||
|
print(f"找到 {len(history)} 个检查点:")
|
||||||
|
|
||||||
|
for i, checkpoint_tuple in enumerate(history):
|
||||||
|
# checkpoint_tuple 包含 config, checkpoint, metadata 等属性
|
||||||
|
# print(f" - ID: {checkpoint_tuple}")
|
||||||
|
checkpoint_data = checkpoint_tuple.checkpoint
|
||||||
|
metadata = checkpoint_tuple.metadata
|
||||||
|
print(f"检查点 {i+1}:")
|
||||||
|
print(f" - ID: {checkpoint_data.get('id', 'N/A')}")
|
||||||
|
print(f" - 状态: {checkpoint_data.get('channel_values', {})}")
|
||||||
|
print(f" - 元数据: {metadata}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取历史记录时出错: {e}")
|
||||||
|
|
||||||
|
def resume_from_checkpoint(self, thread_id: str, checkpoint_id: str):
|
||||||
|
"""从检查点恢复执行"""
|
||||||
|
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
|
||||||
|
|
||||||
|
if checkpoint_id:
|
||||||
|
config["configurable"]["checkpoint_id"] = checkpoint_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取 CheckpointTuple 对象
|
||||||
|
checkpoint_tuple = self.memory.get_tuple(config)
|
||||||
|
|
||||||
|
if checkpoint_tuple:
|
||||||
|
# 直接通过属性访问,而不是解包
|
||||||
|
checkpoint_data = checkpoint_tuple.checkpoint
|
||||||
|
metadata = checkpoint_tuple.metadata
|
||||||
|
print(f"从检查点恢复:")
|
||||||
|
print(f" - 检查点 ID: {checkpoint_data.get('id', 'N/A')}")
|
||||||
|
print(f" - 状态: {checkpoint_data.get('channel_values', {})}")
|
||||||
|
print(f" - 元数据: {metadata}")
|
||||||
|
return checkpoint_data.get('channel_values', {})
|
||||||
|
else:
|
||||||
|
print(f"未找到线程 {thread_id} 的检查点")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"恢复检查点时出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_graph_visualization(self) -> str:
|
def get_graph_visualization(self) -> str:
|
||||||
"""获取工作流图的可视化表示
|
"""获取工作流图的可视化表示
|
||||||
|
|
||||||
@ -285,8 +434,31 @@ class ScriptwriterGraph:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if self.graph:
|
if self.graph:
|
||||||
return str(self.graph)
|
with open('graph_visualization.png', 'wb') as f:
|
||||||
|
f.write(self.graph.get_graph().draw_mermaid_png())
|
||||||
|
print("图片已保存为 graph_visualization.png")
|
||||||
return "工作流图未初始化"
|
return "工作流图未初始化"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取图可视化失败: {e}")
|
logger.error(f"获取图可视化失败: {e}")
|
||||||
return f"获取图可视化失败: {e}"
|
return f"获取图可视化失败: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
print("测试")
|
||||||
|
graph = ScriptwriterGraph()
|
||||||
|
print("创建完成")
|
||||||
|
# graph.get_graph_visualization()
|
||||||
|
# print("可视化完成")
|
||||||
|
# 运行工作流
|
||||||
|
session_id = "68c2c2915e5746343301ef71"
|
||||||
|
result = await graph.run(
|
||||||
|
session_id,
|
||||||
|
[HumanMessage(content="你好编剧,我想写小说!")],
|
||||||
|
session_id
|
||||||
|
)
|
||||||
|
print(f"最终结果: {result}")
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|||||||
@ -1,30 +0,0 @@
|
|||||||
"""工作流记忆管理模块
|
|
||||||
|
|
||||||
该模块负责管理智能编剧系统工作流的记忆存储和检索。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from typing import Dict, Any, List, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
import json
|
|
||||||
from database import client # type: ignore
|
|
||||||
from langgraph.checkpoint.mongodb import MongoDBSaver
|
|
||||||
|
|
||||||
# 添加项目根目录到路径
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
|
|
||||||
from agentgraph.utils.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
DB_NAME = "langgraph_memory_db"
|
|
||||||
|
|
||||||
class WorkflowMemory:
|
|
||||||
"""工作流记忆管理类
|
|
||||||
|
|
||||||
负责管理工作流执行过程中的状态存储、检索和历史记录。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""初始化工作流记忆管理器"""
|
|
||||||
self.memory = MongoDBSaver(client, db_name=DB_NAME)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user