from typing import Any, Dict, Iterator, List, Optional from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage from langchain_core.outputs import ChatGeneration, ChatResult from pydantic import Field from api.huoshan import HuoshanAPI # 导入你现有的API类 class HuoshanChatModel(BaseChatModel): """火山引擎聊天模型的LangChain封装""" # 模型配置参数 model_name: str = Field(default="doubao-seed-1.6-250615", description="模型名称") temperature: float = Field(default=0.6, description="温度参数") max_tokens: int = Field(default=16384, description="最大token数") # 内部API实例 _api: Optional[HuoshanAPI] = None def __init__(self, **kwargs): super().__init__(**kwargs) # 初始化火山引擎API实例 self._api = HuoshanAPI() @property def _llm_type(self) -> str: """返回LLM类型标识""" return "huoshan_chat" def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> tuple[str, str]: """将LangChain消息格式转换为API所需的prompt和system格式""" system_message = "" user_messages = [] for message in messages: if isinstance(message, SystemMessage): system_message = message.content or "" elif isinstance(message, HumanMessage): user_messages.append(message.content) elif isinstance(message, AIMessage): # 如果需要支持多轮对话,可以在这里处理 pass # 合并用户消息 prompt = "\n".join(user_messages) if user_messages else "" return prompt, str(system_message) def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: """生成聊天回复""" if not self._api: raise ValueError("HuoshanAPI未正确初始化") # 转换消息格式 prompt, system = self._convert_messages_to_prompt(messages) # 合并参数 generation_kwargs = { "model": kwargs.get("model", self.model_name), "temperature": kwargs.get("temperature", self.temperature), "system": system } try: # 调用你的API response_text = self._api.get_chat_response( prompt=prompt, **generation_kwargs ) # 创建AI消息 message = AIMessage(content=response_text) # 创建生成结果 generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) except Exception as e: raise ValueError(f"调用火山引擎API失败: {str(e)}") def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGeneration]: """流式生成聊天回复""" if not self._api: raise ValueError("HuoshanAPI未正确初始化") # 转换消息格式 prompt, system = self._convert_messages_to_prompt(messages) # 合并参数 generation_kwargs = { "model": kwargs.get("model", self.model_name), "temperature": kwargs.get("temperature", self.temperature), "system": system } try: # 调用流式API for chunk in self._api.get_chat_response_stream( prompt=prompt, **generation_kwargs ): if chunk: # 创建增量消息 message = AIMessage(content=chunk) generation = ChatGeneration(message=message) # 如果有回调管理器,通知新token if run_manager: run_manager.on_llm_new_token(chunk) yield generation except Exception as e: raise ValueError(f"调用火山引擎流式API失败: {str(e)}") @property def _identifying_params(self) -> Dict[str, Any]: """返回用于标识模型的参数""" return { "model_name": self.model_name, "temperature": self.temperature, "max_tokens": self.max_tokens, }