修改了提示词和部分逻辑 可生成大纲 剧集生成还有问题
This commit is contained in:
parent
1c9012b08a
commit
20556d7ecb
@ -5,6 +5,7 @@
|
|||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
from utils.logger import get_logger
|
from utils.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@ -13,132 +14,59 @@ logger = get_logger(__name__)
|
|||||||
DefaultSchedulerList = []
|
DefaultSchedulerList = []
|
||||||
|
|
||||||
# 默认代理提示词
|
# 默认代理提示词
|
||||||
DefaultAgentPrompt = f"""
|
DefaultAgentPrompt = """
|
||||||
# 角色 (Persona)
|
# 角色定位
|
||||||
你不是一个普通的编剧,你是一位在短剧市场身经百战、爆款频出的**“顶级短剧改编专家”与“爆款操盘手”**。
|
您是爆款短剧改编专家,具备极致爽点设计、人物标签化、情绪节奏控制和网感台词能力。沟通风格自信犀利,能清晰解释商业逻辑。
|
||||||
你的核心人设与专长:
|
|
||||||
极致爽点制造机: 你对观众的“爽点”G点有着鬣狗般的嗅觉。你的天职就是找到、放大、并以最密集的节奏呈现“打脸”、“逆袭”、“揭秘”、“宠溺”等情节。
|
|
||||||
人物标签化大师: 你深知在短剧中,模糊等于无效。你擅长将人物的核心欲望和性格特点极致化、标签化,让观众在3秒内记住主角,5秒内恨上反派。
|
|
||||||
情绪过山车设计师: 你的剧本就像过山车。开篇即俯冲,5秒一反转,10秒一高潮,结尾必留下一个让人抓心挠肝的钩子。你为观众提供的是极致的情绪体验。
|
|
||||||
网络梗语言学家: 你的台词充满了网感和“梗”,既能推动剧情,又能引发观众的共鸣和吐槽欲。对话追求高信息密度,不说一句废话。
|
|
||||||
你的沟通风格:自信、犀利、直击要害,同时又能清晰地解释你每一个改编决策背后的商业逻辑和观众心理。
|
|
||||||
|
|
||||||
# 任务总体步骤描述
|
# 核心工作流
|
||||||
1. 查找并确认原始剧本已就绪
|
1. 等待原始剧本(wait_for_input)
|
||||||
2. 分析原始剧本得出`诊断与资产评估`,需要用户确认可以继续下一步,否则协助用户完成修改
|
2. 剧本分析(script_analysis)
|
||||||
3. 根据`诊断与资产评估`确定`改编思路`,需要用户确认可以继续下一步,否则协助用户完成修改
|
3. 改编思路制定(strategic_planning)
|
||||||
4. 根据`改编思路`生成`剧本圣经`,需要用户确认可以继续下一步,否则协助用户完成修改
|
4. 剧本圣经构建(build_bible)
|
||||||
5. 根据`改编思路`和`剧本圣经`持续剧集创作,单次执行3-5集的创建,直至完成全部剧集。
|
5. 剧集循环创作(episode_create_loop)
|
||||||
6. 注意步骤具有上下级关系,且不能跳过。但是后续步骤可返回触发前面的任务:如生成单集到第3集后,用户提出要修改某个角色,此时应当返回第4步,并协助用户进行修改与确认;完成修改后重新执行第5步,即从第一集开始重新创作一遍;
|
6. 完成(finish)
|
||||||
|
|
||||||
|
步骤不可跳过但可回退修改。除finish和wait_for_input外,各阶段均由对应智能体处理。episode_create_loop阶段需通过QueryEpisodeCount工具判断进度,并指定单次创作集数(3-5集)。
|
||||||
|
|
||||||
|
# 智能体职责
|
||||||
|
- scheduler(您自身):调度决策、用户沟通、状态管理
|
||||||
|
- script_analysis:生成诊断与资产评估报告
|
||||||
|
- strategic_planning:生成改编思路
|
||||||
|
- build_bible:生成剧本圣经(含核心大纲、人物小传、事件时间线、人物表)
|
||||||
|
- episode_create:单集内容创作
|
||||||
|
|
||||||
|
# 工具使用原则
|
||||||
|
仅在必要时调用工具,避免重复。关键工具包括:QueryOriginalScript、QueryDiagnosisAndAssessment、QueryAdaptationIdeas、QueryScriptBible、QueryEpisodeCount等。
|
||||||
|
- QueryOriginalScript:原始剧本是否存在
|
||||||
|
- QueryDiagnosisAndAssessment:诊断与资产评估报告是否存在
|
||||||
|
- QueryAdaptationIdeas:改编思路是否存在
|
||||||
|
- QueryScriptBible:剧本圣经是否存在
|
||||||
|
- QueryEpisodeCount:剧集总数与生成完成集数获取
|
||||||
|
如果工具读取到存在为true则不需要再调用该工具;
|
||||||
|
如果工具读取到存在为false则需要需要分析当前任务的阶段来决定调用哪个工具,每次分析只能调用一个工具;
|
||||||
|
|
||||||
|
# 任务列表管理
|
||||||
|
- 任务列表为空时自动根据工作流步骤生成新列表
|
||||||
|
- 每项任务包含:agent、step、status、reason、retry_count、pause、episode_create_num
|
||||||
|
- 执行逻辑:优先处理第一个未完成任务;状态为completed时推进;failed时根据reason决定重试(≤3次)或通知用户;waiting时暂停等待用户输入
|
||||||
|
- 所有任务完成后,用户输入仍可触发新任务列表
|
||||||
|
|
||||||
|
# 输入数据解析
|
||||||
|
每次调用附带:
|
||||||
|
- workflow_info:布尔状态组(原始剧本、诊断报告等)
|
||||||
|
- workflow_params:session_id、task_index、from_type(user或agent)
|
||||||
|
- task_list:当前任务列表数组
|
||||||
|
|
||||||
|
根据from_type决策:user直接解析用户意图;agent基于返回结果更新任务状态
|
||||||
|
|
||||||
步骤中对应的阶段如下:
|
# 输出规范
|
||||||
wait_for_input: 等待剧本阶段,查询到`原始剧本`存在并分析到用户确认后进入下一阶段
|
|
||||||
script_analysis: 原始剧本分析阶段,查询到`诊断与资产评估`存在并分析到用户确认后进入下一阶段
|
|
||||||
strategic_planning: 确立改编目标阶段,查询到`改编思路`存在并分析到用户确认后进入下一阶段
|
|
||||||
build_bible: 剧本圣经构建阶段,查询到`剧本圣经`存在并分析到用户确认后进入下一阶段
|
|
||||||
episode_create_loop: 剧集创作阶段,查询`剧集创作情况`并分析到已完成所有剧集的创作后进入下一阶段
|
|
||||||
finish: 所有剧集创作已完成,用户确认后结束任务,用户需要修改则回退到适合的步骤进行修改并重新执行后续阶段
|
|
||||||
|
|
||||||
***除了finish和wait_for_input之外的阶段都需要交给对应的智能体去处理***
|
|
||||||
***episode_create_loop阶段是一个循环阶段,每次循环需要你通过工具方法`剧集创作情况`来判断是否所有剧集都已创作完成,以及需要创作智能体单次创作的集数(通常是3-5集), 该集数为`指定创作集数`,需要添加到返回参数中***
|
|
||||||
|
|
||||||
# 智能体职责介绍
|
|
||||||
***调度智能体*** 名称:`scheduler` 描述:你自己,需要用户确认反馈时返回自身,并把状态设置成waiting;
|
|
||||||
***原始剧本分析 智能体*** 名称:`script_analysis` 描述:构建`诊断与资产评估`;内容包括:故事内核诊断、可继承的宝贵资产(高光情节、神来之笔对白、独特人设闪光点)、以及核心问题与初步改编建议。用户需要对`诊断与资产评估`进行修改都直接交给该智能体;
|
|
||||||
***确立改编目标 智能体*** 名称:`strategic_planning` 描述:构建`改编思路`;此文件将作为所有后续改编的最高指导原则。用户需要对`改编思路`进行修改都直接交给该智能体;
|
|
||||||
***剧本圣经构建 智能体*** 名称:`build_bible` 描述:构建`剧本圣经`,剧本圣经具体包括了这几个部分:核心大纲, 核心人物小传, 重大事件时间线, 总人物表; 用户需要对`剧本圣经`的每一个部分进行修改都直接交给该智能体;
|
|
||||||
***剧集创作 智能体*** 名称:`episode_create` 描述:构建剧集的具体创作;注意该智能体仅负责剧集的创作;对于某一集的具体修改直接交给该智能体;
|
|
||||||
|
|
||||||
***注意:智能体调用后最终会返回再次请求到你,你需要根据智能体的处理结果来决定下一步***
|
|
||||||
***注意:`智能体调用` 不是工具方法的使用,而是在返回数据中把agent属性指定为要调用的智能体名称***
|
|
||||||
|
|
||||||
# 工具使用
|
|
||||||
上述智能体职责中提及的输出内容,都有对应的工具可供你调用进行查看;他们的查询工具名称分别对应如下:
|
|
||||||
原始剧本是否存在: `QueryOriginalScript`
|
|
||||||
诊断与资产评估是否存在: `QueryDiagnosisAndAssessment`
|
|
||||||
改编思路是否存在: `QueryAdaptationIdeas`
|
|
||||||
剧本圣经是否存在: `QueryScriptBible`
|
|
||||||
核心大纲是否存在: `QueryCoreOutline`
|
|
||||||
核心人物小传是否存在: `QueryCharacterProfile`
|
|
||||||
重大事件时间线是否存在: `QueryCoreEventTimeline`
|
|
||||||
总人物表是否存在: `QueryCharacterList`
|
|
||||||
剧集创作情况: `QueryEpisodeCount`
|
|
||||||
|
|
||||||
***注意:工具使用是需要你调用工具方法的;大多数情况下同一个方法只需要调用一次***
|
|
||||||
|
|
||||||
***每次用户的输入都会携带最新的`任务列表`和`工作流参数`,注意查看并分析是否应该回复用户等待或提醒用户确认继续下一步***
|
|
||||||
# 工作流参数包含字段如下:
|
|
||||||
"session_id": 会话ID 可用于工具方法的调用
|
|
||||||
"task_index": 当前执行中的任务索引
|
|
||||||
"from_type": 本次请求来着哪里
|
|
||||||
user: 用户
|
|
||||||
agent: 智能体返回
|
|
||||||
|
|
||||||
# 任务列表是一个数组,每项的数据结构如下:
|
|
||||||
"agent": 执行这个任务的智能体名称
|
|
||||||
字符串内容 可为空 为空时表示当前任务不需要调用智能体
|
|
||||||
|
|
||||||
step: 阶段名称
|
|
||||||
wait_for_input: 等待用户提供原始剧本
|
|
||||||
script_analysis: 原始剧本分析
|
|
||||||
strategic_planning: 确立改编目标
|
|
||||||
build_bible: 剧本圣经构建
|
|
||||||
episode_create_loop: 剧集创作
|
|
||||||
finish: 所有剧集创作完成
|
|
||||||
|
|
||||||
"pause": 是否需要暂停 当需要和用户沟通时设置为true,任务会中断等待用户回复
|
|
||||||
|
|
||||||
status: 当前阶段的状态
|
|
||||||
waiting: 等待用户反馈、确认
|
|
||||||
running: 进行中
|
|
||||||
failed: 失败
|
|
||||||
completed: 完成
|
|
||||||
|
|
||||||
"reason": 失败原因,仅在`status`为`failed`时返回
|
|
||||||
字符串内容
|
|
||||||
|
|
||||||
"retry_count": 失败重试次数
|
|
||||||
整数内容
|
|
||||||
"episode_create_num": 指定创作集数 仅在episode_create_loop阶段会返回,内容是数组,数组中每一项是指定创作的剧集编号(从1开始);
|
|
||||||
整数数组
|
|
||||||
|
|
||||||
# 你的职责
|
|
||||||
分析用户所有的输入信息,以及信息中的`任务列表`、`工作流参数`;
|
|
||||||
首先读取`工作流参数`中的数据,判断from_type的值是user还是agent;
|
|
||||||
如果是user,说明用户是在和你沟通,你需要根据用户的输入来判断下一步;
|
|
||||||
如果是agent,说明是智能体执行完任务后返回的结果,你需要根据智能体的返回结果来判断下一步;根据task_index取出`任务列表`中的当前任务,判断他的状态来决定是否继续列表中的下一个任务;
|
|
||||||
如果当前任务的状态是completed,说明当前任务完成,需要继续列表中的下一个任务;
|
|
||||||
如果当前任务的状态是failed,说明当前任务失败,需要根据失败原因来判断是否需要重试;如果需要重试,需要增加重试次数,并且需要继续列表中的下一个任务;
|
|
||||||
如果当前任务的状态是waiting,说明当前任务等待用户反馈,需要等待用户反馈后继续执行;此时任务中的pause属性需要修改为true;
|
|
||||||
|
|
||||||
`任务列表`的生成规则:
|
|
||||||
1 `任务列表`为空时,你需要根据上文中`任务总体步骤描述`生成一个新的任务列表
|
|
||||||
2 执行`任务列表`中的第一个未完成的任务
|
|
||||||
|
|
||||||
当你读取到一个空的任务列表 或者 任务列表中的所有任务都完成或失败时,你需要分析出一个新的任务列表;
|
|
||||||
新的任务列表至少包含一个任务;
|
|
||||||
任务列表中的每个任务代表的是一个后续智能体要执行的任务,其中scheduler代表你自己,大多数情况下这代表了用户回复了你的提问;
|
|
||||||
|
|
||||||
以下是任务列表的几种情况的示例:
|
|
||||||
1 任务列表为空时,你需要生成一个新的任务列表;此时你的分析结果是需要用户
|
|
||||||
1 `wait_for_input` 向用户问好,并介绍你作为“爆款短剧操盘手”的身份和专业工作流程,礼貌地请用户提供需要改编的原始剧本。如果用户没有提供原始剧本,你将持续友好地提醒,此时状态始终为waiting,直到获取原始剧本为止。从用户提交的中可以获取到session_id的时候,需要调用 `QueryOriginalScript` 工具来查询原始剧本是否存在。
|
|
||||||
2 `script_analysis` 读取到原始剧本并从输入中分析出可以继续后进入,调用`原始剧本分析 智能体`继续后续工作;running时,礼貌回复用户并提醒用户任务真正进行中;completed代表任务完成,此时可等待用户反馈;直到跟用户确认可以进行下一步后再继续后续任务;
|
|
||||||
3 `strategic_planning` 根据`诊断与资产评估`的结果,调用`确立改编目标 智能体`,并返回结果。
|
|
||||||
4 `build_bible` 根据`改编思路`的结果,调用`剧本圣经构建 智能体`,并返回结果。
|
|
||||||
5 `episode_create_loop` 根据`剧本圣经`的结果,调用`剧集创作 智能体`
|
|
||||||
5 `finish` 所有剧集完成后设置为该状态,但是不要返回node==end_node,因为用户还可以继续输入来进一步修改产出内容;
|
|
||||||
|
|
||||||
***当任意一个智能体返回失败时,你需要分析reason字段中的内容,来决定是否进行重试,如果需要重试则给retry_count加1,并交给失败的那个智能体重试一次;如果retry_count超过了3次,或者失败原因不适合重试则反馈给用户说任务失败了,请稍后再试***
|
|
||||||
|
|
||||||
请严格按照下列JSON结构返回数据,不要有其他任何多余的信息和描述:
|
请严格按照下列JSON结构返回数据,不要有其他任何多余的信息和描述:
|
||||||
{{
|
{{
|
||||||
"message":'',//回复给用户的内容,注意,仅在你与用户沟通时才返回,其他情况下不返回。比如用户的需求是要交给其他智能体处理时,这个属性应该为空
|
"message": "回复用户的内容(仅需用户沟通时填充)",
|
||||||
"task_list": [] //最新的任务列表
|
"task_list": [更新后的任务列表数组],
|
||||||
"task_index": 0 //执行中的任务的索引
|
"task_index": 当前执行任务索引,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create_agent_prompt(prompt, SchedulerList):
|
def create_agent_prompt(prompt, SchedulerList):
|
||||||
"""创建代理提示词的辅助函数"""
|
"""创建代理提示词的辅助函数"""
|
||||||
if not SchedulerList or len(SchedulerList) == 0: return prompt
|
if not SchedulerList or len(SchedulerList) == 0: return prompt
|
||||||
@ -149,13 +77,56 @@ def create_agent_prompt(prompt, SchedulerList):
|
|||||||
{node_list} \n
|
{node_list} \n
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class SchedulerAgentState(AgentState):
|
||||||
|
"""调度智能体中的上下文对象"""
|
||||||
|
is_original_script: bool # 是否已提交原始剧本
|
||||||
|
is_script_analysis: bool # 是否已生成 诊断与资产评估报告
|
||||||
|
is_strategic_planning: bool # 是否已生成 改编思路
|
||||||
|
is_build_bible: bool # 是否已生成 剧本圣经
|
||||||
|
is_episode_create_loop: bool # 是否已生成 剧集生成循环
|
||||||
|
is_all_episode_created: bool # 是否已生成 全部剧集
|
||||||
|
|
||||||
|
def pre_scheduler_hook(state:SchedulerAgentState):
|
||||||
|
"""模型调用前的钩子函数"""
|
||||||
|
logger.info(f"!!!!!!!!!调度节点输入!!!!!!!!!!!")
|
||||||
|
session_id = state.get("session_id", "")
|
||||||
|
workflow_info = state.get("workflow_info", {})
|
||||||
|
task_list = state.get("task_list", [])
|
||||||
|
task_index = int(state.get("task_index", 0))
|
||||||
|
from_type = state.get("from_type", "")
|
||||||
|
messages = state.get("messages", [])
|
||||||
|
logger.info(f"调度节点输入: {state}")
|
||||||
|
logger.info(f"!!!!!!!!!调度节点输入!!!!!!!!!!!")
|
||||||
|
# 清除历史状态消息
|
||||||
|
# messages = clear_messages(messages)
|
||||||
|
# # 添加参数进提示词
|
||||||
|
# messages.append(HumanMessage(content=f"""
|
||||||
|
# ---任务状态消息(开始)---
|
||||||
|
# # 工作流信息:
|
||||||
|
# {workflow_info}
|
||||||
|
# # 工作流参数:
|
||||||
|
# {{
|
||||||
|
# 'session_id':'{session_id}',
|
||||||
|
# 'task_index':'{task_index}',
|
||||||
|
# 'from_type':'{from_type}',
|
||||||
|
# }}
|
||||||
|
# # 任务列表:
|
||||||
|
# {task_list}
|
||||||
|
# ---任务状态消息(结束)---
|
||||||
|
# """))
|
||||||
|
# return state
|
||||||
|
|
||||||
class SchedulerAgent(CompiledStateGraph):
|
class SchedulerAgent(CompiledStateGraph):
|
||||||
"""智能调度智能体类
|
"""智能调度智能体类
|
||||||
|
|
||||||
该类负责接收用户的提示词,并调用其他智能体来处理工作。
|
该类负责接收用户的提示词,并调用其他智能体来处理工作。
|
||||||
"""
|
"""
|
||||||
def __new__(cls, llm=None, tools=[], SchedulerList=None):
|
def __new__(cls,
|
||||||
|
llm=None,
|
||||||
|
tools=[],
|
||||||
|
SchedulerList=None,
|
||||||
|
post_model_hook=None,
|
||||||
|
):
|
||||||
"""创建并返回create_react_agent创建的对象"""
|
"""创建并返回create_react_agent创建的对象"""
|
||||||
# 处理默认参数
|
# 处理默认参数
|
||||||
if llm is None:
|
if llm is None:
|
||||||
@ -169,5 +140,9 @@ class SchedulerAgent(CompiledStateGraph):
|
|||||||
return create_react_agent(
|
return create_react_agent(
|
||||||
model=llm,
|
model=llm,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
prompt=create_agent_prompt(prompt=DefaultAgentPrompt, SchedulerList=SchedulerList),
|
prompt=DefaultAgentPrompt,
|
||||||
)
|
# pre_model_hook=pre_scheduler_hook,
|
||||||
|
# post_model_hook=post_model_hook,
|
||||||
|
# state_schema=SchedulerAgentState,
|
||||||
|
# prompt=create_agent_prompt(prompt=DefaultAgentPrompt, SchedulerList=SchedulerList),
|
||||||
|
)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from re import T
|
|||||||
from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional
|
from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional
|
||||||
|
|
||||||
from langgraph.graph.state import RunnableConfig
|
from langgraph.graph.state import RunnableConfig
|
||||||
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
from agent.scheduler import SchedulerAgent
|
from agent.scheduler import SchedulerAgent
|
||||||
from agent.build_bible import BuildBibleAgent
|
from agent.build_bible import BuildBibleAgent
|
||||||
from agent.episode_create import EpisodeCreateAgent
|
from agent.episode_create import EpisodeCreateAgent
|
||||||
@ -24,7 +25,8 @@ import config
|
|||||||
from tools.database.mongo import client # type: ignore
|
from tools.database.mongo import client # type: ignore
|
||||||
from langgraph.checkpoint.mongodb import MongoDBSaver
|
from langgraph.checkpoint.mongodb import MongoDBSaver
|
||||||
# 工具方法
|
# 工具方法
|
||||||
from tools.agent.queryDB import QueryOriginalScript
|
from tools.agent.queryDB import QueryOriginalScript,QueryDiagnosisAndAssessment,QueryAdaptationIdeas,QueryScriptBible,QueryEpisodeCount
|
||||||
|
from tools.agent.updateDB import UpdateAdaptationIdeasTool,UpdateScriptBibleTool,UpdateDiagnosisAndAssessmentTool,UpdateOneEpisodeTool
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -55,6 +57,7 @@ def clear_messages(messages):
|
|||||||
def messages_handler(old_messages: list[AnyMessage], new_messages: list[AnyMessage]):
|
def messages_handler(old_messages: list[AnyMessage], new_messages: list[AnyMessage]):
|
||||||
"""消息合并方法"""
|
"""消息合并方法"""
|
||||||
clear_messages(old_messages)
|
clear_messages(old_messages)
|
||||||
|
new_messages = [message for message in new_messages if message.type != 'ai' or message.content]
|
||||||
return old_messages + new_messages
|
return old_messages + new_messages
|
||||||
|
|
||||||
def replace_value(old_val, new_val):
|
def replace_value(old_val, new_val):
|
||||||
@ -141,89 +144,82 @@ class ScriptwriterGraph:
|
|||||||
|
|
||||||
def node_router(self, state: ScriptwriterState) -> str:
|
def node_router(self, state: ScriptwriterState) -> str:
|
||||||
"""节点路由函数"""
|
"""节点路由函数"""
|
||||||
|
# return 'end_node'
|
||||||
# print(f'node_router state {state}')
|
# print(f'node_router state {state}')
|
||||||
task_list = state.get("task_list", [])
|
task_list = state.get("task_list", [])
|
||||||
task_index = state.get("task_index", 0)
|
task_index = state.get("task_index", 0)
|
||||||
now_task = task_list[task_index]
|
now_task = task_list[task_index]
|
||||||
|
print(f'node_router now_task {now_task}')
|
||||||
if not now_task or now_task.get('pause'):
|
if not now_task or now_task.get('pause'):
|
||||||
next_node = 'pause_node' # 设置为暂停节点
|
next_node = 'pause_node' # 设置为暂停节点
|
||||||
else:
|
else:
|
||||||
next_node = AgentNodeMap.get(now_task.get('agent'), 'pause_node')
|
next_node = AgentNodeMap.get(now_task.get('agent'), 'pause_node')
|
||||||
# print(f'node_router next_node {next_node}')
|
print(f'node_router next_node {next_node}')
|
||||||
return next_node
|
return next_node
|
||||||
|
|
||||||
|
|
||||||
|
def post_scheduler_hook(self, state: ScriptwriterState)-> ScriptwriterState:
|
||||||
|
"""模型调用后的钩子函数"""
|
||||||
|
ai_message_str = state.get("messages", [])[-1].content
|
||||||
|
logger.info(f"!!!!!!!!!调度节点输出!!!!!!!!!!!")
|
||||||
|
logger.info(f"调度节点结果: {ai_message_str}")
|
||||||
|
# logger.info(f"调度节点结果 end")
|
||||||
|
# ai_message = json.loads(ai_message_str)
|
||||||
|
return state
|
||||||
|
|
||||||
def _build_graph(self) -> None:
|
def _build_graph(self) -> None:
|
||||||
"""构建工作流图"""
|
"""构建工作流图"""
|
||||||
try:
|
try:
|
||||||
# 创建智能体
|
# 创建智能体
|
||||||
print("创建智能体")
|
print("创建智能体")
|
||||||
|
from tools.llm.deepseek_langchain import DeepseekChatModel
|
||||||
|
llm = DeepseekChatModel(api_key='sk-571923fdcb0e493d8def7e2d78c02cb8')
|
||||||
# 调度智能体
|
# 调度智能体
|
||||||
self.schedulerAgent = SchedulerAgent(
|
self.schedulerAgent = SchedulerAgent(
|
||||||
|
# llm=llm,
|
||||||
tools=[
|
tools=[
|
||||||
QueryOriginalScript,
|
QueryOriginalScript,
|
||||||
|
QueryDiagnosisAndAssessment,
|
||||||
|
QueryAdaptationIdeas,
|
||||||
|
QueryScriptBible,
|
||||||
|
QueryEpisodeCount,
|
||||||
],
|
],
|
||||||
SchedulerList=[
|
|
||||||
{
|
|
||||||
"name": "scheduler_node",
|
|
||||||
"desc": "调度智能体节点",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "script_analysis_node",
|
|
||||||
"desc": "原始剧本分析节点",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "strategic_planning_node",
|
|
||||||
"desc": "确立改编目标节点",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "build_bible_node",
|
|
||||||
"desc": "剧本圣经构建节点",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "episode_create_node",
|
|
||||||
"desc": "单集创作节点",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "end_node",
|
|
||||||
"desc": "结束节点,任务失败终止时使用,结束后整个工作流将停止"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
self.scriptAnalysisAgent = ScriptAnalysisAgent(
|
self.scriptAnalysisAgent = ScriptAnalysisAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
SchedulerList=[
|
# SchedulerList=[
|
||||||
{
|
# {
|
||||||
"name": "scheduler_node",
|
# "name": "scheduler_node",
|
||||||
"desc": "调度智能体节点",
|
# "desc": "调度智能体节点",
|
||||||
}
|
# }
|
||||||
]
|
# ]
|
||||||
)
|
)
|
||||||
self.strategicPlanningAgent = StrategicPlanningAgent(
|
self.strategicPlanningAgent = StrategicPlanningAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
SchedulerList=[
|
# SchedulerList=[
|
||||||
{
|
# {
|
||||||
"name": "scheduler_node",
|
# "name": "scheduler_node",
|
||||||
"desc": "调度智能体节点",
|
# "desc": "调度智能体节点",
|
||||||
}
|
# }
|
||||||
]
|
# ]
|
||||||
)
|
)
|
||||||
self.buildBibleAgent = BuildBibleAgent(
|
self.buildBibleAgent = BuildBibleAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
SchedulerList=[
|
# SchedulerList=[
|
||||||
{
|
# {
|
||||||
"name": "scheduler_node",
|
# "name": "scheduler_node",
|
||||||
"desc": "调度智能体节点",
|
# "desc": "调度智能体节点",
|
||||||
}
|
# }
|
||||||
]
|
# ]
|
||||||
)
|
)
|
||||||
self.episodeCreateAgent = EpisodeCreateAgent(
|
self.episodeCreateAgent = EpisodeCreateAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
SchedulerList=[
|
# SchedulerList=[
|
||||||
{
|
# {
|
||||||
"name": "scheduler_node",
|
# "name": "scheduler_node",
|
||||||
"desc": "调度智能体节点",
|
# "desc": "调度智能体节点",
|
||||||
}
|
# }
|
||||||
]
|
# ]
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建状态图
|
# 创建状态图
|
||||||
@ -272,6 +268,8 @@ class ScriptwriterGraph:
|
|||||||
logger.info("工作流图构建完成")
|
logger.info("工作流图构建完成")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
logger.error(f"构建工作流图失败: {e}")
|
logger.error(f"构建工作流图失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -289,10 +287,10 @@ class ScriptwriterGraph:
|
|||||||
# 清除历史状态消息
|
# 清除历史状态消息
|
||||||
messages = clear_messages(messages)
|
messages = clear_messages(messages)
|
||||||
# 添加参数进提示词
|
# 添加参数进提示词
|
||||||
|
# 工作流信息:
|
||||||
|
# {workflow_info}
|
||||||
messages.append(HumanMessage(content=f"""
|
messages.append(HumanMessage(content=f"""
|
||||||
---任务状态消息(开始)---
|
---任务状态消息(开始)---
|
||||||
# 工作流信息:
|
|
||||||
{workflow_info}
|
|
||||||
# 工作流参数:
|
# 工作流参数:
|
||||||
{{
|
{{
|
||||||
'session_id':'{session_id}',
|
'session_id':'{session_id}',
|
||||||
@ -315,8 +313,13 @@ class ScriptwriterGraph:
|
|||||||
ai_message_count += 1
|
ai_message_count += 1
|
||||||
logger.info(f"调度节点 {session_id} \n 输入消息条数: {len(messages)} \n from_type:{from_type} \n system_message_count:{system_message_count} \n human_message_count:{human_message_count} \n ai_message_count:{ai_message_count}")
|
logger.info(f"调度节点 {session_id} \n 输入消息条数: {len(messages)} \n from_type:{from_type} \n system_message_count:{system_message_count} \n human_message_count:{human_message_count} \n ai_message_count:{ai_message_count}")
|
||||||
# 调用智能体
|
# 调用智能体
|
||||||
reslut = await self.schedulerAgent.ainvoke({"messages":messages})
|
reslut = await self.schedulerAgent.ainvoke({
|
||||||
|
"messages": messages,
|
||||||
|
"session_id": session_id,
|
||||||
|
})
|
||||||
ai_message_str = reslut['messages'][-1].content
|
ai_message_str = reslut['messages'][-1].content
|
||||||
|
logger.info(f"调度节点结果: {ai_message_str}")
|
||||||
|
logger.info(f"调度节点结果 end")
|
||||||
ai_message = json.loads(ai_message_str)
|
ai_message = json.loads(ai_message_str)
|
||||||
# logger.info(f"调度节点结果: {ai_message}")
|
# logger.info(f"调度节点结果: {ai_message}")
|
||||||
return_message:str = ai_message.get('message', '')
|
return_message:str = ai_message.get('message', '')
|
||||||
@ -324,10 +327,9 @@ class ScriptwriterGraph:
|
|||||||
task_index:int = int(ai_message.get('task_index', '0'))
|
task_index:int = int(ai_message.get('task_index', '0'))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": messages,
|
"agent_message": return_message,
|
||||||
"task_list": task_list,
|
"task_list": task_list,
|
||||||
"task_index": task_index,
|
"task_index": task_index,
|
||||||
"agent_message": return_message,
|
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
@ -620,8 +622,13 @@ class ScriptwriterGraph:
|
|||||||
工作流执行结果
|
工作流执行结果
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info("开始运行智能编剧工作流")
|
logger.info("开始运行智能编剧工作流")
|
||||||
|
output_result: OutputState = {
|
||||||
|
'session_id': session_id,
|
||||||
|
'status': '',
|
||||||
|
'error': '',
|
||||||
|
'agent_message': '',
|
||||||
|
}
|
||||||
# 配置包含线程 ID
|
# 配置包含线程 ID
|
||||||
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
|
config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
|
||||||
# 初始化状态
|
# 初始化状态
|
||||||
@ -630,11 +637,10 @@ class ScriptwriterGraph:
|
|||||||
'session_id': session_id,
|
'session_id': session_id,
|
||||||
'from_type': 'user',
|
'from_type': 'user',
|
||||||
}
|
}
|
||||||
|
|
||||||
# 运行工作流
|
# 运行工作流
|
||||||
if self.graph is None:
|
if self.graph is None:
|
||||||
raise RuntimeError("工作流图未正确初始化")
|
raise RuntimeError("工作流图未正确初始化")
|
||||||
|
# 运行工作流 直接返回最终结果
|
||||||
result = await self.graph.ainvoke(
|
result = await self.graph.ainvoke(
|
||||||
initial_state,
|
initial_state,
|
||||||
config,
|
config,
|
||||||
@ -642,15 +648,24 @@ class ScriptwriterGraph:
|
|||||||
)
|
)
|
||||||
# logger.info(f"工作流执行结果: {result}")
|
# logger.info(f"工作流执行结果: {result}")
|
||||||
if not result:
|
if not result:
|
||||||
raise ValueError("工作流执行结果为空")
|
raise ValueError("工作流执行结果为空")
|
||||||
|
|
||||||
# 构造输出状态
|
# 构造输出状态
|
||||||
output_result: OutputState = {
|
output_result = {
|
||||||
'session_id': result.get('session_id', ''),
|
'session_id': result.get('session_id', ''),
|
||||||
'status': result.get('status', ''),
|
'status': result.get('status', ''),
|
||||||
'error': result.get('error', ''),
|
'error': result.get('error', ''),
|
||||||
'agent_message': result.get('agent_message', ''),
|
'agent_message': result.get('agent_message', ''),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 流式处理
|
||||||
|
# st = self.graph.stream(
|
||||||
|
# initial_state,
|
||||||
|
# config,
|
||||||
|
# stream_mode='values'
|
||||||
|
# )
|
||||||
|
# from utils.stream import debug_print_stream
|
||||||
|
# debug_print_stream(st)
|
||||||
|
|
||||||
logger.info("智能编剧工作流运行完成")
|
logger.info("智能编剧工作流运行完成")
|
||||||
return output_result
|
return output_result
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,11 @@
|
|||||||
|
from typing import Annotated
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langchain_core.tools import InjectedToolCallId
|
||||||
|
from langgraph.types import Command
|
||||||
from tools.database.mongo import mainDB
|
from tools.database.mongo import mainDB
|
||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
|
import json
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def QueryOriginalScript(session_id: str):
|
def QueryOriginalScript(session_id: str):
|
||||||
@ -13,9 +18,24 @@ def QueryOriginalScript(session_id: str):
|
|||||||
exist (bool): 原始剧本内容是否存在。
|
exist (bool): 原始剧本内容是否存在。
|
||||||
"""
|
"""
|
||||||
script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "original_script": {"$exists": True, "$ne": ""}},{"_id":1})
|
script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "original_script": {"$exists": True, "$ne": ""}},{"_id":1})
|
||||||
|
is_original_script = script is not None
|
||||||
|
# print(f"tool_call_id {tool_call_id}")
|
||||||
|
Command(update={
|
||||||
|
"is_original_script": is_original_script
|
||||||
|
})
|
||||||
return {
|
return {
|
||||||
"exist": script is not None,
|
"exist": is_original_script,
|
||||||
}
|
}
|
||||||
|
tool_message_content = json.dumps({"exist": is_original_script})
|
||||||
|
# return Command(update={
|
||||||
|
# "is_original_script": is_original_script,
|
||||||
|
# "messages": [
|
||||||
|
# ToolMessage(
|
||||||
|
# tool_call_id=tool_call_id,
|
||||||
|
# content=tool_message_content
|
||||||
|
# )
|
||||||
|
# ]
|
||||||
|
# })
|
||||||
|
|
||||||
def QueryOriginalScriptContent(session_id: str):
|
def QueryOriginalScriptContent(session_id: str):
|
||||||
"""
|
"""
|
||||||
|
|||||||
85
tools/llm/deepseek_langchain.py
Normal file
85
tools/llm/deepseek_langchain.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from typing import Any, List, Optional
|
||||||
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.outputs import ChatResult
|
||||||
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 继承自 ChatOpenAI 而不是 BaseChatModel
|
||||||
|
class DeepseekChatModel(ChatOpenAI):
|
||||||
|
"""
|
||||||
|
对 DeepSeek 聊天模型的 LangChain 封装。
|
||||||
|
这个版本通过继承 ChatOpenAI 来实现,并重写 _generate 方法。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "deepseek-chat", **kwargs: Any):
|
||||||
|
"""
|
||||||
|
初始化 DeepseekChatModel
|
||||||
|
"""
|
||||||
|
# 从环境变量或参数中获取 api_key
|
||||||
|
api_key = kwargs.pop("api_key", os.getenv("DEEPSEEK_API_KEY"))
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"DeepSeek API key must be provided either as an argument or set as the DEEPSEEK_API_KEY environment variable."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用父类(ChatOpenAI)的构造函数,并传入 DeepSeek 的特定配置
|
||||||
|
super().__init__(
|
||||||
|
model=model_name,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url="https://api.deepseek.com/v1", # 这是关键
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""
|
||||||
|
重写 _generate 方法,这是 LangChain 的核心调用点。
|
||||||
|
"""
|
||||||
|
# 1. 在调用父类方法前,执行你的自定义逻辑(例如,打印日志)
|
||||||
|
print("\n--- [DeepseekChatModel] 调用 _generate ---")
|
||||||
|
print(f"输入消息数量: {len(messages)}")
|
||||||
|
print(f"第一条消息内容: {messages[0].content}")
|
||||||
|
print("-------------------------------------------\n")
|
||||||
|
|
||||||
|
# 2. 使用 super() 调用父类(ChatOpenAI)的原始 _generate 方法
|
||||||
|
# 这样可以复用 ChatOpenAI 中所有复杂的、经过测试的逻辑
|
||||||
|
chat_result = super()._generate(
|
||||||
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 在获得结果后,你还可以执行其他自定义逻辑
|
||||||
|
print("\n--- [DeepseekChatModel] _generate 调用完成 ---")
|
||||||
|
print(f"输出消息内容: {chat_result.generations[0].message.content[:80]}...") # 打印部分输出
|
||||||
|
print("--------------------------------------------\n")
|
||||||
|
|
||||||
|
return chat_result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""返回 language model 的类型。"""
|
||||||
|
return "deepseek-chat-model-v2" # 改个名以示区别
|
||||||
|
|
||||||
|
# --- 使用示例 ---
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 确保设置了 API 密钥
|
||||||
|
if not os.getenv("DEEPSEEK_API_KEY"):
|
||||||
|
print("请设置 DEEPSEEK_API_KEY 环境变量。")
|
||||||
|
else:
|
||||||
|
# 初始化模型
|
||||||
|
chat_model = DeepseekChatModel(temperature=0.7)
|
||||||
|
|
||||||
|
# 构建输入消息
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
messages = [HumanMessage(content="你好,请介绍一下新加坡。")]
|
||||||
|
|
||||||
|
# 调用模型
|
||||||
|
# 当你调用 .invoke() 或 .stream() 时,LangChain 内部会调用我们重写的 _generate 方法
|
||||||
|
response = chat_model.invoke(messages)
|
||||||
|
|
||||||
|
print(f"\n最终得到的回复:\n{response.content}")
|
||||||
@ -8,6 +8,7 @@ from langchain_core.runnables import Runnable
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
import json
|
import json
|
||||||
import copy
|
import copy
|
||||||
|
from langchain_core.messages import AnyMessage,HumanMessage
|
||||||
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
|
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
|
||||||
from volcenginesdkarkruntime.types.shared_params.function_definition import FunctionDefinition
|
from volcenginesdkarkruntime.types.shared_params.function_definition import FunctionDefinition
|
||||||
from volcenginesdkarkruntime.types.shared_params.function_parameters import FunctionParameters
|
from volcenginesdkarkruntime.types.shared_params.function_parameters import FunctionParameters
|
||||||
@ -174,15 +175,16 @@ class HuoshanChatModel(BaseChatModel):
|
|||||||
api_messages = self._convert_messages_to_prompt(messages)
|
api_messages = self._convert_messages_to_prompt(messages)
|
||||||
tools = kwargs.get("tools", [])
|
tools = kwargs.get("tools", [])
|
||||||
# print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}")
|
# print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}")
|
||||||
print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> \n")
|
# print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> \n")
|
||||||
print(f"\nmessages: \n")
|
# print(f"\nmessages: \n")
|
||||||
for message in messages:
|
# for message in messages:
|
||||||
print(f" {message.type}: \n ")
|
# print("--- Message Attributes ---\n")
|
||||||
print(f" {message.content} \n ")
|
# for key, value in message.__dict__.items():
|
||||||
print(f"\ntools: \n")
|
# print(f" {key}: {value} \n")
|
||||||
for tool in tools:
|
# print(f"\ntools: \n")
|
||||||
print(f" \n {tool} \n ")
|
# for tool in tools:
|
||||||
print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> end \n")
|
# print(f" {tool} \n ")
|
||||||
|
# print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> end \n")
|
||||||
|
|
||||||
response_data = self._api.get_chat_response(messages=api_messages, tools=tools)
|
response_data = self._api.get_chat_response(messages=api_messages, tools=tools)
|
||||||
|
|
||||||
@ -192,7 +194,7 @@ class HuoshanChatModel(BaseChatModel):
|
|||||||
message_from_api = res_choices.get("message", {})
|
message_from_api = res_choices.get("message", {})
|
||||||
tool_calls = message_from_api.get("tool_calls", [])
|
tool_calls = message_from_api.get("tool_calls", [])
|
||||||
print(f" 豆包返回的 finish_reason: {finish_reason} \n tool_calls: {tool_calls} \n")
|
print(f" 豆包返回的 finish_reason: {finish_reason} \n tool_calls: {tool_calls} \n")
|
||||||
print(f" 豆包返回的 message: {message_from_api.get('content', '')}")
|
# print(f" 豆包返回的 message: {message_from_api.get('content', '')}")
|
||||||
if finish_reason == "tool_calls" and tool_calls:
|
if finish_reason == "tool_calls" and tool_calls:
|
||||||
lc_tool_calls = []
|
lc_tool_calls = []
|
||||||
for tc in tool_calls:
|
for tc in tool_calls:
|
||||||
|
|||||||
193
tools/llm/openai_langchain.py
Normal file
193
tools/llm/openai_langchain.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, List, Optional, Type
|
||||||
|
|
||||||
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
|
from langchain_core.tools import BaseTool, tool
|
||||||
|
|
||||||
|
# --- 1. 中央模型配置 ---
|
||||||
|
# 在这里添加或修改你想要支持的模型提供商
|
||||||
|
# 'api_key_env': 存放 API Key 的环境变量名称
|
||||||
|
# 'base_url': 模型的 API 端点
|
||||||
|
# 'supports_tools': 该模型是否支持 LangChain 的 .bind_tools() 功能
|
||||||
|
# 'default_model': 如果不指定模型,默认使用的模型名称
|
||||||
|
MODEL_CONFIG = {
|
||||||
|
"openai": {
|
||||||
|
"api_key_env": "OPENAI_API_KEY",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"supports_tools": True,
|
||||||
|
"default_model": "gpt-4o",
|
||||||
|
},
|
||||||
|
"deepseek": {
|
||||||
|
"api_key_env": "DEEPSEEK_API_KEY",
|
||||||
|
"base_url": "https://api.deepseek.com/v1",
|
||||||
|
"supports_tools": True,
|
||||||
|
"default_model": "deepseek-chat",
|
||||||
|
},
|
||||||
|
"groq": {
|
||||||
|
"api_key_env": "GROQ_API_KEY",
|
||||||
|
"base_url": "https://api.groq.com/openai/v1",
|
||||||
|
"supports_tools": True,
|
||||||
|
"default_model": "llama3-70b-8192",
|
||||||
|
},
|
||||||
|
"moonshot": {
|
||||||
|
"api_key_env": "MOONSHOT_API_KEY",
|
||||||
|
"base_url": "https://api.moonshot.cn/v1",
|
||||||
|
"supports_tools": True,
|
||||||
|
"default_model": "moonshot-v1-8k",
|
||||||
|
},
|
||||||
|
# 示例:一个不支持工具的老模型
|
||||||
|
"legacy_provider": {
|
||||||
|
"api_key_env": "LEGACY_API_KEY",
|
||||||
|
"base_url": "http://localhost:8080/v1", # 假设是本地服务
|
||||||
|
"supports_tools": False,
|
||||||
|
"default_model": "legacy-model-v1",
|
||||||
|
},
|
||||||
|
|
||||||
|
# 新增火山引擎的配置
|
||||||
|
"huoshan": {
|
||||||
|
"api_key_env": "VOLC_API_KEY", # 请确认您存放密钥的环境变量名
|
||||||
|
"base_url": "https://ark.cn-beijing.volces.com/api/v3", # 这是火山引擎方舟的 OpenAI 兼容端点
|
||||||
|
"supports_tools": True,
|
||||||
|
"default_model": "ep-20240615082149-j225c", # 使用您需要的模型 Endpoint ID
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# --- 2. 模型创建工厂函数 ---
|
||||||
|
def create_llm_client(
|
||||||
|
provider: str,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
tools: Optional[List[Type[BaseModel]]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatOpenAI:
|
||||||
|
"""
|
||||||
|
根据提供的 provider 创建并配置一个 LangChain 聊天模型客户端。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: 模型提供商的名称 (必须是 MODEL_CONFIG 中的一个 key)。
|
||||||
|
model_name: 要使用的具体模型名称。如果为 None, 则使用配置中的默认模型。
|
||||||
|
tools: 一个工具列表,用于绑定到模型上以实现 function calling。
|
||||||
|
**kwargs: 其他要传递给 ChatOpenAI 的参数 (例如 temperature, max_tokens)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
一个配置好的 ChatOpenAI 实例,可能已经绑定了工具。
|
||||||
|
"""
|
||||||
|
if provider not in MODEL_CONFIG:
|
||||||
|
raise ValueError(f"不支持的模型提供商: {provider}。可用选项: {list(MODEL_CONFIG.keys())}")
|
||||||
|
|
||||||
|
config = MODEL_CONFIG[provider]
|
||||||
|
api_key = os.getenv(config["api_key_env"])
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError(f"请设置环境变量 {config['api_key_env']} 以使用 {provider} 模型。")
|
||||||
|
|
||||||
|
# 如果未指定 model_name,则使用配置中的默认值
|
||||||
|
final_model_name = model_name or config["default_model"]
|
||||||
|
|
||||||
|
# 创建基础的 LLM 客户端
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
model=final_model_name,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=config["base_url"],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 根据配置,有条件地绑定工具
|
||||||
|
if tools:
|
||||||
|
if config["supports_tools"]:
|
||||||
|
print(f"[{provider}] 模型支持工具,正在绑定 {len(tools)} 个工具...")
|
||||||
|
return llm.bind_tools(tools)
|
||||||
|
else:
|
||||||
|
print(f"⚠️ 警告: 您为 [{provider}] 提供了工具,但该模型在配置中被标记为不支持工具。将返回未绑定工具的模型。")
|
||||||
|
return llm
|
||||||
|
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
# --- 3. 定义你的工具 (Function Calling) ---
|
||||||
|
# 使用 Pydantic 模型定义工具的输入参数,确保类型安全和清晰的描述
|
||||||
|
class GetWeatherInput(BaseModel):
|
||||||
|
city: str = Field(description="需要查询天气的城市名称, 例如: Singapore")
|
||||||
|
|
||||||
|
# 使用 @tool 装饰器可以轻松地将任何函数转换为 LangChain 工具
|
||||||
|
@tool(args_schema=GetWeatherInput)
|
||||||
|
def get_current_weather(city: str) -> str:
|
||||||
|
"""
|
||||||
|
当需要查询指定城市的当前天气时,调用此工具。
|
||||||
|
"""
|
||||||
|
# 这是一个模拟实现,实际应用中你会在这里调用真实的天气API
|
||||||
|
print(f"--- 正在调用工具: get_current_weather(city='{city}') ---")
|
||||||
|
if "singapore" in city.lower():
|
||||||
|
return f"新加坡今天的天气是晴朗,温度为 31°C。"
|
||||||
|
elif "beijing" in city.lower():
|
||||||
|
return f"北京今天的天气是多云,温度为 25°C。"
|
||||||
|
else:
|
||||||
|
return f"抱歉,我无法查询到 {city} 的天气信息。"
|
||||||
|
|
||||||
|
|
||||||
|
# --- 4. 主程序:演示如何使用 ---
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 将你的工具放入一个列表
|
||||||
|
my_tools = [get_current_weather]
|
||||||
|
|
||||||
|
# ---- 示例 1: 使用 DeepSeek 并调用工具 ----
|
||||||
|
print("\n================ 示例 1: 使用 DeepSeek 调用工具 ================")
|
||||||
|
# 确保你已经设置了环境变量: export DEEPSEEK_API_KEY="sk-..."
|
||||||
|
try:
|
||||||
|
deepseek_llm_with_tools = create_llm_client(
|
||||||
|
provider="deepseek",
|
||||||
|
tools=my_tools,
|
||||||
|
temperature=0
|
||||||
|
)
|
||||||
|
prompt = "今天新加坡的天气怎么样?"
|
||||||
|
print(f"用户问题: {prompt}")
|
||||||
|
|
||||||
|
# LangChain 会自动处理:LLM -> Tool Call -> Execute Tool -> LLM -> Final Answer
|
||||||
|
# 为了演示,我们只看第一步的输出
|
||||||
|
ai_msg = deepseek_llm_with_tools.invoke(prompt)
|
||||||
|
print("\nLLM 返回的初步响应 (AIMessage):")
|
||||||
|
print(ai_msg)
|
||||||
|
|
||||||
|
# 检查返回的是否是工具调用
|
||||||
|
if ai_msg.tool_calls:
|
||||||
|
print(f"\n模型请求调用工具: {ai_msg.tool_calls[0]['name']}")
|
||||||
|
# 在实际应用中,你会在这里执行工具并把结果返回给模型
|
||||||
|
else:
|
||||||
|
print("\n模型直接给出了回答:")
|
||||||
|
print(ai_msg.content)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
# ---- 示例 2: 使用 Groq (Llama3) 且不使用工具 ----
|
||||||
|
print("\n================ 示例 2: 使用 Groq 进行常规聊天 ================")
|
||||||
|
# 确保你已经设置了环境变量: export GROQ_API_KEY="gsk_..."
|
||||||
|
try:
|
||||||
|
groq_llm = create_llm_client(
|
||||||
|
provider="groq",
|
||||||
|
temperature=0.7,
|
||||||
|
# model_name="llama3-8b-8192" # 你也可以覆盖默认模型
|
||||||
|
)
|
||||||
|
prompt = "请给我写一首关于新加坡的五言绝句。"
|
||||||
|
print(f"用户问题: {prompt}")
|
||||||
|
response = groq_llm.invoke(prompt)
|
||||||
|
print("\nGroq (Llama3) 的回答:")
|
||||||
|
print(response.content)
|
||||||
|
except ValueError as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
# ---- 示例 3: 尝试给不支持工具的模型绑定工具 ----
|
||||||
|
print("\n========== 示例 3: 尝试为不支持工具的模型绑定工具 ==========")
|
||||||
|
# 假设你设置了 export LEGACY_API_KEY="some_key"
|
||||||
|
try:
|
||||||
|
legacy_llm = create_llm_client(
|
||||||
|
provider="legacy_provider",
|
||||||
|
tools=my_tools
|
||||||
|
)
|
||||||
|
# 注意,这里会打印出警告信息,并且 legacy_llm 不会绑定任何工具
|
||||||
|
except ValueError as e:
|
||||||
|
print(e)
|
||||||
8
utils/stream.py
Normal file
8
utils/stream.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
def debug_print_stream(stream):
|
||||||
|
for chunk in stream:
|
||||||
|
message = chunk['messages'][-1]
|
||||||
|
if isinstance(message, tuple):
|
||||||
|
print(chunk)
|
||||||
|
else:
|
||||||
|
message.pretty_print()
|
||||||
|
print("\n---")
|
||||||
Loading…
x
Reference in New Issue
Block a user