137 lines
4.8 KiB
Python
137 lines
4.8 KiB
Python
from typing import Any, Dict, Iterator, List, Optional
|
||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||
from langchain_core.language_models.chat_models import BaseChatModel
|
||
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
|
||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||
from pydantic import Field
|
||
from api.huoshan import HuoshanAPI # 导入你现有的API类
|
||
|
||
|
||
class HuoshanChatModel(BaseChatModel):
|
||
"""火山引擎聊天模型的LangChain封装"""
|
||
|
||
# 模型配置参数
|
||
model_name: str = Field(default="doubao-seed-1.6-250615", description="模型名称")
|
||
temperature: float = Field(default=0.6, description="温度参数")
|
||
max_tokens: int = Field(default=16384, description="最大token数")
|
||
|
||
# 内部API实例
|
||
_api: Optional[HuoshanAPI] = None
|
||
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
# 初始化火山引擎API实例
|
||
self._api = HuoshanAPI()
|
||
|
||
@property
|
||
def _llm_type(self) -> str:
|
||
"""返回LLM类型标识"""
|
||
return "huoshan_chat"
|
||
|
||
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> tuple[str, str]:
|
||
"""将LangChain消息格式转换为API所需的prompt和system格式"""
|
||
system_message = ""
|
||
user_messages = []
|
||
|
||
for message in messages:
|
||
if isinstance(message, SystemMessage):
|
||
system_message = message.content or ""
|
||
elif isinstance(message, HumanMessage):
|
||
user_messages.append(message.content)
|
||
elif isinstance(message, AIMessage):
|
||
# 如果需要支持多轮对话,可以在这里处理
|
||
pass
|
||
|
||
# 合并用户消息
|
||
prompt = "\n".join(user_messages) if user_messages else ""
|
||
|
||
return prompt, str(system_message)
|
||
|
||
def _generate(
|
||
self,
|
||
messages: List[BaseMessage],
|
||
stop: Optional[List[str]] = None,
|
||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
**kwargs: Any,
|
||
) -> ChatResult:
|
||
"""生成聊天回复"""
|
||
if not self._api:
|
||
raise ValueError("HuoshanAPI未正确初始化")
|
||
|
||
# 转换消息格式
|
||
prompt, system = self._convert_messages_to_prompt(messages)
|
||
|
||
# 合并参数
|
||
generation_kwargs = {
|
||
"model": kwargs.get("model", self.model_name),
|
||
"temperature": kwargs.get("temperature", self.temperature),
|
||
"system": system
|
||
}
|
||
|
||
try:
|
||
# 调用你的API
|
||
response_text = self._api.get_chat_response(
|
||
prompt=prompt,
|
||
**generation_kwargs
|
||
)
|
||
|
||
# 创建AI消息
|
||
message = AIMessage(content=response_text)
|
||
|
||
# 创建生成结果
|
||
generation = ChatGeneration(message=message)
|
||
|
||
return ChatResult(generations=[generation])
|
||
|
||
except Exception as e:
|
||
raise ValueError(f"调用火山引擎API失败: {str(e)}")
|
||
|
||
def _stream(
|
||
self,
|
||
messages: List[BaseMessage],
|
||
stop: Optional[List[str]] = None,
|
||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
**kwargs: Any,
|
||
) -> Iterator[ChatGeneration]:
|
||
"""流式生成聊天回复"""
|
||
if not self._api:
|
||
raise ValueError("HuoshanAPI未正确初始化")
|
||
|
||
# 转换消息格式
|
||
prompt, system = self._convert_messages_to_prompt(messages)
|
||
|
||
# 合并参数
|
||
generation_kwargs = {
|
||
"model": kwargs.get("model", self.model_name),
|
||
"temperature": kwargs.get("temperature", self.temperature),
|
||
"system": system
|
||
}
|
||
|
||
try:
|
||
# 调用流式API
|
||
for chunk in self._api.get_chat_response_stream(
|
||
prompt=prompt,
|
||
**generation_kwargs
|
||
):
|
||
if chunk:
|
||
# 创建增量消息
|
||
message = AIMessage(content=chunk)
|
||
generation = ChatGeneration(message=message)
|
||
|
||
# 如果有回调管理器,通知新token
|
||
if run_manager:
|
||
run_manager.on_llm_new_token(chunk)
|
||
|
||
yield generation
|
||
|
||
except Exception as e:
|
||
raise ValueError(f"调用火山引擎流式API失败: {str(e)}")
|
||
|
||
@property
|
||
def _identifying_params(self) -> Dict[str, Any]:
|
||
"""返回用于标识模型的参数"""
|
||
return {
|
||
"model_name": self.model_name,
|
||
"temperature": self.temperature,
|
||
"max_tokens": self.max_tokens,
|
||
} |