agent-writer/tools/llm/huoshan_langchain.py
2025-09-11 18:34:03 +08:00

137 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Dict, Iterator, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from pydantic import Field
from api.huoshan import HuoshanAPI # 导入你现有的API类
class HuoshanChatModel(BaseChatModel):
"""火山引擎聊天模型的LangChain封装"""
# 模型配置参数
model_name: str = Field(default="doubao-seed-1.6-250615", description="模型名称")
temperature: float = Field(default=0.6, description="温度参数")
max_tokens: int = Field(default=16384, description="最大token数")
# 内部API实例
_api: Optional[HuoshanAPI] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 初始化火山引擎API实例
self._api = HuoshanAPI()
@property
def _llm_type(self) -> str:
"""返回LLM类型标识"""
return "huoshan_chat"
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> tuple[str, str]:
"""将LangChain消息格式转换为API所需的prompt和system格式"""
system_message = ""
user_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message.content or ""
elif isinstance(message, HumanMessage):
user_messages.append(message.content)
elif isinstance(message, AIMessage):
# 如果需要支持多轮对话,可以在这里处理
pass
# 合并用户消息
prompt = "\n".join(user_messages) if user_messages else ""
return prompt, str(system_message)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""生成聊天回复"""
if not self._api:
raise ValueError("HuoshanAPI未正确初始化")
# 转换消息格式
prompt, system = self._convert_messages_to_prompt(messages)
# 合并参数
generation_kwargs = {
"model": kwargs.get("model", self.model_name),
"temperature": kwargs.get("temperature", self.temperature),
"system": system
}
try:
# 调用你的API
response_text = self._api.get_chat_response(
prompt=prompt,
**generation_kwargs
)
# 创建AI消息
message = AIMessage(content=response_text)
# 创建生成结果
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
except Exception as e:
raise ValueError(f"调用火山引擎API失败: {str(e)}")
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGeneration]:
"""流式生成聊天回复"""
if not self._api:
raise ValueError("HuoshanAPI未正确初始化")
# 转换消息格式
prompt, system = self._convert_messages_to_prompt(messages)
# 合并参数
generation_kwargs = {
"model": kwargs.get("model", self.model_name),
"temperature": kwargs.get("temperature", self.temperature),
"system": system
}
try:
# 调用流式API
for chunk in self._api.get_chat_response_stream(
prompt=prompt,
**generation_kwargs
):
if chunk:
# 创建增量消息
message = AIMessage(content=chunk)
generation = ChatGeneration(message=message)
# 如果有回调管理器通知新token
if run_manager:
run_manager.on_llm_new_token(chunk)
yield generation
except Exception as e:
raise ValueError(f"调用火山引擎流式API失败: {str(e)}")
@property
def _identifying_params(self) -> Dict[str, Any]:
"""返回用于标识模型的参数"""
return {
"model_name": self.model_name,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
}