hs-video-api/task_queue_manager.py
2025-06-07 00:28:35 +08:00

478 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import threading
import time
from datetime import datetime, timedelta
from collections import deque
from typing import Dict, List, Any, Optional
import json
import os
from video_service import get_video_service
import logging
logger = logging.getLogger(__name__)
class TaskQueueManager:
"""任务队列管理器"""
def __init__(self, max_running_tasks: int = None, update_interval: int = None, persistence_file: str = None):
"""
初始化任务队列管理器
Args:
max_running_tasks: 最大运行任务数量(默认从环境变量读取)
update_interval: 更新间隔(秒)(默认从环境变量读取)
persistence_file: 持久化文件路径(默认从环境变量读取)
"""
# 从环境变量读取配置,如果没有则使用默认值
self.max_running_tasks = max_running_tasks or int(os.environ.get('QUEUE_MAX_RUNNING_TASKS', '5'))
self.update_interval = update_interval or int(os.environ.get('QUEUE_UPDATE_INTERVAL', '5'))
self.persistence_file = persistence_file or os.environ.get('QUEUE_PERSISTENCE_FILE', 'task_queue_persistence.json')
# 运行中的任务缓存 (task_id -> task_data)
self.running_tasks_cache: Dict[str, Dict[str, Any]] = {}
# 已完成任务缓存 (task_id -> task_data) - 按时间排序保留最近的任务
self.completed_tasks_cache: Dict[str, Dict[str, Any]] = {}
# 等待队列 (FIFO) - 存储等待中的任务请求
self.waiting_queue: deque = deque()
# 线程锁
self._lock = threading.Lock()
# 更新线程
self._update_thread = None
self._stop_event = threading.Event()
# 视频服务
self.video_service = get_video_service()
# 缓存清理配置(从环境变量读取)
self.max_completed_cache_size = int(os.environ.get('QUEUE_MAX_COMPLETED_CACHE_SIZE', '100')) # 最多保留已完成任务数量
self.completed_cache_ttl_hours = int(os.environ.get('QUEUE_COMPLETED_CACHE_TTL_HOURS', '24')) # 已完成任务缓存保留小时数
def start(self):
"""启动队列管理器"""
logger.info("启动任务队列管理器")
# 从持久化文件恢复等待队列
self._load_persistence_data()
# 从SDK恢复缓存数据
self._load_initial_tasks()
# 启动更新线程
self._start_update_thread()
def stop(self):
"""停止队列管理器"""
logger.info("停止任务队列管理器")
# 保存等待队列到持久化文件
self._save_persistence_data()
self._stop_event.set()
if self._update_thread and self._update_thread.is_alive():
self._update_thread.join()
def _load_persistence_data(self):
"""从持久化文件加载等待队列数据"""
try:
if os.path.exists(self.persistence_file):
with open(self.persistence_file, 'r', encoding='utf-8') as f:
data = json.load(f)
with self._lock:
# 恢复等待队列
waiting_tasks = data.get('waiting_queue', [])
self.waiting_queue = deque(waiting_tasks)
logger.info(f"从持久化文件恢复了 {len(self.waiting_queue)} 个等待任务")
else:
logger.info("持久化文件不存在,跳过恢复")
except Exception as e:
logger.error(f"加载持久化数据异常: {str(e)}")
def _save_persistence_data(self):
"""保存等待队列数据到持久化文件"""
try:
with self._lock:
data = {
'waiting_queue': list(self.waiting_queue),
'timestamp': datetime.now().isoformat()
}
with open(self.persistence_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"保存了 {len(self.waiting_queue)} 个等待任务到持久化文件")
except Exception as e:
logger.error(f"保存持久化数据异常: {str(e)}")
def _load_initial_tasks(self):
"""从SDK加载初始任务恢复缓存"""
try:
logger.info("开始从SDK恢复任务缓存...")
# 分页加载任务
offset = 0
limit = 50
running_tasks_loaded = 0
completed_tasks_loaded = 0
page_count = 0
while True:
result = self.video_service.get_task_list(limit=limit, offset=offset)
if not result['success']:
logger.error(f"加载任务失败(offset={offset}): {result['error']}")
break
tasks = result['data']['tasks']
if not tasks:
break
with self._lock:
for task in tasks:
task_id = task['task_id']
task['cache_time'] = datetime.now().isoformat()
if task['status'] in ['running', 'queued']:
# 运行中任务缓存
self.running_tasks_cache[task_id] = task
running_tasks_loaded += 1
logger.debug(f"恢复运行中任务: {task_id}")
elif task['status'] not in ['queued', 'running']:
# 已完成任务缓存(只保留最近的)
if completed_tasks_loaded < self.max_completed_cache_size:
self.completed_tasks_cache[task_id] = task
completed_tasks_loaded += 1
logger.debug(f"恢复已完成任务: {task_id}")
# 如果已经加载足够的任务,可以停止
if completed_tasks_loaded >= self.max_completed_cache_size and running_tasks_loaded > 0:
break
offset += limit
page_count += 1
# 防止无限循环最多加载10页
if page_count >= 10:
break
logger.info(f"缓存恢复完成: {running_tasks_loaded} 个运行中任务, {completed_tasks_loaded} 个已完成任务")
except Exception as e:
logger.error(f"恢复任务缓存异常: {str(e)}")
def _start_update_thread(self):
"""启动更新线程"""
self._update_thread = threading.Thread(target=self._update_loop, daemon=True)
self._update_thread.start()
def _update_loop(self):
"""更新循环"""
logger.info(f"任务状态更新线程已启动,更新间隔: {self.update_interval}")
while not self._stop_event.is_set():
try:
# 记录更新开始
running_count = len(self.running_tasks_cache)
if running_count > 0:
logger.info(f"开始更新任务状态,当前运行中任务数: {running_count}")
self._update_task_statuses()
self._process_waiting_queue()
# 记录更新完成
if running_count > 0:
new_running_count = len(self.running_tasks_cache)
logger.info(f"任务状态更新完成,运行中任务数: {new_running_count}")
except Exception as e:
logger.error(f"更新任务状态异常: {str(e)}")
# 等待指定间隔
self._stop_event.wait(self.update_interval)
def _update_task_statuses(self):
"""更新任务状态"""
try:
# 获取video_service实例
from video_service import VideoGenerationService
video_service = VideoGenerationService()
# 获取需要更新的运行中任务
tasks_to_update = []
with self._lock:
tasks_to_update = list(self.running_tasks_cache.values())
if len(tasks_to_update) == 0:
return
logger.info(f"开始查询 {len(tasks_to_update)} 个运行中任务的状态")
# 遍历运行中的任务,查询最新状态
for task in tasks_to_update:
try:
task_id = task.get('task_id') or task.get('id')
if not task_id:
continue
logger.info(f"查询任务 {task_id} 的状态")
# 查询任务状态
status_result = video_service.get_task_status(task_id)
if status_result['success']:
updated_task = status_result['data']
task_status = updated_task.get('status', 'running')
logger.info(f"任务 {task_id} 当前状态: {task_status}")
# 更新缓存中的任务信息
with self._lock:
if task_id in self.running_tasks_cache:
# 更新任务数据
self.running_tasks_cache[task_id].update(updated_task)
# 如果任务已完成,移动到已完成缓存
# 正确的状态succeeded, failed, running, cancelled, queued
# 其中只有 queued 和 running 是运行中状态,其他都是已完成状态
if task_status not in ['queued', 'running']:
completed_task = self.running_tasks_cache.pop(task_id)
completed_task['completed_at'] = datetime.now().isoformat()
self.completed_tasks_cache[task_id] = completed_task
logger.info(f"任务 {task_id} 状态更新为: {task_status},已移动到完成缓存")
else:
logger.info(f"任务 {task_id} 仍在运行中,状态: {task_status}")
else:
logger.warning(f"查询任务 {task_id} 状态失败: {status_result.get('error', '未知错误')}")
except Exception as e:
logger.error(f"更新任务 {task_id} 状态时发生错误: {str(e)}")
continue
# 调用清理方法
self._cleanup_completed_tasks()
except Exception as e:
logger.error(f"更新任务状态时发生错误: {str(e)}")
def _cleanup_completed_tasks(self):
"""清理已完成任务缓存"""
try:
with self._lock:
# 快速检查是否需要清理
cache_size = len(self.completed_tasks_cache)
if cache_size == 0:
return
current_time = datetime.now()
cutoff_time = current_time - timedelta(hours=self.completed_cache_ttl_hours)
# 优化使用集合存储要删除的任务ID提高查找效率
tasks_to_remove = set()
valid_tasks = []
# 1. 一次遍历完成过期检查和有效任务收集
for task_id, task_data in self.completed_tasks_cache.items():
cache_time_str = task_data.get('cache_time')
if cache_time_str:
try:
cache_time = datetime.fromisoformat(cache_time_str)
if cache_time < cutoff_time:
tasks_to_remove.add(task_id)
else:
valid_tasks.append((task_id, task_data, cache_time))
except (ValueError, TypeError):
# 时间格式有问题,标记删除
tasks_to_remove.add(task_id)
else:
# 没有缓存时间,标记删除
tasks_to_remove.add(task_id)
# 2. 检查数量限制(只对有效任务进行排序)
if len(valid_tasks) > self.max_completed_cache_size:
# 按缓存时间排序(最新的在前),只排序有效任务
valid_tasks.sort(key=lambda x: x[2], reverse=True)
# 保留最新的任务,其余的标记删除
for task_id, _, _ in valid_tasks[self.max_completed_cache_size:]:
tasks_to_remove.add(task_id)
# 3. 批量删除(如果有需要删除的任务)
if tasks_to_remove:
for task_id in tasks_to_remove:
self.completed_tasks_cache.pop(task_id, None)
logger.info(f"清理了 {len(tasks_to_remove)} 个已完成任务,当前缓存大小: {len(self.completed_tasks_cache)}")
except Exception as e:
logger.error(f"清理已完成任务缓存时发生错误: {str(e)}")
# 发生错误时不中断服务运行
def _process_waiting_queue(self):
"""处理等待队列"""
try:
tasks_moved = 0
with self._lock:
# 检查是否有空闲位置
while len(self.running_tasks_cache) < self.max_running_tasks and self.waiting_queue:
waiting_task = self.waiting_queue.popleft()
task_id = waiting_task['task_id']
# 将等待任务移到运行中缓存
waiting_task['cache_time'] = datetime.now().isoformat()
self.running_tasks_cache[task_id] = waiting_task
tasks_moved += 1
logger.info(f"等待任务开始执行: {task_id}")
# 批量保存持久化数据(只在有变化时保存)
if tasks_moved > 0:
self._save_persistence_data()
except Exception as e:
logger.error(f"处理等待队列时发生错误: {str(e)}")
# 确保异常不会中断服务运行
def can_create_task(self) -> bool:
"""检查是否可以创建新任务"""
with self._lock:
# 限制总任务数(运行中 + 等待中)不超过合理数量
total_tasks = len(self.running_tasks_cache) + len(self.waiting_queue)
max_total_tasks = self.max_running_tasks + 5 # 最多允许5个等待任务
return total_tasks < max_total_tasks
def add_task_to_queue(self, task_data: Dict[str, Any]) -> bool:
"""添加任务到队列
Args:
task_data: 任务数据
Returns:
True: 直接加入运行队列
False: 加入等待队列
"""
task_id = task_data['task_id']
task_data['cache_time'] = datetime.now().isoformat()
with self._lock:
if len(self.running_tasks_cache) < self.max_running_tasks:
# 直接加入运行中缓存
self.running_tasks_cache[task_id] = task_data
logger.info(f"任务直接加入运行队列: {task_id}")
return True
else:
# 加入等待队列
self.waiting_queue.append(task_data)
logger.info(f"任务加入等待队列: {task_id}, 等待队列长度: {len(self.waiting_queue)}")
# 保存持久化数据(等待队列发生变化)
self._save_persistence_data()
return False
def get_task_from_cache(self, task_id: str) -> Optional[Dict[str, Any]]:
"""从缓存获取任务数据"""
with self._lock:
# 先从运行中缓存查找
if task_id in self.running_tasks_cache:
return self.running_tasks_cache[task_id]
# 再从已完成缓存查找
if task_id in self.completed_tasks_cache:
return self.completed_tasks_cache[task_id]
# 最后从等待队列查找
for task in self.waiting_queue:
if task['task_id'] == task_id:
return task
return None
def get_task_by_id(self, task_id: str) -> Optional[Dict[str, Any]]:
"""根据任务ID获取任务数据兼容性方法"""
return self.get_task_from_cache(task_id)
def remove_task_from_cache(self, task_id: str):
"""从缓存中删除任务"""
with self._lock:
removed = False
# 从运行中缓存删除
if task_id in self.running_tasks_cache:
del self.running_tasks_cache[task_id]
logger.info(f"从运行中缓存删除任务: {task_id}")
removed = True
# 从已完成缓存删除
if task_id in self.completed_tasks_cache:
del self.completed_tasks_cache[task_id]
logger.info(f"从已完成缓存删除任务: {task_id}")
removed = True
# 从等待队列中删除
original_length = len(self.waiting_queue)
self.waiting_queue = deque([task for task in self.waiting_queue if task['task_id'] != task_id])
if len(self.waiting_queue) < original_length:
logger.info(f"从等待队列删除任务: {task_id}")
removed = True
# 保存持久化数据(等待队列发生变化)
self._save_persistence_data()
if not removed:
logger.warning(f"任务 {task_id} 不在任何缓存中")
def get_queue_status(self) -> Dict[str, Any]:
"""获取队列状态"""
with self._lock:
# 统计各种状态的任务数量
status_counts = {}
# 统计运行中任务状态
for task in self.running_tasks_cache.values():
status = task['status']
status_counts[status] = status_counts.get(status, 0) + 1
# 统计已完成任务状态
for task in self.completed_tasks_cache.values():
status = task['status']
status_counts[status] = status_counts.get(status, 0) + 1
# 统计等待队列任务状态
for task in self.waiting_queue:
status = task.get('status', 'waiting')
status_counts[status] = status_counts.get(status, 0) + 1
total_cache_count = len(self.running_tasks_cache) + len(self.completed_tasks_cache) + len(self.waiting_queue)
return {
'running_tasks_count': len(self.running_tasks_cache),
'completed_tasks_count': len(self.completed_tasks_cache),
'waiting_queue_count': len(self.waiting_queue),
'total_cache_count': total_cache_count,
'status_counts': status_counts,
'max_running_tasks': self.max_running_tasks,
'max_completed_cache_size': self.max_completed_cache_size,
'completed_cache_ttl_hours': self.completed_cache_ttl_hours,
'running_task_ids': list(self.running_tasks_cache.keys()),
'completed_task_ids': list(self.completed_tasks_cache.keys()),
'waiting_task_ids': [task['task_id'] for task in self.waiting_queue],
'persistence_file': self.persistence_file
}
# 全局队列管理器实例
_queue_manager = None
def get_queue_manager() -> TaskQueueManager:
"""获取队列管理器实例(单例模式)"""
global _queue_manager
if _queue_manager is None:
_queue_manager = TaskQueueManager()
return _queue_manager
def init_queue_manager():
"""初始化并启动队列管理器"""
queue_manager = get_queue_manager()
queue_manager.start()
return queue_manager