agent-writer/tools/llm/deepseek_langchain.py

85 lines
3.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_community.chat_models import ChatOpenAI
import os
# 继承自 ChatOpenAI 而不是 BaseChatModel
class DeepseekChatModel(ChatOpenAI):
"""
对 DeepSeek 聊天模型的 LangChain 封装。
这个版本通过继承 ChatOpenAI 来实现,并重写 _generate 方法。
"""
def __init__(self, model_name: str = "deepseek-chat", **kwargs: Any):
"""
初始化 DeepseekChatModel
"""
# 从环境变量或参数中获取 api_key
api_key = kwargs.pop("api_key", os.getenv("DEEPSEEK_API_KEY"))
if not api_key:
raise ValueError(
"DeepSeek API key must be provided either as an argument or set as the DEEPSEEK_API_KEY environment variable."
)
# 调用父类ChatOpenAI的构造函数并传入 DeepSeek 的特定配置
super().__init__(
model=model_name,
api_key=api_key,
base_url="https://api.deepseek.com/v1", # 这是关键
**kwargs,
)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""
重写 _generate 方法,这是 LangChain 的核心调用点。
"""
# 1. 在调用父类方法前,执行你的自定义逻辑(例如,打印日志)
print("\n--- [DeepseekChatModel] 调用 _generate ---")
print(f"输入消息数量: {len(messages)}")
print(f"第一条消息内容: {messages[0].content}")
print("-------------------------------------------\n")
# 2. 使用 super() 调用父类ChatOpenAI的原始 _generate 方法
# 这样可以复用 ChatOpenAI 中所有复杂的、经过测试的逻辑
chat_result = super()._generate(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
# 3. 在获得结果后,你还可以执行其他自定义逻辑
print("\n--- [DeepseekChatModel] _generate 调用完成 ---")
print(f"输出消息内容: {chat_result.generations[0].message.content[:80]}...") # 打印部分输出
print("--------------------------------------------\n")
return chat_result
@property
def _llm_type(self) -> str:
"""返回 language model 的类型。"""
return "deepseek-chat-model-v2" # 改个名以示区别
# --- 使用示例 ---
if __name__ == "__main__":
# 确保设置了 API 密钥
if not os.getenv("DEEPSEEK_API_KEY"):
print("请设置 DEEPSEEK_API_KEY 环境变量。")
else:
# 初始化模型
chat_model = DeepseekChatModel(temperature=0.7)
# 构建输入消息
from langchain_core.messages import HumanMessage
messages = [HumanMessage(content="你好,请介绍一下新加坡。")]
# 调用模型
# 当你调用 .invoke() 或 .stream() 时LangChain 内部会调用我们重写的 _generate 方法
response = chat_model.invoke(messages)
print(f"\n最终得到的回复:\n{response.content}")