655 lines
20 KiB
Python
655 lines
20 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
格式化器模块
|
||
|
||
提供响应格式化、数据转换和输出格式化功能
|
||
"""
|
||
|
||
import json
|
||
from datetime import datetime, date
|
||
from typing import Dict, List, Optional, Any, Union
|
||
from decimal import Decimal
|
||
|
||
class ResponseFormatter:
|
||
"""响应格式化器"""
|
||
|
||
@staticmethod
|
||
def success_response(data: Any = None, message: str = "操作成功",
|
||
code: int = 200, extra: Dict[str, Any] = {}) -> Dict[str, Any]:
|
||
"""
|
||
创建成功响应
|
||
|
||
Args:
|
||
data: 响应数据
|
||
message: 响应消息
|
||
code: 状态码
|
||
extra: 额外信息
|
||
|
||
Returns:
|
||
Dict: 格式化的响应
|
||
"""
|
||
response = {
|
||
"success": True,
|
||
"code": code,
|
||
"message": message,
|
||
"timestamp": datetime.now().isoformat(),
|
||
"data": data
|
||
}
|
||
|
||
if extra:
|
||
response.update(extra)
|
||
|
||
return response
|
||
|
||
@staticmethod
|
||
def error_response(message: str = "操作失败", code: int = 400,
|
||
error_type: str = "ValidationError",
|
||
details: Dict[str, Any] = {}) -> Dict[str, Any]:
|
||
"""
|
||
创建错误响应
|
||
|
||
Args:
|
||
message: 错误消息
|
||
code: 错误码
|
||
error_type: 错误类型
|
||
details: 错误详情
|
||
|
||
Returns:
|
||
Dict: 格式化的错误响应
|
||
"""
|
||
response = {
|
||
"success": False,
|
||
"code": code,
|
||
"message": message,
|
||
"error_type": error_type,
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
|
||
if details:
|
||
response["details"] = details
|
||
|
||
return response
|
||
|
||
@staticmethod
|
||
def paginated_response(items: List[Any], total: int, page: int,
|
||
page_size: int, message: str = "查询成功") -> Dict[str, Any]:
|
||
"""
|
||
创建分页响应
|
||
|
||
Args:
|
||
items: 数据项列表
|
||
total: 总数量
|
||
page: 当前页码
|
||
page_size: 页大小
|
||
message: 响应消息
|
||
|
||
Returns:
|
||
Dict: 格式化的分页响应
|
||
"""
|
||
total_pages = (total + page_size - 1) // page_size
|
||
|
||
pagination = {
|
||
"current_page": page,
|
||
"page_size": page_size,
|
||
"total_items": total,
|
||
"total_pages": total_pages,
|
||
"has_next": page < total_pages,
|
||
"has_prev": page > 1
|
||
}
|
||
|
||
return ResponseFormatter.success_response(
|
||
data={
|
||
"items": items,
|
||
"pagination": pagination
|
||
},
|
||
message=message
|
||
)
|
||
|
||
@staticmethod
|
||
def stream_response(data: Any, event_type: str = "data",
|
||
event_id: str = "") -> str:
|
||
"""
|
||
创建流式响应
|
||
|
||
Args:
|
||
data: 响应数据
|
||
event_type: 事件类型
|
||
event_id: 事件ID
|
||
|
||
Returns:
|
||
str: SSE格式的响应
|
||
"""
|
||
lines = []
|
||
|
||
if event_id:
|
||
lines.append(f"id: {event_id}")
|
||
|
||
lines.append(f"event: {event_type}")
|
||
|
||
# 将数据转换为JSON字符串
|
||
if isinstance(data, (dict, list)):
|
||
data_str = json.dumps(data, ensure_ascii=False, cls=DateTimeEncoder)
|
||
else:
|
||
data_str = str(data)
|
||
|
||
lines.append(f"data: {data_str}")
|
||
lines.append("") # 空行表示事件结束
|
||
|
||
return "\n".join(lines)
|
||
|
||
@staticmethod
|
||
def validation_error_response(errors: List[str],
|
||
warnings: List[str] = []) -> Dict[str, Any]:
|
||
"""
|
||
创建验证错误响应
|
||
|
||
Args:
|
||
errors: 错误列表
|
||
warnings: 警告列表
|
||
|
||
Returns:
|
||
Dict: 格式化的验证错误响应
|
||
"""
|
||
details = {"errors": errors}
|
||
|
||
if warnings:
|
||
details["warnings"] = warnings
|
||
|
||
return ResponseFormatter.error_response(
|
||
message="数据验证失败",
|
||
code=422,
|
||
error_type="ValidationError",
|
||
details=details
|
||
)
|
||
|
||
@staticmethod
|
||
def not_found_response(resource: str = "资源") -> Dict[str, Any]:
|
||
"""
|
||
创建资源未找到响应
|
||
|
||
Args:
|
||
resource: 资源名称
|
||
|
||
Returns:
|
||
Dict: 格式化的未找到响应
|
||
"""
|
||
return ResponseFormatter.error_response(
|
||
message=f"{resource}未找到",
|
||
code=404,
|
||
error_type="NotFoundError"
|
||
)
|
||
|
||
@staticmethod
|
||
def unauthorized_response(message: str = "未授权访问") -> Dict[str, Any]:
|
||
"""
|
||
创建未授权响应
|
||
|
||
Args:
|
||
message: 错误消息
|
||
|
||
Returns:
|
||
Dict: 格式化的未授权响应
|
||
"""
|
||
return ResponseFormatter.error_response(
|
||
message=message,
|
||
code=401,
|
||
error_type="UnauthorizedError"
|
||
)
|
||
|
||
@staticmethod
|
||
def forbidden_response(message: str = "禁止访问") -> Dict[str, Any]:
|
||
"""
|
||
创建禁止访问响应
|
||
|
||
Args:
|
||
message: 错误消息
|
||
|
||
Returns:
|
||
Dict: 格式化的禁止访问响应
|
||
"""
|
||
return ResponseFormatter.error_response(
|
||
message=message,
|
||
code=403,
|
||
error_type="ForbiddenError"
|
||
)
|
||
|
||
@staticmethod
|
||
def rate_limit_response(retry_after: int = 60) -> Dict[str, Any]:
|
||
"""
|
||
创建限流响应
|
||
|
||
Args:
|
||
retry_after: 重试等待时间(秒)
|
||
|
||
Returns:
|
||
Dict: 格式化的限流响应
|
||
"""
|
||
return ResponseFormatter.error_response(
|
||
message="请求过于频繁,请稍后重试",
|
||
code=429,
|
||
error_type="RateLimitError",
|
||
details={"retry_after": retry_after}
|
||
)
|
||
|
||
@staticmethod
|
||
def server_error_response(message: str = "服务器内部错误") -> Dict[str, Any]:
|
||
"""
|
||
创建服务器错误响应
|
||
|
||
Args:
|
||
message: 错误消息
|
||
|
||
Returns:
|
||
Dict: 格式化的服务器错误响应
|
||
"""
|
||
return ResponseFormatter.error_response(
|
||
message=message,
|
||
code=500,
|
||
error_type="InternalServerError"
|
||
)
|
||
|
||
class ErrorFormatter:
|
||
"""错误格式化器"""
|
||
|
||
@staticmethod
|
||
def format_exception(exception: Exception, include_traceback: bool = False) -> Dict[str, Any]:
|
||
"""
|
||
格式化异常信息
|
||
|
||
Args:
|
||
exception: 异常对象
|
||
include_traceback: 是否包含堆栈跟踪
|
||
|
||
Returns:
|
||
Dict: 格式化的异常信息
|
||
"""
|
||
error_info = {
|
||
"type": exception.__class__.__name__,
|
||
"message": str(exception),
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
|
||
if include_traceback:
|
||
import traceback
|
||
error_info["traceback"] = traceback.format_exc()
|
||
|
||
return error_info
|
||
|
||
@staticmethod
|
||
def format_validation_errors(errors: List[str], field_errors: Dict[str, List[str]] = {}) -> Dict[str, Any]:
|
||
"""
|
||
格式化验证错误
|
||
|
||
Args:
|
||
errors: 通用错误列表
|
||
field_errors: 字段特定错误字典
|
||
|
||
Returns:
|
||
Dict: 格式化的验证错误
|
||
"""
|
||
formatted = {
|
||
"general_errors": errors,
|
||
"error_count": len(errors)
|
||
}
|
||
|
||
if field_errors:
|
||
formatted["field_errors"] = field_errors
|
||
formatted["error_count"] += sum(len(errs) for errs in field_errors.values())
|
||
|
||
return formatted
|
||
|
||
@staticmethod
|
||
def format_database_error(error: Exception) -> Dict[str, Any]:
|
||
"""
|
||
格式化数据库错误
|
||
|
||
Args:
|
||
error: 数据库异常
|
||
|
||
Returns:
|
||
Dict: 格式化的数据库错误
|
||
"""
|
||
error_message = str(error)
|
||
|
||
# 根据错误类型提供更友好的消息
|
||
if "duplicate key" in error_message.lower():
|
||
user_message = "数据已存在,请检查唯一性约束"
|
||
elif "foreign key" in error_message.lower():
|
||
user_message = "关联数据不存在,请检查数据完整性"
|
||
elif "not null" in error_message.lower():
|
||
user_message = "必填字段不能为空"
|
||
elif "timeout" in error_message.lower():
|
||
user_message = "数据库操作超时,请稍后重试"
|
||
else:
|
||
user_message = "数据库操作失败"
|
||
|
||
return {
|
||
"type": "DatabaseError",
|
||
"user_message": user_message,
|
||
"technical_message": error_message,
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
|
||
class DataFormatter:
|
||
"""数据格式化器"""
|
||
|
||
@staticmethod
|
||
def format_session_data(session_data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
格式化会话数据
|
||
|
||
Args:
|
||
session_data: 原始会话数据
|
||
|
||
Returns:
|
||
Dict: 格式化的会话数据
|
||
"""
|
||
formatted = {
|
||
"session_id": session_data.get("session_id"),
|
||
"user_id": session_data.get("user_id"),
|
||
"session_type": session_data.get("session_type"),
|
||
"status": session_data.get("status"),
|
||
"current_step": session_data.get("current_step", 1),
|
||
"created_at": DataFormatter._format_datetime(session_data.get("created_at")),
|
||
"updated_at": DataFormatter._format_datetime(session_data.get("updated_at")),
|
||
"expires_at": DataFormatter._format_datetime(session_data.get("expires_at"))
|
||
}
|
||
|
||
# 添加可选字段
|
||
optional_fields = ["title", "description", "session_config", "metadata"]
|
||
for field in optional_fields:
|
||
if field in session_data:
|
||
formatted[field] = session_data[field]
|
||
|
||
return formatted
|
||
|
||
@staticmethod
|
||
def format_step_data(step_data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
格式化步骤数据
|
||
|
||
Args:
|
||
step_data: 原始步骤数据
|
||
|
||
Returns:
|
||
Dict: 格式化的步骤数据
|
||
"""
|
||
formatted = {
|
||
"step_id": step_data.get("step_id"),
|
||
"session_id": step_data.get("session_id"),
|
||
"step_number": step_data.get("step_number"),
|
||
"step_type": step_data.get("step_type"),
|
||
"status": step_data.get("status"),
|
||
"execution_mode": step_data.get("execution_mode", "sync"),
|
||
"started_at": DataFormatter._format_datetime(step_data.get("started_at")),
|
||
"completed_at": DataFormatter._format_datetime(step_data.get("completed_at")),
|
||
"created_at": DataFormatter._format_datetime(step_data.get("created_at"))
|
||
}
|
||
|
||
# 添加执行结果
|
||
if "result_data" in step_data:
|
||
formatted["result_data"] = step_data["result_data"]
|
||
|
||
# 添加错误信息
|
||
if "error_message" in step_data:
|
||
formatted["error_message"] = step_data["error_message"]
|
||
|
||
# 计算执行时长
|
||
if step_data.get("started_at") and step_data.get("completed_at"):
|
||
start_time = step_data["started_at"]
|
||
end_time = step_data["completed_at"]
|
||
if isinstance(start_time, datetime) and isinstance(end_time, datetime):
|
||
duration = (end_time - start_time).total_seconds()
|
||
formatted["execution_duration_seconds"] = duration
|
||
|
||
return formatted
|
||
|
||
@staticmethod
|
||
def format_workflow_state_data(state_data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
格式化工作流状态数据
|
||
|
||
Args:
|
||
state_data: 原始状态数据
|
||
|
||
Returns:
|
||
Dict: 格式化的状态数据
|
||
"""
|
||
formatted = {
|
||
"state_id": state_data.get("state_id"),
|
||
"session_id": state_data.get("session_id"),
|
||
"state_type": state_data.get("state_type"),
|
||
"status": state_data.get("status"),
|
||
"priority": state_data.get("priority"),
|
||
"created_at": DataFormatter._format_datetime(state_data.get("created_at")),
|
||
"updated_at": DataFormatter._format_datetime(state_data.get("updated_at")),
|
||
"expires_at": DataFormatter._format_datetime(state_data.get("expires_at"))
|
||
}
|
||
|
||
# 添加状态特定数据
|
||
if "state_data" in state_data:
|
||
formatted["state_data"] = state_data["state_data"]
|
||
|
||
# 添加可选字段
|
||
optional_fields = ["category", "related_step", "assigned_to", "processed_by", "processed_at"]
|
||
for field in optional_fields:
|
||
if field in state_data:
|
||
if field == "processed_at":
|
||
formatted[field] = DataFormatter._format_datetime(state_data[field])
|
||
else:
|
||
formatted[field] = state_data[field]
|
||
|
||
return formatted
|
||
|
||
@staticmethod
|
||
def format_memory_data(memory_data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
格式化记忆数据
|
||
|
||
Args:
|
||
memory_data: 原始记忆数据
|
||
|
||
Returns:
|
||
Dict: 格式化的记忆数据
|
||
"""
|
||
formatted = {
|
||
"memory_id": memory_data.get("memory_id"),
|
||
"user_id": memory_data.get("user_id"),
|
||
"memory_type": memory_data.get("memory_type"),
|
||
"memory_key": memory_data.get("memory_key"),
|
||
"category": memory_data.get("category"),
|
||
"access_level": memory_data.get("access_level"),
|
||
"importance": memory_data.get("importance"),
|
||
"weight": memory_data.get("weight"),
|
||
"access_count": memory_data.get("access_count", 0),
|
||
"created_at": DataFormatter._format_datetime(memory_data.get("created_at")),
|
||
"updated_at": DataFormatter._format_datetime(memory_data.get("updated_at"))
|
||
}
|
||
|
||
# 添加记忆数据(可能需要脱敏)
|
||
if "memory_data" in memory_data:
|
||
formatted["memory_data"] = DataFormatter._sanitize_memory_data(
|
||
memory_data["memory_data"],
|
||
memory_data.get("sensitivity", "low")
|
||
)
|
||
|
||
# 添加可选字段
|
||
optional_fields = ["description", "tags", "last_accessed", "expires_at"]
|
||
for field in optional_fields:
|
||
if field in memory_data:
|
||
if field in ["last_accessed", "expires_at"]:
|
||
formatted[field] = DataFormatter._format_datetime(memory_data[field])
|
||
else:
|
||
formatted[field] = memory_data[field]
|
||
|
||
return formatted
|
||
|
||
@staticmethod
|
||
def format_list_response(items: List[Dict[str, Any]],
|
||
formatter_func,
|
||
total: int = 0,
|
||
page: int = 0,
|
||
page_size: int = 0) -> Dict[str, Any]:
|
||
"""
|
||
格式化列表响应
|
||
|
||
Args:
|
||
items: 数据项列表
|
||
formatter_func: 格式化函数
|
||
total: 总数量
|
||
page: 页码
|
||
page_size: 页大小
|
||
|
||
Returns:
|
||
Dict: 格式化的列表响应
|
||
"""
|
||
formatted_items = [formatter_func(item) for item in items]
|
||
|
||
result:Any = {"items": formatted_items}
|
||
|
||
if total > 0:
|
||
result["total"] = total
|
||
|
||
if page > 0 and page_size > 0:
|
||
result["pagination"] = {
|
||
"current_page": page,
|
||
"page_size": page_size,
|
||
"total_pages": (total + page_size - 1) // page_size if total else 1
|
||
}
|
||
|
||
return result
|
||
|
||
@staticmethod
|
||
def _format_datetime(dt: Union[datetime, str, None]) -> Optional[str]:
|
||
"""
|
||
格式化日期时间
|
||
|
||
Args:
|
||
dt: 日期时间对象或字符串
|
||
|
||
Returns:
|
||
Optional[str]: 格式化的日期时间字符串
|
||
"""
|
||
if dt is None:
|
||
return None
|
||
|
||
if isinstance(dt, datetime):
|
||
return dt.isoformat()
|
||
|
||
if isinstance(dt, str):
|
||
return dt
|
||
|
||
return str(dt)
|
||
|
||
@staticmethod
|
||
def _sanitize_memory_data(memory_data: Dict[str, Any], sensitivity: str) -> Dict[str, Any]:
|
||
"""
|
||
根据敏感度脱敏记忆数据
|
||
|
||
Args:
|
||
memory_data: 记忆数据
|
||
sensitivity: 敏感度级别
|
||
|
||
Returns:
|
||
Dict: 脱敏后的数据
|
||
"""
|
||
if sensitivity in ["high", "confidential"]:
|
||
# 高敏感度数据需要脱敏
|
||
sanitized = {}
|
||
for key, value in memory_data.items():
|
||
if key in ["password", "token", "secret", "private_key"]:
|
||
sanitized[key] = "***"
|
||
elif isinstance(value, str) and len(value) > 10:
|
||
# 长字符串部分脱敏
|
||
sanitized[key] = value[:3] + "***" + value[-3:]
|
||
else:
|
||
sanitized[key] = value
|
||
return sanitized
|
||
|
||
return memory_data.copy()
|
||
|
||
class DateTimeEncoder(json.JSONEncoder):
|
||
"""自定义JSON编码器,处理日期时间对象"""
|
||
|
||
def default(self, obj):
|
||
if isinstance(obj, datetime):
|
||
return obj.isoformat()
|
||
elif isinstance(obj, date):
|
||
return obj.isoformat()
|
||
elif isinstance(obj, Decimal):
|
||
return float(obj)
|
||
return super().default(obj)
|
||
|
||
def format_file_size(size_bytes: int) -> str:
|
||
"""
|
||
格式化文件大小
|
||
|
||
Args:
|
||
size_bytes: 字节数
|
||
|
||
Returns:
|
||
str: 格式化的文件大小
|
||
"""
|
||
if size_bytes == 0:
|
||
return "0 B"
|
||
|
||
size_names = ["B", "KB", "MB", "GB", "TB"]
|
||
import math
|
||
i = int(math.floor(math.log(size_bytes, 1024)))
|
||
p = math.pow(1024, i)
|
||
s = round(size_bytes / p, 2)
|
||
return f"{s} {size_names[i]}"
|
||
|
||
def format_duration(seconds: float) -> str:
|
||
"""
|
||
格式化持续时间
|
||
|
||
Args:
|
||
seconds: 秒数
|
||
|
||
Returns:
|
||
str: 格式化的持续时间
|
||
"""
|
||
if seconds < 60:
|
||
return f"{seconds:.2f}秒"
|
||
elif seconds < 3600:
|
||
minutes = seconds / 60
|
||
return f"{minutes:.2f}分钟"
|
||
else:
|
||
hours = seconds / 3600
|
||
return f"{hours:.2f}小时"
|
||
|
||
def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str:
|
||
"""
|
||
截断文本
|
||
|
||
Args:
|
||
text: 原始文本
|
||
max_length: 最大长度
|
||
suffix: 后缀
|
||
|
||
Returns:
|
||
str: 截断后的文本
|
||
"""
|
||
if len(text) <= max_length:
|
||
return text
|
||
|
||
return text[:max_length - len(suffix)] + suffix
|
||
|
||
def mask_sensitive_info(text: str, mask_char: str = "*", visible_chars: int = 3) -> str:
|
||
"""
|
||
遮蔽敏感信息
|
||
|
||
Args:
|
||
text: 原始文本
|
||
mask_char: 遮蔽字符
|
||
visible_chars: 可见字符数
|
||
|
||
Returns:
|
||
str: 遮蔽后的文本
|
||
"""
|
||
if len(text) <= visible_chars * 2:
|
||
return mask_char * len(text)
|
||
|
||
return text[:visible_chars] + mask_char * (len(text) - visible_chars * 2) + text[-visible_chars:] |