from datetime import datetime from typing import List, Optional, Dict import uuid import json import os from pathlib import Path from app.models.project import ( SeriesProject, SeriesProjectCreate, Episode ) from app.utils.logger import get_logger logger = get_logger(__name__) # 数据存储路径 # 使用绝对路径,确保在不同工作目录下都能正确找到 # BASE_DIR = Path(__file__).resolve().parent.parent.parent # DATA_DIR = BASE_DIR / "data" # 临时使用硬编码绝对路径进行调试 DATA_DIR = Path("d:/platform/creative_studio/backend/data") PROJECTS_FILE = DATA_DIR / "projects.json" EPISODES_FILE = DATA_DIR / "episodes.json" MESSAGES_FILE = DATA_DIR / "messages.json" # 确保数据目录存在 if not DATA_DIR.exists(): try: DATA_DIR.mkdir(parents=True, exist_ok=True) logger.info(f"Created data directory: {DATA_DIR}") except Exception as e: logger.error(f"Failed to create data directory {DATA_DIR}: {e}") logger.info(f"Data directory: {DATA_DIR}") logger.info(f"Projects file: {PROJECTS_FILE}") class JsonRepository: """JSON 文件持久化基类""" def __init__(self, file_path: Path): self.file_path = file_path self._data = {} self._load() def _load(self): """从文件加载数据""" if self.file_path.exists(): try: content = self.file_path.read_text(encoding="utf-8") self._data = json.loads(content) except Exception as e: logger.error(f"Failed to load data from {self.file_path}: {e}") self._data = {} else: self._data = {} def _save(self): """保存数据到文件""" try: # 转换对象为可序列化的字典 serialized_data = {} for k, v in self._data.items(): if hasattr(v, "dict"): serialized_data[k] = json.loads(v.json()) elif isinstance(v, dict): serialized_data[k] = v else: serialized_data[k] = str(v) self.file_path.write_text( json.dumps(serialized_data, ensure_ascii=False, indent=2), encoding="utf-8" ) except Exception as e: logger.error(f"Failed to save data to {self.file_path}: {e}") # Re-raise exception to make it visible raise e class ProjectRepository(JsonRepository): """项目仓储(持久化版)""" def __init__(self): super().__init__(PROJECTS_FILE) # 将加载的字典转换为对象 self._objects: Dict[str, SeriesProject] = {} for k, v in self._data.items(): try: self._objects[k] = SeriesProject.parse_obj(v) except Exception as e: logger.error(f"Failed to parse project {k}: {e}") async def create(self, project_data: SeriesProjectCreate) -> SeriesProject: """创建新项目""" project_id = str(uuid.uuid4()) project = SeriesProject( id=project_id, name=project_data.name, totalEpisodes=project_data.totalEpisodes, agentId=project_data.agentId, mode=project_data.mode, globalContext=project_data.globalContext, skillSettings=project_data.skillSettings, createdAt=datetime.now(), updatedAt=datetime.now() ) self._objects[project_id] = project self._data[project_id] = json.loads(project.json()) self._save() logger.info(f"创建项目: {project_id} - {project.name}") return project async def get(self, project_id: str) -> Optional[SeriesProject]: """获取项目""" return self._objects.get(project_id) async def list(self, skip: int = 0, limit: int = 100) -> List[SeriesProject]: """列出所有项目""" # 按创建时间倒序 projects = sorted( self._objects.values(), key=lambda p: p.createdAt or datetime.min, reverse=True ) return projects[skip:skip + limit] async def update( self, project_id: str, project_data: dict ) -> Optional[SeriesProject]: """更新项目""" project = self._objects.get(project_id) if not project: return None for key, value in project_data.items(): if hasattr(project, key): setattr(project, key, value) project.updatedAt = datetime.now() # 更新存储 self._data[project_id] = json.loads(project.json()) self._save() return project async def delete(self, project_id: str) -> bool: """删除项目""" if project_id in self._objects: del self._objects[project_id] if project_id in self._data: del self._data[project_id] self._save() return True return False class EpisodeRepository(JsonRepository): """剧集仓储(持久化版)""" def __init__(self): super().__init__(EPISODES_FILE) self._objects: Dict[str, Episode] = {} for k, v in self._data.items(): try: self._objects[k] = Episode.parse_obj(v) except Exception as e: logger.error(f"Failed to parse episode {k}: {e}") async def create(self, episode: Episode) -> Episode: """创建剧集""" if not episode.id: episode.id = str(uuid.uuid4()) self._objects[episode.id] = episode self._data[episode.id] = json.loads(episode.json()) self._save() logger.info(f"创建剧集: {episode.id} - EP{episode.number}") return episode async def get(self, episode_id: str) -> Optional[Episode]: """获取剧集""" return self._objects.get(episode_id) async def list_by_project( self, project_id: str, skip: int = 0, limit: int = 100 ) -> List[Episode]: """列出项目的所有剧集""" episodes = [ ep for ep in self._objects.values() if ep.projectId == project_id ] episodes.sort(key=lambda x: x.number) return episodes[skip:skip + limit] async def update(self, episode: Episode) -> Episode: """更新剧集""" self._objects[episode.id] = episode self._data[episode.id] = json.loads(episode.json()) self._save() return episode # ============================================ # 全局仓储实例 # ============================================ project_repo = ProjectRepository() episode_repo = EpisodeRepository() class MessageRepository(JsonRepository): """消息记录仓储""" def __init__(self): super().__init__(MESSAGES_FILE) # 结构: {project_id: [{role, content, timestamp}, ...]} async def add_message(self, project_id: str, role: str, content: str): """添加消息""" if project_id not in self._data: self._data[project_id] = [] message = { "role": role, "content": content, "timestamp": datetime.now().isoformat() } self._data[project_id].append(message) self._save() async def get_history(self, project_id: str) -> List[Dict]: """获取项目聊天历史""" return self._data.get(project_id, []) message_repo = MessageRepository()