262 lines
11 KiB
Python
262 lines
11 KiB
Python
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence
|
||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||
from langchain_core.language_models.chat_models import BaseChatModel
|
||
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall
|
||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||
from langchain_core.tools import BaseTool
|
||
from langchain_core.runnables import Runnable
|
||
from pydantic import Field
|
||
import json
|
||
import copy
|
||
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
|
||
from volcenginesdkarkruntime.types.shared_params.function_definition import FunctionDefinition
|
||
from volcenginesdkarkruntime.types.shared_params.function_parameters import FunctionParameters
|
||
from api.huoshan import HuoshanAPI # 导入你现有的API类
|
||
|
||
def _convert_dict_type(data: Any) -> Any:
|
||
"""递归地将字典中的"Dict"字符串类型转换为"dict"."""
|
||
if isinstance(data, dict):
|
||
new_data = {}
|
||
for k, v in data.items():
|
||
if isinstance(v, str) and v == "Dict":
|
||
new_data[k] = "dict"
|
||
else:
|
||
new_data[k] = _convert_dict_type(v)
|
||
return new_data
|
||
elif isinstance(data, list):
|
||
return [_convert_dict_type(item) for item in data]
|
||
else:
|
||
return data
|
||
|
||
class HuoshanChatModel(BaseChatModel):
|
||
"""火山引擎聊天模型的LangChain封装"""
|
||
|
||
# 模型配置参数
|
||
model_name: str = Field(default="doubao-seed-1.6-250615", description="模型名称")
|
||
temperature: float = Field(default=0.6, description="温度参数")
|
||
max_tokens: int = Field(default=16384, description="最大token数")
|
||
|
||
# 内部API实例
|
||
_api: Optional[HuoshanAPI] = None
|
||
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
# 初始化火山引擎API实例
|
||
self._api = HuoshanAPI()
|
||
|
||
@property
|
||
def _llm_type(self) -> str:
|
||
"""返回LLM类型标识"""
|
||
return "huoshan_chat"
|
||
|
||
|
||
def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable:
|
||
"""将工具绑定到模型,并将其转换为火山引擎API所需的格式。"""
|
||
# 1. 创建一个当前模型的副本
|
||
# new_model = copy.copy(self)
|
||
tool_definitions = []
|
||
for tool_item in tools:
|
||
# 1. 获取工具的参数 JSON Schema
|
||
if hasattr(tool_item.args_schema, 'model_json_schema'):
|
||
# 如果是 Pydantic 模型,调用其方法获取 JSON Schema
|
||
parameters_schema = tool_item.args_schema.model_json_schema() # type: ignore
|
||
else:
|
||
parameters_schema = _convert_dict_type(parameters_schema)
|
||
|
||
# 2. 从字典中提取并修正参数
|
||
# SDK要求 'type' 为 'object','properties' 和 'required' 字段为必填
|
||
schema_type = parameters_schema.get("type", "object")
|
||
schema_properties = parameters_schema.get("properties", {})
|
||
schema_required = parameters_schema.get("required", [])
|
||
|
||
# 3. 将修正后的字典作为参数,直接创建 FunctionParameters 对象
|
||
# 这是最安全的方式,因为它绕过了之前可能存在的类型转换问题
|
||
function_parameters:FunctionParameters = {
|
||
"type":schema_type,
|
||
"properties":schema_properties,
|
||
"required":schema_required
|
||
}
|
||
|
||
# 4. 将 FunctionParameters 对象封装成 FunctionDefinition
|
||
function_def:FunctionDefinition = {
|
||
"name":tool_item.name,
|
||
"description":tool_item.description,
|
||
"parameters":function_parameters
|
||
}
|
||
|
||
# 5. 将 FunctionDefinition 对象封装成 ChatCompletionToolParam
|
||
tool_definitions.append(
|
||
ChatCompletionToolParam(
|
||
type="function",
|
||
function=function_def
|
||
)
|
||
)
|
||
|
||
# 返回一个绑定了工具的新实例
|
||
# 这里我们使用_bind方法,它会返回一个新的Runnable实例
|
||
# return self._bind(tools=tool_definitions, **kwargs) # type: ignore
|
||
# return self.bind(tools=tool_definitions, **kwargs)
|
||
return super().bind(tools=tool_definitions, **kwargs)
|
||
|
||
# new_model._bound_tools = tool_definitions
|
||
|
||
# 3. 返回新的、已绑定工具的模型实例
|
||
# return new_model
|
||
|
||
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> List[Dict]:
|
||
"""将LangChain消息转换为火山引擎API所需的格式。"""
|
||
# print(f" 原始 messages: {messages}")
|
||
|
||
# 使用字典来存储每个tool_call_id的最新消息
|
||
last_tool_messages = {}
|
||
# 临时存储其他消息(非ToolMessage)
|
||
other_messages = []
|
||
for msg in messages:
|
||
if isinstance(msg, ToolMessage):
|
||
# 将最新的ToolMessage存入字典,这会覆盖之前的旧消息
|
||
last_tool_messages[msg.tool_call_id] = msg
|
||
else:
|
||
other_messages.append(msg)
|
||
# 从后向前遍历,只保留每个tool_call_id的第一个(即最新的)消息
|
||
deduplicated_messages = []
|
||
seen_tool_call_ids = set()
|
||
|
||
for msg in reversed(messages):
|
||
# 如果是ToolMessage,并且tool_call_id已经存在,则跳过
|
||
if isinstance(msg, ToolMessage):
|
||
if msg.tool_call_id not in seen_tool_call_ids:
|
||
seen_tool_call_ids.add(msg.tool_call_id)
|
||
deduplicated_messages.append(msg)
|
||
# 其他类型的消息直接添加
|
||
else:
|
||
deduplicated_messages.append(msg)
|
||
|
||
# 将消息列表还原为原始时间顺序
|
||
messages_to_convert = list(reversed(deduplicated_messages))
|
||
|
||
# 2. 将去重后的消息列表转换为API所需的格式
|
||
api_messages = []
|
||
for msg in messages_to_convert:
|
||
if isinstance(msg, HumanMessage):
|
||
api_messages.append({"role": "user", "content": msg.content})
|
||
elif isinstance(msg, SystemMessage):
|
||
api_messages.append({"role": "system", "content": msg.content})
|
||
elif isinstance(msg, ToolMessage):
|
||
api_messages.append({
|
||
"role": "tool",
|
||
"content": msg.content,
|
||
"tool_call_id": msg.tool_call_id
|
||
})
|
||
# api_messages = []
|
||
# for msg in messages:
|
||
# if isinstance(msg, HumanMessage):
|
||
# api_messages.append({"role": "user", "content": msg.content})
|
||
# elif isinstance(msg, SystemMessage):
|
||
# api_messages.append({"role": "system", "content": msg.content})
|
||
# elif isinstance(msg, ToolMessage):
|
||
# api_messages.append({
|
||
# "role": "tool",
|
||
# "content": msg.content,
|
||
# "tool_call_id": msg.tool_call_id
|
||
# })
|
||
return api_messages
|
||
|
||
def _generate(
|
||
self,
|
||
messages: List[BaseMessage],
|
||
stop: Optional[List[str]] = None,
|
||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
**kwargs: Any,
|
||
) -> ChatResult:
|
||
if not self._api:
|
||
raise ValueError("HuoshanAPI未正确初始化")
|
||
|
||
api_messages = self._convert_messages_to_prompt(messages)
|
||
tools = kwargs.get("tools", [])
|
||
# print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}")
|
||
print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> \n")
|
||
print(f"\nmessages: \n")
|
||
for message in messages:
|
||
print(f" {message.type}: \n ")
|
||
print(f" {message.content} \n ")
|
||
print(f"\ntools: \n")
|
||
for tool in tools:
|
||
print(f" \n {tool} \n ")
|
||
print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> end \n")
|
||
|
||
response_data = self._api.get_chat_response(messages=api_messages, tools=tools)
|
||
|
||
try:
|
||
res_choices = response_data.get("choices", [{}])[0]
|
||
finish_reason = res_choices.get("finish_reason", "")
|
||
message_from_api = res_choices.get("message", {})
|
||
tool_calls = message_from_api.get("tool_calls", [])
|
||
print(f" 豆包返回的 finish_reason: {finish_reason} \n tool_calls: {tool_calls} \n")
|
||
print(f" 豆包返回的 message: {message_from_api.get('content', '')}")
|
||
if finish_reason == "tool_calls" and tool_calls:
|
||
lc_tool_calls = []
|
||
for tc in tool_calls:
|
||
function_dict = tc.get("function", {})
|
||
lc_tool_calls.append(ToolCall(
|
||
name=function_dict.get("name", ""),
|
||
args=json.loads(function_dict.get("arguments", "{}")),
|
||
id=tc.get("id", "")
|
||
))
|
||
message = AIMessage(content="", tool_calls=lc_tool_calls)
|
||
else:
|
||
content = message_from_api.get("content", "")
|
||
message = AIMessage(content=content)
|
||
|
||
generation = ChatGeneration(message=message)
|
||
return ChatResult(generations=[generation])
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
raise ValueError(f"处理火山引擎API响应失败: {str(e)}")
|
||
|
||
def _stream(
|
||
self,
|
||
messages: List[BaseMessage],
|
||
stop: Optional[List[str]] = None,
|
||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
**kwargs: Any,
|
||
) -> Iterator[ChatGeneration]:
|
||
"""流式生成聊天回复"""
|
||
if not self._api:
|
||
raise ValueError("HuoshanAPI未正确初始化")
|
||
|
||
# 转换消息格式
|
||
api_messages = self._convert_messages_to_prompt(messages)
|
||
tools = kwargs.get("tools", [])
|
||
|
||
try:
|
||
# 调用流式API
|
||
for chunk in self._api.get_chat_response_stream(
|
||
messages=api_messages,
|
||
tools=tools,
|
||
**kwargs
|
||
):
|
||
if chunk:
|
||
# 创建增量消息
|
||
message = AIMessage(content=chunk)
|
||
generation = ChatGeneration(message=message)
|
||
|
||
# 如果有回调管理器,通知新token
|
||
if run_manager:
|
||
run_manager.on_llm_new_token(chunk)
|
||
|
||
yield generation
|
||
|
||
except Exception as e:
|
||
raise ValueError(f"调用火山引擎流式API失败: {str(e)}")
|
||
|
||
@property
|
||
def _identifying_params(self) -> Dict[str, Any]:
|
||
"""返回用于标识模型的参数"""
|
||
return {
|
||
"model_name": self.model_name,
|
||
"temperature": self.temperature,
|
||
"max_tokens": self.max_tokens,
|
||
}
|