base test

This commit is contained in:
jonathang4 2025-09-22 23:28:55 +08:00
parent 20556d7ecb
commit 6ca656791f
9 changed files with 353 additions and 15 deletions

View File

@ -560,12 +560,19 @@ class HuoshanAPI:
max_tokens=16384, # 16K
timeout=600,
stream=True,
thinking={
"type": "disabled", # 不使用深度思考能力
# "type": "enabled", # 使用深度思考能力
# "type": "auto", # 模型自行判断是否使用深度思考能力
},
tools=tools # 传入 tools 参数
)
for chunk in completion:
if chunk.choices and chunk.choices[0].delta.content is not None: # type: ignore
yield chunk.choices[0].delta.content # type: ignore
except Exception as e:
import traceback
traceback.print_exc()
raise ValueError(f"火山引擎API流式调用失败: {str(e)}")
def analyze_image(self, image_url: str, prompt: str = "请描述这张图片的内容", model: Optional[str] = None, detail: str = "high") -> Dict[str, Any]:

46
graph/test/t1.py Normal file
View File

@ -0,0 +1,46 @@
import os
import asyncio
from typing import Annotated
from langchain.chat_models import init_chat_model
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import HumanMessage
from langgraph.graph.state import RunnableConfig
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
class State(TypedDict):
messages: Annotated[list, add_messages]
graph_builder = StateGraph(State)
llm = init_chat_model(
model="openai:doubao-seed-1.6-250615",
temperature=0,
base_url="https://ark.cn-beijing.volces.com/api/v3/",
api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb",
)
def chatbot(state: State):
return {"messages": [llm.invoke(state["messages"])]}
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", END)
graph = graph_builder.compile()
# 单对话节点测试
async def run():
config:RunnableConfig = {"configurable": {"thread_id": "1"}}
inputs:State = {"messages": [HumanMessage(content="你好")]}
result = await graph.ainvoke(inputs, config=config)
print(result)
if __name__ == "__main__":
asyncio.run(run())

56
graph/test/t2.py Normal file
View File

@ -0,0 +1,56 @@
import os
import asyncio
from typing import Annotated
from langchain.chat_models import init_chat_model
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import HumanMessage
from langgraph.graph.state import RunnableConfig
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
class State(TypedDict):
messages: Annotated[list, add_messages]
graph_builder = StateGraph(State)
llm = init_chat_model(
model="openai:doubao-seed-1.6-250615",
temperature=0,
base_url="https://ark.cn-beijing.volces.com/api/v3/",
api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb",
)
chatAgent = create_react_agent(llm, tools=[], prompt="你是一个智能体,你的任务是回答用户的问题。如果用户的输入不是问题,你可以询问用户是否有其他问题。")
graph_builder.add_node("chatbot", chatAgent)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", END)
graph = graph_builder.compile()
# 单智能体节点测试
async def run():
config:RunnableConfig = {"configurable": {"thread_id": "1"}}
inputs:State = {"messages": [HumanMessage(content="你好")]}
# # 非流式处理查询
# result = await graph.ainvoke(inputs, config=config)
# print(result)
# 流式处理查询
async for message_chunk, metadata in graph.astream(
input=inputs,
config=config,
stream_mode="messages"
):
# 输出最终结果
if message_chunk.content: # type: ignore
print(message_chunk.content, end="|", flush=True) # type: ignore
if __name__ == "__main__":
asyncio.run(run())

68
graph/test/t3.py Normal file
View File

@ -0,0 +1,68 @@
import os
import asyncio
from typing import Annotated
from langchain.chat_models import init_chat_model
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import HumanMessage
from langgraph.graph.state import RunnableConfig
from typing_extensions import TypedDict
from langchain.tools import tool
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from tools.llm.huoshan_langchain import HuoshanChatModel
@tool
def QueryLocation():
"""
查询当前所在位置
Returns:
Dict: 返回一个包含以下字段的字典
location (str): 当前所在位置
"""
return {"location": f"杭州"}
class State(TypedDict):
messages: Annotated[list, add_messages]
graph_builder = StateGraph(State)
llm = init_chat_model(
model="openai:doubao-seed-1.6-250615",
temperature=0,
base_url="https://ark.cn-beijing.volces.com/api/v3/",
api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb",
)
# llm = HuoshanChatModel()
chatAgent = create_react_agent(llm, tools=[QueryLocation], prompt="你是一个智能体,你的任务是回答用户的问题。如果用户的输入不是问题,你可以询问用户是否有其他问题。")
graph_builder.add_node("chatbot", chatAgent)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", END)
graph = graph_builder.compile()
# 单智能体节点测试 工具使用 function calling
async def run():
config:RunnableConfig = {"configurable": {"thread_id": "1"}}
inputs:State = {"messages": [HumanMessage(content="我现在在哪啊?")]}
# # 非流式处理查询
# result = await graph.ainvoke(inputs, config=config)
# print(result)
# 流式处理查询
async for message_chunk, metadata in graph.astream(
input=inputs,
config=config,
stream_mode="messages"
):
# 输出最终结果
if message_chunk.content: # type: ignore
print(f"type:{message_chunk.type} \n msg:{message_chunk.content}", end="\n", flush=True) # type: ignore
if __name__ == "__main__":
asyncio.run(run())

161
graph/test/t4.py Normal file
View File

@ -0,0 +1,161 @@
import os
import asyncio
import json
from typing import Annotated
from typing_extensions import NotRequired, TypedDict
from langchain.chat_models import init_chat_model
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import HumanMessage, human
from langgraph.graph.state import RunnableConfig
from typing_extensions import TypedDict
from langchain.tools import tool
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from tools.llm.huoshan_langchain import HuoshanChatModel
@tool
def QueryLocation():
"""
查询当前所在位置
Returns:
Dict: 返回一个包含以下字段的字典
location (str): 当前所在位置
"""
return {"location": f"杭州"}
class State(TypedDict):
messages: Annotated[list, add_messages]
agent: NotRequired[str]
graph_builder = StateGraph(State)
llm = init_chat_model(
model="openai:doubao-seed-1.6-250615",
temperature=0,
base_url="https://ark.cn-beijing.volces.com/api/v3/",
api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb",
)
# llm = HuoshanChatModel()
routerAgent = create_react_agent(llm, tools=[],
prompt=f"""
你的任务是根据用户的输入判断使用哪个智能体来继续执行任务
下面是智能体名称与任务描述:
taskA用于回答位置相关的问题
taskB用于给用户讲一个笑话
taskC用于回复用户的其他问题
# 输出规范
请严格按照下列JSON结构返回数据不要有其他任何多余的信息和描述
{{
"agent": "taskA", # 智能体名称
}}
""")
taskAgentA = create_react_agent(llm, tools=[QueryLocation],
prompt=f"""
你的任务是回答位置相关的问题
""")
taskAgentB = create_react_agent(llm, tools=[],
prompt=f"""
你的任务是给用户讲一个笑话
""")
taskAgentC = create_react_agent(llm, tools=[],
prompt=f"""
你的任务是礼貌的回复用户
""")
def node_router(state: State):
"""
路由函数根据输入的状态判断下一个节点
"""
last_msg = state["messages"][-1].content
# print(f"last_msg:{last_msg}")
try:
data = json.loads(last_msg)
nextAgent = data.get("agent", "taskC")
except json.JSONDecodeError:
nextAgent = "taskC"
return nextAgent
def taskAgentB_node(state: State):
"""
taskAgentB节点用于给用户讲一个笑话
"""
human_msgs = [msg for msg in state["messages"] if msg.type == "human"]
# print(f"taskAgentB_node state:{state}")
res = taskAgentB.invoke({"messages": human_msgs})
state["messages"].append(res["messages"][-1])
return state
def taskAgentA_node(state: State):
"""
taskAgentA节点用于回答位置相关的问题
"""
human_msgs = [msg for msg in state["messages"] if msg.type == "human"]
# print(f"taskAgentA_node state:{state}")
res = taskAgentA.invoke({"messages": human_msgs})
state["messages"].append(res["messages"][-1])
return state
def taskAgentC_node(state: State):
"""
taskAgentC节点用于回复用户的其他问题
"""
human_msgs = [msg for msg in state["messages"] if msg.type == "human"]
# print(f"taskAgentC_node state:{state}")
res = taskAgentC.invoke({"messages": human_msgs})
state["messages"].append(res["messages"][-1])
return state
graph_builder.add_node("router", routerAgent)
graph_builder.add_node("taskA", taskAgentA_node)
graph_builder.add_node("taskB", taskAgentB_node)
graph_builder.add_node("taskC", taskAgentC_node)
graph_builder.add_edge(START, "router")
graph_builder.add_conditional_edges(
"router",
node_router,
{
"taskA": "taskA",
"taskB": "taskB",
"taskC": "taskC",
}
)
graph_builder.add_edge("taskA", END)
graph_builder.add_edge("taskB", END)
graph_builder.add_edge("taskC", END)
graph = graph_builder.compile()
# 多智能体节点测试 工具使用 function calling
#
async def run():
config:RunnableConfig = {"configurable": {"thread_id": "1"}}
# inputs:State = {"messages": [HumanMessage(content="我现在在哪啊?")]}
# inputs:State = {"messages": [HumanMessage(content="给我讲个笑话吧")]}
inputs:State = {"messages": [HumanMessage(content="现在是几点?")]}
# # 非流式处理查询
# result = await graph.ainvoke(inputs, config=config)
# print(result)
# 流式处理查询
async for message_chunk, metadata in graph.astream(
input=inputs,
config=config,
stream_mode="messages"
):
# 输出最终结果
if message_chunk.content: # type: ignore
print(f"type:{message_chunk.type} \n msg:{message_chunk.content}", end="\n", flush=True) # type: ignore
if __name__ == "__main__":
asyncio.run(run())

View File

@ -19,5 +19,6 @@ PyMySQL
DBUtils
PyMuPDF
langchain-community
langchain-openai
langgraph
langgraph-checkpoint-mongodb

View File

@ -7,7 +7,7 @@ from tools.database.mongo import mainDB
from langchain.tools import tool
import json
@tool
@tool(return_direct=True)
def QueryOriginalScript(session_id: str):
"""
查询原始剧本内容是否存在
@ -52,7 +52,7 @@ def QueryOriginalScriptContent(session_id: str):
}
@tool
@tool(return_direct=True)
def QueryDiagnosisAndAssessment(session_id: str):
"""
查询诊断与资产评估报告是否存在
@ -81,7 +81,7 @@ def QueryDiagnosisAndAssessmentContent(session_id: str):
"content": script["diagnosis_and_assessment"] if script else "",
}
@tool
@tool(return_direct=True)
def QueryAdaptationIdeas(session_id: str):
"""
查询改编思路是否存在
@ -110,7 +110,7 @@ def QueryAdaptationIdeasContent(session_id: str):
"content": script["adaptation_ideas"] if script else "",
}
@tool
@tool(return_direct=True)
def QueryScriptBible(session_id: str):
"""
查询剧本圣经是否存在
@ -139,7 +139,7 @@ def QueryScriptBibleContent(session_id: str):
"content": script["script_bible"] if script else {},
}
@tool
@tool(return_direct=True)
def QueryCoreOutline(session_id: str):
"""
查询剧本圣经中的核心大纲是否存在
@ -154,7 +154,7 @@ def QueryCoreOutline(session_id: str):
"exist": script is not None,
}
@tool
@tool(return_direct=True)
def QueryCharacterProfile(session_id: str):
"""
查询剧本圣经中的核心人物小传是否存在
@ -169,7 +169,7 @@ def QueryCharacterProfile(session_id: str):
"exist": script is not None,
}
@tool
@tool(return_direct=True)
def QueryCoreEventTimeline(session_id: str):
"""
查询剧本圣经中的重大事件时间线是否存在
@ -184,7 +184,7 @@ def QueryCoreEventTimeline(session_id: str):
"exist": script is not None,
}
@tool
@tool(return_direct=True)
def QueryCharacterList(session_id: str):
"""
查询剧本圣经中的总人物表是否存在
@ -199,7 +199,7 @@ def QueryCharacterList(session_id: str):
"exist": script is not None,
}
@tool
@tool(return_direct=True)
def QueryEpisodeCount(session_id: str):
"""
查询剧集创作情况

View File

@ -2,7 +2,7 @@ from bson import ObjectId
from tools.database.mongo import mainDB
from langchain.tools import tool
@tool
@tool(return_direct=True)
def UpdateDiagnosisAndAssessmentTool(session_id: str, diagnosis_and_assessment: str):
"""
更新诊断与资产评估报告
@ -30,7 +30,7 @@ def UpdateDiagnosisAndAssessment(session_id: str, diagnosis_and_assessment: str)
"success": script.modified_count > 0,
}
@tool
@tool(return_direct=True)
def UpdateAdaptationIdeasTool(session_id: str, adaptation_ideas: str):
"""
更新改编思路
@ -58,7 +58,7 @@ def UpdateAdaptationIdeas(session_id: str, adaptation_ideas: str):
"success": script.modified_count > 0,
}
@tool
@tool(return_direct=True)
def UpdateScriptBibleTool(
session_id: str,
core_outline:str|None = None,
@ -117,7 +117,7 @@ def UpdateScriptBible(
"success": script.modified_count > 0,
}
@tool
@tool(return_direct=True)
def SetTotalEpisodeNumTool(session_id: str, total_episode_num: int):
"""
设置总集数
@ -145,7 +145,7 @@ def SetTotalEpisodeNum(session_id: str, total_episode_num: int):
"success": script.modified_count > 0,
}
@tool
@tool(return_direct=True)
def UpdateOneEpisodeTool(session_id: str, episode_num:int, content: str):
"""
更新单集内容

View File

@ -237,7 +237,6 @@ class HuoshanChatModel(BaseChatModel):
for chunk in self._api.get_chat_response_stream(
messages=api_messages,
tools=tools,
**kwargs
):
if chunk:
# 创建增量消息