56 lines
1.7 KiB
Python
56 lines
1.7 KiB
Python
import os
|
|
import asyncio
|
|
from typing import Annotated
|
|
|
|
from langchain.chat_models import init_chat_model
|
|
from langgraph.prebuilt import create_react_agent
|
|
from langchain_core.messages import HumanMessage
|
|
from langgraph.graph.state import RunnableConfig
|
|
from typing_extensions import TypedDict
|
|
|
|
from langgraph.graph import StateGraph, START, END
|
|
from langgraph.graph.message import add_messages
|
|
|
|
|
|
class State(TypedDict):
|
|
messages: Annotated[list, add_messages]
|
|
|
|
|
|
graph_builder = StateGraph(State)
|
|
|
|
llm = init_chat_model(
|
|
model="openai:doubao-seed-1.6-250615",
|
|
temperature=0,
|
|
base_url="https://ark.cn-beijing.volces.com/api/v3/",
|
|
api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb",
|
|
)
|
|
|
|
chatAgent = create_react_agent(llm, tools=[], prompt="你是一个智能体,你的任务是回答用户的问题。如果用户的输入不是问题,你可以询问用户是否有其他问题。")
|
|
|
|
graph_builder.add_node("chatbot", chatAgent)
|
|
graph_builder.add_edge(START, "chatbot")
|
|
graph_builder.add_edge("chatbot", END)
|
|
graph = graph_builder.compile()
|
|
|
|
# 单智能体节点测试
|
|
async def run():
|
|
config:RunnableConfig = {"configurable": {"thread_id": "1"}}
|
|
inputs:State = {"messages": [HumanMessage(content="你好")]}
|
|
|
|
# # 非流式处理查询
|
|
# result = await graph.ainvoke(inputs, config=config)
|
|
# print(result)
|
|
|
|
# 流式处理查询
|
|
async for message_chunk, metadata in graph.astream(
|
|
input=inputs,
|
|
config=config,
|
|
stream_mode="messages"
|
|
):
|
|
# 输出最终结果
|
|
if message_chunk.content: # type: ignore
|
|
print(message_chunk.content, end="|", flush=True) # type: ignore
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(run()) |