243 lines
7.5 KiB
Python
243 lines
7.5 KiB
Python
from datetime import datetime
|
|
from typing import List, Optional, Dict
|
|
import uuid
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from app.models.project import (
|
|
SeriesProject,
|
|
SeriesProjectCreate,
|
|
Episode
|
|
)
|
|
from app.utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# 数据存储路径
|
|
# 使用绝对路径,确保在不同工作目录下都能正确找到
|
|
# BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
|
# DATA_DIR = BASE_DIR / "data"
|
|
|
|
# 临时使用硬编码绝对路径进行调试
|
|
DATA_DIR = Path("d:/platform/creative_studio/backend/data")
|
|
PROJECTS_FILE = DATA_DIR / "projects.json"
|
|
EPISODES_FILE = DATA_DIR / "episodes.json"
|
|
MESSAGES_FILE = DATA_DIR / "messages.json"
|
|
|
|
# 确保数据目录存在
|
|
if not DATA_DIR.exists():
|
|
try:
|
|
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
|
logger.info(f"Created data directory: {DATA_DIR}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to create data directory {DATA_DIR}: {e}")
|
|
|
|
logger.info(f"Data directory: {DATA_DIR}")
|
|
logger.info(f"Projects file: {PROJECTS_FILE}")
|
|
|
|
class JsonRepository:
|
|
"""JSON 文件持久化基类"""
|
|
|
|
def __init__(self, file_path: Path):
|
|
self.file_path = file_path
|
|
self._data = {}
|
|
self._load()
|
|
|
|
def _load(self):
|
|
"""从文件加载数据"""
|
|
if self.file_path.exists():
|
|
try:
|
|
content = self.file_path.read_text(encoding="utf-8")
|
|
self._data = json.loads(content)
|
|
except Exception as e:
|
|
logger.error(f"Failed to load data from {self.file_path}: {e}")
|
|
self._data = {}
|
|
else:
|
|
self._data = {}
|
|
|
|
def _save(self):
|
|
"""保存数据到文件"""
|
|
try:
|
|
# 转换对象为可序列化的字典
|
|
serialized_data = {}
|
|
for k, v in self._data.items():
|
|
if hasattr(v, "dict"):
|
|
serialized_data[k] = json.loads(v.json())
|
|
elif isinstance(v, dict):
|
|
serialized_data[k] = v
|
|
else:
|
|
serialized_data[k] = str(v)
|
|
|
|
self.file_path.write_text(
|
|
json.dumps(serialized_data, ensure_ascii=False, indent=2),
|
|
encoding="utf-8"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to save data to {self.file_path}: {e}")
|
|
# Re-raise exception to make it visible
|
|
raise e
|
|
|
|
class ProjectRepository(JsonRepository):
|
|
"""项目仓储(持久化版)"""
|
|
|
|
def __init__(self):
|
|
super().__init__(PROJECTS_FILE)
|
|
# 将加载的字典转换为对象
|
|
self._objects: Dict[str, SeriesProject] = {}
|
|
for k, v in self._data.items():
|
|
try:
|
|
self._objects[k] = SeriesProject.parse_obj(v)
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse project {k}: {e}")
|
|
|
|
async def create(self, project_data: SeriesProjectCreate) -> SeriesProject:
|
|
"""创建新项目"""
|
|
project_id = str(uuid.uuid4())
|
|
project = SeriesProject(
|
|
id=project_id,
|
|
name=project_data.name,
|
|
totalEpisodes=project_data.totalEpisodes,
|
|
agentId=project_data.agentId,
|
|
mode=project_data.mode,
|
|
globalContext=project_data.globalContext,
|
|
skillSettings=project_data.skillSettings,
|
|
createdAt=datetime.now(),
|
|
updatedAt=datetime.now()
|
|
)
|
|
self._objects[project_id] = project
|
|
self._data[project_id] = json.loads(project.json())
|
|
self._save()
|
|
|
|
logger.info(f"创建项目: {project_id} - {project.name}")
|
|
return project
|
|
|
|
async def get(self, project_id: str) -> Optional[SeriesProject]:
|
|
"""获取项目"""
|
|
return self._objects.get(project_id)
|
|
|
|
async def list(self, skip: int = 0, limit: int = 100) -> List[SeriesProject]:
|
|
"""列出所有项目"""
|
|
# 按创建时间倒序
|
|
projects = sorted(
|
|
self._objects.values(),
|
|
key=lambda p: p.createdAt or datetime.min,
|
|
reverse=True
|
|
)
|
|
return projects[skip:skip + limit]
|
|
|
|
async def update(
|
|
self,
|
|
project_id: str,
|
|
project_data: dict
|
|
) -> Optional[SeriesProject]:
|
|
"""更新项目"""
|
|
project = self._objects.get(project_id)
|
|
if not project:
|
|
return None
|
|
|
|
for key, value in project_data.items():
|
|
if hasattr(project, key):
|
|
setattr(project, key, value)
|
|
|
|
project.updatedAt = datetime.now()
|
|
|
|
# 更新存储
|
|
self._data[project_id] = json.loads(project.json())
|
|
self._save()
|
|
|
|
return project
|
|
|
|
async def delete(self, project_id: str) -> bool:
|
|
"""删除项目"""
|
|
if project_id in self._objects:
|
|
del self._objects[project_id]
|
|
if project_id in self._data:
|
|
del self._data[project_id]
|
|
self._save()
|
|
return True
|
|
return False
|
|
|
|
|
|
class EpisodeRepository(JsonRepository):
|
|
"""剧集仓储(持久化版)"""
|
|
|
|
def __init__(self):
|
|
super().__init__(EPISODES_FILE)
|
|
self._objects: Dict[str, Episode] = {}
|
|
for k, v in self._data.items():
|
|
try:
|
|
self._objects[k] = Episode.parse_obj(v)
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse episode {k}: {e}")
|
|
|
|
async def create(self, episode: Episode) -> Episode:
|
|
"""创建剧集"""
|
|
if not episode.id:
|
|
episode.id = str(uuid.uuid4())
|
|
|
|
self._objects[episode.id] = episode
|
|
self._data[episode.id] = json.loads(episode.json())
|
|
self._save()
|
|
|
|
logger.info(f"创建剧集: {episode.id} - EP{episode.number}")
|
|
return episode
|
|
|
|
async def get(self, episode_id: str) -> Optional[Episode]:
|
|
"""获取剧集"""
|
|
return self._objects.get(episode_id)
|
|
|
|
async def list_by_project(
|
|
self,
|
|
project_id: str,
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> List[Episode]:
|
|
"""列出项目的所有剧集"""
|
|
episodes = [
|
|
ep for ep in self._objects.values()
|
|
if ep.projectId == project_id
|
|
]
|
|
episodes.sort(key=lambda x: x.number)
|
|
return episodes[skip:skip + limit]
|
|
|
|
async def update(self, episode: Episode) -> Episode:
|
|
"""更新剧集"""
|
|
self._objects[episode.id] = episode
|
|
self._data[episode.id] = json.loads(episode.json())
|
|
self._save()
|
|
return episode
|
|
|
|
|
|
# ============================================
|
|
# 全局仓储实例
|
|
# ============================================
|
|
project_repo = ProjectRepository()
|
|
episode_repo = EpisodeRepository()
|
|
|
|
class MessageRepository(JsonRepository):
|
|
"""消息记录仓储"""
|
|
|
|
def __init__(self):
|
|
super().__init__(MESSAGES_FILE)
|
|
# 结构: {project_id: [{role, content, timestamp}, ...]}
|
|
|
|
async def add_message(self, project_id: str, role: str, content: str):
|
|
"""添加消息"""
|
|
if project_id not in self._data:
|
|
self._data[project_id] = []
|
|
|
|
message = {
|
|
"role": role,
|
|
"content": content,
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
self._data[project_id].append(message)
|
|
self._save()
|
|
|
|
async def get_history(self, project_id: str) -> List[Dict]:
|
|
"""获取项目聊天历史"""
|
|
return self._data.get(project_id, [])
|
|
|
|
message_repo = MessageRepository()
|