1168 lines
43 KiB
Python
1168 lines
43 KiB
Python
"""
|
|
WebSocket Streaming API
|
|
|
|
提供实时执行进度更新的 WebSocket 端点
|
|
"""
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query
|
|
from typing import Dict, Set, Optional, Any
|
|
import json
|
|
import asyncio
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
from app.config import settings
|
|
from app.utils.logger import get_logger
|
|
from app.core.agent_runtime.director_agent import DirectorAgent
|
|
from app.db.repositories import message_repo
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# ============================================
|
|
# WebSocket 连接管理
|
|
# ============================================
|
|
|
|
class ConnectionManager:
|
|
"""WebSocket 连接管理器"""
|
|
|
|
def __init__(self):
|
|
# 项目ID -> WebSocket连接集合
|
|
self.project_connections: Dict[str, Set[WebSocket]] = {}
|
|
# 批次ID -> WebSocket连接集合
|
|
self.batch_connections: Dict[str, Set[WebSocket]] = {}
|
|
# 项目ID -> Agent实例
|
|
self.project_agents: Dict[str, DirectorAgent] = {}
|
|
|
|
async def connect_to_project(self, websocket: WebSocket, project_id: str):
|
|
"""连接到项目执行流"""
|
|
await websocket.accept()
|
|
if project_id not in self.project_connections:
|
|
self.project_connections[project_id] = set()
|
|
self.project_connections[project_id].add(websocket)
|
|
logger.info(f"WebSocket 已连接到项目: {project_id}")
|
|
|
|
async def connect_to_batch(self, websocket: WebSocket, batch_id: str):
|
|
"""连接到批次执行流"""
|
|
await websocket.accept()
|
|
if batch_id not in self.batch_connections:
|
|
self.batch_connections[batch_id] = set()
|
|
self.batch_connections[batch_id].add(websocket)
|
|
logger.info(f"WebSocket 已连接到批次: {batch_id}")
|
|
|
|
def disconnect(self, websocket: WebSocket):
|
|
"""断开连接"""
|
|
# 从所有项目连接中移除
|
|
for project_id, connections in self.project_connections.items():
|
|
if websocket in connections:
|
|
connections.remove(websocket)
|
|
logger.info(f"WebSocket 已从项目断开: {project_id}")
|
|
if not connections:
|
|
del self.project_connections[project_id]
|
|
# 清理 Agent 实例
|
|
if project_id in self.project_agents:
|
|
del self.project_agents[project_id]
|
|
|
|
# 从所有批次连接中移除
|
|
for batch_id, connections in self.batch_connections.items():
|
|
if websocket in connections:
|
|
connections.remove(websocket)
|
|
logger.info(f"WebSocket 已从批次断开: {batch_id}")
|
|
if not connections:
|
|
del self.batch_connections[batch_id]
|
|
|
|
async def send_to_project(
|
|
self,
|
|
project_id: str,
|
|
message: Dict[str, Any],
|
|
exclude: Optional[WebSocket] = None
|
|
):
|
|
"""向项目的所有连接发送消息"""
|
|
if project_id in self.project_connections:
|
|
disconnected = set()
|
|
for connection in self.project_connections[project_id]:
|
|
if connection != exclude:
|
|
try:
|
|
await connection.send_json(message)
|
|
except Exception as e:
|
|
logger.error(f"发送消息失败: {str(e)}")
|
|
disconnected.add(connection)
|
|
|
|
# 清理断开的连接
|
|
for connection in disconnected:
|
|
self.disconnect(connection)
|
|
|
|
async def send_to_batch(
|
|
self,
|
|
batch_id: str,
|
|
message: Dict[str, Any],
|
|
exclude: Optional[WebSocket] = None
|
|
):
|
|
"""向批次的所有连接发送消息"""
|
|
if batch_id in self.batch_connections:
|
|
disconnected = set()
|
|
for connection in self.batch_connections[batch_id]:
|
|
if connection != exclude:
|
|
try:
|
|
await connection.send_json(message)
|
|
except Exception as e:
|
|
logger.error(f"发送消息失败: {str(e)}")
|
|
disconnected.add(connection)
|
|
|
|
# 清理断开的连接
|
|
for connection in disconnected:
|
|
self.disconnect(connection)
|
|
|
|
def get_agent(self, project_id: str, working_dir: Path) -> DirectorAgent:
|
|
"""获取或创建 Agent 实例"""
|
|
if project_id not in self.project_agents:
|
|
# 确保工作目录存在
|
|
working_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# 检查配置的模型类型
|
|
model_name = settings.zai_model
|
|
enable_thinking = True
|
|
|
|
# 如果是 GLM 模型,禁用 thinking 模式(不支持)
|
|
if "glm" in model_name.lower():
|
|
enable_thinking = False
|
|
|
|
# 不在这里加载项目上下文,而是在 WebSocket 消息处理时加载
|
|
# 因为 get_agent 是同步方法,而 project_repo.get 是异步的
|
|
self.project_agents[project_id] = DirectorAgent(
|
|
working_directory=working_dir,
|
|
enable_thinking=enable_thinking,
|
|
model=model_name,
|
|
project_context=None # 稍后在消息处理中更新
|
|
)
|
|
return self.project_agents[project_id]
|
|
|
|
def get_project_connections_count(self, project_id: str) -> int:
|
|
"""获取项目的连接数"""
|
|
return len(self.project_connections.get(project_id, set()))
|
|
|
|
def get_batch_connections_count(self, batch_id: str) -> int:
|
|
"""获取批次的连接数"""
|
|
return len(self.batch_connections.get(batch_id, set()))
|
|
|
|
|
|
# 全局连接管理器
|
|
manager = ConnectionManager()
|
|
|
|
|
|
# ============================================
|
|
# WebSocket 端点
|
|
# ============================================
|
|
|
|
@router.websocket("/ws/projects/{project_id}/execute")
|
|
async def websocket_project_execution(
|
|
websocket: WebSocket,
|
|
project_id: str
|
|
):
|
|
"""
|
|
项目执行 WebSocket 端点
|
|
"""
|
|
await manager.connect_to_project(websocket, project_id)
|
|
|
|
# 准备工作目录 (假设在 projects/{id})
|
|
# 注意:这里需要根据实际配置调整路径
|
|
project_dir = Path(f"d:/platform/creative_studio/workspace/projects/{project_id}")
|
|
|
|
try:
|
|
# 发送连接确认
|
|
await websocket.send_json({
|
|
"type": "connected",
|
|
"data": {
|
|
"project_id": project_id,
|
|
"message": "已连接到项目执行流",
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
})
|
|
|
|
# 加载并发送历史消息
|
|
history = await message_repo.get_history(project_id)
|
|
if history:
|
|
await websocket.send_json({
|
|
"type": "history",
|
|
"messages": history
|
|
})
|
|
|
|
# 保持连接并接收客户端消息
|
|
while True:
|
|
try:
|
|
# 接收客户端消息
|
|
data = await websocket.receive_text()
|
|
message = json.loads(data)
|
|
|
|
# 处理客户端消息
|
|
await _handle_client_message(websocket, project_id, message, project_dir)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"WebSocket 客户端主动断开: {project_id}")
|
|
break
|
|
except json.JSONDecodeError:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {
|
|
"message": "无效的 JSON 格式"
|
|
}
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"处理 WebSocket 消息错误: {str(e)}")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {
|
|
"message": f"处理消息错误: {str(e)}"
|
|
}
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket 连接错误: {str(e)}")
|
|
|
|
finally:
|
|
manager.disconnect(websocket)
|
|
|
|
|
|
@router.websocket("/ws/batches/{batch_id}/execute")
|
|
async def websocket_batch_execution(
|
|
websocket: WebSocket,
|
|
batch_id: str
|
|
):
|
|
"""
|
|
批量执行 WebSocket 端点
|
|
"""
|
|
await manager.connect_to_batch(websocket, batch_id)
|
|
|
|
try:
|
|
# 发送连接确认
|
|
await websocket.send_json({
|
|
"type": "connected",
|
|
"data": {
|
|
"batch_id": batch_id,
|
|
"message": "已连接到批量执行流",
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
})
|
|
|
|
# 保持连接
|
|
while True:
|
|
try:
|
|
data = await websocket.receive_text()
|
|
# 批量执行目前不接受客户端控制消息,仅广播
|
|
# 但为了保持连接活性,可以处理 ping
|
|
message = json.loads(data)
|
|
if message.get("type") == "ping":
|
|
await websocket.send_json({
|
|
"type": "pong",
|
|
"data": {"timestamp": datetime.now().isoformat()}
|
|
})
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"WebSocket 客户端主动断开: {batch_id}")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"处理 WebSocket 消息错误: {str(e)}")
|
|
|
|
finally:
|
|
manager.disconnect(websocket)
|
|
|
|
|
|
# ============================================
|
|
# 消息处理
|
|
# ============================================
|
|
|
|
async def _ensure_agent_context(
|
|
agent: DirectorAgent,
|
|
project_id: str,
|
|
active_episode_number: Optional[int] = None,
|
|
active_episode_title: Optional[str] = None
|
|
):
|
|
"""确保 Agent 加载了正确的项目上下文"""
|
|
# 如果已经加载了该项目的上下文,且剧集信息一致,则跳过
|
|
if (agent.context and
|
|
agent.context.project_id == project_id and
|
|
agent.context.active_episode_number == active_episode_number):
|
|
return
|
|
|
|
try:
|
|
from app.db.repositories import project_repo
|
|
from app.core.agent_runtime.context import SkillAgentContext
|
|
from app.core.skills.skill_manager import skill_manager
|
|
|
|
project = await project_repo.get(project_id)
|
|
if not project:
|
|
logger.warning(f"Project {project_id} not found in repository")
|
|
return
|
|
|
|
# 1. 加载用户技能
|
|
user_skills = []
|
|
default_task_skills = getattr(project, 'defaultTaskSkills', [])
|
|
if isinstance(default_task_skills, list):
|
|
for task_config in default_task_skills:
|
|
task_skills = getattr(task_config, 'skills', [])
|
|
if isinstance(task_skills, list):
|
|
for skill_config in task_skills:
|
|
try:
|
|
skill_id = getattr(skill_config, 'skill_id', None) or (skill_config.get('skill_id') if isinstance(skill_config, dict) else None)
|
|
if skill_id:
|
|
skill = skill_manager.get_skill_by_id(skill_id)
|
|
if skill:
|
|
user_skills.append({
|
|
'id': skill.id,
|
|
'name': skill.name,
|
|
'behavior': skill.behavior_guide or skill.description or ''
|
|
})
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load skill: {e}")
|
|
|
|
# 2. 整合角色信息
|
|
characters_text = ""
|
|
if project.globalContext:
|
|
if project.globalContext.styleGuide:
|
|
characters_text += project.globalContext.styleGuide + "\n\n"
|
|
|
|
if project.globalContext.characterProfiles:
|
|
# 确保 characterProfiles 是字典类型
|
|
profiles = project.globalContext.characterProfiles
|
|
if isinstance(profiles, dict):
|
|
characters_text += "### Detailed Character Profiles:\n"
|
|
for name, profile in profiles.items():
|
|
characters_text += f"- **{name}**: {getattr(profile, 'description', '')}\n"
|
|
if hasattr(profile, 'personality') and profile.personality:
|
|
characters_text += f" Personality: {profile.personality}\n"
|
|
else:
|
|
# characterProfiles 是字符串或其他类型,直接添加
|
|
characters_text += str(profiles) + "\n\n"
|
|
|
|
# 3. 获取所有剧集
|
|
from app.db.repositories import episode_repo
|
|
episodes_data = await episode_repo.list_by_project(project_id)
|
|
episodes = []
|
|
for ep in episodes_data:
|
|
episodes.append({
|
|
"number": ep.number,
|
|
"title": ep.title,
|
|
"status": ep.status
|
|
})
|
|
|
|
# 4. 创建项目上下文
|
|
project_context = SkillAgentContext(
|
|
skill_loader=agent.context.skill_loader,
|
|
working_directory=agent.context.working_directory,
|
|
project_id=project.id,
|
|
project_name=project.name,
|
|
project_genre=getattr(project, 'genre', '古风'),
|
|
total_episodes=project.totalEpisodes,
|
|
world_setting=project.globalContext.worldSetting if project.globalContext else None,
|
|
characters=characters_text.strip() or None,
|
|
overall_outline=project.globalContext.overallOutline if project.globalContext else None,
|
|
creation_mode='script' if (project.globalContext and project.globalContext.uploadedScript) else 'inspiration',
|
|
source_content=(project.globalContext.uploadedScript if project.globalContext and project.globalContext.uploadedScript
|
|
else project.globalContext.inspiration if project.globalContext else None),
|
|
user_skills=user_skills,
|
|
active_episode_number=active_episode_number,
|
|
active_episode_title=active_episode_title,
|
|
episodes=episodes
|
|
)
|
|
|
|
# 5. 更新 Agent
|
|
agent.context = project_context
|
|
agent.system_prompt = agent._build_system_prompt()
|
|
if hasattr(agent, 'refresh_agent'):
|
|
agent.refresh_agent()
|
|
|
|
msg = f"Successfully injected project context for {project_id}: {project.name}"
|
|
if active_episode_number:
|
|
msg += f" (Focusing on Episode {active_episode_number})"
|
|
logger.info(msg)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load project context for {project_id}: {e}", exc_info=True)
|
|
|
|
|
|
async def _handle_client_message(
|
|
websocket: WebSocket,
|
|
project_id: str,
|
|
message: Dict[str, Any],
|
|
project_dir: Path
|
|
):
|
|
"""
|
|
处理客户端发送的消息
|
|
"""
|
|
message_type = message.get("type")
|
|
logger.info(f"收到消息: type={message_type}, full_message={message}")
|
|
|
|
if message_type == "ping":
|
|
# 心跳响应
|
|
await websocket.send_json({
|
|
"type": "pong",
|
|
"data": {
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
})
|
|
|
|
elif message_type == "focus_episode":
|
|
# 客户端请求切换关注的剧集
|
|
episode_number = message.get("episodeNumber")
|
|
episode_title = message.get("episodeTitle")
|
|
|
|
agent = manager.get_agent(project_id, project_dir)
|
|
await _ensure_agent_context(agent, project_id, episode_number, episode_title)
|
|
|
|
await websocket.send_json({
|
|
"type": "focus_confirmed",
|
|
"data": {
|
|
"episodeNumber": episode_number,
|
|
"episodeTitle": episode_title
|
|
}
|
|
})
|
|
|
|
elif message_type == "update_episode_title":
|
|
# 用户手动编辑画布标题
|
|
episode_number = message.get("episodeNumber")
|
|
new_title = message.get("title")
|
|
|
|
if episode_number and new_title:
|
|
try:
|
|
from app.db.repositories import episode_repo
|
|
episodes = await episode_repo.list_by_project(project_id)
|
|
episode = next((ep for ep in episodes if ep.number == episode_number), None)
|
|
if episode:
|
|
episode.title = new_title
|
|
await episode_repo.update(episode)
|
|
logger.info(f"Manual title update for EP{episode_number}: {new_title}")
|
|
|
|
# 广播更新给所有连接
|
|
await manager.send_to_project(project_id, {
|
|
"type": "episode_updated",
|
|
"data": {
|
|
"number": episode_number,
|
|
"title": new_title,
|
|
"status": episode.status
|
|
}
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Failed to update episode title: {e}")
|
|
|
|
elif message_type == "inbox_action":
|
|
# 用户在 Inbox 中点击批准或拒绝
|
|
action = message.get("action")
|
|
item_id = message.get("itemId")
|
|
|
|
# 将操作转换为自然语言反馈给 Agent
|
|
feedback = f"User {action}ed inbox item {item_id}."
|
|
|
|
agent = manager.get_agent(project_id, project_dir)
|
|
# 确保上下文已加载
|
|
await _ensure_agent_context(agent, project_id)
|
|
|
|
try:
|
|
for event in agent.stream_events(feedback, thread_id=project_id):
|
|
# 同样的事件处理逻辑
|
|
if event.get("type") == "tool_call":
|
|
await _handle_tool_call(project_id, event)
|
|
await manager.send_to_project(project_id, event)
|
|
except Exception as e:
|
|
await manager.send_to_project(project_id, {
|
|
"type": "error",
|
|
"data": {"message": str(e)}
|
|
})
|
|
|
|
elif message_type == "chat_message":
|
|
# 用户发送聊天消息 -> 触发 Agent 执行
|
|
content = message.get("content", "")
|
|
episode_number = message.get("episodeNumber")
|
|
episode_title = message.get("episodeTitle")
|
|
|
|
if not content:
|
|
return
|
|
|
|
# 保存用户消息
|
|
await message_repo.add_message(project_id, "user", content)
|
|
|
|
# 获取 Agent
|
|
agent = manager.get_agent(project_id, project_dir)
|
|
|
|
# 确保上下文已加载(包含当前剧集信息)
|
|
await _ensure_agent_context(agent, project_id, episode_number, episode_title)
|
|
|
|
# 异步运行 Agent 并将事件流推送到前端
|
|
full_response = ""
|
|
try:
|
|
for event in agent.stream_events(content, thread_id=project_id):
|
|
# 检查特殊工具调用并转换格式
|
|
if event.get("type") == "tool_call":
|
|
await _handle_tool_call(project_id, event)
|
|
|
|
# 收集 Agent 回复内容
|
|
if event.get("type") == "text":
|
|
full_response += event.get("content", "")
|
|
|
|
await manager.send_to_project(project_id, event)
|
|
|
|
# 保存 Agent 回复
|
|
if full_response:
|
|
await message_repo.add_message(project_id, "agent", full_response)
|
|
|
|
except Exception as e:
|
|
await manager.send_to_project(project_id, {
|
|
"type": "error",
|
|
"data": {"message": str(e)}
|
|
})
|
|
|
|
elif message_type == "get_status":
|
|
# 请求状态
|
|
# 这里可以返回 Agent 的状态,或者之前的 executor 状态
|
|
pass
|
|
|
|
else:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {
|
|
"message": f"未知消息类型: {message_type}"
|
|
}
|
|
})
|
|
|
|
|
|
async def _sync_context_states(
|
|
project_id: str,
|
|
episode_number: int,
|
|
memory: Any
|
|
):
|
|
"""同步上下文状态到前端"""
|
|
try:
|
|
# 提取角色状态作为上下文状态
|
|
context_states = []
|
|
|
|
# 添加时间状态
|
|
context_states.append({
|
|
"type": "time",
|
|
"value": f"EP{episode_number} 完成后"
|
|
})
|
|
|
|
# 添加角色状态
|
|
character_states = getattr(memory, 'characterStates', {})
|
|
if isinstance(character_states, dict):
|
|
for char_name, states in character_states.items():
|
|
if states and isinstance(states, list) and len(states) > 0:
|
|
latest_state = states[-1]
|
|
if isinstance(latest_state, dict):
|
|
state_value = latest_state.get('state', f"{char_name}状态")
|
|
else:
|
|
state_value = getattr(latest_state, 'state', f"{char_name}状态")
|
|
|
|
context_states.append({
|
|
"type": "character",
|
|
"value": f"{char_name}: {state_value}",
|
|
"character": char_name,
|
|
"state": state_value
|
|
})
|
|
|
|
# 添加待收线数量
|
|
pending_threads = getattr(memory, 'pendingThreads', [])
|
|
if pending_threads:
|
|
context_states.append({
|
|
"type": "pending_threads",
|
|
"value": f"待收线: {len(pending_threads)} 条"
|
|
})
|
|
|
|
# 广播上下文更新
|
|
await manager.send_to_project(project_id, {
|
|
"type": "context_update",
|
|
"states": context_states,
|
|
"episode_number": episode_number
|
|
})
|
|
|
|
logger.info(f"已同步上下文状态到项目 {project_id}, {len(context_states)} 个状态")
|
|
|
|
except Exception as e:
|
|
logger.error(f"同步上下文状态失败: {str(e)}")
|
|
|
|
|
|
async def _handle_tool_call(project_id: str, event: Dict[str, Any]):
|
|
"""
|
|
处理工具调用,转换为特定的 WebSocket 消息
|
|
|
|
这个函数在 agent.stream_events() 中被调用,当检测到 director 工具调用时,
|
|
会将其转换为前端可以理解的 WebSocket 事件格式。
|
|
"""
|
|
name = event.get("name")
|
|
args = event.get("args", {})
|
|
|
|
# Director 工具处理
|
|
if name == "update_plan":
|
|
await manager.send_to_project(project_id, {
|
|
"type": "plan_update",
|
|
"plan": args.get("steps", []),
|
|
"status": args.get("status", "planning"),
|
|
"current_step_index": args.get("current_step_index", 0)
|
|
})
|
|
|
|
elif name == "add_inbox_task":
|
|
await manager.send_to_project(project_id, {
|
|
"type": "review_request",
|
|
"id": f"task_{args.get('title', 'unknown')}_{int(datetime.now().timestamp())}",
|
|
"title": args.get("title"),
|
|
"description": args.get("description"),
|
|
"options": args.get("options", ["Approve", "Reject"]),
|
|
"timestamp": int(datetime.now().timestamp() * 1000)
|
|
})
|
|
|
|
elif name == "add_annotation":
|
|
await manager.send_to_project(project_id, {
|
|
"type": "annotation_add",
|
|
"annotation": {
|
|
"content": args.get("content"),
|
|
"type": args.get("annotation_type", "review"),
|
|
"suggestion": args.get("suggestion", ""),
|
|
"timestamp": int(datetime.now().timestamp() * 1000)
|
|
}
|
|
})
|
|
|
|
elif name == "update_context":
|
|
# 解析 data (可能是 JSON string 或 dict)
|
|
data = args.get("data")
|
|
context_type = args.get("context_type", "state")
|
|
|
|
try:
|
|
if isinstance(data, str):
|
|
data = json.loads(data)
|
|
|
|
# 转换为前端期望的 activeStates 格式
|
|
if isinstance(data, dict):
|
|
# 将字典转换为 [{type, value}, ...] 格式
|
|
states = [{"type": k, "value": v} for k, v in data.items()]
|
|
elif isinstance(data, list):
|
|
# 已经是列表格式
|
|
states = data
|
|
else:
|
|
states = [{"type": context_type, "value": str(data)}]
|
|
|
|
await manager.send_to_project(project_id, {
|
|
"type": "context_update",
|
|
"states": states
|
|
})
|
|
except Exception as e:
|
|
logger.warning(f"Failed to process context update: {e}")
|
|
|
|
elif name == "write_to_canvas":
|
|
# 新的 write_to_canvas 工具
|
|
content = args.get("content", "")
|
|
if content:
|
|
await manager.send_to_project(project_id, {
|
|
"type": "canvas_update",
|
|
"content": content
|
|
})
|
|
|
|
elif name == "write_file":
|
|
# 如果写入的是当前画布文件,也更新画布
|
|
# 这里简化:只要有内容就更新画布
|
|
content = args.get("content")
|
|
if content:
|
|
await manager.send_to_project(project_id, {
|
|
"type": "canvas_update",
|
|
"content": content
|
|
})
|
|
|
|
elif name == "update_memory":
|
|
# 处理记忆库更新
|
|
memory_type = args.get("memory_type", "timeline")
|
|
data = args.get("data", {})
|
|
|
|
# 根据记忆类型格式化数据
|
|
memory_data = {
|
|
"type": memory_type,
|
|
"memory_type": memory_type,
|
|
"timestamp": int(datetime.now().timestamp() * 1000)
|
|
}
|
|
|
|
# 添加具体信息
|
|
if memory_type == "timeline":
|
|
memory_data.update({
|
|
"title": data.get("event", "时间线事件"),
|
|
"description": data.get("description", "")
|
|
})
|
|
elif memory_type == "character_state":
|
|
memory_data.update({
|
|
"title": f"{data.get('character', '角色')}状态变化",
|
|
"description": data.get("state", ""),
|
|
"character": data.get("character", "")
|
|
})
|
|
elif memory_type == "pending_thread":
|
|
memory_data.update({
|
|
"title": "待收线问题",
|
|
"description": data.get("description", "")
|
|
})
|
|
elif memory_type == "foreshadowing":
|
|
memory_data.update({
|
|
"title": "伏笔",
|
|
"description": data.get("description", "")
|
|
})
|
|
|
|
await manager.send_to_project(project_id, {
|
|
"type": "memory_update",
|
|
"data": memory_data
|
|
})
|
|
|
|
elif name == "save_episode":
|
|
# 处理剧集保存并持久化
|
|
episode_number = args.get("episode_number")
|
|
title = args.get("title")
|
|
content = args.get("content")
|
|
outline = args.get("outline")
|
|
|
|
if episode_number:
|
|
try:
|
|
from app.db.repositories import episode_repo, project_repo
|
|
from app.core.memory.memory_manager import get_memory_manager
|
|
episodes = await episode_repo.list_by_project(project_id)
|
|
episode = next((ep for ep in episodes if ep.number == episode_number), None)
|
|
|
|
if episode:
|
|
if title: episode.title = title
|
|
if content: episode.content = content
|
|
if outline: episode.outline = outline
|
|
episode.status = "completed"
|
|
await episode_repo.update(episode)
|
|
logger.info(f"Persisted episode {episode_number} via save_episode")
|
|
|
|
# 自动更新记忆库
|
|
try:
|
|
project = await project_repo.get(project_id)
|
|
if project and episode.content:
|
|
memory_manager = get_memory_manager()
|
|
await memory_manager.update_memory_from_episode(project, episode)
|
|
logger.info(f"Updated memory after saving episode {episode_number}")
|
|
|
|
# 同步上下文状态到前端
|
|
await _sync_context_states(project_id, episode_number, project.memory)
|
|
except Exception as memory_error:
|
|
logger.warning(f"Failed to update memory for episode {episode_number}: {memory_error}")
|
|
|
|
# 广播更新
|
|
await manager.send_to_project(project_id, {
|
|
"type": "episode_updated",
|
|
"data": {
|
|
"number": episode_number,
|
|
"title": episode.title,
|
|
"status": episode.status
|
|
}
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Failed to persist episode via save_episode: {e}")
|
|
|
|
await manager.send_to_project(project_id, {
|
|
"type": "episode_saved",
|
|
"episode_number": episode_number,
|
|
"title": title
|
|
})
|
|
|
|
elif name == "update_episode":
|
|
# 处理剧集更新并持久化
|
|
episode_number = args.get("episode_number")
|
|
title = args.get("title")
|
|
content = args.get("content")
|
|
status = args.get("status")
|
|
outline = args.get("outline")
|
|
|
|
if episode_number:
|
|
try:
|
|
from app.db.repositories import episode_repo, project_repo
|
|
from app.core.memory.memory_manager import get_memory_manager
|
|
episodes = await episode_repo.list_by_project(project_id)
|
|
episode = next((ep for ep in episodes if ep.number == episode_number), None)
|
|
|
|
if episode:
|
|
if title is not None: episode.title = title
|
|
if content is not None: episode.content = content
|
|
if status is not None: episode.status = status
|
|
if outline is not None: episode.outline = outline
|
|
|
|
await episode_repo.update(episode)
|
|
logger.info(f"Updated episode {episode_number} via update_episode")
|
|
|
|
# 如果有新内容,自动更新记忆库
|
|
if content and content.strip():
|
|
try:
|
|
project = await project_repo.get(project_id)
|
|
if project:
|
|
memory_manager = get_memory_manager()
|
|
await memory_manager.update_memory_from_episode(project, episode)
|
|
logger.info(f"Updated memory after updating episode {episode_number}")
|
|
|
|
# 同步上下文状态到前端
|
|
await _sync_context_states(project_id, episode_number, project.memory)
|
|
except Exception as memory_error:
|
|
logger.warning(f"Failed to update memory for episode {episode_number}: {memory_error}")
|
|
|
|
# 广播更新
|
|
await manager.send_to_project(project_id, {
|
|
"type": "episode_updated",
|
|
"data": {
|
|
"number": episode_number,
|
|
"title": episode.title,
|
|
"status": episode.status
|
|
}
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Failed to update episode: {e}")
|
|
|
|
elif name == "focus_episode":
|
|
# 处理剧集焦点切换
|
|
episode_number = args.get("episode_number")
|
|
title = args.get("title")
|
|
|
|
await manager.send_to_project(project_id, {
|
|
"type": "focus_update",
|
|
"episodeNumber": episode_number,
|
|
"episodeTitle": title
|
|
})
|
|
|
|
elif name == "create_episode":
|
|
# 处理剧集创作请求
|
|
episode_number = args.get("episode_number")
|
|
analyze_previous_memory = args.get("analyze_previous_memory", True)
|
|
|
|
if episode_number:
|
|
# 使用 ensure_future 确保任务在后台执行,即使 WebSocket 断开也能继续
|
|
asyncio.ensure_future(_execute_episode_creation(
|
|
project_id, episode_number, analyze_previous_memory
|
|
))
|
|
logger.info(f"已启动后台创作任务: EP{episode_number}")
|
|
|
|
# ============================================
|
|
# 辅助函数 - 用于从其他模块发送消息
|
|
# ============================================
|
|
|
|
|
|
# 全局后台任务跟踪
|
|
_background_tasks: Dict[str, asyncio.Task] = {}
|
|
|
|
|
|
async def _execute_episode_creation(
|
|
project_id: str,
|
|
episode_number: int,
|
|
analyze_previous_memory: bool
|
|
):
|
|
"""
|
|
异步执行剧集创作
|
|
|
|
这个函数在后台执行,不会阻塞 WebSocket 连接。
|
|
即使 WebSocket 断开,任务也会继续执行并保存到数据库。
|
|
|
|
它会:
|
|
1. 分析上一集的记忆(如果需要)
|
|
2. 执行剧集创作
|
|
3. 通过 WebSocket 发送进度更新(如果连接存在)
|
|
4. 将创作内容发送到画布(如果连接存在)
|
|
5. 无论如何都保存到数据库
|
|
"""
|
|
task_key = f"{project_id}_{episode_number}"
|
|
|
|
try:
|
|
from app.db.repositories import project_repo, episode_repo
|
|
from app.core.agents.series_creation_agent import get_series_agent
|
|
from app.core.memory.memory_manager import get_memory_manager
|
|
|
|
# 注册到后台任务字典
|
|
current_task = asyncio.current_task()
|
|
if current_task:
|
|
_background_tasks[task_key] = current_task
|
|
|
|
logger.info(f"开始后台创作任务: {task_key}")
|
|
|
|
# 辅助函数:安全发送 WebSocket 消息(忽略连接错误)
|
|
async def safe_send(message_type: str, data: dict = None):
|
|
try:
|
|
if data is None:
|
|
data = {}
|
|
await manager.send_to_project(project_id, {
|
|
"type": message_type,
|
|
"data": data
|
|
})
|
|
except Exception as e:
|
|
# WebSocket 可能已断开,忽略错误继续执行
|
|
logger.debug(f"WebSocket 发送失败(可能已断开): {e}")
|
|
|
|
# 获取项目
|
|
project = await project_repo.get(project_id)
|
|
if not project:
|
|
await safe_send("error", {"message": f"项目不存在: {project_id}"})
|
|
return
|
|
|
|
# 更新计划状态 - 开始
|
|
await safe_send("plan_update", {
|
|
"plan": [
|
|
f"分析 EP{episode_number - 1 if episode_number > 1 else 'N/A'} 的记忆系统" if analyze_previous_memory and episode_number > 1 else "跳过记忆分析(首集)",
|
|
f"生成 EP{episode_number} 大纲",
|
|
f"创作 EP{episode_number} 对话内容",
|
|
f"执行质量审核",
|
|
f"更新记忆系统"
|
|
],
|
|
"status": "planning",
|
|
"current_step_index": 0
|
|
})
|
|
|
|
# 步骤 1: 分析上一集记忆(如果需要)
|
|
if analyze_previous_memory and episode_number > 1:
|
|
await safe_send("plan_update", {
|
|
"plan": [
|
|
f"分析 EP{episode_number - 1} 的记忆系统",
|
|
f"生成 EP{episode_number} 大纲",
|
|
f"创作 EP{episode_number} 对话内容",
|
|
f"执行质量审核",
|
|
f"更新记忆系统"
|
|
],
|
|
"status": "planning",
|
|
"current_step_index": 0
|
|
})
|
|
|
|
# 获取上一集内容
|
|
prev_episodes = await episode_repo.list_by_project(project_id)
|
|
prev_episode = next((ep for ep in prev_episodes if ep.number == episode_number - 1), None)
|
|
|
|
if prev_episode and prev_episode.content:
|
|
await safe_send("text", {"content": f"\n\n--- 正在分析 EP{episode_number - 1} 的记忆系统 ---\n"})
|
|
|
|
# 使用 MemoryManager 更新记忆
|
|
try:
|
|
memory_manager = get_memory_manager()
|
|
await memory_manager.update_memory_from_episode(project, prev_episode)
|
|
logger.info(f"EP{episode_number - 1} 记忆已分析并注入到 EP{episode_number}")
|
|
except Exception as e:
|
|
logger.warning(f"分析 EP{episode_number - 1} 记忆失败: {e}")
|
|
|
|
# 步骤 2-5: 执行剧集创作
|
|
agent = get_series_agent()
|
|
|
|
await safe_send("plan_update", {
|
|
"plan": [
|
|
f"分析 EP{episode_number - 1} 的记忆系统" if analyze_previous_memory and episode_number > 1 else "跳过记忆分析",
|
|
f"生成 EP{episode_number} 大纲",
|
|
f"创作 EP{episode_number} 对话内容",
|
|
f"执行质量审核",
|
|
f"更新记忆系统"
|
|
],
|
|
"status": "writing",
|
|
"current_step_index": 1
|
|
})
|
|
|
|
await safe_send("text", {"content": f"\n\n--- 开始创作 EP{episode_number} ---\n"})
|
|
|
|
# 执行创作
|
|
episode = await agent.execute_episode(
|
|
project=project,
|
|
episode_number=episode_number,
|
|
title=f"第{episode_number}集"
|
|
)
|
|
|
|
# 检查是否创作成功(错误处理)
|
|
if episode.status == "needs-review" and not episode.content:
|
|
# 创作失败,没有内容
|
|
await safe_send("error", {
|
|
"message": f"EP{episode_number} 创作失败",
|
|
"episode_number": episode_number
|
|
})
|
|
await safe_send("text", {"content": f"\n\n❌ EP{episode_number} 创作失败。请检查错误日志并重试。\n"})
|
|
logger.error(f"EP{episode_number} 创作失败,无内容生成")
|
|
return
|
|
|
|
# 创作成功,保存到数据库(无论 WebSocket 是否连接)
|
|
existing_episodes = await episode_repo.list_by_project(project_id)
|
|
episode_record = next((ep for ep in existing_episodes if ep.number == episode_number), None)
|
|
|
|
if episode_record:
|
|
episode.id = episode_record.id
|
|
episode.projectId = project_id
|
|
await episode_repo.update(episode)
|
|
logger.info(f"更新现有剧集记录: {episode.id}")
|
|
else:
|
|
episode.projectId = project_id
|
|
await episode_repo.create(episode)
|
|
logger.info(f"创建新剧集记录: {episode.id}")
|
|
|
|
# 发送内容到画布
|
|
if episode.content:
|
|
await safe_send("canvas_update", {"content": episode.content})
|
|
|
|
# 更新记忆
|
|
await project_repo.update(project_id, {
|
|
"memory": project.memory.dict()
|
|
})
|
|
|
|
# 完成消息
|
|
await safe_send("plan_update", {
|
|
"plan": [
|
|
f"分析 EP{episode_number - 1} 的记忆系统" if analyze_previous_memory and episode_number > 1 else "跳过记忆分析",
|
|
f"生成 EP{episode_number} 大纲",
|
|
f"创作 EP{episode_number} 对话内容",
|
|
f"执行质量审核",
|
|
f"更新记忆系统"
|
|
],
|
|
"status": "idle",
|
|
"current_step_index": 4
|
|
})
|
|
|
|
await safe_send("text", {"content": f"\n\n✅ EP{episode_number} 创作完成!质量分数: {episode.qualityScore or 0}\n"})
|
|
|
|
# 广播更新
|
|
await safe_send("episode_updated", {
|
|
"number": episode_number,
|
|
"title": episode.title,
|
|
"status": episode.status
|
|
})
|
|
|
|
logger.info(f"EP{episode_number} 后台创作完成,已保存到数据库")
|
|
|
|
except Exception as e:
|
|
logger.error(f"执行剧集创作失败: {str(e)}", exc_info=True)
|
|
try:
|
|
await manager.send_to_project(project_id, {
|
|
"type": "error",
|
|
"data": {
|
|
"message": f"EP{episode_number} 创作失败: {str(e)}",
|
|
"episode_number": episode_number
|
|
}
|
|
})
|
|
await manager.send_to_project(project_id, {
|
|
"type": "text",
|
|
"data": {"content": f"\n\n❌ EP{episode_number} 创作失败: {str(e)}\n"}
|
|
})
|
|
except Exception:
|
|
# WebSocket 可能已断开,忽略错误
|
|
pass
|
|
finally:
|
|
# 清理任务跟踪
|
|
if task_key in _background_tasks:
|
|
del _background_tasks[task_key]
|
|
logger.info(f"后台创作任务结束: {task_key}")
|
|
|
|
|
|
# ============================================
|
|
# 辅助函数 - 用于从其他模块发送消息
|
|
# ============================================
|
|
|
|
async def broadcast_stage_update(
|
|
project_id: str,
|
|
episode_number: int,
|
|
stage: str,
|
|
data: Dict[str, Any]
|
|
):
|
|
message = {
|
|
"type": "stage_update",
|
|
"data": {
|
|
"project_id": project_id,
|
|
"episode_number": episode_number,
|
|
"stage": stage,
|
|
**data
|
|
},
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
await manager.send_to_project(project_id, message)
|
|
|
|
|
|
async def broadcast_episode_complete(
|
|
project_id: str,
|
|
episode_number: int,
|
|
success: bool,
|
|
quality_score: float,
|
|
data: Dict[str, Any]
|
|
):
|
|
message = {
|
|
"type": "episode_complete",
|
|
"data": {
|
|
"project_id": project_id,
|
|
"episode_number": episode_number,
|
|
"success": success,
|
|
"quality_score": quality_score,
|
|
**data
|
|
},
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
await manager.send_to_project(project_id, message)
|
|
|
|
|
|
async def broadcast_batch_progress(
|
|
batch_id: str,
|
|
current_episode: int,
|
|
total_episodes: int,
|
|
completed: int,
|
|
failed: int,
|
|
data: Dict[str, Any]
|
|
):
|
|
message = {
|
|
"type": "batch_progress",
|
|
"data": {
|
|
"batch_id": batch_id,
|
|
"current_episode": current_episode,
|
|
"total_episodes": total_episodes,
|
|
"completed_episodes": completed,
|
|
"failed_episodes": failed,
|
|
"progress_percentage": (completed / total_episodes * 100) if total_episodes > 0 else 0,
|
|
**data
|
|
},
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
await manager.send_to_batch(batch_id, message)
|
|
|
|
|
|
async def broadcast_error(
|
|
project_id: str,
|
|
episode_number: Optional[int],
|
|
error: str,
|
|
error_type: str = "execution_error"
|
|
):
|
|
message = {
|
|
"type": "error",
|
|
"data": {
|
|
"project_id": project_id,
|
|
"episode_number": episode_number,
|
|
"error": error,
|
|
"error_type": error_type
|
|
},
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
await manager.send_to_project(project_id, message)
|
|
|
|
|
|
async def broadcast_batch_complete(
|
|
batch_id: str,
|
|
summary: Dict[str, Any]
|
|
):
|
|
message = {
|
|
"type": "batch_complete",
|
|
"data": {
|
|
"batch_id": batch_id,
|
|
**summary
|
|
},
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
await manager.send_to_batch(batch_id, message)
|
|
|
|
|
|
async def broadcast_to_project(
|
|
project_id: str,
|
|
message_type: str,
|
|
data: Dict[str, Any]
|
|
):
|
|
"""向项目的所有连接广播消息"""
|
|
message = {
|
|
"type": message_type,
|
|
"data": data,
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
await manager.send_to_project(project_id, message)
|
|
|
|
|
|
# 导出连接管理器和辅助函数
|
|
__all__ = [
|
|
"manager",
|
|
"broadcast_stage_update",
|
|
"broadcast_episode_complete",
|
|
"broadcast_batch_progress",
|
|
"broadcast_error",
|
|
"broadcast_batch_complete",
|
|
"broadcast_to_project"
|
|
]
|