creative_studio/backend/app/core/task_manager.py
2026-02-03 01:12:39 +08:00

308 lines
8.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
异步任务管理器
负责管理异步 AI 生成任务的创建、执行和状态跟踪
"""
import asyncio
import uuid
from typing import Dict, Optional, List, Callable, Any
from datetime import datetime
from collections import defaultdict
from app.models.task import AsyncTask, TaskStatus, TaskType, TaskProgress
from app.utils.logger import get_logger
from app.core.task_persistence import get_task_persistence
logger = get_logger(__name__)
class TaskManager:
"""
异步任务管理器
功能:
1. 创建异步任务
2. 执行任务(在后台)
3. 跟踪任务状态和进度
4. 提供任务查询接口
"""
def __init__(self):
# 内存存储(生产环境应使用 Redis 或数据库)
self._tasks: Dict[str, AsyncTask] = {}
# 按项目ID索引的任务
self._project_tasks: Dict[str, List[str]] = defaultdict(list)
# 按类型索引的任务
self._type_tasks: Dict[TaskType, List[str]] = defaultdict(list)
# 持久化存储
self._persistence = get_task_persistence()
# 启动时恢复持久化的任务
self._restore_tasks()
def create_task(
self,
task_type: TaskType,
params: Dict[str, Any],
project_id: Optional[str] = None
) -> AsyncTask:
"""
创建新任务
Args:
task_type: 任务类型
params: 任务参数
project_id: 关联的项目ID
Returns:
创建的任务
"""
task_id = str(uuid.uuid4())
task = AsyncTask(
id=task_id,
type=task_type,
params=params,
project_id=project_id,
status=TaskStatus.PENDING
)
self._tasks[task_id] = task
if project_id:
self._project_tasks[project_id].append(task_id)
self._type_tasks[task_type].append(task_id)
# 保存到持久化存储
self._persistence.save_task(task)
logger.info(f"创建任务: {task_id} ({task_type.value})")
return task
def get_task(self, task_id: str) -> Optional[AsyncTask]:
"""获取任务"""
return self._tasks.get(task_id)
def get_tasks_by_project(self, project_id: str) -> List[AsyncTask]:
"""获取项目的所有任务"""
task_ids = self._project_tasks.get(project_id, [])
return [self._tasks[tid] for tid in task_ids if tid in self._tasks]
def get_tasks_by_type(self, task_type: TaskType) -> List[AsyncTask]:
"""获取指定类型的所有任务"""
task_ids = self._type_tasks.get(task_type, [])
return [self._tasks[tid] for tid in task_ids if tid in self._tasks]
def update_task_progress(
self,
task_id: str,
current: int,
total: int = 100,
message: str = "",
stage: Optional[str] = None
) -> bool:
"""
更新任务进度
Args:
task_id: 任务ID
current: 当前进度
total: 总进度
message: 进度消息
stage: 当前阶段
Returns:
是否更新成功
"""
task = self._tasks.get(task_id)
if not task:
return False
task.progress = TaskProgress(
current=current,
total=total,
message=message,
stage=stage
)
task.updated_at = datetime.now()
logger.debug(f"任务进度更新: {task_id} - {current}/{total} - {message}")
return True
def update_task_status(
self,
task_id: str,
status: TaskStatus,
result: Optional[Dict[str, Any]] = None,
error: Optional[str] = None
) -> bool:
"""
更新任务状态
Args:
task_id: 任务ID
status: 新状态
result: 任务结果
error: 错误信息
Returns:
是否更新成功
"""
task = self._tasks.get(task_id)
if not task:
return False
old_status = task.status
task.status = status
task.updated_at = datetime.now()
if result:
task.result = result
if error:
task.error = error
# 状态变更时间戳
if status == TaskStatus.RUNNING and old_status == TaskStatus.PENDING:
task.started_at = datetime.now()
elif status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
task.completed_at = datetime.now()
# 更新持久化存储
self._persistence.update_task_status(task_id, status, result, error)
logger.info(f"任务状态更新: {task_id} - {old_status.value} -> {status.value}")
return True
def cancel_task(self, task_id: str) -> bool:
"""取消任务"""
return self.update_task_status(task_id, TaskStatus.CANCELLED)
def delete_task(self, task_id: str) -> bool:
"""删除任务"""
task = self._tasks.get(task_id)
if not task:
return False
# 从索引中移除
if task.project_id:
self._project_tasks[task.project_id] = [
tid for tid in self._project_tasks[task.project_id] if tid != task_id
]
self._type_tasks[task.type] = [
tid for tid in self._type_tasks[task.type] if tid != task_id
]
del self._tasks[task_id]
# 从持久化存储中删除
self._persistence.delete_task(task_id)
logger.info(f"删除任务: {task_id}")
return True
async def execute_task_async(
self,
task_id: str,
executor: Callable
) -> Dict[str, Any]:
"""
异步执行任务
Args:
task_id: 任务ID
executor: 执行器函数,接收 params返回 result
Returns:
任务执行结果
"""
task = self._tasks.get(task_id)
if not task:
raise ValueError(f"任务不存在: {task_id}")
# 更新状态为运行中
self.update_task_status(task_id, TaskStatus.RUNNING)
try:
# 执行任务
result = await executor(task.params)
# 更新状态为完成
self.update_task_status(task_id, TaskStatus.COMPLETED, result=result)
return result
except Exception as e:
logger.error(f"任务执行失败: {task_id} - {str(e)}")
self.update_task_status(task_id, TaskStatus.FAILED, error=str(e))
raise
def execute_task_in_background(
self,
task_id: str,
executor: Callable
) -> asyncio.Task:
"""
在后台执行任务
Args:
task_id: 任务ID
executor: 执行器函数
Returns:
asyncio Task 对象
"""
return asyncio.create_task(self.execute_task_async(task_id, executor))
def _restore_tasks(self) -> None:
"""从持久化存储恢复任务"""
try:
saved_tasks = self._persistence.get_all_tasks()
restored_count = 0
for task_id, task_data in saved_tasks.items():
# 跳过已完成、失败或取消的任务(太久远的不需要恢复)
if task_data.get('status') in ['completed', 'failed', 'cancelled']:
continue
# 恢复任务到内存
try:
task = AsyncTask(**task_data)
self._tasks[task_id] = task
# 恢复索引
if task.project_id:
self._project_tasks[task.project_id].append(task_id)
self._type_tasks[task.type].append(task_id)
restored_count += 1
except Exception as e:
logger.warning(f"恢复任务失败: {task_id} - {e}")
if restored_count > 0:
logger.info(f"从持久化存储恢复了 {restored_count} 个任务")
# 清理旧任务超过24小时的已完成任务
self._persistence.cleanup_old_tasks(max_age_hours=24)
except Exception as e:
logger.error(f"恢复任务失败: {e}")
# 全局单例
_task_manager: Optional[TaskManager] = None
def get_task_manager() -> TaskManager:
"""获取任务管理器单例"""
global _task_manager
if _task_manager is None:
_task_manager = TaskManager()
return _task_manager