394 lines
12 KiB
Python
394 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
会话数据模型
|
|
|
|
定义会话的数据结构和验证逻辑
|
|
"""
|
|
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
|
|
class SessionStatus(Enum):
|
|
"""会话状态枚举"""
|
|
QUEUED = "queued" # 排队中
|
|
PENDING = "pending" # 等待中
|
|
ACTIVE = "active" # 活跃状态
|
|
PAUSED = "paused" # 暂停状态
|
|
COMPLETED = "completed" # 已完成
|
|
CANCELLED = "cancelled" # 已取消
|
|
ERROR = "error" # 错误状态
|
|
|
|
class SessionType(Enum):
|
|
"""会话类型枚举"""
|
|
SCRIPTWRITING = "scriptwriting" # 编剧创作
|
|
CONSULTATION = "consultation" # 咨询对话
|
|
REVIEW = "review" # 剧本评审
|
|
COLLABORATION = "collaboration" # 协作创作
|
|
|
|
@dataclass
|
|
class SessionModel:
|
|
"""
|
|
会话数据模型
|
|
|
|
定义会话的完整数据结构
|
|
"""
|
|
|
|
# 基础信息
|
|
user_id: str
|
|
session_type: str = SessionType.SCRIPTWRITING.value
|
|
title: str = ""
|
|
# 原始剧本内容
|
|
original_script: str = ""
|
|
|
|
# 状态信息
|
|
status: str = SessionStatus.ACTIVE.value
|
|
current_step: int = 1
|
|
total_steps: int = 6
|
|
|
|
# 配置信息
|
|
settings: Dict[str, Any] = field(default_factory=dict)
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
# 时间信息
|
|
created_at: datetime = field(default_factory=datetime.now)
|
|
updated_at: datetime = field(default_factory=datetime.now)
|
|
expires_at: datetime = field(default_factory=lambda: datetime.now() + timedelta(days=7))
|
|
|
|
# 可选字段
|
|
session_id: Optional[str] = None
|
|
|
|
def __post_init__(self):
|
|
"""初始化后处理"""
|
|
# 验证数据
|
|
self.validate()
|
|
|
|
# 设置默认值
|
|
if not self.title:
|
|
self.title = f"编剧会话 - {self.created_at.strftime('%Y-%m-%d %H:%M')}"
|
|
|
|
# 确保设置字典存在必要的键
|
|
default_settings = {
|
|
'auto_save': True,
|
|
'step_timeout': 1800, # 30分钟
|
|
'max_retries': 3,
|
|
'language': 'zh-CN',
|
|
'genre': 'drama',
|
|
'style': 'modern'
|
|
}
|
|
|
|
for key, value in default_settings.items():
|
|
if key not in self.settings:
|
|
self.settings[key] = value
|
|
|
|
# 确保元数据字典存在
|
|
if 'tags' not in self.metadata:
|
|
self.metadata['tags'] = []
|
|
if 'source' not in self.metadata:
|
|
self.metadata['source'] = 'web_interface'
|
|
|
|
def validate(self) -> bool:
|
|
"""
|
|
验证数据有效性
|
|
|
|
Returns:
|
|
bool: 是否有效
|
|
|
|
Raises:
|
|
ValueError: 数据无效时抛出异常
|
|
"""
|
|
# 验证用户ID
|
|
if not self.user_id or not isinstance(self.user_id, str):
|
|
raise ValueError("用户ID不能为空且必须是字符串")
|
|
|
|
# 验证会话类型
|
|
valid_types = [t.value for t in SessionType]
|
|
if self.session_type not in valid_types:
|
|
raise ValueError(f"无效的会话类型: {self.session_type},有效值: {valid_types}")
|
|
|
|
# 验证状态
|
|
valid_statuses = [s.value for s in SessionStatus]
|
|
if self.status not in valid_statuses:
|
|
raise ValueError(f"无效的会话状态: {self.status},有效值: {valid_statuses}")
|
|
|
|
# 验证步骤
|
|
if not isinstance(self.current_step, int) or self.current_step < 1:
|
|
raise ValueError("当前步骤必须是大于0的整数")
|
|
|
|
if not isinstance(self.total_steps, int) or self.total_steps < 1:
|
|
raise ValueError("总步骤数必须是大于0的整数")
|
|
|
|
if self.current_step > self.total_steps:
|
|
raise ValueError("当前步骤不能超过总步骤数")
|
|
|
|
# 验证时间
|
|
if self.updated_at < self.created_at:
|
|
raise ValueError("更新时间不能早于创建时间")
|
|
|
|
if self.expires_at <= self.created_at:
|
|
raise ValueError("过期时间必须晚于创建时间")
|
|
|
|
return True
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""
|
|
转换为字典格式
|
|
|
|
Returns:
|
|
Dict: 字典数据
|
|
"""
|
|
data = {
|
|
'user_id': self.user_id,
|
|
'session_type': self.session_type,
|
|
'title': self.title,
|
|
'original_script': self.original_script,
|
|
'status': self.status,
|
|
'current_step': self.current_step,
|
|
'total_steps': self.total_steps,
|
|
'settings': self.settings.copy(),
|
|
'metadata': self.metadata.copy(),
|
|
'created_at': self.created_at,
|
|
'updated_at': self.updated_at,
|
|
'expires_at': self.expires_at
|
|
}
|
|
|
|
if self.session_id:
|
|
data['session_id'] = self.session_id
|
|
|
|
return data
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'SessionModel':
|
|
"""
|
|
从字典创建模型实例
|
|
|
|
Args:
|
|
data: 字典数据
|
|
|
|
Returns:
|
|
SessionModel: 模型实例
|
|
"""
|
|
# 提取必需字段
|
|
user_id = data.get('user_id')
|
|
if not user_id:
|
|
raise ValueError("缺少必需字段: user_id")
|
|
|
|
# 创建实例
|
|
instance = cls(
|
|
user_id=user_id,
|
|
session_type=data.get('session_type', SessionType.SCRIPTWRITING.value),
|
|
title=data.get('title', ''),
|
|
status=data.get('status', SessionStatus.ACTIVE.value),
|
|
current_step=data.get('current_step', 1),
|
|
total_steps=data.get('total_steps', 6),
|
|
settings=data.get('settings', {}),
|
|
metadata=data.get('metadata', {})
|
|
)
|
|
|
|
# 设置时间字段
|
|
if 'created_at' in data:
|
|
instance.created_at = data['created_at']
|
|
if 'updated_at' in data:
|
|
instance.updated_at = data['updated_at']
|
|
if 'expires_at' in data:
|
|
instance.expires_at = data['expires_at']
|
|
|
|
# 设置可选字段
|
|
if 'session_id' in data:
|
|
instance.session_id = data['session_id']
|
|
|
|
return instance
|
|
|
|
def update_status(self, new_status: str) -> None:
|
|
"""
|
|
更新会话状态
|
|
|
|
Args:
|
|
new_status: 新状态
|
|
"""
|
|
valid_statuses = [s.value for s in SessionStatus]
|
|
if new_status not in valid_statuses:
|
|
raise ValueError(f"无效的会话状态: {new_status}")
|
|
|
|
self.status = new_status
|
|
self.updated_at = datetime.now()
|
|
|
|
def advance_step(self) -> bool:
|
|
"""
|
|
推进到下一步
|
|
|
|
Returns:
|
|
bool: 是否成功推进
|
|
"""
|
|
if self.current_step < self.total_steps:
|
|
self.current_step += 1
|
|
self.updated_at = datetime.now()
|
|
|
|
# 如果到达最后一步,标记为完成
|
|
if self.current_step == self.total_steps:
|
|
self.status = SessionStatus.COMPLETED.value
|
|
|
|
return True
|
|
return False
|
|
|
|
def reset_step(self, step_number: int) -> bool:
|
|
"""
|
|
重置到指定步骤
|
|
|
|
Args:
|
|
step_number: 目标步骤号
|
|
|
|
Returns:
|
|
bool: 是否成功重置
|
|
"""
|
|
if 1 <= step_number <= self.total_steps:
|
|
self.current_step = step_number
|
|
self.updated_at = datetime.now()
|
|
|
|
# 如果从完成状态回退,更新状态
|
|
if self.status == SessionStatus.COMPLETED.value and step_number < self.total_steps:
|
|
self.status = SessionStatus.ACTIVE.value
|
|
|
|
return True
|
|
return False
|
|
|
|
def update_setting(self, key: str, value: Any) -> None:
|
|
"""
|
|
更新设置项
|
|
|
|
Args:
|
|
key: 设置键
|
|
value: 设置值
|
|
"""
|
|
self.settings[key] = value
|
|
self.updated_at = datetime.now()
|
|
|
|
def add_metadata(self, key: str, value: Any) -> None:
|
|
"""
|
|
添加元数据
|
|
|
|
Args:
|
|
key: 元数据键
|
|
value: 元数据值
|
|
"""
|
|
self.metadata[key] = value
|
|
self.updated_at = datetime.now()
|
|
|
|
def add_tag(self, tag: str) -> None:
|
|
"""
|
|
添加标签
|
|
|
|
Args:
|
|
tag: 标签名称
|
|
"""
|
|
if 'tags' not in self.metadata:
|
|
self.metadata['tags'] = []
|
|
|
|
if tag not in self.metadata['tags']:
|
|
self.metadata['tags'].append(tag)
|
|
self.updated_at = datetime.now()
|
|
|
|
def remove_tag(self, tag: str) -> bool:
|
|
"""
|
|
移除标签
|
|
|
|
Args:
|
|
tag: 标签名称
|
|
|
|
Returns:
|
|
bool: 是否成功移除
|
|
"""
|
|
if 'tags' in self.metadata and tag in self.metadata['tags']:
|
|
self.metadata['tags'].remove(tag)
|
|
self.updated_at = datetime.now()
|
|
return True
|
|
return False
|
|
|
|
def extend_expiry(self, days: int = 7) -> None:
|
|
"""
|
|
延长过期时间
|
|
|
|
Args:
|
|
days: 延长天数
|
|
"""
|
|
self.expires_at = datetime.now() + timedelta(days=days)
|
|
self.updated_at = datetime.now()
|
|
|
|
def is_expired(self) -> bool:
|
|
"""
|
|
检查是否已过期
|
|
|
|
Returns:
|
|
bool: 是否已过期
|
|
"""
|
|
return datetime.now() > self.expires_at
|
|
|
|
def is_active(self) -> bool:
|
|
"""
|
|
检查是否处于活跃状态
|
|
|
|
Returns:
|
|
bool: 是否活跃
|
|
"""
|
|
return self.status == SessionStatus.ACTIVE.value and not self.is_expired()
|
|
|
|
def is_completed(self) -> bool:
|
|
"""
|
|
检查是否已完成
|
|
|
|
Returns:
|
|
bool: 是否已完成
|
|
"""
|
|
return self.status == SessionStatus.COMPLETED.value or self.current_step >= self.total_steps
|
|
|
|
def get_progress_percentage(self) -> float:
|
|
"""
|
|
获取进度百分比
|
|
|
|
Returns:
|
|
float: 进度百分比 (0-100)
|
|
"""
|
|
if self.total_steps <= 0:
|
|
return 0.0
|
|
|
|
return min(100.0, (self.current_step / self.total_steps) * 100.0)
|
|
|
|
def get_remaining_steps(self) -> int:
|
|
"""
|
|
获取剩余步骤数
|
|
|
|
Returns:
|
|
int: 剩余步骤数
|
|
"""
|
|
return max(0, self.total_steps - self.current_step)
|
|
|
|
def get_session_duration(self) -> timedelta:
|
|
"""
|
|
获取会话持续时间
|
|
|
|
Returns:
|
|
timedelta: 持续时间
|
|
"""
|
|
return self.updated_at - self.created_at
|
|
|
|
def __str__(self) -> str:
|
|
"""
|
|
字符串表示
|
|
|
|
Returns:
|
|
str: 字符串描述
|
|
"""
|
|
return f"SessionModel(id={self.session_id}, user={self.user_id}, step={self.current_step}/{self.total_steps}, status={self.status})"
|
|
|
|
def __repr__(self) -> str:
|
|
"""
|
|
详细字符串表示
|
|
|
|
Returns:
|
|
str: 详细描述
|
|
"""
|
|
return (f"SessionModel(session_id={self.session_id}, user_id={self.user_id}, "
|
|
f"type={self.session_type}, title='{self.title}', status={self.status}, "
|
|
f"step={self.current_step}/{self.total_steps}, created={self.created_at})") |