diff --git a/agent/episode_create.py b/agent/episode_create.py index d9e00f9..1067f84 100644 --- a/agent/episode_create.py +++ b/agent/episode_create.py @@ -45,13 +45,6 @@ DefaultAgentPrompt = f""" {{ "type":'沟通',//回复类型: 沟通:需要跟用户确认或继续沟通时的类型;输出:沟通足够最终给出`改编思路`时的类型; "message":'',//回复给用户的话 - "adaptation_ideas":'',//`改编思路`内容,在type为`输出`时才会有值 - "script_bible":{{//剧本圣经 只有type=输出时才返回,并且只返回有修改的子项,比如只修改了`核心大纲`和`总人物表`, script_bible中就只有core_outline和character_list两个字段; - "core_outline":'',//核心大纲 - "character_profile":'',//核心人物小传 - "core_event_timeline":'',//重大事件时间线 - "character_list":'',//总人物表 - }}, "episodes":[ //剧集内容列表 只有type=输出时才返回 {{ "number":1, //剧集编号(从1开始),只能是`指定创作集数`中的一个 diff --git a/agent/executor.py b/agent/executor.py deleted file mode 100644 index f9c436a..0000000 --- a/agent/executor.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -调度智能体 负责接收和分析用户的提示词,并调用智能调度其他智能体来处理工作 -""" - -from langgraph.graph import StateGraph -from langgraph.prebuilt import create_react_agent -from langgraph.graph.state import CompiledStateGraph -from utils.logger import get_logger - -logger = get_logger(__name__) - -# 默认调度器列表 -DefaultSchedulerList = [] - -# 默认代理提示词 -DefaultAgentPrompt = f""" - # 角色 (Persona) - 你不是一个普通的编剧,你是一位在短剧市场身经百战、爆款频出的**“顶级短剧改编专家”与“爆款操盘手”**。 - 你的核心人设与专长: - 极致爽点制造机: 你对观众的“爽点”G点有着鬣狗般的嗅觉。你的天职就是找到、放大、并以最密集的节奏呈现“打脸”、“逆袭”、“揭秘”、“宠溺”等情节。 - 人物标签化大师: 你深知在短剧中,模糊等于无效。你擅长将人物的核心欲望和性格特点极致化、标签化,让观众在3秒内记住主角,5秒内恨上反派。 - 情绪过山车设计师: 你的剧本就像过山车。开篇即俯冲,5秒一反转,10秒一高潮,结尾必留下一个让人抓心挠肝的钩子。你为观众提供的是极致的情绪体验。 - 网络梗语言学家: 你的台词充满了网感和“梗”,既能推动剧情,又能引发观众的共鸣和吐槽欲。对话追求高信息密度,不说一句废话。 - 你的沟通风格:自信、犀利、直击要害,同时又能清晰地解释你每一个改编决策背后的商业逻辑和观众心理。 - - # 任务总体步骤描述 - 1. 查找并确认原始剧本已就绪 - 2. 分析原始剧本得出`诊断与资产评估`,需要用户确认可以继续下一步,否则协助用户完成修改 - 3. 根据`诊断与资产评估`确定`改编思路`,需要用户确认可以继续下一步,否则协助用户完成修改 - 4. 根据`改编思路`生成`剧本圣经`,需要用户确认可以继续下一步,否则协助用户完成修改 - 5. 根据`改编思路`和`剧本圣经`持续剧集创作,单次执行3-5集的创建,直至完成全部剧集。 - 6. 注意步骤具有上下级关系,且不能跳过。但是后续步骤可返回触发前面的任务:如生成单集到第3集后,用户提出要修改某个角色,此时应当返回第4步,并协助用户进行修改与确认;完成修改后重新执行第5步,即从第一集开始重新创作一遍; - - 步骤中对应的阶段如下: - wait_for_input: 等待剧本阶段,查询到`原始剧本`存在并分析到用户确认后进入下一阶段 - script_analysis: 原始剧本分析阶段,查询到`诊断与资产评估`存在并分析到用户确认后进入下一阶段 - strategic_planning: 确立改编目标阶段,查询到`改编思路`存在并分析到用户确认后进入下一阶段 - build_bible: 剧本圣经构建阶段,查询到`剧本圣经`存在并分析到用户确认后进入下一阶段 - episode_create_loop: 剧集创作阶段,查询`剧集创作情况`并分析到已完成所有剧集的创作后进入下一阶段 - finish: 所有剧集创作已完成,用户确认后结束任务,用户需要修改则回退到适合的步骤进行修改并重新执行后续阶段 - - ***除了finish和wait_for_input之外的阶段都需要交给对应的智能体去处理*** - ***episode_create_loop阶段是一个循环阶段,每次循环需要你通过工具方法`剧集创作情况`来判断是否所有剧集都已创作完成,以及需要创作智能体单次创作的集数(通常是3-5集), 该集数为`指定创作集数`,需要添加到返回参数中*** - - # 智能体职责介绍 - ***调度智能体*** 名称:`scheduler` 描述:你自己,需要用户确认反馈时返回自身,并把状态设置成waiting; - ***原始剧本分析 智能体*** 名称:`script_analysis` 描述:构建`诊断与资产评估`;内容包括:故事内核诊断、可继承的宝贵资产(高光情节、神来之笔对白、独特人设闪光点)、以及核心问题与初步改编建议。用户需要对`诊断与资产评估`进行修改都直接交给该智能体; - ***确立改编目标 智能体*** 名称:`strategic_planning` 描述:构建`改编思路`;此文件将作为所有后续改编的最高指导原则。用户需要对`改编思路`进行修改都直接交给该智能体; - ***剧本圣经构建 智能体*** 名称:`build_bible` 描述:构建`剧本圣经`,剧本圣经具体包括了这几个部分:核心大纲, 核心人物小传, 重大事件时间线, 总人物表; 用户需要对`剧本圣经`的每一个部分进行修改都直接交给该智能体; - ***剧集创作 智能体*** 名称:`episode_create` 描述:构建剧集的具体创作;注意该智能体仅负责剧集的创作;对于某一集的具体修改直接交给该智能体; - - ***注意:智能体调用后最终会返回再次请求到你,你需要根据智能体的处理结果来决定下一步*** - ***注意:`智能体调用` 不是工具方法的使用,而是在返回数据中把agent属性指定为要调用的智能体名称*** - - # 工具使用 - 上述智能体职责中提及的输出内容,都有对应的工具可供你调用进行查看;他们的查询工具名称分别对应如下: - 原始剧本是否存在: `QueryOriginalScript` - 诊断与资产评估是否存在: `QueryDiagnosisAndAssessment` - 改编思路是否存在: `QueryAdaptationIdeas` - 剧本圣经是否存在: `QueryScriptBible` - 核心大纲是否存在: `QueryCoreOutline` - 核心人物小传是否存在: `QueryCharacterProfile` - 重大事件时间线是否存在: `QueryCoreEventTimeline` - 总人物表是否存在: `QueryCharacterList` - 剧集创作情况: `QueryEpisodeCount` - - ***注意:工具使用是需要你调用工具方法的;但是大多数情况下,你不需要查询文本的具体内容,只需要查询存在与否即可*** - - ***每次用户的输入都会携带当前`总任务的进度与任务状态`,注意查看并分析是否应该回复用户等待或提醒用户确认继续下一步*** - # 总任务的进度与任务状态数据结构为 {{"step": "waiting_script", "status": "running", "from_type":"user", "reason": "waiting_script", "retry_count": 0, "query_args":{{}}}} - - step: 阶段名称 - wait_for_input: 等待用户提供原始剧本 - script_analysis: 原始剧本分析 - strategic_planning: 确立改编目标 - build_bible: 剧本圣经构建 - episode_create_loop: 剧集创作 - finish: 所有剧集创作完成 - - status: 当前阶段的状态 - waiting: 等待用户反馈、确认 - running: 进行中 - failed: 失败 - completed: 完成 - - "from_type": 本次请求来着哪里 - user: 用户 - agent: 智能体返回 - - "reason": 失败原因,仅在`status`为`failed`时返回 - 字符串内容 - - "retry_count": 重试次数 - - "query_args": 用于调用工具方法的参数,可能包括的字段有: - "session_id": 会话ID,可用于查询`原始剧本` - - # 职责 - 分析用户输入与`总任务的进度与任务状态`,以下是几种情况的示例: - 1 `wait_for_input` 向用户问好,并介绍你作为“爆款短剧操盘手”的身份和专业工作流程,礼貌地请用户提供需要改编的原始剧本。如果用户没有提供原始剧本,你将持续友好地提醒,此时状态始终为waiting,直到获取原始剧本为止。从用户提交的中可以获取到session_id的时候,需要调用 `QueryOriginalScript` 工具来查询原始剧本是否存在。 - 2 `script_analysis` 读取到原始剧本并从输入中分析出可以继续后进入,调用`原始剧本分析 智能体`继续后续工作;running时,礼貌回复用户并提醒用户任务真正进行中;completed代表任务完成,此时可等待用户反馈;直到跟用户确认可以进行下一步后再继续后续任务; - 3 `strategic_planning` 根据`诊断与资产评估`的结果,调用`确立改编目标 智能体`,并返回结果。 - 4 `build_bible` 根据`改编思路`的结果,调用`剧本圣经构建 智能体`,并返回结果。 - 5 `episode_create_loop` 根据`剧本圣经`的结果,调用`剧集创作 智能体` - 5 `finish` 所有剧集完成后设置为该状态,但是不要返回node==end_node,因为用户还可以继续输入来进一步修改产出内容; - - ***当任意一个智能体返回失败时,你需要分析reason字段中的内容,来决定是否进行重试,如果需要重试则给retry_count加1,并交给失败的那个智能体重试一次;如果retry_count超过了3次,或者失败原因不适合重试则反馈给用户说任务失败了,请稍后再试*** - - 请严格按照下列JSON结构返回数据,不要有其他任何多余的信息和描述: - {{ - "step": "阶段名称",//取值范围在上述 step的描述中 不可写其他值 - "status": "当前阶段的状态",//取值范围在上述 status的描述中 不可写其他值 - "agent":'',//分析后得出由哪个智能体继续任务,此处为智能体名称;如果需要继续与用户交互或仅需要回复用户则为空字符串 - "message":'',//回复给用户的内容,注意,仅在你与用户沟通时才返回,其他情况下不返回。比如用户的需求是要交给其他智能体处理时,这个属性应该为空 - "retry_count":0,//重试次数 - "episode_create_num":[1,2,3],//指定创作集数 仅在episode_create_loop阶段会返回,内容是数组,数组中每一项是指定创作的剧集编号(从1开始); - "node":'',//下一个节点名称,根据指定的agent名称,从取值范围列表中选择一个节点名称返回 - }} - -""" - -def create_agent_prompt(prompt, SchedulerList): - """创建代理提示词的辅助函数""" - if not SchedulerList or len(SchedulerList) == 0: return prompt - node_list = [f"{node['name']}:{node['desc']}" for node in SchedulerList] - return f""" - {prompt} \n - 下面返回数据中node字段的取值范围列表([{{名称:描述}}]),请根据你的分析结果选择一个节点名称返回: - {node_list} \n - """ - - -class SchedulerAgent(CompiledStateGraph): - """智能调度智能体类 - - 该类负责接收用户的提示词,并调用其他智能体来处理工作。 - """ - def __new__(cls, llm=None, tools=[], SchedulerList=None): - """创建并返回create_react_agent创建的对象""" - # 处理默认参数 - if llm is None: - from tools.llm.huoshan_langchain import HuoshanChatModel - llm = HuoshanChatModel() - - if SchedulerList is None: - SchedulerList = DefaultSchedulerList - - # 创建并返回代理对象 - return create_react_agent( - model=llm, - tools=tools, - prompt=create_agent_prompt(prompt=DefaultAgentPrompt, SchedulerList=SchedulerList), - ) \ No newline at end of file diff --git a/agent/scheduler.py b/agent/scheduler.py index f9c436a..7c095ee 100644 --- a/agent/scheduler.py +++ b/agent/scheduler.py @@ -64,10 +64,19 @@ DefaultAgentPrompt = f""" 总人物表是否存在: `QueryCharacterList` 剧集创作情况: `QueryEpisodeCount` - ***注意:工具使用是需要你调用工具方法的;但是大多数情况下,你不需要查询文本的具体内容,只需要查询存在与否即可*** + ***注意:工具使用是需要你调用工具方法的;大多数情况下同一个方法只需要调用一次*** - ***每次用户的输入都会携带当前`总任务的进度与任务状态`,注意查看并分析是否应该回复用户等待或提醒用户确认继续下一步*** - # 总任务的进度与任务状态数据结构为 {{"step": "waiting_script", "status": "running", "from_type":"user", "reason": "waiting_script", "retry_count": 0, "query_args":{{}}}} + ***每次用户的输入都会携带最新的`任务列表`和`工作流参数`,注意查看并分析是否应该回复用户等待或提醒用户确认继续下一步*** + # 工作流参数包含字段如下: + "session_id": 会话ID 可用于工具方法的调用 + "task_index": 当前执行中的任务索引 + "from_type": 本次请求来着哪里 + user: 用户 + agent: 智能体返回 + + # 任务列表是一个数组,每项的数据结构如下: + "agent": 执行这个任务的智能体名称 + 字符串内容 可为空 为空时表示当前任务不需要调用智能体 step: 阶段名称 wait_for_input: 等待用户提供原始剧本 @@ -77,26 +86,41 @@ DefaultAgentPrompt = f""" episode_create_loop: 剧集创作 finish: 所有剧集创作完成 + "pause": 是否需要暂停 当需要和用户沟通时设置为true,任务会中断等待用户回复 + status: 当前阶段的状态 waiting: 等待用户反馈、确认 running: 进行中 failed: 失败 completed: 完成 - "from_type": 本次请求来着哪里 - user: 用户 - agent: 智能体返回 - - "reason": 失败原因,仅在`status`为`failed`时返回 + "reason": 失败原因,仅在`status`为`failed`时返回 字符串内容 - "retry_count": 重试次数 + "retry_count": 失败重试次数 + 整数内容 + "episode_create_num": 指定创作集数 仅在episode_create_loop阶段会返回,内容是数组,数组中每一项是指定创作的剧集编号(从1开始); + 整数数组 - "query_args": 用于调用工具方法的参数,可能包括的字段有: - "session_id": 会话ID,可用于查询`原始剧本` + # 你的职责 + 分析用户所有的输入信息,以及信息中的`任务列表`、`工作流参数`; + 首先读取`工作流参数`中的数据,判断from_type的值是user还是agent; + 如果是user,说明用户是在和你沟通,你需要根据用户的输入来判断下一步; + 如果是agent,说明是智能体执行完任务后返回的结果,你需要根据智能体的返回结果来判断下一步;根据task_index取出`任务列表`中的当前任务,判断他的状态来决定是否继续列表中的下一个任务; + 如果当前任务的状态是completed,说明当前任务完成,需要继续列表中的下一个任务; + 如果当前任务的状态是failed,说明当前任务失败,需要根据失败原因来判断是否需要重试;如果需要重试,需要增加重试次数,并且需要继续列表中的下一个任务; + 如果当前任务的状态是waiting,说明当前任务等待用户反馈,需要等待用户反馈后继续执行;此时任务中的pause属性需要修改为true; - # 职责 - 分析用户输入与`总任务的进度与任务状态`,以下是几种情况的示例: + `任务列表`的生成规则: + 1 `任务列表`为空时,你需要根据上文中`任务总体步骤描述`生成一个新的任务列表 + 2 执行`任务列表`中的第一个未完成的任务 + + 当你读取到一个空的任务列表 或者 任务列表中的所有任务都完成或失败时,你需要分析出一个新的任务列表; + 新的任务列表至少包含一个任务; + 任务列表中的每个任务代表的是一个后续智能体要执行的任务,其中scheduler代表你自己,大多数情况下这代表了用户回复了你的提问; + + 以下是任务列表的几种情况的示例: + 1 任务列表为空时,你需要生成一个新的任务列表;此时你的分析结果是需要用户 1 `wait_for_input` 向用户问好,并介绍你作为“爆款短剧操盘手”的身份和专业工作流程,礼貌地请用户提供需要改编的原始剧本。如果用户没有提供原始剧本,你将持续友好地提醒,此时状态始终为waiting,直到获取原始剧本为止。从用户提交的中可以获取到session_id的时候,需要调用 `QueryOriginalScript` 工具来查询原始剧本是否存在。 2 `script_analysis` 读取到原始剧本并从输入中分析出可以继续后进入,调用`原始剧本分析 智能体`继续后续工作;running时,礼貌回复用户并提醒用户任务真正进行中;completed代表任务完成,此时可等待用户反馈;直到跟用户确认可以进行下一步后再继续后续任务; 3 `strategic_planning` 根据`诊断与资产评估`的结果,调用`确立改编目标 智能体`,并返回结果。 @@ -108,13 +132,9 @@ DefaultAgentPrompt = f""" 请严格按照下列JSON结构返回数据,不要有其他任何多余的信息和描述: {{ - "step": "阶段名称",//取值范围在上述 step的描述中 不可写其他值 - "status": "当前阶段的状态",//取值范围在上述 status的描述中 不可写其他值 - "agent":'',//分析后得出由哪个智能体继续任务,此处为智能体名称;如果需要继续与用户交互或仅需要回复用户则为空字符串 "message":'',//回复给用户的内容,注意,仅在你与用户沟通时才返回,其他情况下不返回。比如用户的需求是要交给其他智能体处理时,这个属性应该为空 - "retry_count":0,//重试次数 - "episode_create_num":[1,2,3],//指定创作集数 仅在episode_create_loop阶段会返回,内容是数组,数组中每一项是指定创作的剧集编号(从1开始); - "node":'',//下一个节点名称,根据指定的agent名称,从取值范围列表中选择一个节点名称返回 + "task_list": [] //最新的任务列表 + "task_index": 0 //执行中的任务的索引 }} """ diff --git a/agent/strategic_planning.py b/agent/strategic_planning.py index d1348f5..34f5327 100644 --- a/agent/strategic_planning.py +++ b/agent/strategic_planning.py @@ -46,6 +46,7 @@ DefaultAgentPrompt = f""" {{ "type":'沟通',//回复类型: 沟通:需要跟用户确认或继续沟通时的类型;输出:沟通足够最终给出`改编思路`时的类型; "message":'',//回复给用户的话 + "total_episode_num":0,//需要创作的总集数 "adaptation_ideas":'',//`改编思路`内容,在type为`输出`时才会有值 }} """ diff --git a/doc/test.txt b/doc/test.txt index 942ffbc..e69de29 100644 --- a/doc/test.txt +++ b/doc/test.txt @@ -1,173 +0,0 @@ -为了在你的代码中正确接入火山引擎的工具调用功能,你需要修改 `huoshan_langchain.py` 和 `huoshan.py` 这两个文件,以实现**从 LangChain 工具到火山引擎 API 工具定义的格式转换**,以及**解析和处理来自 API 的工具调用响应**。 - -下面是根据火山引擎官方文档和你的代码,我为你整理的完整修改方案。 - ------ - -### 第一步:修改 `huoshan_langchain.py` - -这个文件是 LangChain 的封装层,负责连接你的工作流和火山引擎的底层 API。你需要在这里实现 `bind_tools` 和 `_generate` 方法来处理工具调用。 - -1. **导入必要的类**: - 需要添加 `ToolMessage` 和 `ToolCall`,它们是 LangChain 用于表示工具调用结果和工具调用的核心类。 - - ```python - from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence - from langchain_core.callbacks.manager import CallbackManagerForLLMRun - from langchain_core.language_models.chat_models import BaseChatModel - from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall - from langchain_core.outputs import ChatGeneration, ChatResult - from langchain_core.tools import BaseTool - from langchain_core.runnables import Runnable - from langchain.pydantic_v1 import BaseModel - from pydantic import Field - from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam - from api.huoshan import HuoshanAPI - import json - ``` - -2. **实现 `bind_tools` 方法**: - 这个方法是 LangChain 用于将工具定义传递给你的模型封装。在这里,你需要将 LangChain 的 `BaseTool` 对象转换为火山引擎 API 所需的 `ChatCompletionToolParam` 格式。 - - ```python - class HuoshanChatModel(BaseChatModel): - # ... (其他代码不变) ... - - def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable: - """将工具绑定到模型,并将其转换为火山引擎API所需的格式。""" - tool_definitions = [] - for tool_item in tools: - tool_definitions.append( - ChatCompletionToolParam( - type="function", - function={ - "name": tool_item.name, - "description": tool_item.description, - "parameters": tool_item.args_schema.schema() if isinstance(tool_item.args_schema, type(BaseModel)) else tool_item.args_schema - } - ) - ) - - # 返回一个绑定了工具的新实例 - # 这里我们使用_bind方法,它会返回一个新的Runnable实例 - return self._bind(tools=tool_definitions, **kwargs) - ``` - -3. **修改 `_convert_messages_to_prompt` 方法**: - 这个方法需要能够处理 LangChain 的 `ToolMessage` 和 `AIMessage`,并将其转换为火山引擎 API 的消息格式。这对于工具调用的回填和最终回复至关重要。 - - ```python - class HuoshanChatModel(BaseChatModel): - # ... (其他代码不变) ... - - def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> List[Dict]: - """将LangChain消息转换为火山引擎API所需的格式。""" - api_messages = [] - for msg in messages: - if isinstance(msg, HumanMessage): - api_messages.append({"role": "user", "content": msg.content}) - elif isinstance(msg, AIMessage): - if msg.tool_calls: - api_messages.append({ - "role": "assistant", - "content": "", - "tool_calls": [{ - "id": tc.id, - "type": "function", - "function": { - "name": tc.name, - "arguments": json.dumps(tc.args) - } - } for tc in msg.tool_calls] - }) - else: - api_messages.append({"role": "assistant", "content": msg.content}) - elif isinstance(msg, ToolMessage): - api_messages.append({ - "role": "tool", - "content": msg.content, - "tool_call_id": msg.tool_call_id - }) - return api_messages - ``` - -4. **修改 `_generate` 方法**: - 这个方法需要调用底层 API,并解析大模型返回的响应,以检查是否包含工具调用。 - - ```python - class HuoshanChatModel(BaseChatModel): - # ... (其他代码不变) ... - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - if not self._api: - raise ValueError("HuoshanAPI未正确初始化") - - api_messages = self._convert_messages_to_prompt(messages) - tools = kwargs.get("tools", []) - - response_data = self._api.get_chat_response(messages=api_messages, tools=tools) - - try: - message_from_api = response_data.get("choices", [{}])[0].get("message", {}) - - tool_calls = message_from_api.get("tool_calls", []) - if tool_calls: - lc_tool_calls = [] - for tc in tool_calls: - lc_tool_calls.append(ToolCall( - name=tc["function"]["name"], - args=json.loads(tc["function"]["arguments"]), - id=tc.get("id", "") - )) - message = AIMessage(content="", tool_calls=lc_tool_calls) - else: - content = message_from_api.get("content", "") - message = AIMessage(content=content) - - generation = ChatGeneration(message=message) - return ChatResult(generations=[generation]) - - except Exception as e: - raise ValueError(f"处理火山引擎API响应失败: {str(e)}") - ``` - -### 第二步:修改 `huoshan.py` - -这个文件是底层 API 客户端,负责与火山引擎 API 进行通信。你需要修改 `get_chat_response` 方法,使其能够发送 `tools` 参数。 - -```python -# huoshan.py - -from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam - -class HuoshanAPI: - # ... (其他代码不变) ... - - def get_chat_response( - self, - messages: List[Dict], - stream: bool = False, - tools: Optional[List[ChatCompletionToolParam]] = None - ) -> Dict[str, Any]: - """同步获取聊天响应,支持工具调用。""" - client = Ark() - - try: - completion = client.chat.completions.create( - model=self.doubao_seed_1_6_model_id, - messages=messages, - stream=stream, - tools=tools # 传入 tools 参数 - ) - return completion.model_dump() # 使用 model_dump() 转换为字典 - except Exception as e: - raise ValueError(f"火山引擎API调用失败: {str(e)}") -``` - -完成以上修改后,你的 `HuoshanChatModel` 就会支持工具调用,并能与 LangGraph 的智能体框架无缝集成。 \ No newline at end of file diff --git a/graph/test_agent_graph_1.py b/graph/test_agent_graph_1.py index e8a110e..a5d1d4d 100644 --- a/graph/test_agent_graph_1.py +++ b/graph/test_agent_graph_1.py @@ -3,6 +3,7 @@ 该模块定义了智能编剧系统的完整工作流程图,包括各个节点和边的连接关系。 """ +from re import T from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional from langgraph.graph.state import RunnableConfig @@ -12,6 +13,7 @@ from agent.episode_create import EpisodeCreateAgent from agent.script_analysis import ScriptAnalysisAgent from agent.strategic_planning import StrategicPlanningAgent +from langgraph.checkpoint.memory import InMemorySaver from langchain_core.messages import AnyMessage,HumanMessage from langgraph.graph import StateGraph, START, END from utils.logger import get_logger @@ -26,9 +28,34 @@ from tools.agent.queryDB import QueryOriginalScript logger = get_logger(__name__) -def messages_handler(old_messages, new_messages): +def clear_messages(messages): + """清除指定会话的所有消息""" + # 剔除历史状态消息 + exclude_strings = [ + "---任务状态消息(开始)---", + "---原始剧本(开始)---", + "---诊断与资产评估报告(开始)---", + "---改编思路(开始)---", + ] + messages = [ + message for message in messages + if all(s not in message.content for s in exclude_strings) + ] + # HumanMessage 超过 10 条,删除最早的 1 条 + if len([message for message in messages if message.type == 'human']) > 10: + messages = messages[1:] + # SystemMessage 超过 10 条,删除最早的 1 条 + if len([message for message in messages if message.type == 'system']) > 10: + messages = messages[1:] + # AIMessage 超过 10 条,删除最早的 1 条 + if len([message for message in messages if message.type == 'ai']) > 10: + messages = messages[1:] + return messages + +def messages_handler(old_messages: list[AnyMessage], new_messages: list[AnyMessage]): """消息合并方法""" - return new_messages + clear_messages(old_messages) + return old_messages + new_messages def replace_value(old_val, new_val): """值覆盖方法""" @@ -48,13 +75,24 @@ class OutputState(TypedDict): error: Annotated[str, replace_value] agent_message: Annotated[str, replace_value] # 智能体回复 -class NodeInfo(TypedDict): - """工作流信息""" +class TaskItem(TypedDict): + """任务列表中的每个任务""" + agent: Annotated[str, replace_value] # 执行这个任务的智能体名称 step: Annotated[str, replace_value] # 阶段名称 [wait_for_input,script_analysis,strategic_planning,build_bible,episode_create_loop, finish] status: Annotated[str, replace_value] # 当前阶段的状态 [waiting,running,failed,completed] reason: Annotated[str, replace_value] # 失败原因 retry_count: Annotated[int, replace_value] # 重试次数 - from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent] + pause: Annotated[bool, replace_value] # 是否暂停 + episode_create_num: Annotated[List[int], replace_value] # 指定创作集数 仅在episode_create_loop阶段会返回,内容是数组,数组中每一项是指定创作的剧集编号(从1开始); + +class WorkflowInfo(TypedDict): + """总工作流信息""" + is_original_script: Annotated[bool, replace_value] # 是否已提交原始剧本 + is_script_analysis: Annotated[bool, replace_value] # 是否已生成 诊断与资产评估报告 + is_strategic_planning: Annotated[bool, replace_value] # 是否已生成 改编思路 + is_build_bible: Annotated[bool, replace_value] # 是否已生成 剧本圣经 + is_episode_create_loop: Annotated[bool, replace_value] # 是否已生成 剧集生成循环 + is_all_episode_created: Annotated[bool, replace_value] # 是否已生成 全部剧集 class ScriptwriterState(TypedDict, total=False): @@ -63,22 +101,25 @@ class ScriptwriterState(TypedDict, total=False): messages: Annotated[list[AnyMessage], messages_handler] session_id: Annotated[str, replace_value] from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent] - - # 节点间状态 - next_node: Annotated[str, replace_value] # 下一个节点 - workflow_step: Annotated[str, replace_value] # 阶段名称 [wait_for_input,script_analysis,strategic_planning,build_bible,episode_create_loop, finish] - workflow_status: Annotated[str, replace_value] # 当前阶段的状态 [waiting,running,failed,completed] - workflow_reason: Annotated[str, replace_value] # 失败原因 - workflow_retry_count: Annotated[int, replace_value] # 重试次数 - + # 中间状态 - task_list: Annotated[List[Dict[str, Any]], replace_value] # 顺序执行的任务列表 + workflow_info: Annotated[WorkflowInfo, replace_value] # 顺序执行的任务列表 + task_list: Annotated[List[TaskItem], replace_value] # 顺序执行的任务列表 + task_index: Annotated[int, replace_value] # 当前执行中的任务索引 # 输出数据 agent_message: Annotated[str, replace_value] # 智能体回复 status: Annotated[str, replace_value] error: Annotated[str, replace_value] +AgentNodeMap = { + "scheduler": "scheduler_node", + "script_analysis": "script_analysis_node", + "strategic_planning": "strategic_planning_node", + "build_bible": "build_bible_node", + "episode_create": "episode_create_node", +} + class ScriptwriterGraph: """智能编剧工作流图类 @@ -100,12 +141,15 @@ class ScriptwriterGraph: def node_router(self, state: ScriptwriterState) -> str: """节点路由函数""" - print(f'node_router state {state}') - next_node = state.get("next_node", 'pause_node') - # 修复:当 next_node 为空字符串时,设置默认值 - if not next_node: + # print(f'node_router state {state}') + task_list = state.get("task_list", []) + task_index = state.get("task_index", 0) + now_task = task_list[task_index] + if not now_task or now_task.get('pause'): next_node = 'pause_node' # 设置为暂停节点 - print(f'node_router next_node {next_node}') + else: + next_node = AgentNodeMap.get(now_task.get('agent'), 'pause_node') + # print(f'node_router next_node {next_node}') return next_node def _build_graph(self) -> None: @@ -222,63 +266,41 @@ class ScriptwriterGraph: workflow.add_edge("end_node", END) # 编译图 - self.graph = workflow.compile(checkpointer=self.memory) + checkpoint = InMemorySaver() + self.graph = workflow.compile(checkpointer=checkpoint) # 不缓存记忆 + # self.graph = workflow.compile(checkpointer=self.memory) # 使用mongodb缓存记忆 logger.info("工作流图构建完成") except Exception as e: logger.error(f"构建工作流图失败: {e}") raise - def clear_messages(self, messages): - """清除指定会话的所有消息""" - # 剔除历史状态消息 - messages = [message for message in messages if "---任务状态消息(开始)---" not in message.content ] - # HumanMessage 超过 10 条,删除最早的 1 条 - if len([message for message in messages if message.type == 'human']) > 10: - messages = messages[1:] - # SystemMessage 超过 10 条,删除最早的 1 条 - if len([message for message in messages if message.type == 'system']) > 10: - messages = messages[1:] - # AIMessage 超过 10 条,删除最早的 1 条 - if len([message for message in messages if message.type == 'ai']) > 10: - messages = messages[1:] - return messages # --- 定义图中的节点 --- async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState: """调度节点""" - try: - status = state.get("status", "") + try: session_id = state.get("session_id", "") + workflow_info = state.get("workflow_info", {}) + task_list = state.get("task_list", []) + task_index = int(state.get("task_index", 0)) from_type = state.get("from_type", "") messages = state.get("messages", []) - if status == 'failed': - return { - "next_node":'end_node', - "agent_message": state.get("agent_message", ""), - "error": state.get("error", '系统错误,工作流已终止'), - 'status':'failed', - } # 清除历史状态消息 - messages = self.clear_messages(messages) - workflow_step = state.get("workflow_step", "wait_for_input") - workflow_status = state.get("workflow_status", "waiting") - workflow_reason = state.get("workflow_reason", "") - workflow_retry_count = int(state.get("workflow_retry_count", 0)) + messages = clear_messages(messages) # 添加参数进提示词 messages.append(HumanMessage(content=f""" ---任务状态消息(开始)--- - # 总任务的进度与任务状态: + # 工作流信息: + {workflow_info} + # 工作流参数: {{ - 'query_args':{{ - 'session_id':'{session_id}', - }}, - 'step':'{workflow_step}', - 'status':'{workflow_status}', + 'session_id':'{session_id}', + 'task_index':'{task_index}', 'from_type':'{from_type}', - 'reason':'{workflow_reason}', - 'retry_count':{workflow_retry_count}, }} + # 任务列表: + {task_list} ---任务状态消息(结束)--- """)) system_message_count = 0 @@ -291,37 +313,26 @@ class ScriptwriterGraph: human_message_count += 1 elif message.type == 'ai': ai_message_count += 1 - logger.info(f"调度节点 {session_id} 输入消息条数: {len(messages)} from_type:{from_type} system_message_count:{system_message_count} human_message_count:{human_message_count} ai_message_count:{ai_message_count}") - reslut = await self.schedulerAgent.ainvoke(state) + logger.info(f"调度节点 {session_id} \n 输入消息条数: {len(messages)} \n from_type:{from_type} \n system_message_count:{system_message_count} \n human_message_count:{human_message_count} \n ai_message_count:{ai_message_count}") + # 调用智能体 + reslut = await self.schedulerAgent.ainvoke({"messages":messages}) ai_message_str = reslut['messages'][-1].content ai_message = json.loads(ai_message_str) # logger.info(f"调度节点结果: {ai_message}") - step:str = ai_message.get('step', '') - status:str = ai_message.get('status', '') - next_agent:str = ai_message.get('agent', '') return_message:str = ai_message.get('message', '') - retry_count:int = int(ai_message.get('retry_count', '0')) - next_node:str = ai_message.get('node', 'pause_node') - if next_node == 'scheduler_node': - # 返回自身 代表暂停 - print(f"调度节点 暂停等待") - return { + task_list:list = ai_message.get('task_list', []) + task_index:int = int(ai_message.get('task_index', '0')) + + return { + "messages": messages, + "task_list": task_list, + "task_index": task_index, "agent_message": return_message, } - else: - return { - "next_node":next_node, - "workflow_step":step, - "workflow_status":status, - # "workflow_reason":return_message, - "workflow_retry_count":retry_count, - "agent_message":return_message, - } except Exception as e: - # import traceback - # traceback.print_exc() + import traceback + traceback.print_exc() return { - "next_node":'end_node', "agent_message": "执行失败", "error": str(e) or '系统错误,工作流已终止', 'status':'failed', @@ -335,53 +346,47 @@ class ScriptwriterGraph: from_type = state.get("from_type", "") messages = state.get("messages", []) # 清除历史状态消息 - messages = self.clear_messages(messages) - workflow_step = state.get("workflow_step", "wait_for_input") - workflow_status = state.get("workflow_status", "waiting") - workflow_reason = state.get("workflow_reason", "") - workflow_retry_count = int(state.get("workflow_retry_count", 0)) + messages = clear_messages(messages) # 添加参数进提示词 + from tools.agent.queryDB import QueryOriginalScriptContent + original_script_content = QueryOriginalScriptContent(session_id) messages.append(HumanMessage(content=f""" - ---任务状态消息(开始)--- - # 总任务的进度与任务状态: - {{ - 'query_args':{{ - 'session_id':'{session_id}', - }}, - 'step':'{workflow_step}', - 'status':'{workflow_status}', - 'from_type':'{from_type}', - 'reason':'{workflow_reason}', - 'retry_count':{workflow_retry_count}, - }} - ---任务状态消息(结束)--- + ---原始剧本(开始)--- + {original_script_content['content']} + ---原始剧本(结束)--- """)) - reslut = await self.scriptAnalysisAgent.ainvoke(state) + reslut = await self.scriptAnalysisAgent.ainvoke({"messages": messages}) ai_message_str = reslut['messages'][-1].content ai_message = json.loads(ai_message_str) - # logger.info(f"调度节点结果: {ai_message}") - step:str = ai_message.get('step', '') - status:str = ai_message.get('status', '') - next_agent:str = ai_message.get('agent', '') return_message:str = ai_message.get('message', '') - retry_count:int = int(ai_message.get('retry_count', '0')) - next_node:str = ai_message.get('node', 'pause_node') + + task_list:list = state.get('task_list', []) + task_index:int = int(state.get('task_index', '0')) + task = task_list[task_index] + if task: + type:str = ai_message.get('type', '沟通') + if type == '沟通': + task['status'] = 'waiting' + task['parse'] = True + elif type == '输出': + task['status'] = 'completed' + task['parse'] = False + diagnosis_and_assessment:str = ai_message.get('diagnosis_and_assessment', '') + from tools.agent.updateDB import UpdateDiagnosisAndAssessment + UpdateDiagnosisAndAssessment(session_id, diagnosis_and_assessment) # print(f"报告已生成: TEST") print("\n------------ 诊断分析结束 ------------") + return { - "from_type":'agent', - "next_node":next_node, - "workflow_step":step, - "workflow_status":status, - # "workflow_reason":return_message, - "workflow_retry_count":retry_count, - "agent_message":return_message, - } + "messages": messages, + "task_list": task_list, + "task_index": task_index, + "agent_message": return_message, + } except Exception as e: import traceback traceback.print_exc() return { - "next_node":'end_node', "agent_message": "诊断分析失败", "error": str(e) or '系统错误,工作流已终止', 'status':'failed', @@ -390,59 +395,61 @@ class ScriptwriterGraph: async def strategic_planning_node(self, state: ScriptwriterState)-> ScriptwriterState: """第三步:确立改编目标与战略蓝图""" try: - print("\n------------ 正在制定战略蓝图 ------------") + print("\n------------ 正在生成 改编思路 ------------") session_id = state.get("session_id", "") from_type = state.get("from_type", "") messages = state.get("messages", []) # 清除历史状态消息 - messages = self.clear_messages(messages) - workflow_step = state.get("workflow_step", "wait_for_input") - workflow_status = state.get("workflow_status", "waiting") - workflow_reason = state.get("workflow_reason", "") - workflow_retry_count = int(state.get("workflow_retry_count", 0)) + messages = clear_messages(messages) # 添加参数进提示词 + from tools.agent.queryDB import QueryOriginalScriptContent, QueryDiagnosisAndAssessmentContent + original_script_content = QueryOriginalScriptContent(session_id) messages.append(HumanMessage(content=f""" - ---任务状态消息(开始)--- - # 总任务的进度与任务状态: - {{ - 'query_args':{{ - 'session_id':'{session_id}', - }}, - 'step':'{workflow_step}', - 'status':'{workflow_status}', - 'from_type':'{from_type}', - 'reason':'{workflow_reason}', - 'retry_count':{workflow_retry_count}, - }} - ---任务状态消息(结束)--- + ---原始剧本(开始)--- + {original_script_content['content']} + ---原始剧本(结束)--- """)) - reslut = await self.strategicPlanningAgent.ainvoke(state) + diagnosis_and_assessment_content = QueryDiagnosisAndAssessmentContent(session_id) + messages.append(HumanMessage(content=f""" + ---诊断与资产评估报告(开始)--- + {diagnosis_and_assessment_content['content']} + ---诊断与资产评估报告(结束)--- + """)) + reslut = await self.strategicPlanningAgent.ainvoke({"messages": messages}) ai_message_str = reslut['messages'][-1].content ai_message = json.loads(ai_message_str) - # logger.info(f"调度节点结果: {ai_message}") - step:str = ai_message.get('step', '') - status:str = ai_message.get('status', '') - next_agent:str = ai_message.get('agent', '') return_message:str = ai_message.get('message', '') - retry_count:int = int(ai_message.get('retry_count', '0')) - next_node:str = ai_message.get('node', 'pause_node') + + task_list:list = state.get('task_list', []) + task_index:int = int(state.get('task_index', '0')) + task = task_list[task_index] + if task: + type:str = ai_message.get('type', '沟通') + if type == '沟通': + task['status'] = 'waiting' + task['parse'] = True + elif type == '输出': + task['status'] = 'completed' + task['parse'] = False + adaptation_ideas:str = ai_message.get('adaptation_ideas', '') + total_episode_num:int = int(ai_message.get('total_episode_num', 0)) + from tools.agent.updateDB import UpdateAdaptationIdeas,SetTotalEpisodeNum + UpdateAdaptationIdeas(session_id, adaptation_ideas) + SetTotalEpisodeNum(session_id, total_episode_num) # print(f"报告已生成: TEST") - print("\n------------ 制定战略蓝图结束 ------------") + print("\n------------ 生成 改编思路 结束 ------------") + return { - "from_type":'agent', - "next_node":next_node, - "workflow_step":step, - "workflow_status":status, - # "workflow_reason":return_message, - "workflow_retry_count":retry_count, - "agent_message":return_message, - } + "messages": messages, + "task_list": task_list, + "task_index": task_index, + "agent_message": return_message, + } except Exception as e: import traceback traceback.print_exc() return { - "next_node":'end_node', - "agent_message": "制定战略蓝图失败", + "agent_message": "生成 改编思路 失败", "error": str(e) or '系统错误,工作流已终止', 'status':'failed', } @@ -450,120 +457,134 @@ class ScriptwriterGraph: async def build_bible_node(self, state: ScriptwriterState)-> ScriptwriterState: """第四步:制定剧本圣经""" try: - print("\n------------ 正在制定剧本圣经 ------------") + print("\n------------ 正在生成 剧本圣经 ------------") session_id = state.get("session_id", "") from_type = state.get("from_type", "") messages = state.get("messages", []) # 清除历史状态消息 - messages = self.clear_messages(messages) - workflow_step = state.get("workflow_step", "wait_for_input") - workflow_status = state.get("workflow_status", "waiting") - workflow_reason = state.get("workflow_reason", "") - workflow_retry_count = int(state.get("workflow_retry_count", 0)) + messages = clear_messages(messages) # 添加参数进提示词 + from tools.agent.queryDB import QueryOriginalScriptContent, QueryAdaptationIdeasContent + original_script_content = QueryOriginalScriptContent(session_id) messages.append(HumanMessage(content=f""" - ---任务状态消息(开始)--- - # 总任务的进度与任务状态: - {{ - 'query_args':{{ - 'session_id':'{session_id}', - }}, - 'step':'{workflow_step}', - 'status':'{workflow_status}', - 'from_type':'{from_type}', - 'reason':'{workflow_reason}', - 'retry_count':{workflow_retry_count}, - }} - ---任务状态消息(结束)--- + ---原始剧本(开始)--- + {original_script_content['content']} + ---原始剧本(结束)--- """)) - reslut = await self.buildBibleAgent.ainvoke(state) + adaptation_ideas_content = QueryAdaptationIdeasContent(session_id) + messages.append(HumanMessage(content=f""" + ---改编思路(开始)--- + {adaptation_ideas_content['content']} + ---改编思路(结束)--- + """)) + reslut = await self.buildBibleAgent.ainvoke({"messages": messages}) ai_message_str = reslut['messages'][-1].content ai_message = json.loads(ai_message_str) - # logger.info(f"调度节点结果: {ai_message}") - step:str = ai_message.get('step', '') - status:str = ai_message.get('status', '') - next_agent:str = ai_message.get('agent', '') return_message:str = ai_message.get('message', '') - retry_count:int = int(ai_message.get('retry_count', '0')) - next_node:str = ai_message.get('node', 'pause_node') + + task_list:list = state.get('task_list', []) + task_index:int = int(state.get('task_index', '0')) + task = task_list[task_index] + if task: + type:str = ai_message.get('type', '沟通') + if type == '沟通': + task['status'] = 'waiting' + task['parse'] = True + elif type == '输出': + task['status'] = 'completed' + task['parse'] = False + script_bible:dict = ai_message.get('script_bible', {}) + core_outline:str = script_bible.get('core_outline', '') + character_profile:str = script_bible.get('character_profile', '') + core_event_timeline:str = script_bible.get('core_event_timeline', '') + character_list:str = script_bible.get('character_list', '') + from tools.agent.updateDB import UpdateScriptBible + UpdateScriptBible(session_id, core_outline, character_profile, core_event_timeline, character_list) # print(f"报告已生成: TEST") - print("\n------------ 制定剧本圣经结束 ------------") + print("\n------------ 生成 剧本圣经 结束 ------------") + return { - "from_type":'agent', - "next_node":next_node, - "workflow_step":step, - "workflow_status":status, - # "workflow_reason":return_message, - "workflow_retry_count":retry_count, - "agent_message":return_message, - } + "messages": messages, + "task_list": task_list, + "task_index": task_index, + "agent_message": return_message, + } except Exception as e: import traceback traceback.print_exc() return { - "next_node":'end_node', - "agent_message": "制定剧本圣经失败", + "agent_message": "生成 剧本圣经 失败", "error": str(e) or '系统错误,工作流已终止', 'status':'failed', } async def episode_create_node(self, state: ScriptwriterState)-> ScriptwriterState: - """第五步:动态创作与闭环校验(循环主体)""" + """第五步:循环创作剧本内容""" try: - print("\n------------ 正在创作单集内容 ------------") session_id = state.get("session_id", "") from_type = state.get("from_type", "") + task_list:list = state.get('task_list', []) + task_index:int = int(state.get('task_index', '0')) + task:dict = task_list[task_index] or {} + episode_create_num:list = task.get('episode_create_num', []) + print(f"\n------------ 正在生成 剧集内容 {episode_create_num} ------------") messages = state.get("messages", []) # 清除历史状态消息 - messages = self.clear_messages(messages) - workflow_step = state.get("workflow_step", "wait_for_input") - workflow_status = state.get("workflow_status", "waiting") - workflow_reason = state.get("workflow_reason", "") - workflow_retry_count = int(state.get("workflow_retry_count", 0)) + messages = clear_messages(messages) + # 添加参数进提示词 + from tools.agent.queryDB import QueryOriginalScriptContent, QueryAdaptationIdeasContent + original_script_content = QueryOriginalScriptContent(session_id) + messages.append(HumanMessage(content=f""" + ---原始剧本(开始)--- + {original_script_content['content']} + ---原始剧本(结束)--- + """)) + adaptation_ideas_content = QueryAdaptationIdeasContent(session_id) + messages.append(HumanMessage(content=f""" + ---改编思路(开始)--- + {adaptation_ideas_content['content']} + ---改编思路(结束)--- + """)) # 添加参数进提示词 messages.append(HumanMessage(content=f""" ---任务状态消息(开始)--- - # 总任务的进度与任务状态: + # 工作流参数: {{ - 'query_args':{{ - 'session_id':'{session_id}', - }}, - 'step':'{workflow_step}', - 'status':'{workflow_status}', - 'from_type':'{from_type}', - 'reason':'{workflow_reason}', - 'retry_count':{workflow_retry_count}, + 'episode_create_num':'{episode_create_num}', }} ---任务状态消息(结束)--- """)) - reslut = await self.episodeCreateAgent.ainvoke(state) + reslut = await self.buildBibleAgent.ainvoke({"messages": messages}) ai_message_str = reslut['messages'][-1].content ai_message = json.loads(ai_message_str) - # logger.info(f"调度节点结果: {ai_message}") - step:str = ai_message.get('step', '') - status:str = ai_message.get('status', '') - next_agent:str = ai_message.get('agent', '') return_message:str = ai_message.get('message', '') - retry_count:int = int(ai_message.get('retry_count', '0')) - next_node:str = ai_message.get('node', 'pause_node') + + type:str = ai_message.get('type', '沟通') + if type == '沟通': + task['status'] = 'waiting' + task['parse'] = True + elif type == '输出': + task['status'] = 'completed' + task['parse'] = False + episodes:list = ai_message.get('episodes', []) + from tools.agent.updateDB import UpdateOneEpisode + for episode in episodes: + UpdateOneEpisode(session_id, episode.number, episode.content) # print(f"报告已生成: TEST") - print("\n------------ 创作单集内容结束 ------------") + print("\n------------ 生成 剧本圣经 结束 ------------") + return { - "from_type":'agent', - "next_node":next_node, - "workflow_step":step, - "workflow_status":status, - # "workflow_reason":return_message, - "workflow_retry_count":retry_count, - "agent_message":return_message, - } + "messages": messages, + "task_list": task_list, + "task_index": task_index, + "agent_message": return_message, + } except Exception as e: import traceback traceback.print_exc() return { - "next_node":'end_node', - "agent_message": "创作单集内容失败", + "agent_message": "生成 剧本圣经 失败", "error": str(e) or '系统错误,工作流已终止', 'status':'failed', } diff --git a/handlers/langgraph_handler.py b/handlers/langgraph_handler.py deleted file mode 100644 index feabc8c..0000000 --- a/handlers/langgraph_handler.py +++ /dev/null @@ -1,84 +0,0 @@ -from flask import request, jsonify -import asyncio -import uuid -from graph.test_graph_3 import run_with_persistence, get_checkpoint_history, resume_from_checkpoint - -def run_async(coro): - """运行异步函数的辅助函数""" - try: - # 尝试使用现有的事件循环 - loop = asyncio.get_running_loop() - except RuntimeError: - # 如果没有运行中的事件循环,则创建一个新的 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - return loop.run_until_complete(coro) - -def run_langgraph(): - """启动一个新的langgraph任务""" - try: - data = request.get_json() - user_input = data.get('user_input', '') - thread_id = data.get('thread_id', str(uuid.uuid4())) - - if not user_input: - return jsonify({'error': 'user_input is required'}), 400 - - # 运行异步函数 - output, thread_id = run_async(run_with_persistence(user_input, thread_id)) - - return jsonify({ - 'status': 'success', - 'thread_id': thread_id, - 'output': output - }) - except Exception as e: - return jsonify({'error': str(e)}), 500 - -def get_task_status(thread_id): - """查询任务状态和历史""" - try: - # 获取检查点历史 - # 注意:这里需要修改get_checkpoint_history以返回数据而不是打印 - # history = run_async(get_checkpoint_history(thread_id)) - - return jsonify({ - 'status': 'success', - 'thread_id': thread_id, - 'message': 'Task status endpoint' - }) - except Exception as e: - return jsonify({'error': str(e)}), 500 - -def resume_task(thread_id): - """从检查点恢复任务""" - try: - data = request.get_json() - checkpoint_id = data.get('checkpoint_id') - - # 恢复检查点状态 - restored_state = resume_from_checkpoint(thread_id, checkpoint_id) - - if restored_state: - return jsonify({ - 'status': 'success', - 'restored_state': restored_state - }) - else: - return jsonify({'error': 'Failed to restore checkpoint'}), 404 - except Exception as e: - return jsonify({'error': str(e)}), 500 - -def visualize_graph(thread_id): - """可视化图结构""" - try: - # 这里可以返回图的可视化信息 - # 为了简化,我们只返回基本信息 - return jsonify({ - 'status': 'success', - 'thread_id': thread_id, - 'message': 'Graph visualization endpoint' - }) - except Exception as e: - return jsonify({'error': str(e)}), 500 \ No newline at end of file diff --git a/tools/agent/queryDB.py b/tools/agent/queryDB.py index 355d82e..939231e 100644 --- a/tools/agent/queryDB.py +++ b/tools/agent/queryDB.py @@ -17,7 +17,22 @@ def QueryOriginalScript(session_id: str): "exist": script is not None, } +def QueryOriginalScriptContent(session_id: str): + """ + 查询原始剧本内容 + Args: + session_id: 会话id + Returns: + Dict: 返回一个包含以下字段的字典: + content (str): 原始剧本内容。 + """ + script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "original_script": {"$exists": True, "$ne": ""}},{"original_script":1}) + return { + "content": script["original_script"] if script else "", + } + +@tool def QueryDiagnosisAndAssessment(session_id: str): """ 查询诊断与资产评估报告是否存在 @@ -32,6 +47,21 @@ def QueryDiagnosisAndAssessment(session_id: str): "exist": script is not None, } +def QueryDiagnosisAndAssessmentContent(session_id: str): + """ + 查询诊断与资产评估报告内容 + Args: + session_id: 会话id + Returns: + Dict: 返回一个包含以下字段的字典: + content (str): 诊断与资产评估报告内容。 + """ + script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "diagnosis_and_assessment": {"$exists": True, "$ne": ""}},{"diagnosis_and_assessment":1}) + return { + "content": script["diagnosis_and_assessment"] if script else "", + } + +@tool def QueryAdaptationIdeas(session_id: str): """ 查询改编思路是否存在 @@ -46,6 +76,21 @@ def QueryAdaptationIdeas(session_id: str): "exist": script is not None, } +def QueryAdaptationIdeasContent(session_id: str): + """ + 查询改编思路内容 + Args: + session_id: 会话id + Returns: + Dict: 返回一个包含以下字段的字典: + content (str): 改编思路内容。 + """ + script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "adaptation_ideas": {"$exists": True, "$ne": ""}},{"adaptation_ideas":1}) + return { + "content": script["adaptation_ideas"] if script else "", + } + +@tool def QueryScriptBible(session_id: str): """ 查询剧本圣经是否存在 @@ -60,6 +105,21 @@ def QueryScriptBible(session_id: str): "exist": script is not None, } +def QueryScriptBibleContent(session_id: str): + """ + 查询剧本圣经内容 + Args: + session_id: 会话id + Returns: + Dict: 返回一个包含以下字段的字典: + content (str): 剧本圣经内容。 + """ + script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "script_bible": {"$exists": True}},{"script_bible":1}) + return { + "content": script["script_bible"] if script else {}, + } + +@tool def QueryCoreOutline(session_id: str): """ 查询剧本圣经中的核心大纲是否存在 @@ -74,6 +134,7 @@ def QueryCoreOutline(session_id: str): "exist": script is not None, } +@tool def QueryCharacterProfile(session_id: str): """ 查询剧本圣经中的核心人物小传是否存在 @@ -88,6 +149,7 @@ def QueryCharacterProfile(session_id: str): "exist": script is not None, } +@tool def QueryCoreEventTimeline(session_id: str): """ 查询剧本圣经中的重大事件时间线是否存在 @@ -102,6 +164,7 @@ def QueryCoreEventTimeline(session_id: str): "exist": script is not None, } +@tool def QueryCharacterList(session_id: str): """ 查询剧本圣经中的总人物表是否存在 @@ -116,6 +179,7 @@ def QueryCharacterList(session_id: str): "exist": script is not None, } +@tool def QueryEpisodeCount(session_id: str): """ 查询剧集创作情况 @@ -126,11 +190,16 @@ def QueryEpisodeCount(session_id: str): completed (int): 已完成的集数 total (int): 总集数 """ - total = mainDB.agent_writer_episodes.count_documents({"session_id": session_id}) + total = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id)},{"total_episode_num":1}) + if total is None: + return { + "completed": 0, + "total": 0, + } count = mainDB.agent_writer_episodes.count_documents({"session_id": session_id, "content": {"$exists": True, "$ne": ""}}) return { "completed": count, - "total": total, + "total": int(total["total_episode_num"]) or 0, } # def QuerySingleEpisodeContent(session_id: str): diff --git a/tools/agent/updateDB.py b/tools/agent/updateDB.py new file mode 100644 index 0000000..e483694 --- /dev/null +++ b/tools/agent/updateDB.py @@ -0,0 +1,176 @@ +from bson import ObjectId +from tools.database.mongo import mainDB +from langchain.tools import tool + +@tool +def UpdateDiagnosisAndAssessmentTool(session_id: str, diagnosis_and_assessment: str): + """ + 更新诊断与资产评估报告 + Args: + session_id: 会话id + diagnosis_and_assessment: 诊断与资产评估报告 + Returns: + Dict: 返回一个包含以下字段的字典: + success (bool): 是否更新成功 + """ + return UpdateDiagnosisAndAssessment(session_id, diagnosis_and_assessment) + +def UpdateDiagnosisAndAssessment(session_id: str, diagnosis_and_assessment: str): + """ + 更新诊断与资产评估报告 + Args: + session_id: 会话id + diagnosis_and_assessment: 诊断与资产评估报告 + Returns: + Dict: 返回一个包含以下字段的字典: + success (bool): 是否更新成功 + """ + script = mainDB.agent_writer_session.update_one({"_id": ObjectId(session_id)},{"$set":{"diagnosis_and_assessment":diagnosis_and_assessment}}) + return { + "success": script.modified_count > 0, + } + +@tool +def UpdateAdaptationIdeasTool(session_id: str, adaptation_ideas: str): + """ + 更新改编思路 + Args: + session_id: 会话id + adaptation_ideas: 改编思路 + Returns: + Dict: 返回一个包含以下字段的字典: + success (bool): 是否更新成功 + """ + return UpdateAdaptationIdeas(session_id, adaptation_ideas) + +def UpdateAdaptationIdeas(session_id: str, adaptation_ideas: str): + """ + 更新改编思路 + Args: + session_id: 会话id + adaptation_ideas: 改编思路 + Returns: + Dict: 返回一个包含以下字段的字典: + success (bool): 是否更新成功 + """ + script = mainDB.agent_writer_session.update_one({"_id": ObjectId(session_id)},{"$set":{"adaptation_ideas":adaptation_ideas}}) + return { + "success": script.modified_count > 0, + } + +@tool +def UpdateScriptBibleTool( + session_id: str, + core_outline:str|None = None, + character_profile:str|None = None, + core_event_timeline:str|None = None, + character_list:str|None = None, + ): + """ + 更新剧本圣经 + Args: + session_id: 会话id + core_outline: 核心大纲 + character_profile: 核心人物小传 + core_event_timeline: 核心事件时间线 + character_list: 总角色表 + Returns: + Dict: 返回一个包含以下字段的字典: + success (bool): 是否更新成功 + """ + return UpdateScriptBible(session_id, core_outline, character_profile, core_event_timeline, character_list) + +def UpdateScriptBible( + session_id: str, + core_outline:str|None = None, + character_profile:str|None = None, + core_event_timeline:str|None = None, + character_list:str|None = None, + ): + """ + 更新剧本圣经 + Args: + session_id: 会话id + core_outline: 核心大纲 + character_profile: 核心人物小传 + core_event_timeline: 核心事件时间线 + character_list: 总角色表 + Returns: + Dict: 返回一个包含以下字段的字典: + success (bool): 是否更新成功 + """ + if core_outline is None and character_profile is None and core_event_timeline is None and character_list is None: + return { + "success": False, + } + update_dict = {} + if core_outline is not None: + update_dict["script_bible.core_outline"] = core_outline + if character_profile is not None: + update_dict["script_bible.character_profile"] = character_profile + if core_event_timeline is not None: + update_dict["script_bible.core_event_timeline"] = core_event_timeline + if character_list is not None: + update_dict["script_bible.character_list"] = character_list + script = mainDB.agent_writer_session.update_one({"_id": ObjectId(session_id)},{"$set":update_dict}) + return { + "success": script.modified_count > 0, + } + +@tool +def SetTotalEpisodeNumTool(session_id: str, total_episode_num: int): + """ + 设置总集数 + Args: + session_id: 会话id + total_episode_num: 总集数 + Returns: + Dict: 返回一个包含以下字段的字典: + success (bool): 是否更新成功 + """ + return SetTotalEpisodeNum(session_id, total_episode_num) + +def SetTotalEpisodeNum(session_id: str, total_episode_num: int): + """ + 设置总集数 + Args: + session_id: 会话id + total_episode_num: 总集数 + Returns: + Dict: 返回一个包含以下字段的字典: + success (bool): 是否更新成功 + """ + script = mainDB.agent_writer_session.update_one({"_id": ObjectId(session_id)},{"$set":{"total_episode_num":total_episode_num}}) + return { + "success": script.modified_count > 0, + } + +@tool +def UpdateOneEpisodeTool(session_id: str, episode_num:int, content: str): + """ + 更新单集内容 + Args: + session_id: 会话id + episode_num: 剧集编号 + content: 剧集内容 + Returns: + Dict: 返回一个包含以下字段的字典: + success (bool): 是否更新成功 + """ + return UpdateOneEpisode(session_id, episode_num, content) + +def UpdateOneEpisode(session_id: str, episode_num:int, content: str): + """ + 更新单集内容 + Args: + session_id: 会话id + episode_num: 剧集编号 + content: 剧集内容 + Returns: + Dict: 返回一个包含以下字段的字典: + success (bool): 是否更新成功 + """ + script = mainDB.agent_writer_episodes.update_one({"session_id": session_id, "episode_num": episode_num},{"$set":{"content":content}}, upsert=True) + return { + "success": script.modified_count > 0, + } \ No newline at end of file diff --git a/tools/llm/huoshan_langchain.py b/tools/llm/huoshan_langchain.py index 97d65a7..8b5f480 100644 --- a/tools/llm/huoshan_langchain.py +++ b/tools/llm/huoshan_langchain.py @@ -48,26 +48,7 @@ class HuoshanChatModel(BaseChatModel): def _llm_type(self) -> str: """返回LLM类型标识""" return "huoshan_chat" - - # def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> tuple[str, str]: - # """将LangChain消息格式转换为API所需的prompt和system格式""" - # system_message = "" - # user_messages = [] - - # for message in messages: - # if isinstance(message, SystemMessage): - # system_message = message.content or "" - # elif isinstance(message, HumanMessage): - # user_messages.append(message.content) - # elif isinstance(message, SystemMessage): - # # 如果需要支持多轮对话,可以在这里处理 - # pass - - # # 合并用户消息 - # prompt = "\n".join(user_messages) if user_messages else "" - - # return prompt, str(system_message) - + def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable: """将工具绑定到模型,并将其转换为火山引擎API所需的格式。""" @@ -192,8 +173,16 @@ class HuoshanChatModel(BaseChatModel): api_messages = self._convert_messages_to_prompt(messages) tools = kwargs.get("tools", []) - - print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}") + # print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}") + print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> \n") + print(f"\nmessages: \n") + for message in messages: + print(f" {message.type}: \n ") + print(f" {message.content} \n ") + print(f"\ntools: \n") + for tool in tools: + print(f" \n {tool} \n ") + print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> end \n") response_data = self._api.get_chat_response(messages=api_messages, tools=tools)