298 lines
10 KiB
Python
298 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
视频生成服务模块
|
||
使用火山引擎豆包视频生成API
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import time
|
||
from typing import Dict, List, Optional, Any
|
||
from datetime import datetime
|
||
from volcenginesdkarkruntime import Ark
|
||
|
||
class VideoGenerationService:
|
||
"""视频生成服务类"""
|
||
|
||
def __init__(self):
|
||
"""
|
||
初始化视频生成服务
|
||
"""
|
||
# 使用环境变量初始化客户端
|
||
api_key = os.environ.get('ARK_API_KEY', '')
|
||
if not api_key:
|
||
raise ValueError("ARK_API_KEY环境变量未设置")
|
||
|
||
self.client = Ark(
|
||
api_key=api_key
|
||
)
|
||
|
||
# 从环境变量获取模型ID
|
||
self.model_id = os.environ.get("VIDEO_MODEL")
|
||
|
||
def create_video_generation_task(self, content, callback_url=None, parameters=None) -> Dict[str, Any]:
|
||
"""
|
||
创建视频生成任务
|
||
|
||
Args:
|
||
content: 请求内容,格式为 {'image_url': str, 'prompt': str}
|
||
callback_url: 回调URL(可选)
|
||
parameters: 额外参数(可选)
|
||
|
||
Returns:
|
||
包含任务信息的字典
|
||
"""
|
||
try:
|
||
model = self.model_id
|
||
|
||
# 构建符合官方API格式的content数组
|
||
api_content = []
|
||
|
||
# 添加文本提示词
|
||
if 'prompt' in content:
|
||
prompt_text = content['prompt']
|
||
|
||
# 如果parameters中有参数,将其追加到prompt中
|
||
if parameters:
|
||
param_parts = []
|
||
for key, value in parameters.items():
|
||
if key == 'duration':
|
||
param_parts.append(f"--dur {value}")
|
||
elif key == 'ratio':
|
||
param_parts.append(f"--rt {value}")
|
||
elif key == 'resolution':
|
||
param_parts.append(f"--rs {value}")
|
||
elif key == 'framepersecond':
|
||
param_parts.append(f"--fps {value}")
|
||
elif key == 'watermark':
|
||
param_parts.append(f"--wm {value}")
|
||
elif key == 'seed':
|
||
param_parts.append(f"--seed {value}")
|
||
elif key == 'camerafixed':
|
||
param_parts.append(f"--cf {value}")
|
||
|
||
if param_parts:
|
||
prompt_text += " " + " ".join(param_parts)
|
||
|
||
api_content.append({
|
||
"type": "text",
|
||
"text": prompt_text
|
||
})
|
||
|
||
# 添加图片URL
|
||
if 'image_url' in content:
|
||
api_content.append({
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": content['image_url']
|
||
}
|
||
})
|
||
print(f"api_content: {api_content}")
|
||
# 使用官方SDK创建任务
|
||
create_result = self.client.content_generation.tasks.create(
|
||
model=model,
|
||
content=api_content
|
||
)
|
||
task_id = create_result.id
|
||
return {'success': True, 'data':{'task_id':task_id}}
|
||
except Exception as e:
|
||
return {
|
||
'success': False,
|
||
'error': f'请求异常: {str(e)}'
|
||
}
|
||
|
||
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||
"""
|
||
查询任务状态
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
包含任务状态信息的字典
|
||
"""
|
||
try:
|
||
result = self.client.content_generation.tasks.get(
|
||
task_id=task_id,
|
||
)
|
||
print(result)
|
||
|
||
# 构建返回数据,匹配实际的 ContentGenerationTask 对象结构
|
||
task_data = {
|
||
'id': result.id,
|
||
'task_id': result.id, # 保持兼容性
|
||
'model': result.model,
|
||
'status': result.status,
|
||
'error': result.error,
|
||
'content': {
|
||
'video_url': result.content.video_url
|
||
} if hasattr(result, 'content') and result.content else None,
|
||
'usage': {
|
||
'completion_tokens': result.usage.completion_tokens,
|
||
'total_tokens': result.usage.total_tokens
|
||
} if hasattr(result, 'usage') and result.usage else None,
|
||
'created_at': result.created_at,
|
||
'updated_at': result.updated_at
|
||
}
|
||
|
||
return {'success': True, 'data': task_data}
|
||
|
||
except Exception as e:
|
||
error_str = str(e)
|
||
# 检查是否是资源未找到的错误
|
||
if 'ResourceNotFound' in error_str or '404' in error_str:
|
||
return {
|
||
'success': True,
|
||
'data': {
|
||
'id': task_id,
|
||
'task_id': task_id,
|
||
'status': 'not_found',
|
||
'error': 'ResourceNotFound',
|
||
'message': '指定的任务资源未找到',
|
||
'model': None,
|
||
'content': None,
|
||
'usage': None,
|
||
'created_at': None,
|
||
'updated_at': None
|
||
}
|
||
}
|
||
else:
|
||
return {
|
||
'success': False,
|
||
'error': f'查询异常: {str(e)}'
|
||
}
|
||
|
||
def get_task_list(self, limit=20, offset=0) -> Dict[str, Any]:
|
||
"""
|
||
获取任务列表
|
||
|
||
Args:
|
||
limit: 每页数量
|
||
offset: 偏移量
|
||
|
||
Returns:
|
||
包含任务列表的字典
|
||
"""
|
||
try:
|
||
# 将limit/offset转换为page_num/page_size
|
||
page_num = (offset // limit) + 1 if limit > 0 else 1
|
||
page_size = limit
|
||
|
||
result = self.client.content_generation.tasks.list(
|
||
page_num=page_num,
|
||
page_size=page_size
|
||
)
|
||
|
||
# 将ContentGenerationTask对象转换为字典格式
|
||
tasks_data = []
|
||
if hasattr(result, 'items') and result.items:
|
||
for task in result.items:
|
||
task_dict = {
|
||
'id': getattr(task, 'id', ''),
|
||
'task_id': getattr(task, 'id', ''), # 兼容性字段
|
||
'status': getattr(task, 'status', ''),
|
||
'model': getattr(task, 'model', ''),
|
||
'created_at': getattr(task, 'created_at', ''),
|
||
'updated_at': getattr(task, 'updated_at', ''),
|
||
'error': getattr(task, 'error', None),
|
||
}
|
||
# 添加content字段
|
||
if hasattr(task, 'content') and task.content:
|
||
task_dict['content'] = {
|
||
'video_url': getattr(task.content, 'video_url', '')
|
||
}
|
||
else:
|
||
task_dict['content'] = None
|
||
|
||
# 添加usage字段
|
||
if hasattr(task, 'usage') and task.usage:
|
||
task_dict['usage'] = {
|
||
'completion_tokens': getattr(task.usage, 'completion_tokens', 0),
|
||
'total_tokens': getattr(task.usage, 'total_tokens', 0)
|
||
}
|
||
else:
|
||
task_dict['usage'] = None
|
||
tasks_data.append(task_dict)
|
||
|
||
return {'success': True, 'data': {
|
||
'tasks': tasks_data,
|
||
'total': getattr(result, 'total', 0),
|
||
'page_num': page_num,
|
||
'page_size': page_size,
|
||
'limit': limit,
|
||
'offset': offset
|
||
}}
|
||
|
||
except Exception as e:
|
||
return {
|
||
'success': False,
|
||
'error': f'获取列表异常: {str(e)}'
|
||
}
|
||
|
||
def delete_task(self, task_id: str) -> Dict[str, Any]:
|
||
"""
|
||
删除任务
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
删除结果
|
||
"""
|
||
try:
|
||
self.client.content_generation.tasks.delete(
|
||
task_id=task_id
|
||
)
|
||
return {'success': True}
|
||
|
||
except Exception as e:
|
||
return {
|
||
'success': False,
|
||
'error': f'删除异常: {str(e)}'
|
||
}
|
||
|
||
def wait_for_completion(self, task_id: str, max_wait_time: int = 300, check_interval: int = 5) -> Dict[str, Any]:
|
||
"""
|
||
等待任务完成
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
max_wait_time: 最大等待时间(秒)
|
||
check_interval: 检查间隔(秒)
|
||
|
||
Returns:
|
||
最终的任务状态信息
|
||
"""
|
||
start_time = time.time()
|
||
|
||
while time.time() - start_time < max_wait_time:
|
||
status_result = self.get_task_status(task_id)
|
||
|
||
if not status_result["success"]:
|
||
return status_result
|
||
|
||
status = status_result.get("data", {}).get("status")
|
||
|
||
if status not in ["queued", "running"]:
|
||
return status_result
|
||
|
||
time.sleep(check_interval)
|
||
|
||
return {
|
||
"success": False,
|
||
"error": "任务等待超时",
|
||
"task_id": task_id
|
||
}
|
||
|
||
# 全局服务实例
|
||
video_service = None
|
||
|
||
def get_video_service() -> VideoGenerationService:
|
||
"""
|
||
获取视频生成服务实例(单例模式)
|
||
"""
|
||
global video_service
|
||
if video_service is None:
|
||
video_service = VideoGenerationService()
|
||
return video_service |