creative_studio/backend/app/db/repositories.py
2026-01-25 19:27:44 +08:00

117 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime
from typing import List, Optional
import uuid
from app.models.project import (
SeriesProject,
SeriesProjectCreate,
Episode,
EpisodeExecuteRequest,
EpisodeExecuteResponse
)
from app.core.agents.series_creation_agent import get_series_agent
from app.utils.logger import get_logger
logger = get_logger(__name__)
# ============================================
# 内存存储 (MVP 阶段使用文件存储)
# ============================================
_projects: dict = {}
_episodes: dict = {}
class ProjectRepository:
"""项目仓储MVP 简化版)"""
async def create(self, project_data: SeriesProjectCreate) -> SeriesProject:
"""创建新项目"""
project_id = str(uuid.uuid4())
project = SeriesProject(
id=project_id,
name=project_data.name,
totalEpisodes=project_data.totalEpisodes,
agentId=project_data.agentId,
mode=project_data.mode,
globalContext=project_data.globalContext,
skillSettings=project_data.skillSettings,
createdAt=datetime.now(),
updatedAt=datetime.now()
)
_projects[project_id] = project
logger.info(f"创建项目: {project_id} - {project.name}")
return project
async def get(self, project_id: str) -> Optional[SeriesProject]:
"""获取项目"""
return _projects.get(project_id)
async def list(self, skip: int = 0, limit: int = 100) -> List[SeriesProject]:
"""列出所有项目"""
return list(_projects.values())[skip:skip + limit]
async def update(
self,
project_id: str,
project_data: dict
) -> Optional[SeriesProject]:
"""更新项目"""
project = _projects.get(project_id)
if not project:
return None
for key, value in project_data.items():
if hasattr(project, key):
setattr(project, key, value)
project.updatedAt = datetime.now()
return project
async def delete(self, project_id: str) -> bool:
"""删除项目"""
if project_id in _projects:
del _projects[project_id]
return True
return False
class EpisodeRepository:
"""剧集仓储MVP 简化版)"""
async def create(self, episode: Episode) -> Episode:
"""创建剧集"""
if not episode.id:
episode.id = str(uuid.uuid4())
_episodes[episode.id] = episode
logger.info(f"创建剧集: {episode.id} - EP{episode.number}")
return episode
async def get(self, episode_id: str) -> Optional[Episode]:
"""获取剧集"""
return _episodes.get(episode_id)
async def list_by_project(
self,
project_id: str,
skip: int = 0,
limit: int = 100
) -> List[Episode]:
"""列出项目的所有剧集"""
return [
ep for ep in _episodes.values()
if ep.projectId == project_id
][skip:skip + limit]
async def update(self, episode: Episode) -> Episode:
"""更新剧集"""
_episodes[episode.id] = episode
return episode
# ============================================
# 全局仓储实例
# ============================================
project_repo = ProjectRepository()
episode_repo = EpisodeRepository()