creative_studio/backend/app/core/task_manager.py

259 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
异步任务管理器
负责管理异步 AI 生成任务的创建、执行和状态跟踪
"""
import asyncio
import uuid
from typing import Dict, Optional, List, Callable, Any
from datetime import datetime
from collections import defaultdict
from app.models.task import AsyncTask, TaskStatus, TaskType, TaskProgress
from app.utils.logger import get_logger
logger = get_logger(__name__)
class TaskManager:
"""
异步任务管理器
功能:
1. 创建异步任务
2. 执行任务(在后台)
3. 跟踪任务状态和进度
4. 提供任务查询接口
"""
def __init__(self):
# 内存存储(生产环境应使用 Redis 或数据库)
self._tasks: Dict[str, AsyncTask] = {}
# 按项目ID索引的任务
self._project_tasks: Dict[str, List[str]] = defaultdict(list)
# 按类型索引的任务
self._type_tasks: Dict[TaskType, List[str]] = defaultdict(list)
def create_task(
self,
task_type: TaskType,
params: Dict[str, Any],
project_id: Optional[str] = None
) -> AsyncTask:
"""
创建新任务
Args:
task_type: 任务类型
params: 任务参数
project_id: 关联的项目ID
Returns:
创建的任务
"""
task_id = str(uuid.uuid4())
task = AsyncTask(
id=task_id,
type=task_type,
params=params,
project_id=project_id,
status=TaskStatus.PENDING
)
self._tasks[task_id] = task
if project_id:
self._project_tasks[project_id].append(task_id)
self._type_tasks[task_type].append(task_id)
logger.info(f"创建任务: {task_id} ({task_type.value})")
return task
def get_task(self, task_id: str) -> Optional[AsyncTask]:
"""获取任务"""
return self._tasks.get(task_id)
def get_tasks_by_project(self, project_id: str) -> List[AsyncTask]:
"""获取项目的所有任务"""
task_ids = self._project_tasks.get(project_id, [])
return [self._tasks[tid] for tid in task_ids if tid in self._tasks]
def get_tasks_by_type(self, task_type: TaskType) -> List[AsyncTask]:
"""获取指定类型的所有任务"""
task_ids = self._type_tasks.get(task_type, [])
return [self._tasks[tid] for tid in task_ids if tid in self._tasks]
def update_task_progress(
self,
task_id: str,
current: int,
total: int = 100,
message: str = "",
stage: Optional[str] = None
) -> bool:
"""
更新任务进度
Args:
task_id: 任务ID
current: 当前进度
total: 总进度
message: 进度消息
stage: 当前阶段
Returns:
是否更新成功
"""
task = self._tasks.get(task_id)
if not task:
return False
task.progress = TaskProgress(
current=current,
total=total,
message=message,
stage=stage
)
task.updated_at = datetime.now()
logger.debug(f"任务进度更新: {task_id} - {current}/{total} - {message}")
return True
def update_task_status(
self,
task_id: str,
status: TaskStatus,
result: Optional[Dict[str, Any]] = None,
error: Optional[str] = None
) -> bool:
"""
更新任务状态
Args:
task_id: 任务ID
status: 新状态
result: 任务结果
error: 错误信息
Returns:
是否更新成功
"""
task = self._tasks.get(task_id)
if not task:
return False
old_status = task.status
task.status = status
task.updated_at = datetime.now()
if result:
task.result = result
if error:
task.error = error
# 状态变更时间戳
if status == TaskStatus.RUNNING and old_status == TaskStatus.PENDING:
task.started_at = datetime.now()
elif status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
task.completed_at = datetime.now()
logger.info(f"任务状态更新: {task_id} - {old_status.value} -> {status.value}")
return True
def cancel_task(self, task_id: str) -> bool:
"""取消任务"""
return self.update_task_status(task_id, TaskStatus.CANCELLED)
def delete_task(self, task_id: str) -> bool:
"""删除任务"""
task = self._tasks.get(task_id)
if not task:
return False
# 从索引中移除
if task.project_id:
self._project_tasks[task.project_id] = [
tid for tid in self._project_tasks[task.project_id] if tid != task_id
]
self._type_tasks[task.type] = [
tid for tid in self._type_tasks[task.type] if tid != task_id
]
del self._tasks[task_id]
logger.info(f"删除任务: {task_id}")
return True
async def execute_task_async(
self,
task_id: str,
executor: Callable
) -> Dict[str, Any]:
"""
异步执行任务
Args:
task_id: 任务ID
executor: 执行器函数,接收 params返回 result
Returns:
任务执行结果
"""
task = self._tasks.get(task_id)
if not task:
raise ValueError(f"任务不存在: {task_id}")
# 更新状态为运行中
self.update_task_status(task_id, TaskStatus.RUNNING)
try:
# 执行任务
result = await executor(task.params)
# 更新状态为完成
self.update_task_status(task_id, TaskStatus.COMPLETED, result=result)
return result
except Exception as e:
logger.error(f"任务执行失败: {task_id} - {str(e)}")
self.update_task_status(task_id, TaskStatus.FAILED, error=str(e))
raise
def execute_task_in_background(
self,
task_id: str,
executor: Callable
) -> asyncio.Task:
"""
在后台执行任务
Args:
task_id: 任务ID
executor: 执行器函数
Returns:
asyncio Task 对象
"""
return asyncio.create_task(self.execute_task_async(task_id, executor))
# 全局单例
_task_manager: Optional[TaskManager] = None
def get_task_manager() -> TaskManager:
"""获取任务管理器单例"""
global _task_manager
if _task_manager is None:
_task_manager = TaskManager()
return _task_manager