303 lines
9.9 KiB
Python
303 lines
9.9 KiB
Python
from datetime import datetime
|
||
from typing import List, Optional, Dict
|
||
import uuid
|
||
import json
|
||
import os
|
||
from pathlib import Path
|
||
|
||
from app.models.project import (
|
||
SeriesProject,
|
||
SeriesProjectCreate,
|
||
Episode
|
||
)
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# 数据存储路径
|
||
# 使用绝对路径,确保在不同工作目录下都能正确找到
|
||
# BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
||
# DATA_DIR = BASE_DIR / "data"
|
||
|
||
# 临时使用硬编码绝对路径进行调试
|
||
DATA_DIR = Path("d:/platform/creative_studio/backend/data")
|
||
PROJECTS_FILE = DATA_DIR / "projects.json"
|
||
EPISODES_FILE = DATA_DIR / "episodes.json"
|
||
MESSAGES_FILE = DATA_DIR / "messages.json"
|
||
|
||
# 确保数据目录存在
|
||
if not DATA_DIR.exists():
|
||
try:
|
||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||
logger.info(f"Created data directory: {DATA_DIR}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to create data directory {DATA_DIR}: {e}")
|
||
|
||
logger.info(f"Data directory: {DATA_DIR}")
|
||
logger.info(f"Projects file: {PROJECTS_FILE}")
|
||
|
||
class JsonRepository:
|
||
"""JSON 文件持久化基类"""
|
||
|
||
def __init__(self, file_path: Path):
|
||
self.file_path = file_path
|
||
self._data = {}
|
||
self._load()
|
||
|
||
def _load(self):
|
||
"""从文件加载数据"""
|
||
if self.file_path.exists():
|
||
try:
|
||
content = self.file_path.read_text(encoding="utf-8")
|
||
self._data = json.loads(content)
|
||
except Exception as e:
|
||
logger.error(f"Failed to load data from {self.file_path}: {e}")
|
||
self._data = {}
|
||
else:
|
||
self._data = {}
|
||
|
||
def _save(self):
|
||
"""保存数据到文件"""
|
||
try:
|
||
# 转换对象为可序列化的字典
|
||
serialized_data = {}
|
||
for k, v in self._data.items():
|
||
if hasattr(v, "dict"):
|
||
serialized_data[k] = json.loads(v.json())
|
||
elif isinstance(v, (dict, list)):
|
||
serialized_data[k] = v
|
||
else:
|
||
serialized_data[k] = str(v)
|
||
|
||
self.file_path.write_text(
|
||
json.dumps(serialized_data, ensure_ascii=False, indent=2),
|
||
encoding="utf-8"
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Failed to save data to {self.file_path}: {e}")
|
||
# Re-raise exception to make it visible
|
||
raise e
|
||
|
||
class ProjectRepository(JsonRepository):
|
||
"""项目仓储(持久化版)"""
|
||
|
||
def __init__(self):
|
||
super().__init__(PROJECTS_FILE)
|
||
# 将加载的字典转换为对象
|
||
self._objects: Dict[str, SeriesProject] = {}
|
||
import ast
|
||
is_dirty = False
|
||
for k, v in list(self._data.items()):
|
||
try:
|
||
# 修复可能存在的字符串化数据
|
||
if isinstance(v, str):
|
||
try:
|
||
v = ast.literal_eval(v)
|
||
self._data[k] = v
|
||
is_dirty = True
|
||
except:
|
||
pass
|
||
|
||
# 递归修复嵌套的字符串数据(如 memory 或 globalContext)
|
||
if isinstance(v, dict):
|
||
for sub_k, sub_v in v.items():
|
||
if isinstance(sub_v, str) and (sub_v.startswith('{') or sub_v.startswith('[')):
|
||
try:
|
||
v[sub_k] = ast.literal_eval(sub_v)
|
||
is_dirty = True
|
||
except:
|
||
pass
|
||
|
||
self._objects[k] = SeriesProject.parse_obj(v)
|
||
except Exception as e:
|
||
logger.error(f"Failed to parse project {k}: {e}")
|
||
|
||
if is_dirty:
|
||
self._save()
|
||
|
||
async def create(self, project_data: SeriesProjectCreate) -> SeriesProject:
|
||
"""创建新项目"""
|
||
project_id = str(uuid.uuid4())
|
||
project = SeriesProject(
|
||
id=project_id,
|
||
name=project_data.name,
|
||
totalEpisodes=project_data.totalEpisodes,
|
||
agentId=project_data.agentId,
|
||
mode=project_data.mode,
|
||
globalContext=project_data.globalContext,
|
||
skillSettings=project_data.skillSettings,
|
||
createdAt=datetime.now(),
|
||
updatedAt=datetime.now()
|
||
)
|
||
self._objects[project_id] = project
|
||
self._data[project_id] = json.loads(project.json())
|
||
self._save()
|
||
|
||
logger.info(f"创建项目: {project_id} - {project.name}")
|
||
return project
|
||
|
||
async def get(self, project_id: str) -> Optional[SeriesProject]:
|
||
"""获取项目"""
|
||
return self._objects.get(project_id)
|
||
|
||
async def list(self, skip: int = 0, limit: int = 100) -> List[SeriesProject]:
|
||
"""列出所有项目"""
|
||
# 按创建时间倒序
|
||
projects = sorted(
|
||
self._objects.values(),
|
||
key=lambda p: p.createdAt or datetime.min,
|
||
reverse=True
|
||
)
|
||
return projects[skip:skip + limit]
|
||
|
||
async def update(
|
||
self,
|
||
project_id: str,
|
||
project_data: dict
|
||
) -> Optional[SeriesProject]:
|
||
"""更新项目"""
|
||
project = self._objects.get(project_id)
|
||
if not project:
|
||
return None
|
||
|
||
for key, value in project_data.items():
|
||
if hasattr(project, key):
|
||
setattr(project, key, value)
|
||
|
||
project.updatedAt = datetime.now()
|
||
|
||
# 更新存储
|
||
self._data[project_id] = json.loads(project.json())
|
||
self._save()
|
||
|
||
return project
|
||
|
||
async def delete(self, project_id: str) -> bool:
|
||
"""删除项目"""
|
||
if project_id in self._objects:
|
||
del self._objects[project_id]
|
||
if project_id in self._data:
|
||
del self._data[project_id]
|
||
self._save()
|
||
return True
|
||
return False
|
||
|
||
|
||
class EpisodeRepository(JsonRepository):
|
||
"""剧集仓储(持久化版)"""
|
||
|
||
def __init__(self):
|
||
super().__init__(EPISODES_FILE)
|
||
self._objects: Dict[str, Episode] = {}
|
||
# 用于追踪已存在的 (projectId, number) 对,防止重复
|
||
seen_episodes = set()
|
||
duplicates_to_remove = []
|
||
|
||
for k, v in list(self._data.items()):
|
||
try:
|
||
episode = Episode.parse_obj(v)
|
||
key = (episode.projectId, episode.number)
|
||
|
||
if key in seen_episodes:
|
||
logger.warning(f"Found duplicate episode: Project {episode.projectId}, EP{episode.number}. Removing ID {k}")
|
||
duplicates_to_remove.append(k)
|
||
continue
|
||
|
||
seen_episodes.add(key)
|
||
self._objects[k] = episode
|
||
except Exception as e:
|
||
logger.error(f"Failed to parse episode {k}: {e}")
|
||
|
||
# 清理重复项
|
||
if duplicates_to_remove:
|
||
for k in duplicates_to_remove:
|
||
del self._data[k]
|
||
self._save()
|
||
|
||
async def create(self, episode: Episode) -> Episode:
|
||
"""创建剧集"""
|
||
if not episode.id:
|
||
episode.id = str(uuid.uuid4())
|
||
|
||
self._objects[episode.id] = episode
|
||
self._data[episode.id] = json.loads(episode.json())
|
||
self._save()
|
||
|
||
logger.info(f"创建剧集: {episode.id} - EP{episode.number}")
|
||
return episode
|
||
|
||
async def get(self, episode_id: str) -> Optional[Episode]:
|
||
"""获取剧集"""
|
||
return self._objects.get(episode_id)
|
||
|
||
async def list_by_project(
|
||
self,
|
||
project_id: str,
|
||
skip: int = 0,
|
||
limit: int = 100
|
||
) -> List[Episode]:
|
||
"""列出项目的所有剧集"""
|
||
episodes = [
|
||
ep for ep in self._objects.values()
|
||
if ep.projectId == project_id
|
||
]
|
||
episodes.sort(key=lambda x: x.number)
|
||
return episodes[skip:skip + limit]
|
||
|
||
async def update(self, episode: Episode) -> Episode:
|
||
"""更新剧集"""
|
||
self._objects[episode.id] = episode
|
||
self._data[episode.id] = json.loads(episode.json())
|
||
self._save()
|
||
return episode
|
||
|
||
|
||
# ============================================
|
||
# 全局仓储实例
|
||
# ============================================
|
||
project_repo = ProjectRepository()
|
||
episode_repo = EpisodeRepository()
|
||
|
||
class MessageRepository(JsonRepository):
|
||
"""消息记录仓储"""
|
||
|
||
def __init__(self):
|
||
super().__init__(MESSAGES_FILE)
|
||
# 修复可能被错误序列化为字符串的历史数据
|
||
import ast
|
||
is_dirty = False
|
||
for project_id in list(self._data.keys()):
|
||
messages = self._data[project_id]
|
||
if isinstance(messages, str):
|
||
try:
|
||
# 尝试解析 Python repr 格式的字符串
|
||
self._data[project_id] = ast.literal_eval(messages)
|
||
is_dirty = True
|
||
logger.info(f"Fixed corrupted message history for project {project_id}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to fix message history for {project_id}: {e}")
|
||
self._data[project_id] = []
|
||
is_dirty = True
|
||
|
||
if is_dirty:
|
||
self._save()
|
||
|
||
async def add_message(self, project_id: str, role: str, content: str):
|
||
"""添加消息"""
|
||
if project_id not in self._data:
|
||
self._data[project_id] = []
|
||
|
||
message = {
|
||
"role": role,
|
||
"content": content,
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
self._data[project_id].append(message)
|
||
self._save()
|
||
|
||
async def get_history(self, project_id: str) -> List[Dict]:
|
||
"""获取项目聊天历史"""
|
||
return self._data.get(project_id, [])
|
||
|
||
message_repo = MessageRepository()
|