""" WebSocket Streaming API 提供实时执行进度更新的 WebSocket 端点 """ from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query from typing import Dict, Set, Optional, Any import json import asyncio from datetime import datetime from app.utils.logger import get_logger logger = get_logger(__name__) router = APIRouter() # ============================================ # WebSocket 连接管理 # ============================================ class ConnectionManager: """WebSocket 连接管理器""" def __init__(self): # 项目ID -> WebSocket连接集合 self.project_connections: Dict[str, Set[WebSocket]] = {} # 批次ID -> WebSocket连接集合 self.batch_connections: Dict[str, Set[WebSocket]] = {} async def connect_to_project(self, websocket: WebSocket, project_id: str): """连接到项目执行流""" await websocket.accept() if project_id not in self.project_connections: self.project_connections[project_id] = set() self.project_connections[project_id].add(websocket) logger.info(f"WebSocket 已连接到项目: {project_id}") async def connect_to_batch(self, websocket: WebSocket, batch_id: str): """连接到批次执行流""" await websocket.accept() if batch_id not in self.batch_connections: self.batch_connections[batch_id] = set() self.batch_connections[batch_id].add(websocket) logger.info(f"WebSocket 已连接到批次: {batch_id}") def disconnect(self, websocket: WebSocket): """断开连接""" # 从所有项目连接中移除 for project_id, connections in self.project_connections.items(): if websocket in connections: connections.remove(websocket) logger.info(f"WebSocket 已从项目断开: {project_id}") if not connections: del self.project_connections[project_id] # 从所有批次连接中移除 for batch_id, connections in self.batch_connections.items(): if websocket in connections: connections.remove(websocket) logger.info(f"WebSocket 已从批次断开: {batch_id}") if not connections: del self.batch_connections[batch_id] async def send_to_project( self, project_id: str, message: Dict[str, Any], exclude: Optional[WebSocket] = None ): """向项目的所有连接发送消息""" if project_id in self.project_connections: disconnected = set() for connection in self.project_connections[project_id]: if connection != exclude: try: await connection.send_json(message) except Exception as e: logger.error(f"发送消息失败: {str(e)}") disconnected.add(connection) # 清理断开的连接 for connection in disconnected: self.disconnect(connection) async def send_to_batch( self, batch_id: str, message: Dict[str, Any], exclude: Optional[WebSocket] = None ): """向批次的所有连接发送消息""" if batch_id in self.batch_connections: disconnected = set() for connection in self.batch_connections[batch_id]: if connection != exclude: try: await connection.send_json(message) except Exception as e: logger.error(f"发送消息失败: {str(e)}") disconnected.add(connection) # 清理断开的连接 for connection in disconnected: self.disconnect(connection) def get_project_connections_count(self, project_id: str) -> int: """获取项目的连接数""" return len(self.project_connections.get(project_id, set())) def get_batch_connections_count(self, batch_id: str) -> int: """获取批次的连接数""" return len(self.batch_connections.get(batch_id, set())) # 全局连接管理器 manager = ConnectionManager() # ============================================ # WebSocket 端点 # ============================================ @router.websocket("/ws/projects/{project_id}/execute") async def websocket_project_execution( websocket: WebSocket, project_id: str ): """ 项目执行 WebSocket 端点 实时接收项目执行进度更新,包括: - 执行开始/完成事件 - 各阶段进度(结构分析、大纲生成、对话创作等) - 质量检查结果 - 错误信息 消息格式: { "type": "stage_start|stage_progress|stage_complete|error|complete", "data": {...}, "timestamp": "ISO 8601" } """ await manager.connect_to_project(websocket, project_id) try: # 发送连接确认 await websocket.send_json({ "type": "connected", "data": { "project_id": project_id, "message": "已连接到项目执行流", "timestamp": datetime.now().isoformat() } }) # 保持连接并接收客户端消息 while True: try: # 接收客户端消息(可用于心跳、控制命令等) data = await websocket.receive_text() message = json.loads(data) # 处理客户端消息 await _handle_client_message(websocket, project_id, message) except WebSocketDisconnect: logger.info(f"WebSocket 客户端主动断开: {project_id}") break except json.JSONDecodeError: await websocket.send_json({ "type": "error", "data": { "message": "无效的 JSON 格式" } }) except Exception as e: logger.error(f"处理 WebSocket 消息错误: {str(e)}") await websocket.send_json({ "type": "error", "data": { "message": f"处理消息错误: {str(e)}" } }) except Exception as e: logger.error(f"WebSocket 连接错误: {str(e)}") finally: manager.disconnect(websocket) @router.websocket("/ws/batches/{batch_id}/execute") async def websocket_batch_execution( websocket: WebSocket, batch_id: str ): """ 批量执行 WebSocket 端点 实时接收批量执行进度更新,包括: - 批次开始/完成事件 - 各剧集执行进度 - 整体进度百分比 - 质量统计信息 - 错误信息 消息格式: { "type": "batch_start|episode_start|episode_complete|progress|batch_complete|error", "data": {...}, "timestamp": "ISO 8601" } """ await manager.connect_to_batch(websocket, batch_id) try: # 发送连接确认 await websocket.send_json({ "type": "connected", "data": { "batch_id": batch_id, "message": "已连接到批量执行流", "timestamp": datetime.now().isoformat() } }) # 保持连接 while True: try: data = await websocket.receive_text() message = json.loads(data) await _handle_client_message(websocket, batch_id, message) except WebSocketDisconnect: logger.info(f"WebSocket 客户端主动断开: {batch_id}") break except json.JSONDecodeError: await websocket.send_json({ "type": "error", "data": {"message": "无效的 JSON 格式"} }) except Exception as e: logger.error(f"处理 WebSocket 消息错误: {str(e)}") finally: manager.disconnect(websocket) # ============================================ # 消息处理 # ============================================ async def _handle_client_message( websocket: WebSocket, id: str, message: Dict[str, Any] ): """ 处理客户端发送的消息 Args: websocket: WebSocket 连接 id: 项目ID或批次ID message: 客户端消息 """ message_type = message.get("type") if message_type == "ping": # 心跳响应 await websocket.send_json({ "type": "pong", "data": { "timestamp": datetime.now().isoformat() } }) elif message_type == "get_status": # 请求状态 from app.core.execution.batch_executor import get_batch_executor executor = get_batch_executor() status = executor.get_batch_status(id) await websocket.send_json({ "type": "status", "data": status or {"message": "未找到执行状态"} }) else: await websocket.send_json({ "type": "error", "data": { "message": f"未知消息类型: {message_type}" } }) # ============================================ # 辅助函数 - 用于从其他模块发送消息 # ============================================ async def broadcast_stage_update( project_id: str, episode_number: int, stage: str, data: Dict[str, Any] ): """ 广播阶段更新消息 Args: project_id: 项目ID episode_number: 集数 stage: 阶段名称 data: 阶段数据 """ message = { "type": "stage_update", "data": { "project_id": project_id, "episode_number": episode_number, "stage": stage, **data }, "timestamp": datetime.now().isoformat() } await manager.send_to_project(project_id, message) async def broadcast_episode_complete( project_id: str, episode_number: int, success: bool, quality_score: float, data: Dict[str, Any] ): """ 广播剧集完成消息 Args: project_id: 项目ID episode_number: 集数 success: 是否成功 quality_score: 质量分数 data: 额外数据 """ message = { "type": "episode_complete", "data": { "project_id": project_id, "episode_number": episode_number, "success": success, "quality_score": quality_score, **data }, "timestamp": datetime.now().isoformat() } await manager.send_to_project(project_id, message) async def broadcast_batch_progress( batch_id: str, current_episode: int, total_episodes: int, completed: int, failed: int, data: Dict[str, Any] ): """ 广播批量执行进度 Args: batch_id: 批次ID current_episode: 当前集数 total_episodes: 总集数 completed: 已完成数 failed: 失败数 data: 额外数据 """ message = { "type": "batch_progress", "data": { "batch_id": batch_id, "current_episode": current_episode, "total_episodes": total_episodes, "completed_episodes": completed, "failed_episodes": failed, "progress_percentage": (completed / total_episodes * 100) if total_episodes > 0 else 0, **data }, "timestamp": datetime.now().isoformat() } await manager.send_to_batch(batch_id, message) async def broadcast_error( project_id: str, episode_number: Optional[int], error: str, error_type: str = "execution_error" ): """ 广播错误消息 Args: project_id: 项目ID episode_number: 集数(可选) error: 错误信息 error_type: 错误类型 """ message = { "type": "error", "data": { "project_id": project_id, "episode_number": episode_number, "error": error, "error_type": error_type }, "timestamp": datetime.now().isoformat() } await manager.send_to_project(project_id, message) async def broadcast_batch_complete( batch_id: str, summary: Dict[str, Any] ): """ 广播批量执行完成 Args: batch_id: 批次ID summary: 执行摘要 """ message = { "type": "batch_complete", "data": { "batch_id": batch_id, **summary }, "timestamp": datetime.now().isoformat() } await manager.send_to_batch(batch_id, message) # 导出连接管理器和辅助函数 __all__ = [ "manager", "broadcast_stage_update", "broadcast_episode_complete", "broadcast_batch_progress", "broadcast_error", "broadcast_batch_complete" ]