478 lines
21 KiB
Python
478 lines
21 KiB
Python
import threading
|
||
import time
|
||
from datetime import datetime, timedelta
|
||
from collections import deque
|
||
from typing import Dict, List, Any, Optional
|
||
import json
|
||
import os
|
||
from video_service import get_video_service
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class TaskQueueManager:
|
||
"""任务队列管理器"""
|
||
|
||
def __init__(self, max_running_tasks: int = None, update_interval: int = None, persistence_file: str = None):
|
||
"""
|
||
初始化任务队列管理器
|
||
|
||
Args:
|
||
max_running_tasks: 最大运行任务数量(默认从环境变量读取)
|
||
update_interval: 更新间隔(秒)(默认从环境变量读取)
|
||
persistence_file: 持久化文件路径(默认从环境变量读取)
|
||
"""
|
||
# 从环境变量读取配置,如果没有则使用默认值
|
||
self.max_running_tasks = max_running_tasks or int(os.environ.get('QUEUE_MAX_RUNNING_TASKS', '5'))
|
||
self.update_interval = update_interval or int(os.environ.get('QUEUE_UPDATE_INTERVAL', '5'))
|
||
self.persistence_file = persistence_file or os.environ.get('QUEUE_PERSISTENCE_FILE', 'task_queue_persistence.json')
|
||
|
||
# 运行中的任务缓存 (task_id -> task_data)
|
||
self.running_tasks_cache: Dict[str, Dict[str, Any]] = {}
|
||
|
||
# 已完成任务缓存 (task_id -> task_data) - 按时间排序保留最近的任务
|
||
self.completed_tasks_cache: Dict[str, Dict[str, Any]] = {}
|
||
|
||
# 等待队列 (FIFO) - 存储等待中的任务请求
|
||
self.waiting_queue: deque = deque()
|
||
|
||
# 线程锁
|
||
self._lock = threading.Lock()
|
||
|
||
# 更新线程
|
||
self._update_thread = None
|
||
self._stop_event = threading.Event()
|
||
|
||
# 视频服务
|
||
self.video_service = get_video_service()
|
||
|
||
# 缓存清理配置(从环境变量读取)
|
||
self.max_completed_cache_size = int(os.environ.get('QUEUE_MAX_COMPLETED_CACHE_SIZE', '100')) # 最多保留已完成任务数量
|
||
self.completed_cache_ttl_hours = int(os.environ.get('QUEUE_COMPLETED_CACHE_TTL_HOURS', '24')) # 已完成任务缓存保留小时数
|
||
|
||
def start(self):
|
||
"""启动队列管理器"""
|
||
logger.info("启动任务队列管理器")
|
||
|
||
# 从持久化文件恢复等待队列
|
||
self._load_persistence_data()
|
||
|
||
# 从SDK恢复缓存数据
|
||
self._load_initial_tasks()
|
||
|
||
# 启动更新线程
|
||
self._start_update_thread()
|
||
|
||
def stop(self):
|
||
"""停止队列管理器"""
|
||
logger.info("停止任务队列管理器")
|
||
|
||
# 保存等待队列到持久化文件
|
||
self._save_persistence_data()
|
||
|
||
self._stop_event.set()
|
||
if self._update_thread and self._update_thread.is_alive():
|
||
self._update_thread.join()
|
||
|
||
def _load_persistence_data(self):
|
||
"""从持久化文件加载等待队列数据"""
|
||
try:
|
||
if os.path.exists(self.persistence_file):
|
||
with open(self.persistence_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
with self._lock:
|
||
# 恢复等待队列
|
||
waiting_tasks = data.get('waiting_queue', [])
|
||
self.waiting_queue = deque(waiting_tasks)
|
||
|
||
logger.info(f"从持久化文件恢复了 {len(self.waiting_queue)} 个等待任务")
|
||
else:
|
||
logger.info("持久化文件不存在,跳过恢复")
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载持久化数据异常: {str(e)}")
|
||
|
||
def _save_persistence_data(self):
|
||
"""保存等待队列数据到持久化文件"""
|
||
try:
|
||
with self._lock:
|
||
data = {
|
||
'waiting_queue': list(self.waiting_queue),
|
||
'timestamp': datetime.now().isoformat()
|
||
}
|
||
|
||
with open(self.persistence_file, 'w', encoding='utf-8') as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
|
||
logger.info(f"保存了 {len(self.waiting_queue)} 个等待任务到持久化文件")
|
||
|
||
except Exception as e:
|
||
logger.error(f"保存持久化数据异常: {str(e)}")
|
||
|
||
def _load_initial_tasks(self):
|
||
"""从SDK加载初始任务,恢复缓存"""
|
||
try:
|
||
logger.info("开始从SDK恢复任务缓存...")
|
||
|
||
# 分页加载任务
|
||
offset = 0
|
||
limit = 50
|
||
running_tasks_loaded = 0
|
||
completed_tasks_loaded = 0
|
||
page_count = 0
|
||
|
||
while True:
|
||
result = self.video_service.get_task_list(limit=limit, offset=offset)
|
||
if not result['success']:
|
||
logger.error(f"加载任务失败(offset={offset}): {result['error']}")
|
||
break
|
||
|
||
tasks = result['data']['tasks']
|
||
if not tasks:
|
||
break
|
||
|
||
with self._lock:
|
||
for task in tasks:
|
||
task_id = task['task_id']
|
||
task['cache_time'] = datetime.now().isoformat()
|
||
|
||
if task['status'] in ['running', 'queued']:
|
||
# 运行中任务缓存
|
||
self.running_tasks_cache[task_id] = task
|
||
running_tasks_loaded += 1
|
||
logger.debug(f"恢复运行中任务: {task_id}")
|
||
elif task['status'] not in ['queued', 'running']:
|
||
# 已完成任务缓存(只保留最近的)
|
||
if completed_tasks_loaded < self.max_completed_cache_size:
|
||
self.completed_tasks_cache[task_id] = task
|
||
completed_tasks_loaded += 1
|
||
logger.debug(f"恢复已完成任务: {task_id}")
|
||
|
||
# 如果已经加载足够的任务,可以停止
|
||
if completed_tasks_loaded >= self.max_completed_cache_size and running_tasks_loaded > 0:
|
||
break
|
||
|
||
offset += limit
|
||
page_count += 1
|
||
# 防止无限循环,最多加载10页
|
||
if page_count >= 10:
|
||
break
|
||
|
||
logger.info(f"缓存恢复完成: {running_tasks_loaded} 个运行中任务, {completed_tasks_loaded} 个已完成任务")
|
||
|
||
except Exception as e:
|
||
logger.error(f"恢复任务缓存异常: {str(e)}")
|
||
|
||
def _start_update_thread(self):
|
||
"""启动更新线程"""
|
||
self._update_thread = threading.Thread(target=self._update_loop, daemon=True)
|
||
self._update_thread.start()
|
||
|
||
def _update_loop(self):
|
||
"""更新循环"""
|
||
logger.info(f"任务状态更新线程已启动,更新间隔: {self.update_interval}秒")
|
||
while not self._stop_event.is_set():
|
||
try:
|
||
# 记录更新开始
|
||
running_count = len(self.running_tasks_cache)
|
||
if running_count > 0:
|
||
logger.info(f"开始更新任务状态,当前运行中任务数: {running_count}")
|
||
|
||
self._update_task_statuses()
|
||
self._process_waiting_queue()
|
||
|
||
# 记录更新完成
|
||
if running_count > 0:
|
||
new_running_count = len(self.running_tasks_cache)
|
||
logger.info(f"任务状态更新完成,运行中任务数: {new_running_count}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新任务状态异常: {str(e)}")
|
||
|
||
# 等待指定间隔
|
||
self._stop_event.wait(self.update_interval)
|
||
|
||
def _update_task_statuses(self):
|
||
"""更新任务状态"""
|
||
try:
|
||
# 获取video_service实例
|
||
from video_service import VideoGenerationService
|
||
video_service = VideoGenerationService()
|
||
|
||
# 获取需要更新的运行中任务
|
||
tasks_to_update = []
|
||
with self._lock:
|
||
tasks_to_update = list(self.running_tasks_cache.values())
|
||
|
||
if len(tasks_to_update) == 0:
|
||
return
|
||
|
||
logger.info(f"开始查询 {len(tasks_to_update)} 个运行中任务的状态")
|
||
|
||
# 遍历运行中的任务,查询最新状态
|
||
for task in tasks_to_update:
|
||
try:
|
||
task_id = task.get('task_id') or task.get('id')
|
||
if not task_id:
|
||
continue
|
||
|
||
logger.info(f"查询任务 {task_id} 的状态")
|
||
|
||
# 查询任务状态
|
||
status_result = video_service.get_task_status(task_id)
|
||
|
||
if status_result['success']:
|
||
updated_task = status_result['data']
|
||
task_status = updated_task.get('status', 'running')
|
||
|
||
logger.info(f"任务 {task_id} 当前状态: {task_status}")
|
||
|
||
# 更新缓存中的任务信息
|
||
with self._lock:
|
||
if task_id in self.running_tasks_cache:
|
||
# 更新任务数据
|
||
self.running_tasks_cache[task_id].update(updated_task)
|
||
|
||
# 如果任务已完成,移动到已完成缓存
|
||
# 正确的状态:succeeded, failed, running, cancelled, queued
|
||
# 其中只有 queued 和 running 是运行中状态,其他都是已完成状态
|
||
if task_status not in ['queued', 'running']:
|
||
completed_task = self.running_tasks_cache.pop(task_id)
|
||
completed_task['completed_at'] = datetime.now().isoformat()
|
||
self.completed_tasks_cache[task_id] = completed_task
|
||
|
||
logger.info(f"任务 {task_id} 状态更新为: {task_status},已移动到完成缓存")
|
||
else:
|
||
logger.info(f"任务 {task_id} 仍在运行中,状态: {task_status}")
|
||
else:
|
||
logger.warning(f"查询任务 {task_id} 状态失败: {status_result.get('error', '未知错误')}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新任务 {task_id} 状态时发生错误: {str(e)}")
|
||
continue
|
||
|
||
# 调用清理方法
|
||
self._cleanup_completed_tasks()
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新任务状态时发生错误: {str(e)}")
|
||
|
||
def _cleanup_completed_tasks(self):
|
||
"""清理已完成任务缓存"""
|
||
try:
|
||
with self._lock:
|
||
# 快速检查是否需要清理
|
||
cache_size = len(self.completed_tasks_cache)
|
||
if cache_size == 0:
|
||
return
|
||
|
||
current_time = datetime.now()
|
||
cutoff_time = current_time - timedelta(hours=self.completed_cache_ttl_hours)
|
||
|
||
# 优化:使用集合存储要删除的任务ID,提高查找效率
|
||
tasks_to_remove = set()
|
||
valid_tasks = []
|
||
|
||
# 1. 一次遍历完成过期检查和有效任务收集
|
||
for task_id, task_data in self.completed_tasks_cache.items():
|
||
cache_time_str = task_data.get('cache_time')
|
||
if cache_time_str:
|
||
try:
|
||
cache_time = datetime.fromisoformat(cache_time_str)
|
||
if cache_time < cutoff_time:
|
||
tasks_to_remove.add(task_id)
|
||
else:
|
||
valid_tasks.append((task_id, task_data, cache_time))
|
||
except (ValueError, TypeError):
|
||
# 时间格式有问题,标记删除
|
||
tasks_to_remove.add(task_id)
|
||
else:
|
||
# 没有缓存时间,标记删除
|
||
tasks_to_remove.add(task_id)
|
||
|
||
# 2. 检查数量限制(只对有效任务进行排序)
|
||
if len(valid_tasks) > self.max_completed_cache_size:
|
||
# 按缓存时间排序(最新的在前),只排序有效任务
|
||
valid_tasks.sort(key=lambda x: x[2], reverse=True)
|
||
|
||
# 保留最新的任务,其余的标记删除
|
||
for task_id, _, _ in valid_tasks[self.max_completed_cache_size:]:
|
||
tasks_to_remove.add(task_id)
|
||
|
||
# 3. 批量删除(如果有需要删除的任务)
|
||
if tasks_to_remove:
|
||
for task_id in tasks_to_remove:
|
||
self.completed_tasks_cache.pop(task_id, None)
|
||
|
||
logger.info(f"清理了 {len(tasks_to_remove)} 个已完成任务,当前缓存大小: {len(self.completed_tasks_cache)}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"清理已完成任务缓存时发生错误: {str(e)}")
|
||
# 发生错误时不中断服务运行
|
||
|
||
def _process_waiting_queue(self):
|
||
"""处理等待队列"""
|
||
try:
|
||
tasks_moved = 0
|
||
with self._lock:
|
||
# 检查是否有空闲位置
|
||
while len(self.running_tasks_cache) < self.max_running_tasks and self.waiting_queue:
|
||
waiting_task = self.waiting_queue.popleft()
|
||
task_id = waiting_task['task_id']
|
||
|
||
# 将等待任务移到运行中缓存
|
||
waiting_task['cache_time'] = datetime.now().isoformat()
|
||
self.running_tasks_cache[task_id] = waiting_task
|
||
tasks_moved += 1
|
||
|
||
logger.info(f"等待任务开始执行: {task_id}")
|
||
|
||
# 批量保存持久化数据(只在有变化时保存)
|
||
if tasks_moved > 0:
|
||
self._save_persistence_data()
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理等待队列时发生错误: {str(e)}")
|
||
# 确保异常不会中断服务运行
|
||
|
||
def can_create_task(self) -> bool:
|
||
"""检查是否可以创建新任务"""
|
||
with self._lock:
|
||
# 限制总任务数(运行中 + 等待中)不超过合理数量
|
||
total_tasks = len(self.running_tasks_cache) + len(self.waiting_queue)
|
||
max_total_tasks = self.max_running_tasks + 5 # 最多允许5个等待任务
|
||
return total_tasks < max_total_tasks
|
||
|
||
def add_task_to_queue(self, task_data: Dict[str, Any]) -> bool:
|
||
"""添加任务到队列
|
||
|
||
Args:
|
||
task_data: 任务数据
|
||
|
||
Returns:
|
||
True: 直接加入运行队列
|
||
False: 加入等待队列
|
||
"""
|
||
task_id = task_data['task_id']
|
||
task_data['cache_time'] = datetime.now().isoformat()
|
||
|
||
with self._lock:
|
||
if len(self.running_tasks_cache) < self.max_running_tasks:
|
||
# 直接加入运行中缓存
|
||
self.running_tasks_cache[task_id] = task_data
|
||
logger.info(f"任务直接加入运行队列: {task_id}")
|
||
return True
|
||
else:
|
||
# 加入等待队列
|
||
self.waiting_queue.append(task_data)
|
||
logger.info(f"任务加入等待队列: {task_id}, 等待队列长度: {len(self.waiting_queue)}")
|
||
|
||
# 保存持久化数据(等待队列发生变化)
|
||
self._save_persistence_data()
|
||
return False
|
||
|
||
def get_task_from_cache(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||
"""从缓存获取任务数据"""
|
||
with self._lock:
|
||
# 先从运行中缓存查找
|
||
if task_id in self.running_tasks_cache:
|
||
return self.running_tasks_cache[task_id]
|
||
|
||
# 再从已完成缓存查找
|
||
if task_id in self.completed_tasks_cache:
|
||
return self.completed_tasks_cache[task_id]
|
||
|
||
# 最后从等待队列查找
|
||
for task in self.waiting_queue:
|
||
if task['task_id'] == task_id:
|
||
return task
|
||
|
||
return None
|
||
|
||
def get_task_by_id(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||
"""根据任务ID获取任务数据(兼容性方法)"""
|
||
return self.get_task_from_cache(task_id)
|
||
|
||
def remove_task_from_cache(self, task_id: str):
|
||
"""从缓存中删除任务"""
|
||
with self._lock:
|
||
removed = False
|
||
|
||
# 从运行中缓存删除
|
||
if task_id in self.running_tasks_cache:
|
||
del self.running_tasks_cache[task_id]
|
||
logger.info(f"从运行中缓存删除任务: {task_id}")
|
||
removed = True
|
||
|
||
# 从已完成缓存删除
|
||
if task_id in self.completed_tasks_cache:
|
||
del self.completed_tasks_cache[task_id]
|
||
logger.info(f"从已完成缓存删除任务: {task_id}")
|
||
removed = True
|
||
|
||
# 从等待队列中删除
|
||
original_length = len(self.waiting_queue)
|
||
self.waiting_queue = deque([task for task in self.waiting_queue if task['task_id'] != task_id])
|
||
if len(self.waiting_queue) < original_length:
|
||
logger.info(f"从等待队列删除任务: {task_id}")
|
||
removed = True
|
||
# 保存持久化数据(等待队列发生变化)
|
||
self._save_persistence_data()
|
||
|
||
if not removed:
|
||
logger.warning(f"任务 {task_id} 不在任何缓存中")
|
||
|
||
def get_queue_status(self) -> Dict[str, Any]:
|
||
"""获取队列状态"""
|
||
with self._lock:
|
||
# 统计各种状态的任务数量
|
||
status_counts = {}
|
||
|
||
# 统计运行中任务状态
|
||
for task in self.running_tasks_cache.values():
|
||
status = task['status']
|
||
status_counts[status] = status_counts.get(status, 0) + 1
|
||
|
||
# 统计已完成任务状态
|
||
for task in self.completed_tasks_cache.values():
|
||
status = task['status']
|
||
status_counts[status] = status_counts.get(status, 0) + 1
|
||
|
||
# 统计等待队列任务状态
|
||
for task in self.waiting_queue:
|
||
status = task.get('status', 'waiting')
|
||
status_counts[status] = status_counts.get(status, 0) + 1
|
||
|
||
total_cache_count = len(self.running_tasks_cache) + len(self.completed_tasks_cache) + len(self.waiting_queue)
|
||
|
||
return {
|
||
'running_tasks_count': len(self.running_tasks_cache),
|
||
'completed_tasks_count': len(self.completed_tasks_cache),
|
||
'waiting_queue_count': len(self.waiting_queue),
|
||
'total_cache_count': total_cache_count,
|
||
'status_counts': status_counts,
|
||
'max_running_tasks': self.max_running_tasks,
|
||
'max_completed_cache_size': self.max_completed_cache_size,
|
||
'completed_cache_ttl_hours': self.completed_cache_ttl_hours,
|
||
'running_task_ids': list(self.running_tasks_cache.keys()),
|
||
'completed_task_ids': list(self.completed_tasks_cache.keys()),
|
||
'waiting_task_ids': [task['task_id'] for task in self.waiting_queue],
|
||
'persistence_file': self.persistence_file
|
||
}
|
||
|
||
# 全局队列管理器实例
|
||
_queue_manager = None
|
||
|
||
def get_queue_manager() -> TaskQueueManager:
|
||
"""获取队列管理器实例(单例模式)"""
|
||
global _queue_manager
|
||
if _queue_manager is None:
|
||
_queue_manager = TaskQueueManager()
|
||
return _queue_manager
|
||
|
||
def init_queue_manager():
|
||
"""初始化并启动队列管理器"""
|
||
queue_manager = get_queue_manager()
|
||
queue_manager.start()
|
||
return queue_manager |