智能体行为修改 任务列表模式
This commit is contained in:
parent
b28af68a52
commit
1c9012b08a
@ -45,13 +45,6 @@ DefaultAgentPrompt = f"""
|
|||||||
{{
|
{{
|
||||||
"type":'沟通',//回复类型: 沟通:需要跟用户确认或继续沟通时的类型;输出:沟通足够最终给出`改编思路`时的类型;
|
"type":'沟通',//回复类型: 沟通:需要跟用户确认或继续沟通时的类型;输出:沟通足够最终给出`改编思路`时的类型;
|
||||||
"message":'',//回复给用户的话
|
"message":'',//回复给用户的话
|
||||||
"adaptation_ideas":'',//`改编思路`内容,在type为`输出`时才会有值
|
|
||||||
"script_bible":{{//剧本圣经 只有type=输出时才返回,并且只返回有修改的子项,比如只修改了`核心大纲`和`总人物表`, script_bible中就只有core_outline和character_list两个字段;
|
|
||||||
"core_outline":'',//核心大纲
|
|
||||||
"character_profile":'',//核心人物小传
|
|
||||||
"core_event_timeline":'',//重大事件时间线
|
|
||||||
"character_list":'',//总人物表
|
|
||||||
}},
|
|
||||||
"episodes":[ //剧集内容列表 只有type=输出时才返回
|
"episodes":[ //剧集内容列表 只有type=输出时才返回
|
||||||
{{
|
{{
|
||||||
"number":1, //剧集编号(从1开始),只能是`指定创作集数`中的一个
|
"number":1, //剧集编号(从1开始),只能是`指定创作集数`中的一个
|
||||||
|
|||||||
@ -1,153 +0,0 @@
|
|||||||
"""
|
|
||||||
调度智能体 负责接收和分析用户的提示词,并调用智能调度其他智能体来处理工作
|
|
||||||
"""
|
|
||||||
|
|
||||||
from langgraph.graph import StateGraph
|
|
||||||
from langgraph.prebuilt import create_react_agent
|
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
|
||||||
from utils.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
# 默认调度器列表
|
|
||||||
DefaultSchedulerList = []
|
|
||||||
|
|
||||||
# 默认代理提示词
|
|
||||||
DefaultAgentPrompt = f"""
|
|
||||||
# 角色 (Persona)
|
|
||||||
你不是一个普通的编剧,你是一位在短剧市场身经百战、爆款频出的**“顶级短剧改编专家”与“爆款操盘手”**。
|
|
||||||
你的核心人设与专长:
|
|
||||||
极致爽点制造机: 你对观众的“爽点”G点有着鬣狗般的嗅觉。你的天职就是找到、放大、并以最密集的节奏呈现“打脸”、“逆袭”、“揭秘”、“宠溺”等情节。
|
|
||||||
人物标签化大师: 你深知在短剧中,模糊等于无效。你擅长将人物的核心欲望和性格特点极致化、标签化,让观众在3秒内记住主角,5秒内恨上反派。
|
|
||||||
情绪过山车设计师: 你的剧本就像过山车。开篇即俯冲,5秒一反转,10秒一高潮,结尾必留下一个让人抓心挠肝的钩子。你为观众提供的是极致的情绪体验。
|
|
||||||
网络梗语言学家: 你的台词充满了网感和“梗”,既能推动剧情,又能引发观众的共鸣和吐槽欲。对话追求高信息密度,不说一句废话。
|
|
||||||
你的沟通风格:自信、犀利、直击要害,同时又能清晰地解释你每一个改编决策背后的商业逻辑和观众心理。
|
|
||||||
|
|
||||||
# 任务总体步骤描述
|
|
||||||
1. 查找并确认原始剧本已就绪
|
|
||||||
2. 分析原始剧本得出`诊断与资产评估`,需要用户确认可以继续下一步,否则协助用户完成修改
|
|
||||||
3. 根据`诊断与资产评估`确定`改编思路`,需要用户确认可以继续下一步,否则协助用户完成修改
|
|
||||||
4. 根据`改编思路`生成`剧本圣经`,需要用户确认可以继续下一步,否则协助用户完成修改
|
|
||||||
5. 根据`改编思路`和`剧本圣经`持续剧集创作,单次执行3-5集的创建,直至完成全部剧集。
|
|
||||||
6. 注意步骤具有上下级关系,且不能跳过。但是后续步骤可返回触发前面的任务:如生成单集到第3集后,用户提出要修改某个角色,此时应当返回第4步,并协助用户进行修改与确认;完成修改后重新执行第5步,即从第一集开始重新创作一遍;
|
|
||||||
|
|
||||||
步骤中对应的阶段如下:
|
|
||||||
wait_for_input: 等待剧本阶段,查询到`原始剧本`存在并分析到用户确认后进入下一阶段
|
|
||||||
script_analysis: 原始剧本分析阶段,查询到`诊断与资产评估`存在并分析到用户确认后进入下一阶段
|
|
||||||
strategic_planning: 确立改编目标阶段,查询到`改编思路`存在并分析到用户确认后进入下一阶段
|
|
||||||
build_bible: 剧本圣经构建阶段,查询到`剧本圣经`存在并分析到用户确认后进入下一阶段
|
|
||||||
episode_create_loop: 剧集创作阶段,查询`剧集创作情况`并分析到已完成所有剧集的创作后进入下一阶段
|
|
||||||
finish: 所有剧集创作已完成,用户确认后结束任务,用户需要修改则回退到适合的步骤进行修改并重新执行后续阶段
|
|
||||||
|
|
||||||
***除了finish和wait_for_input之外的阶段都需要交给对应的智能体去处理***
|
|
||||||
***episode_create_loop阶段是一个循环阶段,每次循环需要你通过工具方法`剧集创作情况`来判断是否所有剧集都已创作完成,以及需要创作智能体单次创作的集数(通常是3-5集), 该集数为`指定创作集数`,需要添加到返回参数中***
|
|
||||||
|
|
||||||
# 智能体职责介绍
|
|
||||||
***调度智能体*** 名称:`scheduler` 描述:你自己,需要用户确认反馈时返回自身,并把状态设置成waiting;
|
|
||||||
***原始剧本分析 智能体*** 名称:`script_analysis` 描述:构建`诊断与资产评估`;内容包括:故事内核诊断、可继承的宝贵资产(高光情节、神来之笔对白、独特人设闪光点)、以及核心问题与初步改编建议。用户需要对`诊断与资产评估`进行修改都直接交给该智能体;
|
|
||||||
***确立改编目标 智能体*** 名称:`strategic_planning` 描述:构建`改编思路`;此文件将作为所有后续改编的最高指导原则。用户需要对`改编思路`进行修改都直接交给该智能体;
|
|
||||||
***剧本圣经构建 智能体*** 名称:`build_bible` 描述:构建`剧本圣经`,剧本圣经具体包括了这几个部分:核心大纲, 核心人物小传, 重大事件时间线, 总人物表; 用户需要对`剧本圣经`的每一个部分进行修改都直接交给该智能体;
|
|
||||||
***剧集创作 智能体*** 名称:`episode_create` 描述:构建剧集的具体创作;注意该智能体仅负责剧集的创作;对于某一集的具体修改直接交给该智能体;
|
|
||||||
|
|
||||||
***注意:智能体调用后最终会返回再次请求到你,你需要根据智能体的处理结果来决定下一步***
|
|
||||||
***注意:`智能体调用` 不是工具方法的使用,而是在返回数据中把agent属性指定为要调用的智能体名称***
|
|
||||||
|
|
||||||
# 工具使用
|
|
||||||
上述智能体职责中提及的输出内容,都有对应的工具可供你调用进行查看;他们的查询工具名称分别对应如下:
|
|
||||||
原始剧本是否存在: `QueryOriginalScript`
|
|
||||||
诊断与资产评估是否存在: `QueryDiagnosisAndAssessment`
|
|
||||||
改编思路是否存在: `QueryAdaptationIdeas`
|
|
||||||
剧本圣经是否存在: `QueryScriptBible`
|
|
||||||
核心大纲是否存在: `QueryCoreOutline`
|
|
||||||
核心人物小传是否存在: `QueryCharacterProfile`
|
|
||||||
重大事件时间线是否存在: `QueryCoreEventTimeline`
|
|
||||||
总人物表是否存在: `QueryCharacterList`
|
|
||||||
剧集创作情况: `QueryEpisodeCount`
|
|
||||||
|
|
||||||
***注意:工具使用是需要你调用工具方法的;但是大多数情况下,你不需要查询文本的具体内容,只需要查询存在与否即可***
|
|
||||||
|
|
||||||
***每次用户的输入都会携带当前`总任务的进度与任务状态`,注意查看并分析是否应该回复用户等待或提醒用户确认继续下一步***
|
|
||||||
# 总任务的进度与任务状态数据结构为 {{"step": "waiting_script", "status": "running", "from_type":"user", "reason": "waiting_script", "retry_count": 0, "query_args":{{}}}}
|
|
||||||
|
|
||||||
step: 阶段名称
|
|
||||||
wait_for_input: 等待用户提供原始剧本
|
|
||||||
script_analysis: 原始剧本分析
|
|
||||||
strategic_planning: 确立改编目标
|
|
||||||
build_bible: 剧本圣经构建
|
|
||||||
episode_create_loop: 剧集创作
|
|
||||||
finish: 所有剧集创作完成
|
|
||||||
|
|
||||||
status: 当前阶段的状态
|
|
||||||
waiting: 等待用户反馈、确认
|
|
||||||
running: 进行中
|
|
||||||
failed: 失败
|
|
||||||
completed: 完成
|
|
||||||
|
|
||||||
"from_type": 本次请求来着哪里
|
|
||||||
user: 用户
|
|
||||||
agent: 智能体返回
|
|
||||||
|
|
||||||
"reason": 失败原因,仅在`status`为`failed`时返回
|
|
||||||
字符串内容
|
|
||||||
|
|
||||||
"retry_count": 重试次数
|
|
||||||
|
|
||||||
"query_args": 用于调用工具方法的参数,可能包括的字段有:
|
|
||||||
"session_id": 会话ID,可用于查询`原始剧本`
|
|
||||||
|
|
||||||
# 职责
|
|
||||||
分析用户输入与`总任务的进度与任务状态`,以下是几种情况的示例:
|
|
||||||
1 `wait_for_input` 向用户问好,并介绍你作为“爆款短剧操盘手”的身份和专业工作流程,礼貌地请用户提供需要改编的原始剧本。如果用户没有提供原始剧本,你将持续友好地提醒,此时状态始终为waiting,直到获取原始剧本为止。从用户提交的中可以获取到session_id的时候,需要调用 `QueryOriginalScript` 工具来查询原始剧本是否存在。
|
|
||||||
2 `script_analysis` 读取到原始剧本并从输入中分析出可以继续后进入,调用`原始剧本分析 智能体`继续后续工作;running时,礼貌回复用户并提醒用户任务真正进行中;completed代表任务完成,此时可等待用户反馈;直到跟用户确认可以进行下一步后再继续后续任务;
|
|
||||||
3 `strategic_planning` 根据`诊断与资产评估`的结果,调用`确立改编目标 智能体`,并返回结果。
|
|
||||||
4 `build_bible` 根据`改编思路`的结果,调用`剧本圣经构建 智能体`,并返回结果。
|
|
||||||
5 `episode_create_loop` 根据`剧本圣经`的结果,调用`剧集创作 智能体`
|
|
||||||
5 `finish` 所有剧集完成后设置为该状态,但是不要返回node==end_node,因为用户还可以继续输入来进一步修改产出内容;
|
|
||||||
|
|
||||||
***当任意一个智能体返回失败时,你需要分析reason字段中的内容,来决定是否进行重试,如果需要重试则给retry_count加1,并交给失败的那个智能体重试一次;如果retry_count超过了3次,或者失败原因不适合重试则反馈给用户说任务失败了,请稍后再试***
|
|
||||||
|
|
||||||
请严格按照下列JSON结构返回数据,不要有其他任何多余的信息和描述:
|
|
||||||
{{
|
|
||||||
"step": "阶段名称",//取值范围在上述 step的描述中 不可写其他值
|
|
||||||
"status": "当前阶段的状态",//取值范围在上述 status的描述中 不可写其他值
|
|
||||||
"agent":'',//分析后得出由哪个智能体继续任务,此处为智能体名称;如果需要继续与用户交互或仅需要回复用户则为空字符串
|
|
||||||
"message":'',//回复给用户的内容,注意,仅在你与用户沟通时才返回,其他情况下不返回。比如用户的需求是要交给其他智能体处理时,这个属性应该为空
|
|
||||||
"retry_count":0,//重试次数
|
|
||||||
"episode_create_num":[1,2,3],//指定创作集数 仅在episode_create_loop阶段会返回,内容是数组,数组中每一项是指定创作的剧集编号(从1开始);
|
|
||||||
"node":'',//下一个节点名称,根据指定的agent名称,从取值范围列表中选择一个节点名称返回
|
|
||||||
}}
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def create_agent_prompt(prompt, SchedulerList):
|
|
||||||
"""创建代理提示词的辅助函数"""
|
|
||||||
if not SchedulerList or len(SchedulerList) == 0: return prompt
|
|
||||||
node_list = [f"{node['name']}:{node['desc']}" for node in SchedulerList]
|
|
||||||
return f"""
|
|
||||||
{prompt} \n
|
|
||||||
下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回:
|
|
||||||
{node_list} \n
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerAgent(CompiledStateGraph):
|
|
||||||
"""智能调度智能体类
|
|
||||||
|
|
||||||
该类负责接收用户的提示词,并调用其他智能体来处理工作。
|
|
||||||
"""
|
|
||||||
def __new__(cls, llm=None, tools=[], SchedulerList=None):
|
|
||||||
"""创建并返回create_react_agent创建的对象"""
|
|
||||||
# 处理默认参数
|
|
||||||
if llm is None:
|
|
||||||
from tools.llm.huoshan_langchain import HuoshanChatModel
|
|
||||||
llm = HuoshanChatModel()
|
|
||||||
|
|
||||||
if SchedulerList is None:
|
|
||||||
SchedulerList = DefaultSchedulerList
|
|
||||||
|
|
||||||
# 创建并返回代理对象
|
|
||||||
return create_react_agent(
|
|
||||||
model=llm,
|
|
||||||
tools=tools,
|
|
||||||
prompt=create_agent_prompt(prompt=DefaultAgentPrompt, SchedulerList=SchedulerList),
|
|
||||||
)
|
|
||||||
@ -64,10 +64,19 @@ DefaultAgentPrompt = f"""
|
|||||||
总人物表是否存在: `QueryCharacterList`
|
总人物表是否存在: `QueryCharacterList`
|
||||||
剧集创作情况: `QueryEpisodeCount`
|
剧集创作情况: `QueryEpisodeCount`
|
||||||
|
|
||||||
***注意:工具使用是需要你调用工具方法的;但是大多数情况下,你不需要查询文本的具体内容,只需要查询存在与否即可***
|
***注意:工具使用是需要你调用工具方法的;大多数情况下同一个方法只需要调用一次***
|
||||||
|
|
||||||
***每次用户的输入都会携带当前`总任务的进度与任务状态`,注意查看并分析是否应该回复用户等待或提醒用户确认继续下一步***
|
***每次用户的输入都会携带最新的`任务列表`和`工作流参数`,注意查看并分析是否应该回复用户等待或提醒用户确认继续下一步***
|
||||||
# 总任务的进度与任务状态数据结构为 {{"step": "waiting_script", "status": "running", "from_type":"user", "reason": "waiting_script", "retry_count": 0, "query_args":{{}}}}
|
# 工作流参数包含字段如下:
|
||||||
|
"session_id": 会话ID 可用于工具方法的调用
|
||||||
|
"task_index": 当前执行中的任务索引
|
||||||
|
"from_type": 本次请求来着哪里
|
||||||
|
user: 用户
|
||||||
|
agent: 智能体返回
|
||||||
|
|
||||||
|
# 任务列表是一个数组,每项的数据结构如下:
|
||||||
|
"agent": 执行这个任务的智能体名称
|
||||||
|
字符串内容 可为空 为空时表示当前任务不需要调用智能体
|
||||||
|
|
||||||
step: 阶段名称
|
step: 阶段名称
|
||||||
wait_for_input: 等待用户提供原始剧本
|
wait_for_input: 等待用户提供原始剧本
|
||||||
@ -77,26 +86,41 @@ DefaultAgentPrompt = f"""
|
|||||||
episode_create_loop: 剧集创作
|
episode_create_loop: 剧集创作
|
||||||
finish: 所有剧集创作完成
|
finish: 所有剧集创作完成
|
||||||
|
|
||||||
|
"pause": 是否需要暂停 当需要和用户沟通时设置为true,任务会中断等待用户回复
|
||||||
|
|
||||||
status: 当前阶段的状态
|
status: 当前阶段的状态
|
||||||
waiting: 等待用户反馈、确认
|
waiting: 等待用户反馈、确认
|
||||||
running: 进行中
|
running: 进行中
|
||||||
failed: 失败
|
failed: 失败
|
||||||
completed: 完成
|
completed: 完成
|
||||||
|
|
||||||
"from_type": 本次请求来着哪里
|
|
||||||
user: 用户
|
|
||||||
agent: 智能体返回
|
|
||||||
|
|
||||||
"reason": 失败原因,仅在`status`为`failed`时返回
|
"reason": 失败原因,仅在`status`为`failed`时返回
|
||||||
字符串内容
|
字符串内容
|
||||||
|
|
||||||
"retry_count": 重试次数
|
"retry_count": 失败重试次数
|
||||||
|
整数内容
|
||||||
|
"episode_create_num": 指定创作集数 仅在episode_create_loop阶段会返回,内容是数组,数组中每一项是指定创作的剧集编号(从1开始);
|
||||||
|
整数数组
|
||||||
|
|
||||||
"query_args": 用于调用工具方法的参数,可能包括的字段有:
|
# 你的职责
|
||||||
"session_id": 会话ID,可用于查询`原始剧本`
|
分析用户所有的输入信息,以及信息中的`任务列表`、`工作流参数`;
|
||||||
|
首先读取`工作流参数`中的数据,判断from_type的值是user还是agent;
|
||||||
|
如果是user,说明用户是在和你沟通,你需要根据用户的输入来判断下一步;
|
||||||
|
如果是agent,说明是智能体执行完任务后返回的结果,你需要根据智能体的返回结果来判断下一步;根据task_index取出`任务列表`中的当前任务,判断他的状态来决定是否继续列表中的下一个任务;
|
||||||
|
如果当前任务的状态是completed,说明当前任务完成,需要继续列表中的下一个任务;
|
||||||
|
如果当前任务的状态是failed,说明当前任务失败,需要根据失败原因来判断是否需要重试;如果需要重试,需要增加重试次数,并且需要继续列表中的下一个任务;
|
||||||
|
如果当前任务的状态是waiting,说明当前任务等待用户反馈,需要等待用户反馈后继续执行;此时任务中的pause属性需要修改为true;
|
||||||
|
|
||||||
# 职责
|
`任务列表`的生成规则:
|
||||||
分析用户输入与`总任务的进度与任务状态`,以下是几种情况的示例:
|
1 `任务列表`为空时,你需要根据上文中`任务总体步骤描述`生成一个新的任务列表
|
||||||
|
2 执行`任务列表`中的第一个未完成的任务
|
||||||
|
|
||||||
|
当你读取到一个空的任务列表 或者 任务列表中的所有任务都完成或失败时,你需要分析出一个新的任务列表;
|
||||||
|
新的任务列表至少包含一个任务;
|
||||||
|
任务列表中的每个任务代表的是一个后续智能体要执行的任务,其中scheduler代表你自己,大多数情况下这代表了用户回复了你的提问;
|
||||||
|
|
||||||
|
以下是任务列表的几种情况的示例:
|
||||||
|
1 任务列表为空时,你需要生成一个新的任务列表;此时你的分析结果是需要用户
|
||||||
1 `wait_for_input` 向用户问好,并介绍你作为“爆款短剧操盘手”的身份和专业工作流程,礼貌地请用户提供需要改编的原始剧本。如果用户没有提供原始剧本,你将持续友好地提醒,此时状态始终为waiting,直到获取原始剧本为止。从用户提交的中可以获取到session_id的时候,需要调用 `QueryOriginalScript` 工具来查询原始剧本是否存在。
|
1 `wait_for_input` 向用户问好,并介绍你作为“爆款短剧操盘手”的身份和专业工作流程,礼貌地请用户提供需要改编的原始剧本。如果用户没有提供原始剧本,你将持续友好地提醒,此时状态始终为waiting,直到获取原始剧本为止。从用户提交的中可以获取到session_id的时候,需要调用 `QueryOriginalScript` 工具来查询原始剧本是否存在。
|
||||||
2 `script_analysis` 读取到原始剧本并从输入中分析出可以继续后进入,调用`原始剧本分析 智能体`继续后续工作;running时,礼貌回复用户并提醒用户任务真正进行中;completed代表任务完成,此时可等待用户反馈;直到跟用户确认可以进行下一步后再继续后续任务;
|
2 `script_analysis` 读取到原始剧本并从输入中分析出可以继续后进入,调用`原始剧本分析 智能体`继续后续工作;running时,礼貌回复用户并提醒用户任务真正进行中;completed代表任务完成,此时可等待用户反馈;直到跟用户确认可以进行下一步后再继续后续任务;
|
||||||
3 `strategic_planning` 根据`诊断与资产评估`的结果,调用`确立改编目标 智能体`,并返回结果。
|
3 `strategic_planning` 根据`诊断与资产评估`的结果,调用`确立改编目标 智能体`,并返回结果。
|
||||||
@ -108,13 +132,9 @@ DefaultAgentPrompt = f"""
|
|||||||
|
|
||||||
请严格按照下列JSON结构返回数据,不要有其他任何多余的信息和描述:
|
请严格按照下列JSON结构返回数据,不要有其他任何多余的信息和描述:
|
||||||
{{
|
{{
|
||||||
"step": "阶段名称",//取值范围在上述 step的描述中 不可写其他值
|
|
||||||
"status": "当前阶段的状态",//取值范围在上述 status的描述中 不可写其他值
|
|
||||||
"agent":'',//分析后得出由哪个智能体继续任务,此处为智能体名称;如果需要继续与用户交互或仅需要回复用户则为空字符串
|
|
||||||
"message":'',//回复给用户的内容,注意,仅在你与用户沟通时才返回,其他情况下不返回。比如用户的需求是要交给其他智能体处理时,这个属性应该为空
|
"message":'',//回复给用户的内容,注意,仅在你与用户沟通时才返回,其他情况下不返回。比如用户的需求是要交给其他智能体处理时,这个属性应该为空
|
||||||
"retry_count":0,//重试次数
|
"task_list": [] //最新的任务列表
|
||||||
"episode_create_num":[1,2,3],//指定创作集数 仅在episode_create_loop阶段会返回,内容是数组,数组中每一项是指定创作的剧集编号(从1开始);
|
"task_index": 0 //执行中的任务的索引
|
||||||
"node":'',//下一个节点名称,根据指定的agent名称,从取值范围列表中选择一个节点名称返回
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -46,6 +46,7 @@ DefaultAgentPrompt = f"""
|
|||||||
{{
|
{{
|
||||||
"type":'沟通',//回复类型: 沟通:需要跟用户确认或继续沟通时的类型;输出:沟通足够最终给出`改编思路`时的类型;
|
"type":'沟通',//回复类型: 沟通:需要跟用户确认或继续沟通时的类型;输出:沟通足够最终给出`改编思路`时的类型;
|
||||||
"message":'',//回复给用户的话
|
"message":'',//回复给用户的话
|
||||||
|
"total_episode_num":0,//需要创作的总集数
|
||||||
"adaptation_ideas":'',//`改编思路`内容,在type为`输出`时才会有值
|
"adaptation_ideas":'',//`改编思路`内容,在type为`输出`时才会有值
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|||||||
173
doc/test.txt
173
doc/test.txt
@ -1,173 +0,0 @@
|
|||||||
为了在你的代码中正确接入火山引擎的工具调用功能,你需要修改 `huoshan_langchain.py` 和 `huoshan.py` 这两个文件,以实现**从 LangChain 工具到火山引擎 API 工具定义的格式转换**,以及**解析和处理来自 API 的工具调用响应**。
|
|
||||||
|
|
||||||
下面是根据火山引擎官方文档和你的代码,我为你整理的完整修改方案。
|
|
||||||
|
|
||||||
-----
|
|
||||||
|
|
||||||
### 第一步:修改 `huoshan_langchain.py`
|
|
||||||
|
|
||||||
这个文件是 LangChain 的封装层,负责连接你的工作流和火山引擎的底层 API。你需要在这里实现 `bind_tools` 和 `_generate` 方法来处理工具调用。
|
|
||||||
|
|
||||||
1. **导入必要的类**:
|
|
||||||
需要添加 `ToolMessage` 和 `ToolCall`,它们是 LangChain 用于表示工具调用结果和工具调用的核心类。
|
|
||||||
|
|
||||||
```python
|
|
||||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence
|
|
||||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
|
||||||
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall
|
|
||||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
|
||||||
from langchain_core.tools import BaseTool
|
|
||||||
from langchain_core.runnables import Runnable
|
|
||||||
from langchain.pydantic_v1 import BaseModel
|
|
||||||
from pydantic import Field
|
|
||||||
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
|
|
||||||
from api.huoshan import HuoshanAPI
|
|
||||||
import json
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **实现 `bind_tools` 方法**:
|
|
||||||
这个方法是 LangChain 用于将工具定义传递给你的模型封装。在这里,你需要将 LangChain 的 `BaseTool` 对象转换为火山引擎 API 所需的 `ChatCompletionToolParam` 格式。
|
|
||||||
|
|
||||||
```python
|
|
||||||
class HuoshanChatModel(BaseChatModel):
|
|
||||||
# ... (其他代码不变) ...
|
|
||||||
|
|
||||||
def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable:
|
|
||||||
"""将工具绑定到模型,并将其转换为火山引擎API所需的格式。"""
|
|
||||||
tool_definitions = []
|
|
||||||
for tool_item in tools:
|
|
||||||
tool_definitions.append(
|
|
||||||
ChatCompletionToolParam(
|
|
||||||
type="function",
|
|
||||||
function={
|
|
||||||
"name": tool_item.name,
|
|
||||||
"description": tool_item.description,
|
|
||||||
"parameters": tool_item.args_schema.schema() if isinstance(tool_item.args_schema, type(BaseModel)) else tool_item.args_schema
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 返回一个绑定了工具的新实例
|
|
||||||
# 这里我们使用_bind方法,它会返回一个新的Runnable实例
|
|
||||||
return self._bind(tools=tool_definitions, **kwargs)
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **修改 `_convert_messages_to_prompt` 方法**:
|
|
||||||
这个方法需要能够处理 LangChain 的 `ToolMessage` 和 `AIMessage`,并将其转换为火山引擎 API 的消息格式。这对于工具调用的回填和最终回复至关重要。
|
|
||||||
|
|
||||||
```python
|
|
||||||
class HuoshanChatModel(BaseChatModel):
|
|
||||||
# ... (其他代码不变) ...
|
|
||||||
|
|
||||||
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> List[Dict]:
|
|
||||||
"""将LangChain消息转换为火山引擎API所需的格式。"""
|
|
||||||
api_messages = []
|
|
||||||
for msg in messages:
|
|
||||||
if isinstance(msg, HumanMessage):
|
|
||||||
api_messages.append({"role": "user", "content": msg.content})
|
|
||||||
elif isinstance(msg, AIMessage):
|
|
||||||
if msg.tool_calls:
|
|
||||||
api_messages.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "",
|
|
||||||
"tool_calls": [{
|
|
||||||
"id": tc.id,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": tc.name,
|
|
||||||
"arguments": json.dumps(tc.args)
|
|
||||||
}
|
|
||||||
} for tc in msg.tool_calls]
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
api_messages.append({"role": "assistant", "content": msg.content})
|
|
||||||
elif isinstance(msg, ToolMessage):
|
|
||||||
api_messages.append({
|
|
||||||
"role": "tool",
|
|
||||||
"content": msg.content,
|
|
||||||
"tool_call_id": msg.tool_call_id
|
|
||||||
})
|
|
||||||
return api_messages
|
|
||||||
```
|
|
||||||
|
|
||||||
4. **修改 `_generate` 方法**:
|
|
||||||
这个方法需要调用底层 API,并解析大模型返回的响应,以检查是否包含工具调用。
|
|
||||||
|
|
||||||
```python
|
|
||||||
class HuoshanChatModel(BaseChatModel):
|
|
||||||
# ... (其他代码不变) ...
|
|
||||||
|
|
||||||
def _generate(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
if not self._api:
|
|
||||||
raise ValueError("HuoshanAPI未正确初始化")
|
|
||||||
|
|
||||||
api_messages = self._convert_messages_to_prompt(messages)
|
|
||||||
tools = kwargs.get("tools", [])
|
|
||||||
|
|
||||||
response_data = self._api.get_chat_response(messages=api_messages, tools=tools)
|
|
||||||
|
|
||||||
try:
|
|
||||||
message_from_api = response_data.get("choices", [{}])[0].get("message", {})
|
|
||||||
|
|
||||||
tool_calls = message_from_api.get("tool_calls", [])
|
|
||||||
if tool_calls:
|
|
||||||
lc_tool_calls = []
|
|
||||||
for tc in tool_calls:
|
|
||||||
lc_tool_calls.append(ToolCall(
|
|
||||||
name=tc["function"]["name"],
|
|
||||||
args=json.loads(tc["function"]["arguments"]),
|
|
||||||
id=tc.get("id", "")
|
|
||||||
))
|
|
||||||
message = AIMessage(content="", tool_calls=lc_tool_calls)
|
|
||||||
else:
|
|
||||||
content = message_from_api.get("content", "")
|
|
||||||
message = AIMessage(content=content)
|
|
||||||
|
|
||||||
generation = ChatGeneration(message=message)
|
|
||||||
return ChatResult(generations=[generation])
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"处理火山引擎API响应失败: {str(e)}")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 第二步:修改 `huoshan.py`
|
|
||||||
|
|
||||||
这个文件是底层 API 客户端,负责与火山引擎 API 进行通信。你需要修改 `get_chat_response` 方法,使其能够发送 `tools` 参数。
|
|
||||||
|
|
||||||
```python
|
|
||||||
# huoshan.py
|
|
||||||
|
|
||||||
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
|
|
||||||
|
|
||||||
class HuoshanAPI:
|
|
||||||
# ... (其他代码不变) ...
|
|
||||||
|
|
||||||
def get_chat_response(
|
|
||||||
self,
|
|
||||||
messages: List[Dict],
|
|
||||||
stream: bool = False,
|
|
||||||
tools: Optional[List[ChatCompletionToolParam]] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""同步获取聊天响应,支持工具调用。"""
|
|
||||||
client = Ark()
|
|
||||||
|
|
||||||
try:
|
|
||||||
completion = client.chat.completions.create(
|
|
||||||
model=self.doubao_seed_1_6_model_id,
|
|
||||||
messages=messages,
|
|
||||||
stream=stream,
|
|
||||||
tools=tools # 传入 tools 参数
|
|
||||||
)
|
|
||||||
return completion.model_dump() # 使用 model_dump() 转换为字典
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"火山引擎API调用失败: {str(e)}")
|
|
||||||
```
|
|
||||||
|
|
||||||
完成以上修改后,你的 `HuoshanChatModel` 就会支持工具调用,并能与 LangGraph 的智能体框架无缝集成。
|
|
||||||
@ -3,6 +3,7 @@
|
|||||||
该模块定义了智能编剧系统的完整工作流程图,包括各个节点和边的连接关系。
|
该模块定义了智能编剧系统的完整工作流程图,包括各个节点和边的连接关系。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from re import T
|
||||||
from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional
|
from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional
|
||||||
|
|
||||||
from langgraph.graph.state import RunnableConfig
|
from langgraph.graph.state import RunnableConfig
|
||||||
@ -12,6 +13,7 @@ from agent.episode_create import EpisodeCreateAgent
|
|||||||
from agent.script_analysis import ScriptAnalysisAgent
|
from agent.script_analysis import ScriptAnalysisAgent
|
||||||
from agent.strategic_planning import StrategicPlanningAgent
|
from agent.strategic_planning import StrategicPlanningAgent
|
||||||
|
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
from langchain_core.messages import AnyMessage,HumanMessage
|
from langchain_core.messages import AnyMessage,HumanMessage
|
||||||
from langgraph.graph import StateGraph, START, END
|
from langgraph.graph import StateGraph, START, END
|
||||||
from utils.logger import get_logger
|
from utils.logger import get_logger
|
||||||
@ -26,9 +28,34 @@ from tools.agent.queryDB import QueryOriginalScript
|
|||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
def messages_handler(old_messages, new_messages):
|
def clear_messages(messages):
|
||||||
|
"""清除指定会话的所有消息"""
|
||||||
|
# 剔除历史状态消息
|
||||||
|
exclude_strings = [
|
||||||
|
"---任务状态消息(开始)---",
|
||||||
|
"---原始剧本(开始)---",
|
||||||
|
"---诊断与资产评估报告(开始)---",
|
||||||
|
"---改编思路(开始)---",
|
||||||
|
]
|
||||||
|
messages = [
|
||||||
|
message for message in messages
|
||||||
|
if all(s not in message.content for s in exclude_strings)
|
||||||
|
]
|
||||||
|
# HumanMessage 超过 10 条,删除最早的 1 条
|
||||||
|
if len([message for message in messages if message.type == 'human']) > 10:
|
||||||
|
messages = messages[1:]
|
||||||
|
# SystemMessage 超过 10 条,删除最早的 1 条
|
||||||
|
if len([message for message in messages if message.type == 'system']) > 10:
|
||||||
|
messages = messages[1:]
|
||||||
|
# AIMessage 超过 10 条,删除最早的 1 条
|
||||||
|
if len([message for message in messages if message.type == 'ai']) > 10:
|
||||||
|
messages = messages[1:]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def messages_handler(old_messages: list[AnyMessage], new_messages: list[AnyMessage]):
|
||||||
"""消息合并方法"""
|
"""消息合并方法"""
|
||||||
return new_messages
|
clear_messages(old_messages)
|
||||||
|
return old_messages + new_messages
|
||||||
|
|
||||||
def replace_value(old_val, new_val):
|
def replace_value(old_val, new_val):
|
||||||
"""值覆盖方法"""
|
"""值覆盖方法"""
|
||||||
@ -48,13 +75,24 @@ class OutputState(TypedDict):
|
|||||||
error: Annotated[str, replace_value]
|
error: Annotated[str, replace_value]
|
||||||
agent_message: Annotated[str, replace_value] # 智能体回复
|
agent_message: Annotated[str, replace_value] # 智能体回复
|
||||||
|
|
||||||
class NodeInfo(TypedDict):
|
class TaskItem(TypedDict):
|
||||||
"""工作流信息"""
|
"""任务列表中的每个任务"""
|
||||||
|
agent: Annotated[str, replace_value] # 执行这个任务的智能体名称
|
||||||
step: Annotated[str, replace_value] # 阶段名称 [wait_for_input,script_analysis,strategic_planning,build_bible,episode_create_loop, finish]
|
step: Annotated[str, replace_value] # 阶段名称 [wait_for_input,script_analysis,strategic_planning,build_bible,episode_create_loop, finish]
|
||||||
status: Annotated[str, replace_value] # 当前阶段的状态 [waiting,running,failed,completed]
|
status: Annotated[str, replace_value] # 当前阶段的状态 [waiting,running,failed,completed]
|
||||||
reason: Annotated[str, replace_value] # 失败原因
|
reason: Annotated[str, replace_value] # 失败原因
|
||||||
retry_count: Annotated[int, replace_value] # 重试次数
|
retry_count: Annotated[int, replace_value] # 重试次数
|
||||||
from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent]
|
pause: Annotated[bool, replace_value] # 是否暂停
|
||||||
|
episode_create_num: Annotated[List[int], replace_value] # 指定创作集数 仅在episode_create_loop阶段会返回,内容是数组,数组中每一项是指定创作的剧集编号(从1开始);
|
||||||
|
|
||||||
|
class WorkflowInfo(TypedDict):
|
||||||
|
"""总工作流信息"""
|
||||||
|
is_original_script: Annotated[bool, replace_value] # 是否已提交原始剧本
|
||||||
|
is_script_analysis: Annotated[bool, replace_value] # 是否已生成 诊断与资产评估报告
|
||||||
|
is_strategic_planning: Annotated[bool, replace_value] # 是否已生成 改编思路
|
||||||
|
is_build_bible: Annotated[bool, replace_value] # 是否已生成 剧本圣经
|
||||||
|
is_episode_create_loop: Annotated[bool, replace_value] # 是否已生成 剧集生成循环
|
||||||
|
is_all_episode_created: Annotated[bool, replace_value] # 是否已生成 全部剧集
|
||||||
|
|
||||||
|
|
||||||
class ScriptwriterState(TypedDict, total=False):
|
class ScriptwriterState(TypedDict, total=False):
|
||||||
@ -64,21 +102,24 @@ class ScriptwriterState(TypedDict, total=False):
|
|||||||
session_id: Annotated[str, replace_value]
|
session_id: Annotated[str, replace_value]
|
||||||
from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent]
|
from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent]
|
||||||
|
|
||||||
# 节点间状态
|
|
||||||
next_node: Annotated[str, replace_value] # 下一个节点
|
|
||||||
workflow_step: Annotated[str, replace_value] # 阶段名称 [wait_for_input,script_analysis,strategic_planning,build_bible,episode_create_loop, finish]
|
|
||||||
workflow_status: Annotated[str, replace_value] # 当前阶段的状态 [waiting,running,failed,completed]
|
|
||||||
workflow_reason: Annotated[str, replace_value] # 失败原因
|
|
||||||
workflow_retry_count: Annotated[int, replace_value] # 重试次数
|
|
||||||
|
|
||||||
# 中间状态
|
# 中间状态
|
||||||
task_list: Annotated[List[Dict[str, Any]], replace_value] # 顺序执行的任务列表
|
workflow_info: Annotated[WorkflowInfo, replace_value] # 顺序执行的任务列表
|
||||||
|
task_list: Annotated[List[TaskItem], replace_value] # 顺序执行的任务列表
|
||||||
|
task_index: Annotated[int, replace_value] # 当前执行中的任务索引
|
||||||
|
|
||||||
# 输出数据
|
# 输出数据
|
||||||
agent_message: Annotated[str, replace_value] # 智能体回复
|
agent_message: Annotated[str, replace_value] # 智能体回复
|
||||||
status: Annotated[str, replace_value]
|
status: Annotated[str, replace_value]
|
||||||
error: Annotated[str, replace_value]
|
error: Annotated[str, replace_value]
|
||||||
|
|
||||||
|
AgentNodeMap = {
|
||||||
|
"scheduler": "scheduler_node",
|
||||||
|
"script_analysis": "script_analysis_node",
|
||||||
|
"strategic_planning": "strategic_planning_node",
|
||||||
|
"build_bible": "build_bible_node",
|
||||||
|
"episode_create": "episode_create_node",
|
||||||
|
}
|
||||||
|
|
||||||
class ScriptwriterGraph:
|
class ScriptwriterGraph:
|
||||||
"""智能编剧工作流图类
|
"""智能编剧工作流图类
|
||||||
|
|
||||||
@ -100,12 +141,15 @@ class ScriptwriterGraph:
|
|||||||
|
|
||||||
def node_router(self, state: ScriptwriterState) -> str:
|
def node_router(self, state: ScriptwriterState) -> str:
|
||||||
"""节点路由函数"""
|
"""节点路由函数"""
|
||||||
print(f'node_router state {state}')
|
# print(f'node_router state {state}')
|
||||||
next_node = state.get("next_node", 'pause_node')
|
task_list = state.get("task_list", [])
|
||||||
# 修复:当 next_node 为空字符串时,设置默认值
|
task_index = state.get("task_index", 0)
|
||||||
if not next_node:
|
now_task = task_list[task_index]
|
||||||
|
if not now_task or now_task.get('pause'):
|
||||||
next_node = 'pause_node' # 设置为暂停节点
|
next_node = 'pause_node' # 设置为暂停节点
|
||||||
print(f'node_router next_node {next_node}')
|
else:
|
||||||
|
next_node = AgentNodeMap.get(now_task.get('agent'), 'pause_node')
|
||||||
|
# print(f'node_router next_node {next_node}')
|
||||||
return next_node
|
return next_node
|
||||||
|
|
||||||
def _build_graph(self) -> None:
|
def _build_graph(self) -> None:
|
||||||
@ -222,63 +266,41 @@ class ScriptwriterGraph:
|
|||||||
workflow.add_edge("end_node", END)
|
workflow.add_edge("end_node", END)
|
||||||
|
|
||||||
# 编译图
|
# 编译图
|
||||||
self.graph = workflow.compile(checkpointer=self.memory)
|
checkpoint = InMemorySaver()
|
||||||
|
self.graph = workflow.compile(checkpointer=checkpoint) # 不缓存记忆
|
||||||
|
# self.graph = workflow.compile(checkpointer=self.memory) # 使用mongodb缓存记忆
|
||||||
logger.info("工作流图构建完成")
|
logger.info("工作流图构建完成")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"构建工作流图失败: {e}")
|
logger.error(f"构建工作流图失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def clear_messages(self, messages):
|
|
||||||
"""清除指定会话的所有消息"""
|
|
||||||
# 剔除历史状态消息
|
|
||||||
messages = [message for message in messages if "---任务状态消息(开始)---" not in message.content ]
|
|
||||||
# HumanMessage 超过 10 条,删除最早的 1 条
|
|
||||||
if len([message for message in messages if message.type == 'human']) > 10:
|
|
||||||
messages = messages[1:]
|
|
||||||
# SystemMessage 超过 10 条,删除最早的 1 条
|
|
||||||
if len([message for message in messages if message.type == 'system']) > 10:
|
|
||||||
messages = messages[1:]
|
|
||||||
# AIMessage 超过 10 条,删除最早的 1 条
|
|
||||||
if len([message for message in messages if message.type == 'ai']) > 10:
|
|
||||||
messages = messages[1:]
|
|
||||||
return messages
|
|
||||||
|
|
||||||
# --- 定义图中的节点 ---
|
# --- 定义图中的节点 ---
|
||||||
async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
||||||
"""调度节点"""
|
"""调度节点"""
|
||||||
try:
|
try:
|
||||||
status = state.get("status", "")
|
|
||||||
session_id = state.get("session_id", "")
|
session_id = state.get("session_id", "")
|
||||||
|
workflow_info = state.get("workflow_info", {})
|
||||||
|
task_list = state.get("task_list", [])
|
||||||
|
task_index = int(state.get("task_index", 0))
|
||||||
from_type = state.get("from_type", "")
|
from_type = state.get("from_type", "")
|
||||||
messages = state.get("messages", [])
|
messages = state.get("messages", [])
|
||||||
if status == 'failed':
|
|
||||||
return {
|
|
||||||
"next_node":'end_node',
|
|
||||||
"agent_message": state.get("agent_message", ""),
|
|
||||||
"error": state.get("error", '系统错误,工作流已终止'),
|
|
||||||
'status':'failed',
|
|
||||||
}
|
|
||||||
# 清除历史状态消息
|
# 清除历史状态消息
|
||||||
messages = self.clear_messages(messages)
|
messages = clear_messages(messages)
|
||||||
workflow_step = state.get("workflow_step", "wait_for_input")
|
|
||||||
workflow_status = state.get("workflow_status", "waiting")
|
|
||||||
workflow_reason = state.get("workflow_reason", "")
|
|
||||||
workflow_retry_count = int(state.get("workflow_retry_count", 0))
|
|
||||||
# 添加参数进提示词
|
# 添加参数进提示词
|
||||||
messages.append(HumanMessage(content=f"""
|
messages.append(HumanMessage(content=f"""
|
||||||
---任务状态消息(开始)---
|
---任务状态消息(开始)---
|
||||||
# 总任务的进度与任务状态:
|
# 工作流信息:
|
||||||
|
{workflow_info}
|
||||||
|
# 工作流参数:
|
||||||
{{
|
{{
|
||||||
'query_args':{{
|
|
||||||
'session_id':'{session_id}',
|
'session_id':'{session_id}',
|
||||||
}},
|
'task_index':'{task_index}',
|
||||||
'step':'{workflow_step}',
|
|
||||||
'status':'{workflow_status}',
|
|
||||||
'from_type':'{from_type}',
|
'from_type':'{from_type}',
|
||||||
'reason':'{workflow_reason}',
|
|
||||||
'retry_count':{workflow_retry_count},
|
|
||||||
}}
|
}}
|
||||||
|
# 任务列表:
|
||||||
|
{task_list}
|
||||||
---任务状态消息(结束)---
|
---任务状态消息(结束)---
|
||||||
"""))
|
"""))
|
||||||
system_message_count = 0
|
system_message_count = 0
|
||||||
@ -291,37 +313,26 @@ class ScriptwriterGraph:
|
|||||||
human_message_count += 1
|
human_message_count += 1
|
||||||
elif message.type == 'ai':
|
elif message.type == 'ai':
|
||||||
ai_message_count += 1
|
ai_message_count += 1
|
||||||
logger.info(f"调度节点 {session_id} 输入消息条数: {len(messages)} from_type:{from_type} system_message_count:{system_message_count} human_message_count:{human_message_count} ai_message_count:{ai_message_count}")
|
logger.info(f"调度节点 {session_id} \n 输入消息条数: {len(messages)} \n from_type:{from_type} \n system_message_count:{system_message_count} \n human_message_count:{human_message_count} \n ai_message_count:{ai_message_count}")
|
||||||
reslut = await self.schedulerAgent.ainvoke(state)
|
# 调用智能体
|
||||||
|
reslut = await self.schedulerAgent.ainvoke({"messages":messages})
|
||||||
ai_message_str = reslut['messages'][-1].content
|
ai_message_str = reslut['messages'][-1].content
|
||||||
ai_message = json.loads(ai_message_str)
|
ai_message = json.loads(ai_message_str)
|
||||||
# logger.info(f"调度节点结果: {ai_message}")
|
# logger.info(f"调度节点结果: {ai_message}")
|
||||||
step:str = ai_message.get('step', '')
|
|
||||||
status:str = ai_message.get('status', '')
|
|
||||||
next_agent:str = ai_message.get('agent', '')
|
|
||||||
return_message:str = ai_message.get('message', '')
|
return_message:str = ai_message.get('message', '')
|
||||||
retry_count:int = int(ai_message.get('retry_count', '0'))
|
task_list:list = ai_message.get('task_list', [])
|
||||||
next_node:str = ai_message.get('node', 'pause_node')
|
task_index:int = int(ai_message.get('task_index', '0'))
|
||||||
if next_node == 'scheduler_node':
|
|
||||||
# 返回自身 代表暂停
|
|
||||||
print(f"调度节点 暂停等待")
|
|
||||||
return {
|
return {
|
||||||
|
"messages": messages,
|
||||||
|
"task_list": task_list,
|
||||||
|
"task_index": task_index,
|
||||||
"agent_message": return_message,
|
"agent_message": return_message,
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"next_node":next_node,
|
|
||||||
"workflow_step":step,
|
|
||||||
"workflow_status":status,
|
|
||||||
# "workflow_reason":return_message,
|
|
||||||
"workflow_retry_count":retry_count,
|
|
||||||
"agent_message":return_message,
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# import traceback
|
import traceback
|
||||||
# traceback.print_exc()
|
traceback.print_exc()
|
||||||
return {
|
return {
|
||||||
"next_node":'end_node',
|
|
||||||
"agent_message": "执行失败",
|
"agent_message": "执行失败",
|
||||||
"error": str(e) or '系统错误,工作流已终止',
|
"error": str(e) or '系统错误,工作流已终止',
|
||||||
'status':'failed',
|
'status':'failed',
|
||||||
@ -335,53 +346,47 @@ class ScriptwriterGraph:
|
|||||||
from_type = state.get("from_type", "")
|
from_type = state.get("from_type", "")
|
||||||
messages = state.get("messages", [])
|
messages = state.get("messages", [])
|
||||||
# 清除历史状态消息
|
# 清除历史状态消息
|
||||||
messages = self.clear_messages(messages)
|
messages = clear_messages(messages)
|
||||||
workflow_step = state.get("workflow_step", "wait_for_input")
|
|
||||||
workflow_status = state.get("workflow_status", "waiting")
|
|
||||||
workflow_reason = state.get("workflow_reason", "")
|
|
||||||
workflow_retry_count = int(state.get("workflow_retry_count", 0))
|
|
||||||
# 添加参数进提示词
|
# 添加参数进提示词
|
||||||
|
from tools.agent.queryDB import QueryOriginalScriptContent
|
||||||
|
original_script_content = QueryOriginalScriptContent(session_id)
|
||||||
messages.append(HumanMessage(content=f"""
|
messages.append(HumanMessage(content=f"""
|
||||||
---任务状态消息(开始)---
|
---原始剧本(开始)---
|
||||||
# 总任务的进度与任务状态:
|
{original_script_content['content']}
|
||||||
{{
|
---原始剧本(结束)---
|
||||||
'query_args':{{
|
|
||||||
'session_id':'{session_id}',
|
|
||||||
}},
|
|
||||||
'step':'{workflow_step}',
|
|
||||||
'status':'{workflow_status}',
|
|
||||||
'from_type':'{from_type}',
|
|
||||||
'reason':'{workflow_reason}',
|
|
||||||
'retry_count':{workflow_retry_count},
|
|
||||||
}}
|
|
||||||
---任务状态消息(结束)---
|
|
||||||
"""))
|
"""))
|
||||||
reslut = await self.scriptAnalysisAgent.ainvoke(state)
|
reslut = await self.scriptAnalysisAgent.ainvoke({"messages": messages})
|
||||||
ai_message_str = reslut['messages'][-1].content
|
ai_message_str = reslut['messages'][-1].content
|
||||||
ai_message = json.loads(ai_message_str)
|
ai_message = json.loads(ai_message_str)
|
||||||
# logger.info(f"调度节点结果: {ai_message}")
|
|
||||||
step:str = ai_message.get('step', '')
|
|
||||||
status:str = ai_message.get('status', '')
|
|
||||||
next_agent:str = ai_message.get('agent', '')
|
|
||||||
return_message:str = ai_message.get('message', '')
|
return_message:str = ai_message.get('message', '')
|
||||||
retry_count:int = int(ai_message.get('retry_count', '0'))
|
|
||||||
next_node:str = ai_message.get('node', 'pause_node')
|
task_list:list = state.get('task_list', [])
|
||||||
|
task_index:int = int(state.get('task_index', '0'))
|
||||||
|
task = task_list[task_index]
|
||||||
|
if task:
|
||||||
|
type:str = ai_message.get('type', '沟通')
|
||||||
|
if type == '沟通':
|
||||||
|
task['status'] = 'waiting'
|
||||||
|
task['parse'] = True
|
||||||
|
elif type == '输出':
|
||||||
|
task['status'] = 'completed'
|
||||||
|
task['parse'] = False
|
||||||
|
diagnosis_and_assessment:str = ai_message.get('diagnosis_and_assessment', '')
|
||||||
|
from tools.agent.updateDB import UpdateDiagnosisAndAssessment
|
||||||
|
UpdateDiagnosisAndAssessment(session_id, diagnosis_and_assessment)
|
||||||
# print(f"报告已生成: TEST")
|
# print(f"报告已生成: TEST")
|
||||||
print("\n------------ 诊断分析结束 ------------")
|
print("\n------------ 诊断分析结束 ------------")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"from_type":'agent',
|
"messages": messages,
|
||||||
"next_node":next_node,
|
"task_list": task_list,
|
||||||
"workflow_step":step,
|
"task_index": task_index,
|
||||||
"workflow_status":status,
|
"agent_message": return_message,
|
||||||
# "workflow_reason":return_message,
|
|
||||||
"workflow_retry_count":retry_count,
|
|
||||||
"agent_message":return_message,
|
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return {
|
return {
|
||||||
"next_node":'end_node',
|
|
||||||
"agent_message": "诊断分析失败",
|
"agent_message": "诊断分析失败",
|
||||||
"error": str(e) or '系统错误,工作流已终止',
|
"error": str(e) or '系统错误,工作流已终止',
|
||||||
'status':'failed',
|
'status':'failed',
|
||||||
@ -390,59 +395,61 @@ class ScriptwriterGraph:
|
|||||||
async def strategic_planning_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
async def strategic_planning_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
||||||
"""第三步:确立改编目标与战略蓝图"""
|
"""第三步:确立改编目标与战略蓝图"""
|
||||||
try:
|
try:
|
||||||
print("\n------------ 正在制定战略蓝图 ------------")
|
print("\n------------ 正在生成 改编思路 ------------")
|
||||||
session_id = state.get("session_id", "")
|
session_id = state.get("session_id", "")
|
||||||
from_type = state.get("from_type", "")
|
from_type = state.get("from_type", "")
|
||||||
messages = state.get("messages", [])
|
messages = state.get("messages", [])
|
||||||
# 清除历史状态消息
|
# 清除历史状态消息
|
||||||
messages = self.clear_messages(messages)
|
messages = clear_messages(messages)
|
||||||
workflow_step = state.get("workflow_step", "wait_for_input")
|
|
||||||
workflow_status = state.get("workflow_status", "waiting")
|
|
||||||
workflow_reason = state.get("workflow_reason", "")
|
|
||||||
workflow_retry_count = int(state.get("workflow_retry_count", 0))
|
|
||||||
# 添加参数进提示词
|
# 添加参数进提示词
|
||||||
|
from tools.agent.queryDB import QueryOriginalScriptContent, QueryDiagnosisAndAssessmentContent
|
||||||
|
original_script_content = QueryOriginalScriptContent(session_id)
|
||||||
messages.append(HumanMessage(content=f"""
|
messages.append(HumanMessage(content=f"""
|
||||||
---任务状态消息(开始)---
|
---原始剧本(开始)---
|
||||||
# 总任务的进度与任务状态:
|
{original_script_content['content']}
|
||||||
{{
|
---原始剧本(结束)---
|
||||||
'query_args':{{
|
|
||||||
'session_id':'{session_id}',
|
|
||||||
}},
|
|
||||||
'step':'{workflow_step}',
|
|
||||||
'status':'{workflow_status}',
|
|
||||||
'from_type':'{from_type}',
|
|
||||||
'reason':'{workflow_reason}',
|
|
||||||
'retry_count':{workflow_retry_count},
|
|
||||||
}}
|
|
||||||
---任务状态消息(结束)---
|
|
||||||
"""))
|
"""))
|
||||||
reslut = await self.strategicPlanningAgent.ainvoke(state)
|
diagnosis_and_assessment_content = QueryDiagnosisAndAssessmentContent(session_id)
|
||||||
|
messages.append(HumanMessage(content=f"""
|
||||||
|
---诊断与资产评估报告(开始)---
|
||||||
|
{diagnosis_and_assessment_content['content']}
|
||||||
|
---诊断与资产评估报告(结束)---
|
||||||
|
"""))
|
||||||
|
reslut = await self.strategicPlanningAgent.ainvoke({"messages": messages})
|
||||||
ai_message_str = reslut['messages'][-1].content
|
ai_message_str = reslut['messages'][-1].content
|
||||||
ai_message = json.loads(ai_message_str)
|
ai_message = json.loads(ai_message_str)
|
||||||
# logger.info(f"调度节点结果: {ai_message}")
|
|
||||||
step:str = ai_message.get('step', '')
|
|
||||||
status:str = ai_message.get('status', '')
|
|
||||||
next_agent:str = ai_message.get('agent', '')
|
|
||||||
return_message:str = ai_message.get('message', '')
|
return_message:str = ai_message.get('message', '')
|
||||||
retry_count:int = int(ai_message.get('retry_count', '0'))
|
|
||||||
next_node:str = ai_message.get('node', 'pause_node')
|
task_list:list = state.get('task_list', [])
|
||||||
|
task_index:int = int(state.get('task_index', '0'))
|
||||||
|
task = task_list[task_index]
|
||||||
|
if task:
|
||||||
|
type:str = ai_message.get('type', '沟通')
|
||||||
|
if type == '沟通':
|
||||||
|
task['status'] = 'waiting'
|
||||||
|
task['parse'] = True
|
||||||
|
elif type == '输出':
|
||||||
|
task['status'] = 'completed'
|
||||||
|
task['parse'] = False
|
||||||
|
adaptation_ideas:str = ai_message.get('adaptation_ideas', '')
|
||||||
|
total_episode_num:int = int(ai_message.get('total_episode_num', 0))
|
||||||
|
from tools.agent.updateDB import UpdateAdaptationIdeas,SetTotalEpisodeNum
|
||||||
|
UpdateAdaptationIdeas(session_id, adaptation_ideas)
|
||||||
|
SetTotalEpisodeNum(session_id, total_episode_num)
|
||||||
# print(f"报告已生成: TEST")
|
# print(f"报告已生成: TEST")
|
||||||
print("\n------------ 制定战略蓝图结束 ------------")
|
print("\n------------ 生成 改编思路 结束 ------------")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"from_type":'agent',
|
"messages": messages,
|
||||||
"next_node":next_node,
|
"task_list": task_list,
|
||||||
"workflow_step":step,
|
"task_index": task_index,
|
||||||
"workflow_status":status,
|
"agent_message": return_message,
|
||||||
# "workflow_reason":return_message,
|
|
||||||
"workflow_retry_count":retry_count,
|
|
||||||
"agent_message":return_message,
|
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return {
|
return {
|
||||||
"next_node":'end_node',
|
"agent_message": "生成 改编思路 失败",
|
||||||
"agent_message": "制定战略蓝图失败",
|
|
||||||
"error": str(e) or '系统错误,工作流已终止',
|
"error": str(e) or '系统错误,工作流已终止',
|
||||||
'status':'failed',
|
'status':'failed',
|
||||||
}
|
}
|
||||||
@ -450,120 +457,134 @@ class ScriptwriterGraph:
|
|||||||
async def build_bible_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
async def build_bible_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
||||||
"""第四步:制定剧本圣经"""
|
"""第四步:制定剧本圣经"""
|
||||||
try:
|
try:
|
||||||
print("\n------------ 正在制定剧本圣经 ------------")
|
print("\n------------ 正在生成 剧本圣经 ------------")
|
||||||
session_id = state.get("session_id", "")
|
session_id = state.get("session_id", "")
|
||||||
from_type = state.get("from_type", "")
|
from_type = state.get("from_type", "")
|
||||||
messages = state.get("messages", [])
|
messages = state.get("messages", [])
|
||||||
# 清除历史状态消息
|
# 清除历史状态消息
|
||||||
messages = self.clear_messages(messages)
|
messages = clear_messages(messages)
|
||||||
workflow_step = state.get("workflow_step", "wait_for_input")
|
|
||||||
workflow_status = state.get("workflow_status", "waiting")
|
|
||||||
workflow_reason = state.get("workflow_reason", "")
|
|
||||||
workflow_retry_count = int(state.get("workflow_retry_count", 0))
|
|
||||||
# 添加参数进提示词
|
# 添加参数进提示词
|
||||||
|
from tools.agent.queryDB import QueryOriginalScriptContent, QueryAdaptationIdeasContent
|
||||||
|
original_script_content = QueryOriginalScriptContent(session_id)
|
||||||
messages.append(HumanMessage(content=f"""
|
messages.append(HumanMessage(content=f"""
|
||||||
---任务状态消息(开始)---
|
---原始剧本(开始)---
|
||||||
# 总任务的进度与任务状态:
|
{original_script_content['content']}
|
||||||
{{
|
---原始剧本(结束)---
|
||||||
'query_args':{{
|
|
||||||
'session_id':'{session_id}',
|
|
||||||
}},
|
|
||||||
'step':'{workflow_step}',
|
|
||||||
'status':'{workflow_status}',
|
|
||||||
'from_type':'{from_type}',
|
|
||||||
'reason':'{workflow_reason}',
|
|
||||||
'retry_count':{workflow_retry_count},
|
|
||||||
}}
|
|
||||||
---任务状态消息(结束)---
|
|
||||||
"""))
|
"""))
|
||||||
reslut = await self.buildBibleAgent.ainvoke(state)
|
adaptation_ideas_content = QueryAdaptationIdeasContent(session_id)
|
||||||
|
messages.append(HumanMessage(content=f"""
|
||||||
|
---改编思路(开始)---
|
||||||
|
{adaptation_ideas_content['content']}
|
||||||
|
---改编思路(结束)---
|
||||||
|
"""))
|
||||||
|
reslut = await self.buildBibleAgent.ainvoke({"messages": messages})
|
||||||
ai_message_str = reslut['messages'][-1].content
|
ai_message_str = reslut['messages'][-1].content
|
||||||
ai_message = json.loads(ai_message_str)
|
ai_message = json.loads(ai_message_str)
|
||||||
# logger.info(f"调度节点结果: {ai_message}")
|
|
||||||
step:str = ai_message.get('step', '')
|
|
||||||
status:str = ai_message.get('status', '')
|
|
||||||
next_agent:str = ai_message.get('agent', '')
|
|
||||||
return_message:str = ai_message.get('message', '')
|
return_message:str = ai_message.get('message', '')
|
||||||
retry_count:int = int(ai_message.get('retry_count', '0'))
|
|
||||||
next_node:str = ai_message.get('node', 'pause_node')
|
task_list:list = state.get('task_list', [])
|
||||||
|
task_index:int = int(state.get('task_index', '0'))
|
||||||
|
task = task_list[task_index]
|
||||||
|
if task:
|
||||||
|
type:str = ai_message.get('type', '沟通')
|
||||||
|
if type == '沟通':
|
||||||
|
task['status'] = 'waiting'
|
||||||
|
task['parse'] = True
|
||||||
|
elif type == '输出':
|
||||||
|
task['status'] = 'completed'
|
||||||
|
task['parse'] = False
|
||||||
|
script_bible:dict = ai_message.get('script_bible', {})
|
||||||
|
core_outline:str = script_bible.get('core_outline', '')
|
||||||
|
character_profile:str = script_bible.get('character_profile', '')
|
||||||
|
core_event_timeline:str = script_bible.get('core_event_timeline', '')
|
||||||
|
character_list:str = script_bible.get('character_list', '')
|
||||||
|
from tools.agent.updateDB import UpdateScriptBible
|
||||||
|
UpdateScriptBible(session_id, core_outline, character_profile, core_event_timeline, character_list)
|
||||||
# print(f"报告已生成: TEST")
|
# print(f"报告已生成: TEST")
|
||||||
print("\n------------ 制定剧本圣经结束 ------------")
|
print("\n------------ 生成 剧本圣经 结束 ------------")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"from_type":'agent',
|
"messages": messages,
|
||||||
"next_node":next_node,
|
"task_list": task_list,
|
||||||
"workflow_step":step,
|
"task_index": task_index,
|
||||||
"workflow_status":status,
|
"agent_message": return_message,
|
||||||
# "workflow_reason":return_message,
|
|
||||||
"workflow_retry_count":retry_count,
|
|
||||||
"agent_message":return_message,
|
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return {
|
return {
|
||||||
"next_node":'end_node',
|
"agent_message": "生成 剧本圣经 失败",
|
||||||
"agent_message": "制定剧本圣经失败",
|
|
||||||
"error": str(e) or '系统错误,工作流已终止',
|
"error": str(e) or '系统错误,工作流已终止',
|
||||||
'status':'failed',
|
'status':'failed',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def episode_create_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
async def episode_create_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
||||||
"""第五步:动态创作与闭环校验(循环主体)"""
|
"""第五步:循环创作剧本内容"""
|
||||||
try:
|
try:
|
||||||
print("\n------------ 正在创作单集内容 ------------")
|
|
||||||
session_id = state.get("session_id", "")
|
session_id = state.get("session_id", "")
|
||||||
from_type = state.get("from_type", "")
|
from_type = state.get("from_type", "")
|
||||||
|
task_list:list = state.get('task_list', [])
|
||||||
|
task_index:int = int(state.get('task_index', '0'))
|
||||||
|
task:dict = task_list[task_index] or {}
|
||||||
|
episode_create_num:list = task.get('episode_create_num', [])
|
||||||
|
print(f"\n------------ 正在生成 剧集内容 {episode_create_num} ------------")
|
||||||
messages = state.get("messages", [])
|
messages = state.get("messages", [])
|
||||||
# 清除历史状态消息
|
# 清除历史状态消息
|
||||||
messages = self.clear_messages(messages)
|
messages = clear_messages(messages)
|
||||||
workflow_step = state.get("workflow_step", "wait_for_input")
|
# 添加参数进提示词
|
||||||
workflow_status = state.get("workflow_status", "waiting")
|
from tools.agent.queryDB import QueryOriginalScriptContent, QueryAdaptationIdeasContent
|
||||||
workflow_reason = state.get("workflow_reason", "")
|
original_script_content = QueryOriginalScriptContent(session_id)
|
||||||
workflow_retry_count = int(state.get("workflow_retry_count", 0))
|
messages.append(HumanMessage(content=f"""
|
||||||
|
---原始剧本(开始)---
|
||||||
|
{original_script_content['content']}
|
||||||
|
---原始剧本(结束)---
|
||||||
|
"""))
|
||||||
|
adaptation_ideas_content = QueryAdaptationIdeasContent(session_id)
|
||||||
|
messages.append(HumanMessage(content=f"""
|
||||||
|
---改编思路(开始)---
|
||||||
|
{adaptation_ideas_content['content']}
|
||||||
|
---改编思路(结束)---
|
||||||
|
"""))
|
||||||
# 添加参数进提示词
|
# 添加参数进提示词
|
||||||
messages.append(HumanMessage(content=f"""
|
messages.append(HumanMessage(content=f"""
|
||||||
---任务状态消息(开始)---
|
---任务状态消息(开始)---
|
||||||
# 总任务的进度与任务状态:
|
# 工作流参数:
|
||||||
{{
|
{{
|
||||||
'query_args':{{
|
'episode_create_num':'{episode_create_num}',
|
||||||
'session_id':'{session_id}',
|
|
||||||
}},
|
|
||||||
'step':'{workflow_step}',
|
|
||||||
'status':'{workflow_status}',
|
|
||||||
'from_type':'{from_type}',
|
|
||||||
'reason':'{workflow_reason}',
|
|
||||||
'retry_count':{workflow_retry_count},
|
|
||||||
}}
|
}}
|
||||||
---任务状态消息(结束)---
|
---任务状态消息(结束)---
|
||||||
"""))
|
"""))
|
||||||
reslut = await self.episodeCreateAgent.ainvoke(state)
|
reslut = await self.buildBibleAgent.ainvoke({"messages": messages})
|
||||||
ai_message_str = reslut['messages'][-1].content
|
ai_message_str = reslut['messages'][-1].content
|
||||||
ai_message = json.loads(ai_message_str)
|
ai_message = json.loads(ai_message_str)
|
||||||
# logger.info(f"调度节点结果: {ai_message}")
|
|
||||||
step:str = ai_message.get('step', '')
|
|
||||||
status:str = ai_message.get('status', '')
|
|
||||||
next_agent:str = ai_message.get('agent', '')
|
|
||||||
return_message:str = ai_message.get('message', '')
|
return_message:str = ai_message.get('message', '')
|
||||||
retry_count:int = int(ai_message.get('retry_count', '0'))
|
|
||||||
next_node:str = ai_message.get('node', 'pause_node')
|
type:str = ai_message.get('type', '沟通')
|
||||||
|
if type == '沟通':
|
||||||
|
task['status'] = 'waiting'
|
||||||
|
task['parse'] = True
|
||||||
|
elif type == '输出':
|
||||||
|
task['status'] = 'completed'
|
||||||
|
task['parse'] = False
|
||||||
|
episodes:list = ai_message.get('episodes', [])
|
||||||
|
from tools.agent.updateDB import UpdateOneEpisode
|
||||||
|
for episode in episodes:
|
||||||
|
UpdateOneEpisode(session_id, episode.number, episode.content)
|
||||||
# print(f"报告已生成: TEST")
|
# print(f"报告已生成: TEST")
|
||||||
print("\n------------ 创作单集内容结束 ------------")
|
print("\n------------ 生成 剧本圣经 结束 ------------")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"from_type":'agent',
|
"messages": messages,
|
||||||
"next_node":next_node,
|
"task_list": task_list,
|
||||||
"workflow_step":step,
|
"task_index": task_index,
|
||||||
"workflow_status":status,
|
"agent_message": return_message,
|
||||||
# "workflow_reason":return_message,
|
|
||||||
"workflow_retry_count":retry_count,
|
|
||||||
"agent_message":return_message,
|
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return {
|
return {
|
||||||
"next_node":'end_node',
|
"agent_message": "生成 剧本圣经 失败",
|
||||||
"agent_message": "创作单集内容失败",
|
|
||||||
"error": str(e) or '系统错误,工作流已终止',
|
"error": str(e) or '系统错误,工作流已终止',
|
||||||
'status':'failed',
|
'status':'failed',
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,84 +0,0 @@
|
|||||||
from flask import request, jsonify
|
|
||||||
import asyncio
|
|
||||||
import uuid
|
|
||||||
from graph.test_graph_3 import run_with_persistence, get_checkpoint_history, resume_from_checkpoint
|
|
||||||
|
|
||||||
def run_async(coro):
|
|
||||||
"""运行异步函数的辅助函数"""
|
|
||||||
try:
|
|
||||||
# 尝试使用现有的事件循环
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
# 如果没有运行中的事件循环,则创建一个新的
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
return loop.run_until_complete(coro)
|
|
||||||
|
|
||||||
def run_langgraph():
|
|
||||||
"""启动一个新的langgraph任务"""
|
|
||||||
try:
|
|
||||||
data = request.get_json()
|
|
||||||
user_input = data.get('user_input', '')
|
|
||||||
thread_id = data.get('thread_id', str(uuid.uuid4()))
|
|
||||||
|
|
||||||
if not user_input:
|
|
||||||
return jsonify({'error': 'user_input is required'}), 400
|
|
||||||
|
|
||||||
# 运行异步函数
|
|
||||||
output, thread_id = run_async(run_with_persistence(user_input, thread_id))
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'status': 'success',
|
|
||||||
'thread_id': thread_id,
|
|
||||||
'output': output
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({'error': str(e)}), 500
|
|
||||||
|
|
||||||
def get_task_status(thread_id):
|
|
||||||
"""查询任务状态和历史"""
|
|
||||||
try:
|
|
||||||
# 获取检查点历史
|
|
||||||
# 注意:这里需要修改get_checkpoint_history以返回数据而不是打印
|
|
||||||
# history = run_async(get_checkpoint_history(thread_id))
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'status': 'success',
|
|
||||||
'thread_id': thread_id,
|
|
||||||
'message': 'Task status endpoint'
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({'error': str(e)}), 500
|
|
||||||
|
|
||||||
def resume_task(thread_id):
|
|
||||||
"""从检查点恢复任务"""
|
|
||||||
try:
|
|
||||||
data = request.get_json()
|
|
||||||
checkpoint_id = data.get('checkpoint_id')
|
|
||||||
|
|
||||||
# 恢复检查点状态
|
|
||||||
restored_state = resume_from_checkpoint(thread_id, checkpoint_id)
|
|
||||||
|
|
||||||
if restored_state:
|
|
||||||
return jsonify({
|
|
||||||
'status': 'success',
|
|
||||||
'restored_state': restored_state
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
return jsonify({'error': 'Failed to restore checkpoint'}), 404
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({'error': str(e)}), 500
|
|
||||||
|
|
||||||
def visualize_graph(thread_id):
|
|
||||||
"""可视化图结构"""
|
|
||||||
try:
|
|
||||||
# 这里可以返回图的可视化信息
|
|
||||||
# 为了简化,我们只返回基本信息
|
|
||||||
return jsonify({
|
|
||||||
'status': 'success',
|
|
||||||
'thread_id': thread_id,
|
|
||||||
'message': 'Graph visualization endpoint'
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({'error': str(e)}), 500
|
|
||||||
@ -17,7 +17,22 @@ def QueryOriginalScript(session_id: str):
|
|||||||
"exist": script is not None,
|
"exist": script is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def QueryOriginalScriptContent(session_id: str):
|
||||||
|
"""
|
||||||
|
查询原始剧本内容
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
content (str): 原始剧本内容。
|
||||||
|
"""
|
||||||
|
script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "original_script": {"$exists": True, "$ne": ""}},{"original_script":1})
|
||||||
|
return {
|
||||||
|
"content": script["original_script"] if script else "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
def QueryDiagnosisAndAssessment(session_id: str):
|
def QueryDiagnosisAndAssessment(session_id: str):
|
||||||
"""
|
"""
|
||||||
查询诊断与资产评估报告是否存在
|
查询诊断与资产评估报告是否存在
|
||||||
@ -32,6 +47,21 @@ def QueryDiagnosisAndAssessment(session_id: str):
|
|||||||
"exist": script is not None,
|
"exist": script is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def QueryDiagnosisAndAssessmentContent(session_id: str):
|
||||||
|
"""
|
||||||
|
查询诊断与资产评估报告内容
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
content (str): 诊断与资产评估报告内容。
|
||||||
|
"""
|
||||||
|
script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "diagnosis_and_assessment": {"$exists": True, "$ne": ""}},{"diagnosis_and_assessment":1})
|
||||||
|
return {
|
||||||
|
"content": script["diagnosis_and_assessment"] if script else "",
|
||||||
|
}
|
||||||
|
|
||||||
|
@tool
|
||||||
def QueryAdaptationIdeas(session_id: str):
|
def QueryAdaptationIdeas(session_id: str):
|
||||||
"""
|
"""
|
||||||
查询改编思路是否存在
|
查询改编思路是否存在
|
||||||
@ -46,6 +76,21 @@ def QueryAdaptationIdeas(session_id: str):
|
|||||||
"exist": script is not None,
|
"exist": script is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def QueryAdaptationIdeasContent(session_id: str):
|
||||||
|
"""
|
||||||
|
查询改编思路内容
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
content (str): 改编思路内容。
|
||||||
|
"""
|
||||||
|
script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "adaptation_ideas": {"$exists": True, "$ne": ""}},{"adaptation_ideas":1})
|
||||||
|
return {
|
||||||
|
"content": script["adaptation_ideas"] if script else "",
|
||||||
|
}
|
||||||
|
|
||||||
|
@tool
|
||||||
def QueryScriptBible(session_id: str):
|
def QueryScriptBible(session_id: str):
|
||||||
"""
|
"""
|
||||||
查询剧本圣经是否存在
|
查询剧本圣经是否存在
|
||||||
@ -60,6 +105,21 @@ def QueryScriptBible(session_id: str):
|
|||||||
"exist": script is not None,
|
"exist": script is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def QueryScriptBibleContent(session_id: str):
|
||||||
|
"""
|
||||||
|
查询剧本圣经内容
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
content (str): 剧本圣经内容。
|
||||||
|
"""
|
||||||
|
script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "script_bible": {"$exists": True}},{"script_bible":1})
|
||||||
|
return {
|
||||||
|
"content": script["script_bible"] if script else {},
|
||||||
|
}
|
||||||
|
|
||||||
|
@tool
|
||||||
def QueryCoreOutline(session_id: str):
|
def QueryCoreOutline(session_id: str):
|
||||||
"""
|
"""
|
||||||
查询剧本圣经中的核心大纲是否存在
|
查询剧本圣经中的核心大纲是否存在
|
||||||
@ -74,6 +134,7 @@ def QueryCoreOutline(session_id: str):
|
|||||||
"exist": script is not None,
|
"exist": script is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@tool
|
||||||
def QueryCharacterProfile(session_id: str):
|
def QueryCharacterProfile(session_id: str):
|
||||||
"""
|
"""
|
||||||
查询剧本圣经中的核心人物小传是否存在
|
查询剧本圣经中的核心人物小传是否存在
|
||||||
@ -88,6 +149,7 @@ def QueryCharacterProfile(session_id: str):
|
|||||||
"exist": script is not None,
|
"exist": script is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@tool
|
||||||
def QueryCoreEventTimeline(session_id: str):
|
def QueryCoreEventTimeline(session_id: str):
|
||||||
"""
|
"""
|
||||||
查询剧本圣经中的重大事件时间线是否存在
|
查询剧本圣经中的重大事件时间线是否存在
|
||||||
@ -102,6 +164,7 @@ def QueryCoreEventTimeline(session_id: str):
|
|||||||
"exist": script is not None,
|
"exist": script is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@tool
|
||||||
def QueryCharacterList(session_id: str):
|
def QueryCharacterList(session_id: str):
|
||||||
"""
|
"""
|
||||||
查询剧本圣经中的总人物表是否存在
|
查询剧本圣经中的总人物表是否存在
|
||||||
@ -116,6 +179,7 @@ def QueryCharacterList(session_id: str):
|
|||||||
"exist": script is not None,
|
"exist": script is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@tool
|
||||||
def QueryEpisodeCount(session_id: str):
|
def QueryEpisodeCount(session_id: str):
|
||||||
"""
|
"""
|
||||||
查询剧集创作情况
|
查询剧集创作情况
|
||||||
@ -126,11 +190,16 @@ def QueryEpisodeCount(session_id: str):
|
|||||||
completed (int): 已完成的集数
|
completed (int): 已完成的集数
|
||||||
total (int): 总集数
|
total (int): 总集数
|
||||||
"""
|
"""
|
||||||
total = mainDB.agent_writer_episodes.count_documents({"session_id": session_id})
|
total = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id)},{"total_episode_num":1})
|
||||||
|
if total is None:
|
||||||
|
return {
|
||||||
|
"completed": 0,
|
||||||
|
"total": 0,
|
||||||
|
}
|
||||||
count = mainDB.agent_writer_episodes.count_documents({"session_id": session_id, "content": {"$exists": True, "$ne": ""}})
|
count = mainDB.agent_writer_episodes.count_documents({"session_id": session_id, "content": {"$exists": True, "$ne": ""}})
|
||||||
return {
|
return {
|
||||||
"completed": count,
|
"completed": count,
|
||||||
"total": total,
|
"total": int(total["total_episode_num"]) or 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# def QuerySingleEpisodeContent(session_id: str):
|
# def QuerySingleEpisodeContent(session_id: str):
|
||||||
|
|||||||
176
tools/agent/updateDB.py
Normal file
176
tools/agent/updateDB.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
from bson import ObjectId
|
||||||
|
from tools.database.mongo import mainDB
|
||||||
|
from langchain.tools import tool
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def UpdateDiagnosisAndAssessmentTool(session_id: str, diagnosis_and_assessment: str):
|
||||||
|
"""
|
||||||
|
更新诊断与资产评估报告
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
diagnosis_and_assessment: 诊断与资产评估报告
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
success (bool): 是否更新成功
|
||||||
|
"""
|
||||||
|
return UpdateDiagnosisAndAssessment(session_id, diagnosis_and_assessment)
|
||||||
|
|
||||||
|
def UpdateDiagnosisAndAssessment(session_id: str, diagnosis_and_assessment: str):
|
||||||
|
"""
|
||||||
|
更新诊断与资产评估报告
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
diagnosis_and_assessment: 诊断与资产评估报告
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
success (bool): 是否更新成功
|
||||||
|
"""
|
||||||
|
script = mainDB.agent_writer_session.update_one({"_id": ObjectId(session_id)},{"$set":{"diagnosis_and_assessment":diagnosis_and_assessment}})
|
||||||
|
return {
|
||||||
|
"success": script.modified_count > 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def UpdateAdaptationIdeasTool(session_id: str, adaptation_ideas: str):
|
||||||
|
"""
|
||||||
|
更新改编思路
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
adaptation_ideas: 改编思路
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
success (bool): 是否更新成功
|
||||||
|
"""
|
||||||
|
return UpdateAdaptationIdeas(session_id, adaptation_ideas)
|
||||||
|
|
||||||
|
def UpdateAdaptationIdeas(session_id: str, adaptation_ideas: str):
|
||||||
|
"""
|
||||||
|
更新改编思路
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
adaptation_ideas: 改编思路
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
success (bool): 是否更新成功
|
||||||
|
"""
|
||||||
|
script = mainDB.agent_writer_session.update_one({"_id": ObjectId(session_id)},{"$set":{"adaptation_ideas":adaptation_ideas}})
|
||||||
|
return {
|
||||||
|
"success": script.modified_count > 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def UpdateScriptBibleTool(
|
||||||
|
session_id: str,
|
||||||
|
core_outline:str|None = None,
|
||||||
|
character_profile:str|None = None,
|
||||||
|
core_event_timeline:str|None = None,
|
||||||
|
character_list:str|None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
更新剧本圣经
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
core_outline: 核心大纲
|
||||||
|
character_profile: 核心人物小传
|
||||||
|
core_event_timeline: 核心事件时间线
|
||||||
|
character_list: 总角色表
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
success (bool): 是否更新成功
|
||||||
|
"""
|
||||||
|
return UpdateScriptBible(session_id, core_outline, character_profile, core_event_timeline, character_list)
|
||||||
|
|
||||||
|
def UpdateScriptBible(
|
||||||
|
session_id: str,
|
||||||
|
core_outline:str|None = None,
|
||||||
|
character_profile:str|None = None,
|
||||||
|
core_event_timeline:str|None = None,
|
||||||
|
character_list:str|None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
更新剧本圣经
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
core_outline: 核心大纲
|
||||||
|
character_profile: 核心人物小传
|
||||||
|
core_event_timeline: 核心事件时间线
|
||||||
|
character_list: 总角色表
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
success (bool): 是否更新成功
|
||||||
|
"""
|
||||||
|
if core_outline is None and character_profile is None and core_event_timeline is None and character_list is None:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
}
|
||||||
|
update_dict = {}
|
||||||
|
if core_outline is not None:
|
||||||
|
update_dict["script_bible.core_outline"] = core_outline
|
||||||
|
if character_profile is not None:
|
||||||
|
update_dict["script_bible.character_profile"] = character_profile
|
||||||
|
if core_event_timeline is not None:
|
||||||
|
update_dict["script_bible.core_event_timeline"] = core_event_timeline
|
||||||
|
if character_list is not None:
|
||||||
|
update_dict["script_bible.character_list"] = character_list
|
||||||
|
script = mainDB.agent_writer_session.update_one({"_id": ObjectId(session_id)},{"$set":update_dict})
|
||||||
|
return {
|
||||||
|
"success": script.modified_count > 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def SetTotalEpisodeNumTool(session_id: str, total_episode_num: int):
|
||||||
|
"""
|
||||||
|
设置总集数
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
total_episode_num: 总集数
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
success (bool): 是否更新成功
|
||||||
|
"""
|
||||||
|
return SetTotalEpisodeNum(session_id, total_episode_num)
|
||||||
|
|
||||||
|
def SetTotalEpisodeNum(session_id: str, total_episode_num: int):
|
||||||
|
"""
|
||||||
|
设置总集数
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
total_episode_num: 总集数
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
success (bool): 是否更新成功
|
||||||
|
"""
|
||||||
|
script = mainDB.agent_writer_session.update_one({"_id": ObjectId(session_id)},{"$set":{"total_episode_num":total_episode_num}})
|
||||||
|
return {
|
||||||
|
"success": script.modified_count > 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def UpdateOneEpisodeTool(session_id: str, episode_num:int, content: str):
|
||||||
|
"""
|
||||||
|
更新单集内容
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
episode_num: 剧集编号
|
||||||
|
content: 剧集内容
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
success (bool): 是否更新成功
|
||||||
|
"""
|
||||||
|
return UpdateOneEpisode(session_id, episode_num, content)
|
||||||
|
|
||||||
|
def UpdateOneEpisode(session_id: str, episode_num:int, content: str):
|
||||||
|
"""
|
||||||
|
更新单集内容
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
episode_num: 剧集编号
|
||||||
|
content: 剧集内容
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
success (bool): 是否更新成功
|
||||||
|
"""
|
||||||
|
script = mainDB.agent_writer_episodes.update_one({"session_id": session_id, "episode_num": episode_num},{"$set":{"content":content}}, upsert=True)
|
||||||
|
return {
|
||||||
|
"success": script.modified_count > 0,
|
||||||
|
}
|
||||||
@ -49,25 +49,6 @@ class HuoshanChatModel(BaseChatModel):
|
|||||||
"""返回LLM类型标识"""
|
"""返回LLM类型标识"""
|
||||||
return "huoshan_chat"
|
return "huoshan_chat"
|
||||||
|
|
||||||
# def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> tuple[str, str]:
|
|
||||||
# """将LangChain消息格式转换为API所需的prompt和system格式"""
|
|
||||||
# system_message = ""
|
|
||||||
# user_messages = []
|
|
||||||
|
|
||||||
# for message in messages:
|
|
||||||
# if isinstance(message, SystemMessage):
|
|
||||||
# system_message = message.content or ""
|
|
||||||
# elif isinstance(message, HumanMessage):
|
|
||||||
# user_messages.append(message.content)
|
|
||||||
# elif isinstance(message, SystemMessage):
|
|
||||||
# # 如果需要支持多轮对话,可以在这里处理
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# # 合并用户消息
|
|
||||||
# prompt = "\n".join(user_messages) if user_messages else ""
|
|
||||||
|
|
||||||
# return prompt, str(system_message)
|
|
||||||
|
|
||||||
|
|
||||||
def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable:
|
def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable:
|
||||||
"""将工具绑定到模型,并将其转换为火山引擎API所需的格式。"""
|
"""将工具绑定到模型,并将其转换为火山引擎API所需的格式。"""
|
||||||
@ -192,8 +173,16 @@ class HuoshanChatModel(BaseChatModel):
|
|||||||
|
|
||||||
api_messages = self._convert_messages_to_prompt(messages)
|
api_messages = self._convert_messages_to_prompt(messages)
|
||||||
tools = kwargs.get("tools", [])
|
tools = kwargs.get("tools", [])
|
||||||
|
# print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}")
|
||||||
print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}")
|
print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> \n")
|
||||||
|
print(f"\nmessages: \n")
|
||||||
|
for message in messages:
|
||||||
|
print(f" {message.type}: \n ")
|
||||||
|
print(f" {message.content} \n ")
|
||||||
|
print(f"\ntools: \n")
|
||||||
|
for tool in tools:
|
||||||
|
print(f" \n {tool} \n ")
|
||||||
|
print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> end \n")
|
||||||
|
|
||||||
response_data = self._api.get_chat_response(messages=api_messages, tools=tools)
|
response_data = self._api.get_chat_response(messages=api_messages, tools=tools)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user