creative_studio/backend/app/api/v1/skills_async.py
2026-02-03 01:12:39 +08:00

310 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
异步 Skill 生成 API
将同步的 Skill 生成改为异步任务模式
特性:
1. 异步生成,后台执行
2. 生成完成后自动保存到 skill-storage
3. 支持关闭弹窗后继续生成
4. 任务持久化,重启后可恢复
"""
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from typing import Dict, Any, Optional, List
from pydantic import BaseModel
import asyncio
import json
import re
import uuid
from app.core.llm.glm_client import get_glm_client
from app.core.skills.skill_manager import get_skill_manager
from app.core.task_manager import get_task_manager, TaskManager
from app.models.task import TaskType, TaskStatus
from app.models.skill import SkillCreate
from app.utils.logger import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/skills/async", tags=["Skill 异步生成"])
# ============================================================================
# 请求模型
# ============================================================================
class GenerateSkillRequest(BaseModel):
"""生成 Skill 请求(异步)"""
description: str # 用户需求描述
category: Optional[str] = None
tags: Optional[List[str]] = None
temperature: float = 0.7
# ============================================================================
# 任务执行器
# ============================================================================
async def execute_generate_skill_with_id(
task_id: str,
params: Dict[str, Any]
) -> Dict[str, Any]:
"""执行 Skill 生成任务(生成完成后自动保存)
Args:
task_id: 任务ID
params: 任务参数,包含 description, category, tags, temperature 等
Returns:
生成的 Skill 数据
"""
try:
task_manager = get_task_manager()
glm_client = get_glm_client()
skill_manager = get_skill_manager()
# 更新进度
task_manager.update_task_progress(
task_id, 10, 100,
"正在加载 skill-creator 标准..."
)
# 1. 加载 skill-creator 的行为指导
skill_creator = await skill_manager.load_skill("skill-creator")
if not skill_creator:
raise Exception("skill-creator 未找到,请确保内置 Skills 正确安装")
task_manager.update_task_progress(
task_id, 30, 100,
"正在构建提示词..."
)
# 2. 构建提示词
user_requirements = f"""用户想要创建的 Skill 描述:
{params['description']}
"""
if params.get('category'):
user_requirements += f"\n指定分类:{params['category']}"
system_prompt = f"""你是一个专业的 Skill 创建专家。
以下是 skill-creator 的行为指导(关于如何创建有效 Skill 的指南):
{'' * 60}
{skill_creator.behavior_guide}
{'' * 60}
你的任务是根据用户的需求,创建一个符合上述标准的 Skill。
**重要要求**
1. SKILL.md 必须以 YAML frontmatter 开始,包含 name 和 description 字段
2. description 应该清晰说明此 Skill 的用途和使用场景
3. 行为指导部分应该简洁、具体,避免冗余的解释
4. 使用 markdown 格式
5. 返回完整的 SKILL.md 内容
请以 JSON 格式返回结果,包含以下字段:
- suggested_id: 建议 Skill IDkebab-case如 dialogue-writer-ancient
- suggested_name: 建议 Skill 名称(简短中文)
- skill_content: 完整的 SKILL.md 内容
- category: 分类(如"编剧""审核""通用"等)
- suggested_tags: 建议标签数组
- explanation: 对生成的 Skill 的简要说明(中文)
"""
task_manager.update_task_progress(
task_id, 50, 100,
"正在调用 AI 生成..."
)
# 3. 调用 GLM 生成
response = await glm_client.chat(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_requirements}
],
temperature=params.get('temperature', 0.7)
)
task_manager.update_task_progress(
task_id, 75, 100,
"正在解析结果..."
)
# 4. 解析响应
response_text = response["choices"][0]["message"]["content"]
json_match = re.search(r'\{[\s\S]*\}', response_text)
if json_match:
result = json.loads(json_match.group())
else:
result = {
"suggested_id": "custom-skill",
"suggested_name": "自定义 Skill",
"skill_content": response_text,
"category": params.get('category', '通用'),
"suggested_tags": ["自定义"],
"explanation": "AI 生成的 Skill 内容"
}
skill_content = result.get("skill_content", "")
suggested_id = result.get("suggested_id", "custom-skill")
suggested_name = result.get("suggested_name", "自定义 Skill")
category = result.get("category", params.get('category', '通用'))
suggested_tags = result.get("suggested_tags", [])
task_manager.update_task_progress(
task_id, 90, 100,
"正在自动保存 Skill..."
)
# 5. 自动保存到 storage
# 确保 ID 唯一(如果已存在,添加随机后缀)
final_skill_id = suggested_id
counter = 1
while True:
try:
# 检查 skill 是否已存在
existing_skill = await skill_manager.load_skill(final_skill_id)
if existing_skill:
# ID 已存在,添加后缀
final_skill_id = f"{suggested_id}-{counter}"
counter += 1
else:
break
except:
# 加载失败说明不存在,可以使用这个 ID
break
# 创建并保存 skill
skill_data = SkillCreate(
id=final_skill_id,
name=suggested_name,
content=skill_content,
category=category,
tags=suggested_tags
)
try:
saved_skill = await skill_manager.create_user_skill(skill_data)
logger.info(f"自动保存 Skill 成功: {final_skill_id}")
# 更新结果中的 ID 为实际保存的 ID
result["suggested_id"] = final_skill_id
result["saved_skill_id"] = final_skill_id
result["auto_saved"] = True
except Exception as save_error:
logger.error(f"自动保存 Skill 失败: {str(save_error)}")
result["auto_saved"] = False
result["save_error"] = str(save_error)
# 即使保存失败,也返回生成的结果
result["suggested_id"] = final_skill_id
task_manager.update_task_progress(
task_id, 100, 100,
f"生成完成Skill 已保存为: {final_skill_id}"
)
return result
except Exception as e:
logger.error(f"Skill 生成失败: {str(e)}")
raise
async def execute_generate_skill(params: Dict[str, Any]) -> Dict[str, Any]:
"""兼容旧版本的执行函数通过查找当前任务获取task_id"""
task_manager = get_task_manager()
# 获取当前正在运行的任务
running_tasks = task_manager.get_tasks_by_type(TaskType.GENERATE_SKILL)
current_task = None
for task in running_tasks:
if task.status == TaskStatus.RUNNING:
current_task = task
break
if not current_task:
raise Exception("找不到正在运行的任务")
return await execute_generate_skill_with_id(current_task.id, params)
# ============================================================================
# 异步任务创建端点
# ============================================================================
@router.post("/generate")
async def generate_skill_async(
request: GenerateSkillRequest,
task_manager: TaskManager = Depends(get_task_manager)
):
"""
异步生成 Skill
返回任务ID需要通过轮询 /tasks/{task_id} 获取结果
"""
# 创建任务
task = task_manager.create_task(
task_type=TaskType.GENERATE_SKILL,
params=request.dict(),
project_id=None
)
# 在后台执行
async def run_task():
await task_manager.execute_task_async(
task.id,
execute_generate_skill
)
asyncio.create_task(run_task())
return {
"success": True,
"taskId": task.id,
"message": "任务已创建,正在后台执行"
}
@router.get("/task/{task_id}")
async def get_skill_generation_task(
task_id: str,
task_manager: TaskManager = Depends(get_task_manager)
):
"""
获取技能生成任务状态
Args:
task_id: 任务ID
Returns:
任务详细信息
"""
task = task_manager.get_task(task_id)
if not task:
raise HTTPException(
status_code=404,
detail=f"任务不存在: {task_id}"
)
return task.model_dump()
@router.get("/tasks/running")
async def get_running_skill_tasks(
task_manager: TaskManager = Depends(get_task_manager)
):
"""
获取所有正在运行的技能生成任务
Returns:
正在运行的任务列表
"""
tasks = task_manager.get_tasks_by_type(TaskType.GENERATE_SKILL)
running_tasks = [task for task in tasks if task.status == TaskStatus.RUNNING]
return {
"success": True,
"tasks": [task.model_dump() for task in running_tasks],
"count": len(running_tasks)
}