#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 视频生成服务模块 使用火山引擎豆包视频生成API """ import os import json import time from typing import Dict, List, Optional, Any from datetime import datetime from volcenginesdkarkruntime import Ark class VideoGenerationService: """视频生成服务类""" def __init__(self): """ 初始化视频生成服务 """ # 使用环境变量初始化客户端 api_key = os.environ.get('ARK_API_KEY', '') if not api_key: raise ValueError("ARK_API_KEY环境变量未设置") self.client = Ark( api_key=api_key ) # 从环境变量获取模型ID self.model_id = os.environ.get("VIDEO_MODEL") def create_video_generation_task(self, content, callback_url=None, parameters=None) -> Dict[str, Any]: """ 创建视频生成任务 Args: content: 请求内容,格式为 {'image_url': str, 'prompt': str} callback_url: 回调URL(可选) parameters: 额外参数(可选) Returns: 包含任务信息的字典 """ try: model = self.model_id # 构建符合官方API格式的content数组 api_content = [] # 添加文本提示词 if 'prompt' in content: prompt_text = content['prompt'] # 如果parameters中有参数,将其追加到prompt中 if parameters: param_parts = [] for key, value in parameters.items(): if key == 'duration': param_parts.append(f"--dur {value}") elif key == 'ratio': param_parts.append(f"--rt {value}") elif key == 'resolution': param_parts.append(f"--rs {value}") elif key == 'framepersecond': param_parts.append(f"--fps {value}") elif key == 'watermark': param_parts.append(f"--wm {value}") elif key == 'seed': param_parts.append(f"--seed {value}") elif key == 'camerafixed': param_parts.append(f"--cf {value}") if param_parts: prompt_text += " " + " ".join(param_parts) api_content.append({ "type": "text", "text": prompt_text }) # 添加图片URL if 'image_url' in content: api_content.append({ "type": "image_url", "image_url": { "url": content['image_url'] } }) print(f"api_content: {api_content}") # 使用官方SDK创建任务 create_result = self.client.content_generation.tasks.create( model=model, content=api_content ) task_id = create_result.id return {'success': True, 'data':{'task_id':task_id}} except Exception as e: return { 'success': False, 'error': f'请求异常: {str(e)}' } def get_task_status(self, task_id: str) -> Dict[str, Any]: """ 查询任务状态 Args: task_id: 任务ID Returns: 包含任务状态信息的字典 """ try: result = self.client.content_generation.tasks.get( task_id=task_id, ) print(result) # 构建返回数据,匹配实际的 ContentGenerationTask 对象结构 task_data = { 'id': result.id, 'task_id': result.id, # 保持兼容性 'model': result.model, 'status': result.status, 'error': result.error, 'content': { 'video_url': result.content.video_url } if hasattr(result, 'content') and result.content else None, 'usage': { 'completion_tokens': result.usage.completion_tokens, 'total_tokens': result.usage.total_tokens } if hasattr(result, 'usage') and result.usage else None, 'created_at': result.created_at, 'updated_at': result.updated_at } return {'success': True, 'data': task_data} except Exception as e: error_str = str(e) # 检查是否是资源未找到的错误 if 'ResourceNotFound' in error_str or '404' in error_str: return { 'success': True, 'data': { 'id': task_id, 'task_id': task_id, 'status': 'not_found', 'error': 'ResourceNotFound', 'message': '指定的任务资源未找到', 'model': None, 'content': None, 'usage': None, 'created_at': None, 'updated_at': None } } else: return { 'success': False, 'error': f'查询异常: {str(e)}' } def get_task_list(self, limit=20, offset=0) -> Dict[str, Any]: """ 获取任务列表 Args: limit: 每页数量 offset: 偏移量 Returns: 包含任务列表的字典 """ try: # 将limit/offset转换为page_num/page_size page_num = (offset // limit) + 1 if limit > 0 else 1 page_size = limit result = self.client.content_generation.tasks.list( page_num=page_num, page_size=page_size ) # 将ContentGenerationTask对象转换为字典格式 tasks_data = [] if hasattr(result, 'items') and result.items: for task in result.items: task_dict = { 'id': getattr(task, 'id', ''), 'task_id': getattr(task, 'id', ''), # 兼容性字段 'status': getattr(task, 'status', ''), 'model': getattr(task, 'model', ''), 'created_at': getattr(task, 'created_at', ''), 'updated_at': getattr(task, 'updated_at', ''), 'error': getattr(task, 'error', None), } # 添加content字段 if hasattr(task, 'content') and task.content: task_dict['content'] = { 'video_url': getattr(task.content, 'video_url', '') } else: task_dict['content'] = None # 添加usage字段 if hasattr(task, 'usage') and task.usage: task_dict['usage'] = { 'completion_tokens': getattr(task.usage, 'completion_tokens', 0), 'total_tokens': getattr(task.usage, 'total_tokens', 0) } else: task_dict['usage'] = None tasks_data.append(task_dict) return {'success': True, 'data': { 'tasks': tasks_data, 'total': getattr(result, 'total', 0), 'page_num': page_num, 'page_size': page_size, 'limit': limit, 'offset': offset }} except Exception as e: return { 'success': False, 'error': f'获取列表异常: {str(e)}' } def delete_task(self, task_id: str) -> Dict[str, Any]: """ 删除任务 Args: task_id: 任务ID Returns: 删除结果 """ try: self.client.content_generation.tasks.delete( task_id=task_id ) return {'success': True} except Exception as e: return { 'success': False, 'error': f'删除异常: {str(e)}' } def wait_for_completion(self, task_id: str, max_wait_time: int = 300, check_interval: int = 5) -> Dict[str, Any]: """ 等待任务完成 Args: task_id: 任务ID max_wait_time: 最大等待时间(秒) check_interval: 检查间隔(秒) Returns: 最终的任务状态信息 """ start_time = time.time() while time.time() - start_time < max_wait_time: status_result = self.get_task_status(task_id) if not status_result["success"]: return status_result status = status_result.get("data", {}).get("status") if status not in ["queued", "running"]: return status_result time.sleep(check_interval) return { "success": False, "error": "任务等待超时", "task_id": task_id } # 全局服务实例 video_service = None def get_video_service() -> VideoGenerationService: """ 获取视频生成服务实例(单例模式) """ global video_service if video_service is None: video_service = VideoGenerationService() return video_service