-
Notifications
You must be signed in to change notification settings - Fork 0
/
s3_util.py
148 lines (125 loc) · 5.18 KB
/
s3_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import boto3
import logging
import os
from pathlib import Path
import tarfile
from typing import List, Tuple, Optional
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
COMPRESSED_TAR_EXTENSION = ".tar.gz"
# the file name without extension is used as an asset ID by the ASR container to save the results
def generate_asset_id_from_input_file(
input_file: str, with_extension: bool = False
) -> str:
logger.info(f"generating asset ID for {input_file}")
file_name = os.path.basename(input_file) # grab the file_name from the path
if with_extension:
return file_name
# otherwise cut off the extension
asset_id, _ = os.path.splitext(file_name)
return asset_id
def is_valid_tar_path(archive_path: str) -> bool:
logger.info(f"Validating {archive_path}")
if not os.path.exists(Path(archive_path).parent):
logger.error(f"Parent dir does not exist: {archive_path}")
return False
if archive_path[-7:] != COMPRESSED_TAR_EXTENSION:
logger.error(
f"Archive file should have the correct extension: {COMPRESSED_TAR_EXTENSION}"
)
return False
return True
def tar_list_of_files(archive_path: str, file_list: List[str]) -> bool:
logger.info(f"Tarring {len(file_list)} into {archive_path}")
if not is_valid_tar_path(archive_path):
return False
try:
with tarfile.open(archive_path, "w:gz") as tar:
for item in file_list:
logger.info(os.path.basename(item))
tar.add(item, arcname=os.path.basename(item))
logger.info(f"Succesfully created {archive_path}")
return True
except tarfile.TarError:
logger.exception(f"Failed to created archive: {archive_path}")
except FileNotFoundError:
logger.exception("File in file list not found")
except Exception:
logger.exception("Unhandled error")
logger.error("Unknown error")
return False
def validate_s3_uri(s3_uri: str) -> bool:
o = urlparse(s3_uri, allow_fragments=False)
if o.scheme != "s3":
logger.error(f"Invalid protocol in {s3_uri}")
return False
if o.path == "":
logger.error(f"No object_name specified in {s3_uri}")
return False
return True
# e.g. "s3://beng-daan-visxp/jaap-dane-test/dane-test.tar.gz"
def parse_s3_uri(s3_uri: str) -> Tuple[str, str]:
logger.info(f"Parsing s3 URI {s3_uri}")
o = urlparse(s3_uri, allow_fragments=False)
bucket = o.netloc # beng-daan-visxp
object_name = o.path.lstrip("/") # jaap-dane-test/dane-test.tar.gz
return bucket, object_name
def download_s3_uri(s3_uri: str, output_folder: str) -> bool:
if not validate_s3_uri(s3_uri):
logger.error("Invalid S3 URI")
return False
s3_store = S3Store()
bucket, object_name = parse_s3_uri(s3_uri)
return s3_store.download_file(bucket, object_name, output_folder)
class S3Store:
"""
requires environment:
- "AWS_ACCESS_KEY_ID=your-key"
- "AWS_SECRET_ACCESS_KEY=your-secret"
TODO read from .aws/config, so boto3 can assume an IAM role
see: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""
def __init__(self, s3_endpoint_url: Optional[str] = None, unit_testing=False):
self.client = boto3.client("s3", endpoint_url=s3_endpoint_url)
def transfer_to_s3(
self, bucket: str, path: str, file_list: List[str], tar_archive_path: str = ""
) -> bool:
# first check if the file_list needs to be compressed (into tar)
if tar_archive_path:
tar_location = tar_list_of_files(tar_archive_path, file_list)
if not tar_location:
logger.error(
"Could not archive the file list before transferring to S3"
)
return False
file_list = [tar_archive_path] # now the file_list just has the tar
# now go ahead and upload whatever is in the file list
for f in file_list:
try:
self.client.upload_file(
Filename=f,
Bucket=bucket,
Key=os.path.join(
path,
generate_asset_id_from_input_file( # file name with extension
f, True
),
),
)
except Exception: # TODO figure out which Exception to catch specifically
logger.exception(f"Failed to upload {f}")
return False
return True
def download_file(self, bucket: str, object_name: str, output_folder: str) -> bool:
logger.info(f"Downloading {bucket}:{object_name} into {output_folder}")
if not os.path.exists(output_folder):
logger.info("Output folder does not exist, creating it...")
os.makedirs(output_folder)
output_file = os.path.join(output_folder, os.path.basename(object_name))
try:
with open(output_file, "wb") as f:
self.client.download_fileobj(bucket, object_name, f)
except Exception:
logger.exception(f"Failed to download {object_name}")
return False
return True