1168 lines
43 KiB
Python

"""
WebSocket Streaming API
提供实时执行进度更新的 WebSocket 端点
"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query
from typing import Dict, Set, Optional, Any
import json
import asyncio
from datetime import datetime
from pathlib import Path
from app.config import settings
from app.utils.logger import get_logger
from app.core.agent_runtime.director_agent import DirectorAgent
from app.db.repositories import message_repo
logger = get_logger(__name__)
router = APIRouter()
# ============================================
# WebSocket 连接管理
# ============================================
class ConnectionManager:
"""WebSocket 连接管理器"""
def __init__(self):
# 项目ID -> WebSocket连接集合
self.project_connections: Dict[str, Set[WebSocket]] = {}
# 批次ID -> WebSocket连接集合
self.batch_connections: Dict[str, Set[WebSocket]] = {}
# 项目ID -> Agent实例
self.project_agents: Dict[str, DirectorAgent] = {}
async def connect_to_project(self, websocket: WebSocket, project_id: str):
"""连接到项目执行流"""
await websocket.accept()
if project_id not in self.project_connections:
self.project_connections[project_id] = set()
self.project_connections[project_id].add(websocket)
logger.info(f"WebSocket 已连接到项目: {project_id}")
async def connect_to_batch(self, websocket: WebSocket, batch_id: str):
"""连接到批次执行流"""
await websocket.accept()
if batch_id not in self.batch_connections:
self.batch_connections[batch_id] = set()
self.batch_connections[batch_id].add(websocket)
logger.info(f"WebSocket 已连接到批次: {batch_id}")
def disconnect(self, websocket: WebSocket):
"""断开连接"""
# 从所有项目连接中移除
for project_id, connections in self.project_connections.items():
if websocket in connections:
connections.remove(websocket)
logger.info(f"WebSocket 已从项目断开: {project_id}")
if not connections:
del self.project_connections[project_id]
# 清理 Agent 实例
if project_id in self.project_agents:
del self.project_agents[project_id]
# 从所有批次连接中移除
for batch_id, connections in self.batch_connections.items():
if websocket in connections:
connections.remove(websocket)
logger.info(f"WebSocket 已从批次断开: {batch_id}")
if not connections:
del self.batch_connections[batch_id]
async def send_to_project(
self,
project_id: str,
message: Dict[str, Any],
exclude: Optional[WebSocket] = None
):
"""向项目的所有连接发送消息"""
if project_id in self.project_connections:
disconnected = set()
for connection in self.project_connections[project_id]:
if connection != exclude:
try:
await connection.send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {str(e)}")
disconnected.add(connection)
# 清理断开的连接
for connection in disconnected:
self.disconnect(connection)
async def send_to_batch(
self,
batch_id: str,
message: Dict[str, Any],
exclude: Optional[WebSocket] = None
):
"""向批次的所有连接发送消息"""
if batch_id in self.batch_connections:
disconnected = set()
for connection in self.batch_connections[batch_id]:
if connection != exclude:
try:
await connection.send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {str(e)}")
disconnected.add(connection)
# 清理断开的连接
for connection in disconnected:
self.disconnect(connection)
def get_agent(self, project_id: str, working_dir: Path) -> DirectorAgent:
"""获取或创建 Agent 实例"""
if project_id not in self.project_agents:
# 确保工作目录存在
working_dir.mkdir(parents=True, exist_ok=True)
# 检查配置的模型类型
model_name = settings.zai_model
enable_thinking = True
# 如果是 GLM 模型,禁用 thinking 模式(不支持)
if "glm" in model_name.lower():
enable_thinking = False
# 不在这里加载项目上下文,而是在 WebSocket 消息处理时加载
# 因为 get_agent 是同步方法,而 project_repo.get 是异步的
self.project_agents[project_id] = DirectorAgent(
working_directory=working_dir,
enable_thinking=enable_thinking,
model=model_name,
project_context=None # 稍后在消息处理中更新
)
return self.project_agents[project_id]
def get_project_connections_count(self, project_id: str) -> int:
"""获取项目的连接数"""
return len(self.project_connections.get(project_id, set()))
def get_batch_connections_count(self, batch_id: str) -> int:
"""获取批次的连接数"""
return len(self.batch_connections.get(batch_id, set()))
# 全局连接管理器
manager = ConnectionManager()
# ============================================
# WebSocket 端点
# ============================================
@router.websocket("/ws/projects/{project_id}/execute")
async def websocket_project_execution(
websocket: WebSocket,
project_id: str
):
"""
项目执行 WebSocket 端点
"""
await manager.connect_to_project(websocket, project_id)
# 准备工作目录 (假设在 projects/{id})
# 注意:这里需要根据实际配置调整路径
project_dir = Path(f"d:/platform/creative_studio/workspace/projects/{project_id}")
try:
# 发送连接确认
await websocket.send_json({
"type": "connected",
"data": {
"project_id": project_id,
"message": "已连接到项目执行流",
"timestamp": datetime.now().isoformat()
}
})
# 加载并发送历史消息
history = await message_repo.get_history(project_id)
if history:
await websocket.send_json({
"type": "history",
"messages": history
})
# 保持连接并接收客户端消息
while True:
try:
# 接收客户端消息
data = await websocket.receive_text()
message = json.loads(data)
# 处理客户端消息
await _handle_client_message(websocket, project_id, message, project_dir)
except WebSocketDisconnect:
logger.info(f"WebSocket 客户端主动断开: {project_id}")
break
except json.JSONDecodeError:
await websocket.send_json({
"type": "error",
"data": {
"message": "无效的 JSON 格式"
}
})
except Exception as e:
logger.error(f"处理 WebSocket 消息错误: {str(e)}")
await websocket.send_json({
"type": "error",
"data": {
"message": f"处理消息错误: {str(e)}"
}
})
except Exception as e:
logger.error(f"WebSocket 连接错误: {str(e)}")
finally:
manager.disconnect(websocket)
@router.websocket("/ws/batches/{batch_id}/execute")
async def websocket_batch_execution(
websocket: WebSocket,
batch_id: str
):
"""
批量执行 WebSocket 端点
"""
await manager.connect_to_batch(websocket, batch_id)
try:
# 发送连接确认
await websocket.send_json({
"type": "connected",
"data": {
"batch_id": batch_id,
"message": "已连接到批量执行流",
"timestamp": datetime.now().isoformat()
}
})
# 保持连接
while True:
try:
data = await websocket.receive_text()
# 批量执行目前不接受客户端控制消息,仅广播
# 但为了保持连接活性,可以处理 ping
message = json.loads(data)
if message.get("type") == "ping":
await websocket.send_json({
"type": "pong",
"data": {"timestamp": datetime.now().isoformat()}
})
except WebSocketDisconnect:
logger.info(f"WebSocket 客户端主动断开: {batch_id}")
break
except Exception as e:
logger.error(f"处理 WebSocket 消息错误: {str(e)}")
finally:
manager.disconnect(websocket)
# ============================================
# 消息处理
# ============================================
async def _ensure_agent_context(
agent: DirectorAgent,
project_id: str,
active_episode_number: Optional[int] = None,
active_episode_title: Optional[str] = None
):
"""确保 Agent 加载了正确的项目上下文"""
# 如果已经加载了该项目的上下文,且剧集信息一致,则跳过
if (agent.context and
agent.context.project_id == project_id and
agent.context.active_episode_number == active_episode_number):
return
try:
from app.db.repositories import project_repo
from app.core.agent_runtime.context import SkillAgentContext
from app.core.skills.skill_manager import skill_manager
project = await project_repo.get(project_id)
if not project:
logger.warning(f"Project {project_id} not found in repository")
return
# 1. 加载用户技能
user_skills = []
default_task_skills = getattr(project, 'defaultTaskSkills', [])
if isinstance(default_task_skills, list):
for task_config in default_task_skills:
task_skills = getattr(task_config, 'skills', [])
if isinstance(task_skills, list):
for skill_config in task_skills:
try:
skill_id = getattr(skill_config, 'skill_id', None) or (skill_config.get('skill_id') if isinstance(skill_config, dict) else None)
if skill_id:
skill = skill_manager.get_skill_by_id(skill_id)
if skill:
user_skills.append({
'id': skill.id,
'name': skill.name,
'behavior': skill.behavior_guide or skill.description or ''
})
except Exception as e:
logger.warning(f"Failed to load skill: {e}")
# 2. 整合角色信息
characters_text = ""
if project.globalContext:
if project.globalContext.styleGuide:
characters_text += project.globalContext.styleGuide + "\n\n"
if project.globalContext.characterProfiles:
# 确保 characterProfiles 是字典类型
profiles = project.globalContext.characterProfiles
if isinstance(profiles, dict):
characters_text += "### Detailed Character Profiles:\n"
for name, profile in profiles.items():
characters_text += f"- **{name}**: {getattr(profile, 'description', '')}\n"
if hasattr(profile, 'personality') and profile.personality:
characters_text += f" Personality: {profile.personality}\n"
else:
# characterProfiles 是字符串或其他类型,直接添加
characters_text += str(profiles) + "\n\n"
# 3. 获取所有剧集
from app.db.repositories import episode_repo
episodes_data = await episode_repo.list_by_project(project_id)
episodes = []
for ep in episodes_data:
episodes.append({
"number": ep.number,
"title": ep.title,
"status": ep.status
})
# 4. 创建项目上下文
project_context = SkillAgentContext(
skill_loader=agent.context.skill_loader,
working_directory=agent.context.working_directory,
project_id=project.id,
project_name=project.name,
project_genre=getattr(project, 'genre', '古风'),
total_episodes=project.totalEpisodes,
world_setting=project.globalContext.worldSetting if project.globalContext else None,
characters=characters_text.strip() or None,
overall_outline=project.globalContext.overallOutline if project.globalContext else None,
creation_mode='script' if (project.globalContext and project.globalContext.uploadedScript) else 'inspiration',
source_content=(project.globalContext.uploadedScript if project.globalContext and project.globalContext.uploadedScript
else project.globalContext.inspiration if project.globalContext else None),
user_skills=user_skills,
active_episode_number=active_episode_number,
active_episode_title=active_episode_title,
episodes=episodes
)
# 5. 更新 Agent
agent.context = project_context
agent.system_prompt = agent._build_system_prompt()
if hasattr(agent, 'refresh_agent'):
agent.refresh_agent()
msg = f"Successfully injected project context for {project_id}: {project.name}"
if active_episode_number:
msg += f" (Focusing on Episode {active_episode_number})"
logger.info(msg)
except Exception as e:
logger.error(f"Failed to load project context for {project_id}: {e}", exc_info=True)
async def _handle_client_message(
websocket: WebSocket,
project_id: str,
message: Dict[str, Any],
project_dir: Path
):
"""
处理客户端发送的消息
"""
message_type = message.get("type")
logger.info(f"收到消息: type={message_type}, full_message={message}")
if message_type == "ping":
# 心跳响应
await websocket.send_json({
"type": "pong",
"data": {
"timestamp": datetime.now().isoformat()
}
})
elif message_type == "focus_episode":
# 客户端请求切换关注的剧集
episode_number = message.get("episodeNumber")
episode_title = message.get("episodeTitle")
agent = manager.get_agent(project_id, project_dir)
await _ensure_agent_context(agent, project_id, episode_number, episode_title)
await websocket.send_json({
"type": "focus_confirmed",
"data": {
"episodeNumber": episode_number,
"episodeTitle": episode_title
}
})
elif message_type == "update_episode_title":
# 用户手动编辑画布标题
episode_number = message.get("episodeNumber")
new_title = message.get("title")
if episode_number and new_title:
try:
from app.db.repositories import episode_repo
episodes = await episode_repo.list_by_project(project_id)
episode = next((ep for ep in episodes if ep.number == episode_number), None)
if episode:
episode.title = new_title
await episode_repo.update(episode)
logger.info(f"Manual title update for EP{episode_number}: {new_title}")
# 广播更新给所有连接
await manager.send_to_project(project_id, {
"type": "episode_updated",
"data": {
"number": episode_number,
"title": new_title,
"status": episode.status
}
})
except Exception as e:
logger.error(f"Failed to update episode title: {e}")
elif message_type == "inbox_action":
# 用户在 Inbox 中点击批准或拒绝
action = message.get("action")
item_id = message.get("itemId")
# 将操作转换为自然语言反馈给 Agent
feedback = f"User {action}ed inbox item {item_id}."
agent = manager.get_agent(project_id, project_dir)
# 确保上下文已加载
await _ensure_agent_context(agent, project_id)
try:
for event in agent.stream_events(feedback, thread_id=project_id):
# 同样的事件处理逻辑
if event.get("type") == "tool_call":
await _handle_tool_call(project_id, event)
await manager.send_to_project(project_id, event)
except Exception as e:
await manager.send_to_project(project_id, {
"type": "error",
"data": {"message": str(e)}
})
elif message_type == "chat_message":
# 用户发送聊天消息 -> 触发 Agent 执行
content = message.get("content", "")
episode_number = message.get("episodeNumber")
episode_title = message.get("episodeTitle")
if not content:
return
# 保存用户消息
await message_repo.add_message(project_id, "user", content)
# 获取 Agent
agent = manager.get_agent(project_id, project_dir)
# 确保上下文已加载(包含当前剧集信息)
await _ensure_agent_context(agent, project_id, episode_number, episode_title)
# 异步运行 Agent 并将事件流推送到前端
full_response = ""
try:
for event in agent.stream_events(content, thread_id=project_id):
# 检查特殊工具调用并转换格式
if event.get("type") == "tool_call":
await _handle_tool_call(project_id, event)
# 收集 Agent 回复内容
if event.get("type") == "text":
full_response += event.get("content", "")
await manager.send_to_project(project_id, event)
# 保存 Agent 回复
if full_response:
await message_repo.add_message(project_id, "agent", full_response)
except Exception as e:
await manager.send_to_project(project_id, {
"type": "error",
"data": {"message": str(e)}
})
elif message_type == "get_status":
# 请求状态
# 这里可以返回 Agent 的状态,或者之前的 executor 状态
pass
else:
await websocket.send_json({
"type": "error",
"data": {
"message": f"未知消息类型: {message_type}"
}
})
async def _sync_context_states(
project_id: str,
episode_number: int,
memory: Any
):
"""同步上下文状态到前端"""
try:
# 提取角色状态作为上下文状态
context_states = []
# 添加时间状态
context_states.append({
"type": "time",
"value": f"EP{episode_number} 完成后"
})
# 添加角色状态
character_states = getattr(memory, 'characterStates', {})
if isinstance(character_states, dict):
for char_name, states in character_states.items():
if states and isinstance(states, list) and len(states) > 0:
latest_state = states[-1]
if isinstance(latest_state, dict):
state_value = latest_state.get('state', f"{char_name}状态")
else:
state_value = getattr(latest_state, 'state', f"{char_name}状态")
context_states.append({
"type": "character",
"value": f"{char_name}: {state_value}",
"character": char_name,
"state": state_value
})
# 添加待收线数量
pending_threads = getattr(memory, 'pendingThreads', [])
if pending_threads:
context_states.append({
"type": "pending_threads",
"value": f"待收线: {len(pending_threads)}"
})
# 广播上下文更新
await manager.send_to_project(project_id, {
"type": "context_update",
"states": context_states,
"episode_number": episode_number
})
logger.info(f"已同步上下文状态到项目 {project_id}, {len(context_states)} 个状态")
except Exception as e:
logger.error(f"同步上下文状态失败: {str(e)}")
async def _handle_tool_call(project_id: str, event: Dict[str, Any]):
"""
处理工具调用,转换为特定的 WebSocket 消息
这个函数在 agent.stream_events() 中被调用,当检测到 director 工具调用时,
会将其转换为前端可以理解的 WebSocket 事件格式。
"""
name = event.get("name")
args = event.get("args", {})
# Director 工具处理
if name == "update_plan":
await manager.send_to_project(project_id, {
"type": "plan_update",
"plan": args.get("steps", []),
"status": args.get("status", "planning"),
"current_step_index": args.get("current_step_index", 0)
})
elif name == "add_inbox_task":
await manager.send_to_project(project_id, {
"type": "review_request",
"id": f"task_{args.get('title', 'unknown')}_{int(datetime.now().timestamp())}",
"title": args.get("title"),
"description": args.get("description"),
"options": args.get("options", ["Approve", "Reject"]),
"timestamp": int(datetime.now().timestamp() * 1000)
})
elif name == "add_annotation":
await manager.send_to_project(project_id, {
"type": "annotation_add",
"annotation": {
"content": args.get("content"),
"type": args.get("annotation_type", "review"),
"suggestion": args.get("suggestion", ""),
"timestamp": int(datetime.now().timestamp() * 1000)
}
})
elif name == "update_context":
# 解析 data (可能是 JSON string 或 dict)
data = args.get("data")
context_type = args.get("context_type", "state")
try:
if isinstance(data, str):
data = json.loads(data)
# 转换为前端期望的 activeStates 格式
if isinstance(data, dict):
# 将字典转换为 [{type, value}, ...] 格式
states = [{"type": k, "value": v} for k, v in data.items()]
elif isinstance(data, list):
# 已经是列表格式
states = data
else:
states = [{"type": context_type, "value": str(data)}]
await manager.send_to_project(project_id, {
"type": "context_update",
"states": states
})
except Exception as e:
logger.warning(f"Failed to process context update: {e}")
elif name == "write_to_canvas":
# 新的 write_to_canvas 工具
content = args.get("content", "")
if content:
await manager.send_to_project(project_id, {
"type": "canvas_update",
"content": content
})
elif name == "write_file":
# 如果写入的是当前画布文件,也更新画布
# 这里简化:只要有内容就更新画布
content = args.get("content")
if content:
await manager.send_to_project(project_id, {
"type": "canvas_update",
"content": content
})
elif name == "update_memory":
# 处理记忆库更新
memory_type = args.get("memory_type", "timeline")
data = args.get("data", {})
# 根据记忆类型格式化数据
memory_data = {
"type": memory_type,
"memory_type": memory_type,
"timestamp": int(datetime.now().timestamp() * 1000)
}
# 添加具体信息
if memory_type == "timeline":
memory_data.update({
"title": data.get("event", "时间线事件"),
"description": data.get("description", "")
})
elif memory_type == "character_state":
memory_data.update({
"title": f"{data.get('character', '角色')}状态变化",
"description": data.get("state", ""),
"character": data.get("character", "")
})
elif memory_type == "pending_thread":
memory_data.update({
"title": "待收线问题",
"description": data.get("description", "")
})
elif memory_type == "foreshadowing":
memory_data.update({
"title": "伏笔",
"description": data.get("description", "")
})
await manager.send_to_project(project_id, {
"type": "memory_update",
"data": memory_data
})
elif name == "save_episode":
# 处理剧集保存并持久化
episode_number = args.get("episode_number")
title = args.get("title")
content = args.get("content")
outline = args.get("outline")
if episode_number:
try:
from app.db.repositories import episode_repo, project_repo
from app.core.memory.memory_manager import get_memory_manager
episodes = await episode_repo.list_by_project(project_id)
episode = next((ep for ep in episodes if ep.number == episode_number), None)
if episode:
if title: episode.title = title
if content: episode.content = content
if outline: episode.outline = outline
episode.status = "completed"
await episode_repo.update(episode)
logger.info(f"Persisted episode {episode_number} via save_episode")
# 自动更新记忆库
try:
project = await project_repo.get(project_id)
if project and episode.content:
memory_manager = get_memory_manager()
await memory_manager.update_memory_from_episode(project, episode)
logger.info(f"Updated memory after saving episode {episode_number}")
# 同步上下文状态到前端
await _sync_context_states(project_id, episode_number, project.memory)
except Exception as memory_error:
logger.warning(f"Failed to update memory for episode {episode_number}: {memory_error}")
# 广播更新
await manager.send_to_project(project_id, {
"type": "episode_updated",
"data": {
"number": episode_number,
"title": episode.title,
"status": episode.status
}
})
except Exception as e:
logger.error(f"Failed to persist episode via save_episode: {e}")
await manager.send_to_project(project_id, {
"type": "episode_saved",
"episode_number": episode_number,
"title": title
})
elif name == "update_episode":
# 处理剧集更新并持久化
episode_number = args.get("episode_number")
title = args.get("title")
content = args.get("content")
status = args.get("status")
outline = args.get("outline")
if episode_number:
try:
from app.db.repositories import episode_repo, project_repo
from app.core.memory.memory_manager import get_memory_manager
episodes = await episode_repo.list_by_project(project_id)
episode = next((ep for ep in episodes if ep.number == episode_number), None)
if episode:
if title is not None: episode.title = title
if content is not None: episode.content = content
if status is not None: episode.status = status
if outline is not None: episode.outline = outline
await episode_repo.update(episode)
logger.info(f"Updated episode {episode_number} via update_episode")
# 如果有新内容,自动更新记忆库
if content and content.strip():
try:
project = await project_repo.get(project_id)
if project:
memory_manager = get_memory_manager()
await memory_manager.update_memory_from_episode(project, episode)
logger.info(f"Updated memory after updating episode {episode_number}")
# 同步上下文状态到前端
await _sync_context_states(project_id, episode_number, project.memory)
except Exception as memory_error:
logger.warning(f"Failed to update memory for episode {episode_number}: {memory_error}")
# 广播更新
await manager.send_to_project(project_id, {
"type": "episode_updated",
"data": {
"number": episode_number,
"title": episode.title,
"status": episode.status
}
})
except Exception as e:
logger.error(f"Failed to update episode: {e}")
elif name == "focus_episode":
# 处理剧集焦点切换
episode_number = args.get("episode_number")
title = args.get("title")
await manager.send_to_project(project_id, {
"type": "focus_update",
"episodeNumber": episode_number,
"episodeTitle": title
})
elif name == "create_episode":
# 处理剧集创作请求
episode_number = args.get("episode_number")
analyze_previous_memory = args.get("analyze_previous_memory", True)
if episode_number:
# 使用 ensure_future 确保任务在后台执行,即使 WebSocket 断开也能继续
asyncio.ensure_future(_execute_episode_creation(
project_id, episode_number, analyze_previous_memory
))
logger.info(f"已启动后台创作任务: EP{episode_number}")
# ============================================
# 辅助函数 - 用于从其他模块发送消息
# ============================================
# 全局后台任务跟踪
_background_tasks: Dict[str, asyncio.Task] = {}
async def _execute_episode_creation(
project_id: str,
episode_number: int,
analyze_previous_memory: bool
):
"""
异步执行剧集创作
这个函数在后台执行,不会阻塞 WebSocket 连接。
即使 WebSocket 断开,任务也会继续执行并保存到数据库。
它会:
1. 分析上一集的记忆(如果需要)
2. 执行剧集创作
3. 通过 WebSocket 发送进度更新(如果连接存在)
4. 将创作内容发送到画布(如果连接存在)
5. 无论如何都保存到数据库
"""
task_key = f"{project_id}_{episode_number}"
try:
from app.db.repositories import project_repo, episode_repo
from app.core.agents.series_creation_agent import get_series_agent
from app.core.memory.memory_manager import get_memory_manager
# 注册到后台任务字典
current_task = asyncio.current_task()
if current_task:
_background_tasks[task_key] = current_task
logger.info(f"开始后台创作任务: {task_key}")
# 辅助函数:安全发送 WebSocket 消息(忽略连接错误)
async def safe_send(message_type: str, data: dict = None):
try:
if data is None:
data = {}
await manager.send_to_project(project_id, {
"type": message_type,
"data": data
})
except Exception as e:
# WebSocket 可能已断开,忽略错误继续执行
logger.debug(f"WebSocket 发送失败(可能已断开): {e}")
# 获取项目
project = await project_repo.get(project_id)
if not project:
await safe_send("error", {"message": f"项目不存在: {project_id}"})
return
# 更新计划状态 - 开始
await safe_send("plan_update", {
"plan": [
f"分析 EP{episode_number - 1 if episode_number > 1 else 'N/A'} 的记忆系统" if analyze_previous_memory and episode_number > 1 else "跳过记忆分析(首集)",
f"生成 EP{episode_number} 大纲",
f"创作 EP{episode_number} 对话内容",
f"执行质量审核",
f"更新记忆系统"
],
"status": "planning",
"current_step_index": 0
})
# 步骤 1: 分析上一集记忆(如果需要)
if analyze_previous_memory and episode_number > 1:
await safe_send("plan_update", {
"plan": [
f"分析 EP{episode_number - 1} 的记忆系统",
f"生成 EP{episode_number} 大纲",
f"创作 EP{episode_number} 对话内容",
f"执行质量审核",
f"更新记忆系统"
],
"status": "planning",
"current_step_index": 0
})
# 获取上一集内容
prev_episodes = await episode_repo.list_by_project(project_id)
prev_episode = next((ep for ep in prev_episodes if ep.number == episode_number - 1), None)
if prev_episode and prev_episode.content:
await safe_send("text", {"content": f"\n\n--- 正在分析 EP{episode_number - 1} 的记忆系统 ---\n"})
# 使用 MemoryManager 更新记忆
try:
memory_manager = get_memory_manager()
await memory_manager.update_memory_from_episode(project, prev_episode)
logger.info(f"EP{episode_number - 1} 记忆已分析并注入到 EP{episode_number}")
except Exception as e:
logger.warning(f"分析 EP{episode_number - 1} 记忆失败: {e}")
# 步骤 2-5: 执行剧集创作
agent = get_series_agent()
await safe_send("plan_update", {
"plan": [
f"分析 EP{episode_number - 1} 的记忆系统" if analyze_previous_memory and episode_number > 1 else "跳过记忆分析",
f"生成 EP{episode_number} 大纲",
f"创作 EP{episode_number} 对话内容",
f"执行质量审核",
f"更新记忆系统"
],
"status": "writing",
"current_step_index": 1
})
await safe_send("text", {"content": f"\n\n--- 开始创作 EP{episode_number} ---\n"})
# 执行创作
episode = await agent.execute_episode(
project=project,
episode_number=episode_number,
title=f"{episode_number}"
)
# 检查是否创作成功(错误处理)
if episode.status == "needs-review" and not episode.content:
# 创作失败,没有内容
await safe_send("error", {
"message": f"EP{episode_number} 创作失败",
"episode_number": episode_number
})
await safe_send("text", {"content": f"\n\n❌ EP{episode_number} 创作失败。请检查错误日志并重试。\n"})
logger.error(f"EP{episode_number} 创作失败,无内容生成")
return
# 创作成功,保存到数据库(无论 WebSocket 是否连接)
existing_episodes = await episode_repo.list_by_project(project_id)
episode_record = next((ep for ep in existing_episodes if ep.number == episode_number), None)
if episode_record:
episode.id = episode_record.id
episode.projectId = project_id
await episode_repo.update(episode)
logger.info(f"更新现有剧集记录: {episode.id}")
else:
episode.projectId = project_id
await episode_repo.create(episode)
logger.info(f"创建新剧集记录: {episode.id}")
# 发送内容到画布
if episode.content:
await safe_send("canvas_update", {"content": episode.content})
# 更新记忆
await project_repo.update(project_id, {
"memory": project.memory.dict()
})
# 完成消息
await safe_send("plan_update", {
"plan": [
f"分析 EP{episode_number - 1} 的记忆系统" if analyze_previous_memory and episode_number > 1 else "跳过记忆分析",
f"生成 EP{episode_number} 大纲",
f"创作 EP{episode_number} 对话内容",
f"执行质量审核",
f"更新记忆系统"
],
"status": "idle",
"current_step_index": 4
})
await safe_send("text", {"content": f"\n\n✅ EP{episode_number} 创作完成!质量分数: {episode.qualityScore or 0}\n"})
# 广播更新
await safe_send("episode_updated", {
"number": episode_number,
"title": episode.title,
"status": episode.status
})
logger.info(f"EP{episode_number} 后台创作完成,已保存到数据库")
except Exception as e:
logger.error(f"执行剧集创作失败: {str(e)}", exc_info=True)
try:
await manager.send_to_project(project_id, {
"type": "error",
"data": {
"message": f"EP{episode_number} 创作失败: {str(e)}",
"episode_number": episode_number
}
})
await manager.send_to_project(project_id, {
"type": "text",
"data": {"content": f"\n\n❌ EP{episode_number} 创作失败: {str(e)}\n"}
})
except Exception:
# WebSocket 可能已断开,忽略错误
pass
finally:
# 清理任务跟踪
if task_key in _background_tasks:
del _background_tasks[task_key]
logger.info(f"后台创作任务结束: {task_key}")
# ============================================
# 辅助函数 - 用于从其他模块发送消息
# ============================================
async def broadcast_stage_update(
project_id: str,
episode_number: int,
stage: str,
data: Dict[str, Any]
):
message = {
"type": "stage_update",
"data": {
"project_id": project_id,
"episode_number": episode_number,
"stage": stage,
**data
},
"timestamp": datetime.now().isoformat()
}
await manager.send_to_project(project_id, message)
async def broadcast_episode_complete(
project_id: str,
episode_number: int,
success: bool,
quality_score: float,
data: Dict[str, Any]
):
message = {
"type": "episode_complete",
"data": {
"project_id": project_id,
"episode_number": episode_number,
"success": success,
"quality_score": quality_score,
**data
},
"timestamp": datetime.now().isoformat()
}
await manager.send_to_project(project_id, message)
async def broadcast_batch_progress(
batch_id: str,
current_episode: int,
total_episodes: int,
completed: int,
failed: int,
data: Dict[str, Any]
):
message = {
"type": "batch_progress",
"data": {
"batch_id": batch_id,
"current_episode": current_episode,
"total_episodes": total_episodes,
"completed_episodes": completed,
"failed_episodes": failed,
"progress_percentage": (completed / total_episodes * 100) if total_episodes > 0 else 0,
**data
},
"timestamp": datetime.now().isoformat()
}
await manager.send_to_batch(batch_id, message)
async def broadcast_error(
project_id: str,
episode_number: Optional[int],
error: str,
error_type: str = "execution_error"
):
message = {
"type": "error",
"data": {
"project_id": project_id,
"episode_number": episode_number,
"error": error,
"error_type": error_type
},
"timestamp": datetime.now().isoformat()
}
await manager.send_to_project(project_id, message)
async def broadcast_batch_complete(
batch_id: str,
summary: Dict[str, Any]
):
message = {
"type": "batch_complete",
"data": {
"batch_id": batch_id,
**summary
},
"timestamp": datetime.now().isoformat()
}
await manager.send_to_batch(batch_id, message)
async def broadcast_to_project(
project_id: str,
message_type: str,
data: Dict[str, Any]
):
"""向项目的所有连接广播消息"""
message = {
"type": message_type,
"data": data,
"timestamp": datetime.now().isoformat()
}
await manager.send_to_project(project_id, message)
# 导出连接管理器和辅助函数
__all__ = [
"manager",
"broadcast_stage_update",
"broadcast_episode_complete",
"broadcast_batch_progress",
"broadcast_error",
"broadcast_batch_complete",
"broadcast_to_project"
]