creative_studio/backend/app/core/task_persistence.py
2026-02-03 01:12:39 +08:00

170 lines
5.2 KiB
Python

"""
任务持久化存储
使用文件系统持久化任务状态,支持:
1. 服务重启后恢复任务
2. 前端刷新后可继续查看任务
"""
import json
import os
from datetime import datetime
from typing import Dict, List, Optional
from pathlib import Path
from app.models.task import AsyncTask, TaskStatus
from app.utils.logger import get_logger
logger = get_logger(__name__)
class TaskPersistence:
"""任务持久化管理器"""
def __init__(self, storage_dir: str = "data/tasks"):
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(parents=True, exist_ok=True)
self.tasks_file = self.storage_dir / "tasks.json"
def save_task(self, task: AsyncTask) -> bool:
"""保存单个任务"""
try:
tasks = self._load_all_tasks()
tasks[task.id] = task.model_dump(mode='json')
self._save_all_tasks(tasks)
return True
except Exception as e:
logger.error(f"保存任务失败: {task.id} - {e}")
return False
def get_task(self, task_id: str) -> Optional[Dict]:
"""获取任务"""
try:
tasks = self._load_all_tasks()
return tasks.get(task_id)
except Exception as e:
logger.error(f"获取任务失败: {task_id} - {e}")
return None
def get_all_tasks(self) -> Dict[str, Dict]:
"""获取所有任务"""
try:
return self._load_all_tasks()
except Exception as e:
logger.error(f"获取所有任务失败: {e}")
return {}
def delete_task(self, task_id: str) -> bool:
"""删除任务"""
try:
tasks = self._load_all_tasks()
if task_id in tasks:
del tasks[task_id]
self._save_all_tasks(tasks)
return True
except Exception as e:
logger.error(f"删除任务失败: {task_id} - {e}")
return False
def update_task_status(
self,
task_id: str,
status: TaskStatus,
result: Optional[Dict] = None,
error: Optional[str] = None
) -> bool:
"""更新任务状态"""
try:
tasks = self._load_all_tasks()
if task_id not in tasks:
return False
task_data = tasks[task_id]
task_data['status'] = status.value
task_data['updated_at'] = datetime.now().isoformat()
if result is not None:
task_data['result'] = result
if error is not None:
task_data['error'] = error
self._save_all_tasks(tasks)
return True
except Exception as e:
logger.error(f"更新任务状态失败: {task_id} - {e}")
return False
def cleanup_old_tasks(self, max_age_hours: int = 24) -> int:
"""清理旧任务(已完成/失败/取消超过指定时间的任务)"""
try:
tasks = self._load_all_tasks()
now = datetime.now()
to_delete = []
for task_id, task_data in tasks.items():
# 跳过运行中的任务
if task_data.get('status') in ['running', 'pending']:
continue
# 检查任务年龄
updated_at = task_data.get('updated_at')
if updated_at:
try:
update_time = datetime.fromisoformat(updated_at)
age_hours = (now - update_time).total_seconds() / 3600
if age_hours > max_age_hours:
to_delete.append(task_id)
except:
pass
# 删除旧任务
for task_id in to_delete:
del tasks[task_id]
if to_delete:
self._save_all_tasks(tasks)
logger.info(f"清理了 {len(to_delete)} 个旧任务")
return len(to_delete)
except Exception as e:
logger.error(f"清理旧任务失败: {e}")
return 0
def _load_all_tasks(self) -> Dict[str, Dict]:
"""加载所有任务"""
if not self.tasks_file.exists():
return {}
try:
with open(self.tasks_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"加载任务文件失败: {e}")
return {}
def _save_all_tasks(self, tasks: Dict[str, Dict]) -> None:
"""保存所有任务"""
try:
# 先写入临时文件,然后重命名,避免写入失败导致数据丢失
temp_file = self.tasks_file.with_suffix('.tmp')
with open(temp_file, 'w', encoding='utf-8') as f:
json.dump(tasks, f, ensure_ascii=False, indent=2)
# 重命名
temp_file.replace(self.tasks_file)
except Exception as e:
logger.error(f"保存任务文件失败: {e}")
raise
# 全局单例
_persistence: Optional[TaskPersistence] = None
def get_task_persistence() -> TaskPersistence:
"""获取任务持久化单例"""
global _persistence
if _persistence is None:
_persistence = TaskPersistence()
return _persistence