agent-writer/models/script_model.py
2025-09-11 18:34:03 +08:00

397 lines
12 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
会话数据模型
定义会话的数据结构和验证逻辑
"""
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from enum import Enum
class SessionStatus(Enum):
"""会话状态枚举"""
QUEUED = "queued" # 排队中
PENDING = "pending" # 等待中
ACTIVE = "active" # 活跃状态
PAUSED = "paused" # 暂停状态
COMPLETED = "completed" # 已完成
CANCELLED = "cancelled" # 已取消
ERROR = "error" # 错误状态
class SessionType(Enum):
"""会话类型枚举"""
SCRIPTWRITING = "scriptwriting" # 编剧创作
CONSULTATION = "consultation" # 咨询对话
REVIEW = "review" # 剧本评审
COLLABORATION = "collaboration" # 协作创作
@dataclass
class SessionModel:
"""
会话数据模型
定义会话的完整数据结构
"""
# 基础信息
user_id: str
session_type: str = SessionType.SCRIPTWRITING.value
title: str = ""
description: str = ""
# 原始剧本内容
original_script: str = ""
# 状态信息
status: str = SessionStatus.ACTIVE.value
current_step: int = 1
total_steps: int = 6
# 配置信息
settings: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
# 时间信息
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
expires_at: datetime = field(default_factory=lambda: datetime.now() + timedelta(days=7))
# 可选字段
session_id: Optional[str] = None
def __post_init__(self):
"""初始化后处理"""
# 验证数据
self.validate()
# 设置默认值
if not self.title:
self.title = f"编剧会话 - {self.created_at.strftime('%Y-%m-%d %H:%M')}"
# 确保设置字典存在必要的键
default_settings = {
'auto_save': True,
'step_timeout': 1800, # 30分钟
'max_retries': 3,
'language': 'zh-CN',
'genre': 'drama',
'style': 'modern'
}
for key, value in default_settings.items():
if key not in self.settings:
self.settings[key] = value
# 确保元数据字典存在
if 'tags' not in self.metadata:
self.metadata['tags'] = []
if 'source' not in self.metadata:
self.metadata['source'] = 'web_interface'
def validate(self) -> bool:
"""
验证数据有效性
Returns:
bool: 是否有效
Raises:
ValueError: 数据无效时抛出异常
"""
# 验证用户ID
if not self.user_id or not isinstance(self.user_id, str):
raise ValueError("用户ID不能为空且必须是字符串")
# 验证会话类型
valid_types = [t.value for t in SessionType]
if self.session_type not in valid_types:
raise ValueError(f"无效的会话类型: {self.session_type},有效值: {valid_types}")
# 验证状态
valid_statuses = [s.value for s in SessionStatus]
if self.status not in valid_statuses:
raise ValueError(f"无效的会话状态: {self.status},有效值: {valid_statuses}")
# 验证步骤
if not isinstance(self.current_step, int) or self.current_step < 1:
raise ValueError("当前步骤必须是大于0的整数")
if not isinstance(self.total_steps, int) or self.total_steps < 1:
raise ValueError("总步骤数必须是大于0的整数")
if self.current_step > self.total_steps:
raise ValueError("当前步骤不能超过总步骤数")
# 验证时间
if self.updated_at < self.created_at:
raise ValueError("更新时间不能早于创建时间")
if self.expires_at <= self.created_at:
raise ValueError("过期时间必须晚于创建时间")
return True
def to_dict(self) -> Dict[str, Any]:
"""
转换为字典格式
Returns:
Dict: 字典数据
"""
data = {
'user_id': self.user_id,
'session_type': self.session_type,
'title': self.title,
'description': self.description,
'original_script': self.original_script,
'status': self.status,
'current_step': self.current_step,
'total_steps': self.total_steps,
'settings': self.settings.copy(),
'metadata': self.metadata.copy(),
'created_at': self.created_at,
'updated_at': self.updated_at,
'expires_at': self.expires_at
}
if self.session_id:
data['session_id'] = self.session_id
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'SessionModel':
"""
从字典创建模型实例
Args:
data: 字典数据
Returns:
SessionModel: 模型实例
"""
# 提取必需字段
user_id = data.get('user_id')
if not user_id:
raise ValueError("缺少必需字段: user_id")
# 创建实例
instance = cls(
user_id=user_id,
session_type=data.get('session_type', SessionType.SCRIPTWRITING.value),
title=data.get('title', ''),
description=data.get('description', ''),
status=data.get('status', SessionStatus.ACTIVE.value),
current_step=data.get('current_step', 1),
total_steps=data.get('total_steps', 6),
settings=data.get('settings', {}),
metadata=data.get('metadata', {})
)
# 设置时间字段
if 'created_at' in data:
instance.created_at = data['created_at']
if 'updated_at' in data:
instance.updated_at = data['updated_at']
if 'expires_at' in data:
instance.expires_at = data['expires_at']
# 设置可选字段
if 'session_id' in data:
instance.session_id = data['session_id']
return instance
def update_status(self, new_status: str) -> None:
"""
更新会话状态
Args:
new_status: 新状态
"""
valid_statuses = [s.value for s in SessionStatus]
if new_status not in valid_statuses:
raise ValueError(f"无效的会话状态: {new_status}")
self.status = new_status
self.updated_at = datetime.now()
def advance_step(self) -> bool:
"""
推进到下一步
Returns:
bool: 是否成功推进
"""
if self.current_step < self.total_steps:
self.current_step += 1
self.updated_at = datetime.now()
# 如果到达最后一步,标记为完成
if self.current_step == self.total_steps:
self.status = SessionStatus.COMPLETED.value
return True
return False
def reset_step(self, step_number: int) -> bool:
"""
重置到指定步骤
Args:
step_number: 目标步骤号
Returns:
bool: 是否成功重置
"""
if 1 <= step_number <= self.total_steps:
self.current_step = step_number
self.updated_at = datetime.now()
# 如果从完成状态回退,更新状态
if self.status == SessionStatus.COMPLETED.value and step_number < self.total_steps:
self.status = SessionStatus.ACTIVE.value
return True
return False
def update_setting(self, key: str, value: Any) -> None:
"""
更新设置项
Args:
key: 设置键
value: 设置值
"""
self.settings[key] = value
self.updated_at = datetime.now()
def add_metadata(self, key: str, value: Any) -> None:
"""
添加元数据
Args:
key: 元数据键
value: 元数据值
"""
self.metadata[key] = value
self.updated_at = datetime.now()
def add_tag(self, tag: str) -> None:
"""
添加标签
Args:
tag: 标签名称
"""
if 'tags' not in self.metadata:
self.metadata['tags'] = []
if tag not in self.metadata['tags']:
self.metadata['tags'].append(tag)
self.updated_at = datetime.now()
def remove_tag(self, tag: str) -> bool:
"""
移除标签
Args:
tag: 标签名称
Returns:
bool: 是否成功移除
"""
if 'tags' in self.metadata and tag in self.metadata['tags']:
self.metadata['tags'].remove(tag)
self.updated_at = datetime.now()
return True
return False
def extend_expiry(self, days: int = 7) -> None:
"""
延长过期时间
Args:
days: 延长天数
"""
self.expires_at = datetime.now() + timedelta(days=days)
self.updated_at = datetime.now()
def is_expired(self) -> bool:
"""
检查是否已过期
Returns:
bool: 是否已过期
"""
return datetime.now() > self.expires_at
def is_active(self) -> bool:
"""
检查是否处于活跃状态
Returns:
bool: 是否活跃
"""
return self.status == SessionStatus.ACTIVE.value and not self.is_expired()
def is_completed(self) -> bool:
"""
检查是否已完成
Returns:
bool: 是否已完成
"""
return self.status == SessionStatus.COMPLETED.value or self.current_step >= self.total_steps
def get_progress_percentage(self) -> float:
"""
获取进度百分比
Returns:
float: 进度百分比 (0-100)
"""
if self.total_steps <= 0:
return 0.0
return min(100.0, (self.current_step / self.total_steps) * 100.0)
def get_remaining_steps(self) -> int:
"""
获取剩余步骤数
Returns:
int: 剩余步骤数
"""
return max(0, self.total_steps - self.current_step)
def get_session_duration(self) -> timedelta:
"""
获取会话持续时间
Returns:
timedelta: 持续时间
"""
return self.updated_at - self.created_at
def __str__(self) -> str:
"""
字符串表示
Returns:
str: 字符串描述
"""
return f"SessionModel(id={self.session_id}, user={self.user_id}, step={self.current_step}/{self.total_steps}, status={self.status})"
def __repr__(self) -> str:
"""
详细字符串表示
Returns:
str: 详细描述
"""
return (f"SessionModel(session_id={self.session_id}, user_id={self.user_id}, "
f"type={self.session_type}, title='{self.title}', status={self.status}, "
f"step={self.current_step}/{self.total_steps}, created={self.created_at})")