161 lines
4.9 KiB
Python
161 lines
4.9 KiB
Python
import os
|
||
import asyncio
|
||
import json
|
||
|
||
from typing import Annotated
|
||
from typing_extensions import NotRequired, TypedDict
|
||
|
||
from langchain.chat_models import init_chat_model
|
||
from langgraph.prebuilt import create_react_agent
|
||
from langchain_core.messages import HumanMessage, human
|
||
from langgraph.graph.state import RunnableConfig
|
||
from typing_extensions import TypedDict
|
||
from langchain.tools import tool
|
||
|
||
from langgraph.graph import StateGraph, START, END
|
||
from langgraph.graph.message import add_messages
|
||
from tools.llm.huoshan_langchain import HuoshanChatModel
|
||
|
||
@tool
|
||
def QueryLocation():
|
||
"""
|
||
查询当前所在位置
|
||
Returns:
|
||
Dict: 返回一个包含以下字段的字典:
|
||
location (str): 当前所在位置。
|
||
"""
|
||
return {"location": f"杭州"}
|
||
|
||
class State(TypedDict):
|
||
messages: Annotated[list, add_messages]
|
||
agent: NotRequired[str]
|
||
|
||
graph_builder = StateGraph(State)
|
||
|
||
llm = init_chat_model(
|
||
model="openai:doubao-seed-1.6-250615",
|
||
temperature=0,
|
||
base_url="https://ark.cn-beijing.volces.com/api/v3/",
|
||
api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb",
|
||
)
|
||
|
||
# llm = HuoshanChatModel()
|
||
|
||
routerAgent = create_react_agent(llm, tools=[],
|
||
prompt=f"""
|
||
你的任务是根据用户的输入,判断使用哪个智能体来继续执行任务
|
||
下面是智能体名称与任务描述:
|
||
taskA:用于回答位置相关的问题
|
||
taskB:用于给用户讲一个笑话
|
||
taskC:用于回复用户的其他问题
|
||
|
||
# 输出规范
|
||
请严格按照下列JSON结构返回数据,不要有其他任何多余的信息和描述:
|
||
{{
|
||
"agent": "taskA", # 智能体名称
|
||
}}
|
||
""")
|
||
taskAgentA = create_react_agent(llm, tools=[QueryLocation],
|
||
prompt=f"""
|
||
你的任务是回答位置相关的问题
|
||
""")
|
||
taskAgentB = create_react_agent(llm, tools=[],
|
||
prompt=f"""
|
||
你的任务是给用户讲一个笑话
|
||
""")
|
||
taskAgentC = create_react_agent(llm, tools=[],
|
||
prompt=f"""
|
||
你的任务是礼貌的回复用户
|
||
""")
|
||
def node_router(state: State):
|
||
"""
|
||
路由函数,根据输入的状态判断下一个节点
|
||
"""
|
||
|
||
last_msg = state["messages"][-1].content
|
||
# print(f"last_msg:{last_msg}")
|
||
try:
|
||
data = json.loads(last_msg)
|
||
nextAgent = data.get("agent", "taskC")
|
||
except json.JSONDecodeError:
|
||
nextAgent = "taskC"
|
||
return nextAgent
|
||
|
||
def taskAgentB_node(state: State):
|
||
"""
|
||
taskAgentB节点,用于给用户讲一个笑话
|
||
"""
|
||
human_msgs = [msg for msg in state["messages"] if msg.type == "human"]
|
||
# print(f"taskAgentB_node state:{state}")
|
||
res = taskAgentB.invoke({"messages": human_msgs})
|
||
state["messages"].append(res["messages"][-1])
|
||
return state
|
||
|
||
def taskAgentA_node(state: State):
|
||
"""
|
||
taskAgentA节点,用于回答位置相关的问题
|
||
"""
|
||
human_msgs = [msg for msg in state["messages"] if msg.type == "human"]
|
||
# print(f"taskAgentA_node state:{state}")
|
||
res = taskAgentA.invoke({"messages": human_msgs})
|
||
state["messages"].append(res["messages"][-1])
|
||
return state
|
||
|
||
def taskAgentC_node(state: State):
|
||
"""
|
||
taskAgentC节点,用于回复用户的其他问题
|
||
"""
|
||
human_msgs = [msg for msg in state["messages"] if msg.type == "human"]
|
||
# print(f"taskAgentC_node state:{state}")
|
||
res = taskAgentC.invoke({"messages": human_msgs})
|
||
state["messages"].append(res["messages"][-1])
|
||
return state
|
||
|
||
graph_builder.add_node("router", routerAgent)
|
||
graph_builder.add_node("taskA", taskAgentA_node)
|
||
graph_builder.add_node("taskB", taskAgentB_node)
|
||
graph_builder.add_node("taskC", taskAgentC_node)
|
||
|
||
graph_builder.add_edge(START, "router")
|
||
|
||
graph_builder.add_conditional_edges(
|
||
"router",
|
||
node_router,
|
||
{
|
||
"taskA": "taskA",
|
||
"taskB": "taskB",
|
||
"taskC": "taskC",
|
||
}
|
||
)
|
||
|
||
graph_builder.add_edge("taskA", END)
|
||
graph_builder.add_edge("taskB", END)
|
||
graph_builder.add_edge("taskC", END)
|
||
|
||
graph = graph_builder.compile()
|
||
|
||
# 多智能体节点测试 工具使用 function calling
|
||
#
|
||
async def run():
|
||
config:RunnableConfig = {"configurable": {"thread_id": "1"}}
|
||
# inputs:State = {"messages": [HumanMessage(content="我现在在哪啊?")]}
|
||
# inputs:State = {"messages": [HumanMessage(content="给我讲个笑话吧")]}
|
||
inputs:State = {"messages": [HumanMessage(content="现在是几点?")]}
|
||
|
||
# # 非流式处理查询
|
||
# result = await graph.ainvoke(inputs, config=config)
|
||
# print(result)
|
||
|
||
# 流式处理查询
|
||
async for message_chunk, metadata in graph.astream(
|
||
input=inputs,
|
||
config=config,
|
||
stream_mode="messages"
|
||
):
|
||
# 输出最终结果
|
||
if message_chunk.content: # type: ignore
|
||
print(f"type:{message_chunk.type} \n msg:{message_chunk.content}", end="\n", flush=True) # type: ignore
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(run()) |