agent-writer/doc/test.txt
2025-09-12 00:32:55 +08:00

173 lines
7.5 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

为了在你的代码中正确接入火山引擎的工具调用功能,你需要修改 `huoshan_langchain.py` 和 `huoshan.py` 这两个文件,以实现**从 LangChain 工具到火山引擎 API 工具定义的格式转换**,以及**解析和处理来自 API 的工具调用响应**。
下面是根据火山引擎官方文档和你的代码,我为你整理的完整修改方案。
-----
### 第一步:修改 `huoshan_langchain.py`
这个文件是 LangChain 的封装层,负责连接你的工作流和火山引擎的底层 API。你需要在这里实现 `bind_tools` 和 `_generate` 方法来处理工具调用。
1. **导入必要的类**
需要添加 `ToolMessage` 和 `ToolCall`,它们是 LangChain 用于表示工具调用结果和工具调用的核心类。
```python
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import BaseTool
from langchain_core.runnables import Runnable
from langchain.pydantic_v1 import BaseModel
from pydantic import Field
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
from api.huoshan import HuoshanAPI
import json
```
2. **实现 `bind_tools` 方法**
这个方法是 LangChain 用于将工具定义传递给你的模型封装。在这里,你需要将 LangChain 的 `BaseTool` 对象转换为火山引擎 API 所需的 `ChatCompletionToolParam` 格式。
```python
class HuoshanChatModel(BaseChatModel):
# ... (其他代码不变) ...
def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable:
"""将工具绑定到模型并将其转换为火山引擎API所需的格式。"""
tool_definitions = []
for tool_item in tools:
tool_definitions.append(
ChatCompletionToolParam(
type="function",
function={
"name": tool_item.name,
"description": tool_item.description,
"parameters": tool_item.args_schema.schema() if isinstance(tool_item.args_schema, type(BaseModel)) else tool_item.args_schema
}
)
)
# 返回一个绑定了工具的新实例
# 这里我们使用_bind方法它会返回一个新的Runnable实例
return self._bind(tools=tool_definitions, **kwargs)
```
3. **修改 `_convert_messages_to_prompt` 方法**
这个方法需要能够处理 LangChain 的 `ToolMessage` 和 `AIMessage`,并将其转换为火山引擎 API 的消息格式。这对于工具调用的回填和最终回复至关重要。
```python
class HuoshanChatModel(BaseChatModel):
# ... (其他代码不变) ...
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> List[Dict]:
"""将LangChain消息转换为火山引擎API所需的格式。"""
api_messages = []
for msg in messages:
if isinstance(msg, HumanMessage):
api_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
if msg.tool_calls:
api_messages.append({
"role": "assistant",
"content": "",
"tool_calls": [{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.args)
}
} for tc in msg.tool_calls]
})
else:
api_messages.append({"role": "assistant", "content": msg.content})
elif isinstance(msg, ToolMessage):
api_messages.append({
"role": "tool",
"content": msg.content,
"tool_call_id": msg.tool_call_id
})
return api_messages
```
4. **修改 `_generate` 方法**
这个方法需要调用底层 API并解析大模型返回的响应以检查是否包含工具调用。
```python
class HuoshanChatModel(BaseChatModel):
# ... (其他代码不变) ...
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if not self._api:
raise ValueError("HuoshanAPI未正确初始化")
api_messages = self._convert_messages_to_prompt(messages)
tools = kwargs.get("tools", [])
response_data = self._api.get_chat_response(messages=api_messages, tools=tools)
try:
message_from_api = response_data.get("choices", [{}])[0].get("message", {})
tool_calls = message_from_api.get("tool_calls", [])
if tool_calls:
lc_tool_calls = []
for tc in tool_calls:
lc_tool_calls.append(ToolCall(
name=tc["function"]["name"],
args=json.loads(tc["function"]["arguments"]),
id=tc.get("id", "")
))
message = AIMessage(content="", tool_calls=lc_tool_calls)
else:
content = message_from_api.get("content", "")
message = AIMessage(content=content)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
except Exception as e:
raise ValueError(f"处理火山引擎API响应失败: {str(e)}")
```
### 第二步:修改 `huoshan.py`
这个文件是底层 API 客户端,负责与火山引擎 API 进行通信。你需要修改 `get_chat_response` 方法,使其能够发送 `tools` 参数。
```python
# huoshan.py
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
class HuoshanAPI:
# ... (其他代码不变) ...
def get_chat_response(
self,
messages: List[Dict],
stream: bool = False,
tools: Optional[List[ChatCompletionToolParam]] = None
) -> Dict[str, Any]:
"""同步获取聊天响应,支持工具调用。"""
client = Ark()
try:
completion = client.chat.completions.create(
model=self.doubao_seed_1_6_model_id,
messages=messages,
stream=stream,
tools=tools # 传入 tools 参数
)
return completion.model_dump() # 使用 model_dump() 转换为字典
except Exception as e:
raise ValueError(f"火山引擎API调用失败: {str(e)}")
```
完成以上修改后,你的 `HuoshanChatModel` 就会支持工具调用,并能与 LangGraph 的智能体框架无缝集成。