405 lines
14 KiB
TypeScript
405 lines
14 KiB
TypeScript
import GenerationTask, { IGenerationTask } from '@/lib/database/models/GenerationTask.js';
|
||
import GenerationResult, { IGenerationResult } from '@/lib/database/models/GenerationResult.js';
|
||
import NeDBManager from '@/lib/database/nedb.js';
|
||
import Environment from '@/lib/environment.js';
|
||
import logger from '@/lib/logger.js';
|
||
|
||
/**
|
||
* 数据库驱动的生成服务
|
||
* 新的生成方法,支持异步任务处理和结果持久化
|
||
*/
|
||
export class DatabaseGenerationService {
|
||
private static instance: DatabaseGenerationService;
|
||
private currentServerId: string;
|
||
|
||
private constructor() {
|
||
this.currentServerId = process.env.SERVICE_ID || 'jimeng-free-api';
|
||
}
|
||
|
||
public static getInstance(): DatabaseGenerationService {
|
||
if (!DatabaseGenerationService.instance) {
|
||
DatabaseGenerationService.instance = new DatabaseGenerationService();
|
||
}
|
||
return DatabaseGenerationService.instance;
|
||
}
|
||
|
||
/**
|
||
* 图片生成 - 数据库版本
|
||
* 创建任务记录,由轮询服务异步处理
|
||
*/
|
||
async generateImagesV2(
|
||
model: string,
|
||
taskId: string,
|
||
prompt: string,
|
||
params: {
|
||
width?: number;
|
||
height?: number;
|
||
sampleStrength?: number;
|
||
negativePrompt?: string;
|
||
response_format?: string;
|
||
},
|
||
refreshToken: string
|
||
): Promise<void> {
|
||
try {
|
||
// 确保 NeDB 数据库初始化
|
||
await NeDBManager.initialize();
|
||
|
||
const currentServerId = this.currentServerId;
|
||
const imageTimeout = parseInt(process.env.IMAGE_TASK_TIMEOUT || '3600');
|
||
|
||
// 检查任务是否已存在
|
||
const existingTask = await GenerationTask.findOne({ task_id: taskId });
|
||
if (existingTask) {
|
||
logger.warn(`Task ${taskId} already exists, skipping creation`);
|
||
return;
|
||
}
|
||
|
||
// 创建任务记录
|
||
await GenerationTask.create({
|
||
task_id: taskId,
|
||
task_type: 'image',
|
||
server_id: currentServerId,
|
||
original_params: {
|
||
model,
|
||
prompt,
|
||
width: params.width || 1024,
|
||
height: params.height || 1024,
|
||
sample_strength: params.sampleStrength || 0.5,
|
||
negative_prompt: params.negativePrompt || "",
|
||
response_format: params.response_format
|
||
},
|
||
internal_params: {
|
||
refresh_token: refreshToken
|
||
},
|
||
status: 'pending',
|
||
retry_count: 0,
|
||
max_retries: 3,
|
||
poll_interval: 10,
|
||
task_timeout: imageTimeout,
|
||
created_at: Math.floor(Date.now() / 1000),
|
||
updated_at: Math.floor(Date.now() / 1000)
|
||
});
|
||
|
||
logger.info(`Image task created: ${taskId} for server: ${currentServerId}`);
|
||
|
||
} catch (error) {
|
||
logger.error(`Failed to create image task ${taskId}:`, error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 视频生成 - 数据库版本
|
||
* 创建任务记录,由轮询服务异步处理
|
||
*/
|
||
async generateVideoV2(
|
||
taskId: string,
|
||
prompt: string,
|
||
params: {
|
||
images?: Array<{
|
||
url: string;
|
||
width: number;
|
||
height: number;
|
||
}>;
|
||
isPro?: boolean;
|
||
duration?: number;
|
||
ratio?: string;
|
||
},
|
||
refreshToken: string
|
||
): Promise<void> {
|
||
try {
|
||
// 确保 NeDB 数据库初始化
|
||
await NeDBManager.initialize();
|
||
|
||
const currentServerId = this.currentServerId;
|
||
const videoTimeout = parseInt(process.env.VIDEO_TASK_TIMEOUT || '86400');
|
||
|
||
// 检查任务是否已存在
|
||
const existingTask = await GenerationTask.findOne({ task_id: taskId });
|
||
if (existingTask) {
|
||
logger.warn(`Task ${taskId} already exists, skipping creation`);
|
||
return;
|
||
}
|
||
|
||
// 创建任务记录
|
||
await GenerationTask.create({
|
||
task_id: taskId,
|
||
task_type: 'video',
|
||
server_id: currentServerId,
|
||
original_params: {
|
||
prompt,
|
||
images: params.images || [],
|
||
is_pro: params.isPro || false,
|
||
duration: params.duration || 5000,
|
||
ratio: params.ratio || '9:16'
|
||
},
|
||
internal_params: {
|
||
refresh_token: refreshToken
|
||
},
|
||
status: 'pending',
|
||
retry_count: 0,
|
||
max_retries: 3,
|
||
poll_interval: 15, // 视频轮询间隔更长
|
||
task_timeout: videoTimeout,
|
||
created_at: Math.floor(Date.now() / 1000),
|
||
updated_at: Math.floor(Date.now() / 1000)
|
||
});
|
||
|
||
logger.info(`Video task created: ${taskId} for server: ${currentServerId}`);
|
||
|
||
} catch (error) {
|
||
logger.error(`Failed to create video task ${taskId}:`, error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 查询任务结果
|
||
* 从结果表查询,如果存在则返回结果并标记为已读取,依靠TTL自动清理过期数据
|
||
*/
|
||
async queryTaskResult(taskId: string): Promise<any> {
|
||
try {
|
||
// 确保 NeDB 数据库初始化
|
||
await NeDBManager.initialize();
|
||
|
||
// 1. 先查询结果表
|
||
const result = await GenerationResult.findOne({ task_id: taskId });
|
||
|
||
if (result) {
|
||
// 找到结果,标记为已读取
|
||
const currentTime = Math.floor(Date.now() / 1000);
|
||
const updateData: any = {
|
||
read_count: (result.read_count || 0) + 1
|
||
};
|
||
|
||
// 如果是首次读取,记录首次读取时间
|
||
if (!result.is_read) {
|
||
updateData.is_read = true;
|
||
updateData.first_read_at = currentTime;
|
||
}
|
||
|
||
// 更新读取状态
|
||
await GenerationResult.updateOne(
|
||
{ task_id: taskId },
|
||
{ $set: updateData }
|
||
);
|
||
|
||
const response = {
|
||
created: Math.floor(Date.now() / 1000),
|
||
data: {
|
||
task_id: taskId,
|
||
url: result.tos_urls.join(','),
|
||
status: result.status === 'success' ? -1 : -2
|
||
}
|
||
};
|
||
|
||
logger.info(`Task result retrieved: ${taskId}, status: ${result.status}, read_count: ${updateData.read_count}`);
|
||
return response;
|
||
}
|
||
|
||
// 2. 查询任务状态
|
||
const task = await GenerationTask.findOne({ task_id: taskId });
|
||
|
||
if (!task) {
|
||
return {
|
||
created: Math.floor(Date.now() / 1000),
|
||
data: { task_id: taskId, url: "", status: 0 } // 任务不存在
|
||
};
|
||
}
|
||
|
||
// 3. 根据任务状态返回
|
||
const statusMap = {
|
||
'pending': 0,
|
||
'processing': 0,
|
||
'polling': 0,
|
||
'failed': -2,
|
||
'completed': -1 // 这种情况理论上不会出现,因为completed会生成result
|
||
};
|
||
|
||
const responseStatus = statusMap[task.status] || 0;
|
||
|
||
logger.debug(`Task status queried: ${taskId}, status: ${task.status} -> ${responseStatus}`);
|
||
|
||
return {
|
||
created: Math.floor(Date.now() / 1000),
|
||
data: {
|
||
task_id: taskId,
|
||
url: "",
|
||
status: responseStatus
|
||
}
|
||
};
|
||
|
||
} catch (error) {
|
||
logger.error(`Failed to query task result ${taskId}:`, error);
|
||
// 发生错误时返回任务不存在状态
|
||
return {
|
||
created: Math.floor(Date.now() / 1000),
|
||
data: { task_id: taskId, url: "", status: 0 },
|
||
message: `接口报错 taskId: ${taskId} error:${error.message}`,
|
||
};
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取任务统计信息
|
||
*/
|
||
async getTaskStats(serverId?: string): Promise<any> {
|
||
try {
|
||
const filter = serverId ? { server_id: serverId } : {};
|
||
|
||
// NeDB不支持aggregate,使用find替代
|
||
const allTasks = await GenerationTask.find(filter);
|
||
const stats = allTasks.reduce((acc, task) => {
|
||
acc[task.status] = (acc[task.status] || 0) + 1;
|
||
return acc;
|
||
}, {} as Record<string, number>);
|
||
|
||
const result = stats;
|
||
|
||
// 添加服务器负载信息
|
||
if (serverId) {
|
||
result['server_load'] = await this.getServerLoad(serverId);
|
||
}
|
||
|
||
return result;
|
||
} catch (error) {
|
||
logger.error('Failed to get task stats:', error);
|
||
return {};
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取服务器负载
|
||
*/
|
||
async getServerLoad(serverId: string): Promise<number> {
|
||
try {
|
||
const tasks = await GenerationTask.find({
|
||
server_id: serverId,
|
||
status: { $in: ['processing', 'polling'] }
|
||
});
|
||
return tasks.length;
|
||
} catch (error) {
|
||
logger.error(`Failed to get server load for ${serverId}:`, error);
|
||
return 0;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取任务详情
|
||
*/
|
||
async getTaskDetail(taskId: string): Promise<IGenerationTask | null> {
|
||
try {
|
||
return await GenerationTask.findOne({ task_id: taskId });
|
||
} catch (error) {
|
||
logger.error(`Failed to get task detail ${taskId}:`, error);
|
||
return null;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 取消任务
|
||
*/
|
||
async cancelTask(taskId: string): Promise<boolean> {
|
||
try {
|
||
const result = await GenerationTask.updateOne(
|
||
{
|
||
task_id: taskId,
|
||
status: { $in: ['pending', 'processing', 'polling'] }
|
||
},
|
||
{
|
||
$set: {
|
||
status: 'failed',
|
||
error_message: 'Task cancelled by user',
|
||
completed_at: Math.floor(Date.now() / 1000),
|
||
updated_at: Math.floor(Date.now() / 1000)
|
||
}
|
||
}
|
||
);
|
||
|
||
const cancelled = result > 0;
|
||
|
||
if (cancelled) {
|
||
logger.info(`Task cancelled: ${taskId}`);
|
||
|
||
// 创建取消结果记录
|
||
const currentTime = Math.floor(Date.now() / 1000);
|
||
const expireTime = currentTime + parseInt(process.env.RESULT_EXPIRE_TIME || '86400');
|
||
|
||
const task = await GenerationTask.findOne({ task_id: taskId });
|
||
if (task) {
|
||
await GenerationResult.create({
|
||
task_id: taskId,
|
||
task_type: task.task_type,
|
||
server_id: task.server_id,
|
||
status: 'failed',
|
||
original_urls: [],
|
||
tos_urls: [],
|
||
metadata: {
|
||
total_files: 0,
|
||
successful_uploads: 0,
|
||
fail_reason: 'Task cancelled by user'
|
||
},
|
||
created_at: currentTime,
|
||
expires_at: expireTime,
|
||
is_read: false,
|
||
read_count: 0
|
||
});
|
||
}
|
||
}
|
||
|
||
return cancelled;
|
||
} catch (error) {
|
||
logger.error(`Failed to cancel task ${taskId}:`, error);
|
||
return false;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 清理过期任务和结果
|
||
*/
|
||
async cleanupExpiredData(): Promise<{ tasks: number; results: number }> {
|
||
try {
|
||
const currentTime = Math.floor(Date.now() / 1000);
|
||
const taskTimeout = parseInt(process.env.IMAGE_TASK_TIMEOUT || '3600'); // 使用较短的作为默认值
|
||
const cutoffTime = currentTime - taskTimeout * 2; // 清理超过2倍超时时间的任务
|
||
|
||
// 清理过期任务
|
||
const taskResult = await GenerationTask.deleteMany({
|
||
status: { $in: ['completed', 'failed'] },
|
||
completed_at: { $lt: cutoffTime }
|
||
});
|
||
|
||
// 清理过期结果(由TTL索引自动处理,这里手动清理作为备份)
|
||
const resultResult = await GenerationResult.deleteMany({
|
||
expires_at: { $lt: currentTime }
|
||
});
|
||
|
||
const cleanupStats = {
|
||
tasks: taskResult || 0,
|
||
results: resultResult || 0
|
||
};
|
||
|
||
if (cleanupStats.tasks > 0 || cleanupStats.results > 0) {
|
||
logger.info(`Cleanup completed - tasks: ${cleanupStats.tasks}, results: ${cleanupStats.results}`);
|
||
}
|
||
|
||
return cleanupStats;
|
||
} catch (error) {
|
||
logger.error('Failed to cleanup expired data:', error);
|
||
return { tasks: 0, results: 0 };
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取服务信息
|
||
*/
|
||
getServiceInfo() {
|
||
return {
|
||
serverId: this.currentServerId,
|
||
type: 'DatabaseGenerationService',
|
||
version: '1.0.0'
|
||
};
|
||
}
|
||
}
|
||
|
||
export default DatabaseGenerationService.getInstance(); |