462 lines
13 KiB
Python
462 lines
13 KiB
Python
"""
|
|
WebSocket Streaming API
|
|
|
|
提供实时执行进度更新的 WebSocket 端点
|
|
"""
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query
|
|
from typing import Dict, Set, Optional, Any
|
|
import json
|
|
import asyncio
|
|
from datetime import datetime
|
|
|
|
from app.utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# ============================================
|
|
# WebSocket 连接管理
|
|
# ============================================
|
|
|
|
class ConnectionManager:
|
|
"""WebSocket 连接管理器"""
|
|
|
|
def __init__(self):
|
|
# 项目ID -> WebSocket连接集合
|
|
self.project_connections: Dict[str, Set[WebSocket]] = {}
|
|
# 批次ID -> WebSocket连接集合
|
|
self.batch_connections: Dict[str, Set[WebSocket]] = {}
|
|
|
|
async def connect_to_project(self, websocket: WebSocket, project_id: str):
|
|
"""连接到项目执行流"""
|
|
await websocket.accept()
|
|
if project_id not in self.project_connections:
|
|
self.project_connections[project_id] = set()
|
|
self.project_connections[project_id].add(websocket)
|
|
logger.info(f"WebSocket 已连接到项目: {project_id}")
|
|
|
|
async def connect_to_batch(self, websocket: WebSocket, batch_id: str):
|
|
"""连接到批次执行流"""
|
|
await websocket.accept()
|
|
if batch_id not in self.batch_connections:
|
|
self.batch_connections[batch_id] = set()
|
|
self.batch_connections[batch_id].add(websocket)
|
|
logger.info(f"WebSocket 已连接到批次: {batch_id}")
|
|
|
|
def disconnect(self, websocket: WebSocket):
|
|
"""断开连接"""
|
|
# 从所有项目连接中移除
|
|
for project_id, connections in self.project_connections.items():
|
|
if websocket in connections:
|
|
connections.remove(websocket)
|
|
logger.info(f"WebSocket 已从项目断开: {project_id}")
|
|
if not connections:
|
|
del self.project_connections[project_id]
|
|
|
|
# 从所有批次连接中移除
|
|
for batch_id, connections in self.batch_connections.items():
|
|
if websocket in connections:
|
|
connections.remove(websocket)
|
|
logger.info(f"WebSocket 已从批次断开: {batch_id}")
|
|
if not connections:
|
|
del self.batch_connections[batch_id]
|
|
|
|
async def send_to_project(
|
|
self,
|
|
project_id: str,
|
|
message: Dict[str, Any],
|
|
exclude: Optional[WebSocket] = None
|
|
):
|
|
"""向项目的所有连接发送消息"""
|
|
if project_id in self.project_connections:
|
|
disconnected = set()
|
|
for connection in self.project_connections[project_id]:
|
|
if connection != exclude:
|
|
try:
|
|
await connection.send_json(message)
|
|
except Exception as e:
|
|
logger.error(f"发送消息失败: {str(e)}")
|
|
disconnected.add(connection)
|
|
|
|
# 清理断开的连接
|
|
for connection in disconnected:
|
|
self.disconnect(connection)
|
|
|
|
async def send_to_batch(
|
|
self,
|
|
batch_id: str,
|
|
message: Dict[str, Any],
|
|
exclude: Optional[WebSocket] = None
|
|
):
|
|
"""向批次的所有连接发送消息"""
|
|
if batch_id in self.batch_connections:
|
|
disconnected = set()
|
|
for connection in self.batch_connections[batch_id]:
|
|
if connection != exclude:
|
|
try:
|
|
await connection.send_json(message)
|
|
except Exception as e:
|
|
logger.error(f"发送消息失败: {str(e)}")
|
|
disconnected.add(connection)
|
|
|
|
# 清理断开的连接
|
|
for connection in disconnected:
|
|
self.disconnect(connection)
|
|
|
|
def get_project_connections_count(self, project_id: str) -> int:
|
|
"""获取项目的连接数"""
|
|
return len(self.project_connections.get(project_id, set()))
|
|
|
|
def get_batch_connections_count(self, batch_id: str) -> int:
|
|
"""获取批次的连接数"""
|
|
return len(self.batch_connections.get(batch_id, set()))
|
|
|
|
|
|
# 全局连接管理器
|
|
manager = ConnectionManager()
|
|
|
|
|
|
# ============================================
|
|
# WebSocket 端点
|
|
# ============================================
|
|
|
|
@router.websocket("/ws/projects/{project_id}/execute")
|
|
async def websocket_project_execution(
|
|
websocket: WebSocket,
|
|
project_id: str
|
|
):
|
|
"""
|
|
项目执行 WebSocket 端点
|
|
|
|
实时接收项目执行进度更新,包括:
|
|
- 执行开始/完成事件
|
|
- 各阶段进度(结构分析、大纲生成、对话创作等)
|
|
- 质量检查结果
|
|
- 错误信息
|
|
|
|
消息格式:
|
|
{
|
|
"type": "stage_start|stage_progress|stage_complete|error|complete",
|
|
"data": {...},
|
|
"timestamp": "ISO 8601"
|
|
}
|
|
"""
|
|
await manager.connect_to_project(websocket, project_id)
|
|
|
|
try:
|
|
# 发送连接确认
|
|
await websocket.send_json({
|
|
"type": "connected",
|
|
"data": {
|
|
"project_id": project_id,
|
|
"message": "已连接到项目执行流",
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
})
|
|
|
|
# 保持连接并接收客户端消息
|
|
while True:
|
|
try:
|
|
# 接收客户端消息(可用于心跳、控制命令等)
|
|
data = await websocket.receive_text()
|
|
message = json.loads(data)
|
|
|
|
# 处理客户端消息
|
|
await _handle_client_message(websocket, project_id, message)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"WebSocket 客户端主动断开: {project_id}")
|
|
break
|
|
except json.JSONDecodeError:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {
|
|
"message": "无效的 JSON 格式"
|
|
}
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"处理 WebSocket 消息错误: {str(e)}")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {
|
|
"message": f"处理消息错误: {str(e)}"
|
|
}
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket 连接错误: {str(e)}")
|
|
|
|
finally:
|
|
manager.disconnect(websocket)
|
|
|
|
|
|
@router.websocket("/ws/batches/{batch_id}/execute")
|
|
async def websocket_batch_execution(
|
|
websocket: WebSocket,
|
|
batch_id: str
|
|
):
|
|
"""
|
|
批量执行 WebSocket 端点
|
|
|
|
实时接收批量执行进度更新,包括:
|
|
- 批次开始/完成事件
|
|
- 各剧集执行进度
|
|
- 整体进度百分比
|
|
- 质量统计信息
|
|
- 错误信息
|
|
|
|
消息格式:
|
|
{
|
|
"type": "batch_start|episode_start|episode_complete|progress|batch_complete|error",
|
|
"data": {...},
|
|
"timestamp": "ISO 8601"
|
|
}
|
|
"""
|
|
await manager.connect_to_batch(websocket, batch_id)
|
|
|
|
try:
|
|
# 发送连接确认
|
|
await websocket.send_json({
|
|
"type": "connected",
|
|
"data": {
|
|
"batch_id": batch_id,
|
|
"message": "已连接到批量执行流",
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
})
|
|
|
|
# 保持连接
|
|
while True:
|
|
try:
|
|
data = await websocket.receive_text()
|
|
message = json.loads(data)
|
|
await _handle_client_message(websocket, batch_id, message)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"WebSocket 客户端主动断开: {batch_id}")
|
|
break
|
|
except json.JSONDecodeError:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {"message": "无效的 JSON 格式"}
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"处理 WebSocket 消息错误: {str(e)}")
|
|
|
|
finally:
|
|
manager.disconnect(websocket)
|
|
|
|
|
|
# ============================================
|
|
# 消息处理
|
|
# ============================================
|
|
|
|
async def _handle_client_message(
|
|
websocket: WebSocket,
|
|
id: str,
|
|
message: Dict[str, Any]
|
|
):
|
|
"""
|
|
处理客户端发送的消息
|
|
|
|
Args:
|
|
websocket: WebSocket 连接
|
|
id: 项目ID或批次ID
|
|
message: 客户端消息
|
|
"""
|
|
message_type = message.get("type")
|
|
|
|
if message_type == "ping":
|
|
# 心跳响应
|
|
await websocket.send_json({
|
|
"type": "pong",
|
|
"data": {
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
})
|
|
|
|
elif message_type == "get_status":
|
|
# 请求状态
|
|
from app.core.execution.batch_executor import get_batch_executor
|
|
executor = get_batch_executor()
|
|
status = executor.get_batch_status(id)
|
|
|
|
await websocket.send_json({
|
|
"type": "status",
|
|
"data": status or {"message": "未找到执行状态"}
|
|
})
|
|
|
|
else:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {
|
|
"message": f"未知消息类型: {message_type}"
|
|
}
|
|
})
|
|
|
|
|
|
# ============================================
|
|
# 辅助函数 - 用于从其他模块发送消息
|
|
# ============================================
|
|
|
|
async def broadcast_stage_update(
|
|
project_id: str,
|
|
episode_number: int,
|
|
stage: str,
|
|
data: Dict[str, Any]
|
|
):
|
|
"""
|
|
广播阶段更新消息
|
|
|
|
Args:
|
|
project_id: 项目ID
|
|
episode_number: 集数
|
|
stage: 阶段名称
|
|
data: 阶段数据
|
|
"""
|
|
message = {
|
|
"type": "stage_update",
|
|
"data": {
|
|
"project_id": project_id,
|
|
"episode_number": episode_number,
|
|
"stage": stage,
|
|
**data
|
|
},
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
await manager.send_to_project(project_id, message)
|
|
|
|
|
|
async def broadcast_episode_complete(
|
|
project_id: str,
|
|
episode_number: int,
|
|
success: bool,
|
|
quality_score: float,
|
|
data: Dict[str, Any]
|
|
):
|
|
"""
|
|
广播剧集完成消息
|
|
|
|
Args:
|
|
project_id: 项目ID
|
|
episode_number: 集数
|
|
success: 是否成功
|
|
quality_score: 质量分数
|
|
data: 额外数据
|
|
"""
|
|
message = {
|
|
"type": "episode_complete",
|
|
"data": {
|
|
"project_id": project_id,
|
|
"episode_number": episode_number,
|
|
"success": success,
|
|
"quality_score": quality_score,
|
|
**data
|
|
},
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
await manager.send_to_project(project_id, message)
|
|
|
|
|
|
async def broadcast_batch_progress(
|
|
batch_id: str,
|
|
current_episode: int,
|
|
total_episodes: int,
|
|
completed: int,
|
|
failed: int,
|
|
data: Dict[str, Any]
|
|
):
|
|
"""
|
|
广播批量执行进度
|
|
|
|
Args:
|
|
batch_id: 批次ID
|
|
current_episode: 当前集数
|
|
total_episodes: 总集数
|
|
completed: 已完成数
|
|
failed: 失败数
|
|
data: 额外数据
|
|
"""
|
|
message = {
|
|
"type": "batch_progress",
|
|
"data": {
|
|
"batch_id": batch_id,
|
|
"current_episode": current_episode,
|
|
"total_episodes": total_episodes,
|
|
"completed_episodes": completed,
|
|
"failed_episodes": failed,
|
|
"progress_percentage": (completed / total_episodes * 100) if total_episodes > 0 else 0,
|
|
**data
|
|
},
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
await manager.send_to_batch(batch_id, message)
|
|
|
|
|
|
async def broadcast_error(
|
|
project_id: str,
|
|
episode_number: Optional[int],
|
|
error: str,
|
|
error_type: str = "execution_error"
|
|
):
|
|
"""
|
|
广播错误消息
|
|
|
|
Args:
|
|
project_id: 项目ID
|
|
episode_number: 集数(可选)
|
|
error: 错误信息
|
|
error_type: 错误类型
|
|
"""
|
|
message = {
|
|
"type": "error",
|
|
"data": {
|
|
"project_id": project_id,
|
|
"episode_number": episode_number,
|
|
"error": error,
|
|
"error_type": error_type
|
|
},
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
await manager.send_to_project(project_id, message)
|
|
|
|
|
|
async def broadcast_batch_complete(
|
|
batch_id: str,
|
|
summary: Dict[str, Any]
|
|
):
|
|
"""
|
|
广播批量执行完成
|
|
|
|
Args:
|
|
batch_id: 批次ID
|
|
summary: 执行摘要
|
|
"""
|
|
message = {
|
|
"type": "batch_complete",
|
|
"data": {
|
|
"batch_id": batch_id,
|
|
**summary
|
|
},
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
await manager.send_to_batch(batch_id, message)
|
|
|
|
|
|
# 导出连接管理器和辅助函数
|
|
__all__ = [
|
|
"manager",
|
|
"broadcast_stage_update",
|
|
"broadcast_episode_complete",
|
|
"broadcast_batch_progress",
|
|
"broadcast_error",
|
|
"broadcast_batch_complete"
|
|
]
|