46 lines
1.2 KiB
Python
46 lines
1.2 KiB
Python
import os
|
|
import asyncio
|
|
from typing import Annotated
|
|
|
|
from langchain.chat_models import init_chat_model
|
|
from langchain_community.chat_models import ChatOpenAI
|
|
from langchain_core.messages import HumanMessage
|
|
from langgraph.graph.state import RunnableConfig
|
|
from typing_extensions import TypedDict
|
|
|
|
from langgraph.graph import StateGraph, START, END
|
|
from langgraph.graph.message import add_messages
|
|
|
|
|
|
class State(TypedDict):
|
|
messages: Annotated[list, add_messages]
|
|
|
|
|
|
graph_builder = StateGraph(State)
|
|
|
|
llm = init_chat_model(
|
|
model="openai:doubao-seed-1.6-250615",
|
|
temperature=0,
|
|
base_url="https://ark.cn-beijing.volces.com/api/v3/",
|
|
api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb",
|
|
)
|
|
|
|
def chatbot(state: State):
|
|
return {"messages": [llm.invoke(state["messages"])]}
|
|
|
|
|
|
graph_builder.add_node("chatbot", chatbot)
|
|
graph_builder.add_edge(START, "chatbot")
|
|
graph_builder.add_edge("chatbot", END)
|
|
graph = graph_builder.compile()
|
|
|
|
|
|
# 单对话节点测试
|
|
async def run():
|
|
config:RunnableConfig = {"configurable": {"thread_id": "1"}}
|
|
inputs:State = {"messages": [HumanMessage(content="你好")]}
|
|
result = await graph.ainvoke(inputs, config=config)
|
|
print(result)
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(run()) |