工具方法调用 1
This commit is contained in:
parent
750af43ff3
commit
3ca857a589
@ -39,6 +39,7 @@ DefaultAgentPrompt = f"""
|
|||||||
***单集创作 智能体*** 名称:`episode_create` 描述:用户确认`剧本圣经`后,交给该智能体来构建某一集的具体创作;注意该智能体仅负责单集的创作,因此该智能体的调度需要有你根据`剧本圣经`中的`核心大纲`来多次调用,逐步完成所有剧集的创作;对于某一集的具体修改直接交给该智能体;
|
***单集创作 智能体*** 名称:`episode_create` 描述:用户确认`剧本圣经`后,交给该智能体来构建某一集的具体创作;注意该智能体仅负责单集的创作,因此该智能体的调度需要有你根据`剧本圣经`中的`核心大纲`来多次调用,逐步完成所有剧集的创作;对于某一集的具体修改直接交给该智能体;
|
||||||
|
|
||||||
***注意:智能体调用后最终会返回再次请求到你,你需要根据智能体的处理结果来决定下一步***
|
***注意:智能体调用后最终会返回再次请求到你,你需要根据智能体的处理结果来决定下一步***
|
||||||
|
***注意:`智能体调用` 不是工具方法的使用,而是在返回数据中把agent属性指定为要调用的智能体名称***
|
||||||
|
|
||||||
# 工具使用
|
# 工具使用
|
||||||
上述智能体职责中提及的输出内容,都有对应的工具可供你调用进行查看;他们的查询工具名称分别对应如下:
|
上述智能体职责中提及的输出内容,都有对应的工具可供你调用进行查看;他们的查询工具名称分别对应如下:
|
||||||
@ -54,8 +55,10 @@ DefaultAgentPrompt = f"""
|
|||||||
未完成的集数: `QueryUnfinishedEpisodeCount`
|
未完成的集数: `QueryUnfinishedEpisodeCount`
|
||||||
已完成的集数: `QueryCompletedEpisodeCount`
|
已完成的集数: `QueryCompletedEpisodeCount`
|
||||||
|
|
||||||
|
***注意:工具使用是需要你调用工具方法的;但是大多数情况下,你不需要查询文本的具体内容,只需要查询存在与否即可***
|
||||||
|
|
||||||
***每次用户的输入都会携带当前`总任务的进度与任务状态`,注意查看并分析是否应该回复用户等待或提醒用户确认继续下一步***
|
***每次用户的输入都会携带当前`总任务的进度与任务状态`,注意查看并分析是否应该回复用户等待或提醒用户确认继续下一步***
|
||||||
# 总任务的进度与任务状态数据结构为 {{"step": "waiting_script", "status": "running", "from_type":"user", "reason": "waiting_script", "retry_count": 0}}
|
# 总任务的进度与任务状态数据结构为 {{"step": "waiting_script", "status": "running", "from_type":"user", "reason": "waiting_script", "retry_count": 0, "query_args":{{}}}}
|
||||||
|
|
||||||
step: 阶段名称
|
step: 阶段名称
|
||||||
wait_for_input: 等待用户提供原始剧本
|
wait_for_input: 等待用户提供原始剧本
|
||||||
@ -80,9 +83,12 @@ DefaultAgentPrompt = f"""
|
|||||||
|
|
||||||
"retry_count": 重试次数
|
"retry_count": 重试次数
|
||||||
|
|
||||||
|
"query_args": 用于调用工具方法的参数,可能包括的字段有:
|
||||||
|
"session_id": 会话ID,可用于查询`原始剧本`
|
||||||
|
|
||||||
# 职责
|
# 职责
|
||||||
分析用户输入与`总任务的进度与任务状态`,以下是几种情况的示例:
|
分析用户输入与`总任务的进度与任务状态`,以下是几种情况的示例:
|
||||||
1 `wait_for_input` 向用户问好,并介绍你作为“爆款短剧操盘手”的身份和专业工作流程,礼貌地请用户提供需要改编的原始剧本。如果用户没有提供原始剧本,你将持续友好地提醒,此时状态始终为waiting,直到获取原始剧本为止。
|
1 `wait_for_input` 向用户问好,并介绍你作为“爆款短剧操盘手”的身份和专业工作流程,礼貌地请用户提供需要改编的原始剧本。如果用户没有提供原始剧本,你将持续友好地提醒,此时状态始终为waiting,直到获取原始剧本为止。从用户提交的中可以获取到session_id的时候,需要调用 `QueryOriginalScript` 工具来查询原始剧本是否存在。
|
||||||
2 `script_analysis` 读取到原始剧本并从输入中分析出可以继续后进入,调用`原始剧本分析 智能体`继续后续工作;running时,礼貌回复用户并提醒用户任务真正进行中;completed代表任务完成,此时可等待用户反馈;直到跟用户确认可以进行下一步后再继续后续任务;
|
2 `script_analysis` 读取到原始剧本并从输入中分析出可以继续后进入,调用`原始剧本分析 智能体`继续后续工作;running时,礼貌回复用户并提醒用户任务真正进行中;completed代表任务完成,此时可等待用户反馈;直到跟用户确认可以进行下一步后再继续后续任务;
|
||||||
3 `strategic_planning` 根据`诊断与资产评估`的结果,调用`确立改编目标 智能体`,并返回结果。
|
3 `strategic_planning` 根据`诊断与资产评估`的结果,调用`确立改编目标 智能体`,并返回结果。
|
||||||
4 `build_bible` 根据`改编思路`的结果,调用`剧本圣经构建 智能体`,并返回结果。
|
4 `build_bible` 根据`改编思路`的结果,调用`剧本圣经构建 智能体`,并返回结果。
|
||||||
@ -98,7 +104,7 @@ DefaultAgentPrompt = f"""
|
|||||||
"agent":'',//分析后得出由哪个智能体继续任务,此处为智能体名称;如果需要继续与用户交互或仅需要回复用户则为空字符串
|
"agent":'',//分析后得出由哪个智能体继续任务,此处为智能体名称;如果需要继续与用户交互或仅需要回复用户则为空字符串
|
||||||
"message":'',//回复给用户的内容
|
"message":'',//回复给用户的内容
|
||||||
"retry_count":0,//重试次数
|
"retry_count":0,//重试次数
|
||||||
"node":'',//下一个节点名称
|
"node":'',//下一个节点名称,根据指定的agent名称,从取值范围列表中选择一个节点名称返回
|
||||||
}}
|
}}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import os
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Any, cast
|
from typing import Dict, Iterable, List, Optional, Any, cast
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from volcenginesdkarkruntime import Ark
|
from volcenginesdkarkruntime import Ark
|
||||||
from volcengine.visual.VisualService import VisualService
|
from volcengine.visual.VisualService import VisualService
|
||||||
@ -12,6 +12,8 @@ from volcengine.vod.models.request.request_vod_pb2 import VodUploadMediaRequest
|
|||||||
from volcengine.vod.models.request.request_vod_pb2 import VodUrlUploadRequest
|
from volcengine.vod.models.request.request_vod_pb2 import VodUrlUploadRequest
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
from config import API_CONFIG
|
from config import API_CONFIG
|
||||||
|
|
||||||
class HuoshanAPI:
|
class HuoshanAPI:
|
||||||
@ -507,32 +509,23 @@ class HuoshanAPI:
|
|||||||
'error': f'删除异常: {str(e)}'
|
'error': f'删除异常: {str(e)}'
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_chat_response(self, prompt: str, model: Optional[str] = None, system: Optional[str] = None, temperature: float = 0.6) -> str:
|
def get_chat_response(
|
||||||
"""
|
self,
|
||||||
获取聊天机器人回复 - 使用doubao_seed_1.6模型
|
messages: List[Dict],
|
||||||
:param prompt: 用户输入的文本
|
model: Optional[str] = None,
|
||||||
:param model: 模型名称,默认使用doubao_seed_1.6
|
tools: Optional[List[ChatCompletionToolParam]] = None,
|
||||||
:param system: 系统提示词
|
temperature: float = 0.6
|
||||||
:param temperature: 温度参数
|
) -> Dict[str, Any]:
|
||||||
:return: 机器人回复的文本
|
"""同步获取聊天响应,支持工具调用。"""
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
if model is None:
|
if model is None:
|
||||||
model = self.doubao_seed_1_6_model_id
|
model = self.doubao_seed_1_6_model_id
|
||||||
messages:Any = []
|
if not messages or len(messages)==0:
|
||||||
if system:
|
raise ValueError(f"火山引擎API调用失败: 消息不能为空")
|
||||||
messages = [
|
completion = self.client.chat.completions.create(
|
||||||
{"role": "system", "content": system},
|
|
||||||
{"role": "user", "content": prompt}
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": prompt}
|
|
||||||
]
|
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages, # type: ignore
|
||||||
max_tokens=16384, # 16K
|
max_tokens=16384, # 16K
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
timeout=600,
|
timeout=600,
|
||||||
@ -541,55 +534,39 @@ class HuoshanAPI:
|
|||||||
# "type": "enabled", # 使用深度思考能力
|
# "type": "enabled", # 使用深度思考能力
|
||||||
# "type": "auto", # 模型自行判断是否使用深度思考能力
|
# "type": "auto", # 模型自行判断是否使用深度思考能力
|
||||||
},
|
},
|
||||||
|
tools=tools # 传入 tools 参数
|
||||||
)
|
)
|
||||||
|
return completion.model_dump() # type: ignore 使用 model_dump() 转换为字典
|
||||||
return response.choices[0].message.content # pyright: ignore
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f'Huoshan chat API调用异常: {str(e)}')
|
raise ValueError(f"火山引擎API调用失败: {str(e)}")
|
||||||
|
|
||||||
def get_chat_response_stream(self, prompt: str, model: Optional[str] = None, system: Optional[str] = None, temperature: float = 0.6):
|
def get_chat_response_stream(
|
||||||
"""
|
self,
|
||||||
获取聊天机器人的流式回复 - 使用doubao_seed_1.6模型
|
messages: List[Dict],
|
||||||
:param prompt: 用户输入的文本
|
model: Optional[str] = None,
|
||||||
:param model: 模型名称,默认使用doubao_seed_1.6
|
tools: Optional[List[ChatCompletionToolParam]] = None,
|
||||||
:param system: 系统提示词
|
temperature: float = 0.6
|
||||||
:param temperature: 温度参数
|
) -> Iterable[str]:
|
||||||
:return: 生成器,逐步返回机器人回复的文本
|
"""流式获取聊天响应,支持工具调用。"""
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
if model is None:
|
if model is None:
|
||||||
model = self.doubao_seed_1_6_model_id
|
model = self.doubao_seed_1_6_model_id
|
||||||
|
|
||||||
messages:Any = []
|
completion = self.client.chat.completions.create(
|
||||||
if system:
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": system},
|
|
||||||
{"role": "user", "content": prompt}
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": prompt}
|
|
||||||
]
|
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages, # type: ignore
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=16384, # 16K
|
max_tokens=16384, # 16K
|
||||||
timeout=600,
|
timeout=600,
|
||||||
stream=True
|
stream=True,
|
||||||
|
tools=tools # 传入 tools 参数
|
||||||
)
|
)
|
||||||
|
for chunk in completion:
|
||||||
|
if chunk.choices and chunk.choices[0].delta.content is not None: # type: ignore
|
||||||
for chunk in response:
|
yield chunk.choices[0].delta.content # type: ignore
|
||||||
chunk_obj = cast(Any, chunk)
|
|
||||||
if hasattr(chunk_obj, 'choices') and chunk_obj.choices and len(chunk_obj.choices) > 0:
|
|
||||||
delta = chunk_obj.choices[0].delta
|
|
||||||
if hasattr(delta, 'content') and delta.content is not None:
|
|
||||||
yield delta.content
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f'Huoshan chat stream API调用异常: {str(e)}')
|
raise ValueError(f"火山引擎API流式调用失败: {str(e)}")
|
||||||
|
|
||||||
def analyze_image(self, image_url: str, prompt: str = "请描述这张图片的内容", model: Optional[str] = None, detail: str = "high") -> Dict[str, Any]:
|
def analyze_image(self, image_url: str, prompt: str = "请描述这张图片的内容", model: Optional[str] = None, detail: str = "high") -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
173
doc/test.txt
Normal file
173
doc/test.txt
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
为了在你的代码中正确接入火山引擎的工具调用功能,你需要修改 `huoshan_langchain.py` 和 `huoshan.py` 这两个文件,以实现**从 LangChain 工具到火山引擎 API 工具定义的格式转换**,以及**解析和处理来自 API 的工具调用响应**。
|
||||||
|
|
||||||
|
下面是根据火山引擎官方文档和你的代码,我为你整理的完整修改方案。
|
||||||
|
|
||||||
|
-----
|
||||||
|
|
||||||
|
### 第一步:修改 `huoshan_langchain.py`
|
||||||
|
|
||||||
|
这个文件是 LangChain 的封装层,负责连接你的工作流和火山引擎的底层 API。你需要在这里实现 `bind_tools` 和 `_generate` 方法来处理工具调用。
|
||||||
|
|
||||||
|
1. **导入必要的类**:
|
||||||
|
需要添加 `ToolMessage` 和 `ToolCall`,它们是 LangChain 用于表示工具调用结果和工具调用的核心类。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence
|
||||||
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
from langchain.pydantic_v1 import BaseModel
|
||||||
|
from pydantic import Field
|
||||||
|
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
|
||||||
|
from api.huoshan import HuoshanAPI
|
||||||
|
import json
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **实现 `bind_tools` 方法**:
|
||||||
|
这个方法是 LangChain 用于将工具定义传递给你的模型封装。在这里,你需要将 LangChain 的 `BaseTool` 对象转换为火山引擎 API 所需的 `ChatCompletionToolParam` 格式。
|
||||||
|
|
||||||
|
```python
|
||||||
|
class HuoshanChatModel(BaseChatModel):
|
||||||
|
# ... (其他代码不变) ...
|
||||||
|
|
||||||
|
def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable:
|
||||||
|
"""将工具绑定到模型,并将其转换为火山引擎API所需的格式。"""
|
||||||
|
tool_definitions = []
|
||||||
|
for tool_item in tools:
|
||||||
|
tool_definitions.append(
|
||||||
|
ChatCompletionToolParam(
|
||||||
|
type="function",
|
||||||
|
function={
|
||||||
|
"name": tool_item.name,
|
||||||
|
"description": tool_item.description,
|
||||||
|
"parameters": tool_item.args_schema.schema() if isinstance(tool_item.args_schema, type(BaseModel)) else tool_item.args_schema
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 返回一个绑定了工具的新实例
|
||||||
|
# 这里我们使用_bind方法,它会返回一个新的Runnable实例
|
||||||
|
return self._bind(tools=tool_definitions, **kwargs)
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **修改 `_convert_messages_to_prompt` 方法**:
|
||||||
|
这个方法需要能够处理 LangChain 的 `ToolMessage` 和 `AIMessage`,并将其转换为火山引擎 API 的消息格式。这对于工具调用的回填和最终回复至关重要。
|
||||||
|
|
||||||
|
```python
|
||||||
|
class HuoshanChatModel(BaseChatModel):
|
||||||
|
# ... (其他代码不变) ...
|
||||||
|
|
||||||
|
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> List[Dict]:
|
||||||
|
"""将LangChain消息转换为火山引擎API所需的格式。"""
|
||||||
|
api_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
api_messages.append({"role": "user", "content": msg.content})
|
||||||
|
elif isinstance(msg, AIMessage):
|
||||||
|
if msg.tool_calls:
|
||||||
|
api_messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc.name,
|
||||||
|
"arguments": json.dumps(tc.args)
|
||||||
|
}
|
||||||
|
} for tc in msg.tool_calls]
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
api_messages.append({"role": "assistant", "content": msg.content})
|
||||||
|
elif isinstance(msg, ToolMessage):
|
||||||
|
api_messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"content": msg.content,
|
||||||
|
"tool_call_id": msg.tool_call_id
|
||||||
|
})
|
||||||
|
return api_messages
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **修改 `_generate` 方法**:
|
||||||
|
这个方法需要调用底层 API,并解析大模型返回的响应,以检查是否包含工具调用。
|
||||||
|
|
||||||
|
```python
|
||||||
|
class HuoshanChatModel(BaseChatModel):
|
||||||
|
# ... (其他代码不变) ...
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if not self._api:
|
||||||
|
raise ValueError("HuoshanAPI未正确初始化")
|
||||||
|
|
||||||
|
api_messages = self._convert_messages_to_prompt(messages)
|
||||||
|
tools = kwargs.get("tools", [])
|
||||||
|
|
||||||
|
response_data = self._api.get_chat_response(messages=api_messages, tools=tools)
|
||||||
|
|
||||||
|
try:
|
||||||
|
message_from_api = response_data.get("choices", [{}])[0].get("message", {})
|
||||||
|
|
||||||
|
tool_calls = message_from_api.get("tool_calls", [])
|
||||||
|
if tool_calls:
|
||||||
|
lc_tool_calls = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
lc_tool_calls.append(ToolCall(
|
||||||
|
name=tc["function"]["name"],
|
||||||
|
args=json.loads(tc["function"]["arguments"]),
|
||||||
|
id=tc.get("id", "")
|
||||||
|
))
|
||||||
|
message = AIMessage(content="", tool_calls=lc_tool_calls)
|
||||||
|
else:
|
||||||
|
content = message_from_api.get("content", "")
|
||||||
|
message = AIMessage(content=content)
|
||||||
|
|
||||||
|
generation = ChatGeneration(message=message)
|
||||||
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"处理火山引擎API响应失败: {str(e)}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 第二步:修改 `huoshan.py`
|
||||||
|
|
||||||
|
这个文件是底层 API 客户端,负责与火山引擎 API 进行通信。你需要修改 `get_chat_response` 方法,使其能够发送 `tools` 参数。
|
||||||
|
|
||||||
|
```python
|
||||||
|
# huoshan.py
|
||||||
|
|
||||||
|
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
|
class HuoshanAPI:
|
||||||
|
# ... (其他代码不变) ...
|
||||||
|
|
||||||
|
def get_chat_response(
|
||||||
|
self,
|
||||||
|
messages: List[Dict],
|
||||||
|
stream: bool = False,
|
||||||
|
tools: Optional[List[ChatCompletionToolParam]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""同步获取聊天响应,支持工具调用。"""
|
||||||
|
client = Ark()
|
||||||
|
|
||||||
|
try:
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model=self.doubao_seed_1_6_model_id,
|
||||||
|
messages=messages,
|
||||||
|
stream=stream,
|
||||||
|
tools=tools # 传入 tools 参数
|
||||||
|
)
|
||||||
|
return completion.model_dump() # 使用 model_dump() 转换为字典
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"火山引擎API调用失败: {str(e)}")
|
||||||
|
```
|
||||||
|
|
||||||
|
完成以上修改后,你的 `HuoshanChatModel` 就会支持工具调用,并能与 LangGraph 的智能体框架无缝集成。
|
||||||
@ -21,6 +21,8 @@ import config
|
|||||||
|
|
||||||
from tools.database.mongo import client # type: ignore
|
from tools.database.mongo import client # type: ignore
|
||||||
from langgraph.checkpoint.mongodb import MongoDBSaver
|
from langgraph.checkpoint.mongodb import MongoDBSaver
|
||||||
|
# 工具方法
|
||||||
|
from tools.agent.queryDB import QueryOriginalScript
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -32,7 +34,7 @@ def replace_value(old_val, new_val):
|
|||||||
# 状态类型定义
|
# 状态类型定义
|
||||||
class InputState(TypedDict):
|
class InputState(TypedDict):
|
||||||
"""工作流输入状态"""
|
"""工作流输入状态"""
|
||||||
input_data: Annotated[list[AnyMessage], operator.add]
|
messages: Annotated[list[AnyMessage], operator.add]
|
||||||
from_type: Annotated[str, replace_value]
|
from_type: Annotated[str, replace_value]
|
||||||
session_id: Annotated[str, replace_value]
|
session_id: Annotated[str, replace_value]
|
||||||
|
|
||||||
@ -55,7 +57,7 @@ class NodeInfo(TypedDict):
|
|||||||
class ScriptwriterState(TypedDict, total=False):
|
class ScriptwriterState(TypedDict, total=False):
|
||||||
"""智能编剧工作流整体状态"""
|
"""智能编剧工作流整体状态"""
|
||||||
# 输入数据
|
# 输入数据
|
||||||
input_data: Annotated[list[HumanMessage], operator.add]
|
messages: Annotated[list[AnyMessage], operator.add]
|
||||||
session_id: Annotated[str, replace_value]
|
session_id: Annotated[str, replace_value]
|
||||||
from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent]
|
from_type: Annotated[str, replace_value] # 本次请求来着哪里 [user, agent]
|
||||||
|
|
||||||
@ -113,7 +115,9 @@ class ScriptwriterGraph:
|
|||||||
print("创建智能体")
|
print("创建智能体")
|
||||||
# 调度智能体
|
# 调度智能体
|
||||||
self.schedulerAgent = SchedulerAgent(
|
self.schedulerAgent = SchedulerAgent(
|
||||||
tools=[],
|
tools=[
|
||||||
|
QueryOriginalScript,
|
||||||
|
],
|
||||||
SchedulerList=[
|
SchedulerList=[
|
||||||
{
|
{
|
||||||
"name": "scheduler_node",
|
"name": "scheduler_node",
|
||||||
@ -225,14 +229,33 @@ class ScriptwriterGraph:
|
|||||||
logger.error(f"构建工作流图失败: {e}")
|
logger.error(f"构建工作流图失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# --- 定义图中的节点 ---
|
# --- 定义图中的节点 ---
|
||||||
async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
async def scheduler_node(self, state: ScriptwriterState)-> ScriptwriterState:
|
||||||
"""调度节点"""
|
"""调度节点"""
|
||||||
try:
|
try:
|
||||||
session_id = state.get("session_id", "")
|
session_id = state.get("session_id", "")
|
||||||
from_type = state.get("from_type", "")
|
from_type = state.get("from_type", "")
|
||||||
input_data = state.get("input_data", [])
|
messages = state.get("messages", [])
|
||||||
logger.info(f"调度节点 {session_id} 输入参数: {input_data} from_type:{from_type}")
|
|
||||||
|
workflow_step = state.get("workflow_step", "wait_for_input")
|
||||||
|
workflow_status = state.get("workflow_status", "waiting")
|
||||||
|
workflow_reason = state.get("workflow_reason", "")
|
||||||
|
workflow_retry_count = int(state.get("workflow_retry_count", 0))
|
||||||
|
# 添加参数进提示词
|
||||||
|
messages.append(HumanMessage(content=f"""
|
||||||
|
# 总任务的进度与任务状态:
|
||||||
|
{{
|
||||||
|
'query_args':{{
|
||||||
|
'session_id':'{session_id}',
|
||||||
|
}},
|
||||||
|
'step':'{workflow_step}',
|
||||||
|
'status':'{workflow_status}',
|
||||||
|
'from_type':'{from_type}',
|
||||||
|
'reason':'{workflow_reason}',
|
||||||
|
'retry_count':{workflow_retry_count},
|
||||||
|
}}
|
||||||
|
"""))
|
||||||
|
logger.info(f"调度节点 {session_id} 输入参数: {messages} from_type:{from_type}")
|
||||||
reslut = await self.schedulerAgent.ainvoke(state)
|
reslut = await self.schedulerAgent.ainvoke(state)
|
||||||
ai_message_str = reslut['messages'][-1].content
|
ai_message_str = reslut['messages'][-1].content
|
||||||
ai_message = json.loads(ai_message_str)
|
ai_message = json.loads(ai_message_str)
|
||||||
@ -259,6 +282,8 @@ class ScriptwriterGraph:
|
|||||||
"agent_message":return_message,
|
"agent_message":return_message,
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
return {
|
return {
|
||||||
"next_node":'end_node',
|
"next_node":'end_node',
|
||||||
"agent_message": "执行失败",
|
"agent_message": "执行失败",
|
||||||
@ -317,11 +342,11 @@ class ScriptwriterGraph:
|
|||||||
"agent_message": state.get('agent_message', ''),
|
"agent_message": state.get('agent_message', ''),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def run(self, session_id: str, input_data: list[AnyMessage], thread_id: str|None = None) -> OutputState:
|
async def run(self, session_id: str, messages: list[AnyMessage], thread_id: str|None = None) -> OutputState:
|
||||||
"""运行工作流
|
"""运行工作流
|
||||||
Args:
|
Args:
|
||||||
session_id: 会话ID
|
session_id: 会话ID
|
||||||
input_data: 输入数据
|
messages: 输入数据
|
||||||
thread_id: 线程ID
|
thread_id: 线程ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -334,7 +359,7 @@ class ScriptwriterGraph:
|
|||||||
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
|
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
|
||||||
# 初始化状态
|
# 初始化状态
|
||||||
initial_state: InputState = {
|
initial_state: InputState = {
|
||||||
'input_data': input_data,
|
'messages': messages,
|
||||||
'session_id': session_id,
|
'session_id': session_id,
|
||||||
'from_type': 'user',
|
'from_type': 'user',
|
||||||
}
|
}
|
||||||
@ -456,7 +481,7 @@ if __name__ == "__main__":
|
|||||||
session_id = "68c2c2915e5746343301ef71"
|
session_id = "68c2c2915e5746343301ef71"
|
||||||
result = await graph.run(
|
result = await graph.run(
|
||||||
session_id,
|
session_id,
|
||||||
[HumanMessage(content="你好编剧,我想写小说!")],
|
[HumanMessage(content="老师,我写好剧本了,您看看!帮我分析分析把!")],
|
||||||
session_id
|
session_id
|
||||||
)
|
)
|
||||||
print(f"最终结果: {result}")
|
print(f"最终结果: {result}")
|
||||||
|
|||||||
3
tools/agent/__init__.py
Normal file
3
tools/agent/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
智能体工具模块
|
||||||
|
"""
|
||||||
38
tools/agent/queryDB.py
Normal file
38
tools/agent/queryDB.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from bson import ObjectId
|
||||||
|
from tools.database.mongo import mainDB
|
||||||
|
from langchain.tools import tool
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def QueryOriginalScript(session_id: str, only_exist: bool = False):
|
||||||
|
"""
|
||||||
|
查询原始剧本内容或是否存在
|
||||||
|
Args:
|
||||||
|
session_id: 会话id
|
||||||
|
only_exist: 是否只查询存在的剧本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 返回一个包含以下字段的字典:
|
||||||
|
original_script (str): 原始剧本内容。仅当 only_exist 为 False 时返回该字段。
|
||||||
|
exist (bool): 原始剧本内容是否存在。
|
||||||
|
|
||||||
|
"""
|
||||||
|
# c = mainDB.agent_writer_session.count_documents({})
|
||||||
|
# print(f"查询到的原始剧本session_id: {session_id}, only_exist: {only_exist} count:{c}")
|
||||||
|
if only_exist:
|
||||||
|
script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "original_script": {"$exists": True, "$ne": ""}})
|
||||||
|
# print(f"exist: {script}")
|
||||||
|
return {
|
||||||
|
"original_script": "",
|
||||||
|
"exist": script is not None,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id)}, {"original_script": 1})
|
||||||
|
original_script = ""
|
||||||
|
if script:
|
||||||
|
original_script = script["original_script"] or ''
|
||||||
|
print(f"查询到的原始剧本字符长度: {len(original_script)}")
|
||||||
|
return {
|
||||||
|
"original_script": original_script,
|
||||||
|
"exist": original_script != '',
|
||||||
|
}
|
||||||
|
|
||||||
@ -1,11 +1,32 @@
|
|||||||
from typing import Any, Dict, Iterator, List, Optional
|
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence
|
||||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall
|
||||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
import json
|
||||||
|
import copy
|
||||||
|
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
|
||||||
|
from volcenginesdkarkruntime.types.shared_params.function_definition import FunctionDefinition
|
||||||
|
from volcenginesdkarkruntime.types.shared_params.function_parameters import FunctionParameters
|
||||||
from api.huoshan import HuoshanAPI # 导入你现有的API类
|
from api.huoshan import HuoshanAPI # 导入你现有的API类
|
||||||
|
|
||||||
|
def _convert_dict_type(data: Any) -> Any:
|
||||||
|
"""递归地将字典中的"Dict"字符串类型转换为"dict"."""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
new_data = {}
|
||||||
|
for k, v in data.items():
|
||||||
|
if isinstance(v, str) and v == "Dict":
|
||||||
|
new_data[k] = "dict"
|
||||||
|
else:
|
||||||
|
new_data[k] = _convert_dict_type(v)
|
||||||
|
return new_data
|
||||||
|
elif isinstance(data, list):
|
||||||
|
return [_convert_dict_type(item) for item in data]
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
class HuoshanChatModel(BaseChatModel):
|
class HuoshanChatModel(BaseChatModel):
|
||||||
"""火山引擎聊天模型的LangChain封装"""
|
"""火山引擎聊天模型的LangChain封装"""
|
||||||
@ -28,25 +49,137 @@ class HuoshanChatModel(BaseChatModel):
|
|||||||
"""返回LLM类型标识"""
|
"""返回LLM类型标识"""
|
||||||
return "huoshan_chat"
|
return "huoshan_chat"
|
||||||
|
|
||||||
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> tuple[str, str]:
|
# def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> tuple[str, str]:
|
||||||
"""将LangChain消息格式转换为API所需的prompt和system格式"""
|
# """将LangChain消息格式转换为API所需的prompt和system格式"""
|
||||||
system_message = ""
|
# system_message = ""
|
||||||
user_messages = []
|
# user_messages = []
|
||||||
|
|
||||||
for message in messages:
|
# for message in messages:
|
||||||
if isinstance(message, SystemMessage):
|
# if isinstance(message, SystemMessage):
|
||||||
system_message = message.content or ""
|
# system_message = message.content or ""
|
||||||
elif isinstance(message, HumanMessage):
|
# elif isinstance(message, HumanMessage):
|
||||||
user_messages.append(message.content)
|
# user_messages.append(message.content)
|
||||||
elif isinstance(message, AIMessage):
|
# elif isinstance(message, SystemMessage):
|
||||||
# 如果需要支持多轮对话,可以在这里处理
|
# # 如果需要支持多轮对话,可以在这里处理
|
||||||
pass
|
# pass
|
||||||
|
|
||||||
# 合并用户消息
|
# # 合并用户消息
|
||||||
prompt = "\n".join(user_messages) if user_messages else ""
|
# prompt = "\n".join(user_messages) if user_messages else ""
|
||||||
|
|
||||||
return prompt, str(system_message)
|
# return prompt, str(system_message)
|
||||||
|
|
||||||
|
|
||||||
|
def bind_tools(self, tools: Sequence[BaseTool], **kwargs: Any) -> Runnable:
|
||||||
|
"""将工具绑定到模型,并将其转换为火山引擎API所需的格式。"""
|
||||||
|
# 1. 创建一个当前模型的副本
|
||||||
|
# new_model = copy.copy(self)
|
||||||
|
tool_definitions = []
|
||||||
|
for tool_item in tools:
|
||||||
|
# 1. 获取工具的参数 JSON Schema
|
||||||
|
if hasattr(tool_item.args_schema, 'model_json_schema'):
|
||||||
|
# 如果是 Pydantic 模型,调用其方法获取 JSON Schema
|
||||||
|
parameters_schema = tool_item.args_schema.model_json_schema() # type: ignore
|
||||||
|
else:
|
||||||
|
parameters_schema = _convert_dict_type(parameters_schema)
|
||||||
|
|
||||||
|
# 2. 从字典中提取并修正参数
|
||||||
|
# SDK要求 'type' 为 'object','properties' 和 'required' 字段为必填
|
||||||
|
schema_type = parameters_schema.get("type", "object")
|
||||||
|
schema_properties = parameters_schema.get("properties", {})
|
||||||
|
schema_required = parameters_schema.get("required", [])
|
||||||
|
|
||||||
|
# 3. 将修正后的字典作为参数,直接创建 FunctionParameters 对象
|
||||||
|
# 这是最安全的方式,因为它绕过了之前可能存在的类型转换问题
|
||||||
|
function_parameters:FunctionParameters = {
|
||||||
|
"type":schema_type,
|
||||||
|
"properties":schema_properties,
|
||||||
|
"required":schema_required
|
||||||
|
}
|
||||||
|
|
||||||
|
# 4. 将 FunctionParameters 对象封装成 FunctionDefinition
|
||||||
|
function_def:FunctionDefinition = {
|
||||||
|
"name":tool_item.name,
|
||||||
|
"description":tool_item.description,
|
||||||
|
"parameters":function_parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. 将 FunctionDefinition 对象封装成 ChatCompletionToolParam
|
||||||
|
tool_definitions.append(
|
||||||
|
ChatCompletionToolParam(
|
||||||
|
type="function",
|
||||||
|
function=function_def
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 返回一个绑定了工具的新实例
|
||||||
|
# 这里我们使用_bind方法,它会返回一个新的Runnable实例
|
||||||
|
# return self._bind(tools=tool_definitions, **kwargs) # type: ignore
|
||||||
|
# return self.bind(tools=tool_definitions, **kwargs)
|
||||||
|
return super().bind(tools=tool_definitions, **kwargs)
|
||||||
|
|
||||||
|
# new_model._bound_tools = tool_definitions
|
||||||
|
|
||||||
|
# 3. 返回新的、已绑定工具的模型实例
|
||||||
|
# return new_model
|
||||||
|
|
||||||
|
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> List[Dict]:
|
||||||
|
"""将LangChain消息转换为火山引擎API所需的格式。"""
|
||||||
|
# print(f" 原始 messages: {messages}")
|
||||||
|
|
||||||
|
# 使用字典来存储每个tool_call_id的最新消息
|
||||||
|
last_tool_messages = {}
|
||||||
|
# 临时存储其他消息(非ToolMessage)
|
||||||
|
other_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, ToolMessage):
|
||||||
|
# 将最新的ToolMessage存入字典,这会覆盖之前的旧消息
|
||||||
|
last_tool_messages[msg.tool_call_id] = msg
|
||||||
|
else:
|
||||||
|
other_messages.append(msg)
|
||||||
|
# 从后向前遍历,只保留每个tool_call_id的第一个(即最新的)消息
|
||||||
|
deduplicated_messages = []
|
||||||
|
seen_tool_call_ids = set()
|
||||||
|
|
||||||
|
for msg in reversed(messages):
|
||||||
|
# 如果是ToolMessage,并且tool_call_id已经存在,则跳过
|
||||||
|
if isinstance(msg, ToolMessage):
|
||||||
|
if msg.tool_call_id not in seen_tool_call_ids:
|
||||||
|
seen_tool_call_ids.add(msg.tool_call_id)
|
||||||
|
deduplicated_messages.append(msg)
|
||||||
|
# 其他类型的消息直接添加
|
||||||
|
else:
|
||||||
|
deduplicated_messages.append(msg)
|
||||||
|
|
||||||
|
# 将消息列表还原为原始时间顺序
|
||||||
|
messages_to_convert = list(reversed(deduplicated_messages))
|
||||||
|
|
||||||
|
# 2. 将去重后的消息列表转换为API所需的格式
|
||||||
|
api_messages = []
|
||||||
|
for msg in messages_to_convert:
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
api_messages.append({"role": "user", "content": msg.content})
|
||||||
|
elif isinstance(msg, SystemMessage):
|
||||||
|
api_messages.append({"role": "system", "content": msg.content})
|
||||||
|
elif isinstance(msg, ToolMessage):
|
||||||
|
api_messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"content": msg.content,
|
||||||
|
"tool_call_id": msg.tool_call_id
|
||||||
|
})
|
||||||
|
# api_messages = []
|
||||||
|
# for msg in messages:
|
||||||
|
# if isinstance(msg, HumanMessage):
|
||||||
|
# api_messages.append({"role": "user", "content": msg.content})
|
||||||
|
# elif isinstance(msg, SystemMessage):
|
||||||
|
# api_messages.append({"role": "system", "content": msg.content})
|
||||||
|
# elif isinstance(msg, ToolMessage):
|
||||||
|
# api_messages.append({
|
||||||
|
# "role": "tool",
|
||||||
|
# "content": msg.content,
|
||||||
|
# "tool_call_id": msg.tool_call_id
|
||||||
|
# })
|
||||||
|
return api_messages
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -54,38 +187,45 @@ class HuoshanChatModel(BaseChatModel):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""生成聊天回复"""
|
|
||||||
if not self._api:
|
if not self._api:
|
||||||
raise ValueError("HuoshanAPI未正确初始化")
|
raise ValueError("HuoshanAPI未正确初始化")
|
||||||
|
|
||||||
# 转换消息格式
|
api_messages = self._convert_messages_to_prompt(messages)
|
||||||
prompt, system = self._convert_messages_to_prompt(messages)
|
tools = kwargs.get("tools", [])
|
||||||
|
|
||||||
# 合并参数
|
print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}")
|
||||||
generation_kwargs = {
|
|
||||||
"model": kwargs.get("model", self.model_name),
|
|
||||||
"temperature": kwargs.get("temperature", self.temperature),
|
|
||||||
"system": system
|
|
||||||
}
|
|
||||||
|
|
||||||
|
response_data = self._api.get_chat_response(messages=api_messages, tools=tools)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用你的API
|
res_choices = response_data.get("choices", [{}])[0]
|
||||||
response_text = self._api.get_chat_response(
|
finish_reason = res_choices.get("finish_reason", "")
|
||||||
prompt=prompt,
|
message_from_api = res_choices.get("message", {})
|
||||||
**generation_kwargs
|
tool_calls = message_from_api.get("tool_calls", [])
|
||||||
)
|
print(f" 豆包返回的 finish_reason: {finish_reason} \n tool_calls: {tool_calls} \n")
|
||||||
|
print(f" 豆包返回的 message: {message_from_api.get('content', '')}")
|
||||||
# 创建AI消息
|
if finish_reason == "tool_calls" and tool_calls:
|
||||||
message = AIMessage(content=response_text)
|
lc_tool_calls = []
|
||||||
|
for tc in tool_calls:
|
||||||
# 创建生成结果
|
function_dict = tc.get("function", {})
|
||||||
|
lc_tool_calls.append(ToolCall(
|
||||||
|
name=function_dict.get("name", ""),
|
||||||
|
args=json.loads(function_dict.get("arguments", "{}")),
|
||||||
|
id=tc.get("id", "")
|
||||||
|
))
|
||||||
|
message = AIMessage(content="", tool_calls=lc_tool_calls)
|
||||||
|
else:
|
||||||
|
content = message_from_api.get("content", "")
|
||||||
|
message = AIMessage(content=content)
|
||||||
|
|
||||||
generation = ChatGeneration(message=message)
|
generation = ChatGeneration(message=message)
|
||||||
|
|
||||||
return ChatResult(generations=[generation])
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"调用火山引擎API失败: {str(e)}")
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
raise ValueError(f"处理火山引擎API响应失败: {str(e)}")
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -98,20 +238,15 @@ class HuoshanChatModel(BaseChatModel):
|
|||||||
raise ValueError("HuoshanAPI未正确初始化")
|
raise ValueError("HuoshanAPI未正确初始化")
|
||||||
|
|
||||||
# 转换消息格式
|
# 转换消息格式
|
||||||
prompt, system = self._convert_messages_to_prompt(messages)
|
api_messages = self._convert_messages_to_prompt(messages)
|
||||||
|
tools = kwargs.get("tools", [])
|
||||||
# 合并参数
|
|
||||||
generation_kwargs = {
|
|
||||||
"model": kwargs.get("model", self.model_name),
|
|
||||||
"temperature": kwargs.get("temperature", self.temperature),
|
|
||||||
"system": system
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用流式API
|
# 调用流式API
|
||||||
for chunk in self._api.get_chat_response_stream(
|
for chunk in self._api.get_chat_response_stream(
|
||||||
prompt=prompt,
|
messages=api_messages,
|
||||||
**generation_kwargs
|
tools=tools,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
if chunk:
|
if chunk:
|
||||||
# 创建增量消息
|
# 创建增量消息
|
||||||
@ -134,4 +269,4 @@ class HuoshanChatModel(BaseChatModel):
|
|||||||
"model_name": self.model_name,
|
"model_name": self.model_name,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user