creative_studio/backend/app/db/repositories.py

243 lines
7.5 KiB
Python

from datetime import datetime
from typing import List, Optional, Dict
import uuid
import json
import os
from pathlib import Path
from app.models.project import (
SeriesProject,
SeriesProjectCreate,
Episode
)
from app.utils.logger import get_logger
logger = get_logger(__name__)
# 数据存储路径
# 使用绝对路径,确保在不同工作目录下都能正确找到
# BASE_DIR = Path(__file__).resolve().parent.parent.parent
# DATA_DIR = BASE_DIR / "data"
# 临时使用硬编码绝对路径进行调试
DATA_DIR = Path("d:/platform/creative_studio/backend/data")
PROJECTS_FILE = DATA_DIR / "projects.json"
EPISODES_FILE = DATA_DIR / "episodes.json"
MESSAGES_FILE = DATA_DIR / "messages.json"
# 确保数据目录存在
if not DATA_DIR.exists():
try:
DATA_DIR.mkdir(parents=True, exist_ok=True)
logger.info(f"Created data directory: {DATA_DIR}")
except Exception as e:
logger.error(f"Failed to create data directory {DATA_DIR}: {e}")
logger.info(f"Data directory: {DATA_DIR}")
logger.info(f"Projects file: {PROJECTS_FILE}")
class JsonRepository:
"""JSON 文件持久化基类"""
def __init__(self, file_path: Path):
self.file_path = file_path
self._data = {}
self._load()
def _load(self):
"""从文件加载数据"""
if self.file_path.exists():
try:
content = self.file_path.read_text(encoding="utf-8")
self._data = json.loads(content)
except Exception as e:
logger.error(f"Failed to load data from {self.file_path}: {e}")
self._data = {}
else:
self._data = {}
def _save(self):
"""保存数据到文件"""
try:
# 转换对象为可序列化的字典
serialized_data = {}
for k, v in self._data.items():
if hasattr(v, "dict"):
serialized_data[k] = json.loads(v.json())
elif isinstance(v, dict):
serialized_data[k] = v
else:
serialized_data[k] = str(v)
self.file_path.write_text(
json.dumps(serialized_data, ensure_ascii=False, indent=2),
encoding="utf-8"
)
except Exception as e:
logger.error(f"Failed to save data to {self.file_path}: {e}")
# Re-raise exception to make it visible
raise e
class ProjectRepository(JsonRepository):
"""项目仓储(持久化版)"""
def __init__(self):
super().__init__(PROJECTS_FILE)
# 将加载的字典转换为对象
self._objects: Dict[str, SeriesProject] = {}
for k, v in self._data.items():
try:
self._objects[k] = SeriesProject.parse_obj(v)
except Exception as e:
logger.error(f"Failed to parse project {k}: {e}")
async def create(self, project_data: SeriesProjectCreate) -> SeriesProject:
"""创建新项目"""
project_id = str(uuid.uuid4())
project = SeriesProject(
id=project_id,
name=project_data.name,
totalEpisodes=project_data.totalEpisodes,
agentId=project_data.agentId,
mode=project_data.mode,
globalContext=project_data.globalContext,
skillSettings=project_data.skillSettings,
createdAt=datetime.now(),
updatedAt=datetime.now()
)
self._objects[project_id] = project
self._data[project_id] = json.loads(project.json())
self._save()
logger.info(f"创建项目: {project_id} - {project.name}")
return project
async def get(self, project_id: str) -> Optional[SeriesProject]:
"""获取项目"""
return self._objects.get(project_id)
async def list(self, skip: int = 0, limit: int = 100) -> List[SeriesProject]:
"""列出所有项目"""
# 按创建时间倒序
projects = sorted(
self._objects.values(),
key=lambda p: p.createdAt or datetime.min,
reverse=True
)
return projects[skip:skip + limit]
async def update(
self,
project_id: str,
project_data: dict
) -> Optional[SeriesProject]:
"""更新项目"""
project = self._objects.get(project_id)
if not project:
return None
for key, value in project_data.items():
if hasattr(project, key):
setattr(project, key, value)
project.updatedAt = datetime.now()
# 更新存储
self._data[project_id] = json.loads(project.json())
self._save()
return project
async def delete(self, project_id: str) -> bool:
"""删除项目"""
if project_id in self._objects:
del self._objects[project_id]
if project_id in self._data:
del self._data[project_id]
self._save()
return True
return False
class EpisodeRepository(JsonRepository):
"""剧集仓储(持久化版)"""
def __init__(self):
super().__init__(EPISODES_FILE)
self._objects: Dict[str, Episode] = {}
for k, v in self._data.items():
try:
self._objects[k] = Episode.parse_obj(v)
except Exception as e:
logger.error(f"Failed to parse episode {k}: {e}")
async def create(self, episode: Episode) -> Episode:
"""创建剧集"""
if not episode.id:
episode.id = str(uuid.uuid4())
self._objects[episode.id] = episode
self._data[episode.id] = json.loads(episode.json())
self._save()
logger.info(f"创建剧集: {episode.id} - EP{episode.number}")
return episode
async def get(self, episode_id: str) -> Optional[Episode]:
"""获取剧集"""
return self._objects.get(episode_id)
async def list_by_project(
self,
project_id: str,
skip: int = 0,
limit: int = 100
) -> List[Episode]:
"""列出项目的所有剧集"""
episodes = [
ep for ep in self._objects.values()
if ep.projectId == project_id
]
episodes.sort(key=lambda x: x.number)
return episodes[skip:skip + limit]
async def update(self, episode: Episode) -> Episode:
"""更新剧集"""
self._objects[episode.id] = episode
self._data[episode.id] = json.loads(episode.json())
self._save()
return episode
# ============================================
# 全局仓储实例
# ============================================
project_repo = ProjectRepository()
episode_repo = EpisodeRepository()
class MessageRepository(JsonRepository):
"""消息记录仓储"""
def __init__(self):
super().__init__(MESSAGES_FILE)
# 结构: {project_id: [{role, content, timestamp}, ...]}
async def add_message(self, project_id: str, role: str, content: str):
"""添加消息"""
if project_id not in self._data:
self._data[project_id] = []
message = {
"role": role,
"content": content,
"timestamp": datetime.now().isoformat()
}
self._data[project_id].append(message)
self._save()
async def get_history(self, project_id: str) -> List[Dict]:
"""获取项目聊天历史"""
return self._data.get(project_id, [])
message_repo = MessageRepository()