agent-writer/tools/llm/huoshan_langchain.py
2025-09-12 00:32:55 +08:00

273 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import BaseTool
from langchain_core.runnables import Runnable
from pydantic import Field
import json
import copy
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
from volcenginesdkarkruntime.types.shared_params.function_definition import FunctionDefinition
from volcenginesdkarkruntime.types.shared_params.function_parameters import FunctionParameters
from api.huoshan import HuoshanAPI # 导入你现有的API类
def _convert_dict_type(data: Any) -> Any:
"""递归地将字典中的"Dict"字符串类型转换为"dict"."""
if isinstance(data, dict):
new_data = {}
for k, v in data.items():
if isinstance(v, str) and v == "Dict":
new_data[k] = "dict"
else:
new_data[k] = _convert_dict_type(v)
return new_data
elif isinstance(data, list):
return [_convert_dict_type(item) for item in data]
else:
return data
class HuoshanChatModel(BaseChatModel):
"""火山引擎聊天模型的LangChain封装"""
# 模型配置参数
model_name: str = Field(default="doubao-seed-1.6-250615", description="模型名称")
temperature: float = Field(default=0.6, description="温度参数")
max_tokens: int = Field(default=16384, description="最大token数")
# 内部API实例
_api: Optional[HuoshanAPI] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 初始化火山引擎API实例
self._api = HuoshanAPI()
@property
def _llm_type(self) -> str:
"""返回LLM类型标识"""
return "huoshan_chat"
# def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> tuple[str, str]:
# """将LangChain消息格式转换为API所需的prompt和system格式"""
# system_message = ""
# user_messages = []
# for message in messages:
# if isinstance(message, SystemMessage):
# system_message = message.content or ""
# elif isinstance(message, HumanMessage):
# user_messages.append(message.content)
# elif isinstance(message, SystemMessage):
# # 如果需要支持多轮对话,可以在这里处理
# pass
# # 合并用户消息
# prompt = "\n".join(user_messages) if user_messages else ""
# return prompt, str(system_message)
def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable:
"""将工具绑定到模型并将其转换为火山引擎API所需的格式。"""
# 1. 创建一个当前模型的副本
# new_model = copy.copy(self)
tool_definitions = []
for tool_item in tools:
# 1. 获取工具的参数 JSON Schema
if hasattr(tool_item.args_schema, 'model_json_schema'):
# 如果是 Pydantic 模型,调用其方法获取 JSON Schema
parameters_schema = tool_item.args_schema.model_json_schema() # type: ignore
else:
parameters_schema = _convert_dict_type(parameters_schema)
# 2. 从字典中提取并修正参数
# SDK要求 'type' 为 'object''properties' 和 'required' 字段为必填
schema_type = parameters_schema.get("type", "object")
schema_properties = parameters_schema.get("properties", {})
schema_required = parameters_schema.get("required", [])
# 3. 将修正后的字典作为参数,直接创建 FunctionParameters 对象
# 这是最安全的方式,因为它绕过了之前可能存在的类型转换问题
function_parameters:FunctionParameters = {
"type":schema_type,
"properties":schema_properties,
"required":schema_required
}
# 4. 将 FunctionParameters 对象封装成 FunctionDefinition
function_def:FunctionDefinition = {
"name":tool_item.name,
"description":tool_item.description,
"parameters":function_parameters
}
# 5. 将 FunctionDefinition 对象封装成 ChatCompletionToolParam
tool_definitions.append(
ChatCompletionToolParam(
type="function",
function=function_def
)
)
# 返回一个绑定了工具的新实例
# 这里我们使用_bind方法它会返回一个新的Runnable实例
# return self._bind(tools=tool_definitions, **kwargs) # type: ignore
# return self.bind(tools=tool_definitions, **kwargs)
return super().bind(tools=tool_definitions, **kwargs)
# new_model._bound_tools = tool_definitions
# 3. 返回新的、已绑定工具的模型实例
# return new_model
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> List[Dict]:
"""将LangChain消息转换为火山引擎API所需的格式。"""
# print(f" 原始 messages: {messages}")
# 使用字典来存储每个tool_call_id的最新消息
last_tool_messages = {}
# 临时存储其他消息非ToolMessage
other_messages = []
for msg in messages:
if isinstance(msg, ToolMessage):
# 将最新的ToolMessage存入字典这会覆盖之前的旧消息
last_tool_messages[msg.tool_call_id] = msg
else:
other_messages.append(msg)
# 从后向前遍历只保留每个tool_call_id的第一个即最新的消息
deduplicated_messages = []
seen_tool_call_ids = set()
for msg in reversed(messages):
# 如果是ToolMessage并且tool_call_id已经存在则跳过
if isinstance(msg, ToolMessage):
if msg.tool_call_id not in seen_tool_call_ids:
seen_tool_call_ids.add(msg.tool_call_id)
deduplicated_messages.append(msg)
# 其他类型的消息直接添加
else:
deduplicated_messages.append(msg)
# 将消息列表还原为原始时间顺序
messages_to_convert = list(reversed(deduplicated_messages))
# 2. 将去重后的消息列表转换为API所需的格式
api_messages = []
for msg in messages_to_convert:
if isinstance(msg, HumanMessage):
api_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, SystemMessage):
api_messages.append({"role": "system", "content": msg.content})
elif isinstance(msg, ToolMessage):
api_messages.append({
"role": "tool",
"content": msg.content,
"tool_call_id": msg.tool_call_id
})
# api_messages = []
# for msg in messages:
# if isinstance(msg, HumanMessage):
# api_messages.append({"role": "user", "content": msg.content})
# elif isinstance(msg, SystemMessage):
# api_messages.append({"role": "system", "content": msg.content})
# elif isinstance(msg, ToolMessage):
# api_messages.append({
# "role": "tool",
# "content": msg.content,
# "tool_call_id": msg.tool_call_id
# })
return api_messages
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if not self._api:
raise ValueError("HuoshanAPI未正确初始化")
api_messages = self._convert_messages_to_prompt(messages)
tools = kwargs.get("tools", [])
print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}")
response_data = self._api.get_chat_response(messages=api_messages, tools=tools)
try:
res_choices = response_data.get("choices", [{}])[0]
finish_reason = res_choices.get("finish_reason", "")
message_from_api = res_choices.get("message", {})
tool_calls = message_from_api.get("tool_calls", [])
print(f" 豆包返回的 finish_reason: {finish_reason} \n tool_calls: {tool_calls} \n")
print(f" 豆包返回的 message: {message_from_api.get('content', '')}")
if finish_reason == "tool_calls" and tool_calls:
lc_tool_calls = []
for tc in tool_calls:
function_dict = tc.get("function", {})
lc_tool_calls.append(ToolCall(
name=function_dict.get("name", ""),
args=json.loads(function_dict.get("arguments", "{}")),
id=tc.get("id", "")
))
message = AIMessage(content="", tool_calls=lc_tool_calls)
else:
content = message_from_api.get("content", "")
message = AIMessage(content=content)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
except Exception as e:
import traceback
traceback.print_exc()
raise ValueError(f"处理火山引擎API响应失败: {str(e)}")
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGeneration]:
"""流式生成聊天回复"""
if not self._api:
raise ValueError("HuoshanAPI未正确初始化")
# 转换消息格式
api_messages = self._convert_messages_to_prompt(messages)
tools = kwargs.get("tools", [])
try:
# 调用流式API
for chunk in self._api.get_chat_response_stream(
messages=api_messages,
tools=tools,
**kwargs
):
if chunk:
# 创建增量消息
message = AIMessage(content=chunk)
generation = ChatGeneration(message=message)
# 如果有回调管理器通知新token
if run_manager:
run_manager.on_llm_new_token(chunk)
yield generation
except Exception as e:
raise ValueError(f"调用火山引擎流式API失败: {str(e)}")
@property
def _identifying_params(self) -> Dict[str, Any]:
"""返回用于标识模型的参数"""
return {
"model_name": self.model_name,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
}