""" 异步任务管理器 负责管理异步 AI 生成任务的创建、执行和状态跟踪 """ import asyncio import uuid from typing import Dict, Optional, List, Callable, Any from datetime import datetime from collections import defaultdict from app.models.task import AsyncTask, TaskStatus, TaskType, TaskProgress from app.utils.logger import get_logger from app.core.task_persistence import get_task_persistence logger = get_logger(__name__) class TaskManager: """ 异步任务管理器 功能: 1. 创建异步任务 2. 执行任务(在后台) 3. 跟踪任务状态和进度 4. 提供任务查询接口 """ def __init__(self): # 内存存储(生产环境应使用 Redis 或数据库) self._tasks: Dict[str, AsyncTask] = {} # 按项目ID索引的任务 self._project_tasks: Dict[str, List[str]] = defaultdict(list) # 按类型索引的任务 self._type_tasks: Dict[TaskType, List[str]] = defaultdict(list) # 持久化存储 self._persistence = get_task_persistence() # 启动时恢复持久化的任务 self._restore_tasks() def create_task( self, task_type: TaskType, params: Dict[str, Any], project_id: Optional[str] = None ) -> AsyncTask: """ 创建新任务 Args: task_type: 任务类型 params: 任务参数 project_id: 关联的项目ID Returns: 创建的任务 """ task_id = str(uuid.uuid4()) task = AsyncTask( id=task_id, type=task_type, params=params, project_id=project_id, status=TaskStatus.PENDING ) self._tasks[task_id] = task if project_id: self._project_tasks[project_id].append(task_id) self._type_tasks[task_type].append(task_id) # 保存到持久化存储 self._persistence.save_task(task) logger.info(f"创建任务: {task_id} ({task_type.value})") return task def get_task(self, task_id: str) -> Optional[AsyncTask]: """获取任务""" return self._tasks.get(task_id) def get_tasks_by_project(self, project_id: str) -> List[AsyncTask]: """获取项目的所有任务""" task_ids = self._project_tasks.get(project_id, []) return [self._tasks[tid] for tid in task_ids if tid in self._tasks] def get_tasks_by_type(self, task_type: TaskType) -> List[AsyncTask]: """获取指定类型的所有任务""" task_ids = self._type_tasks.get(task_type, []) return [self._tasks[tid] for tid in task_ids if tid in self._tasks] def update_task_progress( self, task_id: str, current: int, total: int = 100, message: str = "", stage: Optional[str] = None ) -> bool: """ 更新任务进度 Args: task_id: 任务ID current: 当前进度 total: 总进度 message: 进度消息 stage: 当前阶段 Returns: 是否更新成功 """ task = self._tasks.get(task_id) if not task: return False task.progress = TaskProgress( current=current, total=total, message=message, stage=stage ) task.updated_at = datetime.now() logger.debug(f"任务进度更新: {task_id} - {current}/{total} - {message}") return True def update_task_status( self, task_id: str, status: TaskStatus, result: Optional[Dict[str, Any]] = None, error: Optional[str] = None ) -> bool: """ 更新任务状态 Args: task_id: 任务ID status: 新状态 result: 任务结果 error: 错误信息 Returns: 是否更新成功 """ task = self._tasks.get(task_id) if not task: return False old_status = task.status task.status = status task.updated_at = datetime.now() if result: task.result = result if error: task.error = error # 状态变更时间戳 if status == TaskStatus.RUNNING and old_status == TaskStatus.PENDING: task.started_at = datetime.now() elif status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]: task.completed_at = datetime.now() # 更新持久化存储 self._persistence.update_task_status(task_id, status, result, error) logger.info(f"任务状态更新: {task_id} - {old_status.value} -> {status.value}") return True def cancel_task(self, task_id: str) -> bool: """取消任务""" return self.update_task_status(task_id, TaskStatus.CANCELLED) def delete_task(self, task_id: str) -> bool: """删除任务""" task = self._tasks.get(task_id) if not task: return False # 从索引中移除 if task.project_id: self._project_tasks[task.project_id] = [ tid for tid in self._project_tasks[task.project_id] if tid != task_id ] self._type_tasks[task.type] = [ tid for tid in self._type_tasks[task.type] if tid != task_id ] del self._tasks[task_id] # 从持久化存储中删除 self._persistence.delete_task(task_id) logger.info(f"删除任务: {task_id}") return True async def execute_task_async( self, task_id: str, executor: Callable ) -> Dict[str, Any]: """ 异步执行任务 Args: task_id: 任务ID executor: 执行器函数,接收 params,返回 result Returns: 任务执行结果 """ task = self._tasks.get(task_id) if not task: raise ValueError(f"任务不存在: {task_id}") # 更新状态为运行中 self.update_task_status(task_id, TaskStatus.RUNNING) try: # 执行任务 result = await executor(task.params) # 更新状态为完成 self.update_task_status(task_id, TaskStatus.COMPLETED, result=result) return result except Exception as e: logger.error(f"任务执行失败: {task_id} - {str(e)}") self.update_task_status(task_id, TaskStatus.FAILED, error=str(e)) raise def execute_task_in_background( self, task_id: str, executor: Callable ) -> asyncio.Task: """ 在后台执行任务 Args: task_id: 任务ID executor: 执行器函数 Returns: asyncio Task 对象 """ return asyncio.create_task(self.execute_task_async(task_id, executor)) def _restore_tasks(self) -> None: """从持久化存储恢复任务""" try: saved_tasks = self._persistence.get_all_tasks() restored_count = 0 for task_id, task_data in saved_tasks.items(): # 跳过已完成、失败或取消的任务(太久远的不需要恢复) if task_data.get('status') in ['completed', 'failed', 'cancelled']: continue # 恢复任务到内存 try: task = AsyncTask(**task_data) self._tasks[task_id] = task # 恢复索引 if task.project_id: self._project_tasks[task.project_id].append(task_id) self._type_tasks[task.type].append(task_id) restored_count += 1 except Exception as e: logger.warning(f"恢复任务失败: {task_id} - {e}") if restored_count > 0: logger.info(f"从持久化存储恢复了 {restored_count} 个任务") # 清理旧任务(超过24小时的已完成任务) self._persistence.cleanup_old_tasks(max_age_hours=24) except Exception as e: logger.error(f"恢复任务失败: {e}") # 全局单例 _task_manager: Optional[TaskManager] = None def get_task_manager() -> TaskManager: """获取任务管理器单例""" global _task_manager if _task_manager is None: _task_manager = TaskManager() return _task_manager