from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.tools import BaseTool from langchain_core.runnables import Runnable from pydantic import Field import json import copy from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam from volcenginesdkarkruntime.types.shared_params.function_definition import FunctionDefinition from volcenginesdkarkruntime.types.shared_params.function_parameters import FunctionParameters from api.huoshan import HuoshanAPI # 导入你现有的API类 def _convert_dict_type(data: Any) -> Any: """递归地将字典中的"Dict"字符串类型转换为"dict".""" if isinstance(data, dict): new_data = {} for k, v in data.items(): if isinstance(v, str) and v == "Dict": new_data[k] = "dict" else: new_data[k] = _convert_dict_type(v) return new_data elif isinstance(data, list): return [_convert_dict_type(item) for item in data] else: return data class HuoshanChatModel(BaseChatModel): """火山引擎聊天模型的LangChain封装""" # 模型配置参数 model_name: str = Field(default="doubao-seed-1.6-250615", description="模型名称") temperature: float = Field(default=0.6, description="温度参数") max_tokens: int = Field(default=16384, description="最大token数") # 内部API实例 _api: Optional[HuoshanAPI] = None def __init__(self, **kwargs): super().__init__(**kwargs) # 初始化火山引擎API实例 self._api = HuoshanAPI() @property def _llm_type(self) -> str: """返回LLM类型标识""" return "huoshan_chat" # def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> tuple[str, str]: # """将LangChain消息格式转换为API所需的prompt和system格式""" # system_message = "" # user_messages = [] # for message in messages: # if isinstance(message, SystemMessage): # system_message = message.content or "" # elif isinstance(message, HumanMessage): # user_messages.append(message.content) # elif isinstance(message, SystemMessage): # # 如果需要支持多轮对话,可以在这里处理 # pass # # 合并用户消息 # prompt = "\n".join(user_messages) if user_messages else "" # return prompt, str(system_message) def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable: """将工具绑定到模型,并将其转换为火山引擎API所需的格式。""" # 1. 创建一个当前模型的副本 # new_model = copy.copy(self) tool_definitions = [] for tool_item in tools: # 1. 获取工具的参数 JSON Schema if hasattr(tool_item.args_schema, 'model_json_schema'): # 如果是 Pydantic 模型,调用其方法获取 JSON Schema parameters_schema = tool_item.args_schema.model_json_schema() # type: ignore else: parameters_schema = _convert_dict_type(parameters_schema) # 2. 从字典中提取并修正参数 # SDK要求 'type' 为 'object','properties' 和 'required' 字段为必填 schema_type = parameters_schema.get("type", "object") schema_properties = parameters_schema.get("properties", {}) schema_required = parameters_schema.get("required", []) # 3. 将修正后的字典作为参数,直接创建 FunctionParameters 对象 # 这是最安全的方式,因为它绕过了之前可能存在的类型转换问题 function_parameters:FunctionParameters = { "type":schema_type, "properties":schema_properties, "required":schema_required } # 4. 将 FunctionParameters 对象封装成 FunctionDefinition function_def:FunctionDefinition = { "name":tool_item.name, "description":tool_item.description, "parameters":function_parameters } # 5. 将 FunctionDefinition 对象封装成 ChatCompletionToolParam tool_definitions.append( ChatCompletionToolParam( type="function", function=function_def ) ) # 返回一个绑定了工具的新实例 # 这里我们使用_bind方法,它会返回一个新的Runnable实例 # return self._bind(tools=tool_definitions, **kwargs) # type: ignore # return self.bind(tools=tool_definitions, **kwargs) return super().bind(tools=tool_definitions, **kwargs) # new_model._bound_tools = tool_definitions # 3. 返回新的、已绑定工具的模型实例 # return new_model def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> List[Dict]: """将LangChain消息转换为火山引擎API所需的格式。""" # print(f" 原始 messages: {messages}") # 使用字典来存储每个tool_call_id的最新消息 last_tool_messages = {} # 临时存储其他消息(非ToolMessage) other_messages = [] for msg in messages: if isinstance(msg, ToolMessage): # 将最新的ToolMessage存入字典,这会覆盖之前的旧消息 last_tool_messages[msg.tool_call_id] = msg else: other_messages.append(msg) # 从后向前遍历,只保留每个tool_call_id的第一个(即最新的)消息 deduplicated_messages = [] seen_tool_call_ids = set() for msg in reversed(messages): # 如果是ToolMessage,并且tool_call_id已经存在,则跳过 if isinstance(msg, ToolMessage): if msg.tool_call_id not in seen_tool_call_ids: seen_tool_call_ids.add(msg.tool_call_id) deduplicated_messages.append(msg) # 其他类型的消息直接添加 else: deduplicated_messages.append(msg) # 将消息列表还原为原始时间顺序 messages_to_convert = list(reversed(deduplicated_messages)) # 2. 将去重后的消息列表转换为API所需的格式 api_messages = [] for msg in messages_to_convert: if isinstance(msg, HumanMessage): api_messages.append({"role": "user", "content": msg.content}) elif isinstance(msg, SystemMessage): api_messages.append({"role": "system", "content": msg.content}) elif isinstance(msg, ToolMessage): api_messages.append({ "role": "tool", "content": msg.content, "tool_call_id": msg.tool_call_id }) # api_messages = [] # for msg in messages: # if isinstance(msg, HumanMessage): # api_messages.append({"role": "user", "content": msg.content}) # elif isinstance(msg, SystemMessage): # api_messages.append({"role": "system", "content": msg.content}) # elif isinstance(msg, ToolMessage): # api_messages.append({ # "role": "tool", # "content": msg.content, # "tool_call_id": msg.tool_call_id # }) return api_messages def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: if not self._api: raise ValueError("HuoshanAPI未正确初始化") api_messages = self._convert_messages_to_prompt(messages) tools = kwargs.get("tools", []) print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}") response_data = self._api.get_chat_response(messages=api_messages, tools=tools) try: res_choices = response_data.get("choices", [{}])[0] finish_reason = res_choices.get("finish_reason", "") message_from_api = res_choices.get("message", {}) tool_calls = message_from_api.get("tool_calls", []) print(f" 豆包返回的 finish_reason: {finish_reason} \n tool_calls: {tool_calls} \n") print(f" 豆包返回的 message: {message_from_api.get('content', '')}") if finish_reason == "tool_calls" and tool_calls: lc_tool_calls = [] for tc in tool_calls: function_dict = tc.get("function", {}) lc_tool_calls.append(ToolCall( name=function_dict.get("name", ""), args=json.loads(function_dict.get("arguments", "{}")), id=tc.get("id", "") )) message = AIMessage(content="", tool_calls=lc_tool_calls) else: content = message_from_api.get("content", "") message = AIMessage(content=content) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) except Exception as e: import traceback traceback.print_exc() raise ValueError(f"处理火山引擎API响应失败: {str(e)}") def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGeneration]: """流式生成聊天回复""" if not self._api: raise ValueError("HuoshanAPI未正确初始化") # 转换消息格式 api_messages = self._convert_messages_to_prompt(messages) tools = kwargs.get("tools", []) try: # 调用流式API for chunk in self._api.get_chat_response_stream( messages=api_messages, tools=tools, **kwargs ): if chunk: # 创建增量消息 message = AIMessage(content=chunk) generation = ChatGeneration(message=message) # 如果有回调管理器,通知新token if run_manager: run_manager.on_llm_new_token(chunk) yield generation except Exception as e: raise ValueError(f"调用火山引擎流式API失败: {str(e)}") @property def _identifying_params(self) -> Dict[str, Any]: """返回用于标识模型的参数""" return { "model_name": self.model_name, "temperature": self.temperature, "max_tokens": self.max_tokens, }