creative_studio/backend/app/db/repositories.py

303 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime
from typing import List, Optional, Dict
import uuid
import json
import os
from pathlib import Path
from app.models.project import (
SeriesProject,
SeriesProjectCreate,
Episode
)
from app.utils.logger import get_logger
logger = get_logger(__name__)
# 数据存储路径
# 使用绝对路径,确保在不同工作目录下都能正确找到
# BASE_DIR = Path(__file__).resolve().parent.parent.parent
# DATA_DIR = BASE_DIR / "data"
# 临时使用硬编码绝对路径进行调试
DATA_DIR = Path("d:/platform/creative_studio/backend/data")
PROJECTS_FILE = DATA_DIR / "projects.json"
EPISODES_FILE = DATA_DIR / "episodes.json"
MESSAGES_FILE = DATA_DIR / "messages.json"
# 确保数据目录存在
if not DATA_DIR.exists():
try:
DATA_DIR.mkdir(parents=True, exist_ok=True)
logger.info(f"Created data directory: {DATA_DIR}")
except Exception as e:
logger.error(f"Failed to create data directory {DATA_DIR}: {e}")
logger.info(f"Data directory: {DATA_DIR}")
logger.info(f"Projects file: {PROJECTS_FILE}")
class JsonRepository:
"""JSON 文件持久化基类"""
def __init__(self, file_path: Path):
self.file_path = file_path
self._data = {}
self._load()
def _load(self):
"""从文件加载数据"""
if self.file_path.exists():
try:
content = self.file_path.read_text(encoding="utf-8")
self._data = json.loads(content)
except Exception as e:
logger.error(f"Failed to load data from {self.file_path}: {e}")
self._data = {}
else:
self._data = {}
def _save(self):
"""保存数据到文件"""
try:
# 转换对象为可序列化的字典
serialized_data = {}
for k, v in self._data.items():
if hasattr(v, "dict"):
serialized_data[k] = json.loads(v.json())
elif isinstance(v, (dict, list)):
serialized_data[k] = v
else:
serialized_data[k] = str(v)
self.file_path.write_text(
json.dumps(serialized_data, ensure_ascii=False, indent=2),
encoding="utf-8"
)
except Exception as e:
logger.error(f"Failed to save data to {self.file_path}: {e}")
# Re-raise exception to make it visible
raise e
class ProjectRepository(JsonRepository):
"""项目仓储(持久化版)"""
def __init__(self):
super().__init__(PROJECTS_FILE)
# 将加载的字典转换为对象
self._objects: Dict[str, SeriesProject] = {}
import ast
is_dirty = False
for k, v in list(self._data.items()):
try:
# 修复可能存在的字符串化数据
if isinstance(v, str):
try:
v = ast.literal_eval(v)
self._data[k] = v
is_dirty = True
except:
pass
# 递归修复嵌套的字符串数据(如 memory 或 globalContext
if isinstance(v, dict):
for sub_k, sub_v in v.items():
if isinstance(sub_v, str) and (sub_v.startswith('{') or sub_v.startswith('[')):
try:
v[sub_k] = ast.literal_eval(sub_v)
is_dirty = True
except:
pass
self._objects[k] = SeriesProject.parse_obj(v)
except Exception as e:
logger.error(f"Failed to parse project {k}: {e}")
if is_dirty:
self._save()
async def create(self, project_data: SeriesProjectCreate) -> SeriesProject:
"""创建新项目"""
project_id = str(uuid.uuid4())
project = SeriesProject(
id=project_id,
name=project_data.name,
totalEpisodes=project_data.totalEpisodes,
agentId=project_data.agentId,
mode=project_data.mode,
globalContext=project_data.globalContext,
skillSettings=project_data.skillSettings,
createdAt=datetime.now(),
updatedAt=datetime.now()
)
self._objects[project_id] = project
self._data[project_id] = json.loads(project.json())
self._save()
logger.info(f"创建项目: {project_id} - {project.name}")
return project
async def get(self, project_id: str) -> Optional[SeriesProject]:
"""获取项目"""
return self._objects.get(project_id)
async def list(self, skip: int = 0, limit: int = 100) -> List[SeriesProject]:
"""列出所有项目"""
# 按创建时间倒序
projects = sorted(
self._objects.values(),
key=lambda p: p.createdAt or datetime.min,
reverse=True
)
return projects[skip:skip + limit]
async def update(
self,
project_id: str,
project_data: dict
) -> Optional[SeriesProject]:
"""更新项目"""
project = self._objects.get(project_id)
if not project:
return None
for key, value in project_data.items():
if hasattr(project, key):
setattr(project, key, value)
project.updatedAt = datetime.now()
# 更新存储
self._data[project_id] = json.loads(project.json())
self._save()
return project
async def delete(self, project_id: str) -> bool:
"""删除项目"""
if project_id in self._objects:
del self._objects[project_id]
if project_id in self._data:
del self._data[project_id]
self._save()
return True
return False
class EpisodeRepository(JsonRepository):
"""剧集仓储(持久化版)"""
def __init__(self):
super().__init__(EPISODES_FILE)
self._objects: Dict[str, Episode] = {}
# 用于追踪已存在的 (projectId, number) 对,防止重复
seen_episodes = set()
duplicates_to_remove = []
for k, v in list(self._data.items()):
try:
episode = Episode.parse_obj(v)
key = (episode.projectId, episode.number)
if key in seen_episodes:
logger.warning(f"Found duplicate episode: Project {episode.projectId}, EP{episode.number}. Removing ID {k}")
duplicates_to_remove.append(k)
continue
seen_episodes.add(key)
self._objects[k] = episode
except Exception as e:
logger.error(f"Failed to parse episode {k}: {e}")
# 清理重复项
if duplicates_to_remove:
for k in duplicates_to_remove:
del self._data[k]
self._save()
async def create(self, episode: Episode) -> Episode:
"""创建剧集"""
if not episode.id:
episode.id = str(uuid.uuid4())
self._objects[episode.id] = episode
self._data[episode.id] = json.loads(episode.json())
self._save()
logger.info(f"创建剧集: {episode.id} - EP{episode.number}")
return episode
async def get(self, episode_id: str) -> Optional[Episode]:
"""获取剧集"""
return self._objects.get(episode_id)
async def list_by_project(
self,
project_id: str,
skip: int = 0,
limit: int = 100
) -> List[Episode]:
"""列出项目的所有剧集"""
episodes = [
ep for ep in self._objects.values()
if ep.projectId == project_id
]
episodes.sort(key=lambda x: x.number)
return episodes[skip:skip + limit]
async def update(self, episode: Episode) -> Episode:
"""更新剧集"""
self._objects[episode.id] = episode
self._data[episode.id] = json.loads(episode.json())
self._save()
return episode
# ============================================
# 全局仓储实例
# ============================================
project_repo = ProjectRepository()
episode_repo = EpisodeRepository()
class MessageRepository(JsonRepository):
"""消息记录仓储"""
def __init__(self):
super().__init__(MESSAGES_FILE)
# 修复可能被错误序列化为字符串的历史数据
import ast
is_dirty = False
for project_id in list(self._data.keys()):
messages = self._data[project_id]
if isinstance(messages, str):
try:
# 尝试解析 Python repr 格式的字符串
self._data[project_id] = ast.literal_eval(messages)
is_dirty = True
logger.info(f"Fixed corrupted message history for project {project_id}")
except Exception as e:
logger.warning(f"Failed to fix message history for {project_id}: {e}")
self._data[project_id] = []
is_dirty = True
if is_dirty:
self._save()
async def add_message(self, project_id: str, role: str, content: str):
"""添加消息"""
if project_id not in self._data:
self._data[project_id] = []
message = {
"role": role,
"content": content,
"timestamp": datetime.now().isoformat()
}
self._data[project_id].append(message)
self._save()
async def get_history(self, project_id: str) -> List[Dict]:
"""获取项目聊天历史"""
return self._data.get(project_id, [])
message_repo = MessageRepository()