base test
This commit is contained in:
parent
20556d7ecb
commit
6ca656791f
@ -560,12 +560,19 @@ class HuoshanAPI:
|
||||
max_tokens=16384, # 16K
|
||||
timeout=600,
|
||||
stream=True,
|
||||
thinking={
|
||||
"type": "disabled", # 不使用深度思考能力
|
||||
# "type": "enabled", # 使用深度思考能力
|
||||
# "type": "auto", # 模型自行判断是否使用深度思考能力
|
||||
},
|
||||
tools=tools # 传入 tools 参数
|
||||
)
|
||||
for chunk in completion:
|
||||
if chunk.choices and chunk.choices[0].delta.content is not None: # type: ignore
|
||||
yield chunk.choices[0].delta.content # type: ignore
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise ValueError(f"火山引擎API流式调用失败: {str(e)}")
|
||||
|
||||
def analyze_image(self, image_url: str, prompt: str = "请描述这张图片的内容", model: Optional[str] = None, detail: str = "high") -> Dict[str, Any]:
|
||||
|
||||
46
graph/test/t1.py
Normal file
46
graph/test/t1.py
Normal file
@ -0,0 +1,46 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.graph.state import RunnableConfig
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
|
||||
graph_builder = StateGraph(State)
|
||||
|
||||
llm = init_chat_model(
|
||||
model="openai:doubao-seed-1.6-250615",
|
||||
temperature=0,
|
||||
base_url="https://ark.cn-beijing.volces.com/api/v3/",
|
||||
api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb",
|
||||
)
|
||||
|
||||
def chatbot(state: State):
|
||||
return {"messages": [llm.invoke(state["messages"])]}
|
||||
|
||||
|
||||
graph_builder.add_node("chatbot", chatbot)
|
||||
graph_builder.add_edge(START, "chatbot")
|
||||
graph_builder.add_edge("chatbot", END)
|
||||
graph = graph_builder.compile()
|
||||
|
||||
|
||||
# 单对话节点测试
|
||||
async def run():
|
||||
config:RunnableConfig = {"configurable": {"thread_id": "1"}}
|
||||
inputs:State = {"messages": [HumanMessage(content="你好")]}
|
||||
result = await graph.ainvoke(inputs, config=config)
|
||||
print(result)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run())
|
||||
56
graph/test/t2.py
Normal file
56
graph/test/t2.py
Normal file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.graph.state import RunnableConfig
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
|
||||
graph_builder = StateGraph(State)
|
||||
|
||||
llm = init_chat_model(
|
||||
model="openai:doubao-seed-1.6-250615",
|
||||
temperature=0,
|
||||
base_url="https://ark.cn-beijing.volces.com/api/v3/",
|
||||
api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb",
|
||||
)
|
||||
|
||||
chatAgent = create_react_agent(llm, tools=[], prompt="你是一个智能体,你的任务是回答用户的问题。如果用户的输入不是问题,你可以询问用户是否有其他问题。")
|
||||
|
||||
graph_builder.add_node("chatbot", chatAgent)
|
||||
graph_builder.add_edge(START, "chatbot")
|
||||
graph_builder.add_edge("chatbot", END)
|
||||
graph = graph_builder.compile()
|
||||
|
||||
# 单智能体节点测试
|
||||
async def run():
|
||||
config:RunnableConfig = {"configurable": {"thread_id": "1"}}
|
||||
inputs:State = {"messages": [HumanMessage(content="你好")]}
|
||||
|
||||
# # 非流式处理查询
|
||||
# result = await graph.ainvoke(inputs, config=config)
|
||||
# print(result)
|
||||
|
||||
# 流式处理查询
|
||||
async for message_chunk, metadata in graph.astream(
|
||||
input=inputs,
|
||||
config=config,
|
||||
stream_mode="messages"
|
||||
):
|
||||
# 输出最终结果
|
||||
if message_chunk.content: # type: ignore
|
||||
print(message_chunk.content, end="|", flush=True) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run())
|
||||
68
graph/test/t3.py
Normal file
68
graph/test/t3.py
Normal file
@ -0,0 +1,68 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.graph.state import RunnableConfig
|
||||
from typing_extensions import TypedDict
|
||||
from langchain.tools import tool
|
||||
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
from tools.llm.huoshan_langchain import HuoshanChatModel
|
||||
|
||||
@tool
|
||||
def QueryLocation():
|
||||
"""
|
||||
查询当前所在位置
|
||||
Returns:
|
||||
Dict: 返回一个包含以下字段的字典:
|
||||
location (str): 当前所在位置。
|
||||
"""
|
||||
return {"location": f"杭州"}
|
||||
|
||||
class State(TypedDict):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
graph_builder = StateGraph(State)
|
||||
|
||||
llm = init_chat_model(
|
||||
model="openai:doubao-seed-1.6-250615",
|
||||
temperature=0,
|
||||
base_url="https://ark.cn-beijing.volces.com/api/v3/",
|
||||
api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb",
|
||||
)
|
||||
|
||||
# llm = HuoshanChatModel()
|
||||
|
||||
chatAgent = create_react_agent(llm, tools=[QueryLocation], prompt="你是一个智能体,你的任务是回答用户的问题。如果用户的输入不是问题,你可以询问用户是否有其他问题。")
|
||||
|
||||
graph_builder.add_node("chatbot", chatAgent)
|
||||
graph_builder.add_edge(START, "chatbot")
|
||||
graph_builder.add_edge("chatbot", END)
|
||||
graph = graph_builder.compile()
|
||||
|
||||
# 单智能体节点测试 工具使用 function calling
|
||||
async def run():
|
||||
config:RunnableConfig = {"configurable": {"thread_id": "1"}}
|
||||
inputs:State = {"messages": [HumanMessage(content="我现在在哪啊?")]}
|
||||
|
||||
# # 非流式处理查询
|
||||
# result = await graph.ainvoke(inputs, config=config)
|
||||
# print(result)
|
||||
|
||||
# 流式处理查询
|
||||
async for message_chunk, metadata in graph.astream(
|
||||
input=inputs,
|
||||
config=config,
|
||||
stream_mode="messages"
|
||||
):
|
||||
# 输出最终结果
|
||||
if message_chunk.content: # type: ignore
|
||||
print(f"type:{message_chunk.type} \n msg:{message_chunk.content}", end="\n", flush=True) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run())
|
||||
161
graph/test/t4.py
Normal file
161
graph/test/t4.py
Normal file
@ -0,0 +1,161 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from typing import Annotated
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langchain_core.messages import HumanMessage, human
|
||||
from langgraph.graph.state import RunnableConfig
|
||||
from typing_extensions import TypedDict
|
||||
from langchain.tools import tool
|
||||
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
from tools.llm.huoshan_langchain import HuoshanChatModel
|
||||
|
||||
@tool
|
||||
def QueryLocation():
|
||||
"""
|
||||
查询当前所在位置
|
||||
Returns:
|
||||
Dict: 返回一个包含以下字段的字典:
|
||||
location (str): 当前所在位置。
|
||||
"""
|
||||
return {"location": f"杭州"}
|
||||
|
||||
class State(TypedDict):
|
||||
messages: Annotated[list, add_messages]
|
||||
agent: NotRequired[str]
|
||||
|
||||
graph_builder = StateGraph(State)
|
||||
|
||||
llm = init_chat_model(
|
||||
model="openai:doubao-seed-1.6-250615",
|
||||
temperature=0,
|
||||
base_url="https://ark.cn-beijing.volces.com/api/v3/",
|
||||
api_key="0d5189b3-9a03-4393-81be-8c1ba1e97cbb",
|
||||
)
|
||||
|
||||
# llm = HuoshanChatModel()
|
||||
|
||||
routerAgent = create_react_agent(llm, tools=[],
|
||||
prompt=f"""
|
||||
你的任务是根据用户的输入,判断使用哪个智能体来继续执行任务
|
||||
下面是智能体名称与任务描述:
|
||||
taskA:用于回答位置相关的问题
|
||||
taskB:用于给用户讲一个笑话
|
||||
taskC:用于回复用户的其他问题
|
||||
|
||||
# 输出规范
|
||||
请严格按照下列JSON结构返回数据,不要有其他任何多余的信息和描述:
|
||||
{{
|
||||
"agent": "taskA", # 智能体名称
|
||||
}}
|
||||
""")
|
||||
taskAgentA = create_react_agent(llm, tools=[QueryLocation],
|
||||
prompt=f"""
|
||||
你的任务是回答位置相关的问题
|
||||
""")
|
||||
taskAgentB = create_react_agent(llm, tools=[],
|
||||
prompt=f"""
|
||||
你的任务是给用户讲一个笑话
|
||||
""")
|
||||
taskAgentC = create_react_agent(llm, tools=[],
|
||||
prompt=f"""
|
||||
你的任务是礼貌的回复用户
|
||||
""")
|
||||
def node_router(state: State):
|
||||
"""
|
||||
路由函数,根据输入的状态判断下一个节点
|
||||
"""
|
||||
|
||||
last_msg = state["messages"][-1].content
|
||||
# print(f"last_msg:{last_msg}")
|
||||
try:
|
||||
data = json.loads(last_msg)
|
||||
nextAgent = data.get("agent", "taskC")
|
||||
except json.JSONDecodeError:
|
||||
nextAgent = "taskC"
|
||||
return nextAgent
|
||||
|
||||
def taskAgentB_node(state: State):
|
||||
"""
|
||||
taskAgentB节点,用于给用户讲一个笑话
|
||||
"""
|
||||
human_msgs = [msg for msg in state["messages"] if msg.type == "human"]
|
||||
# print(f"taskAgentB_node state:{state}")
|
||||
res = taskAgentB.invoke({"messages": human_msgs})
|
||||
state["messages"].append(res["messages"][-1])
|
||||
return state
|
||||
|
||||
def taskAgentA_node(state: State):
|
||||
"""
|
||||
taskAgentA节点,用于回答位置相关的问题
|
||||
"""
|
||||
human_msgs = [msg for msg in state["messages"] if msg.type == "human"]
|
||||
# print(f"taskAgentA_node state:{state}")
|
||||
res = taskAgentA.invoke({"messages": human_msgs})
|
||||
state["messages"].append(res["messages"][-1])
|
||||
return state
|
||||
|
||||
def taskAgentC_node(state: State):
|
||||
"""
|
||||
taskAgentC节点,用于回复用户的其他问题
|
||||
"""
|
||||
human_msgs = [msg for msg in state["messages"] if msg.type == "human"]
|
||||
# print(f"taskAgentC_node state:{state}")
|
||||
res = taskAgentC.invoke({"messages": human_msgs})
|
||||
state["messages"].append(res["messages"][-1])
|
||||
return state
|
||||
|
||||
graph_builder.add_node("router", routerAgent)
|
||||
graph_builder.add_node("taskA", taskAgentA_node)
|
||||
graph_builder.add_node("taskB", taskAgentB_node)
|
||||
graph_builder.add_node("taskC", taskAgentC_node)
|
||||
|
||||
graph_builder.add_edge(START, "router")
|
||||
|
||||
graph_builder.add_conditional_edges(
|
||||
"router",
|
||||
node_router,
|
||||
{
|
||||
"taskA": "taskA",
|
||||
"taskB": "taskB",
|
||||
"taskC": "taskC",
|
||||
}
|
||||
)
|
||||
|
||||
graph_builder.add_edge("taskA", END)
|
||||
graph_builder.add_edge("taskB", END)
|
||||
graph_builder.add_edge("taskC", END)
|
||||
|
||||
graph = graph_builder.compile()
|
||||
|
||||
# 多智能体节点测试 工具使用 function calling
|
||||
#
|
||||
async def run():
|
||||
config:RunnableConfig = {"configurable": {"thread_id": "1"}}
|
||||
# inputs:State = {"messages": [HumanMessage(content="我现在在哪啊?")]}
|
||||
# inputs:State = {"messages": [HumanMessage(content="给我讲个笑话吧")]}
|
||||
inputs:State = {"messages": [HumanMessage(content="现在是几点?")]}
|
||||
|
||||
# # 非流式处理查询
|
||||
# result = await graph.ainvoke(inputs, config=config)
|
||||
# print(result)
|
||||
|
||||
# 流式处理查询
|
||||
async for message_chunk, metadata in graph.astream(
|
||||
input=inputs,
|
||||
config=config,
|
||||
stream_mode="messages"
|
||||
):
|
||||
# 输出最终结果
|
||||
if message_chunk.content: # type: ignore
|
||||
print(f"type:{message_chunk.type} \n msg:{message_chunk.content}", end="\n", flush=True) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run())
|
||||
@ -19,5 +19,6 @@ PyMySQL
|
||||
DBUtils
|
||||
PyMuPDF
|
||||
langchain-community
|
||||
langchain-openai
|
||||
langgraph
|
||||
langgraph-checkpoint-mongodb
|
||||
@ -7,7 +7,7 @@ from tools.database.mongo import mainDB
|
||||
from langchain.tools import tool
|
||||
import json
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def QueryOriginalScript(session_id: str):
|
||||
"""
|
||||
查询原始剧本内容是否存在
|
||||
@ -52,7 +52,7 @@ def QueryOriginalScriptContent(session_id: str):
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def QueryDiagnosisAndAssessment(session_id: str):
|
||||
"""
|
||||
查询诊断与资产评估报告是否存在
|
||||
@ -81,7 +81,7 @@ def QueryDiagnosisAndAssessmentContent(session_id: str):
|
||||
"content": script["diagnosis_and_assessment"] if script else "",
|
||||
}
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def QueryAdaptationIdeas(session_id: str):
|
||||
"""
|
||||
查询改编思路是否存在
|
||||
@ -110,7 +110,7 @@ def QueryAdaptationIdeasContent(session_id: str):
|
||||
"content": script["adaptation_ideas"] if script else "",
|
||||
}
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def QueryScriptBible(session_id: str):
|
||||
"""
|
||||
查询剧本圣经是否存在
|
||||
@ -139,7 +139,7 @@ def QueryScriptBibleContent(session_id: str):
|
||||
"content": script["script_bible"] if script else {},
|
||||
}
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def QueryCoreOutline(session_id: str):
|
||||
"""
|
||||
查询剧本圣经中的核心大纲是否存在
|
||||
@ -154,7 +154,7 @@ def QueryCoreOutline(session_id: str):
|
||||
"exist": script is not None,
|
||||
}
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def QueryCharacterProfile(session_id: str):
|
||||
"""
|
||||
查询剧本圣经中的核心人物小传是否存在
|
||||
@ -169,7 +169,7 @@ def QueryCharacterProfile(session_id: str):
|
||||
"exist": script is not None,
|
||||
}
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def QueryCoreEventTimeline(session_id: str):
|
||||
"""
|
||||
查询剧本圣经中的重大事件时间线是否存在
|
||||
@ -184,7 +184,7 @@ def QueryCoreEventTimeline(session_id: str):
|
||||
"exist": script is not None,
|
||||
}
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def QueryCharacterList(session_id: str):
|
||||
"""
|
||||
查询剧本圣经中的总人物表是否存在
|
||||
@ -199,7 +199,7 @@ def QueryCharacterList(session_id: str):
|
||||
"exist": script is not None,
|
||||
}
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def QueryEpisodeCount(session_id: str):
|
||||
"""
|
||||
查询剧集创作情况
|
||||
|
||||
@ -2,7 +2,7 @@ from bson import ObjectId
|
||||
from tools.database.mongo import mainDB
|
||||
from langchain.tools import tool
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def UpdateDiagnosisAndAssessmentTool(session_id: str, diagnosis_and_assessment: str):
|
||||
"""
|
||||
更新诊断与资产评估报告
|
||||
@ -30,7 +30,7 @@ def UpdateDiagnosisAndAssessment(session_id: str, diagnosis_and_assessment: str)
|
||||
"success": script.modified_count > 0,
|
||||
}
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def UpdateAdaptationIdeasTool(session_id: str, adaptation_ideas: str):
|
||||
"""
|
||||
更新改编思路
|
||||
@ -58,7 +58,7 @@ def UpdateAdaptationIdeas(session_id: str, adaptation_ideas: str):
|
||||
"success": script.modified_count > 0,
|
||||
}
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def UpdateScriptBibleTool(
|
||||
session_id: str,
|
||||
core_outline:str|None = None,
|
||||
@ -117,7 +117,7 @@ def UpdateScriptBible(
|
||||
"success": script.modified_count > 0,
|
||||
}
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def SetTotalEpisodeNumTool(session_id: str, total_episode_num: int):
|
||||
"""
|
||||
设置总集数
|
||||
@ -145,7 +145,7 @@ def SetTotalEpisodeNum(session_id: str, total_episode_num: int):
|
||||
"success": script.modified_count > 0,
|
||||
}
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def UpdateOneEpisodeTool(session_id: str, episode_num:int, content: str):
|
||||
"""
|
||||
更新单集内容
|
||||
|
||||
@ -237,7 +237,6 @@ class HuoshanChatModel(BaseChatModel):
|
||||
for chunk in self._api.get_chat_response_stream(
|
||||
messages=api_messages,
|
||||
tools=tools,
|
||||
**kwargs
|
||||
):
|
||||
if chunk:
|
||||
# 创建增量消息
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user