#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 格式化器模块 提供响应格式化、数据转换和输出格式化功能 """ import json from datetime import datetime, date from typing import Dict, List, Optional, Any, Union from decimal import Decimal class ResponseFormatter: """响应格式化器""" @staticmethod def success_response(data: Any = None, message: str = "操作成功", code: int = 200, extra: Dict[str, Any] = {}) -> Dict[str, Any]: """ 创建成功响应 Args: data: 响应数据 message: 响应消息 code: 状态码 extra: 额外信息 Returns: Dict: 格式化的响应 """ response = { "success": True, "code": code, "message": message, "timestamp": datetime.now().isoformat(), "data": data } if extra: response.update(extra) return response @staticmethod def error_response(message: str = "操作失败", code: int = 400, error_type: str = "ValidationError", details: Dict[str, Any] = {}) -> Dict[str, Any]: """ 创建错误响应 Args: message: 错误消息 code: 错误码 error_type: 错误类型 details: 错误详情 Returns: Dict: 格式化的错误响应 """ response = { "success": False, "code": code, "message": message, "error_type": error_type, "timestamp": datetime.now().isoformat() } if details: response["details"] = details return response @staticmethod def paginated_response(items: List[Any], total: int, page: int, page_size: int, message: str = "查询成功") -> Dict[str, Any]: """ 创建分页响应 Args: items: 数据项列表 total: 总数量 page: 当前页码 page_size: 页大小 message: 响应消息 Returns: Dict: 格式化的分页响应 """ total_pages = (total + page_size - 1) // page_size pagination = { "current_page": page, "page_size": page_size, "total_items": total, "total_pages": total_pages, "has_next": page < total_pages, "has_prev": page > 1 } return ResponseFormatter.success_response( data={ "items": items, "pagination": pagination }, message=message ) @staticmethod def stream_response(data: Any, event_type: str = "data", event_id: str = "") -> str: """ 创建流式响应 Args: data: 响应数据 event_type: 事件类型 event_id: 事件ID Returns: str: SSE格式的响应 """ lines = [] if event_id: lines.append(f"id: {event_id}") lines.append(f"event: {event_type}") # 将数据转换为JSON字符串 if isinstance(data, (dict, list)): data_str = json.dumps(data, ensure_ascii=False, cls=DateTimeEncoder) else: data_str = str(data) lines.append(f"data: {data_str}") lines.append("") # 空行表示事件结束 return "\n".join(lines) @staticmethod def validation_error_response(errors: List[str], warnings: List[str] = []) -> Dict[str, Any]: """ 创建验证错误响应 Args: errors: 错误列表 warnings: 警告列表 Returns: Dict: 格式化的验证错误响应 """ details = {"errors": errors} if warnings: details["warnings"] = warnings return ResponseFormatter.error_response( message="数据验证失败", code=422, error_type="ValidationError", details=details ) @staticmethod def not_found_response(resource: str = "资源") -> Dict[str, Any]: """ 创建资源未找到响应 Args: resource: 资源名称 Returns: Dict: 格式化的未找到响应 """ return ResponseFormatter.error_response( message=f"{resource}未找到", code=404, error_type="NotFoundError" ) @staticmethod def unauthorized_response(message: str = "未授权访问") -> Dict[str, Any]: """ 创建未授权响应 Args: message: 错误消息 Returns: Dict: 格式化的未授权响应 """ return ResponseFormatter.error_response( message=message, code=401, error_type="UnauthorizedError" ) @staticmethod def forbidden_response(message: str = "禁止访问") -> Dict[str, Any]: """ 创建禁止访问响应 Args: message: 错误消息 Returns: Dict: 格式化的禁止访问响应 """ return ResponseFormatter.error_response( message=message, code=403, error_type="ForbiddenError" ) @staticmethod def rate_limit_response(retry_after: int = 60) -> Dict[str, Any]: """ 创建限流响应 Args: retry_after: 重试等待时间(秒) Returns: Dict: 格式化的限流响应 """ return ResponseFormatter.error_response( message="请求过于频繁,请稍后重试", code=429, error_type="RateLimitError", details={"retry_after": retry_after} ) @staticmethod def server_error_response(message: str = "服务器内部错误") -> Dict[str, Any]: """ 创建服务器错误响应 Args: message: 错误消息 Returns: Dict: 格式化的服务器错误响应 """ return ResponseFormatter.error_response( message=message, code=500, error_type="InternalServerError" ) class ErrorFormatter: """错误格式化器""" @staticmethod def format_exception(exception: Exception, include_traceback: bool = False) -> Dict[str, Any]: """ 格式化异常信息 Args: exception: 异常对象 include_traceback: 是否包含堆栈跟踪 Returns: Dict: 格式化的异常信息 """ error_info = { "type": exception.__class__.__name__, "message": str(exception), "timestamp": datetime.now().isoformat() } if include_traceback: import traceback error_info["traceback"] = traceback.format_exc() return error_info @staticmethod def format_validation_errors(errors: List[str], field_errors: Dict[str, List[str]] = {}) -> Dict[str, Any]: """ 格式化验证错误 Args: errors: 通用错误列表 field_errors: 字段特定错误字典 Returns: Dict: 格式化的验证错误 """ formatted = { "general_errors": errors, "error_count": len(errors) } if field_errors: formatted["field_errors"] = field_errors formatted["error_count"] += sum(len(errs) for errs in field_errors.values()) return formatted @staticmethod def format_database_error(error: Exception) -> Dict[str, Any]: """ 格式化数据库错误 Args: error: 数据库异常 Returns: Dict: 格式化的数据库错误 """ error_message = str(error) # 根据错误类型提供更友好的消息 if "duplicate key" in error_message.lower(): user_message = "数据已存在,请检查唯一性约束" elif "foreign key" in error_message.lower(): user_message = "关联数据不存在,请检查数据完整性" elif "not null" in error_message.lower(): user_message = "必填字段不能为空" elif "timeout" in error_message.lower(): user_message = "数据库操作超时,请稍后重试" else: user_message = "数据库操作失败" return { "type": "DatabaseError", "user_message": user_message, "technical_message": error_message, "timestamp": datetime.now().isoformat() } class DataFormatter: """数据格式化器""" @staticmethod def format_session_data(session_data: Dict[str, Any]) -> Dict[str, Any]: """ 格式化会话数据 Args: session_data: 原始会话数据 Returns: Dict: 格式化的会话数据 """ formatted = { "session_id": session_data.get("session_id"), "user_id": session_data.get("user_id"), "session_type": session_data.get("session_type"), "status": session_data.get("status"), "current_step": session_data.get("current_step", 1), "created_at": DataFormatter._format_datetime(session_data.get("created_at")), "updated_at": DataFormatter._format_datetime(session_data.get("updated_at")), "expires_at": DataFormatter._format_datetime(session_data.get("expires_at")) } # 添加可选字段 optional_fields = ["title", "description", "session_config", "metadata"] for field in optional_fields: if field in session_data: formatted[field] = session_data[field] return formatted @staticmethod def format_step_data(step_data: Dict[str, Any]) -> Dict[str, Any]: """ 格式化步骤数据 Args: step_data: 原始步骤数据 Returns: Dict: 格式化的步骤数据 """ formatted = { "step_id": step_data.get("step_id"), "session_id": step_data.get("session_id"), "step_number": step_data.get("step_number"), "step_type": step_data.get("step_type"), "status": step_data.get("status"), "execution_mode": step_data.get("execution_mode", "sync"), "started_at": DataFormatter._format_datetime(step_data.get("started_at")), "completed_at": DataFormatter._format_datetime(step_data.get("completed_at")), "created_at": DataFormatter._format_datetime(step_data.get("created_at")) } # 添加执行结果 if "result_data" in step_data: formatted["result_data"] = step_data["result_data"] # 添加错误信息 if "error_message" in step_data: formatted["error_message"] = step_data["error_message"] # 计算执行时长 if step_data.get("started_at") and step_data.get("completed_at"): start_time = step_data["started_at"] end_time = step_data["completed_at"] if isinstance(start_time, datetime) and isinstance(end_time, datetime): duration = (end_time - start_time).total_seconds() formatted["execution_duration_seconds"] = duration return formatted @staticmethod def format_workflow_state_data(state_data: Dict[str, Any]) -> Dict[str, Any]: """ 格式化工作流状态数据 Args: state_data: 原始状态数据 Returns: Dict: 格式化的状态数据 """ formatted = { "state_id": state_data.get("state_id"), "session_id": state_data.get("session_id"), "state_type": state_data.get("state_type"), "status": state_data.get("status"), "priority": state_data.get("priority"), "created_at": DataFormatter._format_datetime(state_data.get("created_at")), "updated_at": DataFormatter._format_datetime(state_data.get("updated_at")), "expires_at": DataFormatter._format_datetime(state_data.get("expires_at")) } # 添加状态特定数据 if "state_data" in state_data: formatted["state_data"] = state_data["state_data"] # 添加可选字段 optional_fields = ["category", "related_step", "assigned_to", "processed_by", "processed_at"] for field in optional_fields: if field in state_data: if field == "processed_at": formatted[field] = DataFormatter._format_datetime(state_data[field]) else: formatted[field] = state_data[field] return formatted @staticmethod def format_memory_data(memory_data: Dict[str, Any]) -> Dict[str, Any]: """ 格式化记忆数据 Args: memory_data: 原始记忆数据 Returns: Dict: 格式化的记忆数据 """ formatted = { "memory_id": memory_data.get("memory_id"), "user_id": memory_data.get("user_id"), "memory_type": memory_data.get("memory_type"), "memory_key": memory_data.get("memory_key"), "category": memory_data.get("category"), "access_level": memory_data.get("access_level"), "importance": memory_data.get("importance"), "weight": memory_data.get("weight"), "access_count": memory_data.get("access_count", 0), "created_at": DataFormatter._format_datetime(memory_data.get("created_at")), "updated_at": DataFormatter._format_datetime(memory_data.get("updated_at")) } # 添加记忆数据(可能需要脱敏) if "memory_data" in memory_data: formatted["memory_data"] = DataFormatter._sanitize_memory_data( memory_data["memory_data"], memory_data.get("sensitivity", "low") ) # 添加可选字段 optional_fields = ["description", "tags", "last_accessed", "expires_at"] for field in optional_fields: if field in memory_data: if field in ["last_accessed", "expires_at"]: formatted[field] = DataFormatter._format_datetime(memory_data[field]) else: formatted[field] = memory_data[field] return formatted @staticmethod def format_list_response(items: List[Dict[str, Any]], formatter_func, total: int = 0, page: int = 0, page_size: int = 0) -> Dict[str, Any]: """ 格式化列表响应 Args: items: 数据项列表 formatter_func: 格式化函数 total: 总数量 page: 页码 page_size: 页大小 Returns: Dict: 格式化的列表响应 """ formatted_items = [formatter_func(item) for item in items] result:Any = {"items": formatted_items} if total > 0: result["total"] = total if page > 0 and page_size > 0: result["pagination"] = { "current_page": page, "page_size": page_size, "total_pages": (total + page_size - 1) // page_size if total else 1 } return result @staticmethod def _format_datetime(dt: Union[datetime, str, None]) -> Optional[str]: """ 格式化日期时间 Args: dt: 日期时间对象或字符串 Returns: Optional[str]: 格式化的日期时间字符串 """ if dt is None: return None if isinstance(dt, datetime): return dt.isoformat() if isinstance(dt, str): return dt return str(dt) @staticmethod def _sanitize_memory_data(memory_data: Dict[str, Any], sensitivity: str) -> Dict[str, Any]: """ 根据敏感度脱敏记忆数据 Args: memory_data: 记忆数据 sensitivity: 敏感度级别 Returns: Dict: 脱敏后的数据 """ if sensitivity in ["high", "confidential"]: # 高敏感度数据需要脱敏 sanitized = {} for key, value in memory_data.items(): if key in ["password", "token", "secret", "private_key"]: sanitized[key] = "***" elif isinstance(value, str) and len(value) > 10: # 长字符串部分脱敏 sanitized[key] = value[:3] + "***" + value[-3:] else: sanitized[key] = value return sanitized return memory_data.copy() class DateTimeEncoder(json.JSONEncoder): """自定义JSON编码器,处理日期时间对象""" def default(self, obj): if isinstance(obj, datetime): return obj.isoformat() elif isinstance(obj, date): return obj.isoformat() elif isinstance(obj, Decimal): return float(obj) return super().default(obj) def format_file_size(size_bytes: int) -> str: """ 格式化文件大小 Args: size_bytes: 字节数 Returns: str: 格式化的文件大小 """ if size_bytes == 0: return "0 B" size_names = ["B", "KB", "MB", "GB", "TB"] import math i = int(math.floor(math.log(size_bytes, 1024))) p = math.pow(1024, i) s = round(size_bytes / p, 2) return f"{s} {size_names[i]}" def format_duration(seconds: float) -> str: """ 格式化持续时间 Args: seconds: 秒数 Returns: str: 格式化的持续时间 """ if seconds < 60: return f"{seconds:.2f}秒" elif seconds < 3600: minutes = seconds / 60 return f"{minutes:.2f}分钟" else: hours = seconds / 3600 return f"{hours:.2f}小时" def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str: """ 截断文本 Args: text: 原始文本 max_length: 最大长度 suffix: 后缀 Returns: str: 截断后的文本 """ if len(text) <= max_length: return text return text[:max_length - len(suffix)] + suffix def mask_sensitive_info(text: str, mask_char: str = "*", visible_chars: int = 3) -> str: """ 遮蔽敏感信息 Args: text: 原始文本 mask_char: 遮蔽字符 visible_chars: 可见字符数 Returns: str: 遮蔽后的文本 """ if len(text) <= visible_chars * 2: return mask_char * len(text) return text[:visible_chars] + mask_char * (len(text) - visible_chars * 2) + text[-visible_chars:]