665 lines
23 KiB
Python
665 lines
23 KiB
Python
"""
|
||
WebSocket Streaming API
|
||
|
||
提供实时执行进度更新的 WebSocket 端点
|
||
"""
|
||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query
|
||
from typing import Dict, Set, Optional, Any
|
||
import json
|
||
import asyncio
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
from app.config import settings
|
||
from app.utils.logger import get_logger
|
||
from app.core.agent_runtime.director_agent import DirectorAgent
|
||
from app.db.repositories import message_repo
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
# ============================================
|
||
# WebSocket 连接管理
|
||
# ============================================
|
||
|
||
class ConnectionManager:
|
||
"""WebSocket 连接管理器"""
|
||
|
||
def __init__(self):
|
||
# 项目ID -> WebSocket连接集合
|
||
self.project_connections: Dict[str, Set[WebSocket]] = {}
|
||
# 批次ID -> WebSocket连接集合
|
||
self.batch_connections: Dict[str, Set[WebSocket]] = {}
|
||
# 项目ID -> Agent实例
|
||
self.project_agents: Dict[str, DirectorAgent] = {}
|
||
|
||
async def connect_to_project(self, websocket: WebSocket, project_id: str):
|
||
"""连接到项目执行流"""
|
||
await websocket.accept()
|
||
if project_id not in self.project_connections:
|
||
self.project_connections[project_id] = set()
|
||
self.project_connections[project_id].add(websocket)
|
||
logger.info(f"WebSocket 已连接到项目: {project_id}")
|
||
|
||
async def connect_to_batch(self, websocket: WebSocket, batch_id: str):
|
||
"""连接到批次执行流"""
|
||
await websocket.accept()
|
||
if batch_id not in self.batch_connections:
|
||
self.batch_connections[batch_id] = set()
|
||
self.batch_connections[batch_id].add(websocket)
|
||
logger.info(f"WebSocket 已连接到批次: {batch_id}")
|
||
|
||
def disconnect(self, websocket: WebSocket):
|
||
"""断开连接"""
|
||
# 从所有项目连接中移除
|
||
for project_id, connections in self.project_connections.items():
|
||
if websocket in connections:
|
||
connections.remove(websocket)
|
||
logger.info(f"WebSocket 已从项目断开: {project_id}")
|
||
if not connections:
|
||
del self.project_connections[project_id]
|
||
# 清理 Agent 实例
|
||
if project_id in self.project_agents:
|
||
del self.project_agents[project_id]
|
||
|
||
# 从所有批次连接中移除
|
||
for batch_id, connections in self.batch_connections.items():
|
||
if websocket in connections:
|
||
connections.remove(websocket)
|
||
logger.info(f"WebSocket 已从批次断开: {batch_id}")
|
||
if not connections:
|
||
del self.batch_connections[batch_id]
|
||
|
||
async def send_to_project(
|
||
self,
|
||
project_id: str,
|
||
message: Dict[str, Any],
|
||
exclude: Optional[WebSocket] = None
|
||
):
|
||
"""向项目的所有连接发送消息"""
|
||
if project_id in self.project_connections:
|
||
disconnected = set()
|
||
for connection in self.project_connections[project_id]:
|
||
if connection != exclude:
|
||
try:
|
||
await connection.send_json(message)
|
||
except Exception as e:
|
||
logger.error(f"发送消息失败: {str(e)}")
|
||
disconnected.add(connection)
|
||
|
||
# 清理断开的连接
|
||
for connection in disconnected:
|
||
self.disconnect(connection)
|
||
|
||
async def send_to_batch(
|
||
self,
|
||
batch_id: str,
|
||
message: Dict[str, Any],
|
||
exclude: Optional[WebSocket] = None
|
||
):
|
||
"""向批次的所有连接发送消息"""
|
||
if batch_id in self.batch_connections:
|
||
disconnected = set()
|
||
for connection in self.batch_connections[batch_id]:
|
||
if connection != exclude:
|
||
try:
|
||
await connection.send_json(message)
|
||
except Exception as e:
|
||
logger.error(f"发送消息失败: {str(e)}")
|
||
disconnected.add(connection)
|
||
|
||
# 清理断开的连接
|
||
for connection in disconnected:
|
||
self.disconnect(connection)
|
||
|
||
def get_agent(self, project_id: str, working_dir: Path) -> DirectorAgent:
|
||
"""获取或创建 Agent 实例"""
|
||
if project_id not in self.project_agents:
|
||
# 确保工作目录存在
|
||
working_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 检查配置的模型类型
|
||
model_name = settings.zai_model
|
||
enable_thinking = True
|
||
|
||
# 如果是 GLM 模型,禁用 thinking 模式(不支持)
|
||
if "glm" in model_name.lower():
|
||
enable_thinking = False
|
||
|
||
# 不在这里加载项目上下文,而是在 WebSocket 消息处理时加载
|
||
# 因为 get_agent 是同步方法,而 project_repo.get 是异步的
|
||
self.project_agents[project_id] = DirectorAgent(
|
||
working_directory=working_dir,
|
||
enable_thinking=enable_thinking,
|
||
model=model_name,
|
||
project_context=None # 稍后在消息处理中更新
|
||
)
|
||
return self.project_agents[project_id]
|
||
|
||
def get_project_connections_count(self, project_id: str) -> int:
|
||
"""获取项目的连接数"""
|
||
return len(self.project_connections.get(project_id, set()))
|
||
|
||
def get_batch_connections_count(self, batch_id: str) -> int:
|
||
"""获取批次的连接数"""
|
||
return len(self.batch_connections.get(batch_id, set()))
|
||
|
||
|
||
# 全局连接管理器
|
||
manager = ConnectionManager()
|
||
|
||
|
||
# ============================================
|
||
# WebSocket 端点
|
||
# ============================================
|
||
|
||
@router.websocket("/ws/projects/{project_id}/execute")
|
||
async def websocket_project_execution(
|
||
websocket: WebSocket,
|
||
project_id: str
|
||
):
|
||
"""
|
||
项目执行 WebSocket 端点
|
||
"""
|
||
await manager.connect_to_project(websocket, project_id)
|
||
|
||
# 准备工作目录 (假设在 projects/{id})
|
||
# 注意:这里需要根据实际配置调整路径
|
||
project_dir = Path(f"d:/platform/creative_studio/workspace/projects/{project_id}")
|
||
|
||
try:
|
||
# 发送连接确认
|
||
await websocket.send_json({
|
||
"type": "connected",
|
||
"data": {
|
||
"project_id": project_id,
|
||
"message": "已连接到项目执行流",
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
})
|
||
|
||
# 加载并发送历史消息
|
||
history = await message_repo.get_history(project_id)
|
||
if history:
|
||
await websocket.send_json({
|
||
"type": "history",
|
||
"messages": history
|
||
})
|
||
|
||
# 保持连接并接收客户端消息
|
||
while True:
|
||
try:
|
||
# 接收客户端消息
|
||
data = await websocket.receive_text()
|
||
message = json.loads(data)
|
||
|
||
# 处理客户端消息
|
||
await _handle_client_message(websocket, project_id, message, project_dir)
|
||
|
||
except WebSocketDisconnect:
|
||
logger.info(f"WebSocket 客户端主动断开: {project_id}")
|
||
break
|
||
except json.JSONDecodeError:
|
||
await websocket.send_json({
|
||
"type": "error",
|
||
"data": {
|
||
"message": "无效的 JSON 格式"
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"处理 WebSocket 消息错误: {str(e)}")
|
||
await websocket.send_json({
|
||
"type": "error",
|
||
"data": {
|
||
"message": f"处理消息错误: {str(e)}"
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"WebSocket 连接错误: {str(e)}")
|
||
|
||
finally:
|
||
manager.disconnect(websocket)
|
||
|
||
|
||
@router.websocket("/ws/batches/{batch_id}/execute")
|
||
async def websocket_batch_execution(
|
||
websocket: WebSocket,
|
||
batch_id: str
|
||
):
|
||
"""
|
||
批量执行 WebSocket 端点
|
||
"""
|
||
await manager.connect_to_batch(websocket, batch_id)
|
||
|
||
try:
|
||
# 发送连接确认
|
||
await websocket.send_json({
|
||
"type": "connected",
|
||
"data": {
|
||
"batch_id": batch_id,
|
||
"message": "已连接到批量执行流",
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
})
|
||
|
||
# 保持连接
|
||
while True:
|
||
try:
|
||
data = await websocket.receive_text()
|
||
# 批量执行目前不接受客户端控制消息,仅广播
|
||
# 但为了保持连接活性,可以处理 ping
|
||
message = json.loads(data)
|
||
if message.get("type") == "ping":
|
||
await websocket.send_json({
|
||
"type": "pong",
|
||
"data": {"timestamp": datetime.now().isoformat()}
|
||
})
|
||
|
||
except WebSocketDisconnect:
|
||
logger.info(f"WebSocket 客户端主动断开: {batch_id}")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"处理 WebSocket 消息错误: {str(e)}")
|
||
|
||
finally:
|
||
manager.disconnect(websocket)
|
||
|
||
|
||
# ============================================
|
||
# 消息处理
|
||
# ============================================
|
||
|
||
async def _handle_client_message(
|
||
websocket: WebSocket,
|
||
project_id: str,
|
||
message: Dict[str, Any],
|
||
project_dir: Path
|
||
):
|
||
"""
|
||
处理客户端发送的消息
|
||
"""
|
||
message_type = message.get("type")
|
||
logger.info(f"收到消息: type={message_type}, full_message={message}")
|
||
|
||
if message_type == "ping":
|
||
# 心跳响应
|
||
await websocket.send_json({
|
||
"type": "pong",
|
||
"data": {
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
})
|
||
|
||
elif message_type == "inbox_action":
|
||
# 用户在 Inbox 中点击批准或拒绝
|
||
action = message.get("action")
|
||
item_id = message.get("itemId")
|
||
|
||
# 将操作转换为自然语言反馈给 Agent
|
||
feedback = f"User {action}ed inbox item {item_id}."
|
||
|
||
agent = manager.get_agent(project_id, project_dir)
|
||
try:
|
||
for event in agent.stream_events(feedback, thread_id=project_id):
|
||
# 同样的事件处理逻辑
|
||
if event.get("type") == "tool_call":
|
||
await _handle_tool_call(project_id, event)
|
||
await manager.send_to_project(project_id, event)
|
||
except Exception as e:
|
||
await manager.send_to_project(project_id, {
|
||
"type": "error",
|
||
"data": {"message": str(e)}
|
||
})
|
||
|
||
elif message_type == "chat_message":
|
||
# 用户发送聊天消息 -> 触发 Agent 执行
|
||
content = message.get("content", "")
|
||
if not content:
|
||
return
|
||
|
||
# 保存用户消息
|
||
await message_repo.add_message(project_id, "user", content)
|
||
|
||
# 获取 Agent
|
||
agent = manager.get_agent(project_id, project_dir)
|
||
|
||
# 加载项目上下文并更新 Agent(如果尚未加载)
|
||
if agent.context and not agent.context.project_id:
|
||
try:
|
||
from app.db.repositories import project_repo
|
||
from app.core.agent_runtime.context import SkillAgentContext
|
||
from app.core.agent_runtime.skill_loader import SkillLoader
|
||
from app.core.skills.skill_manager import skill_manager
|
||
|
||
project = await project_repo.get(project_id)
|
||
if project:
|
||
# 将项目的 defaultTaskSkills 转换为 user_skills 格式
|
||
user_skills = []
|
||
if hasattr(project, 'defaultTaskSkills') and project.defaultTaskSkills:
|
||
for task_config in project.defaultTaskSkills:
|
||
for skill_config in task_config.skills:
|
||
try:
|
||
# 通过 skill_manager 获取技能详细信息
|
||
skill = skill_manager.get_skill_by_id(skill_config.skill_id)
|
||
if skill:
|
||
user_skills.append({
|
||
'id': skill.id,
|
||
'name': skill.name,
|
||
'behavior': skill.behavior_guide or skill.description or ''
|
||
})
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load skill {skill_config.skill_id}: {e}")
|
||
|
||
# 创建项目上下文
|
||
project_context = SkillAgentContext(
|
||
skill_loader=agent.context.skill_loader,
|
||
working_directory=agent.context.working_directory,
|
||
project_id=project.id,
|
||
project_name=project.name,
|
||
project_genre=getattr(project, 'genre', '古风'),
|
||
total_episodes=project.totalEpisodes,
|
||
world_setting=project.globalContext.worldSetting if project.globalContext else None,
|
||
characters=project.globalContext.styleGuide if project.globalContext else None,
|
||
overall_outline=project.globalContext.overallOutline if project.globalContext else None,
|
||
creation_mode='script' if (project.globalContext and project.globalContext.uploadedScript) else 'inspiration',
|
||
source_content=(project.globalContext.uploadedScript if project.globalContext and project.globalContext.uploadedScript
|
||
else project.globalContext.inspiration if project.globalContext else None),
|
||
user_skills=user_skills
|
||
)
|
||
# 更新 Agent 的上下文
|
||
agent.context = project_context
|
||
# 重新构建 system prompt
|
||
agent.system_prompt = agent._build_system_prompt()
|
||
logger.info(f"Loaded project context for {project_id}: {project.name}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load project context for {project_id}: {e}")
|
||
|
||
# 异步运行 Agent 并将事件流推送到前端
|
||
full_response = ""
|
||
try:
|
||
for event in agent.stream_events(content, thread_id=project_id):
|
||
# 检查特殊工具调用并转换格式
|
||
if event.get("type") == "tool_call":
|
||
await _handle_tool_call(project_id, event)
|
||
|
||
# 收集 Agent 回复内容
|
||
if event.get("type") == "text":
|
||
full_response += event.get("content", "")
|
||
|
||
await manager.send_to_project(project_id, event)
|
||
|
||
# 保存 Agent 回复
|
||
if full_response:
|
||
await message_repo.add_message(project_id, "agent", full_response)
|
||
|
||
except Exception as e:
|
||
await manager.send_to_project(project_id, {
|
||
"type": "error",
|
||
"data": {"message": str(e)}
|
||
})
|
||
|
||
elif message_type == "get_status":
|
||
# 请求状态
|
||
# 这里可以返回 Agent 的状态,或者之前的 executor 状态
|
||
pass
|
||
|
||
else:
|
||
await websocket.send_json({
|
||
"type": "error",
|
||
"data": {
|
||
"message": f"未知消息类型: {message_type}"
|
||
}
|
||
})
|
||
|
||
|
||
async def _handle_tool_call(project_id: str, event: Dict[str, Any]):
|
||
"""
|
||
处理工具调用,转换为特定的 WebSocket 消息
|
||
|
||
这个函数在 agent.stream_events() 中被调用,当检测到 director 工具调用时,
|
||
会将其转换为前端可以理解的 WebSocket 事件格式。
|
||
"""
|
||
name = event.get("name")
|
||
args = event.get("args", {})
|
||
|
||
# Director 工具处理
|
||
if name == "update_plan":
|
||
await manager.send_to_project(project_id, {
|
||
"type": "plan_update",
|
||
"plan": args.get("steps", []),
|
||
"status": args.get("status", "planning"),
|
||
"current_step_index": args.get("current_step_index", 0)
|
||
})
|
||
|
||
elif name == "add_inbox_task":
|
||
await manager.send_to_project(project_id, {
|
||
"type": "review_request",
|
||
"id": f"task_{args.get('title', 'unknown')}_{int(datetime.now().timestamp())}",
|
||
"title": args.get("title"),
|
||
"description": args.get("description"),
|
||
"options": args.get("options", ["Approve", "Reject"]),
|
||
"timestamp": int(datetime.now().timestamp() * 1000)
|
||
})
|
||
|
||
elif name == "add_annotation":
|
||
await manager.send_to_project(project_id, {
|
||
"type": "annotation_add",
|
||
"annotation": {
|
||
"content": args.get("content"),
|
||
"type": args.get("annotation_type", "review"),
|
||
"suggestion": args.get("suggestion", ""),
|
||
"timestamp": int(datetime.now().timestamp() * 1000)
|
||
}
|
||
})
|
||
|
||
elif name == "update_context":
|
||
# 解析 data (可能是 JSON string 或 dict)
|
||
data = args.get("data")
|
||
context_type = args.get("context_type", "state")
|
||
|
||
try:
|
||
if isinstance(data, str):
|
||
data = json.loads(data)
|
||
|
||
# 转换为前端期望的 activeStates 格式
|
||
if isinstance(data, dict):
|
||
# 将字典转换为 [{type, value}, ...] 格式
|
||
states = [{"type": k, "value": v} for k, v in data.items()]
|
||
elif isinstance(data, list):
|
||
# 已经是列表格式
|
||
states = data
|
||
else:
|
||
states = [{"type": context_type, "value": str(data)}]
|
||
|
||
await manager.send_to_project(project_id, {
|
||
"type": "context_update",
|
||
"states": states
|
||
})
|
||
except Exception as e:
|
||
logger.warning(f"Failed to process context update: {e}")
|
||
|
||
elif name == "write_to_canvas":
|
||
# 新的 write_to_canvas 工具
|
||
content = args.get("content", "")
|
||
if content:
|
||
await manager.send_to_project(project_id, {
|
||
"type": "canvas_update",
|
||
"content": content
|
||
})
|
||
|
||
elif name == "write_file":
|
||
# 如果写入的是当前画布文件,也更新画布
|
||
# 这里简化:只要有内容就更新画布
|
||
content = args.get("content")
|
||
if content:
|
||
await manager.send_to_project(project_id, {
|
||
"type": "canvas_update",
|
||
"content": content
|
||
})
|
||
|
||
elif name == "update_memory":
|
||
# 处理记忆库更新
|
||
memory_type = args.get("memory_type", "timeline")
|
||
data = args.get("data", {})
|
||
|
||
# 根据记忆类型格式化数据
|
||
memory_data = {
|
||
"type": memory_type,
|
||
"memory_type": memory_type,
|
||
"timestamp": int(datetime.now().timestamp() * 1000)
|
||
}
|
||
|
||
# 添加具体信息
|
||
if memory_type == "timeline":
|
||
memory_data.update({
|
||
"title": data.get("event", "时间线事件"),
|
||
"description": data.get("description", "")
|
||
})
|
||
elif memory_type == "character_state":
|
||
memory_data.update({
|
||
"title": f"{data.get('character', '角色')}状态变化",
|
||
"description": data.get("state", ""),
|
||
"character": data.get("character", "")
|
||
})
|
||
elif memory_type == "pending_thread":
|
||
memory_data.update({
|
||
"title": "待收线问题",
|
||
"description": data.get("description", "")
|
||
})
|
||
elif memory_type == "foreshadowing":
|
||
memory_data.update({
|
||
"title": "伏笔",
|
||
"description": data.get("description", "")
|
||
})
|
||
|
||
await manager.send_to_project(project_id, {
|
||
"type": "memory_update",
|
||
"data": memory_data
|
||
})
|
||
|
||
elif name == "save_episode":
|
||
# 处理剧集保存
|
||
episode_number = args.get("episode_number", 0)
|
||
title = args.get("title", "")
|
||
|
||
await manager.send_to_project(project_id, {
|
||
"type": "episode_saved",
|
||
"episode_number": episode_number,
|
||
"title": title
|
||
})
|
||
|
||
# ============================================
|
||
# 辅助函数 - 用于从其他模块发送消息
|
||
# ============================================
|
||
|
||
async def broadcast_stage_update(
|
||
project_id: str,
|
||
episode_number: int,
|
||
stage: str,
|
||
data: Dict[str, Any]
|
||
):
|
||
message = {
|
||
"type": "stage_update",
|
||
"data": {
|
||
"project_id": project_id,
|
||
"episode_number": episode_number,
|
||
"stage": stage,
|
||
**data
|
||
},
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
await manager.send_to_project(project_id, message)
|
||
|
||
|
||
async def broadcast_episode_complete(
|
||
project_id: str,
|
||
episode_number: int,
|
||
success: bool,
|
||
quality_score: float,
|
||
data: Dict[str, Any]
|
||
):
|
||
message = {
|
||
"type": "episode_complete",
|
||
"data": {
|
||
"project_id": project_id,
|
||
"episode_number": episode_number,
|
||
"success": success,
|
||
"quality_score": quality_score,
|
||
**data
|
||
},
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
await manager.send_to_project(project_id, message)
|
||
|
||
|
||
async def broadcast_batch_progress(
|
||
batch_id: str,
|
||
current_episode: int,
|
||
total_episodes: int,
|
||
completed: int,
|
||
failed: int,
|
||
data: Dict[str, Any]
|
||
):
|
||
message = {
|
||
"type": "batch_progress",
|
||
"data": {
|
||
"batch_id": batch_id,
|
||
"current_episode": current_episode,
|
||
"total_episodes": total_episodes,
|
||
"completed_episodes": completed,
|
||
"failed_episodes": failed,
|
||
"progress_percentage": (completed / total_episodes * 100) if total_episodes > 0 else 0,
|
||
**data
|
||
},
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
await manager.send_to_batch(batch_id, message)
|
||
|
||
|
||
async def broadcast_error(
|
||
project_id: str,
|
||
episode_number: Optional[int],
|
||
error: str,
|
||
error_type: str = "execution_error"
|
||
):
|
||
message = {
|
||
"type": "error",
|
||
"data": {
|
||
"project_id": project_id,
|
||
"episode_number": episode_number,
|
||
"error": error,
|
||
"error_type": error_type
|
||
},
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
await manager.send_to_project(project_id, message)
|
||
|
||
|
||
async def broadcast_batch_complete(
|
||
batch_id: str,
|
||
summary: Dict[str, Any]
|
||
):
|
||
message = {
|
||
"type": "batch_complete",
|
||
"data": {
|
||
"batch_id": batch_id,
|
||
**summary
|
||
},
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
await manager.send_to_batch(batch_id, message)
|
||
|
||
|
||
# 导出连接管理器和辅助函数
|
||
__all__ = [
|
||
"manager",
|
||
"broadcast_stage_update",
|
||
"broadcast_episode_complete",
|
||
"broadcast_batch_progress",
|
||
"broadcast_error",
|
||
"broadcast_batch_complete"
|
||
]
|