2025-09-22 23:28:55 +08:00

68 lines
2.1 KiB
Python

import os
import asyncio
from typing import Annotated
from langchain.chat_models import init_chat_model
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import HumanMessage
from langgraph.graph.state import RunnableConfig
from typing_extensions import TypedDict
from langchain.tools import tool
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from tools.llm.huoshan_langchain import HuoshanChatModel
@tool
def QueryLocation():
"""
查询当前所在位置
Returns:
Dict: 返回一个包含以下字段的字典:
location (str): 当前所在位置。
"""
return {"location": f"杭州"}
class State(TypedDict):
messages: Annotated[list, add_messages]
graph_builder = StateGraph(State)
llm = init_chat_model(
model="openai:doubao-seed-1.6-250615",
temperature=0,
base_url="https://ark.cn-beijing.volces.com/api/v3/",
api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb",
)
# llm = HuoshanChatModel()
chatAgent = create_react_agent(llm, tools=[QueryLocation], prompt="你是一个智能体,你的任务是回答用户的问题。如果用户的输入不是问题,你可以询问用户是否有其他问题。")
graph_builder.add_node("chatbot", chatAgent)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", END)
graph = graph_builder.compile()
# 单智能体节点测试 工具使用 function calling
async def run():
config:RunnableConfig = {"configurable": {"thread_id": "1"}}
inputs:State = {"messages": [HumanMessage(content="我现在在哪啊?")]}
# # 非流式处理查询
# result = await graph.ainvoke(inputs, config=config)
# print(result)
# 流式处理查询
async for message_chunk, metadata in graph.astream(
input=inputs,
config=config,
stream_mode="messages"
):
# 输出最终结果
if message_chunk.content: # type: ignore
print(f"type:{message_chunk.type} \n msg:{message_chunk.content}", end="\n", flush=True) # type: ignore
if __name__ == "__main__":
asyncio.run(run())