854 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
项目管理 API 路由
提供项目的 CRUD 操作和剧集执行功能
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
from typing import List, Optional
from pydantic import BaseModel, Field
from app.models.project import (
SeriesProject,
SeriesProjectCreate,
Episode,
EpisodeExecuteRequest,
EpisodeExecuteResponse
)
from app.models.review import ReviewConfig
from app.models.skill_config import (
ProjectSkillConfigUpdate,
SkillConfigResponse,
EpisodeSkillConfigUpdate
)
from app.core.agents.series_creation_agent import get_series_agent
from app.core.execution.batch_executor import get_batch_executor
from app.core.execution.retry_manager import get_retry_manager, RetryConfig
from app.db.repositories import project_repo, episode_repo
from app.utils.logger import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/projects", tags=["项目管理"])
# ============================================
# 项目管理
# ============================================
@router.post("/", response_model=SeriesProject, status_code=status.HTTP_201_CREATED)
async def create_project(project_data: SeriesProjectCreate):
"""创建新项目并自动生成剧集记录"""
try:
# 创建项目
project = await project_repo.create(project_data)
# 自动创建剧集记录(状态为 pending
import uuid
for episode_num in range(1, project.totalEpisodes + 1):
episode = Episode(
id=str(uuid.uuid4()),
projectId=project.id,
number=episode_num,
title=f"{episode_num}集内容创作",
status="pending",
content="" # 初始化为空白
)
await episode_repo.create(episode)
logger.info(f"自动创建剧集: {episode.id} - EP{episode_num}")
return project
except Exception as e:
logger.error(f"创建项目失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"创建项目失败: {str(e)}"
)
@router.get("/", response_model=List[SeriesProject])
async def list_projects(skip: int = 0, limit: int = 100):
"""列出所有项目"""
projects = await project_repo.list(skip=skip, limit=limit)
return projects
@router.get("/{project_id}", response_model=SeriesProject)
async def get_project(project_id: str):
"""获取项目详情"""
project = await project_repo.get(project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
return project
@router.put("/{project_id}", response_model=SeriesProject)
async def update_project(project_id: str, project_data: dict):
"""更新项目"""
project = await project_repo.update(project_id, project_data)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
return project
@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_project(project_id: str):
"""删除项目"""
success = await project_repo.delete(project_id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
return None
# ============================================
# 剧集管理
# ============================================
@router.get("/{project_id}/episodes", response_model=List[Episode])
async def list_episodes(project_id: str):
"""列出项目的所有剧集,如果为空则自动初始化"""
# 先验证项目存在
project = await project_repo.get(project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
episodes = await episode_repo.list_by_project(project_id)
# 如果剧集列表为空,自动初始化剧集记录
if not episodes and project.totalEpisodes:
import uuid
logger.info(f"项目 {project_id} 暂无剧集记录,自动初始化 {project.totalEpisodes} 集...")
for episode_num in range(1, project.totalEpisodes + 1):
# 再次检查,防止并发或逻辑重复
existing = next((ep for ep in episodes if ep.number == episode_num), None)
if existing:
continue
episode = Episode(
id=str(uuid.uuid4()),
projectId=project_id,
number=episode_num,
title=f"{episode_num}集内容创作",
status="pending",
content="" # 初始化为空白,避免触发前端生成大纲按钮
)
await episode_repo.create(episode)
# 重新获取列表
episodes = await episode_repo.list_by_project(project_id)
return episodes
@router.get("/{project_id}/episodes/{episode_number}", response_model=Episode)
async def get_episode(project_id: str, episode_number: int):
"""获取指定集数"""
episodes = await episode_repo.list_by_project(project_id)
for ep in episodes:
if ep.number == episode_number:
return ep
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"剧集不存在: EP{episode_number}"
)
@router.put("/{project_id}/episodes/{episode_number}", response_model=Episode)
async def update_episode(project_id: str, episode_number: int, update_data: dict):
"""更新指定集数的内容"""
episodes = await episode_repo.list_by_project(project_id)
episode = next((ep for ep in episodes if ep.number == episode_number), None)
if not episode:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"剧集不存在: EP{episode_number}"
)
# 更新允许的字段
if "title" in update_data:
episode.title = update_data["title"]
if "content" in update_data:
episode.content = update_data["content"]
if "outline" in update_data:
episode.outline = update_data["outline"]
if "summary" in update_data:
episode.summary = update_data["summary"]
if "status" in update_data:
episode.status = update_data["status"]
# 如果状态变为完成,设置完成时间
if episode.status == "completed" and not episode.completedAt:
from datetime import datetime
episode.completedAt = datetime.now()
# 保存更新
await episode_repo.update(episode)
logger.info(f"更新剧集: {episode.id} - EP{episode_number}")
return episode
# ============================================
# 剧集执行(核心功能)
# ============================================
@router.post("/{project_id}/execute", response_model=EpisodeExecuteResponse)
async def execute_episode(
project_id: str,
request: EpisodeExecuteRequest,
background_tasks: BackgroundTasks
):
"""
执行单集创作(已弃用,建议通过 WebSocket 使用 DirectorAgent
注意:此端点现在会立即返回,创作在后台执行。
推荐通过 WebSocket 连接到 /ws/projects/{project_id}/execute 并发送消息。
"""
# 获取项目验证存在
project = await project_repo.get(project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
logger.info(f"接收创作请求(后台执行模式): 项目 {project_id}, EP{request.episodeNumber}")
# 立即返回响应,创作在后台执行
background_tasks.add_task(_execute_episode_in_background, project_id, request.episodeNumber, request.title)
return EpisodeExecuteResponse(
episode=Episode(
projectId=project_id,
number=request.episodeNumber,
status="writing",
title=request.title or f"{request.episodeNumber}"
),
success=True,
message=f"EP{request.episodeNumber} 已开始在后台创作"
)
async def _execute_episode_in_background(
project_id: str,
episode_number: int,
title: str
):
"""后台执行剧集创作的辅助函数"""
try:
from app.api.v1.websocket import _execute_episode_creation
logger.info(f"后台创作任务开始: EP{episode_number}")
# 调用 WebSocket 的创作函数(它已经处理了后台执行)
await _execute_episode_creation(
project_id=project_id,
episode_number=episode_number,
analyze_previous_memory=True
)
logger.info(f"后台创作任务完成: EP{episode_number}")
except Exception as e:
logger.error(f"后台创作任务失败: {str(e)}", exc_info=True)
# ============================================
# 请求/响应模型
# ============================================
class BatchExecuteRequest(BaseModel):
"""批量执行请求"""
start_episode: int = Field(..., ge=1, description="起始集数")
end_episode: int = Field(..., ge=1, description="结束集数")
enable_review: bool = Field(True, description="是否启用质量检查")
enable_retry: bool = Field(True, description="是否启用自动重试")
max_retries: int = Field(2, ge=0, le=5, description="最大重试次数")
quality_threshold: float = Field(75.0, ge=0, le=100, description="质量阈值")
class AutoExecuteRequest(BaseModel):
"""自动执行请求"""
start_episode: int = Field(1, ge=1, description="起始集数")
episode_count: Optional[int] = Field(None, ge=1, description="执行集数(不指定则执行到项目总集数)")
enable_review: bool = Field(True, description="是否启用质量检查")
enable_retry: bool = Field(True, description="是否启用自动重试")
max_retries: int = Field(2, ge=0, le=5, description="最大重试次数")
quality_threshold: float = Field(75.0, ge=0, le=100, description="质量阈值")
stop_on_failure: bool = Field(False, description="失败时是否停止")
class StopExecutionRequest(BaseModel):
"""停止执行请求"""
batch_id: str = Field(..., description="批次ID")
# ============================================
# 增强的批量执行
# ============================================
@router.post("/{project_id}/execute-batch", response_model=dict)
async def execute_batch_enhanced(
project_id: str,
request: BatchExecuteRequest,
background_tasks: BackgroundTasks
):
"""
增强的批量执行
新功能:
- 集成质量检查
- 自动重试失败剧集
- 实时进度推送(通过 WebSocket
- 生成详细执行摘要
返回批次ID可通过 WebSocket 或状态端点追踪进度
"""
project = await project_repo.get(project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
if request.end_episode < request.start_episode:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="结束集数必须大于或等于起始集数"
)
# 获取批量执行器
batch_executor = get_batch_executor()
# 定义进度回调(发送 WebSocket 消息)
async def on_progress(progress_data):
from app.api.v1.websocket import broadcast_batch_progress
await broadcast_batch_progress(
batch_id=progress_data["batch_id"],
current_episode=progress_data["current_episode"],
total_episodes=progress_data["total_episodes"],
completed=progress_data["completed_episodes"],
failed=progress_data["failed_episodes"],
data=progress_data.get("current_episode_result", {})
)
# 定义剧集完成回调
async def on_episode_complete(episode_result):
from app.api.v1.websocket import broadcast_episode_complete
await broadcast_episode_complete(
project_id=project_id,
episode_number=episode_result["episode_number"],
success=episode_result["success"],
quality_score=episode_result.get("quality_score", 0),
data=episode_result
)
# 定义错误回调
async def on_error(error_data):
from app.api.v1.websocket import broadcast_error
await broadcast_error(
project_id=project_id,
episode_number=error_data.get("episode_number"),
error=error_data.get("error", "未知错误"),
error_type="batch_execution_error"
)
# 构建审核配置
review_config = None
if request.enable_review:
from app.models.review import ReviewConfig, DimensionConfig, DimensionType
review_config = ReviewConfig(
enabled_review_skills=["consistency_checker"],
overall_strictness=0.7,
pass_threshold=request.quality_threshold
)
# 添加默认维度
for dim_type in [DimensionType.consistency, DimensionType.quality, DimensionType.dialogue]:
review_config.dimension_settings[dim_type] = DimensionConfig(
enabled=True,
strictness=0.7,
weight=1.0
)
# 在后台执行批量创作
async def run_batch():
try:
summary = await batch_executor.execute_batch(
project=project,
start_episode=request.start_episode,
end_episode=request.end_episode,
review_config=review_config,
enable_retry=request.enable_retry,
max_retries=request.max_retries,
on_progress=on_progress,
on_episode_complete=on_episode_complete,
on_error=on_error
)
# 广播完成消息
from app.api.v1.websocket import broadcast_batch_complete
await broadcast_batch_complete(summary["batch_id"], summary)
except Exception as e:
logger.error(f"批量执行后台任务失败: {str(e)}")
from app.api.v1.websocket import broadcast_error
await broadcast_error(
project_id=project_id,
episode_number=None,
error=f"批量执行失败: {str(e)}",
error_type="batch_error"
)
# 添加后台任务
background_tasks.add_task(run_batch)
# 立即返回批次ID
# 注意由于是在后台执行我们需要先创建一个占位的批次ID
import uuid
batch_id = str(uuid.uuid4())
return {
"batch_id": batch_id,
"project_id": project_id,
"start_episode": request.start_episode,
"end_episode": request.end_episode,
"status": "started",
"message": "批量执行已启动,请通过 WebSocket 或状态端点追踪进度",
"websocket_url": f"/ws/batches/{batch_id}/execute"
}
@router.post("/{project_id}/execute-auto", response_model=dict)
async def execute_auto(
project_id: str,
request: AutoExecuteRequest,
background_tasks: BackgroundTasks
):
"""
自动执行模式
自动执行指定数量的剧集,具有以下特性:
- 智能质量检查和重试
- 失败时可选停止或继续
- 实时进度通知
- 完整的执行摘要
这是推荐的执行方式,适合大部分使用场景
"""
project = await project_repo.get(project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
# 确定结束集数
end_episode = request.start_episode + (request.episode_count or 10) - 1
if end_episode > project.totalEpisodes:
end_episode = project.totalEpisodes
# 获取执行器
batch_executor = get_batch_executor()
# 定义回调
async def on_progress(progress_data):
from app.api.v1.websocket import broadcast_batch_progress
await broadcast_batch_progress(
batch_id=progress_data["batch_id"],
current_episode=progress_data["current_episode"],
total_episodes=progress_data["total_episodes"],
completed=progress_data["completed_episodes"],
failed=progress_data["failed_episodes"],
data=progress_data.get("current_episode_result", {})
)
async def on_episode_complete(episode_result):
from app.api.v1.websocket import broadcast_episode_complete
await broadcast_episode_complete(
project_id=project_id,
episode_number=episode_result["episode_number"],
success=episode_result["success"],
quality_score=episode_result.get("quality_score", 0),
data=episode_result
)
async def on_error(error_data):
from app.api.v1.websocket import broadcast_error
await broadcast_error(
project_id=project_id,
episode_number=error_data.get("episode_number"),
error=error_data.get("error", "未知错误"),
error_type="auto_execution_error"
)
# 如果设置为失败时停止,则停止批次执行
if request.stop_on_failure:
batch_id = error_data.get("batch_id")
if batch_id:
await batch_executor.stop_batch(batch_id)
# 构建审核配置
review_config = None
if request.enable_review:
from app.models.review import ReviewConfig, DimensionConfig, DimensionType
review_config = ReviewConfig(
enabled_review_skills=["consistency_checker"],
overall_strictness=0.7,
pass_threshold=request.quality_threshold
)
for dim_type in [DimensionType.consistency, DimensionType.quality, DimensionType.dialogue]:
review_config.dimension_settings[dim_type] = DimensionConfig(
enabled=True,
strictness=0.7,
weight=1.0
)
# 后台执行
async def run_auto():
try:
summary = await batch_executor.execute_batch(
project=project,
start_episode=request.start_episode,
end_episode=end_episode,
review_config=review_config,
enable_retry=request.enable_retry,
max_retries=request.max_retries,
on_progress=on_progress,
on_episode_complete=on_episode_complete,
on_error=on_error
)
from app.api.v1.websocket import broadcast_batch_complete
await broadcast_batch_complete(summary["batch_id"], summary)
except Exception as e:
logger.error(f"自动执行失败: {str(e)}")
from app.api.v1.websocket import broadcast_error
await broadcast_error(
project_id=project_id,
episode_number=None,
error=f"自动执行失败: {str(e)}",
error_type="auto_error"
)
background_tasks.add_task(run_auto)
import uuid
batch_id = str(uuid.uuid4())
return {
"batch_id": batch_id,
"project_id": project_id,
"start_episode": request.start_episode,
"end_episode": end_episode,
"total_episodes": end_episode - request.start_episode + 1,
"status": "started",
"stop_on_failure": request.stop_on_failure,
"message": "自动执行已启动",
"websocket_url": f"/ws/batches/{batch_id}/execute"
}
@router.get("/{project_id}/execution-status")
async def get_execution_status(
project_id: str,
batch_id: Optional[str] = None
):
"""
获取执行状态
Args:
project_id: 项目ID
batch_id: 批次ID可选不提供则返回所有活跃批次
Returns:
执行状态信息
"""
project = await project_repo.get(project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
batch_executor = get_batch_executor()
if batch_id:
# 获取指定批次状态
status = batch_executor.get_batch_status(batch_id)
if not status:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"批次不存在: {batch_id}"
)
return status
else:
# 返回所有活跃批次(简化实现)
# 在实际应用中,可能需要维护项目到批次的映射
return {
"project_id": project_id,
"active_batches": [],
"message": "请提供具体的批次ID"
}
@router.post("/{project_id}/stop-execution")
async def stop_execution(project_id: str, request: StopExecutionRequest):
"""
停止执行
停止正在运行的批量执行任务
Args:
project_id: 项目ID
request: 停止请求包含批次ID
Returns:
停止结果
"""
project = await project_repo.get(project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
batch_executor = get_batch_executor()
success = await batch_executor.stop_batch(request.batch_id)
if success:
return {
"batch_id": request.batch_id,
"project_id": project_id,
"status": "stopping",
"message": "已发送停止信号"
}
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"批次不存在或已完成: {request.batch_id}"
)
# ============================================
# 旧的批量执行端点(保留兼容性)
# ============================================
@router.post("/{project_id}/execute-batch-legacy")
async def execute_batch_legacy(
project_id: str,
start_episode: int = 1,
end_episode: int = 3
):
"""
旧的批量执行端点(保留用于向后兼容)
创作指定范围的剧集(用于分批次模式)
"""
project = await project_repo.get(project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
agent = get_series_agent()
results = []
for ep_num in range(start_episode, end_episode + 1):
try:
episode = await agent.execute_episode(project, ep_num)
episode.projectId = project_id
await episode_repo.create(episode)
results.append({
"episode": ep_num,
"success": True,
"qualityScore": episode.qualityScore
})
except Exception as e:
results.append({
"episode": ep_num,
"success": False,
"error": str(e)
})
return {
"projectId": project_id,
"results": results,
"total": len(results),
"success": sum(1 for r in results if r["success"])
}
# ============================================
# 记忆系统
# ============================================
@router.get("/{project_id}/memory")
async def get_project_memory(project_id: str):
"""获取项目记忆系统"""
project = await project_repo.get(project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
return project.memory
# ============================================
# Skills 配置管理
# ============================================
@router.get("/{project_id}/skill-config", response_model=SkillConfigResponse)
async def get_project_skill_config(project_id: str):
"""获取项目的 Skills 配置"""
project = await project_repo.get(project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
return {
"defaultTaskSkills": project.defaultTaskSkills,
"episodeSkillOverrides": project.episodeSkillOverrides
}
@router.put("/{project_id}/skill-config")
async def update_project_skill_config(
project_id: str,
config: ProjectSkillConfigUpdate
):
"""更新项目的 Skills 配置"""
project = await project_repo.get(project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
try:
# 转换配置为字典格式(用于保存)
update_data = {
"defaultTaskSkills": [t.dict() for t in config.defaultTaskSkills],
"episodeSkillOverrides": {
k: v.dict() for k, v in config.episodeSkillOverrides.items()
}
}
await project_repo.update(project_id, update_data)
return {
"success": True,
"message": "Skills 配置已更新",
"defaultTaskSkills": config.defaultTaskSkills,
"episodeSkillOverrides": config.episodeSkillOverrides
}
except Exception as e:
logger.error(f"更新 Skills 配置失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新失败: {str(e)}"
)
@router.put("/{project_id}/episodes/{episode_number}/skill-config")
async def update_episode_skill_config(
project_id: str,
episode_number: int,
config: EpisodeSkillConfigUpdate
):
"""更新单集的 Skills 覆盖配置"""
project = await project_repo.get(project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
try:
# 获取现有覆盖配置
overrides = dict(project.episodeSkillOverrides)
# 更新或添加单集配置
overrides[episode_number] = {
"episode_number": config.episode_number,
"task_configs": [t.dict() for t in config.task_configs],
"use_project_default": config.use_project_default
}
await project_repo.update(project_id, {
"episodeSkillOverrides": overrides
})
return {
"success": True,
"message": f"EP{episode_number} 的 Skills 配置已更新",
"config": config
}
except Exception as e:
logger.error(f"更新单集 Skills 配置失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新失败: {str(e)}"
)
@router.delete("/{project_id}/episodes/{episode_number}/skill-config")
async def delete_episode_skill_config(
project_id: str,
episode_number: int
):
"""删除单集的 Skills 覆盖配置(恢复使用项目默认)"""
project = await project_repo.get(project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
try:
overrides = dict(project.episodeSkillOverrides)
if episode_number not in overrides:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"EP{episode_number} 没有自定义配置"
)
# 删除单集配置
del overrides[episode_number]
await project_repo.update(project_id, {
"episodeSkillOverrides": overrides
})
return {
"success": True,
"message": f"EP{episode_number} 的配置已删除,将使用项目默认配置"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"删除单集 Skills 配置失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"删除失败: {str(e)}"
)