173 lines
7.5 KiB
Plaintext
173 lines
7.5 KiB
Plaintext
为了在你的代码中正确接入火山引擎的工具调用功能,你需要修改 `huoshan_langchain.py` 和 `huoshan.py` 这两个文件,以实现**从 LangChain 工具到火山引擎 API 工具定义的格式转换**,以及**解析和处理来自 API 的工具调用响应**。
|
||
|
||
下面是根据火山引擎官方文档和你的代码,我为你整理的完整修改方案。
|
||
|
||
-----
|
||
|
||
### 第一步:修改 `huoshan_langchain.py`
|
||
|
||
这个文件是 LangChain 的封装层,负责连接你的工作流和火山引擎的底层 API。你需要在这里实现 `bind_tools` 和 `_generate` 方法来处理工具调用。
|
||
|
||
1. **导入必要的类**:
|
||
需要添加 `ToolMessage` 和 `ToolCall`,它们是 LangChain 用于表示工具调用结果和工具调用的核心类。
|
||
|
||
```python
|
||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence
|
||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||
from langchain_core.language_models.chat_models import BaseChatModel
|
||
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall
|
||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||
from langchain_core.tools import BaseTool
|
||
from langchain_core.runnables import Runnable
|
||
from langchain.pydantic_v1 import BaseModel
|
||
from pydantic import Field
|
||
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
|
||
from api.huoshan import HuoshanAPI
|
||
import json
|
||
```
|
||
|
||
2. **实现 `bind_tools` 方法**:
|
||
这个方法是 LangChain 用于将工具定义传递给你的模型封装。在这里,你需要将 LangChain 的 `BaseTool` 对象转换为火山引擎 API 所需的 `ChatCompletionToolParam` 格式。
|
||
|
||
```python
|
||
class HuoshanChatModel(BaseChatModel):
|
||
# ... (其他代码不变) ...
|
||
|
||
def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable:
|
||
"""将工具绑定到模型,并将其转换为火山引擎API所需的格式。"""
|
||
tool_definitions = []
|
||
for tool_item in tools:
|
||
tool_definitions.append(
|
||
ChatCompletionToolParam(
|
||
type="function",
|
||
function={
|
||
"name": tool_item.name,
|
||
"description": tool_item.description,
|
||
"parameters": tool_item.args_schema.schema() if isinstance(tool_item.args_schema, type(BaseModel)) else tool_item.args_schema
|
||
}
|
||
)
|
||
)
|
||
|
||
# 返回一个绑定了工具的新实例
|
||
# 这里我们使用_bind方法,它会返回一个新的Runnable实例
|
||
return self._bind(tools=tool_definitions, **kwargs)
|
||
```
|
||
|
||
3. **修改 `_convert_messages_to_prompt` 方法**:
|
||
这个方法需要能够处理 LangChain 的 `ToolMessage` 和 `AIMessage`,并将其转换为火山引擎 API 的消息格式。这对于工具调用的回填和最终回复至关重要。
|
||
|
||
```python
|
||
class HuoshanChatModel(BaseChatModel):
|
||
# ... (其他代码不变) ...
|
||
|
||
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> List[Dict]:
|
||
"""将LangChain消息转换为火山引擎API所需的格式。"""
|
||
api_messages = []
|
||
for msg in messages:
|
||
if isinstance(msg, HumanMessage):
|
||
api_messages.append({"role": "user", "content": msg.content})
|
||
elif isinstance(msg, AIMessage):
|
||
if msg.tool_calls:
|
||
api_messages.append({
|
||
"role": "assistant",
|
||
"content": "",
|
||
"tool_calls": [{
|
||
"id": tc.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tc.name,
|
||
"arguments": json.dumps(tc.args)
|
||
}
|
||
} for tc in msg.tool_calls]
|
||
})
|
||
else:
|
||
api_messages.append({"role": "assistant", "content": msg.content})
|
||
elif isinstance(msg, ToolMessage):
|
||
api_messages.append({
|
||
"role": "tool",
|
||
"content": msg.content,
|
||
"tool_call_id": msg.tool_call_id
|
||
})
|
||
return api_messages
|
||
```
|
||
|
||
4. **修改 `_generate` 方法**:
|
||
这个方法需要调用底层 API,并解析大模型返回的响应,以检查是否包含工具调用。
|
||
|
||
```python
|
||
class HuoshanChatModel(BaseChatModel):
|
||
# ... (其他代码不变) ...
|
||
|
||
def _generate(
|
||
self,
|
||
messages: List[BaseMessage],
|
||
stop: Optional[List[str]] = None,
|
||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
**kwargs: Any,
|
||
) -> ChatResult:
|
||
if not self._api:
|
||
raise ValueError("HuoshanAPI未正确初始化")
|
||
|
||
api_messages = self._convert_messages_to_prompt(messages)
|
||
tools = kwargs.get("tools", [])
|
||
|
||
response_data = self._api.get_chat_response(messages=api_messages, tools=tools)
|
||
|
||
try:
|
||
message_from_api = response_data.get("choices", [{}])[0].get("message", {})
|
||
|
||
tool_calls = message_from_api.get("tool_calls", [])
|
||
if tool_calls:
|
||
lc_tool_calls = []
|
||
for tc in tool_calls:
|
||
lc_tool_calls.append(ToolCall(
|
||
name=tc["function"]["name"],
|
||
args=json.loads(tc["function"]["arguments"]),
|
||
id=tc.get("id", "")
|
||
))
|
||
message = AIMessage(content="", tool_calls=lc_tool_calls)
|
||
else:
|
||
content = message_from_api.get("content", "")
|
||
message = AIMessage(content=content)
|
||
|
||
generation = ChatGeneration(message=message)
|
||
return ChatResult(generations=[generation])
|
||
|
||
except Exception as e:
|
||
raise ValueError(f"处理火山引擎API响应失败: {str(e)}")
|
||
```
|
||
|
||
### 第二步:修改 `huoshan.py`
|
||
|
||
这个文件是底层 API 客户端,负责与火山引擎 API 进行通信。你需要修改 `get_chat_response` 方法,使其能够发送 `tools` 参数。
|
||
|
||
```python
|
||
# huoshan.py
|
||
|
||
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
|
||
|
||
class HuoshanAPI:
|
||
# ... (其他代码不变) ...
|
||
|
||
def get_chat_response(
|
||
self,
|
||
messages: List[Dict],
|
||
stream: bool = False,
|
||
tools: Optional[List[ChatCompletionToolParam]] = None
|
||
) -> Dict[str, Any]:
|
||
"""同步获取聊天响应,支持工具调用。"""
|
||
client = Ark()
|
||
|
||
try:
|
||
completion = client.chat.completions.create(
|
||
model=self.doubao_seed_1_6_model_id,
|
||
messages=messages,
|
||
stream=stream,
|
||
tools=tools # 传入 tools 参数
|
||
)
|
||
return completion.model_dump() # 使用 model_dump() 转换为字典
|
||
except Exception as e:
|
||
raise ValueError(f"火山引擎API调用失败: {str(e)}")
|
||
```
|
||
|
||
完成以上修改后,你的 `HuoshanChatModel` 就会支持工具调用,并能与 LangGraph 的智能体框架无缝集成。 |