2026-01-25 19:27:44 +08:00

730 lines
20 KiB
Python

"""
记忆系统 API 路由
Memory System API - 提供记忆系统的管理接口
"""
from fastapi import APIRouter, Depends, HTTPException, status
from typing import List, Dict
from datetime import datetime
import uuid
from app.models.memory import (
EnhancedMemory,
TimelineEvent,
PendingThread,
ThreadCreateRequest,
ThreadUpdateRequest,
ThreadResolveRequest,
MemoryUpdateRequest,
MemoryExtractionResult,
CharacterStateChange,
ResolutionSuggestion,
ThreadStatus,
ImportanceLevel
)
from app.models.project import SeriesProject, Episode
from app.core.memory.memory_manager import get_memory_manager, MemoryManager
from app.db.repositories import project_repo, episode_repo
from app.utils.logger import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/projects/{project_id}/memory", tags=["记忆系统"])
# ============================================
# 辅助函数
# ============================================
async def verify_project(project_id: str) -> SeriesProject:
"""验证项目存在性"""
project = await project_repo.get(project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"项目不存在: {project_id}"
)
return project
# ============================================
# 记忆系统主端点
# ============================================
@router.get("/", response_model=EnhancedMemory)
async def get_project_memory(project_id: str):
"""
获取项目完整记忆系统
返回项目的所有记忆数据,包括:
- 事件时间线
- 待收线问题
- 角色状态历史
- 伏笔追踪
- 角色关系
"""
project = await verify_project(project_id)
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
return enhanced_memory
@router.post("/update", response_model=MemoryExtractionResult)
async def update_memory_from_episode(
project_id: str,
request: MemoryUpdateRequest
):
"""
从剧集更新记忆系统
自动执行以下操作:
- 提取关键事件
- 检测伏笔和待收线
- 更新角色状态
- 检查一致性
Args:
request: 记忆更新请求
"""
project = await verify_project(project_id)
# 获取剧集
episodes = await episode_repo.list_by_project(project_id)
episode = None
for ep in episodes:
if ep.number == request.episode_number:
episode = ep
break
if not episode:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"剧集不存在: EP{request.episode_number}"
)
if not episode.content:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="剧集内容为空,无法更新记忆"
)
try:
memory_manager = get_memory_manager()
result = await memory_manager.update_memory_from_episode(project, episode)
# 保存更新后的记忆
await project_repo.update(project_id, {
"memory": project.memory.dict()
})
return result
except Exception as e:
logger.error(f"更新记忆失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新记忆失败: {str(e)}"
)
# ============================================
# 事件时间线
# ============================================
@router.get("/timeline", response_model=List[TimelineEvent])
async def get_event_timeline(
project_id: str,
skip: int = 0,
limit: int = 100,
episode: int = None,
importance: ImportanceLevel = None,
character: str = None
):
"""
获取事件时间线
支持过滤:
- episode: 按集数过滤
- importance: 按重要程度过滤
- character: 按角色过滤
"""
project = await verify_project(project_id)
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
events = enhanced_memory.eventTimeline
# 应用过滤
if episode is not None:
events = [e for e in events if e.episode == episode]
if importance is not None:
events = [e for e in events if e.importance == importance]
if character:
events = [e for e in events if character in e.characters_involved]
# 按集数排序
events = sorted(events, key=lambda x: x.episode)
# 分页
return events[skip:skip + limit]
@router.get("/timeline/episodes", response_model=List[dict])
async def get_timeline_by_episodes(project_id: str):
"""
按集数分组获取时间线
返回格式:[{episode: 1, events: [...]}, ...]
"""
project = await verify_project(project_id)
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
# 按集数分组
episodes_dict = {}
for event in enhanced_memory.eventTimeline:
ep = event.episode
if ep not in episodes_dict:
episodes_dict[ep] = []
episodes_dict[ep].append(event)
# 转换为列表并排序
result = [
{"episode": ep, "events": events}
for ep, events in sorted(episodes_dict.items())
]
return result
# ============================================
# 待收线管理
# ============================================
@router.get("/threads", response_model=List[PendingThread])
async def get_pending_threads(
project_id: str,
status: ThreadStatus = None,
importance: ImportanceLevel = None
):
"""
获取待收线列表
支持过滤:
- status: 按状态过滤
- importance: 按重要程度过滤
"""
project = await verify_project(project_id)
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
threads = enhanced_memory.pendingThreads
# 应用过滤
if status is not None:
threads = [t for t in threads if t.status == status]
if importance is not None:
threads = [t for t in threads if t.importance == importance]
# 按引入集数排序
threads = sorted(threads, key=lambda x: x.introduced_at)
return threads
@router.post("/threads", response_model=PendingThread, status_code=status.HTTP_201_CREATED)
async def create_thread(
project_id: str,
request: ThreadCreateRequest
):
"""
手动添加待收线问题
允许用户手动添加待收线,用于补充 LLM 未检测到的重要伏笔
"""
project = await verify_project(project_id)
# 获取当前最大集数
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
current_episode = enhanced_memory.last_episode_processed
thread = PendingThread(
id=str(uuid.uuid4()),
description=request.description,
introduced_at=current_episode,
importance=request.importance,
reminder_episode=request.reminder_episode,
characters_involved=request.characters_involved,
notes=request.notes,
created_at=datetime.now(),
updated_at=datetime.now()
)
# 添加到记忆
enhanced_memory.pendingThreads.append(thread)
project.memory = memory_manager._convert_to_memory(enhanced_memory)
# 保存
await project_repo.update(project_id, {
"memory": project.memory.dict()
})
logger.info(f"创建待收线: {thread.id} - {thread.description}")
return thread
@router.put("/threads/{thread_id}", response_model=PendingThread)
async def update_thread(
project_id: str,
thread_id: str,
request: ThreadUpdateRequest
):
"""
更新待收线信息
允许修改待收线的描述、状态、重要程度等
"""
project = await verify_project(project_id)
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
# 查找待收线
thread = None
for t in enhanced_memory.pendingThreads:
if t.id == thread_id:
thread = t
break
if not thread:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"待收线不存在: {thread_id}"
)
# 更新字段
if request.description is not None:
thread.description = request.description
if request.importance is not None:
thread.importance = request.importance
if request.status is not None:
thread.status = request.status
if request.reminder_episode is not None:
thread.reminder_episode = request.reminder_episode
if request.characters_involved is not None:
thread.characters_involved = request.characters_involved
if request.notes is not None:
thread.notes = request.notes
thread.updated_at = datetime.now()
# 保存
project.memory = memory_manager._convert_to_memory(enhanced_memory)
await project_repo.update(project_id, {
"memory": project.memory.dict()
})
logger.info(f"更新待收线: {thread_id}")
return thread
@router.post("/resolve-thread", response_model=dict)
async def resolve_thread(
project_id: str,
request: ThreadResolveRequest
):
"""
标记待收线为已解决
记录待收线的收线信息
"""
project = await verify_project(project_id)
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
# 查找待收线
thread = None
for t in enhanced_memory.pendingThreads:
if t.id == request.thread_id:
thread = t
break
if not thread:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"待收线不存在: {request.thread_id}"
)
# 更新状态
thread.resolved = True
thread.resolved_at = request.resolved_at
thread.status = ThreadStatus.RESOLVED
thread.updated_at = datetime.now()
# 添加收线摘要到 notes
if request.resolution_summary:
thread.notes += f"\n[收线摘要 EP{request.resolved_at}]: {request.resolution_summary}"
# 保存
project.memory = memory_manager._convert_to_memory(enhanced_memory)
await project_repo.update(project_id, {
"memory": project.memory.dict()
})
logger.info(f"收线完成: {request.thread_id} @ EP{request.resolved_at}")
return {
"success": True,
"message": f"待收线已收线",
"thread_id": request.thread_id,
"resolved_at": request.resolved_at
}
@router.delete("/threads/{thread_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_thread(
project_id: str,
thread_id: str
):
"""
删除待收线
移除不需要的待收线记录
"""
project = await verify_project(project_id)
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
# 查找并删除
original_count = len(enhanced_memory.pendingThreads)
enhanced_memory.pendingThreads = [
t for t in enhanced_memory.pendingThreads
if t.id != thread_id
]
if len(enhanced_memory.pendingThreads) == original_count:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"待收线不存在: {thread_id}"
)
# 保存
project.memory = memory_manager._convert_to_memory(enhanced_memory)
await project_repo.update(project_id, {
"memory": project.memory.dict()
})
logger.info(f"删除待收线: {thread_id}")
return None
@router.get("/threads/{thread_id}/suggestion", response_model=ResolutionSuggestion)
async def get_thread_resolution_suggestion(
project_id: str,
thread_id: str
):
"""
获取待收线的收线建议
使用 LLM 生成如何收线的建议
"""
project = await verify_project(project_id)
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
# 查找待收线
thread = None
for t in enhanced_memory.pendingThreads:
if t.id == thread_id:
thread = t
break
if not thread:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"待收线不存在: {thread_id}"
)
try:
suggestion = await memory_manager.suggest_thread_resolution(
thread=thread,
current_episode=enhanced_memory.last_episode_processed,
memory=enhanced_memory
)
return suggestion
except Exception as e:
logger.error(f"生成收线建议失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"生成收线建议失败: {str(e)}"
)
# ============================================
# 角色状态
# ============================================
@router.get("/characters", response_model=Dict[str, List[CharacterStateChange]])
async def get_character_states(
project_id: str,
character: str = None
):
"""
获取角色状态历史
Args:
character: 可选,筛选特定角色
"""
project = await verify_project(project_id)
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
if character:
if character not in enhanced_memory.characterStates:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"角色不存在: {character}"
)
return {character: enhanced_memory.characterStates[character]}
return enhanced_memory.characterStates
@router.get("/characters/{character_name}", response_model=List[CharacterStateChange])
async def get_character_state_history(
project_id: str,
character_name: str
):
"""
获取特定角色的状态变化历史
返回该角色在各集的状态变化记录
"""
project = await verify_project(project_id)
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
if character_name not in enhanced_memory.characterStates:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"角色不存在: {character_name}"
)
return enhanced_memory.characterStates[character_name]
# ============================================
# 伏笔追踪
# ============================================
@router.get("/foreshadowing")
async def get_foreshadowing(
project_id: str,
payoff_status: bool = None
):
"""
获取伏笔列表
Args:
payoff_status: 可选,筛选是否已呼应
"""
project = await verify_project(project_id)
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
foreshadowing = enhanced_memory.foreshadowing
if payoff_status is not None:
foreshadowing = [f for f in foreshadowing if f.is_payed_off == payoff_status]
return foreshadowing
# ============================================
# 一致性检查
# ============================================
@router.post("/check-consistency")
async def run_consistency_check(
project_id: str,
episode_number: int
):
"""
对指定剧集进行一致性检查
检查剧集与已有记忆的一致性
"""
project = await verify_project(project_id)
# 获取剧集
episodes = await episode_repo.list_by_project(project_id)
episode = None
for ep in episodes:
if ep.number == episode_number:
episode = ep
break
if not episode:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"剧集不存在: EP{episode_number}"
)
if not episode.content:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="剧集内容为空"
)
try:
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
issues = await memory_manager.check_consistency(
episode.content,
episode_number,
enhanced_memory
)
# 添加到记忆
enhanced_memory.consistencyIssues.extend(issues)
project.memory = memory_manager._convert_to_memory(enhanced_memory)
# 保存
await project_repo.update(project_id, {
"memory": project.memory.dict()
})
return {
"episode_number": episode_number,
"issues_found": len(issues),
"issues": issues
}
except Exception as e:
logger.error(f"一致性检查失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"一致性检查失败: {str(e)}"
)
@router.get("/consistency-issues")
async def get_consistency_issues(
project_id: str,
severity: ImportanceLevel = None
):
"""
获取一致性问题列表
Args:
severity: 可选,按严重程度筛选
"""
project = await verify_project(project_id)
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
issues = enhanced_memory.consistencyIssues
if severity is not None:
issues = [i for i in issues if i.severity == severity]
return issues
# ============================================
# 角色关系
# ============================================
@router.get("/relationships")
async def get_relationships(project_id: str):
"""
获取角色关系网络
返回角色之间的关系映射
"""
project = await verify_project(project_id)
return project.memory.relationships
@router.put("/relationships")
async def update_relationships(
project_id: str,
relationships: Dict[str, Dict[str, str]]
):
"""
更新角色关系网络
允许手动更新角色之间的关系
"""
project = await verify_project(project_id)
project.memory.relationships = relationships
await project_repo.update(project_id, {
"memory": project.memory.dict()
})
logger.info(f"更新角色关系网络")
return project.memory.relationships
# ============================================
# 统计信息
# ============================================
@router.get("/stats")
async def get_memory_stats(project_id: str):
"""
获取记忆系统统计信息
返回:
- 总事件数
- 待收线数量
- 已收线数量
- 角色数量
- 伏笔数量
- 一致性问题数量
"""
project = await verify_project(project_id)
memory_manager = get_memory_manager()
enhanced_memory = memory_manager._convert_to_enhanced_memory(project.memory)
pending_count = len([
t for t in enhanced_memory.pendingThreads
if not t.resolved
])
resolved_count = len([
t for t in enhanced_memory.pendingThreads
if t.resolved
])
foreshadowing_pending = len([
f for f in enhanced_memory.foreshadowing
if not f.is_payed_off
])
return {
"total_events": len(enhanced_memory.eventTimeline),
"pending_threads": pending_count,
"resolved_threads": resolved_count,
"total_threads": len(enhanced_memory.pendingThreads),
"characters_tracked": len(enhanced_memory.characterStates),
"foreshadowing_total": len(enhanced_memory.foreshadowing),
"foreshadowing_pending": foreshadowing_pending,
"consistency_issues": len(enhanced_memory.consistencyIssues),
"last_updated": enhanced_memory.last_updated,
"last_episode_processed": enhanced_memory.last_episode_processed
}