From 3ca857a589722169f5c40bb2a22ed88a5facacfe Mon Sep 17 00:00:00 2001 From: jonathang4 Date: Fri, 12 Sep 2025 00:32:55 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B7=A5=E5=85=B7=E6=96=B9=E6=B3=95=E8=B0=83?= =?UTF-8?q?=E7=94=A8=201?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent/scheduler.py | 12 +- api/huoshan.py | 99 ++++++-------- doc/test.txt | 173 +++++++++++++++++++++++ graph/test_agent_graph_1.py | 45 ++++-- tools/agent/__init__.py | 3 + tools/agent/queryDB.py | 38 ++++++ tools/llm/huoshan_langchain.py | 241 +++++++++++++++++++++++++-------- 7 files changed, 484 insertions(+), 127 deletions(-) create mode 100644 doc/test.txt create mode 100644 tools/agent/__init__.py create mode 100644 tools/agent/queryDB.py diff --git a/agent/scheduler.py b/agent/scheduler.py index f064777..14637bb 100644 --- a/agent/scheduler.py +++ b/agent/scheduler.py @@ -39,6 +39,7 @@ DefaultAgentPrompt = f""" ***单集创作 智能体*** 名称:`episode_create` 描述:用户确认`剧本圣经`后,交给该智能体来构建某一集的具体创作;注意该智能体仅负责单集的创作,因此该智能体的调度需要有你根据`剧本圣经`中的`核心大纲`来多次调用,逐步完成所有剧集的创作;对于某一集的具体修改直接交给该智能体; ***注意:智能体调用后最终会返回再次请求到你,你需要根据智能体的处理结果来决定下一步*** + ***注意:`智能体调用` 不是工具方法的使用,而是在返回数据中把agent属性指定为要调用的智能体名称*** # 工具使用 上述智能体职责中提及的输出内容,都有对应的工具可供你调用进行查看;他们的查询工具名称分别对应如下: @@ -54,8 +55,10 @@ DefaultAgentPrompt = f""" 未完成的集数: `QueryUnfinishedEpisodeCount` 已完成的集数: `QueryCompletedEpisodeCount` + ***注意:工具使用是需要你调用工具方法的;但是大多数情况下,你不需要查询文本的具体内容,只需要查询存在与否即可*** + ***每次用户的输入都会携带当前`总任务的进度与任务状态`,注意查看并分析是否应该回复用户等待或提醒用户确认继续下一步*** - # 总任务的进度与任务状态数据结构为 {{"step": "waiting_script", "status": "running", "from_type":"user", "reason": "waiting_script", "retry_count": 0}} + # 总任务的进度与任务状态数据结构为 {{"step": "waiting_script", "status": "running", "from_type":"user", "reason": "waiting_script", "retry_count": 0, "query_args":{{}}}} step: 阶段名称 wait_for_input: 等待用户提供原始剧本 @@ -80,9 +83,12 @@ DefaultAgentPrompt = f""" "retry_count": 重试次数 + "query_args": 用于调用工具方法的参数,可能包括的字段有: + "session_id": 会话ID,可用于查询`原始剧本` + # 职责 分析用户输入与`总任务的进度与任务状态`,以下是几种情况的示例: - 1 `wait_for_input` 向用户问好,并介绍你作为“爆款短剧操盘手”的身份和专业工作流程,礼貌地请用户提供需要改编的原始剧本。如果用户没有提供原始剧本,你将持续友好地提醒,此时状态始终为waiting,直到获取原始剧本为止。 + 1 `wait_for_input` 向用户问好,并介绍你作为“爆款短剧操盘手”的身份和专业工作流程,礼貌地请用户提供需要改编的原始剧本。如果用户没有提供原始剧本,你将持续友好地提醒,此时状态始终为waiting,直到获取原始剧本为止。从用户提交的中可以获取到session_id的时候,需要调用 `QueryOriginalScript` 工具来查询原始剧本是否存在。 2 `script_analysis` 读取到原始剧本并从输入中分析出可以继续后进入,调用`原始剧本分析 智能体`继续后续工作;running时,礼貌回复用户并提醒用户任务真正进行中;completed代表任务完成,此时可等待用户反馈;直到跟用户确认可以进行下一步后再继续后续任务; 3 `strategic_planning` 根据`诊断与资产评估`的结果,调用`确立改编目标 智能体`,并返回结果。 4 `build_bible` 根据`改编思路`的结果,调用`剧本圣经构建 智能体`,并返回结果。 @@ -98,7 +104,7 @@ DefaultAgentPrompt = f""" "agent":'',//分析后得出由哪个智能体继续任务,此处为智能体名称;如果需要继续与用户交互或仅需要回复用户则为空字符串 "message":'',//回复给用户的内容 "retry_count":0,//重试次数 - "node":'',//下一个节点名称 + "node":'',//下一个节点名称,根据指定的agent名称,从取值范围列表中选择一个节点名称返回 }} """ diff --git a/api/huoshan.py b/api/huoshan.py index 0741e8e..bb83a03 100644 --- a/api/huoshan.py +++ b/api/huoshan.py @@ -2,7 +2,7 @@ import os import json import random import time -from typing import Dict, List, Optional, Any, cast +from typing import Dict, Iterable, List, Optional, Any, cast from datetime import datetime from volcenginesdkarkruntime import Ark from volcengine.visual.VisualService import VisualService @@ -12,6 +12,8 @@ from volcengine.vod.models.request.request_vod_pb2 import VodUploadMediaRequest from volcengine.vod.models.request.request_vod_pb2 import VodUrlUploadRequest import base64 +from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam + from config import API_CONFIG class HuoshanAPI: @@ -507,32 +509,23 @@ class HuoshanAPI: 'error': f'删除异常: {str(e)}' } - def get_chat_response(self, prompt: str, model: Optional[str] = None, system: Optional[str] = None, temperature: float = 0.6) -> str: - """ - 获取聊天机器人回复 - 使用doubao_seed_1.6模型 - :param prompt: 用户输入的文本 - :param model: 模型名称,默认使用doubao_seed_1.6 - :param system: 系统提示词 - :param temperature: 温度参数 - :return: 机器人回复的文本 - """ + def get_chat_response( + self, + messages: List[Dict], + model: Optional[str] = None, + tools: Optional[List[ChatCompletionToolParam]] = None, + temperature: float = 0.6 + ) -> Dict[str, Any]: + """同步获取聊天响应,支持工具调用。""" + try: if model is None: model = self.doubao_seed_1_6_model_id - messages:Any = [] - if system: - messages = [ - {"role": "system", "content": system}, - {"role": "user", "content": prompt} - ] - else: - messages = [ - {"role": "user", "content": prompt} - ] - - response = self.client.chat.completions.create( + if not messages or len(messages)==0: + raise ValueError(f"火山引擎API调用失败: 消息不能为空") + completion = self.client.chat.completions.create( model=model, - messages=messages, + messages=messages, # type: ignore max_tokens=16384, # 16K temperature=temperature, timeout=600, @@ -541,55 +534,39 @@ class HuoshanAPI: # "type": "enabled", # 使用深度思考能力 # "type": "auto", # 模型自行判断是否使用深度思考能力 }, + tools=tools # 传入 tools 参数 ) - - return response.choices[0].message.content # pyright: ignore + return completion.model_dump() # type: ignore 使用 model_dump() 转换为字典 except Exception as e: - raise Exception(f'Huoshan chat API调用异常: {str(e)}') - - def get_chat_response_stream(self, prompt: str, model: Optional[str] = None, system: Optional[str] = None, temperature: float = 0.6): - """ - 获取聊天机器人的流式回复 - 使用doubao_seed_1.6模型 - :param prompt: 用户输入的文本 - :param model: 模型名称,默认使用doubao_seed_1.6 - :param system: 系统提示词 - :param temperature: 温度参数 - :return: 生成器,逐步返回机器人回复的文本 - """ + raise ValueError(f"火山引擎API调用失败: {str(e)}") + + def get_chat_response_stream( + self, + messages: List[Dict], + model: Optional[str] = None, + tools: Optional[List[ChatCompletionToolParam]] = None, + temperature: float = 0.6 + ) -> Iterable[str]: + """流式获取聊天响应,支持工具调用。""" + try: if model is None: model = self.doubao_seed_1_6_model_id - - messages:Any = [] - if system: - messages = [ - {"role": "system", "content": system}, - {"role": "user", "content": prompt} - ] - else: - messages = [ - {"role": "user", "content": prompt} - ] - - response = self.client.chat.completions.create( + + completion = self.client.chat.completions.create( model=model, - messages=messages, + messages=messages, # type: ignore temperature=temperature, max_tokens=16384, # 16K timeout=600, - stream=True + stream=True, + tools=tools # 传入 tools 参数 ) - - - for chunk in response: - chunk_obj = cast(Any, chunk) - if hasattr(chunk_obj, 'choices') and chunk_obj.choices and len(chunk_obj.choices) > 0: - delta = chunk_obj.choices[0].delta - if hasattr(delta, 'content') and delta.content is not None: - yield delta.content - + for chunk in completion: + if chunk.choices and chunk.choices[0].delta.content is not None: # type: ignore + yield chunk.choices[0].delta.content # type: ignore except Exception as e: - raise Exception(f'Huoshan chat stream API调用异常: {str(e)}') + raise ValueError(f"火山引擎API流式调用失败: {str(e)}") def analyze_image(self, image_url: str, prompt: str = "请描述这张图片的内容", model: Optional[str] = None, detail: str = "high") -> Dict[str, Any]: """ diff --git a/doc/test.txt b/doc/test.txt new file mode 100644 index 0000000..942ffbc --- /dev/null +++ b/doc/test.txt @@ -0,0 +1,173 @@ +为了在你的代码中正确接入火山引擎的工具调用功能,你需要修改 `huoshan_langchain.py` 和 `huoshan.py` 这两个文件,以实现**从 LangChain 工具到火山引擎 API 工具定义的格式转换**,以及**解析和处理来自 API 的工具调用响应**。 + +下面是根据火山引擎官方文档和你的代码,我为你整理的完整修改方案。 + +----- + +### 第一步:修改 `huoshan_langchain.py` + +这个文件是 LangChain 的封装层,负责连接你的工作流和火山引擎的底层 API。你需要在这里实现 `bind_tools` 和 `_generate` 方法来处理工具调用。 + +1. **导入必要的类**: + 需要添加 `ToolMessage` 和 `ToolCall`,它们是 LangChain 用于表示工具调用结果和工具调用的核心类。 + + ```python + from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence + from langchain_core.callbacks.manager import CallbackManagerForLLMRun + from langchain_core.language_models.chat_models import BaseChatModel + from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall + from langchain_core.outputs import ChatGeneration, ChatResult + from langchain_core.tools import BaseTool + from langchain_core.runnables import Runnable + from langchain.pydantic_v1 import BaseModel + from pydantic import Field + from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam + from api.huoshan import HuoshanAPI + import json + ``` + +2. **实现 `bind_tools` 方法**: + 这个方法是 LangChain 用于将工具定义传递给你的模型封装。在这里,你需要将 LangChain 的 `BaseTool` 对象转换为火山引擎 API 所需的 `ChatCompletionToolParam` 格式。 + + ```python + class HuoshanChatModel(BaseChatModel): + # ... (其他代码不变) ... + + def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable: + """将工具绑定到模型,并将其转换为火山引擎API所需的格式。""" + tool_definitions = [] + for tool_item in tools: + tool_definitions.append( + ChatCompletionToolParam( + type="function", + function={ + "name": tool_item.name, + "description": tool_item.description, + "parameters": tool_item.args_schema.schema() if isinstance(tool_item.args_schema, type(BaseModel)) else tool_item.args_schema + } + ) + ) + + # 返回一个绑定了工具的新实例 + # 这里我们使用_bind方法,它会返回一个新的Runnable实例 + return self._bind(tools=tool_definitions, **kwargs) + ``` + +3. **修改 `_convert_messages_to_prompt` 方法**: + 这个方法需要能够处理 LangChain 的 `ToolMessage` 和 `AIMessage`,并将其转换为火山引擎 API 的消息格式。这对于工具调用的回填和最终回复至关重要。 + + ```python + class HuoshanChatModel(BaseChatModel): + # ... (其他代码不变) ... + + def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> List[Dict]: + """将LangChain消息转换为火山引擎API所需的格式。""" + api_messages = [] + for msg in messages: + if isinstance(msg, HumanMessage): + api_messages.append({"role": "user", "content": msg.content}) + elif isinstance(msg, AIMessage): + if msg.tool_calls: + api_messages.append({ + "role": "assistant", + "content": "", + "tool_calls": [{ + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": json.dumps(tc.args) + } + } for tc in msg.tool_calls] + }) + else: + api_messages.append({"role": "assistant", "content": msg.content}) + elif isinstance(msg, ToolMessage): + api_messages.append({ + "role": "tool", + "content": msg.content, + "tool_call_id": msg.tool_call_id + }) + return api_messages + ``` + +4. **修改 `_generate` 方法**: + 这个方法需要调用底层 API,并解析大模型返回的响应,以检查是否包含工具调用。 + + ```python + class HuoshanChatModel(BaseChatModel): + # ... (其他代码不变) ... + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if not self._api: + raise ValueError("HuoshanAPI未正确初始化") + + api_messages = self._convert_messages_to_prompt(messages) + tools = kwargs.get("tools", []) + + response_data = self._api.get_chat_response(messages=api_messages, tools=tools) + + try: + message_from_api = response_data.get("choices", [{}])[0].get("message", {}) + + tool_calls = message_from_api.get("tool_calls", []) + if tool_calls: + lc_tool_calls = [] + for tc in tool_calls: + lc_tool_calls.append(ToolCall( + name=tc["function"]["name"], + args=json.loads(tc["function"]["arguments"]), + id=tc.get("id", "") + )) + message = AIMessage(content="", tool_calls=lc_tool_calls) + else: + content = message_from_api.get("content", "") + message = AIMessage(content=content) + + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + except Exception as e: + raise ValueError(f"处理火山引擎API响应失败: {str(e)}") + ``` + +### 第二步:修改 `huoshan.py` + +这个文件是底层 API 客户端,负责与火山引擎 API 进行通信。你需要修改 `get_chat_response` 方法,使其能够发送 `tools` 参数。 + +```python +# huoshan.py + +from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam + +class HuoshanAPI: + # ... (其他代码不变) ... + + def get_chat_response( + self, + messages: List[Dict], + stream: bool = False, + tools: Optional[List[ChatCompletionToolParam]] = None + ) -> Dict[str, Any]: + """同步获取聊天响应,支持工具调用。""" + client = Ark() + + try: + completion = client.chat.completions.create( + model=self.doubao_seed_1_6_model_id, + messages=messages, + stream=stream, + tools=tools # 传入 tools 参数 + ) + return completion.model_dump() # 使用 model_dump() 转换为字典 + except Exception as e: + raise ValueError(f"火山引擎API调用失败: {str(e)}") +``` + +完成以上修改后,你的 `HuoshanChatModel` 就会支持工具调用,并能与 LangGraph 的智能体框架无缝集成。 \ No newline at end of file diff --git a/graph/test_agent_graph_1.py b/graph/test_agent_graph_1.py index e5eef46..b34ae36 100644 --- a/graph/test_agent_graph_1.py +++ b/graph/test_agent_graph_1.py @@ -21,6 +21,8 @@ import config from tools.database.mongo import client # type: ignore from langgraph.checkpoint.mongodb import MongoDBSaver +# 工具方法 +from tools.agent.queryDB import QueryOriginalScript logger = get_logger(__name__) @@ -32,7 +34,7 @@ def replace_value(old_val, new_val): # 状态类型定义 class InputState(TypedDict): """工作流输入状态""" - input_data: Annotated[list[AnyMessage], operator.add] + messages: Annotated[list[AnyMessage], operator.add] from_type: Annotated[str, replace_value] session_id: Annotated[str, replace_value] @@ -55,7 +57,7 @@ class NodeInfo(TypedDict): class ScriptwriterState(TypedDict, total=False): """智能编剧工作流整体状态""" # 输入数据 - input_data: Annotated[list[HumanMessage], operator.add] + messages: Annotated[list[AnyMessage], operator.add] session_id: Annotated[str, replace_value] from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent] @@ -113,7 +115,9 @@ class ScriptwriterGraph: print("创建智能体") # 调度智能体 self.schedulerAgent = SchedulerAgent( - tools=[], + tools=[ + QueryOriginalScript, + ], SchedulerList=[ { "name": "scheduler_node", @@ -225,14 +229,33 @@ class ScriptwriterGraph: logger.error(f"构建工作流图失败: {e}") raise - # --- 定义图中的节点 --- + # --- 定义图中的节点 --- async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState: """调度节点""" try: session_id = state.get("session_id", "") from_type = state.get("from_type", "") - input_data = state.get("input_data", []) - logger.info(f"调度节点 {session_id} 输入参数: {input_data} from_type:{from_type}") + messages = state.get("messages", []) + + workflow_step = state.get("workflow_step", "wait_for_input") + workflow_status = state.get("workflow_status", "waiting") + workflow_reason = state.get("workflow_reason", "") + workflow_retry_count = int(state.get("workflow_retry_count", 0)) + # 添加参数进提示词 + messages.append(HumanMessage(content=f""" + # 总任务的进度与任务状态: + {{ + 'query_args':{{ + 'session_id':'{session_id}', + }}, + 'step':'{workflow_step}', + 'status':'{workflow_status}', + 'from_type':'{from_type}', + 'reason':'{workflow_reason}', + 'retry_count':{workflow_retry_count}, + }} + """)) + logger.info(f"调度节点 {session_id} 输入参数: {messages} from_type:{from_type}") reslut = await self.schedulerAgent.ainvoke(state) ai_message_str = reslut['messages'][-1].content ai_message = json.loads(ai_message_str) @@ -259,6 +282,8 @@ class ScriptwriterGraph: "agent_message":return_message, } except Exception as e: + import traceback + traceback.print_exc() return { "next_node":'end_node', "agent_message": "执行失败", @@ -317,11 +342,11 @@ class ScriptwriterGraph: "agent_message": state.get('agent_message', ''), } - async def run(self, session_id: str, input_data: list[AnyMessage], thread_id: str|None = None) -> OutputState: + async def run(self, session_id: str, messages: list[AnyMessage], thread_id: str|None = None) -> OutputState: """运行工作流 Args: session_id: 会话ID - input_data: 输入数据 + messages: 输入数据 thread_id: 线程ID Returns: @@ -334,7 +359,7 @@ class ScriptwriterGraph: config:RunnableConfig = {"configurable": {"thread_id": thread_id}} # 初始化状态 initial_state: InputState = { - 'input_data': input_data, + 'messages': messages, 'session_id': session_id, 'from_type': 'user', } @@ -456,7 +481,7 @@ if __name__ == "__main__": session_id = "68c2c2915e5746343301ef71" result = await graph.run( session_id, - [HumanMessage(content="你好编剧,我想写小说!")], + [HumanMessage(content="老师,我写好剧本了,您看看!帮我分析分析把!")], session_id ) print(f"最终结果: {result}") diff --git a/tools/agent/__init__.py b/tools/agent/__init__.py new file mode 100644 index 0000000..f75794c --- /dev/null +++ b/tools/agent/__init__.py @@ -0,0 +1,3 @@ +""" +智能体工具模块 +""" \ No newline at end of file diff --git a/tools/agent/queryDB.py b/tools/agent/queryDB.py new file mode 100644 index 0000000..db6e7c2 --- /dev/null +++ b/tools/agent/queryDB.py @@ -0,0 +1,38 @@ +from bson import ObjectId +from tools.database.mongo import mainDB +from langchain.tools import tool + +@tool +def QueryOriginalScript(session_id: str, only_exist: bool = False): + """ + 查询原始剧本内容或是否存在 + Args: + session_id: 会话id + only_exist: 是否只查询存在的剧本 + + Returns: + Dict: 返回一个包含以下字段的字典: + original_script (str): 原始剧本内容。仅当 only_exist 为 False 时返回该字段。 + exist (bool): 原始剧本内容是否存在。 + + """ + # c = mainDB.agent_writer_session.count_documents({}) + # print(f"查询到的原始剧本session_id: {session_id}, only_exist: {only_exist} count:{c}") + if only_exist: + script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "original_script": {"$exists": True, "$ne": ""}}) + # print(f"exist: {script}") + return { + "original_script": "", + "exist": script is not None, + } + else: + script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id)}, {"original_script": 1}) + original_script = "" + if script: + original_script = script["original_script"] or '' + print(f"查询到的原始剧本字符长度: {len(original_script)}") + return { + "original_script": original_script, + "exist": original_script != '', + } + diff --git a/tools/llm/huoshan_langchain.py b/tools/llm/huoshan_langchain.py index 17d40d7..97d65a7 100644 --- a/tools/llm/huoshan_langchain.py +++ b/tools/llm/huoshan_langchain.py @@ -1,11 +1,32 @@ -from typing import Any, Dict, Iterator, List, Optional +from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage +from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.tools import BaseTool +from langchain_core.runnables import Runnable from pydantic import Field +import json +import copy +from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam +from volcenginesdkarkruntime.types.shared_params.function_definition import FunctionDefinition +from volcenginesdkarkruntime.types.shared_params.function_parameters import FunctionParameters from api.huoshan import HuoshanAPI # 导入你现有的API类 +def _convert_dict_type(data: Any) -> Any: + """递归地将字典中的"Dict"字符串类型转换为"dict".""" + if isinstance(data, dict): + new_data = {} + for k, v in data.items(): + if isinstance(v, str) and v == "Dict": + new_data[k] = "dict" + else: + new_data[k] = _convert_dict_type(v) + return new_data + elif isinstance(data, list): + return [_convert_dict_type(item) for item in data] + else: + return data class HuoshanChatModel(BaseChatModel): """火山引擎聊天模型的LangChain封装""" @@ -28,25 +49,137 @@ class HuoshanChatModel(BaseChatModel): """返回LLM类型标识""" return "huoshan_chat" - def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> tuple[str, str]: - """将LangChain消息格式转换为API所需的prompt和system格式""" - system_message = "" - user_messages = [] + # def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> tuple[str, str]: + # """将LangChain消息格式转换为API所需的prompt和system格式""" + # system_message = "" + # user_messages = [] - for message in messages: - if isinstance(message, SystemMessage): - system_message = message.content or "" - elif isinstance(message, HumanMessage): - user_messages.append(message.content) - elif isinstance(message, AIMessage): - # 如果需要支持多轮对话,可以在这里处理 - pass + # for message in messages: + # if isinstance(message, SystemMessage): + # system_message = message.content or "" + # elif isinstance(message, HumanMessage): + # user_messages.append(message.content) + # elif isinstance(message, SystemMessage): + # # 如果需要支持多轮对话,可以在这里处理 + # pass - # 合并用户消息 - prompt = "\n".join(user_messages) if user_messages else "" + # # 合并用户消息 + # prompt = "\n".join(user_messages) if user_messages else "" - return prompt, str(system_message) + # return prompt, str(system_message) + + def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable: + """将工具绑定到模型,并将其转换为火山引擎API所需的格式。""" + # 1. 创建一个当前模型的副本 + # new_model = copy.copy(self) + tool_definitions = [] + for tool_item in tools: + # 1. 获取工具的参数 JSON Schema + if hasattr(tool_item.args_schema, 'model_json_schema'): + # 如果是 Pydantic 模型,调用其方法获取 JSON Schema + parameters_schema = tool_item.args_schema.model_json_schema() # type: ignore + else: + parameters_schema = _convert_dict_type(parameters_schema) + + # 2. 从字典中提取并修正参数 + # SDK要求 'type' 为 'object','properties' 和 'required' 字段为必填 + schema_type = parameters_schema.get("type", "object") + schema_properties = parameters_schema.get("properties", {}) + schema_required = parameters_schema.get("required", []) + + # 3. 将修正后的字典作为参数,直接创建 FunctionParameters 对象 + # 这是最安全的方式,因为它绕过了之前可能存在的类型转换问题 + function_parameters:FunctionParameters = { + "type":schema_type, + "properties":schema_properties, + "required":schema_required + } + + # 4. 将 FunctionParameters 对象封装成 FunctionDefinition + function_def:FunctionDefinition = { + "name":tool_item.name, + "description":tool_item.description, + "parameters":function_parameters + } + + # 5. 将 FunctionDefinition 对象封装成 ChatCompletionToolParam + tool_definitions.append( + ChatCompletionToolParam( + type="function", + function=function_def + ) + ) + + # 返回一个绑定了工具的新实例 + # 这里我们使用_bind方法,它会返回一个新的Runnable实例 + # return self._bind(tools=tool_definitions, **kwargs) # type: ignore + # return self.bind(tools=tool_definitions, **kwargs) + return super().bind(tools=tool_definitions, **kwargs) + + # new_model._bound_tools = tool_definitions + + # 3. 返回新的、已绑定工具的模型实例 + # return new_model + + def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> List[Dict]: + """将LangChain消息转换为火山引擎API所需的格式。""" + # print(f" 原始 messages: {messages}") + + # 使用字典来存储每个tool_call_id的最新消息 + last_tool_messages = {} + # 临时存储其他消息(非ToolMessage) + other_messages = [] + for msg in messages: + if isinstance(msg, ToolMessage): + # 将最新的ToolMessage存入字典,这会覆盖之前的旧消息 + last_tool_messages[msg.tool_call_id] = msg + else: + other_messages.append(msg) + # 从后向前遍历,只保留每个tool_call_id的第一个(即最新的)消息 + deduplicated_messages = [] + seen_tool_call_ids = set() + + for msg in reversed(messages): + # 如果是ToolMessage,并且tool_call_id已经存在,则跳过 + if isinstance(msg, ToolMessage): + if msg.tool_call_id not in seen_tool_call_ids: + seen_tool_call_ids.add(msg.tool_call_id) + deduplicated_messages.append(msg) + # 其他类型的消息直接添加 + else: + deduplicated_messages.append(msg) + + # 将消息列表还原为原始时间顺序 + messages_to_convert = list(reversed(deduplicated_messages)) + + # 2. 将去重后的消息列表转换为API所需的格式 + api_messages = [] + for msg in messages_to_convert: + if isinstance(msg, HumanMessage): + api_messages.append({"role": "user", "content": msg.content}) + elif isinstance(msg, SystemMessage): + api_messages.append({"role": "system", "content": msg.content}) + elif isinstance(msg, ToolMessage): + api_messages.append({ + "role": "tool", + "content": msg.content, + "tool_call_id": msg.tool_call_id + }) + # api_messages = [] + # for msg in messages: + # if isinstance(msg, HumanMessage): + # api_messages.append({"role": "user", "content": msg.content}) + # elif isinstance(msg, SystemMessage): + # api_messages.append({"role": "system", "content": msg.content}) + # elif isinstance(msg, ToolMessage): + # api_messages.append({ + # "role": "tool", + # "content": msg.content, + # "tool_call_id": msg.tool_call_id + # }) + return api_messages + def _generate( self, messages: List[BaseMessage], @@ -54,38 +187,45 @@ class HuoshanChatModel(BaseChatModel): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - """生成聊天回复""" if not self._api: raise ValueError("HuoshanAPI未正确初始化") - # 转换消息格式 - prompt, system = self._convert_messages_to_prompt(messages) - - # 合并参数 - generation_kwargs = { - "model": kwargs.get("model", self.model_name), - "temperature": kwargs.get("temperature", self.temperature), - "system": system - } + api_messages = self._convert_messages_to_prompt(messages) + tools = kwargs.get("tools", []) + + print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}") + response_data = self._api.get_chat_response(messages=api_messages, tools=tools) + try: - # 调用你的API - response_text = self._api.get_chat_response( - prompt=prompt, - **generation_kwargs - ) - - # 创建AI消息 - message = AIMessage(content=response_text) - - # 创建生成结果 + res_choices = response_data.get("choices", [{}])[0] + finish_reason = res_choices.get("finish_reason", "") + message_from_api = res_choices.get("message", {}) + tool_calls = message_from_api.get("tool_calls", []) + print(f" 豆包返回的 finish_reason: {finish_reason} \n tool_calls: {tool_calls} \n") + print(f" 豆包返回的 message: {message_from_api.get('content', '')}") + if finish_reason == "tool_calls" and tool_calls: + lc_tool_calls = [] + for tc in tool_calls: + function_dict = tc.get("function", {}) + lc_tool_calls.append(ToolCall( + name=function_dict.get("name", ""), + args=json.loads(function_dict.get("arguments", "{}")), + id=tc.get("id", "") + )) + message = AIMessage(content="", tool_calls=lc_tool_calls) + else: + content = message_from_api.get("content", "") + message = AIMessage(content=content) + generation = ChatGeneration(message=message) - return ChatResult(generations=[generation]) - + except Exception as e: - raise ValueError(f"调用火山引擎API失败: {str(e)}") - + import traceback + traceback.print_exc() + raise ValueError(f"处理火山引擎API响应失败: {str(e)}") + def _stream( self, messages: List[BaseMessage], @@ -98,20 +238,15 @@ class HuoshanChatModel(BaseChatModel): raise ValueError("HuoshanAPI未正确初始化") # 转换消息格式 - prompt, system = self._convert_messages_to_prompt(messages) - - # 合并参数 - generation_kwargs = { - "model": kwargs.get("model", self.model_name), - "temperature": kwargs.get("temperature", self.temperature), - "system": system - } - + api_messages = self._convert_messages_to_prompt(messages) + tools = kwargs.get("tools", []) + try: # 调用流式API for chunk in self._api.get_chat_response_stream( - prompt=prompt, - **generation_kwargs + messages=api_messages, + tools=tools, + **kwargs ): if chunk: # 创建增量消息 @@ -134,4 +269,4 @@ class HuoshanChatModel(BaseChatModel): "model_name": self.model_name, "temperature": self.temperature, "max_tokens": self.max_tokens, - } \ No newline at end of file + }