259 lines
6.7 KiB
Python
259 lines
6.7 KiB
Python
"""
|
||
异步任务管理器
|
||
|
||
负责管理异步 AI 生成任务的创建、执行和状态跟踪
|
||
"""
|
||
import asyncio
|
||
import uuid
|
||
from typing import Dict, Optional, List, Callable, Any
|
||
from datetime import datetime
|
||
from collections import defaultdict
|
||
|
||
from app.models.task import AsyncTask, TaskStatus, TaskType, TaskProgress
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class TaskManager:
|
||
"""
|
||
异步任务管理器
|
||
|
||
功能:
|
||
1. 创建异步任务
|
||
2. 执行任务(在后台)
|
||
3. 跟踪任务状态和进度
|
||
4. 提供任务查询接口
|
||
"""
|
||
|
||
def __init__(self):
|
||
# 内存存储(生产环境应使用 Redis 或数据库)
|
||
self._tasks: Dict[str, AsyncTask] = {}
|
||
# 按项目ID索引的任务
|
||
self._project_tasks: Dict[str, List[str]] = defaultdict(list)
|
||
# 按类型索引的任务
|
||
self._type_tasks: Dict[TaskType, List[str]] = defaultdict(list)
|
||
|
||
def create_task(
|
||
self,
|
||
task_type: TaskType,
|
||
params: Dict[str, Any],
|
||
project_id: Optional[str] = None
|
||
) -> AsyncTask:
|
||
"""
|
||
创建新任务
|
||
|
||
Args:
|
||
task_type: 任务类型
|
||
params: 任务参数
|
||
project_id: 关联的项目ID
|
||
|
||
Returns:
|
||
创建的任务
|
||
"""
|
||
task_id = str(uuid.uuid4())
|
||
|
||
task = AsyncTask(
|
||
id=task_id,
|
||
type=task_type,
|
||
params=params,
|
||
project_id=project_id,
|
||
status=TaskStatus.PENDING
|
||
)
|
||
|
||
self._tasks[task_id] = task
|
||
|
||
if project_id:
|
||
self._project_tasks[project_id].append(task_id)
|
||
|
||
self._type_tasks[task_type].append(task_id)
|
||
|
||
logger.info(f"创建任务: {task_id} ({task_type.value})")
|
||
|
||
return task
|
||
|
||
def get_task(self, task_id: str) -> Optional[AsyncTask]:
|
||
"""获取任务"""
|
||
return self._tasks.get(task_id)
|
||
|
||
def get_tasks_by_project(self, project_id: str) -> List[AsyncTask]:
|
||
"""获取项目的所有任务"""
|
||
task_ids = self._project_tasks.get(project_id, [])
|
||
return [self._tasks[tid] for tid in task_ids if tid in self._tasks]
|
||
|
||
def get_tasks_by_type(self, task_type: TaskType) -> List[AsyncTask]:
|
||
"""获取指定类型的所有任务"""
|
||
task_ids = self._type_tasks.get(task_type, [])
|
||
return [self._tasks[tid] for tid in task_ids if tid in self._tasks]
|
||
|
||
def update_task_progress(
|
||
self,
|
||
task_id: str,
|
||
current: int,
|
||
total: int = 100,
|
||
message: str = "",
|
||
stage: Optional[str] = None
|
||
) -> bool:
|
||
"""
|
||
更新任务进度
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
current: 当前进度
|
||
total: 总进度
|
||
message: 进度消息
|
||
stage: 当前阶段
|
||
|
||
Returns:
|
||
是否更新成功
|
||
"""
|
||
task = self._tasks.get(task_id)
|
||
if not task:
|
||
return False
|
||
|
||
task.progress = TaskProgress(
|
||
current=current,
|
||
total=total,
|
||
message=message,
|
||
stage=stage
|
||
)
|
||
task.updated_at = datetime.now()
|
||
|
||
logger.debug(f"任务进度更新: {task_id} - {current}/{total} - {message}")
|
||
|
||
return True
|
||
|
||
def update_task_status(
|
||
self,
|
||
task_id: str,
|
||
status: TaskStatus,
|
||
result: Optional[Dict[str, Any]] = None,
|
||
error: Optional[str] = None
|
||
) -> bool:
|
||
"""
|
||
更新任务状态
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
status: 新状态
|
||
result: 任务结果
|
||
error: 错误信息
|
||
|
||
Returns:
|
||
是否更新成功
|
||
"""
|
||
task = self._tasks.get(task_id)
|
||
if not task:
|
||
return False
|
||
|
||
old_status = task.status
|
||
task.status = status
|
||
task.updated_at = datetime.now()
|
||
|
||
if result:
|
||
task.result = result
|
||
|
||
if error:
|
||
task.error = error
|
||
|
||
# 状态变更时间戳
|
||
if status == TaskStatus.RUNNING and old_status == TaskStatus.PENDING:
|
||
task.started_at = datetime.now()
|
||
elif status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
||
task.completed_at = datetime.now()
|
||
|
||
logger.info(f"任务状态更新: {task_id} - {old_status.value} -> {status.value}")
|
||
|
||
return True
|
||
|
||
def cancel_task(self, task_id: str) -> bool:
|
||
"""取消任务"""
|
||
return self.update_task_status(task_id, TaskStatus.CANCELLED)
|
||
|
||
def delete_task(self, task_id: str) -> bool:
|
||
"""删除任务"""
|
||
task = self._tasks.get(task_id)
|
||
if not task:
|
||
return False
|
||
|
||
# 从索引中移除
|
||
if task.project_id:
|
||
self._project_tasks[task.project_id] = [
|
||
tid for tid in self._project_tasks[task.project_id] if tid != task_id
|
||
]
|
||
|
||
self._type_tasks[task.type] = [
|
||
tid for tid in self._type_tasks[task.type] if tid != task_id
|
||
]
|
||
|
||
del self._tasks[task_id]
|
||
|
||
logger.info(f"删除任务: {task_id}")
|
||
|
||
return True
|
||
|
||
async def execute_task_async(
|
||
self,
|
||
task_id: str,
|
||
executor: Callable
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
异步执行任务
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
executor: 执行器函数,接收 params,返回 result
|
||
|
||
Returns:
|
||
任务执行结果
|
||
"""
|
||
task = self._tasks.get(task_id)
|
||
if not task:
|
||
raise ValueError(f"任务不存在: {task_id}")
|
||
|
||
# 更新状态为运行中
|
||
self.update_task_status(task_id, TaskStatus.RUNNING)
|
||
|
||
try:
|
||
# 执行任务
|
||
result = await executor(task.params)
|
||
|
||
# 更新状态为完成
|
||
self.update_task_status(task_id, TaskStatus.COMPLETED, result=result)
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"任务执行失败: {task_id} - {str(e)}")
|
||
self.update_task_status(task_id, TaskStatus.FAILED, error=str(e))
|
||
raise
|
||
|
||
def execute_task_in_background(
|
||
self,
|
||
task_id: str,
|
||
executor: Callable
|
||
) -> asyncio.Task:
|
||
"""
|
||
在后台执行任务
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
executor: 执行器函数
|
||
|
||
Returns:
|
||
asyncio Task 对象
|
||
"""
|
||
return asyncio.create_task(self.execute_task_async(task_id, executor))
|
||
|
||
|
||
# 全局单例
|
||
_task_manager: Optional[TaskManager] = None
|
||
|
||
|
||
def get_task_manager() -> TaskManager:
|
||
"""获取任务管理器单例"""
|
||
global _task_manager
|
||
if _task_manager is None:
|
||
_task_manager = TaskManager()
|
||
return _task_manager
|