工具方法调用 1

This commit is contained in:
jonathang4 2025-09-12 00:32:55 +08:00
parent 750af43ff3
commit 3ca857a589
7 changed files with 484 additions and 127 deletions

View File

@ -39,6 +39,7 @@ DefaultAgentPrompt = f"""
***单集创作 智能体*** 名称:`episode_create` 描述:用户确认`剧本圣经`,交给该智能体来构建某一集的具体创作注意该智能体仅负责单集的创作因此该智能体的调度需要有你根据`剧本圣经`中的`核心大纲`来多次调用逐步完成所有剧集的创作对于某一集的具体修改直接交给该智能体 ***单集创作 智能体*** 名称:`episode_create` 描述:用户确认`剧本圣经`,交给该智能体来构建某一集的具体创作注意该智能体仅负责单集的创作因此该智能体的调度需要有你根据`剧本圣经`中的`核心大纲`来多次调用逐步完成所有剧集的创作对于某一集的具体修改直接交给该智能体
***注意智能体调用后最终会返回再次请求到你你需要根据智能体的处理结果来决定下一步*** ***注意智能体调用后最终会返回再次请求到你你需要根据智能体的处理结果来决定下一步***
***注意`智能体调用` 不是工具方法的使用而是在返回数据中把agent属性指定为要调用的智能体名称***
# 工具使用 # 工具使用
上述智能体职责中提及的输出内容都有对应的工具可供你调用进行查看他们的查询工具名称分别对应如下 上述智能体职责中提及的输出内容都有对应的工具可供你调用进行查看他们的查询工具名称分别对应如下
@ -54,8 +55,10 @@ DefaultAgentPrompt = f"""
未完成的集数: `QueryUnfinishedEpisodeCount` 未完成的集数: `QueryUnfinishedEpisodeCount`
已完成的集数: `QueryCompletedEpisodeCount` 已完成的集数: `QueryCompletedEpisodeCount`
***注意工具使用是需要你调用工具方法的但是大多数情况下你不需要查询文本的具体内容只需要查询存在与否即可***
***每次用户的输入都会携带当前`总任务的进度与任务状态`注意查看并分析是否应该回复用户等待或提醒用户确认继续下一步*** ***每次用户的输入都会携带当前`总任务的进度与任务状态`注意查看并分析是否应该回复用户等待或提醒用户确认继续下一步***
# 总任务的进度与任务状态数据结构为 {{"step": "waiting_script", "status": "running", "from_type":"user", "reason": "waiting_script", "retry_count": 0}} # 总任务的进度与任务状态数据结构为 {{"step": "waiting_script", "status": "running", "from_type":"user", "reason": "waiting_script", "retry_count": 0, "query_args":{{}}}}
step: 阶段名称 step: 阶段名称
wait_for_input: 等待用户提供原始剧本 wait_for_input: 等待用户提供原始剧本
@ -80,9 +83,12 @@ DefaultAgentPrompt = f"""
"retry_count": 重试次数 "retry_count": 重试次数
"query_args": 用于调用工具方法的参数可能包括的字段有
"session_id": 会话ID可用于查询`原始剧本`
# 职责 # 职责
分析用户输入与`总任务的进度与任务状态`以下是几种情况的示例 分析用户输入与`总任务的进度与任务状态`以下是几种情况的示例
1 `wait_for_input` 向用户问好并介绍你作为爆款短剧操盘手的身份和专业工作流程礼貌地请用户提供需要改编的原始剧本如果用户没有提供原始剧本你将持续友好地提醒此时状态始终为waiting直到获取原始剧本为止 1 `wait_for_input` 向用户问好并介绍你作为爆款短剧操盘手的身份和专业工作流程礼貌地请用户提供需要改编的原始剧本如果用户没有提供原始剧本你将持续友好地提醒此时状态始终为waiting直到获取原始剧本为止从用户提交的中可以获取到session_id的时候需要调用 `QueryOriginalScript` 工具来查询原始剧本是否存在
2 `script_analysis` 读取到原始剧本并从输入中分析出可以继续后进入调用`原始剧本分析 智能体`继续后续工作running时礼貌回复用户并提醒用户任务真正进行中completed代表任务完成此时可等待用户反馈直到跟用户确认可以进行下一步后再继续后续任务 2 `script_analysis` 读取到原始剧本并从输入中分析出可以继续后进入调用`原始剧本分析 智能体`继续后续工作running时礼貌回复用户并提醒用户任务真正进行中completed代表任务完成此时可等待用户反馈直到跟用户确认可以进行下一步后再继续后续任务
3 `strategic_planning` 根据`诊断与资产评估`的结果调用`确立改编目标 智能体`并返回结果 3 `strategic_planning` 根据`诊断与资产评估`的结果调用`确立改编目标 智能体`并返回结果
4 `build_bible` 根据`改编思路`的结果调用`剧本圣经构建 智能体`并返回结果 4 `build_bible` 根据`改编思路`的结果调用`剧本圣经构建 智能体`并返回结果
@ -98,7 +104,7 @@ DefaultAgentPrompt = f"""
"agent":'',//分析后得出由哪个智能体继续任务此处为智能体名称如果需要继续与用户交互或仅需要回复用户则为空字符串 "agent":'',//分析后得出由哪个智能体继续任务此处为智能体名称如果需要继续与用户交互或仅需要回复用户则为空字符串
"message":'',//回复给用户的内容 "message":'',//回复给用户的内容
"retry_count":0,//重试次数 "retry_count":0,//重试次数
"node":'',//下一个节点名称 "node":'',//下一个节点名称根据指定的agent名称从取值范围列表中选择一个节点名称返回
}} }}
""" """

View File

@ -2,7 +2,7 @@ import os
import json import json
import random import random
import time import time
from typing import Dict, List, Optional, Any, cast from typing import Dict, Iterable, List, Optional, Any, cast
from datetime import datetime from datetime import datetime
from volcenginesdkarkruntime import Ark from volcenginesdkarkruntime import Ark
from volcengine.visual.VisualService import VisualService from volcengine.visual.VisualService import VisualService
@ -12,6 +12,8 @@ from volcengine.vod.models.request.request_vod_pb2 import VodUploadMediaRequest
from volcengine.vod.models.request.request_vod_pb2 import VodUrlUploadRequest from volcengine.vod.models.request.request_vod_pb2 import VodUrlUploadRequest
import base64 import base64
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
from config import API_CONFIG from config import API_CONFIG
class HuoshanAPI: class HuoshanAPI:
@ -507,32 +509,23 @@ class HuoshanAPI:
'error': f'删除异常: {str(e)}' 'error': f'删除异常: {str(e)}'
} }
def get_chat_response(self, prompt: str, model: Optional[str] = None, system: Optional[str] = None, temperature: float = 0.6) -> str: def get_chat_response(
""" self,
获取聊天机器人回复 - 使用doubao_seed_1.6模型 messages: List[Dict],
:param prompt: 用户输入的文本 model: Optional[str] = None,
:param model: 模型名称默认使用doubao_seed_1.6 tools: Optional[List[ChatCompletionToolParam]] = None,
:param system: 系统提示词 temperature: float = 0.6
:param temperature: 温度参数 ) -> Dict[str, Any]:
:return: 机器人回复的文本 """同步获取聊天响应,支持工具调用。"""
"""
try: try:
if model is None: if model is None:
model = self.doubao_seed_1_6_model_id model = self.doubao_seed_1_6_model_id
messages:Any = [] if not messages or len(messages)==0:
if system: raise ValueError(f"火山引擎API调用失败: 消息不能为空")
messages = [ completion = self.client.chat.completions.create(
{"role": "system", "content": system},
{"role": "user", "content": prompt}
]
else:
messages = [
{"role": "user", "content": prompt}
]
response = self.client.chat.completions.create(
model=model, model=model,
messages=messages, messages=messages, # type: ignore
max_tokens=16384, # 16K max_tokens=16384, # 16K
temperature=temperature, temperature=temperature,
timeout=600, timeout=600,
@ -541,55 +534,39 @@ class HuoshanAPI:
# "type": "enabled", # 使用深度思考能力 # "type": "enabled", # 使用深度思考能力
# "type": "auto", # 模型自行判断是否使用深度思考能力 # "type": "auto", # 模型自行判断是否使用深度思考能力
}, },
tools=tools # 传入 tools 参数
) )
return completion.model_dump() # type: ignore 使用 model_dump() 转换为字典
return response.choices[0].message.content # pyright: ignore
except Exception as e: except Exception as e:
raise Exception(f'Huoshan chat API调用异常: {str(e)}') raise ValueError(f"火山引擎API调用失败: {str(e)}")
def get_chat_response_stream(
self,
messages: List[Dict],
model: Optional[str] = None,
tools: Optional[List[ChatCompletionToolParam]] = None,
temperature: float = 0.6
) -> Iterable[str]:
"""流式获取聊天响应,支持工具调用。"""
def get_chat_response_stream(self, prompt: str, model: Optional[str] = None, system: Optional[str] = None, temperature: float = 0.6):
"""
获取聊天机器人的流式回复 - 使用doubao_seed_1.6模型
:param prompt: 用户输入的文本
:param model: 模型名称默认使用doubao_seed_1.6
:param system: 系统提示词
:param temperature: 温度参数
:return: 生成器逐步返回机器人回复的文本
"""
try: try:
if model is None: if model is None:
model = self.doubao_seed_1_6_model_id model = self.doubao_seed_1_6_model_id
messages:Any = [] completion = self.client.chat.completions.create(
if system:
messages = [
{"role": "system", "content": system},
{"role": "user", "content": prompt}
]
else:
messages = [
{"role": "user", "content": prompt}
]
response = self.client.chat.completions.create(
model=model, model=model,
messages=messages, messages=messages, # type: ignore
temperature=temperature, temperature=temperature,
max_tokens=16384, # 16K max_tokens=16384, # 16K
timeout=600, timeout=600,
stream=True stream=True,
tools=tools # 传入 tools 参数
) )
for chunk in completion:
if chunk.choices and chunk.choices[0].delta.content is not None: # type: ignore
for chunk in response: yield chunk.choices[0].delta.content # type: ignore
chunk_obj = cast(Any, chunk)
if hasattr(chunk_obj, 'choices') and chunk_obj.choices and len(chunk_obj.choices) > 0:
delta = chunk_obj.choices[0].delta
if hasattr(delta, 'content') and delta.content is not None:
yield delta.content
except Exception as e: except Exception as e:
raise Exception(f'Huoshan chat stream API调用异常: {str(e)}') raise ValueError(f"火山引擎API流式调用失败: {str(e)}")
def analyze_image(self, image_url: str, prompt: str = "请描述这张图片的内容", model: Optional[str] = None, detail: str = "high") -> Dict[str, Any]: def analyze_image(self, image_url: str, prompt: str = "请描述这张图片的内容", model: Optional[str] = None, detail: str = "high") -> Dict[str, Any]:
""" """

173
doc/test.txt Normal file
View File

@ -0,0 +1,173 @@
为了在你的代码中正确接入火山引擎的工具调用功能,你需要修改 `huoshan_langchain.py` 和 `huoshan.py` 这两个文件,以实现**从 LangChain 工具到火山引擎 API 工具定义的格式转换**,以及**解析和处理来自 API 的工具调用响应**。
下面是根据火山引擎官方文档和你的代码,我为你整理的完整修改方案。
-----
### 第一步:修改 `huoshan_langchain.py`
这个文件是 LangChain 的封装层,负责连接你的工作流和火山引擎的底层 API。你需要在这里实现 `bind_tools` 和 `_generate` 方法来处理工具调用。
1. **导入必要的类**
需要添加 `ToolMessage` 和 `ToolCall`,它们是 LangChain 用于表示工具调用结果和工具调用的核心类。
```python
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import BaseTool
from langchain_core.runnables import Runnable
from langchain.pydantic_v1 import BaseModel
from pydantic import Field
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
from api.huoshan import HuoshanAPI
import json
```
2. **实现 `bind_tools` 方法**
这个方法是 LangChain 用于将工具定义传递给你的模型封装。在这里,你需要将 LangChain 的 `BaseTool` 对象转换为火山引擎 API 所需的 `ChatCompletionToolParam` 格式。
```python
class HuoshanChatModel(BaseChatModel):
# ... (其他代码不变) ...
def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable:
"""将工具绑定到模型并将其转换为火山引擎API所需的格式。"""
tool_definitions = []
for tool_item in tools:
tool_definitions.append(
ChatCompletionToolParam(
type="function",
function={
"name": tool_item.name,
"description": tool_item.description,
"parameters": tool_item.args_schema.schema() if isinstance(tool_item.args_schema, type(BaseModel)) else tool_item.args_schema
}
)
)
# 返回一个绑定了工具的新实例
# 这里我们使用_bind方法它会返回一个新的Runnable实例
return self._bind(tools=tool_definitions, **kwargs)
```
3. **修改 `_convert_messages_to_prompt` 方法**
这个方法需要能够处理 LangChain 的 `ToolMessage` 和 `AIMessage`,并将其转换为火山引擎 API 的消息格式。这对于工具调用的回填和最终回复至关重要。
```python
class HuoshanChatModel(BaseChatModel):
# ... (其他代码不变) ...
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> List[Dict]:
"""将LangChain消息转换为火山引擎API所需的格式。"""
api_messages = []
for msg in messages:
if isinstance(msg, HumanMessage):
api_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
if msg.tool_calls:
api_messages.append({
"role": "assistant",
"content": "",
"tool_calls": [{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.args)
}
} for tc in msg.tool_calls]
})
else:
api_messages.append({"role": "assistant", "content": msg.content})
elif isinstance(msg, ToolMessage):
api_messages.append({
"role": "tool",
"content": msg.content,
"tool_call_id": msg.tool_call_id
})
return api_messages
```
4. **修改 `_generate` 方法**
这个方法需要调用底层 API并解析大模型返回的响应以检查是否包含工具调用。
```python
class HuoshanChatModel(BaseChatModel):
# ... (其他代码不变) ...
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if not self._api:
raise ValueError("HuoshanAPI未正确初始化")
api_messages = self._convert_messages_to_prompt(messages)
tools = kwargs.get("tools", [])
response_data = self._api.get_chat_response(messages=api_messages, tools=tools)
try:
message_from_api = response_data.get("choices", [{}])[0].get("message", {})
tool_calls = message_from_api.get("tool_calls", [])
if tool_calls:
lc_tool_calls = []
for tc in tool_calls:
lc_tool_calls.append(ToolCall(
name=tc["function"]["name"],
args=json.loads(tc["function"]["arguments"]),
id=tc.get("id", "")
))
message = AIMessage(content="", tool_calls=lc_tool_calls)
else:
content = message_from_api.get("content", "")
message = AIMessage(content=content)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
except Exception as e:
raise ValueError(f"处理火山引擎API响应失败: {str(e)}")
```
### 第二步:修改 `huoshan.py`
这个文件是底层 API 客户端,负责与火山引擎 API 进行通信。你需要修改 `get_chat_response` 方法,使其能够发送 `tools` 参数。
```python
# huoshan.py
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
class HuoshanAPI:
# ... (其他代码不变) ...
def get_chat_response(
self,
messages: List[Dict],
stream: bool = False,
tools: Optional[List[ChatCompletionToolParam]] = None
) -> Dict[str, Any]:
"""同步获取聊天响应,支持工具调用。"""
client = Ark()
try:
completion = client.chat.completions.create(
model=self.doubao_seed_1_6_model_id,
messages=messages,
stream=stream,
tools=tools # 传入 tools 参数
)
return completion.model_dump() # 使用 model_dump() 转换为字典
except Exception as e:
raise ValueError(f"火山引擎API调用失败: {str(e)}")
```
完成以上修改后,你的 `HuoshanChatModel` 就会支持工具调用,并能与 LangGraph 的智能体框架无缝集成。

View File

@ -21,6 +21,8 @@ import config
from tools.database.mongo import client # type: ignore from tools.database.mongo import client # type: ignore
from langgraph.checkpoint.mongodb import MongoDBSaver from langgraph.checkpoint.mongodb import MongoDBSaver
# 工具方法
from tools.agent.queryDB import QueryOriginalScript
logger = get_logger(__name__) logger = get_logger(__name__)
@ -32,7 +34,7 @@ def replace_value(old_val, new_val):
# 状态类型定义 # 状态类型定义
class InputState(TypedDict): class InputState(TypedDict):
"""工作流输入状态""" """工作流输入状态"""
input_data: Annotated[list[AnyMessage], operator.add] messages: Annotated[list[AnyMessage], operator.add]
from_type: Annotated[str, replace_value] from_type: Annotated[str, replace_value]
session_id: Annotated[str, replace_value] session_id: Annotated[str, replace_value]
@ -55,7 +57,7 @@ class NodeInfo(TypedDict):
class ScriptwriterState(TypedDict, total=False): class ScriptwriterState(TypedDict, total=False):
"""智能编剧工作流整体状态""" """智能编剧工作流整体状态"""
# 输入数据 # 输入数据
input_data: Annotated[list[HumanMessage], operator.add] messages: Annotated[list[AnyMessage], operator.add]
session_id: Annotated[str, replace_value] session_id: Annotated[str, replace_value]
from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent] from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent]
@ -113,7 +115,9 @@ class ScriptwriterGraph:
print("创建智能体") print("创建智能体")
# 调度智能体 # 调度智能体
self.schedulerAgent = SchedulerAgent( self.schedulerAgent = SchedulerAgent(
tools=[], tools=[
QueryOriginalScript,
],
SchedulerList=[ SchedulerList=[
{ {
"name": "scheduler_node", "name": "scheduler_node",
@ -225,14 +229,33 @@ class ScriptwriterGraph:
logger.error(f"构建工作流图失败: {e}") logger.error(f"构建工作流图失败: {e}")
raise raise
# --- 定义图中的节点 --- # --- 定义图中的节点 ---
async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState: async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState:
"""调度节点""" """调度节点"""
try: try:
session_id = state.get("session_id", "") session_id = state.get("session_id", "")
from_type = state.get("from_type", "") from_type = state.get("from_type", "")
input_data = state.get("input_data", []) messages = state.get("messages", [])
logger.info(f"调度节点 {session_id} 输入参数: {input_data} from_type:{from_type}")
workflow_step = state.get("workflow_step", "wait_for_input")
workflow_status = state.get("workflow_status", "waiting")
workflow_reason = state.get("workflow_reason", "")
workflow_retry_count = int(state.get("workflow_retry_count", 0))
# 添加参数进提示词
messages.append(HumanMessage(content=f"""
# 总任务的进度与任务状态:
{{
'query_args':{{
'session_id':'{session_id}',
}},
'step':'{workflow_step}',
'status':'{workflow_status}',
'from_type':'{from_type}',
'reason':'{workflow_reason}',
'retry_count':{workflow_retry_count},
}}
"""))
logger.info(f"调度节点 {session_id} 输入参数: {messages} from_type:{from_type}")
reslut = await self.schedulerAgent.ainvoke(state) reslut = await self.schedulerAgent.ainvoke(state)
ai_message_str = reslut['messages'][-1].content ai_message_str = reslut['messages'][-1].content
ai_message = json.loads(ai_message_str) ai_message = json.loads(ai_message_str)
@ -259,6 +282,8 @@ class ScriptwriterGraph:
"agent_message":return_message, "agent_message":return_message,
} }
except Exception as e: except Exception as e:
import traceback
traceback.print_exc()
return { return {
"next_node":'end_node', "next_node":'end_node',
"agent_message": "执行失败", "agent_message": "执行失败",
@ -317,11 +342,11 @@ class ScriptwriterGraph:
"agent_message": state.get('agent_message', ''), "agent_message": state.get('agent_message', ''),
} }
async def run(self, session_id: str, input_data: list[AnyMessage], thread_id: str|None = None) -> OutputState: async def run(self, session_id: str, messages: list[AnyMessage], thread_id: str|None = None) -> OutputState:
"""运行工作流 """运行工作流
Args: Args:
session_id: 会话ID session_id: 会话ID
input_data: 输入数据 messages: 输入数据
thread_id: 线程ID thread_id: 线程ID
Returns: Returns:
@ -334,7 +359,7 @@ class ScriptwriterGraph:
config:RunnableConfig = {"configurable": {"thread_id": thread_id}} config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
# 初始化状态 # 初始化状态
initial_state: InputState = { initial_state: InputState = {
'input_data': input_data, 'messages': messages,
'session_id': session_id, 'session_id': session_id,
'from_type': 'user', 'from_type': 'user',
} }
@ -456,7 +481,7 @@ if __name__ == "__main__":
session_id = "68c2c2915e5746343301ef71" session_id = "68c2c2915e5746343301ef71"
result = await graph.run( result = await graph.run(
session_id, session_id,
[HumanMessage(content="你好编剧,我想写小说!")], [HumanMessage(content="老师,我写好剧本了,您看看!帮我分析分析把!")],
session_id session_id
) )
print(f"最终结果: {result}") print(f"最终结果: {result}")

3
tools/agent/__init__.py Normal file
View File

@ -0,0 +1,3 @@
"""
智能体工具模块
"""

38
tools/agent/queryDB.py Normal file
View File

@ -0,0 +1,38 @@
from bson import ObjectId
from tools.database.mongo import mainDB
from langchain.tools import tool
@tool
def QueryOriginalScript(session_id: str, only_exist: bool = False):
"""
查询原始剧本内容或是否存在
Args:
session_id: 会话id
only_exist: 是否只查询存在的剧本
Returns:
Dict: 返回一个包含以下字段的字典
original_script (str): 原始剧本内容仅当 only_exist False 时返回该字段
exist (bool): 原始剧本内容是否存在
"""
# c = mainDB.agent_writer_session.count_documents({})
# print(f"查询到的原始剧本session_id: {session_id}, only_exist: {only_exist} count:{c}")
if only_exist:
script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "original_script": {"$exists": True, "$ne": ""}})
# print(f"exist: {script}")
return {
"original_script": "",
"exist": script is not None,
}
else:
script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id)}, {"original_script": 1})
original_script = ""
if script:
original_script = script["original_script"] or ''
print(f"查询到的原始剧本字符长度: {len(original_script)}")
return {
"original_script": original_script,
"exist": original_script != '',
}

View File

@ -1,11 +1,32 @@
from typing import Any, Dict, Iterator, List, Optional from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence
from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall
from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import BaseTool
from langchain_core.runnables import Runnable
from pydantic import Field from pydantic import Field
import json
import copy
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
from volcenginesdkarkruntime.types.shared_params.function_definition import FunctionDefinition
from volcenginesdkarkruntime.types.shared_params.function_parameters import FunctionParameters
from api.huoshan import HuoshanAPI # 导入你现有的API类 from api.huoshan import HuoshanAPI # 导入你现有的API类
def _convert_dict_type(data: Any) -> Any:
"""递归地将字典中的"Dict"字符串类型转换为"dict"."""
if isinstance(data, dict):
new_data = {}
for k, v in data.items():
if isinstance(v, str) and v == "Dict":
new_data[k] = "dict"
else:
new_data[k] = _convert_dict_type(v)
return new_data
elif isinstance(data, list):
return [_convert_dict_type(item) for item in data]
else:
return data
class HuoshanChatModel(BaseChatModel): class HuoshanChatModel(BaseChatModel):
"""火山引擎聊天模型的LangChain封装""" """火山引擎聊天模型的LangChain封装"""
@ -28,24 +49,136 @@ class HuoshanChatModel(BaseChatModel):
"""返回LLM类型标识""" """返回LLM类型标识"""
return "huoshan_chat" return "huoshan_chat"
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> tuple[str, str]: # def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> tuple[str, str]:
"""将LangChain消息格式转换为API所需的prompt和system格式""" # """将LangChain消息格式转换为API所需的prompt和system格式"""
system_message = "" # system_message = ""
user_messages = [] # user_messages = []
for message in messages: # for message in messages:
if isinstance(message, SystemMessage): # if isinstance(message, SystemMessage):
system_message = message.content or "" # system_message = message.content or ""
elif isinstance(message, HumanMessage): # elif isinstance(message, HumanMessage):
user_messages.append(message.content) # user_messages.append(message.content)
elif isinstance(message, AIMessage): # elif isinstance(message, SystemMessage):
# 如果需要支持多轮对话,可以在这里处理 # # 如果需要支持多轮对话,可以在这里处理
pass # pass
# 合并用户消息 # # 合并用户消息
prompt = "\n".join(user_messages) if user_messages else "" # prompt = "\n".join(user_messages) if user_messages else ""
return prompt, str(system_message) # return prompt, str(system_message)
def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable:
"""将工具绑定到模型并将其转换为火山引擎API所需的格式。"""
# 1. 创建一个当前模型的副本
# new_model = copy.copy(self)
tool_definitions = []
for tool_item in tools:
# 1. 获取工具的参数 JSON Schema
if hasattr(tool_item.args_schema, 'model_json_schema'):
# 如果是 Pydantic 模型,调用其方法获取 JSON Schema
parameters_schema = tool_item.args_schema.model_json_schema() # type: ignore
else:
parameters_schema = _convert_dict_type(parameters_schema)
# 2. 从字典中提取并修正参数
# SDK要求 'type' 为 'object''properties' 和 'required' 字段为必填
schema_type = parameters_schema.get("type", "object")
schema_properties = parameters_schema.get("properties", {})
schema_required = parameters_schema.get("required", [])
# 3. 将修正后的字典作为参数,直接创建 FunctionParameters 对象
# 这是最安全的方式,因为它绕过了之前可能存在的类型转换问题
function_parameters:FunctionParameters = {
"type":schema_type,
"properties":schema_properties,
"required":schema_required
}
# 4. 将 FunctionParameters 对象封装成 FunctionDefinition
function_def:FunctionDefinition = {
"name":tool_item.name,
"description":tool_item.description,
"parameters":function_parameters
}
# 5. 将 FunctionDefinition 对象封装成 ChatCompletionToolParam
tool_definitions.append(
ChatCompletionToolParam(
type="function",
function=function_def
)
)
# 返回一个绑定了工具的新实例
# 这里我们使用_bind方法它会返回一个新的Runnable实例
# return self._bind(tools=tool_definitions, **kwargs) # type: ignore
# return self.bind(tools=tool_definitions, **kwargs)
return super().bind(tools=tool_definitions, **kwargs)
# new_model._bound_tools = tool_definitions
# 3. 返回新的、已绑定工具的模型实例
# return new_model
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> List[Dict]:
"""将LangChain消息转换为火山引擎API所需的格式。"""
# print(f" 原始 messages: {messages}")
# 使用字典来存储每个tool_call_id的最新消息
last_tool_messages = {}
# 临时存储其他消息非ToolMessage
other_messages = []
for msg in messages:
if isinstance(msg, ToolMessage):
# 将最新的ToolMessage存入字典这会覆盖之前的旧消息
last_tool_messages[msg.tool_call_id] = msg
else:
other_messages.append(msg)
# 从后向前遍历只保留每个tool_call_id的第一个即最新的消息
deduplicated_messages = []
seen_tool_call_ids = set()
for msg in reversed(messages):
# 如果是ToolMessage并且tool_call_id已经存在则跳过
if isinstance(msg, ToolMessage):
if msg.tool_call_id not in seen_tool_call_ids:
seen_tool_call_ids.add(msg.tool_call_id)
deduplicated_messages.append(msg)
# 其他类型的消息直接添加
else:
deduplicated_messages.append(msg)
# 将消息列表还原为原始时间顺序
messages_to_convert = list(reversed(deduplicated_messages))
# 2. 将去重后的消息列表转换为API所需的格式
api_messages = []
for msg in messages_to_convert:
if isinstance(msg, HumanMessage):
api_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, SystemMessage):
api_messages.append({"role": "system", "content": msg.content})
elif isinstance(msg, ToolMessage):
api_messages.append({
"role": "tool",
"content": msg.content,
"tool_call_id": msg.tool_call_id
})
# api_messages = []
# for msg in messages:
# if isinstance(msg, HumanMessage):
# api_messages.append({"role": "user", "content": msg.content})
# elif isinstance(msg, SystemMessage):
# api_messages.append({"role": "system", "content": msg.content})
# elif isinstance(msg, ToolMessage):
# api_messages.append({
# "role": "tool",
# "content": msg.content,
# "tool_call_id": msg.tool_call_id
# })
return api_messages
def _generate( def _generate(
self, self,
@ -54,37 +187,44 @@ class HuoshanChatModel(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""生成聊天回复"""
if not self._api: if not self._api:
raise ValueError("HuoshanAPI未正确初始化") raise ValueError("HuoshanAPI未正确初始化")
# 转换消息格式 api_messages = self._convert_messages_to_prompt(messages)
prompt, system = self._convert_messages_to_prompt(messages) tools = kwargs.get("tools", [])
# 合并参数 print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}")
generation_kwargs = {
"model": kwargs.get("model", self.model_name), response_data = self._api.get_chat_response(messages=api_messages, tools=tools)
"temperature": kwargs.get("temperature", self.temperature),
"system": system
}
try: try:
# 调用你的API res_choices = response_data.get("choices", [{}])[0]
response_text = self._api.get_chat_response( finish_reason = res_choices.get("finish_reason", "")
prompt=prompt, message_from_api = res_choices.get("message", {})
**generation_kwargs tool_calls = message_from_api.get("tool_calls", [])
) print(f" 豆包返回的 finish_reason: {finish_reason} \n tool_calls: {tool_calls} \n")
print(f" 豆包返回的 message: {message_from_api.get('content', '')}")
if finish_reason == "tool_calls" and tool_calls:
lc_tool_calls = []
for tc in tool_calls:
function_dict = tc.get("function", {})
lc_tool_calls.append(ToolCall(
name=function_dict.get("name", ""),
args=json.loads(function_dict.get("arguments", "{}")),
id=tc.get("id", "")
))
message = AIMessage(content="", tool_calls=lc_tool_calls)
else:
content = message_from_api.get("content", "")
message = AIMessage(content=content)
# 创建AI消息
message = AIMessage(content=response_text)
# 创建生成结果
generation = ChatGeneration(message=message) generation = ChatGeneration(message=message)
return ChatResult(generations=[generation]) return ChatResult(generations=[generation])
except Exception as e: except Exception as e:
raise ValueError(f"调用火山引擎API失败: {str(e)}") import traceback
traceback.print_exc()
raise ValueError(f"处理火山引擎API响应失败: {str(e)}")
def _stream( def _stream(
self, self,
@ -98,20 +238,15 @@ class HuoshanChatModel(BaseChatModel):
raise ValueError("HuoshanAPI未正确初始化") raise ValueError("HuoshanAPI未正确初始化")
# 转换消息格式 # 转换消息格式
prompt, system = self._convert_messages_to_prompt(messages) api_messages = self._convert_messages_to_prompt(messages)
tools = kwargs.get("tools", [])
# 合并参数
generation_kwargs = {
"model": kwargs.get("model", self.model_name),
"temperature": kwargs.get("temperature", self.temperature),
"system": system
}
try: try:
# 调用流式API # 调用流式API
for chunk in self._api.get_chat_response_stream( for chunk in self._api.get_chat_response_stream(
prompt=prompt, messages=api_messages,
**generation_kwargs tools=tools,
**kwargs
): ):
if chunk: if chunk:
# 创建增量消息 # 创建增量消息