from datetime import datetime from typing import List, Optional, Dict import uuid import json import os from pathlib import Path from app.models.project import ( SeriesProject, SeriesProjectCreate, Episode ) from app.utils.logger import get_logger logger = get_logger(__name__) # 数据存储路径 # 使用绝对路径,确保在不同工作目录下都能正确找到 # BASE_DIR = Path(__file__).resolve().parent.parent.parent # DATA_DIR = BASE_DIR / "data" # 临时使用硬编码绝对路径进行调试 DATA_DIR = Path("d:/platform/creative_studio/backend/data") PROJECTS_FILE = DATA_DIR / "projects.json" EPISODES_FILE = DATA_DIR / "episodes.json" MESSAGES_FILE = DATA_DIR / "messages.json" # 确保数据目录存在 if not DATA_DIR.exists(): try: DATA_DIR.mkdir(parents=True, exist_ok=True) logger.info(f"Created data directory: {DATA_DIR}") except Exception as e: logger.error(f"Failed to create data directory {DATA_DIR}: {e}") logger.info(f"Data directory: {DATA_DIR}") logger.info(f"Projects file: {PROJECTS_FILE}") class JsonRepository: """JSON 文件持久化基类""" def __init__(self, file_path: Path): self.file_path = file_path self._data = {} self._load() def _load(self): """从文件加载数据""" if self.file_path.exists(): try: content = self.file_path.read_text(encoding="utf-8") self._data = json.loads(content) except Exception as e: logger.error(f"Failed to load data from {self.file_path}: {e}") self._data = {} else: self._data = {} def _save(self): """保存数据到文件""" try: # 转换对象为可序列化的字典 serialized_data = {} for k, v in self._data.items(): if hasattr(v, "dict"): serialized_data[k] = json.loads(v.json()) elif isinstance(v, (dict, list)): serialized_data[k] = v else: serialized_data[k] = str(v) self.file_path.write_text( json.dumps(serialized_data, ensure_ascii=False, indent=2), encoding="utf-8" ) except Exception as e: logger.error(f"Failed to save data to {self.file_path}: {e}") # Re-raise exception to make it visible raise e class ProjectRepository(JsonRepository): """项目仓储(持久化版)""" def __init__(self): super().__init__(PROJECTS_FILE) # 将加载的字典转换为对象 self._objects: Dict[str, SeriesProject] = {} import ast is_dirty = False for k, v in list(self._data.items()): try: # 修复可能存在的字符串化数据 if isinstance(v, str): try: v = ast.literal_eval(v) self._data[k] = v is_dirty = True except: pass # 递归修复嵌套的字符串数据(如 memory 或 globalContext) if isinstance(v, dict): for sub_k, sub_v in v.items(): if isinstance(sub_v, str) and (sub_v.startswith('{') or sub_v.startswith('[')): try: v[sub_k] = ast.literal_eval(sub_v) is_dirty = True except: pass self._objects[k] = SeriesProject.parse_obj(v) except Exception as e: logger.error(f"Failed to parse project {k}: {e}") if is_dirty: self._save() async def create(self, project_data: SeriesProjectCreate) -> SeriesProject: """创建新项目""" project_id = str(uuid.uuid4()) project = SeriesProject( id=project_id, name=project_data.name, totalEpisodes=project_data.totalEpisodes, agentId=project_data.agentId, mode=project_data.mode, globalContext=project_data.globalContext, skillSettings=project_data.skillSettings, createdAt=datetime.now(), updatedAt=datetime.now() ) self._objects[project_id] = project self._data[project_id] = json.loads(project.json()) self._save() logger.info(f"创建项目: {project_id} - {project.name}") return project async def get(self, project_id: str) -> Optional[SeriesProject]: """获取项目""" return self._objects.get(project_id) async def list(self, skip: int = 0, limit: int = 100) -> List[SeriesProject]: """列出所有项目""" # 按创建时间倒序 projects = sorted( self._objects.values(), key=lambda p: p.createdAt or datetime.min, reverse=True ) return projects[skip:skip + limit] async def update( self, project_id: str, project_data: dict ) -> Optional[SeriesProject]: """更新项目""" project = self._objects.get(project_id) if not project: return None for key, value in project_data.items(): if hasattr(project, key): setattr(project, key, value) project.updatedAt = datetime.now() # 更新存储 self._data[project_id] = json.loads(project.json()) self._save() return project async def delete(self, project_id: str) -> bool: """删除项目""" if project_id in self._objects: del self._objects[project_id] if project_id in self._data: del self._data[project_id] self._save() return True return False class EpisodeRepository(JsonRepository): """剧集仓储(持久化版)""" def __init__(self): super().__init__(EPISODES_FILE) self._objects: Dict[str, Episode] = {} # 用于追踪已存在的 (projectId, number) 对,防止重复 seen_episodes = set() duplicates_to_remove = [] for k, v in list(self._data.items()): try: episode = Episode.parse_obj(v) key = (episode.projectId, episode.number) if key in seen_episodes: logger.warning(f"Found duplicate episode: Project {episode.projectId}, EP{episode.number}. Removing ID {k}") duplicates_to_remove.append(k) continue seen_episodes.add(key) self._objects[k] = episode except Exception as e: logger.error(f"Failed to parse episode {k}: {e}") # 清理重复项 if duplicates_to_remove: for k in duplicates_to_remove: del self._data[k] self._save() async def create(self, episode: Episode) -> Episode: """创建剧集""" if not episode.id: episode.id = str(uuid.uuid4()) self._objects[episode.id] = episode self._data[episode.id] = json.loads(episode.json()) self._save() logger.info(f"创建剧集: {episode.id} - EP{episode.number}") return episode async def get(self, episode_id: str) -> Optional[Episode]: """获取剧集""" return self._objects.get(episode_id) async def list_by_project( self, project_id: str, skip: int = 0, limit: int = 100 ) -> List[Episode]: """列出项目的所有剧集""" episodes = [ ep for ep in self._objects.values() if ep.projectId == project_id ] episodes.sort(key=lambda x: x.number) return episodes[skip:skip + limit] async def update(self, episode: Episode) -> Episode: """更新剧集""" self._objects[episode.id] = episode self._data[episode.id] = json.loads(episode.json()) self._save() return episode # ============================================ # 全局仓储实例 # ============================================ project_repo = ProjectRepository() episode_repo = EpisodeRepository() class MessageRepository(JsonRepository): """消息记录仓储""" def __init__(self): super().__init__(MESSAGES_FILE) # 修复可能被错误序列化为字符串的历史数据 import ast is_dirty = False for project_id in list(self._data.keys()): messages = self._data[project_id] if isinstance(messages, str): try: # 尝试解析 Python repr 格式的字符串 self._data[project_id] = ast.literal_eval(messages) is_dirty = True logger.info(f"Fixed corrupted message history for project {project_id}") except Exception as e: logger.warning(f"Failed to fix message history for {project_id}: {e}") self._data[project_id] = [] is_dirty = True if is_dirty: self._save() async def add_message(self, project_id: str, role: str, content: str): """添加消息""" if project_id not in self._data: self._data[project_id] = [] message = { "role": role, "content": content, "timestamp": datetime.now().isoformat() } self._data[project_id].append(message) self._save() async def get_history(self, project_id: str) -> List[Dict]: """获取项目聊天历史""" return self._data.get(project_id, []) message_repo = MessageRepository()