From 6ca656791fecd7b249550b1fa02f3601a6d698f5 Mon Sep 17 00:00:00 2001 From: jonathang4 Date: Mon, 22 Sep 2025 23:28:55 +0800 Subject: [PATCH] base test --- api/huoshan.py | 7 ++ graph/test/t1.py | 46 ++++++++++ graph/test/t2.py | 56 ++++++++++++ graph/test/t3.py | 68 ++++++++++++++ graph/test/t4.py | 161 +++++++++++++++++++++++++++++++++ requirements.txt | 1 + tools/agent/queryDB.py | 18 ++-- tools/agent/updateDB.py | 10 +- tools/llm/huoshan_langchain.py | 1 - 9 files changed, 353 insertions(+), 15 deletions(-) create mode 100644 graph/test/t1.py create mode 100644 graph/test/t2.py create mode 100644 graph/test/t3.py create mode 100644 graph/test/t4.py diff --git a/api/huoshan.py b/api/huoshan.py index bb83a03..516f674 100644 --- a/api/huoshan.py +++ b/api/huoshan.py @@ -560,12 +560,19 @@ class HuoshanAPI: max_tokens=16384, # 16K timeout=600, stream=True, + thinking={ + "type": "disabled", # 不使用深度思考能力 + # "type": "enabled", # 使用深度思考能力 + # "type": "auto", # 模型自行判断是否使用深度思考能力 + }, tools=tools # 传入 tools 参数 ) for chunk in completion: if chunk.choices and chunk.choices[0].delta.content is not None: # type: ignore yield chunk.choices[0].delta.content # type: ignore except Exception as e: + import traceback + traceback.print_exc() raise ValueError(f"火山引擎API流式调用失败: {str(e)}") def analyze_image(self, image_url: str, prompt: str = "请描述这张图片的内容", model: Optional[str] = None, detail: str = "high") -> Dict[str, Any]: diff --git a/graph/test/t1.py b/graph/test/t1.py new file mode 100644 index 0000000..c09f130 --- /dev/null +++ b/graph/test/t1.py @@ -0,0 +1,46 @@ +import os +import asyncio +from typing import Annotated + +from langchain.chat_models import init_chat_model +from langchain_community.chat_models import ChatOpenAI +from langchain_core.messages import HumanMessage +from langgraph.graph.state import RunnableConfig +from typing_extensions import TypedDict + +from langgraph.graph import StateGraph, START, END +from langgraph.graph.message import add_messages + + +class State(TypedDict): + messages: Annotated[list, add_messages] + + +graph_builder = StateGraph(State) + +llm = init_chat_model( + model="openai:doubao-seed-1.6-250615", + temperature=0, + base_url="https://ark.cn-beijing.volces.com/api/v3/", + api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb", +) + +def chatbot(state: State): + return {"messages": [llm.invoke(state["messages"])]} + + +graph_builder.add_node("chatbot", chatbot) +graph_builder.add_edge(START, "chatbot") +graph_builder.add_edge("chatbot", END) +graph = graph_builder.compile() + + +# 单对话节点测试 +async def run(): + config:RunnableConfig = {"configurable": {"thread_id": "1"}} + inputs:State = {"messages": [HumanMessage(content="你好")]} + result = await graph.ainvoke(inputs, config=config) + print(result) + +if __name__ == "__main__": + asyncio.run(run()) \ No newline at end of file diff --git a/graph/test/t2.py b/graph/test/t2.py new file mode 100644 index 0000000..7829e4f --- /dev/null +++ b/graph/test/t2.py @@ -0,0 +1,56 @@ +import os +import asyncio +from typing import Annotated + +from langchain.chat_models import init_chat_model +from langgraph.prebuilt import create_react_agent +from langchain_core.messages import HumanMessage +from langgraph.graph.state import RunnableConfig +from typing_extensions import TypedDict + +from langgraph.graph import StateGraph, START, END +from langgraph.graph.message import add_messages + + +class State(TypedDict): + messages: Annotated[list, add_messages] + + +graph_builder = StateGraph(State) + +llm = init_chat_model( + model="openai:doubao-seed-1.6-250615", + temperature=0, + base_url="https://ark.cn-beijing.volces.com/api/v3/", + api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb", +) + +chatAgent = create_react_agent(llm, tools=[], prompt="你是一个智能体,你的任务是回答用户的问题。如果用户的输入不是问题,你可以询问用户是否有其他问题。") + +graph_builder.add_node("chatbot", chatAgent) +graph_builder.add_edge(START, "chatbot") +graph_builder.add_edge("chatbot", END) +graph = graph_builder.compile() + +# 单智能体节点测试 +async def run(): + config:RunnableConfig = {"configurable": {"thread_id": "1"}} + inputs:State = {"messages": [HumanMessage(content="你好")]} + + # # 非流式处理查询 + # result = await graph.ainvoke(inputs, config=config) + # print(result) + + # 流式处理查询 + async for message_chunk, metadata in graph.astream( + input=inputs, + config=config, + stream_mode="messages" + ): + # 输出最终结果 + if message_chunk.content: # type: ignore + print(message_chunk.content, end="|", flush=True) # type: ignore + + +if __name__ == "__main__": + asyncio.run(run()) \ No newline at end of file diff --git a/graph/test/t3.py b/graph/test/t3.py new file mode 100644 index 0000000..2905b5c --- /dev/null +++ b/graph/test/t3.py @@ -0,0 +1,68 @@ +import os +import asyncio +from typing import Annotated + +from langchain.chat_models import init_chat_model +from langgraph.prebuilt import create_react_agent +from langchain_core.messages import HumanMessage +from langgraph.graph.state import RunnableConfig +from typing_extensions import TypedDict +from langchain.tools import tool + +from langgraph.graph import StateGraph, START, END +from langgraph.graph.message import add_messages +from tools.llm.huoshan_langchain import HuoshanChatModel + +@tool +def QueryLocation(): + """ + 查询当前所在位置 + Returns: + Dict: 返回一个包含以下字段的字典: + location (str): 当前所在位置。 + """ + return {"location": f"杭州"} + +class State(TypedDict): + messages: Annotated[list, add_messages] + +graph_builder = StateGraph(State) + +llm = init_chat_model( + model="openai:doubao-seed-1.6-250615", + temperature=0, + base_url="https://ark.cn-beijing.volces.com/api/v3/", + api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb", +) + +# llm = HuoshanChatModel() + +chatAgent = create_react_agent(llm, tools=[QueryLocation], prompt="你是一个智能体,你的任务是回答用户的问题。如果用户的输入不是问题,你可以询问用户是否有其他问题。") + +graph_builder.add_node("chatbot", chatAgent) +graph_builder.add_edge(START, "chatbot") +graph_builder.add_edge("chatbot", END) +graph = graph_builder.compile() + +# 单智能体节点测试 工具使用 function calling +async def run(): + config:RunnableConfig = {"configurable": {"thread_id": "1"}} + inputs:State = {"messages": [HumanMessage(content="我现在在哪啊?")]} + + # # 非流式处理查询 + # result = await graph.ainvoke(inputs, config=config) + # print(result) + + # 流式处理查询 + async for message_chunk, metadata in graph.astream( + input=inputs, + config=config, + stream_mode="messages" + ): + # 输出最终结果 + if message_chunk.content: # type: ignore + print(f"type:{message_chunk.type} \n msg:{message_chunk.content}", end="\n", flush=True) # type: ignore + + +if __name__ == "__main__": + asyncio.run(run()) \ No newline at end of file diff --git a/graph/test/t4.py b/graph/test/t4.py new file mode 100644 index 0000000..c2494b7 --- /dev/null +++ b/graph/test/t4.py @@ -0,0 +1,161 @@ +import os +import asyncio +import json + +from typing import Annotated +from typing_extensions import NotRequired, TypedDict + +from langchain.chat_models import init_chat_model +from langgraph.prebuilt import create_react_agent +from langchain_core.messages import HumanMessage, human +from langgraph.graph.state import RunnableConfig +from typing_extensions import TypedDict +from langchain.tools import tool + +from langgraph.graph import StateGraph, START, END +from langgraph.graph.message import add_messages +from tools.llm.huoshan_langchain import HuoshanChatModel + +@tool +def QueryLocation(): + """ + 查询当前所在位置 + Returns: + Dict: 返回一个包含以下字段的字典: + location (str): 当前所在位置。 + """ + return {"location": f"杭州"} + +class State(TypedDict): + messages: Annotated[list, add_messages] + agent: NotRequired[str] + +graph_builder = StateGraph(State) + +llm = init_chat_model( + model="openai:doubao-seed-1.6-250615", + temperature=0, + base_url="https://ark.cn-beijing.volces.com/api/v3/", + api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb", +) + +# llm = HuoshanChatModel() + +routerAgent = create_react_agent(llm, tools=[], + prompt=f""" + 你的任务是根据用户的输入,判断使用哪个智能体来继续执行任务 + 下面是智能体名称与任务描述: + taskA:用于回答位置相关的问题 + taskB:用于给用户讲一个笑话 + taskC:用于回复用户的其他问题 + + # 输出规范 + 请严格按照下列JSON结构返回数据,不要有其他任何多余的信息和描述: + {{ + "agent": "taskA", # 智能体名称 + }} + """) +taskAgentA = create_react_agent(llm, tools=[QueryLocation], + prompt=f""" + 你的任务是回答位置相关的问题 + """) +taskAgentB = create_react_agent(llm, tools=[], + prompt=f""" + 你的任务是给用户讲一个笑话 + """) +taskAgentC = create_react_agent(llm, tools=[], + prompt=f""" + 你的任务是礼貌的回复用户 + """) +def node_router(state: State): + """ + 路由函数,根据输入的状态判断下一个节点 + """ + + last_msg = state["messages"][-1].content + # print(f"last_msg:{last_msg}") + try: + data = json.loads(last_msg) + nextAgent = data.get("agent", "taskC") + except json.JSONDecodeError: + nextAgent = "taskC" + return nextAgent + +def taskAgentB_node(state: State): + """ + taskAgentB节点,用于给用户讲一个笑话 + """ + human_msgs = [msg for msg in state["messages"] if msg.type == "human"] + # print(f"taskAgentB_node state:{state}") + res = taskAgentB.invoke({"messages": human_msgs}) + state["messages"].append(res["messages"][-1]) + return state + +def taskAgentA_node(state: State): + """ + taskAgentA节点,用于回答位置相关的问题 + """ + human_msgs = [msg for msg in state["messages"] if msg.type == "human"] + # print(f"taskAgentA_node state:{state}") + res = taskAgentA.invoke({"messages": human_msgs}) + state["messages"].append(res["messages"][-1]) + return state + +def taskAgentC_node(state: State): + """ + taskAgentC节点,用于回复用户的其他问题 + """ + human_msgs = [msg for msg in state["messages"] if msg.type == "human"] + # print(f"taskAgentC_node state:{state}") + res = taskAgentC.invoke({"messages": human_msgs}) + state["messages"].append(res["messages"][-1]) + return state + +graph_builder.add_node("router", routerAgent) +graph_builder.add_node("taskA", taskAgentA_node) +graph_builder.add_node("taskB", taskAgentB_node) +graph_builder.add_node("taskC", taskAgentC_node) + +graph_builder.add_edge(START, "router") + +graph_builder.add_conditional_edges( + "router", + node_router, + { + "taskA": "taskA", + "taskB": "taskB", + "taskC": "taskC", + } +) + +graph_builder.add_edge("taskA", END) +graph_builder.add_edge("taskB", END) +graph_builder.add_edge("taskC", END) + +graph = graph_builder.compile() + +# 多智能体节点测试 工具使用 function calling +# +async def run(): + config:RunnableConfig = {"configurable": {"thread_id": "1"}} + # inputs:State = {"messages": [HumanMessage(content="我现在在哪啊?")]} + # inputs:State = {"messages": [HumanMessage(content="给我讲个笑话吧")]} + inputs:State = {"messages": [HumanMessage(content="现在是几点?")]} + + # # 非流式处理查询 + # result = await graph.ainvoke(inputs, config=config) + # print(result) + + # 流式处理查询 + async for message_chunk, metadata in graph.astream( + input=inputs, + config=config, + stream_mode="messages" + ): + # 输出最终结果 + if message_chunk.content: # type: ignore + print(f"type:{message_chunk.type} \n msg:{message_chunk.content}", end="\n", flush=True) # type: ignore + + +if __name__ == "__main__": + asyncio.run(run()) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index aa01572..8897746 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,5 +19,6 @@ PyMySQL DBUtils PyMuPDF langchain-community +langchain-openai langgraph langgraph-checkpoint-mongodb \ No newline at end of file diff --git a/tools/agent/queryDB.py b/tools/agent/queryDB.py index 71c257b..668ca3c 100644 --- a/tools/agent/queryDB.py +++ b/tools/agent/queryDB.py @@ -7,7 +7,7 @@ from tools.database.mongo import mainDB from langchain.tools import tool import json -@tool +@tool(return_direct=True) def QueryOriginalScript(session_id: str): """ 查询原始剧本内容是否存在 @@ -52,7 +52,7 @@ def QueryOriginalScriptContent(session_id: str): } -@tool +@tool(return_direct=True) def QueryDiagnosisAndAssessment(session_id: str): """ 查询诊断与资产评估报告是否存在 @@ -81,7 +81,7 @@ def QueryDiagnosisAndAssessmentContent(session_id: str): "content": script["diagnosis_and_assessment"] if script else "", } -@tool +@tool(return_direct=True) def QueryAdaptationIdeas(session_id: str): """ 查询改编思路是否存在 @@ -110,7 +110,7 @@ def QueryAdaptationIdeasContent(session_id: str): "content": script["adaptation_ideas"] if script else "", } -@tool +@tool(return_direct=True) def QueryScriptBible(session_id: str): """ 查询剧本圣经是否存在 @@ -139,7 +139,7 @@ def QueryScriptBibleContent(session_id: str): "content": script["script_bible"] if script else {}, } -@tool +@tool(return_direct=True) def QueryCoreOutline(session_id: str): """ 查询剧本圣经中的核心大纲是否存在 @@ -154,7 +154,7 @@ def QueryCoreOutline(session_id: str): "exist": script is not None, } -@tool +@tool(return_direct=True) def QueryCharacterProfile(session_id: str): """ 查询剧本圣经中的核心人物小传是否存在 @@ -169,7 +169,7 @@ def QueryCharacterProfile(session_id: str): "exist": script is not None, } -@tool +@tool(return_direct=True) def QueryCoreEventTimeline(session_id: str): """ 查询剧本圣经中的重大事件时间线是否存在 @@ -184,7 +184,7 @@ def QueryCoreEventTimeline(session_id: str): "exist": script is not None, } -@tool +@tool(return_direct=True) def QueryCharacterList(session_id: str): """ 查询剧本圣经中的总人物表是否存在 @@ -199,7 +199,7 @@ def QueryCharacterList(session_id: str): "exist": script is not None, } -@tool +@tool(return_direct=True) def QueryEpisodeCount(session_id: str): """ 查询剧集创作情况 diff --git a/tools/agent/updateDB.py b/tools/agent/updateDB.py index e483694..107e041 100644 --- a/tools/agent/updateDB.py +++ b/tools/agent/updateDB.py @@ -2,7 +2,7 @@ from bson import ObjectId from tools.database.mongo import mainDB from langchain.tools import tool -@tool +@tool(return_direct=True) def UpdateDiagnosisAndAssessmentTool(session_id: str, diagnosis_and_assessment: str): """ 更新诊断与资产评估报告 @@ -30,7 +30,7 @@ def UpdateDiagnosisAndAssessment(session_id: str, diagnosis_and_assessment: str) "success": script.modified_count > 0, } -@tool +@tool(return_direct=True) def UpdateAdaptationIdeasTool(session_id: str, adaptation_ideas: str): """ 更新改编思路 @@ -58,7 +58,7 @@ def UpdateAdaptationIdeas(session_id: str, adaptation_ideas: str): "success": script.modified_count > 0, } -@tool +@tool(return_direct=True) def UpdateScriptBibleTool( session_id: str, core_outline:str|None = None, @@ -117,7 +117,7 @@ def UpdateScriptBible( "success": script.modified_count > 0, } -@tool +@tool(return_direct=True) def SetTotalEpisodeNumTool(session_id: str, total_episode_num: int): """ 设置总集数 @@ -145,7 +145,7 @@ def SetTotalEpisodeNum(session_id: str, total_episode_num: int): "success": script.modified_count > 0, } -@tool +@tool(return_direct=True) def UpdateOneEpisodeTool(session_id: str, episode_num:int, content: str): """ 更新单集内容 diff --git a/tools/llm/huoshan_langchain.py b/tools/llm/huoshan_langchain.py index ade11e4..ec7fdda 100644 --- a/tools/llm/huoshan_langchain.py +++ b/tools/llm/huoshan_langchain.py @@ -237,7 +237,6 @@ class HuoshanChatModel(BaseChatModel): for chunk in self._api.get_chat_response_stream( messages=api_messages, tools=tools, - **kwargs ): if chunk: # 创建增量消息