-
Notifications
You must be signed in to change notification settings - Fork 135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add streaming fileobj support to CRTTransferManager #277
base: develop
Are you sure you want to change the base?
Changes from all commits
f62f38e
4ecc662
7353943
c2f14b0
fa6e6d8
5cf54e2
755605f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,15 +25,24 @@ | |
EventLoopGroup, | ||
TlsContextOptions, | ||
) | ||
from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType | ||
from awscrt.s3 import ( | ||
S3ChecksumAlgorithm, | ||
S3ChecksumConfig, | ||
S3ChecksumLocation, | ||
S3Client, | ||
S3RequestTlsMode, | ||
S3RequestType, | ||
) | ||
from botocore import UNSIGNED | ||
from botocore.compat import urlsplit | ||
from botocore.config import Config | ||
from botocore.exceptions import NoCredentialsError | ||
|
||
from s3transfer.compat import seekable | ||
from s3transfer.constants import GB, MB | ||
from s3transfer.exceptions import TransferNotDoneError | ||
from s3transfer.futures import BaseTransferFuture, BaseTransferMeta | ||
from s3transfer.manager import TransferManager | ||
from s3transfer.utils import CallArgs, OSUtils, get_callbacks | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -67,7 +76,7 @@ def create_s3_crt_client( | |
region, | ||
botocore_credential_provider=None, | ||
num_threads=None, | ||
target_throughput=5 * GB / 8, | ||
target_throughput=5_000_000_000.0 / 8, | ||
part_size=8 * MB, | ||
use_ssl=True, | ||
verify=None, | ||
|
@@ -86,8 +95,8 @@ def create_s3_crt_client( | |
is the number of processors in the machine. | ||
|
||
:type target_throughput: Optional[int] | ||
:param target_throughput: Throughput target in Bytes. | ||
Default is 0.625 GB/s (which translates to 5 Gb/s). | ||
:param target_throughput: Throughput target in bytes per second. | ||
Default translates to 5.0 Gb/s or 0.582 GiB/s. | ||
|
||
:type part_size: Optional[int] | ||
:param part_size: Size, in Bytes, of parts that files will be downloaded | ||
|
@@ -137,19 +146,24 @@ def create_s3_crt_client( | |
credentails_provider_adapter | ||
) | ||
|
||
target_gbps = target_throughput * 8 / GB | ||
target_gigabits = target_throughput * 8 / 1_000_000_000.0 | ||
return S3Client( | ||
bootstrap=bootstrap, | ||
region=region, | ||
credential_provider=provider, | ||
part_size=part_size, | ||
tls_mode=tls_mode, | ||
tls_connection_options=tls_connection_options, | ||
throughput_target_gbps=target_gbps, | ||
throughput_target_gbps=target_gigabits, | ||
) | ||
|
||
|
||
class CRTTransferManager: | ||
|
||
ALLOWED_DOWNLOAD_ARGS = TransferManager.ALLOWED_DOWNLOAD_ARGS | ||
ALLOWED_UPLOAD_ARGS = TransferManager.ALLOWED_UPLOAD_ARGS | ||
ALLOWED_DELETE_ARGS = TransferManager.ALLOWED_DELETE_ARGS | ||
|
||
def __init__(self, crt_s3_client, crt_request_serializer, osutil=None): | ||
"""A transfer manager interface for Amazon S3 on CRT s3 client. | ||
|
||
|
@@ -192,6 +206,8 @@ def download( | |
extra_args = {} | ||
if subscribers is None: | ||
subscribers = {} | ||
self._validate_all_known_args(extra_args, TransferManager.ALLOWED_DOWNLOAD_ARGS) | ||
# TODO: _validate_if_bucket_supported() ??? | ||
callargs = CallArgs( | ||
bucket=bucket, | ||
key=key, | ||
|
@@ -206,6 +222,7 @@ def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None): | |
extra_args = {} | ||
if subscribers is None: | ||
subscribers = {} | ||
self._validate_all_known_args(extra_args, TransferManager.ALLOWED_UPLOAD_ARGS) | ||
callargs = CallArgs( | ||
bucket=bucket, | ||
key=key, | ||
|
@@ -220,6 +237,7 @@ def delete(self, bucket, key, extra_args=None, subscribers=None): | |
extra_args = {} | ||
if subscribers is None: | ||
subscribers = {} | ||
self._validate_all_known_args(extra_args, TransferManager.ALLOWED_DELETE_ARGS) | ||
callargs = CallArgs( | ||
bucket=bucket, | ||
key=key, | ||
|
@@ -260,6 +278,14 @@ def _shutdown(self, cancel=False): | |
def _release_semaphore(self, **kwargs): | ||
self._semaphore.release() | ||
|
||
def _validate_all_known_args(self, actual, allowed): | ||
for kwarg in actual: | ||
if kwarg not in allowed: | ||
raise ValueError( | ||
"Invalid extra_args key '%s', " | ||
"must be one of: %s" % (kwarg, ', '.join(allowed)) | ||
) | ||
|
||
def _submit_transfer(self, request_type, call_args): | ||
on_done_after_calls = [self._release_semaphore] | ||
coordinator = CRTTransferCoordinator(transfer_id=self._id_counter) | ||
|
@@ -359,7 +385,7 @@ def set_exception(self, exception): | |
|
||
|
||
class BaseCRTRequestSerializer: | ||
def serialize_http_request(self, transfer_type, future): | ||
def serialize_http_request(self, transfer_type, future, fileobj): | ||
"""Serialize CRT HTTP requests. | ||
|
||
:type transfer_type: string | ||
|
@@ -428,19 +454,12 @@ def _crt_request_from_aws_request(self, aws_request): | |
headers_list.append((name, str(value, 'utf-8'))) | ||
|
||
crt_headers = awscrt.http.HttpHeaders(headers_list) | ||
# CRT requires body (if it exists) to be an I/O stream. | ||
crt_body_stream = None | ||
if aws_request.body: | ||
if hasattr(aws_request.body, 'seek'): | ||
crt_body_stream = aws_request.body | ||
else: | ||
crt_body_stream = BytesIO(aws_request.body) | ||
|
||
crt_request = awscrt.http.HttpRequest( | ||
method=aws_request.method, | ||
path=crt_path, | ||
headers=crt_headers, | ||
body_stream=crt_body_stream, | ||
body_stream=aws_request.body, | ||
) | ||
return crt_request | ||
|
||
|
@@ -451,8 +470,24 @@ def _convert_to_crt_http_request(self, botocore_http_request): | |
# If host is not set, set it for the request before using CRT s3 | ||
url_parts = urlsplit(botocore_http_request.url) | ||
crt_request.headers.set("host", url_parts.netloc) | ||
|
||
# Remove bogus Content-MD5 value (see comment elsewhere in file) | ||
if crt_request.headers.get('Content-MD5') is not None: | ||
crt_request.headers.remove("Content-MD5") | ||
|
||
# Explicitly set "Content-Length: 0" when there's no body. | ||
# Botocore doesn't bother setting this, but CRT likes to know. | ||
# Note that Content-Length SHOULD be absent if body is nonseekable. | ||
if crt_request.headers.get('Content-Length') is None: | ||
if botocore_http_request.body is None: | ||
crt_request.headers.add('Content-Length', "0") | ||
|
||
# Remove "Transfer-Encoding: chunked". | ||
# Botocore sets this on nonseekable streams, | ||
# but CRT currently chokes on this header (TODO: fix this in CRT) | ||
if crt_request.headers.get('Transfer-Encoding') is not None: | ||
crt_request.headers.remove('Transfer-Encoding') | ||
|
||
return crt_request | ||
|
||
def _capture_http_request(self, request, **kwargs): | ||
|
@@ -556,22 +591,57 @@ def get_make_request_args( | |
self, request_type, call_args, coordinator, future, on_done_after_calls | ||
): | ||
recv_filepath = None | ||
on_body = None | ||
send_filepath = None | ||
s3_meta_request_type = getattr( | ||
S3RequestType, request_type.upper(), S3RequestType.DEFAULT | ||
) | ||
on_done_before_calls = [] | ||
checksum_config = S3ChecksumConfig() | ||
|
||
if s3_meta_request_type == S3RequestType.GET_OBJECT: | ||
final_filepath = call_args.fileobj | ||
recv_filepath = self._os_utils.get_temp_filename(final_filepath) | ||
file_ondone_call = RenameTempFileHandler( | ||
coordinator, final_filepath, recv_filepath, self._os_utils | ||
) | ||
on_done_before_calls.append(file_ondone_call) | ||
if isinstance(call_args.fileobj, str): | ||
# fileobj is a filepath | ||
final_filepath = call_args.fileobj | ||
recv_filepath = self._os_utils.get_temp_filename(final_filepath) | ||
file_ondone_call = RenameTempFileHandler( | ||
coordinator, final_filepath, recv_filepath, self._os_utils | ||
) | ||
on_done_before_calls.append(file_ondone_call) | ||
|
||
elif call_args.fileobj is not None: | ||
# fileobj is a file-like object | ||
response_handler = _FileobjResponseHandler(call_args.fileobj) | ||
on_body = response_handler.on_body | ||
|
||
# Only validate response checksums when downloading. | ||
# (upload responses also have checksum headers, but they're just an | ||
# echo of what was in the request, an upload response's body is empty) | ||
checksum_config.validate_response = True | ||
|
||
elif s3_meta_request_type == S3RequestType.PUT_OBJECT: | ||
send_filepath = call_args.fileobj | ||
data_len = self._os_utils.get_file_size(send_filepath) | ||
call_args.extra_args["ContentLength"] = data_len | ||
if isinstance(call_args.fileobj, str): | ||
# fileobj is a filepath | ||
send_filepath = call_args.fileobj | ||
data_len = self._os_utils.get_file_size(send_filepath) | ||
call_args.extra_args["ContentLength"] = data_len | ||
|
||
elif call_args.fileobj is not None: | ||
# fileobj is a file-like object | ||
call_args.extra_args["Body"] = call_args.fileobj | ||
|
||
# We want the CRT S3Client to calculate checksums, not botocore. | ||
# Default to CRC32. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is different than the default TransferManager. With this change CRT is defaulting to CRC32 instead of Content-MD5. The CRT team got grumpy about how Content-MD5 currently works in our code and pushed me to do it this way, but we can still make Content-MD5 work if we need to. |
||
if call_args.extra_args.get('ChecksumAlgorithm') is not None: | ||
algorithm_name = call_args.extra_args.pop('ChecksumAlgorithm') | ||
checksum_config.algorithm = S3ChecksumAlgorithm[algorithm_name] | ||
else: | ||
checksum_config.algorithm = S3ChecksumAlgorithm.CRC32 | ||
checksum_config.location = S3ChecksumLocation.TRAILER | ||
|
||
# Suppress botocore's MD5 calculation by setting a bogus value. | ||
# (this header gets removed before the request is passed to CRT) | ||
call_args.extra_args["ContentMD5"] = "bogus value deleted later" | ||
|
||
crt_request = self._request_serializer.serialize_http_request( | ||
request_type, future | ||
|
@@ -582,6 +652,8 @@ def get_make_request_args( | |
'type': s3_meta_request_type, | ||
'recv_filepath': recv_filepath, | ||
'send_filepath': send_filepath, | ||
'checksum_config': checksum_config, | ||
'on_body': on_body, | ||
'on_done': self.get_crt_callback( | ||
future, 'done', on_done_before_calls, on_done_after_calls | ||
), | ||
|
@@ -642,3 +714,11 @@ def __init__(self, coordinator): | |
|
||
def __call__(self, **kwargs): | ||
self._coordinator.set_done_callbacks_complete() | ||
|
||
|
||
class _FileobjResponseHandler: | ||
def __init__(self, fileobj): | ||
self._fileobj = fileobj | ||
|
||
def on_body(self, chunk: bytes, offset: int, **kwargs): | ||
self._fileobj.write(chunk) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I make a common base class for
TransferManager
andCRTTransferManager
, where I can put stuff like these lists, anddef _validate_all_known_args()
anddef _validate_if_bucket_supported()
?Or move these to be standalone lists/functions that both classes can use?
Or just copy/paste the functions, and reference the lists, like I'm doing here?