agent-writer/utils/formatters.py
2025-09-11 18:34:03 +08:00

655 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
格式化器模块
提供响应格式化、数据转换和输出格式化功能
"""
import json
from datetime import datetime, date
from typing import Dict, List, Optional, Any, Union
from decimal import Decimal
class ResponseFormatter:
"""响应格式化器"""
@staticmethod
def success_response(data: Any = None, message: str = "操作成功",
code: int = 200, extra: Dict[str, Any] = {}) -> Dict[str, Any]:
"""
创建成功响应
Args:
data: 响应数据
message: 响应消息
code: 状态码
extra: 额外信息
Returns:
Dict: 格式化的响应
"""
response = {
"success": True,
"code": code,
"message": message,
"timestamp": datetime.now().isoformat(),
"data": data
}
if extra:
response.update(extra)
return response
@staticmethod
def error_response(message: str = "操作失败", code: int = 400,
error_type: str = "ValidationError",
details: Dict[str, Any] = {}) -> Dict[str, Any]:
"""
创建错误响应
Args:
message: 错误消息
code: 错误码
error_type: 错误类型
details: 错误详情
Returns:
Dict: 格式化的错误响应
"""
response = {
"success": False,
"code": code,
"message": message,
"error_type": error_type,
"timestamp": datetime.now().isoformat()
}
if details:
response["details"] = details
return response
@staticmethod
def paginated_response(items: List[Any], total: int, page: int,
page_size: int, message: str = "查询成功") -> Dict[str, Any]:
"""
创建分页响应
Args:
items: 数据项列表
total: 总数量
page: 当前页码
page_size: 页大小
message: 响应消息
Returns:
Dict: 格式化的分页响应
"""
total_pages = (total + page_size - 1) // page_size
pagination = {
"current_page": page,
"page_size": page_size,
"total_items": total,
"total_pages": total_pages,
"has_next": page < total_pages,
"has_prev": page > 1
}
return ResponseFormatter.success_response(
data={
"items": items,
"pagination": pagination
},
message=message
)
@staticmethod
def stream_response(data: Any, event_type: str = "data",
event_id: str = "") -> str:
"""
创建流式响应
Args:
data: 响应数据
event_type: 事件类型
event_id: 事件ID
Returns:
str: SSE格式的响应
"""
lines = []
if event_id:
lines.append(f"id: {event_id}")
lines.append(f"event: {event_type}")
# 将数据转换为JSON字符串
if isinstance(data, (dict, list)):
data_str = json.dumps(data, ensure_ascii=False, cls=DateTimeEncoder)
else:
data_str = str(data)
lines.append(f"data: {data_str}")
lines.append("") # 空行表示事件结束
return "\n".join(lines)
@staticmethod
def validation_error_response(errors: List[str],
warnings: List[str] = []) -> Dict[str, Any]:
"""
创建验证错误响应
Args:
errors: 错误列表
warnings: 警告列表
Returns:
Dict: 格式化的验证错误响应
"""
details = {"errors": errors}
if warnings:
details["warnings"] = warnings
return ResponseFormatter.error_response(
message="数据验证失败",
code=422,
error_type="ValidationError",
details=details
)
@staticmethod
def not_found_response(resource: str = "资源") -> Dict[str, Any]:
"""
创建资源未找到响应
Args:
resource: 资源名称
Returns:
Dict: 格式化的未找到响应
"""
return ResponseFormatter.error_response(
message=f"{resource}未找到",
code=404,
error_type="NotFoundError"
)
@staticmethod
def unauthorized_response(message: str = "未授权访问") -> Dict[str, Any]:
"""
创建未授权响应
Args:
message: 错误消息
Returns:
Dict: 格式化的未授权响应
"""
return ResponseFormatter.error_response(
message=message,
code=401,
error_type="UnauthorizedError"
)
@staticmethod
def forbidden_response(message: str = "禁止访问") -> Dict[str, Any]:
"""
创建禁止访问响应
Args:
message: 错误消息
Returns:
Dict: 格式化的禁止访问响应
"""
return ResponseFormatter.error_response(
message=message,
code=403,
error_type="ForbiddenError"
)
@staticmethod
def rate_limit_response(retry_after: int = 60) -> Dict[str, Any]:
"""
创建限流响应
Args:
retry_after: 重试等待时间(秒)
Returns:
Dict: 格式化的限流响应
"""
return ResponseFormatter.error_response(
message="请求过于频繁,请稍后重试",
code=429,
error_type="RateLimitError",
details={"retry_after": retry_after}
)
@staticmethod
def server_error_response(message: str = "服务器内部错误") -> Dict[str, Any]:
"""
创建服务器错误响应
Args:
message: 错误消息
Returns:
Dict: 格式化的服务器错误响应
"""
return ResponseFormatter.error_response(
message=message,
code=500,
error_type="InternalServerError"
)
class ErrorFormatter:
"""错误格式化器"""
@staticmethod
def format_exception(exception: Exception, include_traceback: bool = False) -> Dict[str, Any]:
"""
格式化异常信息
Args:
exception: 异常对象
include_traceback: 是否包含堆栈跟踪
Returns:
Dict: 格式化的异常信息
"""
error_info = {
"type": exception.__class__.__name__,
"message": str(exception),
"timestamp": datetime.now().isoformat()
}
if include_traceback:
import traceback
error_info["traceback"] = traceback.format_exc()
return error_info
@staticmethod
def format_validation_errors(errors: List[str], field_errors: Dict[str, List[str]] = {}) -> Dict[str, Any]:
"""
格式化验证错误
Args:
errors: 通用错误列表
field_errors: 字段特定错误字典
Returns:
Dict: 格式化的验证错误
"""
formatted = {
"general_errors": errors,
"error_count": len(errors)
}
if field_errors:
formatted["field_errors"] = field_errors
formatted["error_count"] += sum(len(errs) for errs in field_errors.values())
return formatted
@staticmethod
def format_database_error(error: Exception) -> Dict[str, Any]:
"""
格式化数据库错误
Args:
error: 数据库异常
Returns:
Dict: 格式化的数据库错误
"""
error_message = str(error)
# 根据错误类型提供更友好的消息
if "duplicate key" in error_message.lower():
user_message = "数据已存在,请检查唯一性约束"
elif "foreign key" in error_message.lower():
user_message = "关联数据不存在,请检查数据完整性"
elif "not null" in error_message.lower():
user_message = "必填字段不能为空"
elif "timeout" in error_message.lower():
user_message = "数据库操作超时,请稍后重试"
else:
user_message = "数据库操作失败"
return {
"type": "DatabaseError",
"user_message": user_message,
"technical_message": error_message,
"timestamp": datetime.now().isoformat()
}
class DataFormatter:
"""数据格式化器"""
@staticmethod
def format_session_data(session_data: Dict[str, Any]) -> Dict[str, Any]:
"""
格式化会话数据
Args:
session_data: 原始会话数据
Returns:
Dict: 格式化的会话数据
"""
formatted = {
"session_id": session_data.get("session_id"),
"user_id": session_data.get("user_id"),
"session_type": session_data.get("session_type"),
"status": session_data.get("status"),
"current_step": session_data.get("current_step", 1),
"created_at": DataFormatter._format_datetime(session_data.get("created_at")),
"updated_at": DataFormatter._format_datetime(session_data.get("updated_at")),
"expires_at": DataFormatter._format_datetime(session_data.get("expires_at"))
}
# 添加可选字段
optional_fields = ["title", "description", "session_config", "metadata"]
for field in optional_fields:
if field in session_data:
formatted[field] = session_data[field]
return formatted
@staticmethod
def format_step_data(step_data: Dict[str, Any]) -> Dict[str, Any]:
"""
格式化步骤数据
Args:
step_data: 原始步骤数据
Returns:
Dict: 格式化的步骤数据
"""
formatted = {
"step_id": step_data.get("step_id"),
"session_id": step_data.get("session_id"),
"step_number": step_data.get("step_number"),
"step_type": step_data.get("step_type"),
"status": step_data.get("status"),
"execution_mode": step_data.get("execution_mode", "sync"),
"started_at": DataFormatter._format_datetime(step_data.get("started_at")),
"completed_at": DataFormatter._format_datetime(step_data.get("completed_at")),
"created_at": DataFormatter._format_datetime(step_data.get("created_at"))
}
# 添加执行结果
if "result_data" in step_data:
formatted["result_data"] = step_data["result_data"]
# 添加错误信息
if "error_message" in step_data:
formatted["error_message"] = step_data["error_message"]
# 计算执行时长
if step_data.get("started_at") and step_data.get("completed_at"):
start_time = step_data["started_at"]
end_time = step_data["completed_at"]
if isinstance(start_time, datetime) and isinstance(end_time, datetime):
duration = (end_time - start_time).total_seconds()
formatted["execution_duration_seconds"] = duration
return formatted
@staticmethod
def format_workflow_state_data(state_data: Dict[str, Any]) -> Dict[str, Any]:
"""
格式化工作流状态数据
Args:
state_data: 原始状态数据
Returns:
Dict: 格式化的状态数据
"""
formatted = {
"state_id": state_data.get("state_id"),
"session_id": state_data.get("session_id"),
"state_type": state_data.get("state_type"),
"status": state_data.get("status"),
"priority": state_data.get("priority"),
"created_at": DataFormatter._format_datetime(state_data.get("created_at")),
"updated_at": DataFormatter._format_datetime(state_data.get("updated_at")),
"expires_at": DataFormatter._format_datetime(state_data.get("expires_at"))
}
# 添加状态特定数据
if "state_data" in state_data:
formatted["state_data"] = state_data["state_data"]
# 添加可选字段
optional_fields = ["category", "related_step", "assigned_to", "processed_by", "processed_at"]
for field in optional_fields:
if field in state_data:
if field == "processed_at":
formatted[field] = DataFormatter._format_datetime(state_data[field])
else:
formatted[field] = state_data[field]
return formatted
@staticmethod
def format_memory_data(memory_data: Dict[str, Any]) -> Dict[str, Any]:
"""
格式化记忆数据
Args:
memory_data: 原始记忆数据
Returns:
Dict: 格式化的记忆数据
"""
formatted = {
"memory_id": memory_data.get("memory_id"),
"user_id": memory_data.get("user_id"),
"memory_type": memory_data.get("memory_type"),
"memory_key": memory_data.get("memory_key"),
"category": memory_data.get("category"),
"access_level": memory_data.get("access_level"),
"importance": memory_data.get("importance"),
"weight": memory_data.get("weight"),
"access_count": memory_data.get("access_count", 0),
"created_at": DataFormatter._format_datetime(memory_data.get("created_at")),
"updated_at": DataFormatter._format_datetime(memory_data.get("updated_at"))
}
# 添加记忆数据(可能需要脱敏)
if "memory_data" in memory_data:
formatted["memory_data"] = DataFormatter._sanitize_memory_data(
memory_data["memory_data"],
memory_data.get("sensitivity", "low")
)
# 添加可选字段
optional_fields = ["description", "tags", "last_accessed", "expires_at"]
for field in optional_fields:
if field in memory_data:
if field in ["last_accessed", "expires_at"]:
formatted[field] = DataFormatter._format_datetime(memory_data[field])
else:
formatted[field] = memory_data[field]
return formatted
@staticmethod
def format_list_response(items: List[Dict[str, Any]],
formatter_func,
total: int = 0,
page: int = 0,
page_size: int = 0) -> Dict[str, Any]:
"""
格式化列表响应
Args:
items: 数据项列表
formatter_func: 格式化函数
total: 总数量
page: 页码
page_size: 页大小
Returns:
Dict: 格式化的列表响应
"""
formatted_items = [formatter_func(item) for item in items]
result:Any = {"items": formatted_items}
if total > 0:
result["total"] = total
if page > 0 and page_size > 0:
result["pagination"] = {
"current_page": page,
"page_size": page_size,
"total_pages": (total + page_size - 1) // page_size if total else 1
}
return result
@staticmethod
def _format_datetime(dt: Union[datetime, str, None]) -> Optional[str]:
"""
格式化日期时间
Args:
dt: 日期时间对象或字符串
Returns:
Optional[str]: 格式化的日期时间字符串
"""
if dt is None:
return None
if isinstance(dt, datetime):
return dt.isoformat()
if isinstance(dt, str):
return dt
return str(dt)
@staticmethod
def _sanitize_memory_data(memory_data: Dict[str, Any], sensitivity: str) -> Dict[str, Any]:
"""
根据敏感度脱敏记忆数据
Args:
memory_data: 记忆数据
sensitivity: 敏感度级别
Returns:
Dict: 脱敏后的数据
"""
if sensitivity in ["high", "confidential"]:
# 高敏感度数据需要脱敏
sanitized = {}
for key, value in memory_data.items():
if key in ["password", "token", "secret", "private_key"]:
sanitized[key] = "***"
elif isinstance(value, str) and len(value) > 10:
# 长字符串部分脱敏
sanitized[key] = value[:3] + "***" + value[-3:]
else:
sanitized[key] = value
return sanitized
return memory_data.copy()
class DateTimeEncoder(json.JSONEncoder):
"""自定义JSON编码器处理日期时间对象"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, date):
return obj.isoformat()
elif isinstance(obj, Decimal):
return float(obj)
return super().default(obj)
def format_file_size(size_bytes: int) -> str:
"""
格式化文件大小
Args:
size_bytes: 字节数
Returns:
str: 格式化的文件大小
"""
if size_bytes == 0:
return "0 B"
size_names = ["B", "KB", "MB", "GB", "TB"]
import math
i = int(math.floor(math.log(size_bytes, 1024)))
p = math.pow(1024, i)
s = round(size_bytes / p, 2)
return f"{s} {size_names[i]}"
def format_duration(seconds: float) -> str:
"""
格式化持续时间
Args:
seconds: 秒数
Returns:
str: 格式化的持续时间
"""
if seconds < 60:
return f"{seconds:.2f}"
elif seconds < 3600:
minutes = seconds / 60
return f"{minutes:.2f}分钟"
else:
hours = seconds / 3600
return f"{hours:.2f}小时"
def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str:
"""
截断文本
Args:
text: 原始文本
max_length: 最大长度
suffix: 后缀
Returns:
str: 截断后的文本
"""
if len(text) <= max_length:
return text
return text[:max_length - len(suffix)] + suffix
def mask_sensitive_info(text: str, mask_char: str = "*", visible_chars: int = 3) -> str:
"""
遮蔽敏感信息
Args:
text: 原始文本
mask_char: 遮蔽字符
visible_chars: 可见字符数
Returns:
str: 遮蔽后的文本
"""
if len(text) <= visible_chars * 2:
return mask_char * len(text)
return text[:visible_chars] + mask_char * (len(text) - visible_chars * 2) + text[-visible_chars:]