diff --git a/agent/scheduler.py b/agent/scheduler.py index 7c095ee..53aaadc 100644 --- a/agent/scheduler.py +++ b/agent/scheduler.py @@ -5,6 +5,7 @@ from langgraph.graph import StateGraph from langgraph.prebuilt import create_react_agent from langgraph.graph.state import CompiledStateGraph +from langgraph.prebuilt.chat_agent_executor import AgentState from utils.logger import get_logger logger = get_logger(__name__) @@ -13,132 +14,59 @@ logger = get_logger(__name__) DefaultSchedulerList = [] # 默认代理提示词 -DefaultAgentPrompt = f""" - # 角色 (Persona) - 你不是一个普通的编剧,你是一位在短剧市场身经百战、爆款频出的**“顶级短剧改编专家”与“爆款操盘手”**。 - 你的核心人设与专长: - 极致爽点制造机: 你对观众的“爽点”G点有着鬣狗般的嗅觉。你的天职就是找到、放大、并以最密集的节奏呈现“打脸”、“逆袭”、“揭秘”、“宠溺”等情节。 - 人物标签化大师: 你深知在短剧中,模糊等于无效。你擅长将人物的核心欲望和性格特点极致化、标签化,让观众在3秒内记住主角,5秒内恨上反派。 - 情绪过山车设计师: 你的剧本就像过山车。开篇即俯冲,5秒一反转,10秒一高潮,结尾必留下一个让人抓心挠肝的钩子。你为观众提供的是极致的情绪体验。 - 网络梗语言学家: 你的台词充满了网感和“梗”,既能推动剧情,又能引发观众的共鸣和吐槽欲。对话追求高信息密度,不说一句废话。 - 你的沟通风格:自信、犀利、直击要害,同时又能清晰地解释你每一个改编决策背后的商业逻辑和观众心理。 +DefaultAgentPrompt = """ + # 角色定位 + 您是爆款短剧改编专家,具备极致爽点设计、人物标签化、情绪节奏控制和网感台词能力。沟通风格自信犀利,能清晰解释商业逻辑。 - # 任务总体步骤描述 - 1. 查找并确认原始剧本已就绪 - 2. 分析原始剧本得出`诊断与资产评估`,需要用户确认可以继续下一步,否则协助用户完成修改 - 3. 根据`诊断与资产评估`确定`改编思路`,需要用户确认可以继续下一步,否则协助用户完成修改 - 4. 根据`改编思路`生成`剧本圣经`,需要用户确认可以继续下一步,否则协助用户完成修改 - 5. 根据`改编思路`和`剧本圣经`持续剧集创作,单次执行3-5集的创建,直至完成全部剧集。 - 6. 注意步骤具有上下级关系,且不能跳过。但是后续步骤可返回触发前面的任务:如生成单集到第3集后,用户提出要修改某个角色,此时应当返回第4步,并协助用户进行修改与确认;完成修改后重新执行第5步,即从第一集开始重新创作一遍; + # 核心工作流 + 1. 等待原始剧本(wait_for_input) + 2. 剧本分析(script_analysis) + 3. 改编思路制定(strategic_planning) + 4. 剧本圣经构建(build_bible) + 5. 剧集循环创作(episode_create_loop) + 6. 完成(finish) + + 步骤不可跳过但可回退修改。除finish和wait_for_input外,各阶段均由对应智能体处理。episode_create_loop阶段需通过QueryEpisodeCount工具判断进度,并指定单次创作集数(3-5集)。 + + # 智能体职责 + - scheduler(您自身):调度决策、用户沟通、状态管理 + - script_analysis:生成诊断与资产评估报告 + - strategic_planning:生成改编思路 + - build_bible:生成剧本圣经(含核心大纲、人物小传、事件时间线、人物表) + - episode_create:单集内容创作 + + # 工具使用原则 + 仅在必要时调用工具,避免重复。关键工具包括:QueryOriginalScript、QueryDiagnosisAndAssessment、QueryAdaptationIdeas、QueryScriptBible、QueryEpisodeCount等。 + - QueryOriginalScript:原始剧本是否存在 + - QueryDiagnosisAndAssessment:诊断与资产评估报告是否存在 + - QueryAdaptationIdeas:改编思路是否存在 + - QueryScriptBible:剧本圣经是否存在 + - QueryEpisodeCount:剧集总数与生成完成集数获取 + 如果工具读取到存在为true则不需要再调用该工具; + 如果工具读取到存在为false则需要需要分析当前任务的阶段来决定调用哪个工具,每次分析只能调用一个工具; + + # 任务列表管理 + - 任务列表为空时自动根据工作流步骤生成新列表 + - 每项任务包含:agent、step、status、reason、retry_count、pause、episode_create_num + - 执行逻辑:优先处理第一个未完成任务;状态为completed时推进;failed时根据reason决定重试(≤3次)或通知用户;waiting时暂停等待用户输入 + - 所有任务完成后,用户输入仍可触发新任务列表 + + # 输入数据解析 + 每次调用附带: + - workflow_info:布尔状态组(原始剧本、诊断报告等) + - workflow_params:session_id、task_index、from_type(user或agent) + - task_list:当前任务列表数组 + + 根据from_type决策:user直接解析用户意图;agent基于返回结果更新任务状态 - 步骤中对应的阶段如下: - wait_for_input: 等待剧本阶段,查询到`原始剧本`存在并分析到用户确认后进入下一阶段 - script_analysis: 原始剧本分析阶段,查询到`诊断与资产评估`存在并分析到用户确认后进入下一阶段 - strategic_planning: 确立改编目标阶段,查询到`改编思路`存在并分析到用户确认后进入下一阶段 - build_bible: 剧本圣经构建阶段,查询到`剧本圣经`存在并分析到用户确认后进入下一阶段 - episode_create_loop: 剧集创作阶段,查询`剧集创作情况`并分析到已完成所有剧集的创作后进入下一阶段 - finish: 所有剧集创作已完成,用户确认后结束任务,用户需要修改则回退到适合的步骤进行修改并重新执行后续阶段 - - ***除了finish和wait_for_input之外的阶段都需要交给对应的智能体去处理*** - ***episode_create_loop阶段是一个循环阶段,每次循环需要你通过工具方法`剧集创作情况`来判断是否所有剧集都已创作完成,以及需要创作智能体单次创作的集数(通常是3-5集), 该集数为`指定创作集数`,需要添加到返回参数中*** - - # 智能体职责介绍 - ***调度智能体*** 名称:`scheduler` 描述:你自己,需要用户确认反馈时返回自身,并把状态设置成waiting; - ***原始剧本分析 智能体*** 名称:`script_analysis` 描述:构建`诊断与资产评估`;内容包括:故事内核诊断、可继承的宝贵资产(高光情节、神来之笔对白、独特人设闪光点)、以及核心问题与初步改编建议。用户需要对`诊断与资产评估`进行修改都直接交给该智能体; - ***确立改编目标 智能体*** 名称:`strategic_planning` 描述:构建`改编思路`;此文件将作为所有后续改编的最高指导原则。用户需要对`改编思路`进行修改都直接交给该智能体; - ***剧本圣经构建 智能体*** 名称:`build_bible` 描述:构建`剧本圣经`,剧本圣经具体包括了这几个部分:核心大纲, 核心人物小传, 重大事件时间线, 总人物表; 用户需要对`剧本圣经`的每一个部分进行修改都直接交给该智能体; - ***剧集创作 智能体*** 名称:`episode_create` 描述:构建剧集的具体创作;注意该智能体仅负责剧集的创作;对于某一集的具体修改直接交给该智能体; - - ***注意:智能体调用后最终会返回再次请求到你,你需要根据智能体的处理结果来决定下一步*** - ***注意:`智能体调用` 不是工具方法的使用,而是在返回数据中把agent属性指定为要调用的智能体名称*** - - # 工具使用 - 上述智能体职责中提及的输出内容,都有对应的工具可供你调用进行查看;他们的查询工具名称分别对应如下: - 原始剧本是否存在: `QueryOriginalScript` - 诊断与资产评估是否存在: `QueryDiagnosisAndAssessment` - 改编思路是否存在: `QueryAdaptationIdeas` - 剧本圣经是否存在: `QueryScriptBible` - 核心大纲是否存在: `QueryCoreOutline` - 核心人物小传是否存在: `QueryCharacterProfile` - 重大事件时间线是否存在: `QueryCoreEventTimeline` - 总人物表是否存在: `QueryCharacterList` - 剧集创作情况: `QueryEpisodeCount` - - ***注意:工具使用是需要你调用工具方法的;大多数情况下同一个方法只需要调用一次*** - - ***每次用户的输入都会携带最新的`任务列表`和`工作流参数`,注意查看并分析是否应该回复用户等待或提醒用户确认继续下一步*** - # 工作流参数包含字段如下: - "session_id": 会话ID 可用于工具方法的调用 - "task_index": 当前执行中的任务索引 - "from_type": 本次请求来着哪里 - user: 用户 - agent: 智能体返回 - - # 任务列表是一个数组,每项的数据结构如下: - "agent": 执行这个任务的智能体名称 - 字符串内容 可为空 为空时表示当前任务不需要调用智能体 - - step: 阶段名称 - wait_for_input: 等待用户提供原始剧本 - script_analysis: 原始剧本分析 - strategic_planning: 确立改编目标 - build_bible: 剧本圣经构建 - episode_create_loop: 剧集创作 - finish: 所有剧集创作完成 - - "pause": 是否需要暂停 当需要和用户沟通时设置为true,任务会中断等待用户回复 - - status: 当前阶段的状态 - waiting: 等待用户反馈、确认 - running: 进行中 - failed: 失败 - completed: 完成 - - "reason": 失败原因,仅在`status`为`failed`时返回 - 字符串内容 - - "retry_count": 失败重试次数 - 整数内容 - "episode_create_num": 指定创作集数 仅在episode_create_loop阶段会返回,内容是数组,数组中每一项是指定创作的剧集编号(从1开始); - 整数数组 - - # 你的职责 - 分析用户所有的输入信息,以及信息中的`任务列表`、`工作流参数`; - 首先读取`工作流参数`中的数据,判断from_type的值是user还是agent; - 如果是user,说明用户是在和你沟通,你需要根据用户的输入来判断下一步; - 如果是agent,说明是智能体执行完任务后返回的结果,你需要根据智能体的返回结果来判断下一步;根据task_index取出`任务列表`中的当前任务,判断他的状态来决定是否继续列表中的下一个任务; - 如果当前任务的状态是completed,说明当前任务完成,需要继续列表中的下一个任务; - 如果当前任务的状态是failed,说明当前任务失败,需要根据失败原因来判断是否需要重试;如果需要重试,需要增加重试次数,并且需要继续列表中的下一个任务; - 如果当前任务的状态是waiting,说明当前任务等待用户反馈,需要等待用户反馈后继续执行;此时任务中的pause属性需要修改为true; - - `任务列表`的生成规则: - 1 `任务列表`为空时,你需要根据上文中`任务总体步骤描述`生成一个新的任务列表 - 2 执行`任务列表`中的第一个未完成的任务 - - 当你读取到一个空的任务列表 或者 任务列表中的所有任务都完成或失败时,你需要分析出一个新的任务列表; - 新的任务列表至少包含一个任务; - 任务列表中的每个任务代表的是一个后续智能体要执行的任务,其中scheduler代表你自己,大多数情况下这代表了用户回复了你的提问; - - 以下是任务列表的几种情况的示例: - 1 任务列表为空时,你需要生成一个新的任务列表;此时你的分析结果是需要用户 - 1 `wait_for_input` 向用户问好,并介绍你作为“爆款短剧操盘手”的身份和专业工作流程,礼貌地请用户提供需要改编的原始剧本。如果用户没有提供原始剧本,你将持续友好地提醒,此时状态始终为waiting,直到获取原始剧本为止。从用户提交的中可以获取到session_id的时候,需要调用 `QueryOriginalScript` 工具来查询原始剧本是否存在。 - 2 `script_analysis` 读取到原始剧本并从输入中分析出可以继续后进入,调用`原始剧本分析 智能体`继续后续工作;running时,礼貌回复用户并提醒用户任务真正进行中;completed代表任务完成,此时可等待用户反馈;直到跟用户确认可以进行下一步后再继续后续任务; - 3 `strategic_planning` 根据`诊断与资产评估`的结果,调用`确立改编目标 智能体`,并返回结果。 - 4 `build_bible` 根据`改编思路`的结果,调用`剧本圣经构建 智能体`,并返回结果。 - 5 `episode_create_loop` 根据`剧本圣经`的结果,调用`剧集创作 智能体` - 5 `finish` 所有剧集完成后设置为该状态,但是不要返回node==end_node,因为用户还可以继续输入来进一步修改产出内容; - - ***当任意一个智能体返回失败时,你需要分析reason字段中的内容,来决定是否进行重试,如果需要重试则给retry_count加1,并交给失败的那个智能体重试一次;如果retry_count超过了3次,或者失败原因不适合重试则反馈给用户说任务失败了,请稍后再试*** - + # 输出规范 请严格按照下列JSON结构返回数据,不要有其他任何多余的信息和描述: {{ - "message":'',//回复给用户的内容,注意,仅在你与用户沟通时才返回,其他情况下不返回。比如用户的需求是要交给其他智能体处理时,这个属性应该为空 - "task_list": [] //最新的任务列表 - "task_index": 0 //执行中的任务的索引 + "message": "回复用户的内容(仅需用户沟通时填充)", + "task_list": [更新后的任务列表数组], + "task_index": 当前执行任务索引, }} - """ - def create_agent_prompt(prompt, SchedulerList): """创建代理提示词的辅助函数""" if not SchedulerList or len(SchedulerList) == 0: return prompt @@ -149,13 +77,56 @@ def create_agent_prompt(prompt, SchedulerList): {node_list} \n """ +class SchedulerAgentState(AgentState): + """调度智能体中的上下文对象""" + is_original_script: bool # 是否已提交原始剧本 + is_script_analysis: bool # 是否已生成 诊断与资产评估报告 + is_strategic_planning: bool # 是否已生成 改编思路 + is_build_bible: bool # 是否已生成 剧本圣经 + is_episode_create_loop: bool # 是否已生成 剧集生成循环 + is_all_episode_created: bool # 是否已生成 全部剧集 + +def pre_scheduler_hook(state:SchedulerAgentState): + """模型调用前的钩子函数""" + logger.info(f"!!!!!!!!!调度节点输入!!!!!!!!!!!") + session_id = state.get("session_id", "") + workflow_info = state.get("workflow_info", {}) + task_list = state.get("task_list", []) + task_index = int(state.get("task_index", 0)) + from_type = state.get("from_type", "") + messages = state.get("messages", []) + logger.info(f"调度节点输入: {state}") + logger.info(f"!!!!!!!!!调度节点输入!!!!!!!!!!!") + # 清除历史状态消息 + # messages = clear_messages(messages) + # # 添加参数进提示词 + # messages.append(HumanMessage(content=f""" + # ---任务状态消息(开始)--- + # # 工作流信息: + # {workflow_info} + # # 工作流参数: + # {{ + # 'session_id':'{session_id}', + # 'task_index':'{task_index}', + # 'from_type':'{from_type}', + # }} + # # 任务列表: + # {task_list} + # ---任务状态消息(结束)--- + # """)) + # return state class SchedulerAgent(CompiledStateGraph): """智能调度智能体类 该类负责接收用户的提示词,并调用其他智能体来处理工作。 """ - def __new__(cls, llm=None, tools=[], SchedulerList=None): + def __new__(cls, + llm=None, + tools=[], + SchedulerList=None, + post_model_hook=None, + ): """创建并返回create_react_agent创建的对象""" # 处理默认参数 if llm is None: @@ -169,5 +140,9 @@ class SchedulerAgent(CompiledStateGraph): return create_react_agent( model=llm, tools=tools, - prompt=create_agent_prompt(prompt=DefaultAgentPrompt, SchedulerList=SchedulerList), - ) \ No newline at end of file + prompt=DefaultAgentPrompt, + # pre_model_hook=pre_scheduler_hook, + # post_model_hook=post_model_hook, + # state_schema=SchedulerAgentState, + # prompt=create_agent_prompt(prompt=DefaultAgentPrompt, SchedulerList=SchedulerList), + ) diff --git a/graph/test_agent_graph_1.py b/graph/test_agent_graph_1.py index a5d1d4d..0091036 100644 --- a/graph/test_agent_graph_1.py +++ b/graph/test_agent_graph_1.py @@ -7,6 +7,7 @@ from re import T from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional from langgraph.graph.state import RunnableConfig +from langgraph.prebuilt.chat_agent_executor import AgentState from agent.scheduler import SchedulerAgent from agent.build_bible import BuildBibleAgent from agent.episode_create import EpisodeCreateAgent @@ -24,7 +25,8 @@ import config from tools.database.mongo import client # type: ignore from langgraph.checkpoint.mongodb import MongoDBSaver # 工具方法 -from tools.agent.queryDB import QueryOriginalScript +from tools.agent.queryDB import QueryOriginalScript,QueryDiagnosisAndAssessment,QueryAdaptationIdeas,QueryScriptBible,QueryEpisodeCount +from tools.agent.updateDB import UpdateAdaptationIdeasTool,UpdateScriptBibleTool,UpdateDiagnosisAndAssessmentTool,UpdateOneEpisodeTool logger = get_logger(__name__) @@ -55,6 +57,7 @@ def clear_messages(messages): def messages_handler(old_messages: list[AnyMessage], new_messages: list[AnyMessage]): """消息合并方法""" clear_messages(old_messages) + new_messages = [message for message in new_messages if message.type != 'ai' or message.content] return old_messages + new_messages def replace_value(old_val, new_val): @@ -141,89 +144,82 @@ class ScriptwriterGraph: def node_router(self, state: ScriptwriterState) -> str: """节点路由函数""" + # return 'end_node' # print(f'node_router state {state}') task_list = state.get("task_list", []) task_index = state.get("task_index", 0) now_task = task_list[task_index] + print(f'node_router now_task {now_task}') if not now_task or now_task.get('pause'): next_node = 'pause_node' # 设置为暂停节点 else: next_node = AgentNodeMap.get(now_task.get('agent'), 'pause_node') - # print(f'node_router next_node {next_node}') + print(f'node_router next_node {next_node}') return next_node + + + def post_scheduler_hook(self, state: ScriptwriterState)-> ScriptwriterState: + """模型调用后的钩子函数""" + ai_message_str = state.get("messages", [])[-1].content + logger.info(f"!!!!!!!!!调度节点输出!!!!!!!!!!!") + logger.info(f"调度节点结果: {ai_message_str}") + # logger.info(f"调度节点结果 end") + # ai_message = json.loads(ai_message_str) + return state def _build_graph(self) -> None: """构建工作流图""" try: # 创建智能体 print("创建智能体") + from tools.llm.deepseek_langchain import DeepseekChatModel + llm = DeepseekChatModel(api_key='sk-571923fdcb0e493d8def7e2d78c02cb8') # 调度智能体 self.schedulerAgent = SchedulerAgent( + # llm=llm, tools=[ QueryOriginalScript, + QueryDiagnosisAndAssessment, + QueryAdaptationIdeas, + QueryScriptBible, + QueryEpisodeCount, ], - SchedulerList=[ - { - "name": "scheduler_node", - "desc": "调度智能体节点", - }, - { - "name": "script_analysis_node", - "desc": "原始剧本分析节点", - }, - { - "name": "strategic_planning_node", - "desc": "确立改编目标节点", - }, - { - "name": "build_bible_node", - "desc": "剧本圣经构建节点", - }, - { - "name": "episode_create_node", - "desc": "单集创作节点", - }, - { - "name": "end_node", - "desc": "结束节点,任务失败终止时使用,结束后整个工作流将停止" - } - ] ) self.scriptAnalysisAgent = ScriptAnalysisAgent( tools=[], - SchedulerList=[ - { - "name": "scheduler_node", - "desc": "调度智能体节点", - } - ] + # SchedulerList=[ + # { + # "name": "scheduler_node", + # "desc": "调度智能体节点", + # } + # ] ) self.strategicPlanningAgent = StrategicPlanningAgent( tools=[], - SchedulerList=[ - { - "name": "scheduler_node", - "desc": "调度智能体节点", - } - ] + # SchedulerList=[ + # { + # "name": "scheduler_node", + # "desc": "调度智能体节点", + # } + # ] ) self.buildBibleAgent = BuildBibleAgent( tools=[], - SchedulerList=[ - { - "name": "scheduler_node", - "desc": "调度智能体节点", - } - ] + # SchedulerList=[ + # { + # "name": "scheduler_node", + # "desc": "调度智能体节点", + # } + # ] ) self.episodeCreateAgent = EpisodeCreateAgent( tools=[], - SchedulerList=[ - { - "name": "scheduler_node", - "desc": "调度智能体节点", - } - ] + # SchedulerList=[ + # { + # "name": "scheduler_node", + # "desc": "调度智能体节点", + # } + # ] ) # 创建状态图 @@ -272,6 +268,8 @@ class ScriptwriterGraph: logger.info("工作流图构建完成") except Exception as e: + import traceback + traceback.print_exc() logger.error(f"构建工作流图失败: {e}") raise @@ -289,10 +287,10 @@ class ScriptwriterGraph: # 清除历史状态消息 messages = clear_messages(messages) # 添加参数进提示词 + # 工作流信息: + # {workflow_info} messages.append(HumanMessage(content=f""" ---任务状态消息(开始)--- - # 工作流信息: - {workflow_info} # 工作流参数: {{ 'session_id':'{session_id}', @@ -315,8 +313,13 @@ class ScriptwriterGraph: ai_message_count += 1 logger.info(f"调度节点 {session_id} \n 输入消息条数: {len(messages)} \n from_type:{from_type} \n system_message_count:{system_message_count} \n human_message_count:{human_message_count} \n ai_message_count:{ai_message_count}") # 调用智能体 - reslut = await self.schedulerAgent.ainvoke({"messages":messages}) + reslut = await self.schedulerAgent.ainvoke({ + "messages": messages, + "session_id": session_id, + }) ai_message_str = reslut['messages'][-1].content + logger.info(f"调度节点结果: {ai_message_str}") + logger.info(f"调度节点结果 end") ai_message = json.loads(ai_message_str) # logger.info(f"调度节点结果: {ai_message}") return_message:str = ai_message.get('message', '') @@ -324,10 +327,9 @@ class ScriptwriterGraph: task_index:int = int(ai_message.get('task_index', '0')) return { - "messages": messages, + "agent_message": return_message, "task_list": task_list, "task_index": task_index, - "agent_message": return_message, } except Exception as e: import traceback @@ -620,8 +622,13 @@ class ScriptwriterGraph: 工作流执行结果 """ try: - logger.info("开始运行智能编剧工作流") - + logger.info("开始运行智能编剧工作流") + output_result: OutputState = { + 'session_id': session_id, + 'status': '', + 'error': '', + 'agent_message': '', + } # 配置包含线程 ID config:RunnableConfig = {"configurable": {"thread_id": thread_id}} # 初始化状态 @@ -630,11 +637,10 @@ class ScriptwriterGraph: 'session_id': session_id, 'from_type': 'user', } - # 运行工作流 if self.graph is None: raise RuntimeError("工作流图未正确初始化") - + # 运行工作流 直接返回最终结果 result = await self.graph.ainvoke( initial_state, config, @@ -642,15 +648,24 @@ class ScriptwriterGraph: ) # logger.info(f"工作流执行结果: {result}") if not result: - raise ValueError("工作流执行结果为空") - + raise ValueError("工作流执行结果为空") # 构造输出状态 - output_result: OutputState = { + output_result = { 'session_id': result.get('session_id', ''), 'status': result.get('status', ''), 'error': result.get('error', ''), 'agent_message': result.get('agent_message', ''), } + + # 流式处理 + # st = self.graph.stream( + # initial_state, + # config, + # stream_mode='values' + # ) + # from utils.stream import debug_print_stream + # debug_print_stream(st) + logger.info("智能编剧工作流运行完成") return output_result diff --git a/tools/agent/queryDB.py b/tools/agent/queryDB.py index 939231e..71c257b 100644 --- a/tools/agent/queryDB.py +++ b/tools/agent/queryDB.py @@ -1,6 +1,11 @@ +from typing import Annotated from bson import ObjectId +from langchain_core.messages import ToolMessage +from langchain_core.tools import InjectedToolCallId +from langgraph.types import Command from tools.database.mongo import mainDB from langchain.tools import tool +import json @tool def QueryOriginalScript(session_id: str): @@ -13,9 +18,24 @@ def QueryOriginalScript(session_id: str): exist (bool): 原始剧本内容是否存在。 """ script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "original_script": {"$exists": True, "$ne": ""}},{"_id":1}) + is_original_script = script is not None + # print(f"tool_call_id {tool_call_id}") + Command(update={ + "is_original_script": is_original_script + }) return { - "exist": script is not None, + "exist": is_original_script, } + tool_message_content = json.dumps({"exist": is_original_script}) + # return Command(update={ + # "is_original_script": is_original_script, + # "messages": [ + # ToolMessage( + # tool_call_id=tool_call_id, + # content=tool_message_content + # ) + # ] + # }) def QueryOriginalScriptContent(session_id: str): """ diff --git a/tools/llm/deepseek_langchain.py b/tools/llm/deepseek_langchain.py new file mode 100644 index 0000000..d000a79 --- /dev/null +++ b/tools/llm/deepseek_langchain.py @@ -0,0 +1,85 @@ +from typing import Any, List, Optional +from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatResult +from langchain_community.chat_models import ChatOpenAI +import os + +# 继承自 ChatOpenAI 而不是 BaseChatModel +class DeepseekChatModel(ChatOpenAI): + """ + 对 DeepSeek 聊天模型的 LangChain 封装。 + 这个版本通过继承 ChatOpenAI 来实现,并重写 _generate 方法。 + """ + + def __init__(self, model_name: str = "deepseek-chat", **kwargs: Any): + """ + 初始化 DeepseekChatModel + """ + # 从环境变量或参数中获取 api_key + api_key = kwargs.pop("api_key", os.getenv("DEEPSEEK_API_KEY")) + if not api_key: + raise ValueError( + "DeepSeek API key must be provided either as an argument or set as the DEEPSEEK_API_KEY environment variable." + ) + + # 调用父类(ChatOpenAI)的构造函数,并传入 DeepSeek 的特定配置 + super().__init__( + model=model_name, + api_key=api_key, + base_url="https://api.deepseek.com/v1", # 这是关键 + **kwargs, + ) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """ + 重写 _generate 方法,这是 LangChain 的核心调用点。 + """ + # 1. 在调用父类方法前,执行你的自定义逻辑(例如,打印日志) + print("\n--- [DeepseekChatModel] 调用 _generate ---") + print(f"输入消息数量: {len(messages)}") + print(f"第一条消息内容: {messages[0].content}") + print("-------------------------------------------\n") + + # 2. 使用 super() 调用父类(ChatOpenAI)的原始 _generate 方法 + # 这样可以复用 ChatOpenAI 中所有复杂的、经过测试的逻辑 + chat_result = super()._generate( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ) + + # 3. 在获得结果后,你还可以执行其他自定义逻辑 + print("\n--- [DeepseekChatModel] _generate 调用完成 ---") + print(f"输出消息内容: {chat_result.generations[0].message.content[:80]}...") # 打印部分输出 + print("--------------------------------------------\n") + + return chat_result + + @property + def _llm_type(self) -> str: + """返回 language model 的类型。""" + return "deepseek-chat-model-v2" # 改个名以示区别 + +# --- 使用示例 --- +if __name__ == "__main__": + # 确保设置了 API 密钥 + if not os.getenv("DEEPSEEK_API_KEY"): + print("请设置 DEEPSEEK_API_KEY 环境变量。") + else: + # 初始化模型 + chat_model = DeepseekChatModel(temperature=0.7) + + # 构建输入消息 + from langchain_core.messages import HumanMessage + messages = [HumanMessage(content="你好,请介绍一下新加坡。")] + + # 调用模型 + # 当你调用 .invoke() 或 .stream() 时,LangChain 内部会调用我们重写的 _generate 方法 + response = chat_model.invoke(messages) + + print(f"\n最终得到的回复:\n{response.content}") \ No newline at end of file diff --git a/tools/llm/huoshan_langchain.py b/tools/llm/huoshan_langchain.py index 8b5f480..ade11e4 100644 --- a/tools/llm/huoshan_langchain.py +++ b/tools/llm/huoshan_langchain.py @@ -8,6 +8,7 @@ from langchain_core.runnables import Runnable from pydantic import Field import json import copy +from langchain_core.messages import AnyMessage,HumanMessage from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam from volcenginesdkarkruntime.types.shared_params.function_definition import FunctionDefinition from volcenginesdkarkruntime.types.shared_params.function_parameters import FunctionParameters @@ -174,15 +175,16 @@ class HuoshanChatModel(BaseChatModel): api_messages = self._convert_messages_to_prompt(messages) tools = kwargs.get("tools", []) # print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}") - print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> \n") - print(f"\nmessages: \n") - for message in messages: - print(f" {message.type}: \n ") - print(f" {message.content} \n ") - print(f"\ntools: \n") - for tool in tools: - print(f" \n {tool} \n ") - print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> end \n") + # print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> \n") + # print(f"\nmessages: \n") + # for message in messages: + # print("--- Message Attributes ---\n") + # for key, value in message.__dict__.items(): + # print(f" {key}: {value} \n") + # print(f"\ntools: \n") + # for tool in tools: + # print(f" {tool} \n ") + # print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> end \n") response_data = self._api.get_chat_response(messages=api_messages, tools=tools) @@ -192,7 +194,7 @@ class HuoshanChatModel(BaseChatModel): message_from_api = res_choices.get("message", {}) tool_calls = message_from_api.get("tool_calls", []) print(f" 豆包返回的 finish_reason: {finish_reason} \n tool_calls: {tool_calls} \n") - print(f" 豆包返回的 message: {message_from_api.get('content', '')}") + # print(f" 豆包返回的 message: {message_from_api.get('content', '')}") if finish_reason == "tool_calls" and tool_calls: lc_tool_calls = [] for tc in tool_calls: diff --git a/tools/llm/openai_langchain.py b/tools/llm/openai_langchain.py new file mode 100644 index 0000000..9c20516 --- /dev/null +++ b/tools/llm/openai_langchain.py @@ -0,0 +1,193 @@ +import os +from typing import Any, List, Optional, Type + +from langchain_community.chat_models import ChatOpenAI +from langchain_core.messages import HumanMessage +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.tools import BaseTool, tool + +# --- 1. 中央模型配置 --- +# 在这里添加或修改你想要支持的模型提供商 +# 'api_key_env': 存放 API Key 的环境变量名称 +# 'base_url': 模型的 API 端点 +# 'supports_tools': 该模型是否支持 LangChain 的 .bind_tools() 功能 +# 'default_model': 如果不指定模型,默认使用的模型名称 +MODEL_CONFIG = { + "openai": { + "api_key_env": "OPENAI_API_KEY", + "base_url": "https://api.openai.com/v1", + "supports_tools": True, + "default_model": "gpt-4o", + }, + "deepseek": { + "api_key_env": "DEEPSEEK_API_KEY", + "base_url": "https://api.deepseek.com/v1", + "supports_tools": True, + "default_model": "deepseek-chat", + }, + "groq": { + "api_key_env": "GROQ_API_KEY", + "base_url": "https://api.groq.com/openai/v1", + "supports_tools": True, + "default_model": "llama3-70b-8192", + }, + "moonshot": { + "api_key_env": "MOONSHOT_API_KEY", + "base_url": "https://api.moonshot.cn/v1", + "supports_tools": True, + "default_model": "moonshot-v1-8k", + }, + # 示例:一个不支持工具的老模型 + "legacy_provider": { + "api_key_env": "LEGACY_API_KEY", + "base_url": "http://localhost:8080/v1", # 假设是本地服务 + "supports_tools": False, + "default_model": "legacy-model-v1", + }, + + # 新增火山引擎的配置 + "huoshan": { + "api_key_env": "VOLC_API_KEY", # 请确认您存放密钥的环境变量名 + "base_url": "https://ark.cn-beijing.volces.com/api/v3", # 这是火山引擎方舟的 OpenAI 兼容端点 + "supports_tools": True, + "default_model": "ep-20240615082149-j225c", # 使用您需要的模型 Endpoint ID + }, +} + + +# --- 2. 模型创建工厂函数 --- +def create_llm_client( + provider: str, + model_name: Optional[str] = None, + tools: Optional[List[Type[BaseModel]]] = None, + **kwargs: Any, +) -> ChatOpenAI: + """ + 根据提供的 provider 创建并配置一个 LangChain 聊天模型客户端。 + + Args: + provider: 模型提供商的名称 (必须是 MODEL_CONFIG 中的一个 key)。 + model_name: 要使用的具体模型名称。如果为 None, 则使用配置中的默认模型。 + tools: 一个工具列表,用于绑定到模型上以实现 function calling。 + **kwargs: 其他要传递给 ChatOpenAI 的参数 (例如 temperature, max_tokens)。 + + Returns: + 一个配置好的 ChatOpenAI 实例,可能已经绑定了工具。 + """ + if provider not in MODEL_CONFIG: + raise ValueError(f"不支持的模型提供商: {provider}。可用选项: {list(MODEL_CONFIG.keys())}") + + config = MODEL_CONFIG[provider] + api_key = os.getenv(config["api_key_env"]) + + if not api_key: + raise ValueError(f"请设置环境变量 {config['api_key_env']} 以使用 {provider} 模型。") + + # 如果未指定 model_name,则使用配置中的默认值 + final_model_name = model_name or config["default_model"] + + # 创建基础的 LLM 客户端 + llm = ChatOpenAI( + model=final_model_name, + api_key=api_key, + base_url=config["base_url"], + **kwargs, + ) + + # 根据配置,有条件地绑定工具 + if tools: + if config["supports_tools"]: + print(f"[{provider}] 模型支持工具,正在绑定 {len(tools)} 个工具...") + return llm.bind_tools(tools) + else: + print(f"⚠️ 警告: 您为 [{provider}] 提供了工具,但该模型在配置中被标记为不支持工具。将返回未绑定工具的模型。") + return llm + + return llm + + +# --- 3. 定义你的工具 (Function Calling) --- +# 使用 Pydantic 模型定义工具的输入参数,确保类型安全和清晰的描述 +class GetWeatherInput(BaseModel): + city: str = Field(description="需要查询天气的城市名称, 例如: Singapore") + +# 使用 @tool 装饰器可以轻松地将任何函数转换为 LangChain 工具 +@tool(args_schema=GetWeatherInput) +def get_current_weather(city: str) -> str: + """ + 当需要查询指定城市的当前天气时,调用此工具。 + """ + # 这是一个模拟实现,实际应用中你会在这里调用真实的天气API + print(f"--- 正在调用工具: get_current_weather(city='{city}') ---") + if "singapore" in city.lower(): + return f"新加坡今天的天气是晴朗,温度为 31°C。" + elif "beijing" in city.lower(): + return f"北京今天的天气是多云,温度为 25°C。" + else: + return f"抱歉,我无法查询到 {city} 的天气信息。" + + +# --- 4. 主程序:演示如何使用 --- +if __name__ == "__main__": + # 将你的工具放入一个列表 + my_tools = [get_current_weather] + + # ---- 示例 1: 使用 DeepSeek 并调用工具 ---- + print("\n================ 示例 1: 使用 DeepSeek 调用工具 ================") + # 确保你已经设置了环境变量: export DEEPSEEK_API_KEY="sk-..." + try: + deepseek_llm_with_tools = create_llm_client( + provider="deepseek", + tools=my_tools, + temperature=0 + ) + prompt = "今天新加坡的天气怎么样?" + print(f"用户问题: {prompt}") + + # LangChain 会自动处理:LLM -> Tool Call -> Execute Tool -> LLM -> Final Answer + # 为了演示,我们只看第一步的输出 + ai_msg = deepseek_llm_with_tools.invoke(prompt) + print("\nLLM 返回的初步响应 (AIMessage):") + print(ai_msg) + + # 检查返回的是否是工具调用 + if ai_msg.tool_calls: + print(f"\n模型请求调用工具: {ai_msg.tool_calls[0]['name']}") + # 在实际应用中,你会在这里执行工具并把结果返回给模型 + else: + print("\n模型直接给出了回答:") + print(ai_msg.content) + + except ValueError as e: + print(e) + + + # ---- 示例 2: 使用 Groq (Llama3) 且不使用工具 ---- + print("\n================ 示例 2: 使用 Groq 进行常规聊天 ================") + # 确保你已经设置了环境变量: export GROQ_API_KEY="gsk_..." + try: + groq_llm = create_llm_client( + provider="groq", + temperature=0.7, + # model_name="llama3-8b-8192" # 你也可以覆盖默认模型 + ) + prompt = "请给我写一首关于新加坡的五言绝句。" + print(f"用户问题: {prompt}") + response = groq_llm.invoke(prompt) + print("\nGroq (Llama3) 的回答:") + print(response.content) + except ValueError as e: + print(e) + + # ---- 示例 3: 尝试给不支持工具的模型绑定工具 ---- + print("\n========== 示例 3: 尝试为不支持工具的模型绑定工具 ==========") + # 假设你设置了 export LEGACY_API_KEY="some_key" + try: + legacy_llm = create_llm_client( + provider="legacy_provider", + tools=my_tools + ) + # 注意,这里会打印出警告信息,并且 legacy_llm 不会绑定任何工具 + except ValueError as e: + print(e) \ No newline at end of file diff --git a/utils/stream.py b/utils/stream.py new file mode 100644 index 0000000..7b2796e --- /dev/null +++ b/utils/stream.py @@ -0,0 +1,8 @@ +def debug_print_stream(stream): + for chunk in stream: + message = chunk['messages'][-1] + if isinstance(message, tuple): + print(chunk) + else: + message.pretty_print() + print("\n---")