agent-writer/tools/llm/openai_langchain.py

193 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from typing import Any, List, Optional, Type
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool, tool
# --- 1. 中央模型配置 ---
# 在这里添加或修改你想要支持的模型提供商
# 'api_key_env': 存放 API Key 的环境变量名称
# 'base_url': 模型的 API 端点
# 'supports_tools': 该模型是否支持 LangChain 的 .bind_tools() 功能
# 'default_model': 如果不指定模型,默认使用的模型名称
MODEL_CONFIG = {
"openai": {
"api_key_env": "OPENAI_API_KEY",
"base_url": "https://api.openai.com/v1",
"supports_tools": True,
"default_model": "gpt-4o",
},
"deepseek": {
"api_key_env": "DEEPSEEK_API_KEY",
"base_url": "https://api.deepseek.com/v1",
"supports_tools": True,
"default_model": "deepseek-chat",
},
"groq": {
"api_key_env": "GROQ_API_KEY",
"base_url": "https://api.groq.com/openai/v1",
"supports_tools": True,
"default_model": "llama3-70b-8192",
},
"moonshot": {
"api_key_env": "MOONSHOT_API_KEY",
"base_url": "https://api.moonshot.cn/v1",
"supports_tools": True,
"default_model": "moonshot-v1-8k",
},
# 示例:一个不支持工具的老模型
"legacy_provider": {
"api_key_env": "LEGACY_API_KEY",
"base_url": "http://localhost:8080/v1", # 假设是本地服务
"supports_tools": False,
"default_model": "legacy-model-v1",
},
# 新增火山引擎的配置
"huoshan": {
"api_key_env": "VOLC_API_KEY", # 请确认您存放密钥的环境变量名
"base_url": "https://ark.cn-beijing.volces.com/api/v3", # 这是火山引擎方舟的 OpenAI 兼容端点
"supports_tools": True,
"default_model": "ep-20240615082149-j225c", # 使用您需要的模型 Endpoint ID
},
}
# --- 2. 模型创建工厂函数 ---
def create_llm_client(
provider: str,
model_name: Optional[str] = None,
tools: Optional[List[Type[BaseModel]]] = None,
**kwargs: Any,
) -> ChatOpenAI:
"""
根据提供的 provider 创建并配置一个 LangChain 聊天模型客户端。
Args:
provider: 模型提供商的名称 (必须是 MODEL_CONFIG 中的一个 key)。
model_name: 要使用的具体模型名称。如果为 None, 则使用配置中的默认模型。
tools: 一个工具列表,用于绑定到模型上以实现 function calling。
**kwargs: 其他要传递给 ChatOpenAI 的参数 (例如 temperature, max_tokens)。
Returns:
一个配置好的 ChatOpenAI 实例,可能已经绑定了工具。
"""
if provider not in MODEL_CONFIG:
raise ValueError(f"不支持的模型提供商: {provider}。可用选项: {list(MODEL_CONFIG.keys())}")
config = MODEL_CONFIG[provider]
api_key = os.getenv(config["api_key_env"])
if not api_key:
raise ValueError(f"请设置环境变量 {config['api_key_env']} 以使用 {provider} 模型。")
# 如果未指定 model_name则使用配置中的默认值
final_model_name = model_name or config["default_model"]
# 创建基础的 LLM 客户端
llm = ChatOpenAI(
model=final_model_name,
api_key=api_key,
base_url=config["base_url"],
**kwargs,
)
# 根据配置,有条件地绑定工具
if tools:
if config["supports_tools"]:
print(f"[{provider}] 模型支持工具,正在绑定 {len(tools)} 个工具...")
return llm.bind_tools(tools)
else:
print(f"⚠️ 警告: 您为 [{provider}] 提供了工具,但该模型在配置中被标记为不支持工具。将返回未绑定工具的模型。")
return llm
return llm
# --- 3. 定义你的工具 (Function Calling) ---
# 使用 Pydantic 模型定义工具的输入参数,确保类型安全和清晰的描述
class GetWeatherInput(BaseModel):
city: str = Field(description="需要查询天气的城市名称, 例如: Singapore")
# 使用 @tool 装饰器可以轻松地将任何函数转换为 LangChain 工具
@tool(args_schema=GetWeatherInput)
def get_current_weather(city: str) -> str:
"""
当需要查询指定城市的当前天气时,调用此工具。
"""
# 这是一个模拟实现实际应用中你会在这里调用真实的天气API
print(f"--- 正在调用工具: get_current_weather(city='{city}') ---")
if "singapore" in city.lower():
return f"新加坡今天的天气是晴朗,温度为 31°C。"
elif "beijing" in city.lower():
return f"北京今天的天气是多云,温度为 25°C。"
else:
return f"抱歉,我无法查询到 {city} 的天气信息。"
# --- 4. 主程序:演示如何使用 ---
if __name__ == "__main__":
# 将你的工具放入一个列表
my_tools = [get_current_weather]
# ---- 示例 1: 使用 DeepSeek 并调用工具 ----
print("\n================ 示例 1: 使用 DeepSeek 调用工具 ================")
# 确保你已经设置了环境变量: export DEEPSEEK_API_KEY="sk-..."
try:
deepseek_llm_with_tools = create_llm_client(
provider="deepseek",
tools=my_tools,
temperature=0
)
prompt = "今天新加坡的天气怎么样?"
print(f"用户问题: {prompt}")
# LangChain 会自动处理LLM -> Tool Call -> Execute Tool -> LLM -> Final Answer
# 为了演示,我们只看第一步的输出
ai_msg = deepseek_llm_with_tools.invoke(prompt)
print("\nLLM 返回的初步响应 (AIMessage):")
print(ai_msg)
# 检查返回的是否是工具调用
if ai_msg.tool_calls:
print(f"\n模型请求调用工具: {ai_msg.tool_calls[0]['name']}")
# 在实际应用中,你会在这里执行工具并把结果返回给模型
else:
print("\n模型直接给出了回答:")
print(ai_msg.content)
except ValueError as e:
print(e)
# ---- 示例 2: 使用 Groq (Llama3) 且不使用工具 ----
print("\n================ 示例 2: 使用 Groq 进行常规聊天 ================")
# 确保你已经设置了环境变量: export GROQ_API_KEY="gsk_..."
try:
groq_llm = create_llm_client(
provider="groq",
temperature=0.7,
# model_name="llama3-8b-8192" # 你也可以覆盖默认模型
)
prompt = "请给我写一首关于新加坡的五言绝句。"
print(f"用户问题: {prompt}")
response = groq_llm.invoke(prompt)
print("\nGroq (Llama3) 的回答:")
print(response.content)
except ValueError as e:
print(e)
# ---- 示例 3: 尝试给不支持工具的模型绑定工具 ----
print("\n========== 示例 3: 尝试为不支持工具的模型绑定工具 ==========")
# 假设你设置了 export LEGACY_API_KEY="some_key"
try:
legacy_llm = create_llm_client(
provider="legacy_provider",
tools=my_tools
)
# 注意,这里会打印出警告信息,并且 legacy_llm 不会绑定任何工具
except ValueError as e:
print(e)