563 lines
19 KiB
Python
563 lines
19 KiB
Python
from typing import Any, Optional
|
||
import mimetypes
|
||
from io import StringIO
|
||
import os
|
||
import tos
|
||
import urllib3
|
||
from urllib3.exceptions import InsecureRequestWarning
|
||
from config import API_CONFIG
|
||
# 火山对象存储
|
||
class TOSClient:
|
||
def __init__(
|
||
self,
|
||
access_key_id: str,
|
||
access_key_secret: str,
|
||
endpoint: str,
|
||
region: str,
|
||
bucket_name: str,
|
||
self_domain: str,
|
||
disable_ssl_warnings: bool = True
|
||
):
|
||
"""
|
||
初始化OSS客户端
|
||
|
||
Args:
|
||
access_key_id: ak
|
||
access_key_secret: sk
|
||
endpoint: OSS访问端点 (如: https://oss-cn-hangzhou.aliyuncs.com)
|
||
bucket_name: 存储桶名称
|
||
self_domain: 自定义域名
|
||
disable_ssl_warnings: 是否禁用SSL警告
|
||
"""
|
||
# 禁用SSL警告(如果需要)
|
||
if disable_ssl_warnings:
|
||
urllib3.disable_warnings(InsecureRequestWarning)
|
||
sts_token: str = "token_test"
|
||
self.bucket_name = bucket_name
|
||
self.self_domain = self_domain
|
||
self.endpoint = endpoint
|
||
self.client = tos.TosClientV2(
|
||
ak=access_key_id,
|
||
sk=access_key_secret,
|
||
endpoint=self_domain,
|
||
region=region,
|
||
is_custom_domain=True,
|
||
# bucket_name,
|
||
# security_token=sts_token,
|
||
connection_time=30, socket_timeout=60, max_retry_count=3
|
||
)
|
||
|
||
def get_base_url(self, object_key: str) -> str:
|
||
"""获取基础URL(不带签名参数)"""
|
||
# endpoint = self.endpoint.replace('https://', '').replace('http://', '')
|
||
return f"https://{self.self_domain}/{object_key}"
|
||
|
||
def generate_url(self, object_key: str, expires: int = 3600) -> str:
|
||
"""生成带签名的临时访问URL"""
|
||
# 生成签名URL
|
||
pre_signed_url_output = self.client.pre_signed_url(
|
||
tos.HttpMethodType.Http_Method_Get,
|
||
bucket=self.bucket_name,
|
||
key=object_key,
|
||
expires=expires)
|
||
return pre_signed_url_output.signed_url
|
||
|
||
def upload_string(
|
||
self,
|
||
content_str: str,
|
||
object_key: str,
|
||
headers: Optional[dict] = None,
|
||
return_url: bool = True,
|
||
) -> str:
|
||
"""
|
||
上传本地文件到OSS
|
||
|
||
Args:
|
||
local_file_path: 本地文件路径
|
||
object_key: OSS对象键(路径),如果为None则使用本地文件名
|
||
headers: 自定义HTTP头
|
||
|
||
Returns:
|
||
str: 文件在OSS的公开URL
|
||
|
||
Raises:
|
||
Exception: 如果上传失败
|
||
"""
|
||
|
||
|
||
try:
|
||
# if headers is None:
|
||
# headers = {}
|
||
# if content_type and 'Content-Type' not in headers:
|
||
# headers['Content-Type'] = content_type
|
||
content = StringIO(content_str)
|
||
result = self.client.put_object(
|
||
bucket=self.bucket_name,
|
||
key=object_key,
|
||
content_type='text/plain',
|
||
content=content,
|
||
)
|
||
|
||
# HTTP状态码
|
||
print('upload_string http status code:{}'.format(result.status_code))
|
||
# 请求ID。请求ID是本次请求的唯一标识,建议在日志中添加此参数
|
||
# print('request_id: {}'.format(result.request_id))
|
||
# hash_crc64_ecma 表示该对象的64位CRC值, 可用于验证上传对象的完整性
|
||
# print('crc64: {}'.format(result.hash_crc64_ecma))
|
||
if result.status_code != 200:
|
||
raise Exception(f"上传失败,HTTP状态码: {result.status_code}")
|
||
|
||
return self.get_base_url(object_key) if return_url else object_key # 修改返回逻辑
|
||
except Exception as e:
|
||
raise Exception(f"上传文件到OSS失败: {str(e)}")
|
||
|
||
|
||
def upload_file(
|
||
self,
|
||
local_file_path: str,
|
||
object_key: Optional[str] = None,
|
||
headers: Optional[dict] = None,
|
||
return_url: bool = True,
|
||
expires: int = 3600 # 新增参数,默认1小时
|
||
) -> str:
|
||
"""
|
||
上传本地文件到OSS
|
||
|
||
Args:
|
||
local_file_path: 本地文件路径
|
||
object_key: OSS对象键(路径),如果为None则使用本地文件名
|
||
headers: 自定义HTTP头
|
||
|
||
Returns:
|
||
str: 文件在OSS的公开URL
|
||
|
||
Raises:
|
||
Exception: 如果上传失败
|
||
"""
|
||
if not os.path.exists(local_file_path):
|
||
raise FileNotFoundError(f"本地文件不存在: {local_file_path}")
|
||
|
||
# 如果没有指定object_key,则使用文件名
|
||
if object_key is None:
|
||
object_key = os.path.basename(local_file_path)
|
||
|
||
# 自动设置Content-Type
|
||
content_type, _ = mimetypes.guess_type(local_file_path)
|
||
|
||
try:
|
||
# file_name为本地文件的完整路径。
|
||
result = self.client.put_object_from_file(
|
||
bucket=self.bucket_name,
|
||
key=object_key,
|
||
content_type=content_type or '',
|
||
file_path=local_file_path,
|
||
)
|
||
|
||
if result.status_code != 200:
|
||
raise Exception(f"上传失败,HTTP状态码: {result.status_code}")
|
||
|
||
return self.get_base_url(object_key) if return_url else object_key # 修改返回逻辑
|
||
except Exception as e:
|
||
raise Exception(f"上传文件到OSS失败: {str(e)}")
|
||
|
||
|
||
def upload_bytes(
|
||
self,
|
||
data: bytes,
|
||
object_key: str,
|
||
content_type: Optional[str] = None,
|
||
headers: Optional[dict] = None,
|
||
return_url: bool = True,
|
||
expires: int = 3600 # 新增参数
|
||
) -> str:
|
||
"""
|
||
上传字节数据到OSS
|
||
Args:
|
||
data: 要上传的字节数据
|
||
object_key: OSS对象键(路径)
|
||
content_type: 内容类型 (如: image/jpeg)
|
||
headers: 自定义HTTP头
|
||
Returns:
|
||
str: 文件在OSS的公开URL
|
||
Raises:
|
||
Exception: 如果上传失败
|
||
"""
|
||
|
||
try:
|
||
result = self.client.put_object(
|
||
bucket=self.bucket_name,
|
||
key=object_key,
|
||
content_type=content_type or 'application/octet-stream',
|
||
content=data,
|
||
)
|
||
|
||
if result.status_code != 200:
|
||
raise Exception(f"上传失败,HTTP状态码: {result.status_code}")
|
||
|
||
return self.get_base_url(object_key) if return_url else object_key # 修改返回逻辑
|
||
except Exception as e:
|
||
raise Exception(f"上传字节数据到OSS失败: {str(e)}")
|
||
|
||
def upload_from_url(
|
||
self,
|
||
url: str,
|
||
object_key: str,
|
||
headers: Optional[dict] = None,
|
||
timeout: int = 30,
|
||
return_url: bool = True,
|
||
expires: int = 3600 # 新增参数
|
||
) -> str:
|
||
"""
|
||
从网络URL下载文件并上传到OSS
|
||
|
||
Args:
|
||
url: 网络文件URL
|
||
object_key: OSS对象键(路径)
|
||
headers: 自定义HTTP头
|
||
timeout: 下载超时时间(秒)
|
||
return_url: 是否返回完整URL
|
||
|
||
Returns:
|
||
str: 文件在OSS的公开URL或object_key
|
||
|
||
Raises:
|
||
Exception: 如果下载或上传失败
|
||
"""
|
||
import requests
|
||
from io import BytesIO
|
||
|
||
if not url.startswith(('http://', 'https://')):
|
||
raise ValueError("URL必须以http://或https://开头")
|
||
|
||
try:
|
||
# 下载文件
|
||
response = requests.get(url, stream=True, timeout=timeout)
|
||
response.raise_for_status()
|
||
|
||
# 获取内容类型
|
||
content_type = response.headers.get('Content-Type', '')
|
||
if not content_type:
|
||
content_type = mimetypes.guess_type(url)[0] or 'application/octet-stream'
|
||
|
||
# 上传到OSS
|
||
return self.upload_bytes(
|
||
data=response.content,
|
||
object_key=object_key,
|
||
content_type=content_type,
|
||
headers=headers,
|
||
return_url=return_url,
|
||
expires=expires # 传递参数
|
||
)
|
||
except requests.exceptions.RequestException as e:
|
||
raise Exception(f"下载网络文件失败: {str(e)}")
|
||
except Exception as e:
|
||
raise Exception(f"上传网络文件到OSS失败: {str(e)}")
|
||
|
||
def _format_object_key(self, object_key: str) -> str:
|
||
"""
|
||
格式化OSS对象键(路径)
|
||
"""
|
||
# 如果object_key包含self_domain,截取self_domain后面的字符作为新的object_key
|
||
if self.self_domain and self.self_domain in object_key:
|
||
# 找到self_domain在object_key中的位置,截取后面的部分
|
||
domain_index = object_key.find(self.self_domain)
|
||
if domain_index != -1:
|
||
# 截取self_domain后面的部分,去掉开头的斜杠
|
||
object_key = object_key[domain_index + len(self.self_domain):].lstrip('/')
|
||
return object_key
|
||
|
||
# 删除文件
|
||
def delete_file(self, object_key: str) -> bool:
|
||
"""
|
||
删除OSS上的文件
|
||
|
||
Args:
|
||
object_key: OSS对象键(路径)
|
||
|
||
Returns:
|
||
bool: 删除是否成功
|
||
"""
|
||
try:
|
||
self.client.delete_object(
|
||
bucket=self.bucket_name,
|
||
key=self._format_object_key(object_key),
|
||
)
|
||
return True
|
||
except Exception as e:
|
||
print(f"删除文件失败: {str(e)}")
|
||
return False
|
||
|
||
def download_file(self, object_key: str) -> bytes:
|
||
"""
|
||
从TOS下载文件并返回文件数据
|
||
|
||
Args:
|
||
object_key: OSS对象键(路径)
|
||
|
||
Returns:
|
||
bytes: 文件的字节数据
|
||
|
||
Raises:
|
||
Exception: 如果下载失败
|
||
"""
|
||
try:
|
||
object_key = self._format_object_key(object_key)
|
||
|
||
object_stream = self.client.get_object(
|
||
bucket=self.bucket_name,
|
||
key=object_key,
|
||
)
|
||
content = object_stream.read() or b''
|
||
if not content:
|
||
raise Exception(f"文件内容为空: {object_key}")
|
||
return content
|
||
except tos.exceptions.TosClientError as e:
|
||
# 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常
|
||
print('TOS下载 fail with client error, message:{}, cause: {}'.format(e.message, e.cause))
|
||
raise Exception(f"下载异常: {object_key} {e.message}")
|
||
except tos.exceptions.TosServerError as e:
|
||
# 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息
|
||
print('TOS下载 fail with server error, code: {}'.format(e.code))
|
||
# request id 可定位具体问题,强烈建议日志中保存
|
||
print('TOS下载 error with request id: {}'.format(e.request_id))
|
||
print('TOS下载 error with message: {}'.format(e.message))
|
||
print('TOS下载 error with http code: {}'.format(e.status_code))
|
||
print('TOS下载 error with ec: {}'.format(e.ec))
|
||
print('TOS下载 error with request url: {}'.format(e.request_url))
|
||
raise Exception(f"下载异常: {object_key} {e.message}")
|
||
except Exception as e:
|
||
raise Exception(f"下载文件失败: {str(e)}")
|
||
|
||
|
||
class TOSChunkUploader:
|
||
"""TOS分片上传类"""
|
||
|
||
def __init__(self, tos_client: TOSClient):
|
||
"""
|
||
初始化分片上传器
|
||
|
||
Args:
|
||
tos_client: TOS客户端实例
|
||
"""
|
||
self.client = tos_client.client
|
||
self.bucket_name = tos_client.bucket_name
|
||
self.self_domain = tos_client.self_domain
|
||
|
||
def init_multipart_upload(self, object_key: str, content_type: Optional[str] = None) -> str | None:
|
||
"""
|
||
初始化分片上传
|
||
|
||
Args:
|
||
object_key: 对象键
|
||
content_type: 内容类型
|
||
|
||
Returns:
|
||
str: 上传ID
|
||
|
||
Raises:
|
||
Exception: 如果初始化失败
|
||
"""
|
||
try:
|
||
# 设置默认内容类型
|
||
if not content_type:
|
||
content_type = mimetypes.guess_type(object_key)[0] or 'application/octet-stream'
|
||
|
||
# 初始化分片上传
|
||
result = self.client.create_multipart_upload(
|
||
bucket=self.bucket_name,
|
||
key=object_key,
|
||
content_type=content_type
|
||
)
|
||
|
||
return result.upload_id
|
||
|
||
except tos.exceptions.TosClientError as e:
|
||
raise Exception(f"初始化分片上传失败(客户端错误): {e.message}")
|
||
except tos.exceptions.TosServerError as e:
|
||
raise Exception(f"初始化分片上传失败(服务端错误): {e.message}")
|
||
except Exception as e:
|
||
raise Exception(f"初始化分片上传失败: {str(e)}")
|
||
|
||
def upload_part(self, object_key: str, upload_id: str, part_number: int, data: bytes) -> dict:
|
||
"""
|
||
上传分片
|
||
|
||
Args:
|
||
object_key: 对象键
|
||
upload_id: 上传ID
|
||
part_number: 分片号(从1开始)
|
||
data: 分片数据
|
||
|
||
Returns:
|
||
dict: 包含完整分片信息的字典
|
||
|
||
Raises:
|
||
Exception: 如果上传失败
|
||
"""
|
||
try:
|
||
from io import BytesIO
|
||
import hashlib
|
||
|
||
# 计算分片大小
|
||
part_size = len(data)
|
||
|
||
# 计算CRC64(如果需要的话,这里先设为None)
|
||
hash_crc64_ecma = None
|
||
|
||
# 上传分片
|
||
result = self.client.upload_part(
|
||
bucket=self.bucket_name,
|
||
key=object_key,
|
||
upload_id=upload_id,
|
||
part_number=part_number,
|
||
content=BytesIO(data)
|
||
)
|
||
|
||
return {
|
||
'part_number': part_number,
|
||
'etag': result.etag,
|
||
'part_size': part_size,
|
||
'hash_crc64_ecma': hash_crc64_ecma,
|
||
'is_completed': True
|
||
}
|
||
|
||
except tos.exceptions.TosClientError as e:
|
||
raise Exception(f"上传分片失败(客户端错误): {e.message}")
|
||
except tos.exceptions.TosServerError as e:
|
||
raise Exception(f"上传分片失败(服务端错误): {e.message}")
|
||
except Exception as e:
|
||
raise Exception(f"上传分片失败: {str(e)}")
|
||
|
||
def complete_multipart_upload(self, object_key: str, upload_id: str, parts: list) -> str:
|
||
"""
|
||
完成分片上传
|
||
|
||
Args:
|
||
object_key: 对象键
|
||
upload_id: 上传ID
|
||
parts: 分片信息列表,每个元素包含part_number和etag
|
||
|
||
Returns:
|
||
str: 文件的完整URL
|
||
|
||
Raises:
|
||
Exception: 如果完成上传失败
|
||
"""
|
||
try:
|
||
# 按分片号排序
|
||
sorted_parts = sorted(parts, key=lambda x: x['part_number'])
|
||
|
||
# 构建分片列表并计算偏移量
|
||
part_list = []
|
||
current_offset = 0
|
||
|
||
for part in sorted_parts:
|
||
part_list.append(tos.models2.PartInfo(
|
||
part_number=part['part_number'],
|
||
etag=part['etag'],
|
||
part_size=part.get('part_size'),
|
||
offset=current_offset,
|
||
hash_crc64_ecma=part.get('hash_crc64_ecma'),
|
||
is_completed=part.get('is_completed', True)
|
||
))
|
||
|
||
# 更新偏移量
|
||
if part.get('part_size'):
|
||
current_offset += part['part_size']
|
||
|
||
# 完成分片上传
|
||
result = self.client.complete_multipart_upload(
|
||
bucket=self.bucket_name,
|
||
key=object_key,
|
||
upload_id=upload_id,
|
||
parts=part_list
|
||
)
|
||
|
||
# 返回完整URL
|
||
return f"https://{self.self_domain}/{object_key}"
|
||
|
||
except tos.exceptions.TosClientError as e:
|
||
raise Exception(f"完成分片上传失败(客户端错误): {e.message}")
|
||
except tos.exceptions.TosServerError as e:
|
||
raise Exception(f"完成分片上传失败(服务端错误): {e.message}")
|
||
except Exception as e:
|
||
raise Exception(f"完成分片上传失败: {str(e)}")
|
||
|
||
def abort_multipart_upload(self, object_key: str, upload_id: str) -> bool:
|
||
"""
|
||
取消分片上传
|
||
|
||
Args:
|
||
object_key: 对象键
|
||
upload_id: 上传ID
|
||
|
||
Returns:
|
||
bool: 是否取消成功
|
||
"""
|
||
try:
|
||
self.client.abort_multipart_upload(
|
||
bucket=self.bucket_name,
|
||
key=object_key,
|
||
upload_id=upload_id
|
||
)
|
||
return True
|
||
|
||
except tos.exceptions.TosClientError as e:
|
||
print(f"取消分片上传失败(客户端错误): {e.message}")
|
||
return False
|
||
except tos.exceptions.TosServerError as e:
|
||
print(f"取消分片上传失败(服务端错误): {e.message}")
|
||
return False
|
||
except Exception as e:
|
||
print(f"取消分片上传失败: {str(e)}")
|
||
return False
|
||
|
||
def list_parts(self, object_key: str, upload_id: str) -> list:
|
||
"""
|
||
列出已上传的分片
|
||
|
||
Args:
|
||
object_key: 对象键
|
||
upload_id: 上传ID
|
||
|
||
Returns:
|
||
list: 已上传的分片列表
|
||
"""
|
||
try:
|
||
result = self.client.list_parts(
|
||
bucket=self.bucket_name,
|
||
key=object_key,
|
||
upload_id=upload_id
|
||
)
|
||
|
||
parts = []
|
||
for part in result.parts:
|
||
parts.append({
|
||
'part_number': part.part_number,
|
||
'etag': part.etag,
|
||
'size': part.size,
|
||
'last_modified': part.last_modified
|
||
})
|
||
|
||
return parts
|
||
|
||
except Exception as e:
|
||
print(f"列出分片失败: {str(e)}")
|
||
return []
|
||
|
||
|
||
# 创建OSS客户端
|
||
from config import TOS_CONFIG
|
||
oss_client = TOSClient(
|
||
access_key_id=TOS_CONFIG['access_key_id'],
|
||
access_key_secret=TOS_CONFIG['access_key_secret'],
|
||
endpoint=TOS_CONFIG['endpoint'],
|
||
region=TOS_CONFIG['region'],
|
||
bucket_name=TOS_CONFIG['bucket_name'],
|
||
self_domain=TOS_CONFIG['self_domain'],
|
||
disable_ssl_warnings=TOS_CONFIG['disable_ssl_warnings']
|
||
)
|
||
|
||
# 创建分片上传器
|
||
chunk_uploader = TOSChunkUploader(oss_client)
|