""" WebSocket Streaming API 提供实时执行进度更新的 WebSocket 端点 """ from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query from typing import Dict, Set, Optional, Any import json import asyncio from datetime import datetime from pathlib import Path from app.config import settings from app.utils.logger import get_logger from app.core.agent_runtime.director_agent import DirectorAgent from app.db.repositories import message_repo logger = get_logger(__name__) router = APIRouter() # ============================================ # WebSocket 连接管理 # ============================================ class ConnectionManager: """WebSocket 连接管理器""" def __init__(self): # 项目ID -> WebSocket连接集合 self.project_connections: Dict[str, Set[WebSocket]] = {} # 批次ID -> WebSocket连接集合 self.batch_connections: Dict[str, Set[WebSocket]] = {} # 项目ID -> Agent实例 self.project_agents: Dict[str, DirectorAgent] = {} async def connect_to_project(self, websocket: WebSocket, project_id: str): """连接到项目执行流""" await websocket.accept() if project_id not in self.project_connections: self.project_connections[project_id] = set() self.project_connections[project_id].add(websocket) logger.info(f"WebSocket 已连接到项目: {project_id}") async def connect_to_batch(self, websocket: WebSocket, batch_id: str): """连接到批次执行流""" await websocket.accept() if batch_id not in self.batch_connections: self.batch_connections[batch_id] = set() self.batch_connections[batch_id].add(websocket) logger.info(f"WebSocket 已连接到批次: {batch_id}") def disconnect(self, websocket: WebSocket): """断开连接""" # 从所有项目连接中移除 for project_id, connections in self.project_connections.items(): if websocket in connections: connections.remove(websocket) logger.info(f"WebSocket 已从项目断开: {project_id}") if not connections: del self.project_connections[project_id] # 清理 Agent 实例 if project_id in self.project_agents: del self.project_agents[project_id] # 从所有批次连接中移除 for batch_id, connections in self.batch_connections.items(): if websocket in connections: connections.remove(websocket) logger.info(f"WebSocket 已从批次断开: {batch_id}") if not connections: del self.batch_connections[batch_id] async def send_to_project( self, project_id: str, message: Dict[str, Any], exclude: Optional[WebSocket] = None ): """向项目的所有连接发送消息""" if project_id in self.project_connections: disconnected = set() for connection in self.project_connections[project_id]: if connection != exclude: try: await connection.send_json(message) except Exception as e: logger.error(f"发送消息失败: {str(e)}") disconnected.add(connection) # 清理断开的连接 for connection in disconnected: self.disconnect(connection) async def send_to_batch( self, batch_id: str, message: Dict[str, Any], exclude: Optional[WebSocket] = None ): """向批次的所有连接发送消息""" if batch_id in self.batch_connections: disconnected = set() for connection in self.batch_connections[batch_id]: if connection != exclude: try: await connection.send_json(message) except Exception as e: logger.error(f"发送消息失败: {str(e)}") disconnected.add(connection) # 清理断开的连接 for connection in disconnected: self.disconnect(connection) def get_agent(self, project_id: str, working_dir: Path) -> DirectorAgent: """获取或创建 Agent 实例""" if project_id not in self.project_agents: # 确保工作目录存在 working_dir.mkdir(parents=True, exist_ok=True) # 检查配置的模型类型 model_name = settings.zai_model enable_thinking = True # 如果是 GLM 模型,禁用 thinking 模式(不支持) if "glm" in model_name.lower(): enable_thinking = False # 不在这里加载项目上下文,而是在 WebSocket 消息处理时加载 # 因为 get_agent 是同步方法,而 project_repo.get 是异步的 self.project_agents[project_id] = DirectorAgent( working_directory=working_dir, enable_thinking=enable_thinking, model=model_name, project_context=None # 稍后在消息处理中更新 ) return self.project_agents[project_id] def get_project_connections_count(self, project_id: str) -> int: """获取项目的连接数""" return len(self.project_connections.get(project_id, set())) def get_batch_connections_count(self, batch_id: str) -> int: """获取批次的连接数""" return len(self.batch_connections.get(batch_id, set())) # 全局连接管理器 manager = ConnectionManager() # ============================================ # WebSocket 端点 # ============================================ @router.websocket("/ws/projects/{project_id}/execute") async def websocket_project_execution( websocket: WebSocket, project_id: str ): """ 项目执行 WebSocket 端点 """ await manager.connect_to_project(websocket, project_id) # 准备工作目录 (假设在 projects/{id}) # 注意:这里需要根据实际配置调整路径 project_dir = Path(f"d:/platform/creative_studio/workspace/projects/{project_id}") try: # 发送连接确认 await websocket.send_json({ "type": "connected", "data": { "project_id": project_id, "message": "已连接到项目执行流", "timestamp": datetime.now().isoformat() } }) # 加载并发送历史消息 history = await message_repo.get_history(project_id) if history: await websocket.send_json({ "type": "history", "messages": history }) # 保持连接并接收客户端消息 while True: try: # 接收客户端消息 data = await websocket.receive_text() message = json.loads(data) # 处理客户端消息 await _handle_client_message(websocket, project_id, message, project_dir) except WebSocketDisconnect: logger.info(f"WebSocket 客户端主动断开: {project_id}") break except json.JSONDecodeError: await websocket.send_json({ "type": "error", "data": { "message": "无效的 JSON 格式" } }) except Exception as e: logger.error(f"处理 WebSocket 消息错误: {str(e)}") await websocket.send_json({ "type": "error", "data": { "message": f"处理消息错误: {str(e)}" } }) except Exception as e: logger.error(f"WebSocket 连接错误: {str(e)}") finally: manager.disconnect(websocket) @router.websocket("/ws/batches/{batch_id}/execute") async def websocket_batch_execution( websocket: WebSocket, batch_id: str ): """ 批量执行 WebSocket 端点 """ await manager.connect_to_batch(websocket, batch_id) try: # 发送连接确认 await websocket.send_json({ "type": "connected", "data": { "batch_id": batch_id, "message": "已连接到批量执行流", "timestamp": datetime.now().isoformat() } }) # 保持连接 while True: try: data = await websocket.receive_text() # 批量执行目前不接受客户端控制消息,仅广播 # 但为了保持连接活性,可以处理 ping message = json.loads(data) if message.get("type") == "ping": await websocket.send_json({ "type": "pong", "data": {"timestamp": datetime.now().isoformat()} }) except WebSocketDisconnect: logger.info(f"WebSocket 客户端主动断开: {batch_id}") break except Exception as e: logger.error(f"处理 WebSocket 消息错误: {str(e)}") finally: manager.disconnect(websocket) # ============================================ # 消息处理 # ============================================ async def _handle_client_message( websocket: WebSocket, project_id: str, message: Dict[str, Any], project_dir: Path ): """ 处理客户端发送的消息 """ message_type = message.get("type") logger.info(f"收到消息: type={message_type}, full_message={message}") if message_type == "ping": # 心跳响应 await websocket.send_json({ "type": "pong", "data": { "timestamp": datetime.now().isoformat() } }) elif message_type == "inbox_action": # 用户在 Inbox 中点击批准或拒绝 action = message.get("action") item_id = message.get("itemId") # 将操作转换为自然语言反馈给 Agent feedback = f"User {action}ed inbox item {item_id}." agent = manager.get_agent(project_id, project_dir) try: for event in agent.stream_events(feedback, thread_id=project_id): # 同样的事件处理逻辑 if event.get("type") == "tool_call": await _handle_tool_call(project_id, event) await manager.send_to_project(project_id, event) except Exception as e: await manager.send_to_project(project_id, { "type": "error", "data": {"message": str(e)} }) elif message_type == "chat_message": # 用户发送聊天消息 -> 触发 Agent 执行 content = message.get("content", "") if not content: return # 保存用户消息 await message_repo.add_message(project_id, "user", content) # 获取 Agent agent = manager.get_agent(project_id, project_dir) # 加载项目上下文并更新 Agent(如果尚未加载) if agent.context and not agent.context.project_id: try: from app.db.repositories import project_repo from app.core.agent_runtime.context import SkillAgentContext from app.core.agent_runtime.skill_loader import SkillLoader from app.core.skills.skill_manager import skill_manager project = await project_repo.get(project_id) if project: # 将项目的 defaultTaskSkills 转换为 user_skills 格式 user_skills = [] if hasattr(project, 'defaultTaskSkills') and project.defaultTaskSkills: for task_config in project.defaultTaskSkills: for skill_config in task_config.skills: try: # 通过 skill_manager 获取技能详细信息 skill = skill_manager.get_skill_by_id(skill_config.skill_id) if skill: user_skills.append({ 'id': skill.id, 'name': skill.name, 'behavior': skill.behavior_guide or skill.description or '' }) except Exception as e: logger.warning(f"Failed to load skill {skill_config.skill_id}: {e}") # 创建项目上下文 project_context = SkillAgentContext( skill_loader=agent.context.skill_loader, working_directory=agent.context.working_directory, project_id=project.id, project_name=project.name, project_genre=getattr(project, 'genre', '古风'), total_episodes=project.totalEpisodes, world_setting=project.globalContext.worldSetting if project.globalContext else None, characters=project.globalContext.styleGuide if project.globalContext else None, overall_outline=project.globalContext.overallOutline if project.globalContext else None, creation_mode='script' if (project.globalContext and project.globalContext.uploadedScript) else 'inspiration', source_content=(project.globalContext.uploadedScript if project.globalContext and project.globalContext.uploadedScript else project.globalContext.inspiration if project.globalContext else None), user_skills=user_skills ) # 更新 Agent 的上下文 agent.context = project_context # 重新构建 system prompt agent.system_prompt = agent._build_system_prompt() logger.info(f"Loaded project context for {project_id}: {project.name}") except Exception as e: logger.warning(f"Failed to load project context for {project_id}: {e}") # 异步运行 Agent 并将事件流推送到前端 full_response = "" try: for event in agent.stream_events(content, thread_id=project_id): # 检查特殊工具调用并转换格式 if event.get("type") == "tool_call": await _handle_tool_call(project_id, event) # 收集 Agent 回复内容 if event.get("type") == "text": full_response += event.get("content", "") await manager.send_to_project(project_id, event) # 保存 Agent 回复 if full_response: await message_repo.add_message(project_id, "agent", full_response) except Exception as e: await manager.send_to_project(project_id, { "type": "error", "data": {"message": str(e)} }) elif message_type == "get_status": # 请求状态 # 这里可以返回 Agent 的状态,或者之前的 executor 状态 pass else: await websocket.send_json({ "type": "error", "data": { "message": f"未知消息类型: {message_type}" } }) async def _handle_tool_call(project_id: str, event: Dict[str, Any]): """ 处理工具调用,转换为特定的 WebSocket 消息 这个函数在 agent.stream_events() 中被调用,当检测到 director 工具调用时, 会将其转换为前端可以理解的 WebSocket 事件格式。 """ name = event.get("name") args = event.get("args", {}) # Director 工具处理 if name == "update_plan": await manager.send_to_project(project_id, { "type": "plan_update", "plan": args.get("steps", []), "status": args.get("status", "planning"), "current_step_index": args.get("current_step_index", 0) }) elif name == "add_inbox_task": await manager.send_to_project(project_id, { "type": "review_request", "id": f"task_{args.get('title', 'unknown')}_{int(datetime.now().timestamp())}", "title": args.get("title"), "description": args.get("description"), "options": args.get("options", ["Approve", "Reject"]), "timestamp": int(datetime.now().timestamp() * 1000) }) elif name == "add_annotation": await manager.send_to_project(project_id, { "type": "annotation_add", "annotation": { "content": args.get("content"), "type": args.get("annotation_type", "review"), "suggestion": args.get("suggestion", ""), "timestamp": int(datetime.now().timestamp() * 1000) } }) elif name == "update_context": # 解析 data (可能是 JSON string 或 dict) data = args.get("data") context_type = args.get("context_type", "state") try: if isinstance(data, str): data = json.loads(data) # 转换为前端期望的 activeStates 格式 if isinstance(data, dict): # 将字典转换为 [{type, value}, ...] 格式 states = [{"type": k, "value": v} for k, v in data.items()] elif isinstance(data, list): # 已经是列表格式 states = data else: states = [{"type": context_type, "value": str(data)}] await manager.send_to_project(project_id, { "type": "context_update", "states": states }) except Exception as e: logger.warning(f"Failed to process context update: {e}") elif name == "write_to_canvas": # 新的 write_to_canvas 工具 content = args.get("content", "") if content: await manager.send_to_project(project_id, { "type": "canvas_update", "content": content }) elif name == "write_file": # 如果写入的是当前画布文件,也更新画布 # 这里简化:只要有内容就更新画布 content = args.get("content") if content: await manager.send_to_project(project_id, { "type": "canvas_update", "content": content }) elif name == "update_memory": # 处理记忆库更新 memory_type = args.get("memory_type", "timeline") data = args.get("data", {}) # 根据记忆类型格式化数据 memory_data = { "type": memory_type, "memory_type": memory_type, "timestamp": int(datetime.now().timestamp() * 1000) } # 添加具体信息 if memory_type == "timeline": memory_data.update({ "title": data.get("event", "时间线事件"), "description": data.get("description", "") }) elif memory_type == "character_state": memory_data.update({ "title": f"{data.get('character', '角色')}状态变化", "description": data.get("state", ""), "character": data.get("character", "") }) elif memory_type == "pending_thread": memory_data.update({ "title": "待收线问题", "description": data.get("description", "") }) elif memory_type == "foreshadowing": memory_data.update({ "title": "伏笔", "description": data.get("description", "") }) await manager.send_to_project(project_id, { "type": "memory_update", "data": memory_data }) elif name == "save_episode": # 处理剧集保存 episode_number = args.get("episode_number", 0) title = args.get("title", "") await manager.send_to_project(project_id, { "type": "episode_saved", "episode_number": episode_number, "title": title }) # ============================================ # 辅助函数 - 用于从其他模块发送消息 # ============================================ async def broadcast_stage_update( project_id: str, episode_number: int, stage: str, data: Dict[str, Any] ): message = { "type": "stage_update", "data": { "project_id": project_id, "episode_number": episode_number, "stage": stage, **data }, "timestamp": datetime.now().isoformat() } await manager.send_to_project(project_id, message) async def broadcast_episode_complete( project_id: str, episode_number: int, success: bool, quality_score: float, data: Dict[str, Any] ): message = { "type": "episode_complete", "data": { "project_id": project_id, "episode_number": episode_number, "success": success, "quality_score": quality_score, **data }, "timestamp": datetime.now().isoformat() } await manager.send_to_project(project_id, message) async def broadcast_batch_progress( batch_id: str, current_episode: int, total_episodes: int, completed: int, failed: int, data: Dict[str, Any] ): message = { "type": "batch_progress", "data": { "batch_id": batch_id, "current_episode": current_episode, "total_episodes": total_episodes, "completed_episodes": completed, "failed_episodes": failed, "progress_percentage": (completed / total_episodes * 100) if total_episodes > 0 else 0, **data }, "timestamp": datetime.now().isoformat() } await manager.send_to_batch(batch_id, message) async def broadcast_error( project_id: str, episode_number: Optional[int], error: str, error_type: str = "execution_error" ): message = { "type": "error", "data": { "project_id": project_id, "episode_number": episode_number, "error": error, "error_type": error_type }, "timestamp": datetime.now().isoformat() } await manager.send_to_project(project_id, message) async def broadcast_batch_complete( batch_id: str, summary: Dict[str, Any] ): message = { "type": "batch_complete", "data": { "batch_id": batch_id, **summary }, "timestamp": datetime.now().isoformat() } await manager.send_to_batch(batch_id, message) # 导出连接管理器和辅助函数 __all__ = [ "manager", "broadcast_stage_update", "broadcast_episode_complete", "broadcast_batch_progress", "broadcast_error", "broadcast_batch_complete" ]