修改了提示词和部分逻辑 可生成大纲 剧集生成还有问题

This commit is contained in:
jonathang4 2025-09-13 22:19:46 +08:00
parent 1c9012b08a
commit 20556d7ecb
7 changed files with 496 additions and 198 deletions

View File

@ -5,6 +5,7 @@
from langgraph.graph import StateGraph from langgraph.graph import StateGraph
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from langgraph.graph.state import CompiledStateGraph from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.chat_agent_executor import AgentState
from utils.logger import get_logger from utils.logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@ -13,132 +14,59 @@ logger = get_logger(__name__)
DefaultSchedulerList = [] DefaultSchedulerList = []
# 默认代理提示词 # 默认代理提示词
DefaultAgentPrompt = f""" DefaultAgentPrompt = """
# 角色 (Persona) # 角色定位
你不是一个普通的编剧你是一位在短剧市场身经百战爆款频出的**顶级短剧改编专家爆款操盘手** 您是爆款短剧改编专家具备极致爽点设计人物标签化情绪节奏控制和网感台词能力沟通风格自信犀利能清晰解释商业逻辑
你的核心人设与专长
极致爽点制造机 你对观众的爽点G点有着鬣狗般的嗅觉你的天职就是找到放大并以最密集的节奏呈现打脸逆袭揭秘宠溺等情节
人物标签化大师 你深知在短剧中模糊等于无效你擅长将人物的核心欲望和性格特点极致化标签化让观众在3秒内记住主角5秒内恨上反派
情绪过山车设计师 你的剧本就像过山车开篇即俯冲5秒一反转10秒一高潮结尾必留下一个让人抓心挠肝的钩子你为观众提供的是极致的情绪体验
网络梗语言学家 你的台词充满了网感和既能推动剧情又能引发观众的共鸣和吐槽欲对话追求高信息密度不说一句废话
你的沟通风格自信犀利直击要害同时又能清晰地解释你每一个改编决策背后的商业逻辑和观众心理
# 任务总体步骤描述 # 核心工作流
1. 查找并确认原始剧本已就绪 1. 等待原始剧本wait_for_input
2. 分析原始剧本得出`诊断与资产评估`需要用户确认可以继续下一步否则协助用户完成修改 2. 剧本分析script_analysis
3. 根据`诊断与资产评估`确定`改编思路`需要用户确认可以继续下一步否则协助用户完成修改 3. 改编思路制定strategic_planning
4. 根据`改编思路`生成`剧本圣经`需要用户确认可以继续下一步否则协助用户完成修改 4. 剧本圣经构建build_bible
5. 根据`改编思路``剧本圣经`持续剧集创作单次执行3-5集的创建直至完成全部剧集 5. 剧集循环创作episode_create_loop
6. 注意步骤具有上下级关系且不能跳过但是后续步骤可返回触发前面的任务如生成单集到第3集后用户提出要修改某个角色此时应当返回第4步并协助用户进行修改与确认完成修改后重新执行第5步即从第一集开始重新创作一遍 6. 完成finish
步骤中对应的阶段如下 步骤不可跳过但可回退修改除finish和wait_for_input外各阶段均由对应智能体处理episode_create_loop阶段需通过QueryEpisodeCount工具判断进度并指定单次创作集数3-5
wait_for_input: 等待剧本阶段查询到`原始剧本`存在并分析到用户确认后进入下一阶段
script_analysis: 原始剧本分析阶段查询到`诊断与资产评估`存在并分析到用户确认后进入下一阶段
strategic_planning: 确立改编目标阶段查询到`改编思路`存在并分析到用户确认后进入下一阶段
build_bible: 剧本圣经构建阶段查询到`剧本圣经`存在并分析到用户确认后进入下一阶段
episode_create_loop: 剧集创作阶段查询`剧集创作情况`并分析到已完成所有剧集的创作后进入下一阶段
finish: 所有剧集创作已完成用户确认后结束任务用户需要修改则回退到适合的步骤进行修改并重新执行后续阶段
***除了finish和wait_for_input之外的阶段都需要交给对应的智能体去处理*** # 智能体职责
***episode_create_loop阶段是一个循环阶段每次循环需要你通过工具方法`剧集创作情况`来判断是否所有剧集都已创作完成以及需要创作智能体单次创作的集数通常是3-5, 该集数为`指定创作集数`需要添加到返回参数中*** - scheduler您自身调度决策用户沟通状态管理
- script_analysis生成诊断与资产评估报告
- strategic_planning生成改编思路
- build_bible生成剧本圣经含核心大纲人物小传事件时间线人物表
- episode_create单集内容创作
# 智能体职责介绍 # 工具使用原则
***调度智能体*** 名称:`scheduler` 描述:你自己需要用户确认反馈时返回自身并把状态设置成waiting 仅在必要时调用工具避免重复关键工具包括QueryOriginalScriptQueryDiagnosisAndAssessmentQueryAdaptationIdeasQueryScriptBibleQueryEpisodeCount等
***原始剧本分析 智能体*** 名称:`script_analysis` 描述:构建`诊断与资产评估`内容包括故事内核诊断可继承的宝贵资产高光情节神来之笔对白独特人设闪光点以及核心问题与初步改编建议用户需要对`诊断与资产评估`进行修改都直接交给该智能体 - QueryOriginalScript原始剧本是否存在
***确立改编目标 智能体*** 名称:`strategic_planning` 描述:构建`改编思路`此文件将作为所有后续改编的最高指导原则用户需要对`改编思路`进行修改都直接交给该智能体 - QueryDiagnosisAndAssessment诊断与资产评估报告是否存在
***剧本圣经构建 智能体*** 名称:`build_bible` 描述:构建`剧本圣经`,剧本圣经具体包括了这几个部分核心大纲, 核心人物小传, 重大事件时间线, 总人物表; 用户需要对`剧本圣经`的每一个部分进行修改都直接交给该智能体 - QueryAdaptationIdeas改编思路是否存在
***剧集创作 智能体*** 名称:`episode_create` 描述:构建剧集的具体创作注意该智能体仅负责剧集的创作;对于某一集的具体修改直接交给该智能体 - QueryScriptBible剧本圣经是否存在
- QueryEpisodeCount剧集总数与生成完成集数获取
如果工具读取到存在为true则不需要再调用该工具;
如果工具读取到存在为false则需要需要分析当前任务的阶段来决定调用哪个工具每次分析只能调用一个工具
***注意智能体调用后最终会返回再次请求到你你需要根据智能体的处理结果来决定下一步*** # 任务列表管理
***注意`智能体调用` 不是工具方法的使用而是在返回数据中把agent属性指定为要调用的智能体名称*** - 任务列表为空时自动根据工作流步骤生成新列表
- 每项任务包含agentstepstatusreasonretry_countpauseepisode_create_num
- 执行逻辑优先处理第一个未完成任务状态为completed时推进failed时根据reason决定重试3或通知用户waiting时暂停等待用户输入
- 所有任务完成后用户输入仍可触发新任务列表
# 工具使用 # 输入数据解析
上述智能体职责中提及的输出内容都有对应的工具可供你调用进行查看他们的查询工具名称分别对应如下 每次调用附带
原始剧本是否存在: `QueryOriginalScript` - workflow_info布尔状态组原始剧本诊断报告等
诊断与资产评估是否存在: `QueryDiagnosisAndAssessment` - workflow_paramssession_idtask_indexfrom_typeuser或agent
改编思路是否存在: `QueryAdaptationIdeas` - task_list当前任务列表数组
剧本圣经是否存在: `QueryScriptBible`
核心大纲是否存在: `QueryCoreOutline`
核心人物小传是否存在: `QueryCharacterProfile`
重大事件时间线是否存在: `QueryCoreEventTimeline`
总人物表是否存在: `QueryCharacterList`
剧集创作情况: `QueryEpisodeCount`
***注意工具使用是需要你调用工具方法的大多数情况下同一个方法只需要调用一次*** 根据from_type决策user直接解析用户意图agent基于返回结果更新任务状态
***每次用户的输入都会携带最新的`任务列表``工作流参数`注意查看并分析是否应该回复用户等待或提醒用户确认继续下一步***
# 工作流参数包含字段如下:
"session_id": 会话ID 可用于工具方法的调用
"task_index": 当前执行中的任务索引
"from_type": 本次请求来着哪里
user: 用户
agent: 智能体返回
# 任务列表是一个数组,每项的数据结构如下:
"agent": 执行这个任务的智能体名称
字符串内容 可为空 为空时表示当前任务不需要调用智能体
step: 阶段名称
wait_for_input: 等待用户提供原始剧本
script_analysis: 原始剧本分析
strategic_planning: 确立改编目标
build_bible: 剧本圣经构建
episode_create_loop: 剧集创作
finish: 所有剧集创作完成
"pause": 是否需要暂停 当需要和用户沟通时设置为true任务会中断等待用户回复
status: 当前阶段的状态
waiting: 等待用户反馈确认
running: 进行中
failed: 失败
completed: 完成
"reason": 失败原因仅在`status``failed`时返回
字符串内容
"retry_count": 失败重试次数
整数内容
"episode_create_num": 指定创作集数 仅在episode_create_loop阶段会返回内容是数组数组中每一项是指定创作的剧集编号从1开始;
整数数组
# 你的职责
分析用户所有的输入信息以及信息中的`任务列表``工作流参数`;
首先读取`工作流参数`中的数据判断from_type的值是user还是agent;
如果是user说明用户是在和你沟通你需要根据用户的输入来判断下一步
如果是agent说明是智能体执行完任务后返回的结果你需要根据智能体的返回结果来判断下一步根据task_index取出`任务列表`中的当前任务判断他的状态来决定是否继续列表中的下一个任务
如果当前任务的状态是completed说明当前任务完成需要继续列表中的下一个任务
如果当前任务的状态是failed说明当前任务失败需要根据失败原因来判断是否需要重试如果需要重试需要增加重试次数并且需要继续列表中的下一个任务
如果当前任务的状态是waiting说明当前任务等待用户反馈需要等待用户反馈后继续执行此时任务中的pause属性需要修改为true
`任务列表`的生成规则
1 `任务列表`为空时你需要根据上文中`任务总体步骤描述`生成一个新的任务列表
2 执行`任务列表`中的第一个未完成的任务
当你读取到一个空的任务列表 或者 任务列表中的所有任务都完成或失败时你需要分析出一个新的任务列表
新的任务列表至少包含一个任务
任务列表中的每个任务代表的是一个后续智能体要执行的任务其中scheduler代表你自己大多数情况下这代表了用户回复了你的提问
以下是任务列表的几种情况的示例
1 任务列表为空时你需要生成一个新的任务列表此时你的分析结果是需要用户
1 `wait_for_input` 向用户问好并介绍你作为爆款短剧操盘手的身份和专业工作流程礼貌地请用户提供需要改编的原始剧本如果用户没有提供原始剧本你将持续友好地提醒此时状态始终为waiting直到获取原始剧本为止从用户提交的中可以获取到session_id的时候需要调用 `QueryOriginalScript` 工具来查询原始剧本是否存在
2 `script_analysis` 读取到原始剧本并从输入中分析出可以继续后进入调用`原始剧本分析 智能体`继续后续工作running时礼貌回复用户并提醒用户任务真正进行中completed代表任务完成此时可等待用户反馈直到跟用户确认可以进行下一步后再继续后续任务
3 `strategic_planning` 根据`诊断与资产评估`的结果调用`确立改编目标 智能体`并返回结果
4 `build_bible` 根据`改编思路`的结果调用`剧本圣经构建 智能体`并返回结果
5 `episode_create_loop` 根据`剧本圣经`的结果调用`剧集创作 智能体`
5 `finish` 所有剧集完成后设置为该状态但是不要返回node==end_node因为用户还可以继续输入来进一步修改产出内容
***当任意一个智能体返回失败时你需要分析reason字段中的内容来决定是否进行重试如果需要重试则给retry_count加1并交给失败的那个智能体重试一次如果retry_count超过了3次或者失败原因不适合重试则反馈给用户说任务失败了请稍后再试***
# 输出规范
请严格按照下列JSON结构返回数据不要有其他任何多余的信息和描述 请严格按照下列JSON结构返回数据不要有其他任何多余的信息和描述
{{ {{
"message":'',//回复给用户的内容注意仅在你与用户沟通时才返回其他情况下不返回比如用户的需求是要交给其他智能体处理时这个属性应该为空 "message": "回复用户的内容(仅需用户沟通时填充)",
"task_list": [] //最新的任务列表 "task_list": [更新后的任务列表数组],
"task_index": 0 //执行中的任务索引 "task_index": 当前执行任务索引,
}} }}
""" """
def create_agent_prompt(prompt, SchedulerList): def create_agent_prompt(prompt, SchedulerList):
"""创建代理提示词的辅助函数""" """创建代理提示词的辅助函数"""
if not SchedulerList or len(SchedulerList) == 0: return prompt if not SchedulerList or len(SchedulerList) == 0: return prompt
@ -149,13 +77,56 @@ def create_agent_prompt(prompt, SchedulerList):
{node_list} \n {node_list} \n
""" """
class SchedulerAgentState(AgentState):
"""调度智能体中的上下文对象"""
is_original_script: bool # 是否已提交原始剧本
is_script_analysis: bool # 是否已生成 诊断与资产评估报告
is_strategic_planning: bool # 是否已生成 改编思路
is_build_bible: bool # 是否已生成 剧本圣经
is_episode_create_loop: bool # 是否已生成 剧集生成循环
is_all_episode_created: bool # 是否已生成 全部剧集
def pre_scheduler_hook(state:SchedulerAgentState):
"""模型调用前的钩子函数"""
logger.info(f"!!!!!!!!!调度节点输入!!!!!!!!!!!")
session_id = state.get("session_id", "")
workflow_info = state.get("workflow_info", {})
task_list = state.get("task_list", [])
task_index = int(state.get("task_index", 0))
from_type = state.get("from_type", "")
messages = state.get("messages", [])
logger.info(f"调度节点输入: {state}")
logger.info(f"!!!!!!!!!调度节点输入!!!!!!!!!!!")
# 清除历史状态消息
# messages = clear_messages(messages)
# # 添加参数进提示词
# messages.append(HumanMessage(content=f"""
# ---任务状态消息(开始)---
# # 工作流信息:
# {workflow_info}
# # 工作流参数:
# {{
# 'session_id':'{session_id}',
# 'task_index':'{task_index}',
# 'from_type':'{from_type}',
# }}
# # 任务列表:
# {task_list}
# ---任务状态消息(结束)---
# """))
# return state
class SchedulerAgent(CompiledStateGraph): class SchedulerAgent(CompiledStateGraph):
"""智能调度智能体类 """智能调度智能体类
该类负责接收用户的提示词并调用其他智能体来处理工作 该类负责接收用户的提示词并调用其他智能体来处理工作
""" """
def __new__(cls, llm=None, tools=[], SchedulerList=None): def __new__(cls,
llm=None,
tools=[],
SchedulerList=None,
post_model_hook=None,
):
"""创建并返回create_react_agent创建的对象""" """创建并返回create_react_agent创建的对象"""
# 处理默认参数 # 处理默认参数
if llm is None: if llm is None:
@ -169,5 +140,9 @@ class SchedulerAgent(CompiledStateGraph):
return create_react_agent( return create_react_agent(
model=llm, model=llm,
tools=tools, tools=tools,
prompt=create_agent_prompt(prompt=DefaultAgentPrompt, SchedulerList=SchedulerList), prompt=DefaultAgentPrompt,
# pre_model_hook=pre_scheduler_hook,
# post_model_hook=post_model_hook,
# state_schema=SchedulerAgentState,
# prompt=create_agent_prompt(prompt=DefaultAgentPrompt, SchedulerList=SchedulerList),
) )

View File

@ -7,6 +7,7 @@ from re import T
from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional from typing import TypedDict, Annotated, Dict, Any, List, TypedDict, Optional
from langgraph.graph.state import RunnableConfig from langgraph.graph.state import RunnableConfig
from langgraph.prebuilt.chat_agent_executor import AgentState
from agent.scheduler import SchedulerAgent from agent.scheduler import SchedulerAgent
from agent.build_bible import BuildBibleAgent from agent.build_bible import BuildBibleAgent
from agent.episode_create import EpisodeCreateAgent from agent.episode_create import EpisodeCreateAgent
@ -24,7 +25,8 @@ import config
from tools.database.mongo import client # type: ignore from tools.database.mongo import client # type: ignore
from langgraph.checkpoint.mongodb import MongoDBSaver from langgraph.checkpoint.mongodb import MongoDBSaver
# 工具方法 # 工具方法
from tools.agent.queryDB import QueryOriginalScript from tools.agent.queryDB import QueryOriginalScript,QueryDiagnosisAndAssessment,QueryAdaptationIdeas,QueryScriptBible,QueryEpisodeCount
from tools.agent.updateDB import UpdateAdaptationIdeasTool,UpdateScriptBibleTool,UpdateDiagnosisAndAssessmentTool,UpdateOneEpisodeTool
logger = get_logger(__name__) logger = get_logger(__name__)
@ -55,6 +57,7 @@ def clear_messages(messages):
def messages_handler(old_messages: list[AnyMessage], new_messages: list[AnyMessage]): def messages_handler(old_messages: list[AnyMessage], new_messages: list[AnyMessage]):
"""消息合并方法""" """消息合并方法"""
clear_messages(old_messages) clear_messages(old_messages)
new_messages = [message for message in new_messages if message.type != 'ai' or message.content]
return old_messages + new_messages return old_messages + new_messages
def replace_value(old_val, new_val): def replace_value(old_val, new_val):
@ -141,89 +144,82 @@ class ScriptwriterGraph:
def node_router(self, state: ScriptwriterState) -> str: def node_router(self, state: ScriptwriterState) -> str:
"""节点路由函数""" """节点路由函数"""
# return 'end_node'
# print(f'node_router state {state}') # print(f'node_router state {state}')
task_list = state.get("task_list", []) task_list = state.get("task_list", [])
task_index = state.get("task_index", 0) task_index = state.get("task_index", 0)
now_task = task_list[task_index] now_task = task_list[task_index]
print(f'node_router now_task {now_task}')
if not now_task or now_task.get('pause'): if not now_task or now_task.get('pause'):
next_node = 'pause_node' # 设置为暂停节点 next_node = 'pause_node' # 设置为暂停节点
else: else:
next_node = AgentNodeMap.get(now_task.get('agent'), 'pause_node') next_node = AgentNodeMap.get(now_task.get('agent'), 'pause_node')
# print(f'node_router next_node {next_node}') print(f'node_router next_node {next_node}')
return next_node return next_node
def post_scheduler_hook(self, state: ScriptwriterState)-> ScriptwriterState:
"""模型调用后的钩子函数"""
ai_message_str = state.get("messages", [])[-1].content
logger.info(f"!!!!!!!!!调度节点输出!!!!!!!!!!!")
logger.info(f"调度节点结果: {ai_message_str}")
# logger.info(f"调度节点结果 end")
# ai_message = json.loads(ai_message_str)
return state
def _build_graph(self) -> None: def _build_graph(self) -> None:
"""构建工作流图""" """构建工作流图"""
try: try:
# 创建智能体 # 创建智能体
print("创建智能体") print("创建智能体")
from tools.llm.deepseek_langchain import DeepseekChatModel
llm = DeepseekChatModel(api_key='sk-571923fdcb0e493d8def7e2d78c02cb8')
# 调度智能体 # 调度智能体
self.schedulerAgent = SchedulerAgent( self.schedulerAgent = SchedulerAgent(
# llm=llm,
tools=[ tools=[
QueryOriginalScript, QueryOriginalScript,
QueryDiagnosisAndAssessment,
QueryAdaptationIdeas,
QueryScriptBible,
QueryEpisodeCount,
], ],
SchedulerList=[
{
"name": "scheduler_node",
"desc": "调度智能体节点",
},
{
"name": "script_analysis_node",
"desc": "原始剧本分析节点",
},
{
"name": "strategic_planning_node",
"desc": "确立改编目标节点",
},
{
"name": "build_bible_node",
"desc": "剧本圣经构建节点",
},
{
"name": "episode_create_node",
"desc": "单集创作节点",
},
{
"name": "end_node",
"desc": "结束节点,任务失败终止时使用,结束后整个工作流将停止"
}
]
) )
self.scriptAnalysisAgent = ScriptAnalysisAgent( self.scriptAnalysisAgent = ScriptAnalysisAgent(
tools=[], tools=[],
SchedulerList=[ # SchedulerList=[
{ # {
"name": "scheduler_node", # "name": "scheduler_node",
"desc": "调度智能体节点", # "desc": "调度智能体节点",
} # }
] # ]
) )
self.strategicPlanningAgent = StrategicPlanningAgent( self.strategicPlanningAgent = StrategicPlanningAgent(
tools=[], tools=[],
SchedulerList=[ # SchedulerList=[
{ # {
"name": "scheduler_node", # "name": "scheduler_node",
"desc": "调度智能体节点", # "desc": "调度智能体节点",
} # }
] # ]
) )
self.buildBibleAgent = BuildBibleAgent( self.buildBibleAgent = BuildBibleAgent(
tools=[], tools=[],
SchedulerList=[ # SchedulerList=[
{ # {
"name": "scheduler_node", # "name": "scheduler_node",
"desc": "调度智能体节点", # "desc": "调度智能体节点",
} # }
] # ]
) )
self.episodeCreateAgent = EpisodeCreateAgent( self.episodeCreateAgent = EpisodeCreateAgent(
tools=[], tools=[],
SchedulerList=[ # SchedulerList=[
{ # {
"name": "scheduler_node", # "name": "scheduler_node",
"desc": "调度智能体节点", # "desc": "调度智能体节点",
} # }
] # ]
) )
# 创建状态图 # 创建状态图
@ -272,6 +268,8 @@ class ScriptwriterGraph:
logger.info("工作流图构建完成") logger.info("工作流图构建完成")
except Exception as e: except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"构建工作流图失败: {e}") logger.error(f"构建工作流图失败: {e}")
raise raise
@ -289,10 +287,10 @@ class ScriptwriterGraph:
# 清除历史状态消息 # 清除历史状态消息
messages = clear_messages(messages) messages = clear_messages(messages)
# 添加参数进提示词 # 添加参数进提示词
# 工作流信息:
# {workflow_info}
messages.append(HumanMessage(content=f""" messages.append(HumanMessage(content=f"""
---任务状态消息(开始)--- ---任务状态消息(开始)---
# 工作流信息:
{workflow_info}
# 工作流参数: # 工作流参数:
{{ {{
'session_id':'{session_id}', 'session_id':'{session_id}',
@ -315,8 +313,13 @@ class ScriptwriterGraph:
ai_message_count += 1 ai_message_count += 1
logger.info(f"调度节点 {session_id} \n 输入消息条数: {len(messages)} \n from_type:{from_type} \n system_message_count:{system_message_count} \n human_message_count:{human_message_count} \n ai_message_count:{ai_message_count}") logger.info(f"调度节点 {session_id} \n 输入消息条数: {len(messages)} \n from_type:{from_type} \n system_message_count:{system_message_count} \n human_message_count:{human_message_count} \n ai_message_count:{ai_message_count}")
# 调用智能体 # 调用智能体
reslut = await self.schedulerAgent.ainvoke({"messages":messages}) reslut = await self.schedulerAgent.ainvoke({
"messages": messages,
"session_id": session_id,
})
ai_message_str = reslut['messages'][-1].content ai_message_str = reslut['messages'][-1].content
logger.info(f"调度节点结果: {ai_message_str}")
logger.info(f"调度节点结果 end")
ai_message = json.loads(ai_message_str) ai_message = json.loads(ai_message_str)
# logger.info(f"调度节点结果: {ai_message}") # logger.info(f"调度节点结果: {ai_message}")
return_message:str = ai_message.get('message', '') return_message:str = ai_message.get('message', '')
@ -324,10 +327,9 @@ class ScriptwriterGraph:
task_index:int = int(ai_message.get('task_index', '0')) task_index:int = int(ai_message.get('task_index', '0'))
return { return {
"messages": messages, "agent_message": return_message,
"task_list": task_list, "task_list": task_list,
"task_index": task_index, "task_index": task_index,
"agent_message": return_message,
} }
except Exception as e: except Exception as e:
import traceback import traceback
@ -621,7 +623,12 @@ class ScriptwriterGraph:
""" """
try: try:
logger.info("开始运行智能编剧工作流") logger.info("开始运行智能编剧工作流")
output_result: OutputState = {
'session_id': session_id,
'status': '',
'error': '',
'agent_message': '',
}
# 配置包含线程 ID # 配置包含线程 ID
config:RunnableConfig = {"configurable": {"thread_id": thread_id}} config:RunnableConfig = {"configurable": {"thread_id": thread_id}}
# 初始化状态 # 初始化状态
@ -630,11 +637,10 @@ class ScriptwriterGraph:
'session_id': session_id, 'session_id': session_id,
'from_type': 'user', 'from_type': 'user',
} }
# 运行工作流 # 运行工作流
if self.graph is None: if self.graph is None:
raise RuntimeError("工作流图未正确初始化") raise RuntimeError("工作流图未正确初始化")
# 运行工作流 直接返回最终结果
result = await self.graph.ainvoke( result = await self.graph.ainvoke(
initial_state, initial_state,
config, config,
@ -643,14 +649,23 @@ class ScriptwriterGraph:
# logger.info(f"工作流执行结果: {result}") # logger.info(f"工作流执行结果: {result}")
if not result: if not result:
raise ValueError("工作流执行结果为空") raise ValueError("工作流执行结果为空")
# 构造输出状态 # 构造输出状态
output_result: OutputState = { output_result = {
'session_id': result.get('session_id', ''), 'session_id': result.get('session_id', ''),
'status': result.get('status', ''), 'status': result.get('status', ''),
'error': result.get('error', ''), 'error': result.get('error', ''),
'agent_message': result.get('agent_message', ''), 'agent_message': result.get('agent_message', ''),
} }
# 流式处理
# st = self.graph.stream(
# initial_state,
# config,
# stream_mode='values'
# )
# from utils.stream import debug_print_stream
# debug_print_stream(st)
logger.info("智能编剧工作流运行完成") logger.info("智能编剧工作流运行完成")
return output_result return output_result

View File

@ -1,6 +1,11 @@
from typing import Annotated
from bson import ObjectId from bson import ObjectId
from langchain_core.messages import ToolMessage
from langchain_core.tools import InjectedToolCallId
from langgraph.types import Command
from tools.database.mongo import mainDB from tools.database.mongo import mainDB
from langchain.tools import tool from langchain.tools import tool
import json
@tool @tool
def QueryOriginalScript(session_id: str): def QueryOriginalScript(session_id: str):
@ -13,9 +18,24 @@ def QueryOriginalScript(session_id: str):
exist (bool): 原始剧本内容是否存在 exist (bool): 原始剧本内容是否存在
""" """
script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "original_script": {"$exists": True, "$ne": ""}},{"_id":1}) script = mainDB.agent_writer_session.find_one({"_id": ObjectId(session_id), "original_script": {"$exists": True, "$ne": ""}},{"_id":1})
is_original_script = script is not None
# print(f"tool_call_id {tool_call_id}")
Command(update={
"is_original_script": is_original_script
})
return { return {
"exist": script is not None, "exist": is_original_script,
} }
tool_message_content = json.dumps({"exist": is_original_script})
# return Command(update={
# "is_original_script": is_original_script,
# "messages": [
# ToolMessage(
# tool_call_id=tool_call_id,
# content=tool_message_content
# )
# ]
# })
def QueryOriginalScriptContent(session_id: str): def QueryOriginalScriptContent(session_id: str):
""" """

View File

@ -0,0 +1,85 @@
from typing import Any, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_community.chat_models import ChatOpenAI
import os
# 继承自 ChatOpenAI 而不是 BaseChatModel
class DeepseekChatModel(ChatOpenAI):
"""
DeepSeek 聊天模型的 LangChain 封装
这个版本通过继承 ChatOpenAI 来实现并重写 _generate 方法
"""
def __init__(self, model_name: str = "deepseek-chat", **kwargs: Any):
"""
初始化 DeepseekChatModel
"""
# 从环境变量或参数中获取 api_key
api_key = kwargs.pop("api_key", os.getenv("DEEPSEEK_API_KEY"))
if not api_key:
raise ValueError(
"DeepSeek API key must be provided either as an argument or set as the DEEPSEEK_API_KEY environment variable."
)
# 调用父类ChatOpenAI的构造函数并传入 DeepSeek 的特定配置
super().__init__(
model=model_name,
api_key=api_key,
base_url="https://api.deepseek.com/v1", # 这是关键
**kwargs,
)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""
重写 _generate 方法这是 LangChain 的核心调用点
"""
# 1. 在调用父类方法前,执行你的自定义逻辑(例如,打印日志)
print("\n--- [DeepseekChatModel] 调用 _generate ---")
print(f"输入消息数量: {len(messages)}")
print(f"第一条消息内容: {messages[0].content}")
print("-------------------------------------------\n")
# 2. 使用 super() 调用父类ChatOpenAI的原始 _generate 方法
# 这样可以复用 ChatOpenAI 中所有复杂的、经过测试的逻辑
chat_result = super()._generate(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
# 3. 在获得结果后,你还可以执行其他自定义逻辑
print("\n--- [DeepseekChatModel] _generate 调用完成 ---")
print(f"输出消息内容: {chat_result.generations[0].message.content[:80]}...") # 打印部分输出
print("--------------------------------------------\n")
return chat_result
@property
def _llm_type(self) -> str:
"""返回 language model 的类型。"""
return "deepseek-chat-model-v2" # 改个名以示区别
# --- 使用示例 ---
if __name__ == "__main__":
# 确保设置了 API 密钥
if not os.getenv("DEEPSEEK_API_KEY"):
print("请设置 DEEPSEEK_API_KEY 环境变量。")
else:
# 初始化模型
chat_model = DeepseekChatModel(temperature=0.7)
# 构建输入消息
from langchain_core.messages import HumanMessage
messages = [HumanMessage(content="你好,请介绍一下新加坡。")]
# 调用模型
# 当你调用 .invoke() 或 .stream() 时LangChain 内部会调用我们重写的 _generate 方法
response = chat_model.invoke(messages)
print(f"\n最终得到的回复:\n{response.content}")

View File

@ -8,6 +8,7 @@ from langchain_core.runnables import Runnable
from pydantic import Field from pydantic import Field
import json import json
import copy import copy
from langchain_core.messages import AnyMessage,HumanMessage
from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam from volcenginesdkarkruntime.types.chat import ChatCompletionToolParam
from volcenginesdkarkruntime.types.shared_params.function_definition import FunctionDefinition from volcenginesdkarkruntime.types.shared_params.function_definition import FunctionDefinition
from volcenginesdkarkruntime.types.shared_params.function_parameters import FunctionParameters from volcenginesdkarkruntime.types.shared_params.function_parameters import FunctionParameters
@ -174,15 +175,16 @@ class HuoshanChatModel(BaseChatModel):
api_messages = self._convert_messages_to_prompt(messages) api_messages = self._convert_messages_to_prompt(messages)
tools = kwargs.get("tools", []) tools = kwargs.get("tools", [])
# print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}") # print(f" 提交给豆包的 messages数组长度: \n {len(messages)} \n tools: {tools}")
print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> \n") # print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> \n")
print(f"\nmessages: \n") # print(f"\nmessages: \n")
for message in messages: # for message in messages:
print(f" {message.type}: \n ") # print("--- Message Attributes ---\n")
print(f" {message.content} \n ") # for key, value in message.__dict__.items():
print(f"\ntools: \n") # print(f" {key}: {value} \n")
for tool in tools: # print(f"\ntools: \n")
print(f" \n {tool} \n ") # for tool in tools:
print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> end \n") # print(f" {tool} \n ")
# print(f" 提交给豆包的消息 =======================>>>>>>>>>>>> end \n")
response_data = self._api.get_chat_response(messages=api_messages, tools=tools) response_data = self._api.get_chat_response(messages=api_messages, tools=tools)
@ -192,7 +194,7 @@ class HuoshanChatModel(BaseChatModel):
message_from_api = res_choices.get("message", {}) message_from_api = res_choices.get("message", {})
tool_calls = message_from_api.get("tool_calls", []) tool_calls = message_from_api.get("tool_calls", [])
print(f" 豆包返回的 finish_reason: {finish_reason} \n tool_calls: {tool_calls} \n") print(f" 豆包返回的 finish_reason: {finish_reason} \n tool_calls: {tool_calls} \n")
print(f" 豆包返回的 message: {message_from_api.get('content', '')}") # print(f" 豆包返回的 message: {message_from_api.get('content', '')}")
if finish_reason == "tool_calls" and tool_calls: if finish_reason == "tool_calls" and tool_calls:
lc_tool_calls = [] lc_tool_calls = []
for tc in tool_calls: for tc in tool_calls:

View File

@ -0,0 +1,193 @@
import os
from typing import Any, List, Optional, Type
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool, tool
# --- 1. 中央模型配置 ---
# 在这里添加或修改你想要支持的模型提供商
# 'api_key_env': 存放 API Key 的环境变量名称
# 'base_url': 模型的 API 端点
# 'supports_tools': 该模型是否支持 LangChain 的 .bind_tools() 功能
# 'default_model': 如果不指定模型,默认使用的模型名称
MODEL_CONFIG = {
"openai": {
"api_key_env": "OPENAI_API_KEY",
"base_url": "https://api.openai.com/v1",
"supports_tools": True,
"default_model": "gpt-4o",
},
"deepseek": {
"api_key_env": "DEEPSEEK_API_KEY",
"base_url": "https://api.deepseek.com/v1",
"supports_tools": True,
"default_model": "deepseek-chat",
},
"groq": {
"api_key_env": "GROQ_API_KEY",
"base_url": "https://api.groq.com/openai/v1",
"supports_tools": True,
"default_model": "llama3-70b-8192",
},
"moonshot": {
"api_key_env": "MOONSHOT_API_KEY",
"base_url": "https://api.moonshot.cn/v1",
"supports_tools": True,
"default_model": "moonshot-v1-8k",
},
# 示例:一个不支持工具的老模型
"legacy_provider": {
"api_key_env": "LEGACY_API_KEY",
"base_url": "http://localhost:8080/v1", # 假设是本地服务
"supports_tools": False,
"default_model": "legacy-model-v1",
},
# 新增火山引擎的配置
"huoshan": {
"api_key_env": "VOLC_API_KEY", # 请确认您存放密钥的环境变量名
"base_url": "https://ark.cn-beijing.volces.com/api/v3", # 这是火山引擎方舟的 OpenAI 兼容端点
"supports_tools": True,
"default_model": "ep-20240615082149-j225c", # 使用您需要的模型 Endpoint ID
},
}
# --- 2. 模型创建工厂函数 ---
def create_llm_client(
provider: str,
model_name: Optional[str] = None,
tools: Optional[List[Type[BaseModel]]] = None,
**kwargs: Any,
) -> ChatOpenAI:
"""
根据提供的 provider 创建并配置一个 LangChain 聊天模型客户端
Args:
provider: 模型提供商的名称 (必须是 MODEL_CONFIG 中的一个 key)
model_name: 要使用的具体模型名称如果为 None, 则使用配置中的默认模型
tools: 一个工具列表用于绑定到模型上以实现 function calling
**kwargs: 其他要传递给 ChatOpenAI 的参数 (例如 temperature, max_tokens)
Returns:
一个配置好的 ChatOpenAI 实例可能已经绑定了工具
"""
if provider not in MODEL_CONFIG:
raise ValueError(f"不支持的模型提供商: {provider}。可用选项: {list(MODEL_CONFIG.keys())}")
config = MODEL_CONFIG[provider]
api_key = os.getenv(config["api_key_env"])
if not api_key:
raise ValueError(f"请设置环境变量 {config['api_key_env']} 以使用 {provider} 模型。")
# 如果未指定 model_name则使用配置中的默认值
final_model_name = model_name or config["default_model"]
# 创建基础的 LLM 客户端
llm = ChatOpenAI(
model=final_model_name,
api_key=api_key,
base_url=config["base_url"],
**kwargs,
)
# 根据配置,有条件地绑定工具
if tools:
if config["supports_tools"]:
print(f"[{provider}] 模型支持工具,正在绑定 {len(tools)} 个工具...")
return llm.bind_tools(tools)
else:
print(f"⚠️ 警告: 您为 [{provider}] 提供了工具,但该模型在配置中被标记为不支持工具。将返回未绑定工具的模型。")
return llm
return llm
# --- 3. 定义你的工具 (Function Calling) ---
# 使用 Pydantic 模型定义工具的输入参数,确保类型安全和清晰的描述
class GetWeatherInput(BaseModel):
city: str = Field(description="需要查询天气的城市名称, 例如: Singapore")
# 使用 @tool 装饰器可以轻松地将任何函数转换为 LangChain 工具
@tool(args_schema=GetWeatherInput)
def get_current_weather(city: str) -> str:
"""
当需要查询指定城市的当前天气时调用此工具
"""
# 这是一个模拟实现实际应用中你会在这里调用真实的天气API
print(f"--- 正在调用工具: get_current_weather(city='{city}') ---")
if "singapore" in city.lower():
return f"新加坡今天的天气是晴朗,温度为 31°C。"
elif "beijing" in city.lower():
return f"北京今天的天气是多云,温度为 25°C。"
else:
return f"抱歉,我无法查询到 {city} 的天气信息。"
# --- 4. 主程序:演示如何使用 ---
if __name__ == "__main__":
# 将你的工具放入一个列表
my_tools = [get_current_weather]
# ---- 示例 1: 使用 DeepSeek 并调用工具 ----
print("\n================ 示例 1: 使用 DeepSeek 调用工具 ================")
# 确保你已经设置了环境变量: export DEEPSEEK_API_KEY="sk-..."
try:
deepseek_llm_with_tools = create_llm_client(
provider="deepseek",
tools=my_tools,
temperature=0
)
prompt = "今天新加坡的天气怎么样?"
print(f"用户问题: {prompt}")
# LangChain 会自动处理LLM -> Tool Call -> Execute Tool -> LLM -> Final Answer
# 为了演示,我们只看第一步的输出
ai_msg = deepseek_llm_with_tools.invoke(prompt)
print("\nLLM 返回的初步响应 (AIMessage):")
print(ai_msg)
# 检查返回的是否是工具调用
if ai_msg.tool_calls:
print(f"\n模型请求调用工具: {ai_msg.tool_calls[0]['name']}")
# 在实际应用中,你会在这里执行工具并把结果返回给模型
else:
print("\n模型直接给出了回答:")
print(ai_msg.content)
except ValueError as e:
print(e)
# ---- 示例 2: 使用 Groq (Llama3) 且不使用工具 ----
print("\n================ 示例 2: 使用 Groq 进行常规聊天 ================")
# 确保你已经设置了环境变量: export GROQ_API_KEY="gsk_..."
try:
groq_llm = create_llm_client(
provider="groq",
temperature=0.7,
# model_name="llama3-8b-8192" # 你也可以覆盖默认模型
)
prompt = "请给我写一首关于新加坡的五言绝句。"
print(f"用户问题: {prompt}")
response = groq_llm.invoke(prompt)
print("\nGroq (Llama3) 的回答:")
print(response.content)
except ValueError as e:
print(e)
# ---- 示例 3: 尝试给不支持工具的模型绑定工具 ----
print("\n========== 示例 3: 尝试为不支持工具的模型绑定工具 ==========")
# 假设你设置了 export LEGACY_API_KEY="some_key"
try:
legacy_llm = create_llm_client(
provider="legacy_provider",
tools=my_tools
)
# 注意,这里会打印出警告信息,并且 legacy_llm 不会绑定任何工具
except ValueError as e:
print(e)

8
utils/stream.py Normal file
View File

@ -0,0 +1,8 @@
def debug_print_stream(stream):
for chunk in stream:
message = chunk['messages'][-1]
if isinstance(message, tuple):
print(chunk)
else:
message.pretty_print()
print("\n---")