creative_studio/backend/app/api/v1/skills_async.py

226 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
异步 Skill 生成 API
将同步的 Skill 生成改为异步任务模式
"""
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from typing import Dict, Any, Optional, List
from pydantic import BaseModel
import asyncio
import json
import re
from app.core.llm.glm_client import get_glm_client
from app.core.skills.skill_manager import get_skill_manager
from app.core.task_manager import get_task_manager, TaskManager
from app.models.task import TaskType, TaskStatus
from app.utils.logger import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/skills/async", tags=["Skill 异步生成"])
# ============================================================================
# 请求模型
# ============================================================================
class GenerateSkillRequest(BaseModel):
"""生成 Skill 请求(异步)"""
description: str # 用户需求描述
category: Optional[str] = None
tags: Optional[List[str]] = None
temperature: float = 0.7
# ============================================================================
# 任务执行器
# ============================================================================
async def execute_generate_skill(
task_manager: TaskManager,
task_id: str,
params: Dict[str, Any]
) -> Dict[str, Any]:
"""执行 Skill 生成任务"""
try:
glm_client = get_glm_client()
skill_manager = get_skill_manager()
# 更新进度
task_manager.update_task_progress(
task_id, 10, 100,
"正在加载 skill-creator 标准..."
)
# 1. 加载 skill-creator 的行为指导
skill_creator = await skill_manager.load_skill("skill-creator")
if not skill_creator:
raise Exception("skill-creator 未找到,请确保内置 Skills 正确安装")
task_manager.update_task_progress(
task_id, 30, 100,
"正在构建提示词..."
)
# 2. 构建提示词
user_requirements = f"""用户想要创建的 Skill 描述:
{params['description']}
"""
if params.get('category'):
user_requirements += f"\n指定分类:{params['category']}"
system_prompt = f"""你是一个专业的 Skill 创建专家。
以下是 skill-creator 的行为指导(关于如何创建有效 Skill 的指南):
{'' * 60}
{skill_creator.behavior_guide}
{'' * 60}
你的任务是根据用户的需求,创建一个符合上述标准的 Skill。
**重要要求**
1. SKILL.md 必须以 YAML frontmatter 开始,包含 name 和 description 字段
2. description 应该清晰说明此 Skill 的用途和使用场景
3. 行为指导部分应该简洁、具体,避免冗余的解释
4. 使用 markdown 格式
5. 返回完整的 SKILL.md 内容
请以 JSON 格式返回结果,包含以下字段:
- suggested_id: 建议 Skill IDkebab-case如 dialogue-writer-ancient
- suggested_name: 建议 Skill 名称(简短中文)
- skill_content: 完整的 SKILL.md 内容
- category: 分类(如"编剧""审核""通用"等)
- suggested_tags: 建议标签数组
- explanation: 对生成的 Skill 的简要说明(中文)
"""
task_manager.update_task_progress(
task_id, 50, 100,
"正在调用 AI 生成..."
)
# 3. 调用 GLM 生成
response = await glm_client.chat(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_requirements}
],
temperature=params.get('temperature', 0.7)
)
task_manager.update_task_progress(
task_id, 90, 100,
"正在处理结果..."
)
# 4. 解析响应
import json
import re
response_text = response["choices"][0]["message"]["content"]
json_match = re.search(r'\{[\s\S]*\}', response_text)
if json_match:
result = json.loads(json_match.group())
else:
result = {
"suggested_id": "custom-skill",
"suggested_name": "自定义 Skill",
"skill_content": response_text,
"category": params.get('category', '通用'),
"suggested_tags": ["自定义"],
"explanation": "AI 生成的 Skill 内容"
}
task_manager.update_task_progress(
task_id, 100, 100,
"生成完成!"
)
return result
except Exception as e:
logger.error(f"Skill 生成失败: {str(e)}")
raise
# ============================================================================
# 异步任务创建端点
# ============================================================================
@router.post("/generate")
async def generate_skill_async(
request: GenerateSkillRequest,
task_manager: TaskManager = Depends(get_task_manager)
):
"""
异步生成 Skill
返回任务ID需要通过轮询 /tasks/{task_id} 获取结果
"""
# 创建任务
task = task_manager.create_task(
task_type=TaskType.GENERATE_SKILL,
params=request.dict(),
project_id=None
)
# 在后台执行
async def run_task():
await task_manager.execute_task_async(
task.id,
lambda p: execute_generate_skill(task_manager, task.id, p)
)
asyncio.create_task(run_task())
return {
"success": True,
"taskId": task.id,
"message": "任务已创建,正在后台执行"
}
@router.get("/task/{task_id}")
async def get_skill_generation_task(
task_id: str,
task_manager: TaskManager = Depends(get_task_manager)
):
"""
获取技能生成任务状态
Args:
task_id: 任务ID
Returns:
任务详细信息
"""
task = task_manager.get_task(task_id)
if not task:
raise HTTPException(
status_code=404,
detail=f"任务不存在: {task_id}"
)
return task.model_dump()
@router.get("/tasks/running")
async def get_running_skill_tasks(
task_manager: TaskManager = Depends(get_task_manager)
):
"""
获取所有正在运行的技能生成任务
Returns:
正在运行的任务列表
"""
tasks = task_manager.get_tasks_by_type(TaskType.GENERATE_SKILL)
running_tasks = [task for task in tasks if task.status == TaskStatus.RUNNING]
return {
"success": True,
"tasks": [task.model_dump() for task in running_tasks],
"count": len(running_tasks)
}