#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 会话数据模型 定义会话的数据结构和验证逻辑 """ from datetime import datetime, timedelta from typing import Dict, List, Optional, Any from dataclasses import dataclass, field from enum import Enum class SessionStatus(Enum): """会话状态枚举""" QUEUED = "queued" # 排队中 PENDING = "pending" # 等待中 ACTIVE = "active" # 活跃状态 PAUSED = "paused" # 暂停状态 COMPLETED = "completed" # 已完成 CANCELLED = "cancelled" # 已取消 ERROR = "error" # 错误状态 class SessionType(Enum): """会话类型枚举""" SCRIPTWRITING = "scriptwriting" # 编剧创作 CONSULTATION = "consultation" # 咨询对话 REVIEW = "review" # 剧本评审 COLLABORATION = "collaboration" # 协作创作 @dataclass class SessionModel: """ 会话数据模型 定义会话的完整数据结构 """ # 基础信息 user_id: str session_type: str = SessionType.SCRIPTWRITING.value title: str = "" # 原始剧本内容 original_script: str = "" # 状态信息 status: str = SessionStatus.ACTIVE.value current_step: int = 1 total_steps: int = 6 # 配置信息 settings: Dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict) # 时间信息 created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) expires_at: datetime = field(default_factory=lambda: datetime.now() + timedelta(days=7)) # 可选字段 session_id: Optional[str] = None def __post_init__(self): """初始化后处理""" # 验证数据 self.validate() # 设置默认值 if not self.title: self.title = f"编剧会话 - {self.created_at.strftime('%Y-%m-%d %H:%M')}" # 确保设置字典存在必要的键 default_settings = { 'auto_save': True, 'step_timeout': 1800, # 30分钟 'max_retries': 3, 'language': 'zh-CN', 'genre': 'drama', 'style': 'modern' } for key, value in default_settings.items(): if key not in self.settings: self.settings[key] = value # 确保元数据字典存在 if 'tags' not in self.metadata: self.metadata['tags'] = [] if 'source' not in self.metadata: self.metadata['source'] = 'web_interface' def validate(self) -> bool: """ 验证数据有效性 Returns: bool: 是否有效 Raises: ValueError: 数据无效时抛出异常 """ # 验证用户ID if not self.user_id or not isinstance(self.user_id, str): raise ValueError("用户ID不能为空且必须是字符串") # 验证会话类型 valid_types = [t.value for t in SessionType] if self.session_type not in valid_types: raise ValueError(f"无效的会话类型: {self.session_type},有效值: {valid_types}") # 验证状态 valid_statuses = [s.value for s in SessionStatus] if self.status not in valid_statuses: raise ValueError(f"无效的会话状态: {self.status},有效值: {valid_statuses}") # 验证步骤 if not isinstance(self.current_step, int) or self.current_step < 1: raise ValueError("当前步骤必须是大于0的整数") if not isinstance(self.total_steps, int) or self.total_steps < 1: raise ValueError("总步骤数必须是大于0的整数") if self.current_step > self.total_steps: raise ValueError("当前步骤不能超过总步骤数") # 验证时间 if self.updated_at < self.created_at: raise ValueError("更新时间不能早于创建时间") if self.expires_at <= self.created_at: raise ValueError("过期时间必须晚于创建时间") return True def to_dict(self) -> Dict[str, Any]: """ 转换为字典格式 Returns: Dict: 字典数据 """ data = { 'user_id': self.user_id, 'session_type': self.session_type, 'title': self.title, 'original_script': self.original_script, 'status': self.status, 'current_step': self.current_step, 'total_steps': self.total_steps, 'settings': self.settings.copy(), 'metadata': self.metadata.copy(), 'created_at': self.created_at, 'updated_at': self.updated_at, 'expires_at': self.expires_at } if self.session_id: data['session_id'] = self.session_id return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'SessionModel': """ 从字典创建模型实例 Args: data: 字典数据 Returns: SessionModel: 模型实例 """ # 提取必需字段 user_id = data.get('user_id') if not user_id: raise ValueError("缺少必需字段: user_id") # 创建实例 instance = cls( user_id=user_id, session_type=data.get('session_type', SessionType.SCRIPTWRITING.value), title=data.get('title', ''), status=data.get('status', SessionStatus.ACTIVE.value), current_step=data.get('current_step', 1), total_steps=data.get('total_steps', 6), settings=data.get('settings', {}), metadata=data.get('metadata', {}) ) # 设置时间字段 if 'created_at' in data: instance.created_at = data['created_at'] if 'updated_at' in data: instance.updated_at = data['updated_at'] if 'expires_at' in data: instance.expires_at = data['expires_at'] # 设置可选字段 if 'session_id' in data: instance.session_id = data['session_id'] return instance def update_status(self, new_status: str) -> None: """ 更新会话状态 Args: new_status: 新状态 """ valid_statuses = [s.value for s in SessionStatus] if new_status not in valid_statuses: raise ValueError(f"无效的会话状态: {new_status}") self.status = new_status self.updated_at = datetime.now() def advance_step(self) -> bool: """ 推进到下一步 Returns: bool: 是否成功推进 """ if self.current_step < self.total_steps: self.current_step += 1 self.updated_at = datetime.now() # 如果到达最后一步,标记为完成 if self.current_step == self.total_steps: self.status = SessionStatus.COMPLETED.value return True return False def reset_step(self, step_number: int) -> bool: """ 重置到指定步骤 Args: step_number: 目标步骤号 Returns: bool: 是否成功重置 """ if 1 <= step_number <= self.total_steps: self.current_step = step_number self.updated_at = datetime.now() # 如果从完成状态回退,更新状态 if self.status == SessionStatus.COMPLETED.value and step_number < self.total_steps: self.status = SessionStatus.ACTIVE.value return True return False def update_setting(self, key: str, value: Any) -> None: """ 更新设置项 Args: key: 设置键 value: 设置值 """ self.settings[key] = value self.updated_at = datetime.now() def add_metadata(self, key: str, value: Any) -> None: """ 添加元数据 Args: key: 元数据键 value: 元数据值 """ self.metadata[key] = value self.updated_at = datetime.now() def add_tag(self, tag: str) -> None: """ 添加标签 Args: tag: 标签名称 """ if 'tags' not in self.metadata: self.metadata['tags'] = [] if tag not in self.metadata['tags']: self.metadata['tags'].append(tag) self.updated_at = datetime.now() def remove_tag(self, tag: str) -> bool: """ 移除标签 Args: tag: 标签名称 Returns: bool: 是否成功移除 """ if 'tags' in self.metadata and tag in self.metadata['tags']: self.metadata['tags'].remove(tag) self.updated_at = datetime.now() return True return False def extend_expiry(self, days: int = 7) -> None: """ 延长过期时间 Args: days: 延长天数 """ self.expires_at = datetime.now() + timedelta(days=days) self.updated_at = datetime.now() def is_expired(self) -> bool: """ 检查是否已过期 Returns: bool: 是否已过期 """ return datetime.now() > self.expires_at def is_active(self) -> bool: """ 检查是否处于活跃状态 Returns: bool: 是否活跃 """ return self.status == SessionStatus.ACTIVE.value and not self.is_expired() def is_completed(self) -> bool: """ 检查是否已完成 Returns: bool: 是否已完成 """ return self.status == SessionStatus.COMPLETED.value or self.current_step >= self.total_steps def get_progress_percentage(self) -> float: """ 获取进度百分比 Returns: float: 进度百分比 (0-100) """ if self.total_steps <= 0: return 0.0 return min(100.0, (self.current_step / self.total_steps) * 100.0) def get_remaining_steps(self) -> int: """ 获取剩余步骤数 Returns: int: 剩余步骤数 """ return max(0, self.total_steps - self.current_step) def get_session_duration(self) -> timedelta: """ 获取会话持续时间 Returns: timedelta: 持续时间 """ return self.updated_at - self.created_at def __str__(self) -> str: """ 字符串表示 Returns: str: 字符串描述 """ return f"SessionModel(id={self.session_id}, user={self.user_id}, step={self.current_step}/{self.total_steps}, status={self.status})" def __repr__(self) -> str: """ 详细字符串表示 Returns: str: 详细描述 """ return (f"SessionModel(session_id={self.session_id}, user_id={self.user_id}, " f"type={self.session_type}, title='{self.title}', status={self.status}, " f"step={self.current_step}/{self.total_steps}, created={self.created_at})")