85 lines
3.3 KiB
Python
85 lines
3.3 KiB
Python
from typing import Any, List, Optional
|
||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||
from langchain_core.messages import BaseMessage
|
||
from langchain_core.outputs import ChatResult
|
||
from langchain_community.chat_models import ChatOpenAI
|
||
import os
|
||
|
||
# 继承自 ChatOpenAI 而不是 BaseChatModel
|
||
class DeepseekChatModel(ChatOpenAI):
|
||
"""
|
||
对 DeepSeek 聊天模型的 LangChain 封装。
|
||
这个版本通过继承 ChatOpenAI 来实现,并重写 _generate 方法。
|
||
"""
|
||
|
||
def __init__(self, model_name: str = "deepseek-chat", **kwargs: Any):
|
||
"""
|
||
初始化 DeepseekChatModel
|
||
"""
|
||
# 从环境变量或参数中获取 api_key
|
||
api_key = kwargs.pop("api_key", os.getenv("DEEPSEEK_API_KEY"))
|
||
if not api_key:
|
||
raise ValueError(
|
||
"DeepSeek API key must be provided either as an argument or set as the DEEPSEEK_API_KEY environment variable."
|
||
)
|
||
|
||
# 调用父类(ChatOpenAI)的构造函数,并传入 DeepSeek 的特定配置
|
||
super().__init__(
|
||
model=model_name,
|
||
api_key=api_key,
|
||
base_url="https://api.deepseek.com/v1", # 这是关键
|
||
**kwargs,
|
||
)
|
||
|
||
def _generate(
|
||
self,
|
||
messages: List[BaseMessage],
|
||
stop: Optional[List[str]] = None,
|
||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
**kwargs: Any,
|
||
) -> ChatResult:
|
||
"""
|
||
重写 _generate 方法,这是 LangChain 的核心调用点。
|
||
"""
|
||
# 1. 在调用父类方法前,执行你的自定义逻辑(例如,打印日志)
|
||
print("\n--- [DeepseekChatModel] 调用 _generate ---")
|
||
print(f"输入消息数量: {len(messages)}")
|
||
print(f"第一条消息内容: {messages[0].content}")
|
||
print("-------------------------------------------\n")
|
||
|
||
# 2. 使用 super() 调用父类(ChatOpenAI)的原始 _generate 方法
|
||
# 这样可以复用 ChatOpenAI 中所有复杂的、经过测试的逻辑
|
||
chat_result = super()._generate(
|
||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||
)
|
||
|
||
# 3. 在获得结果后,你还可以执行其他自定义逻辑
|
||
print("\n--- [DeepseekChatModel] _generate 调用完成 ---")
|
||
print(f"输出消息内容: {chat_result.generations[0].message.content[:80]}...") # 打印部分输出
|
||
print("--------------------------------------------\n")
|
||
|
||
return chat_result
|
||
|
||
@property
|
||
def _llm_type(self) -> str:
|
||
"""返回 language model 的类型。"""
|
||
return "deepseek-chat-model-v2" # 改个名以示区别
|
||
|
||
# --- 使用示例 ---
|
||
if __name__ == "__main__":
|
||
# 确保设置了 API 密钥
|
||
if not os.getenv("DEEPSEEK_API_KEY"):
|
||
print("请设置 DEEPSEEK_API_KEY 环境变量。")
|
||
else:
|
||
# 初始化模型
|
||
chat_model = DeepseekChatModel(temperature=0.7)
|
||
|
||
# 构建输入消息
|
||
from langchain_core.messages import HumanMessage
|
||
messages = [HumanMessage(content="你好,请介绍一下新加坡。")]
|
||
|
||
# 调用模型
|
||
# 当你调用 .invoke() 或 .stream() 时,LangChain 内部会调用我们重写的 _generate 方法
|
||
response = chat_model.invoke(messages)
|
||
|
||
print(f"\n最终得到的回复:\n{response.content}") |