hs-video-api/video_service.py
2025-06-07 00:28:35 +08:00

298 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
视频生成服务模块
使用火山引擎豆包视频生成API
"""
import os
import json
import time
from typing import Dict, List, Optional, Any
from datetime import datetime
from volcenginesdkarkruntime import Ark
class VideoGenerationService:
"""视频生成服务类"""
def __init__(self):
"""
初始化视频生成服务
"""
# 使用环境变量初始化客户端
api_key = os.environ.get('ARK_API_KEY', '')
if not api_key:
raise ValueError("ARK_API_KEY环境变量未设置")
self.client = Ark(
api_key=api_key
)
# 从环境变量获取模型ID
self.model_id = os.environ.get("VIDEO_MODEL")
def create_video_generation_task(self, content, callback_url=None, parameters=None) -> Dict[str, Any]:
"""
创建视频生成任务
Args:
content: 请求内容,格式为 {'image_url': str, 'prompt': str}
callback_url: 回调URL可选
parameters: 额外参数(可选)
Returns:
包含任务信息的字典
"""
try:
model = self.model_id
# 构建符合官方API格式的content数组
api_content = []
# 添加文本提示词
if 'prompt' in content:
prompt_text = content['prompt']
# 如果parameters中有参数将其追加到prompt中
if parameters:
param_parts = []
for key, value in parameters.items():
if key == 'duration':
param_parts.append(f"--dur {value}")
elif key == 'ratio':
param_parts.append(f"--rt {value}")
elif key == 'resolution':
param_parts.append(f"--rs {value}")
elif key == 'framepersecond':
param_parts.append(f"--fps {value}")
elif key == 'watermark':
param_parts.append(f"--wm {value}")
elif key == 'seed':
param_parts.append(f"--seed {value}")
elif key == 'camerafixed':
param_parts.append(f"--cf {value}")
if param_parts:
prompt_text += " " + " ".join(param_parts)
api_content.append({
"type": "text",
"text": prompt_text
})
# 添加图片URL
if 'image_url' in content:
api_content.append({
"type": "image_url",
"image_url": {
"url": content['image_url']
}
})
print(f"api_content: {api_content}")
# 使用官方SDK创建任务
create_result = self.client.content_generation.tasks.create(
model=model,
content=api_content
)
task_id = create_result.id
return {'success': True, 'data':{'task_id':task_id}}
except Exception as e:
return {
'success': False,
'error': f'请求异常: {str(e)}'
}
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""
查询任务状态
Args:
task_id: 任务ID
Returns:
包含任务状态信息的字典
"""
try:
result = self.client.content_generation.tasks.get(
task_id=task_id,
)
print(result)
# 构建返回数据,匹配实际的 ContentGenerationTask 对象结构
task_data = {
'id': result.id,
'task_id': result.id, # 保持兼容性
'model': result.model,
'status': result.status,
'error': result.error,
'content': {
'video_url': result.content.video_url
} if hasattr(result, 'content') and result.content else None,
'usage': {
'completion_tokens': result.usage.completion_tokens,
'total_tokens': result.usage.total_tokens
} if hasattr(result, 'usage') and result.usage else None,
'created_at': result.created_at,
'updated_at': result.updated_at
}
return {'success': True, 'data': task_data}
except Exception as e:
error_str = str(e)
# 检查是否是资源未找到的错误
if 'ResourceNotFound' in error_str or '404' in error_str:
return {
'success': True,
'data': {
'id': task_id,
'task_id': task_id,
'status': 'not_found',
'error': 'ResourceNotFound',
'message': '指定的任务资源未找到',
'model': None,
'content': None,
'usage': None,
'created_at': None,
'updated_at': None
}
}
else:
return {
'success': False,
'error': f'查询异常: {str(e)}'
}
def get_task_list(self, limit=20, offset=0) -> Dict[str, Any]:
"""
获取任务列表
Args:
limit: 每页数量
offset: 偏移量
Returns:
包含任务列表的字典
"""
try:
# 将limit/offset转换为page_num/page_size
page_num = (offset // limit) + 1 if limit > 0 else 1
page_size = limit
result = self.client.content_generation.tasks.list(
page_num=page_num,
page_size=page_size
)
# 将ContentGenerationTask对象转换为字典格式
tasks_data = []
if hasattr(result, 'items') and result.items:
for task in result.items:
task_dict = {
'id': getattr(task, 'id', ''),
'task_id': getattr(task, 'id', ''), # 兼容性字段
'status': getattr(task, 'status', ''),
'model': getattr(task, 'model', ''),
'created_at': getattr(task, 'created_at', ''),
'updated_at': getattr(task, 'updated_at', ''),
'error': getattr(task, 'error', None),
}
# 添加content字段
if hasattr(task, 'content') and task.content:
task_dict['content'] = {
'video_url': getattr(task.content, 'video_url', '')
}
else:
task_dict['content'] = None
# 添加usage字段
if hasattr(task, 'usage') and task.usage:
task_dict['usage'] = {
'completion_tokens': getattr(task.usage, 'completion_tokens', 0),
'total_tokens': getattr(task.usage, 'total_tokens', 0)
}
else:
task_dict['usage'] = None
tasks_data.append(task_dict)
return {'success': True, 'data': {
'tasks': tasks_data,
'total': getattr(result, 'total', 0),
'page_num': page_num,
'page_size': page_size,
'limit': limit,
'offset': offset
}}
except Exception as e:
return {
'success': False,
'error': f'获取列表异常: {str(e)}'
}
def delete_task(self, task_id: str) -> Dict[str, Any]:
"""
删除任务
Args:
task_id: 任务ID
Returns:
删除结果
"""
try:
self.client.content_generation.tasks.delete(
task_id=task_id
)
return {'success': True}
except Exception as e:
return {
'success': False,
'error': f'删除异常: {str(e)}'
}
def wait_for_completion(self, task_id: str, max_wait_time: int = 300, check_interval: int = 5) -> Dict[str, Any]:
"""
等待任务完成
Args:
task_id: 任务ID
max_wait_time: 最大等待时间(秒)
check_interval: 检查间隔(秒)
Returns:
最终的任务状态信息
"""
start_time = time.time()
while time.time() - start_time < max_wait_time:
status_result = self.get_task_status(task_id)
if not status_result["success"]:
return status_result
status = status_result.get("data", {}).get("status")
if status not in ["queued", "running"]:
return status_result
time.sleep(check_interval)
return {
"success": False,
"error": "任务等待超时",
"task_id": task_id
}
# 全局服务实例
video_service = None
def get_video_service() -> VideoGenerationService:
"""
获取视频生成服务实例(单例模式)
"""
global video_service
if video_service is None:
video_service = VideoGenerationService()
return video_service