117 lines
3.3 KiB
Python
117 lines
3.3 KiB
Python
from datetime import datetime
|
||
from typing import List, Optional
|
||
import uuid
|
||
|
||
from app.models.project import (
|
||
SeriesProject,
|
||
SeriesProjectCreate,
|
||
Episode,
|
||
EpisodeExecuteRequest,
|
||
EpisodeExecuteResponse
|
||
)
|
||
from app.core.agents.series_creation_agent import get_series_agent
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
# ============================================
|
||
# 内存存储 (MVP 阶段使用文件存储)
|
||
# ============================================
|
||
_projects: dict = {}
|
||
_episodes: dict = {}
|
||
|
||
|
||
class ProjectRepository:
|
||
"""项目仓储(MVP 简化版)"""
|
||
|
||
async def create(self, project_data: SeriesProjectCreate) -> SeriesProject:
|
||
"""创建新项目"""
|
||
project_id = str(uuid.uuid4())
|
||
project = SeriesProject(
|
||
id=project_id,
|
||
name=project_data.name,
|
||
totalEpisodes=project_data.totalEpisodes,
|
||
agentId=project_data.agentId,
|
||
mode=project_data.mode,
|
||
globalContext=project_data.globalContext,
|
||
skillSettings=project_data.skillSettings,
|
||
createdAt=datetime.now(),
|
||
updatedAt=datetime.now()
|
||
)
|
||
_projects[project_id] = project
|
||
logger.info(f"创建项目: {project_id} - {project.name}")
|
||
return project
|
||
|
||
async def get(self, project_id: str) -> Optional[SeriesProject]:
|
||
"""获取项目"""
|
||
return _projects.get(project_id)
|
||
|
||
async def list(self, skip: int = 0, limit: int = 100) -> List[SeriesProject]:
|
||
"""列出所有项目"""
|
||
return list(_projects.values())[skip:skip + limit]
|
||
|
||
async def update(
|
||
self,
|
||
project_id: str,
|
||
project_data: dict
|
||
) -> Optional[SeriesProject]:
|
||
"""更新项目"""
|
||
project = _projects.get(project_id)
|
||
if not project:
|
||
return None
|
||
|
||
for key, value in project_data.items():
|
||
if hasattr(project, key):
|
||
setattr(project, key, value)
|
||
|
||
project.updatedAt = datetime.now()
|
||
return project
|
||
|
||
async def delete(self, project_id: str) -> bool:
|
||
"""删除项目"""
|
||
if project_id in _projects:
|
||
del _projects[project_id]
|
||
return True
|
||
return False
|
||
|
||
|
||
class EpisodeRepository:
|
||
"""剧集仓储(MVP 简化版)"""
|
||
|
||
async def create(self, episode: Episode) -> Episode:
|
||
"""创建剧集"""
|
||
if not episode.id:
|
||
episode.id = str(uuid.uuid4())
|
||
_episodes[episode.id] = episode
|
||
logger.info(f"创建剧集: {episode.id} - EP{episode.number}")
|
||
return episode
|
||
|
||
async def get(self, episode_id: str) -> Optional[Episode]:
|
||
"""获取剧集"""
|
||
return _episodes.get(episode_id)
|
||
|
||
async def list_by_project(
|
||
self,
|
||
project_id: str,
|
||
skip: int = 0,
|
||
limit: int = 100
|
||
) -> List[Episode]:
|
||
"""列出项目的所有剧集"""
|
||
return [
|
||
ep for ep in _episodes.values()
|
||
if ep.projectId == project_id
|
||
][skip:skip + limit]
|
||
|
||
async def update(self, episode: Episode) -> Episode:
|
||
"""更新剧集"""
|
||
_episodes[episode.id] = episode
|
||
return episode
|
||
|
||
|
||
# ============================================
|
||
# 全局仓储实例
|
||
# ============================================
|
||
project_repo = ProjectRepository()
|
||
episode_repo = EpisodeRepository()
|