-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_download.py
83 lines (71 loc) · 3.28 KB
/
model_download.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import logging
import os
import tarfile
from urllib.parse import urlparse
import requests
from s3_util import S3Store, parse_s3_uri, validate_s3_uri
from base_util import get_asset_info, validate_http_uri
from config import s3_endpoint_url
logger = logging.getLogger(__name__)
# e.g. {base_dir}/modelx.tar.gz will be extracted in {base_dir}/modelx
def extract_model(destination: str, extension: str) -> str:
tar_path = f"{destination}.{extension}"
logger.info(f"Extracting {tar_path} into {destination}")
if not os.path.exists(destination): # Create dir for model to be extracted in
os.makedirs(destination)
logger.info(f"Extracting the model into {destination}")
try:
with tarfile.open(tar_path) as tar:
tar.extractall(path=destination)
# cleanup: delete the tar file
os.remove(tar_path)
if os.path.exists(os.path.join(destination, "model.bin")):
logger.info(
f"model.bin found in {destination}. Model extracted successfully!"
)
return destination
else:
logger.error(f"{destination} does not contain a model.bin file. Exiting...")
return ""
except tarfile.ReadError:
logger.error("Could not extract the model")
return ""
# makes sure the model is obtained from S3/HTTP/Huggingface, if w_model doesn't exist locally
def get_model_location(base_dir: str, whisper_model: str) -> str:
logger.info(f"Checking w_model: {whisper_model} and download if necessary")
if validate_s3_uri(whisper_model):
return check_s3_location(base_dir, whisper_model)
elif validate_http_uri(whisper_model):
return check_http_location(base_dir, whisper_model)
# The faster-whisper API can auto-detect if the version exists locally. No need to add extra checks
logger.info(f"{whisper_model} is not an S3/HTTP URI. Using HuggingFace instead")
return whisper_model
def check_s3_location(base_dir: str, whisper_model: str) -> str:
logger.info(f"{whisper_model} is an S3 URI. Attempting to download")
bucket, object_name = parse_s3_uri(whisper_model)
asset_id, extension = get_asset_info(object_name)
destination = os.path.join(base_dir, asset_id)
if os.path.exists(destination):
logger.info("Model already exists")
return destination
s3 = S3Store(s3_endpoint_url)
success = s3.download_file(bucket, object_name, base_dir)
if not success:
logger.error(f"Could not download {whisper_model} into {base_dir}")
return ""
return extract_model(destination, extension)
def check_http_location(base_dir: str, whisper_model: str) -> str:
logger.info(f"{whisper_model} is an HTTP URI. Attempting to download")
asset_id, extension = get_asset_info(urlparse(whisper_model).path)
destination = os.path.join(base_dir, asset_id)
if os.path.exists(destination):
logger.info("Model already exists")
return destination
with open(f"{destination}.{extension}", "wb") as file:
response = requests.get(whisper_model)
if response.status_code >= 400:
logger.error(f"Could not download {whisper_model} into {base_dir}")
return ""
file.write(response.content)
file.close()
return extract_model(destination, extension)