2025-09-22 23:28:55 +08:00

161 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import asyncio
import json
from typing import Annotated
from typing_extensions import NotRequired, TypedDict
from langchain.chat_models import init_chat_model
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import HumanMessage, human
from langgraph.graph.state import RunnableConfig
from typing_extensions import TypedDict
from langchain.tools import tool
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from tools.llm.huoshan_langchain import HuoshanChatModel
@tool
def QueryLocation():
"""
查询当前所在位置
Returns:
Dict: 返回一个包含以下字段的字典:
location (str): 当前所在位置。
"""
return {"location": f"杭州"}
class State(TypedDict):
messages: Annotated[list, add_messages]
agent: NotRequired[str]
graph_builder = StateGraph(State)
llm = init_chat_model(
model="openai:doubao-seed-1.6-250615",
temperature=0,
base_url="https://ark.cn-beijing.volces.com/api/v3/",
api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb",
)
# llm = HuoshanChatModel()
routerAgent = create_react_agent(llm, tools=[],
prompt=f"""
你的任务是根据用户的输入,判断使用哪个智能体来继续执行任务
下面是智能体名称与任务描述:
taskA用于回答位置相关的问题
taskB用于给用户讲一个笑话
taskC用于回复用户的其他问题
# 输出规范
请严格按照下列JSON结构返回数据不要有其他任何多余的信息和描述
{{
"agent": "taskA", # 智能体名称
}}
""")
taskAgentA = create_react_agent(llm, tools=[QueryLocation],
prompt=f"""
你的任务是回答位置相关的问题
""")
taskAgentB = create_react_agent(llm, tools=[],
prompt=f"""
你的任务是给用户讲一个笑话
""")
taskAgentC = create_react_agent(llm, tools=[],
prompt=f"""
你的任务是礼貌的回复用户
""")
def node_router(state: State):
"""
路由函数,根据输入的状态判断下一个节点
"""
last_msg = state["messages"][-1].content
# print(f"last_msg:{last_msg}")
try:
data = json.loads(last_msg)
nextAgent = data.get("agent", "taskC")
except json.JSONDecodeError:
nextAgent = "taskC"
return nextAgent
def taskAgentB_node(state: State):
"""
taskAgentB节点用于给用户讲一个笑话
"""
human_msgs = [msg for msg in state["messages"] if msg.type == "human"]
# print(f"taskAgentB_node state:{state}")
res = taskAgentB.invoke({"messages": human_msgs})
state["messages"].append(res["messages"][-1])
return state
def taskAgentA_node(state: State):
"""
taskAgentA节点用于回答位置相关的问题
"""
human_msgs = [msg for msg in state["messages"] if msg.type == "human"]
# print(f"taskAgentA_node state:{state}")
res = taskAgentA.invoke({"messages": human_msgs})
state["messages"].append(res["messages"][-1])
return state
def taskAgentC_node(state: State):
"""
taskAgentC节点用于回复用户的其他问题
"""
human_msgs = [msg for msg in state["messages"] if msg.type == "human"]
# print(f"taskAgentC_node state:{state}")
res = taskAgentC.invoke({"messages": human_msgs})
state["messages"].append(res["messages"][-1])
return state
graph_builder.add_node("router", routerAgent)
graph_builder.add_node("taskA", taskAgentA_node)
graph_builder.add_node("taskB", taskAgentB_node)
graph_builder.add_node("taskC", taskAgentC_node)
graph_builder.add_edge(START, "router")
graph_builder.add_conditional_edges(
"router",
node_router,
{
"taskA": "taskA",
"taskB": "taskB",
"taskC": "taskC",
}
)
graph_builder.add_edge("taskA", END)
graph_builder.add_edge("taskB", END)
graph_builder.add_edge("taskC", END)
graph = graph_builder.compile()
# 多智能体节点测试 工具使用 function calling
#
async def run():
config:RunnableConfig = {"configurable": {"thread_id": "1"}}
# inputs:State = {"messages": [HumanMessage(content="我现在在哪啊?")]}
# inputs:State = {"messages": [HumanMessage(content="给我讲个笑话吧")]}
inputs:State = {"messages": [HumanMessage(content="现在是几点?")]}
# # 非流式处理查询
# result = await graph.ainvoke(inputs, config=config)
# print(result)
# 流式处理查询
async for message_chunk, metadata in graph.astream(
input=inputs,
config=config,
stream_mode="messages"
):
# 输出最终结果
if message_chunk.content: # type: ignore
print(f"type:{message_chunk.type} \n msg:{message_chunk.content}", end="\n", flush=True) # type: ignore
if __name__ == "__main__":
asyncio.run(run())