730 lines
20 KiB
Python
730 lines
20 KiB
Python
"""
|
|
记忆系统 API 路由
|
|
|
|
Memory System API - 提供记忆系统的管理接口
|
|
"""
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from typing import List, Dict
|
|
from datetime import datetime
|
|
import uuid
|
|
|
|
from app.models.memory import (
|
|
EnhancedMemory,
|
|
TimelineEvent,
|
|
PendingThread,
|
|
ThreadCreateRequest,
|
|
ThreadUpdateRequest,
|
|
ThreadResolveRequest,
|
|
MemoryUpdateRequest,
|
|
MemoryExtractionResult,
|
|
CharacterStateChange,
|
|
ResolutionSuggestion,
|
|
ThreadStatus,
|
|
ImportanceLevel
|
|
)
|
|
from app.models.project import SeriesProject, Episode
|
|
from app.core.memory.memory_manager import get_memory_manager, MemoryManager
|
|
from app.db.repositories import project_repo, episode_repo
|
|
from app.utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
router = APIRouter(prefix="/projects/{project_id}/memory", tags=["记忆系统"])
|
|
|
|
|
|
# ============================================
|
|
# 辅助函数
|
|
# ============================================
|
|
|
|
async def verify_project(project_id: str) -> SeriesProject:
|
|
"""验证项目存在性"""
|
|
project = await project_repo.get(project_id)
|
|
if not project:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"项目不存在: {project_id}"
|
|
)
|
|
return project
|
|
|
|
|
|
# ============================================
|
|
# 记忆系统主端点
|
|
# ============================================
|
|
|
|
@router.get("/", response_model=EnhancedMemory)
|
|
async def get_project_memory(project_id: str):
|
|
"""
|
|
获取项目完整记忆系统
|
|
|
|
返回项目的所有记忆数据,包括:
|
|
- 事件时间线
|
|
- 待收线问题
|
|
- 角色状态历史
|
|
- 伏笔追踪
|
|
- 角色关系
|
|
"""
|
|
project = await verify_project(project_id)
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
return enhanced_memory
|
|
|
|
|
|
@router.post("/update", response_model=MemoryExtractionResult)
|
|
async def update_memory_from_episode(
|
|
project_id: str,
|
|
request: MemoryUpdateRequest
|
|
):
|
|
"""
|
|
从剧集更新记忆系统
|
|
|
|
自动执行以下操作:
|
|
- 提取关键事件
|
|
- 检测伏笔和待收线
|
|
- 更新角色状态
|
|
- 检查一致性
|
|
|
|
Args:
|
|
request: 记忆更新请求
|
|
"""
|
|
project = await verify_project(project_id)
|
|
|
|
# 获取剧集
|
|
episodes = await episode_repo.list_by_project(project_id)
|
|
episode = None
|
|
for ep in episodes:
|
|
if ep.number == request.episode_number:
|
|
episode = ep
|
|
break
|
|
|
|
if not episode:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"剧集不存在: EP{request.episode_number}"
|
|
)
|
|
|
|
if not episode.content:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="剧集内容为空,无法更新记忆"
|
|
)
|
|
|
|
try:
|
|
memory_manager = get_memory_manager()
|
|
result = await memory_manager.update_memory_from_episode(project, episode)
|
|
|
|
# 保存更新后的记忆
|
|
await project_repo.update(project_id, {
|
|
"memory": project.memory.dict()
|
|
})
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"更新记忆失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"更新记忆失败: {str(e)}"
|
|
)
|
|
|
|
|
|
# ============================================
|
|
# 事件时间线
|
|
# ============================================
|
|
|
|
@router.get("/timeline", response_model=List[TimelineEvent])
|
|
async def get_event_timeline(
|
|
project_id: str,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
episode: int = None,
|
|
importance: ImportanceLevel = None,
|
|
character: str = None
|
|
):
|
|
"""
|
|
获取事件时间线
|
|
|
|
支持过滤:
|
|
- episode: 按集数过滤
|
|
- importance: 按重要程度过滤
|
|
- character: 按角色过滤
|
|
"""
|
|
project = await verify_project(project_id)
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
|
|
events = enhanced_memory.eventTimeline
|
|
|
|
# 应用过滤
|
|
if episode is not None:
|
|
events = [e for e in events if e.episode == episode]
|
|
|
|
if importance is not None:
|
|
events = [e for e in events if e.importance == importance]
|
|
|
|
if character:
|
|
events = [e for e in events if character in e.characters_involved]
|
|
|
|
# 按集数排序
|
|
events = sorted(events, key=lambda x: x.episode)
|
|
|
|
# 分页
|
|
return events[skip:skip + limit]
|
|
|
|
|
|
@router.get("/timeline/episodes", response_model=List[dict])
|
|
async def get_timeline_by_episodes(project_id: str):
|
|
"""
|
|
按集数分组获取时间线
|
|
|
|
返回格式:[{episode: 1, events: [...]}, ...]
|
|
"""
|
|
project = await verify_project(project_id)
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
|
|
# 按集数分组
|
|
episodes_dict = {}
|
|
for event in enhanced_memory.eventTimeline:
|
|
ep = event.episode
|
|
if ep not in episodes_dict:
|
|
episodes_dict[ep] = []
|
|
episodes_dict[ep].append(event)
|
|
|
|
# 转换为列表并排序
|
|
result = [
|
|
{"episode": ep, "events": events}
|
|
for ep, events in sorted(episodes_dict.items())
|
|
]
|
|
|
|
return result
|
|
|
|
|
|
# ============================================
|
|
# 待收线管理
|
|
# ============================================
|
|
|
|
@router.get("/threads", response_model=List[PendingThread])
|
|
async def get_pending_threads(
|
|
project_id: str,
|
|
status: ThreadStatus = None,
|
|
importance: ImportanceLevel = None
|
|
):
|
|
"""
|
|
获取待收线列表
|
|
|
|
支持过滤:
|
|
- status: 按状态过滤
|
|
- importance: 按重要程度过滤
|
|
"""
|
|
project = await verify_project(project_id)
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
|
|
threads = enhanced_memory.pendingThreads
|
|
|
|
# 应用过滤
|
|
if status is not None:
|
|
threads = [t for t in threads if t.status == status]
|
|
|
|
if importance is not None:
|
|
threads = [t for t in threads if t.importance == importance]
|
|
|
|
# 按引入集数排序
|
|
threads = sorted(threads, key=lambda x: x.introduced_at)
|
|
|
|
return threads
|
|
|
|
|
|
@router.post("/threads", response_model=PendingThread, status_code=status.HTTP_201_CREATED)
|
|
async def create_thread(
|
|
project_id: str,
|
|
request: ThreadCreateRequest
|
|
):
|
|
"""
|
|
手动添加待收线问题
|
|
|
|
允许用户手动添加待收线,用于补充 LLM 未检测到的重要伏笔
|
|
"""
|
|
project = await verify_project(project_id)
|
|
|
|
# 获取当前最大集数
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
current_episode = enhanced_memory.last_episode_processed
|
|
|
|
thread = PendingThread(
|
|
id=str(uuid.uuid4()),
|
|
description=request.description,
|
|
introduced_at=current_episode,
|
|
importance=request.importance,
|
|
reminder_episode=request.reminder_episode,
|
|
characters_involved=request.characters_involved,
|
|
notes=request.notes,
|
|
created_at=datetime.now(),
|
|
updated_at=datetime.now()
|
|
)
|
|
|
|
# 添加到记忆
|
|
enhanced_memory.pendingThreads.append(thread)
|
|
project.memory = memory_manager._convert_to_memory(enhanced_memory)
|
|
|
|
# 保存
|
|
await project_repo.update(project_id, {
|
|
"memory": project.memory.dict()
|
|
})
|
|
|
|
logger.info(f"创建待收线: {thread.id} - {thread.description}")
|
|
return thread
|
|
|
|
|
|
@router.put("/threads/{thread_id}", response_model=PendingThread)
|
|
async def update_thread(
|
|
project_id: str,
|
|
thread_id: str,
|
|
request: ThreadUpdateRequest
|
|
):
|
|
"""
|
|
更新待收线信息
|
|
|
|
允许修改待收线的描述、状态、重要程度等
|
|
"""
|
|
project = await verify_project(project_id)
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
|
|
# 查找待收线
|
|
thread = None
|
|
for t in enhanced_memory.pendingThreads:
|
|
if t.id == thread_id:
|
|
thread = t
|
|
break
|
|
|
|
if not thread:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"待收线不存在: {thread_id}"
|
|
)
|
|
|
|
# 更新字段
|
|
if request.description is not None:
|
|
thread.description = request.description
|
|
|
|
if request.importance is not None:
|
|
thread.importance = request.importance
|
|
|
|
if request.status is not None:
|
|
thread.status = request.status
|
|
|
|
if request.reminder_episode is not None:
|
|
thread.reminder_episode = request.reminder_episode
|
|
|
|
if request.characters_involved is not None:
|
|
thread.characters_involved = request.characters_involved
|
|
|
|
if request.notes is not None:
|
|
thread.notes = request.notes
|
|
|
|
thread.updated_at = datetime.now()
|
|
|
|
# 保存
|
|
project.memory = memory_manager._convert_to_memory(enhanced_memory)
|
|
await project_repo.update(project_id, {
|
|
"memory": project.memory.dict()
|
|
})
|
|
|
|
logger.info(f"更新待收线: {thread_id}")
|
|
return thread
|
|
|
|
|
|
@router.post("/resolve-thread", response_model=dict)
|
|
async def resolve_thread(
|
|
project_id: str,
|
|
request: ThreadResolveRequest
|
|
):
|
|
"""
|
|
标记待收线为已解决
|
|
|
|
记录待收线的收线信息
|
|
"""
|
|
project = await verify_project(project_id)
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
|
|
# 查找待收线
|
|
thread = None
|
|
for t in enhanced_memory.pendingThreads:
|
|
if t.id == request.thread_id:
|
|
thread = t
|
|
break
|
|
|
|
if not thread:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"待收线不存在: {request.thread_id}"
|
|
)
|
|
|
|
# 更新状态
|
|
thread.resolved = True
|
|
thread.resolved_at = request.resolved_at
|
|
thread.status = ThreadStatus.RESOLVED
|
|
thread.updated_at = datetime.now()
|
|
|
|
# 添加收线摘要到 notes
|
|
if request.resolution_summary:
|
|
thread.notes += f"\n[收线摘要 EP{request.resolved_at}]: {request.resolution_summary}"
|
|
|
|
# 保存
|
|
project.memory = memory_manager._convert_to_memory(enhanced_memory)
|
|
await project_repo.update(project_id, {
|
|
"memory": project.memory.dict()
|
|
})
|
|
|
|
logger.info(f"收线完成: {request.thread_id} @ EP{request.resolved_at}")
|
|
|
|
return {
|
|
"success": True,
|
|
"message": f"待收线已收线",
|
|
"thread_id": request.thread_id,
|
|
"resolved_at": request.resolved_at
|
|
}
|
|
|
|
|
|
@router.delete("/threads/{thread_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_thread(
|
|
project_id: str,
|
|
thread_id: str
|
|
):
|
|
"""
|
|
删除待收线
|
|
|
|
移除不需要的待收线记录
|
|
"""
|
|
project = await verify_project(project_id)
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
|
|
# 查找并删除
|
|
original_count = len(enhanced_memory.pendingThreads)
|
|
enhanced_memory.pendingThreads = [
|
|
t for t in enhanced_memory.pendingThreads
|
|
if t.id != thread_id
|
|
]
|
|
|
|
if len(enhanced_memory.pendingThreads) == original_count:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"待收线不存在: {thread_id}"
|
|
)
|
|
|
|
# 保存
|
|
project.memory = memory_manager._convert_to_memory(enhanced_memory)
|
|
await project_repo.update(project_id, {
|
|
"memory": project.memory.dict()
|
|
})
|
|
|
|
logger.info(f"删除待收线: {thread_id}")
|
|
return None
|
|
|
|
|
|
@router.get("/threads/{thread_id}/suggestion", response_model=ResolutionSuggestion)
|
|
async def get_thread_resolution_suggestion(
|
|
project_id: str,
|
|
thread_id: str
|
|
):
|
|
"""
|
|
获取待收线的收线建议
|
|
|
|
使用 LLM 生成如何收线的建议
|
|
"""
|
|
project = await verify_project(project_id)
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
|
|
# 查找待收线
|
|
thread = None
|
|
for t in enhanced_memory.pendingThreads:
|
|
if t.id == thread_id:
|
|
thread = t
|
|
break
|
|
|
|
if not thread:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"待收线不存在: {thread_id}"
|
|
)
|
|
|
|
try:
|
|
suggestion = await memory_manager.suggest_thread_resolution(
|
|
thread=thread,
|
|
current_episode=enhanced_memory.last_episode_processed,
|
|
memory=enhanced_memory
|
|
)
|
|
|
|
return suggestion
|
|
|
|
except Exception as e:
|
|
logger.error(f"生成收线建议失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"生成收线建议失败: {str(e)}"
|
|
)
|
|
|
|
|
|
# ============================================
|
|
# 角色状态
|
|
# ============================================
|
|
|
|
@router.get("/characters", response_model=Dict[str, List[CharacterStateChange]])
|
|
async def get_character_states(
|
|
project_id: str,
|
|
character: str = None
|
|
):
|
|
"""
|
|
获取角色状态历史
|
|
|
|
Args:
|
|
character: 可选,筛选特定角色
|
|
"""
|
|
project = await verify_project(project_id)
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
|
|
if character:
|
|
if character not in enhanced_memory.characterStates:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"角色不存在: {character}"
|
|
)
|
|
return {character: enhanced_memory.characterStates[character]}
|
|
|
|
return enhanced_memory.characterStates
|
|
|
|
|
|
@router.get("/characters/{character_name}", response_model=List[CharacterStateChange])
|
|
async def get_character_state_history(
|
|
project_id: str,
|
|
character_name: str
|
|
):
|
|
"""
|
|
获取特定角色的状态变化历史
|
|
|
|
返回该角色在各集的状态变化记录
|
|
"""
|
|
project = await verify_project(project_id)
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
|
|
if character_name not in enhanced_memory.characterStates:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"角色不存在: {character_name}"
|
|
)
|
|
|
|
return enhanced_memory.characterStates[character_name]
|
|
|
|
|
|
# ============================================
|
|
# 伏笔追踪
|
|
# ============================================
|
|
|
|
@router.get("/foreshadowing")
|
|
async def get_foreshadowing(
|
|
project_id: str,
|
|
payoff_status: bool = None
|
|
):
|
|
"""
|
|
获取伏笔列表
|
|
|
|
Args:
|
|
payoff_status: 可选,筛选是否已呼应
|
|
"""
|
|
project = await verify_project(project_id)
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
|
|
foreshadowing = enhanced_memory.foreshadowing
|
|
|
|
if payoff_status is not None:
|
|
foreshadowing = [f for f in foreshadowing if f.is_payed_off == payoff_status]
|
|
|
|
return foreshadowing
|
|
|
|
|
|
# ============================================
|
|
# 一致性检查
|
|
# ============================================
|
|
|
|
@router.post("/check-consistency")
|
|
async def run_consistency_check(
|
|
project_id: str,
|
|
episode_number: int
|
|
):
|
|
"""
|
|
对指定剧集进行一致性检查
|
|
|
|
检查剧集与已有记忆的一致性
|
|
"""
|
|
project = await verify_project(project_id)
|
|
|
|
# 获取剧集
|
|
episodes = await episode_repo.list_by_project(project_id)
|
|
episode = None
|
|
for ep in episodes:
|
|
if ep.number == episode_number:
|
|
episode = ep
|
|
break
|
|
|
|
if not episode:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"剧集不存在: EP{episode_number}"
|
|
)
|
|
|
|
if not episode.content:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="剧集内容为空"
|
|
)
|
|
|
|
try:
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
|
|
issues = await memory_manager.check_consistency(
|
|
episode.content,
|
|
episode_number,
|
|
enhanced_memory
|
|
)
|
|
|
|
# 添加到记忆
|
|
enhanced_memory.consistencyIssues.extend(issues)
|
|
project.memory = memory_manager._convert_to_memory(enhanced_memory)
|
|
|
|
# 保存
|
|
await project_repo.update(project_id, {
|
|
"memory": project.memory.dict()
|
|
})
|
|
|
|
return {
|
|
"episode_number": episode_number,
|
|
"issues_found": len(issues),
|
|
"issues": issues
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"一致性检查失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"一致性检查失败: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get("/consistency-issues")
|
|
async def get_consistency_issues(
|
|
project_id: str,
|
|
severity: ImportanceLevel = None
|
|
):
|
|
"""
|
|
获取一致性问题列表
|
|
|
|
Args:
|
|
severity: 可选,按严重程度筛选
|
|
"""
|
|
project = await verify_project(project_id)
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
|
|
issues = enhanced_memory.consistencyIssues
|
|
|
|
if severity is not None:
|
|
issues = [i for i in issues if i.severity == severity]
|
|
|
|
return issues
|
|
|
|
|
|
# ============================================
|
|
# 角色关系
|
|
# ============================================
|
|
|
|
@router.get("/relationships")
|
|
async def get_relationships(project_id: str):
|
|
"""
|
|
获取角色关系网络
|
|
|
|
返回角色之间的关系映射
|
|
"""
|
|
project = await verify_project(project_id)
|
|
return project.memory.relationships
|
|
|
|
|
|
@router.put("/relationships")
|
|
async def update_relationships(
|
|
project_id: str,
|
|
relationships: Dict[str, Dict[str, str]]
|
|
):
|
|
"""
|
|
更新角色关系网络
|
|
|
|
允许手动更新角色之间的关系
|
|
"""
|
|
project = await verify_project(project_id)
|
|
|
|
project.memory.relationships = relationships
|
|
|
|
await project_repo.update(project_id, {
|
|
"memory": project.memory.dict()
|
|
})
|
|
|
|
logger.info(f"更新角色关系网络")
|
|
return project.memory.relationships
|
|
|
|
|
|
# ============================================
|
|
# 统计信息
|
|
# ============================================
|
|
|
|
@router.get("/stats")
|
|
async def get_memory_stats(project_id: str):
|
|
"""
|
|
获取记忆系统统计信息
|
|
|
|
返回:
|
|
- 总事件数
|
|
- 待收线数量
|
|
- 已收线数量
|
|
- 角色数量
|
|
- 伏笔数量
|
|
- 一致性问题数量
|
|
"""
|
|
project = await verify_project(project_id)
|
|
memory_manager = get_memory_manager()
|
|
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
|
|
|
|
pending_count = len([
|
|
t for t in enhanced_memory.pendingThreads
|
|
if not t.resolved
|
|
])
|
|
|
|
resolved_count = len([
|
|
t for t in enhanced_memory.pendingThreads
|
|
if t.resolved
|
|
])
|
|
|
|
foreshadowing_pending = len([
|
|
f for f in enhanced_memory.foreshadowing
|
|
if not f.is_payed_off
|
|
])
|
|
|
|
return {
|
|
"total_events": len(enhanced_memory.eventTimeline),
|
|
"pending_threads": pending_count,
|
|
"resolved_threads": resolved_count,
|
|
"total_threads": len(enhanced_memory.pendingThreads),
|
|
"characters_tracked": len(enhanced_memory.characterStates),
|
|
"foreshadowing_total": len(enhanced_memory.foreshadowing),
|
|
"foreshadowing_pending": foreshadowing_pending,
|
|
"consistency_issues": len(enhanced_memory.consistencyIssues),
|
|
"last_updated": enhanced_memory.last_updated,
|
|
"last_episode_processed": enhanced_memory.last_episode_processed
|
|
}
|