""" 异步任务管理器 负责管理异步 AI 生成任务的创建、执行和状态跟踪 """ import asyncio import uuid from typing import Dict, Optional, List, Callable, Any from datetime import datetime from collections import defaultdict from app.models.task import AsyncTask, TaskStatus, TaskType, TaskProgress from app.utils.logger import get_logger logger = get_logger(__name__) class TaskManager: """ 异步任务管理器 功能: 1. 创建异步任务 2. 执行任务(在后台) 3. 跟踪任务状态和进度 4. 提供任务查询接口 """ def __init__(self): # 内存存储(生产环境应使用 Redis 或数据库) self._tasks: Dict[str, AsyncTask] = {} # 按项目ID索引的任务 self._project_tasks: Dict[str, List[str]] = defaultdict(list) # 按类型索引的任务 self._type_tasks: Dict[TaskType, List[str]] = defaultdict(list) def create_task( self, task_type: TaskType, params: Dict[str, Any], project_id: Optional[str] = None ) -> AsyncTask: """ 创建新任务 Args: task_type: 任务类型 params: 任务参数 project_id: 关联的项目ID Returns: 创建的任务 """ task_id = str(uuid.uuid4()) task = AsyncTask( id=task_id, type=task_type, params=params, project_id=project_id, status=TaskStatus.PENDING ) self._tasks[task_id] = task if project_id: self._project_tasks[project_id].append(task_id) self._type_tasks[task_type].append(task_id) logger.info(f"创建任务: {task_id} ({task_type.value})") return task def get_task(self, task_id: str) -> Optional[AsyncTask]: """获取任务""" return self._tasks.get(task_id) def get_tasks_by_project(self, project_id: str) -> List[AsyncTask]: """获取项目的所有任务""" task_ids = self._project_tasks.get(project_id, []) return [self._tasks[tid] for tid in task_ids if tid in self._tasks] def get_tasks_by_type(self, task_type: TaskType) -> List[AsyncTask]: """获取指定类型的所有任务""" task_ids = self._type_tasks.get(task_type, []) return [self._tasks[tid] for tid in task_ids if tid in self._tasks] def update_task_progress( self, task_id: str, current: int, total: int = 100, message: str = "", stage: Optional[str] = None ) -> bool: """ 更新任务进度 Args: task_id: 任务ID current: 当前进度 total: 总进度 message: 进度消息 stage: 当前阶段 Returns: 是否更新成功 """ task = self._tasks.get(task_id) if not task: return False task.progress = TaskProgress( current=current, total=total, message=message, stage=stage ) task.updated_at = datetime.now() logger.debug(f"任务进度更新: {task_id} - {current}/{total} - {message}") return True def update_task_status( self, task_id: str, status: TaskStatus, result: Optional[Dict[str, Any]] = None, error: Optional[str] = None ) -> bool: """ 更新任务状态 Args: task_id: 任务ID status: 新状态 result: 任务结果 error: 错误信息 Returns: 是否更新成功 """ task = self._tasks.get(task_id) if not task: return False old_status = task.status task.status = status task.updated_at = datetime.now() if result: task.result = result if error: task.error = error # 状态变更时间戳 if status == TaskStatus.RUNNING and old_status == TaskStatus.PENDING: task.started_at = datetime.now() elif status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]: task.completed_at = datetime.now() logger.info(f"任务状态更新: {task_id} - {old_status.value} -> {status.value}") return True def cancel_task(self, task_id: str) -> bool: """取消任务""" return self.update_task_status(task_id, TaskStatus.CANCELLED) def delete_task(self, task_id: str) -> bool: """删除任务""" task = self._tasks.get(task_id) if not task: return False # 从索引中移除 if task.project_id: self._project_tasks[task.project_id] = [ tid for tid in self._project_tasks[task.project_id] if tid != task_id ] self._type_tasks[task.type] = [ tid for tid in self._type_tasks[task.type] if tid != task_id ] del self._tasks[task_id] logger.info(f"删除任务: {task_id}") return True async def execute_task_async( self, task_id: str, executor: Callable ) -> Dict[str, Any]: """ 异步执行任务 Args: task_id: 任务ID executor: 执行器函数,接收 params,返回 result Returns: 任务执行结果 """ task = self._tasks.get(task_id) if not task: raise ValueError(f"任务不存在: {task_id}") # 更新状态为运行中 self.update_task_status(task_id, TaskStatus.RUNNING) try: # 执行任务 result = await executor(task.params) # 更新状态为完成 self.update_task_status(task_id, TaskStatus.COMPLETED, result=result) return result except Exception as e: logger.error(f"任务执行失败: {task_id} - {str(e)}") self.update_task_status(task_id, TaskStatus.FAILED, error=str(e)) raise def execute_task_in_background( self, task_id: str, executor: Callable ) -> asyncio.Task: """ 在后台执行任务 Args: task_id: 任务ID executor: 执行器函数 Returns: asyncio Task 对象 """ return asyncio.create_task(self.execute_task_async(task_id, executor)) # 全局单例 _task_manager: Optional[TaskManager] = None def get_task_manager() -> TaskManager: """获取任务管理器单例""" global _task_manager if _task_manager is None: _task_manager = TaskManager() return _task_manager