433 lines
12 KiB
Python
433 lines
12 KiB
Python
"""
|
||
异步 AI 辅助生成 API
|
||
|
||
将原有的同步 AI 生成改为异步任务模式
|
||
"""
|
||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||
from typing import Dict, Any, Optional, List
|
||
from pydantic import BaseModel
|
||
import asyncio
|
||
|
||
from app.core.llm.glm_client import get_glm_client
|
||
from app.core.skills.skill_manager import get_skill_manager
|
||
from app.core.task_manager import get_task_manager, TaskManager
|
||
from app.models.task import TaskType, TaskStatus
|
||
from app.api.v1.ai_assistant import build_enhanced_system_prompt
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
router = APIRouter(prefix="/ai-assistant/async", tags=["AI 异步生成"])
|
||
|
||
|
||
# ============================================================================
|
||
# 请求模型
|
||
# ============================================================================
|
||
|
||
class SkillInfo(BaseModel):
|
||
"""Skill 信息(由前端传递)"""
|
||
id: str
|
||
name: str
|
||
behavior: str # behavior_guide 内容
|
||
|
||
|
||
# ============================================================================
|
||
# 任务执行器
|
||
# ============================================================================
|
||
|
||
async def execute_generate_characters(
|
||
task_manager: TaskManager,
|
||
task_id: str,
|
||
params: Dict[str, Any]
|
||
) -> Dict[str, Any]:
|
||
"""执行人物生成任务"""
|
||
try:
|
||
glm_client = get_glm_client()
|
||
skill_manager = get_skill_manager()
|
||
|
||
# 更新进度
|
||
task_manager.update_task_progress(
|
||
task_id, 10, 100,
|
||
"正在构建提示词..."
|
||
)
|
||
|
||
# 构建系统提示词
|
||
base_role = "你是专业的剧集创作专家,擅长创作丰富立体的人物角色。"
|
||
skills = params.get("skills", [])
|
||
system_prompt = await build_enhanced_system_prompt(
|
||
base_role=base_role,
|
||
skills=skills,
|
||
skill_manager=skill_manager
|
||
)
|
||
|
||
# 构建用户提示
|
||
extra_info = ""
|
||
if params.get("projectName"):
|
||
extra_info += f"\n项目名称:{params['projectName']}"
|
||
if params.get("totalEpisodes"):
|
||
extra_info += f"\n总集数:{params['totalEpisodes']}"
|
||
if params.get("genre"):
|
||
extra_info += f"\n类型:{params['genre']}"
|
||
|
||
custom_requirements = ""
|
||
if params.get("customPrompt"):
|
||
custom_requirements = f"\n【用户自定义要求】\n{params['customPrompt']}\n"
|
||
|
||
prompt = f"""请根据以下想法生成 3-5 个主要人物设定:
|
||
|
||
用户想法:{params['idea']}{extra_info}{custom_requirements}
|
||
|
||
要求:
|
||
1. 每个人物包含:姓名、身份、性格、说话风格、背景故事
|
||
2. 人物之间要有关系冲突
|
||
3. 每个人物 50-100 字
|
||
4. 格式:姓名:身份 - 性格 - 说话风格 - 背景故事
|
||
5. 严格遵守上面【应用技能指导】中的要求
|
||
|
||
请按以下格式输出:
|
||
【人物1】
|
||
姓名:xxx
|
||
身份:xxx
|
||
性格:xxx
|
||
说话风格:xxx
|
||
背景故事:xxx
|
||
【人物2】
|
||
...
|
||
"""
|
||
|
||
task_manager.update_task_progress(
|
||
task_id, 30, 100,
|
||
"正在调用 AI 生成..."
|
||
)
|
||
|
||
response = await glm_client.chat(
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=0.9
|
||
)
|
||
|
||
task_manager.update_task_progress(
|
||
task_id, 90, 100,
|
||
"正在处理结果..."
|
||
)
|
||
|
||
content = response["choices"][0]["message"]["content"]
|
||
|
||
task_manager.update_task_progress(
|
||
task_id, 100, 100,
|
||
"生成完成!"
|
||
)
|
||
|
||
return {
|
||
"success": True,
|
||
"characters": content,
|
||
"usage": response.get("usage")
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"人物生成失败: {str(e)}")
|
||
raise
|
||
|
||
|
||
async def execute_generate_outline(
|
||
task_manager: TaskManager,
|
||
task_id: str,
|
||
params: Dict[str, Any]
|
||
) -> Dict[str, Any]:
|
||
"""执行大纲生成任务"""
|
||
try:
|
||
glm_client = get_glm_client()
|
||
skill_manager = get_skill_manager()
|
||
|
||
task_manager.update_task_progress(
|
||
task_id, 10, 100,
|
||
"正在构建提示词..."
|
||
)
|
||
|
||
base_role = "你是专业的剧集创作专家,擅长构建引人入胜的剧情结构和故事节奏。"
|
||
skills = params.get("skills", [])
|
||
system_prompt = await build_enhanced_system_prompt(
|
||
base_role=base_role,
|
||
skills=skills,
|
||
skill_manager=skill_manager
|
||
)
|
||
|
||
custom_requirements = ""
|
||
if params.get("customPrompt"):
|
||
custom_requirements = f"\n【用户自定义要求】\n{params['customPrompt']}\n"
|
||
|
||
prompt = f"""请根据以下想法生成完整的剧集大纲:
|
||
|
||
用户想法:{params['idea']}
|
||
总集数:{params['totalEpisodes']}
|
||
类型:{params['genre']}
|
||
{f"项目名称:{params['projectName']}" if params.get('projectName') else ''}{custom_requirements}
|
||
|
||
要求:
|
||
1. 将故事分为 4-5 个阶段
|
||
2. 每个阶段包含具体的集数范围
|
||
3. 标注每个阶段的关键事件和转折点
|
||
4. 字数 200-400 字
|
||
5. 严格遵守上面【应用技能指导】中的要求
|
||
|
||
请按以下格式输出:
|
||
【阶段1】EPxx-EPxx:阶段名称
|
||
内容概要...
|
||
|
||
【阶段2】EPxx-EPxx:阶段名称
|
||
...
|
||
"""
|
||
|
||
task_manager.update_task_progress(
|
||
task_id, 30, 100,
|
||
"正在调用 AI 生成..."
|
||
)
|
||
|
||
response = await glm_client.chat(
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=0.9
|
||
)
|
||
|
||
task_manager.update_task_progress(
|
||
task_id, 90, 100,
|
||
"正在处理结果..."
|
||
)
|
||
|
||
content = response["choices"][0]["message"]["content"]
|
||
|
||
task_manager.update_task_progress(
|
||
task_id, 100, 100,
|
||
"生成完成!"
|
||
)
|
||
|
||
return {
|
||
"success": True,
|
||
"outline": content,
|
||
"usage": response.get("usage")
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"大纲生成失败: {str(e)}")
|
||
raise
|
||
|
||
|
||
async def execute_generate_world(
|
||
task_manager: TaskManager,
|
||
task_id: str,
|
||
params: Dict[str, Any]
|
||
) -> Dict[str, Any]:
|
||
"""执行世界观生成任务"""
|
||
try:
|
||
glm_client = get_glm_client()
|
||
skill_manager = get_skill_manager()
|
||
|
||
task_manager.update_task_progress(
|
||
task_id, 10, 100,
|
||
"正在构建提示词..."
|
||
)
|
||
|
||
base_role = "你是专业的世界观设定专家,擅长构建架空世界的背景设定。"
|
||
skills = params.get("skills", [])
|
||
system_prompt = await build_enhanced_system_prompt(
|
||
base_role=base_role,
|
||
skills=skills,
|
||
skill_manager=skill_manager
|
||
)
|
||
|
||
custom_requirements = ""
|
||
if params.get("customPrompt"):
|
||
custom_requirements = f"\n【用户自定义要求】\n{params['customPrompt']}\n"
|
||
|
||
prompt = f"""请根据以下想法生成世界观设定:
|
||
|
||
用户想法:{params['idea']}
|
||
类型:{params['genre']}
|
||
{f"项目名称:{params['projectName']}" if params.get('projectName') else ''}{custom_requirements}
|
||
|
||
要求:
|
||
1. 描述时代背景(朝代、架空世界等)
|
||
2. 描述地理环境和主要场景
|
||
3. 描述社会结构(权力体系、阶级关系)
|
||
4. 描述文化特色(习俗、服饰、语言等)
|
||
5. 字数 200-500 字
|
||
6. 严格遵守上面【应用技能指导】中的要求
|
||
|
||
请输出详细的世界观设定:
|
||
"""
|
||
|
||
task_manager.update_task_progress(
|
||
task_id, 30, 100,
|
||
"正在调用 AI 生成..."
|
||
)
|
||
|
||
response = await glm_client.chat(
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=0.9
|
||
)
|
||
|
||
task_manager.update_task_progress(
|
||
task_id, 90, 100,
|
||
"正在处理结果..."
|
||
)
|
||
|
||
content = response["choices"][0]["message"]["content"]
|
||
|
||
task_manager.update_task_progress(
|
||
task_id, 100, 100,
|
||
"生成完成!"
|
||
)
|
||
|
||
return {
|
||
"success": True,
|
||
"worldSetting": content,
|
||
"usage": response.get("usage")
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"世界观生成失败: {str(e)}")
|
||
raise
|
||
|
||
|
||
# ============================================================================
|
||
# 异步任务创建端点
|
||
# ============================================================================
|
||
|
||
class GenerateCharactersRequest(BaseModel):
|
||
"""生成人物设定请求(异步)"""
|
||
idea: str
|
||
projectName: Optional[str] = None
|
||
totalEpisodes: Optional[int] = None
|
||
genre: Optional[str] = "古风"
|
||
skills: Optional[List[SkillInfo]] = None
|
||
customPrompt: Optional[str] = None
|
||
projectId: Optional[str] = None # 关联项目ID
|
||
|
||
|
||
class GenerateOutlineRequest(BaseModel):
|
||
"""生成大纲请求(异步)"""
|
||
idea: str
|
||
totalEpisodes: int = 30
|
||
genre: str = "古风"
|
||
projectName: Optional[str] = None
|
||
skills: Optional[List[SkillInfo]] = None
|
||
customPrompt: Optional[str] = None
|
||
projectId: Optional[str] = None
|
||
|
||
|
||
class GenerateWorldRequest(BaseModel):
|
||
"""生成世界观请求(异步)"""
|
||
idea: str
|
||
projectName: Optional[str] = None
|
||
genre: Optional[str] = "古风"
|
||
skills: Optional[List[SkillInfo]] = None
|
||
customPrompt: Optional[str] = None
|
||
projectId: Optional[str] = None
|
||
|
||
|
||
@router.post("/generate/characters")
|
||
async def generate_characters_async(
|
||
request: GenerateCharactersRequest,
|
||
background_tasks: BackgroundTasks,
|
||
task_manager: TaskManager = Depends(get_task_manager)
|
||
):
|
||
"""
|
||
异步生成人物设定
|
||
|
||
返回任务ID,需要通过轮询 /tasks/{task_id} 获取结果
|
||
"""
|
||
# 创建任务
|
||
task = task_manager.create_task(
|
||
task_type=TaskType.GENERATE_CHARACTERS,
|
||
params=request.dict(exclude={"projectId"}),
|
||
project_id=request.projectId
|
||
)
|
||
|
||
# 在后台执行
|
||
async def run_task():
|
||
await task_manager.execute_task_async(
|
||
task.id,
|
||
lambda p: execute_generate_characters(task_manager, task.id, p)
|
||
)
|
||
|
||
asyncio.create_task(run_task())
|
||
|
||
return {
|
||
"success": True,
|
||
"taskId": task.id,
|
||
"message": "任务已创建,正在后台执行"
|
||
}
|
||
|
||
|
||
@router.post("/generate/outline")
|
||
async def generate_outline_async(
|
||
request: GenerateOutlineRequest,
|
||
task_manager: TaskManager = Depends(get_task_manager)
|
||
):
|
||
"""
|
||
异步生成大纲
|
||
|
||
返回任务ID,需要通过轮询 /tasks/{task_id} 获取结果
|
||
"""
|
||
# 创建任务
|
||
task = task_manager.create_task(
|
||
task_type=TaskType.GENERATE_OUTLINE,
|
||
params=request.dict(exclude={"projectId"}),
|
||
project_id=request.projectId
|
||
)
|
||
|
||
# 在后台执行
|
||
async def run_task():
|
||
await task_manager.execute_task_async(
|
||
task.id,
|
||
lambda p: execute_generate_outline(task_manager, task.id, p)
|
||
)
|
||
|
||
asyncio.create_task(run_task())
|
||
|
||
return {
|
||
"success": True,
|
||
"taskId": task.id,
|
||
"message": "任务已创建,正在后台执行"
|
||
}
|
||
|
||
|
||
@router.post("/generate/world")
|
||
async def generate_world_async(
|
||
request: GenerateWorldRequest,
|
||
task_manager: TaskManager = Depends(get_task_manager)
|
||
):
|
||
"""
|
||
异步生成世界观设定
|
||
|
||
返回任务ID,需要通过轮询 /tasks/{task_id} 获取结果
|
||
"""
|
||
# 创建任务
|
||
task = task_manager.create_task(
|
||
task_type=TaskType.GENERATE_WORLD,
|
||
params=request.dict(exclude={"projectId"}),
|
||
project_id=request.projectId
|
||
)
|
||
|
||
# 在后台执行
|
||
async def run_task():
|
||
await task_manager.execute_task_async(
|
||
task.id,
|
||
lambda p: execute_generate_world(task_manager, task.id, p)
|
||
)
|
||
|
||
asyncio.create_task(run_task())
|
||
|
||
return {
|
||
"success": True,
|
||
"taskId": task.id,
|
||
"message": "任务已创建,正在后台执行"
|
||
}
|