Skip to content

Commit

Permalink
fix: uploading large files saving to disk instead of memory (#4935)
Browse files Browse the repository at this point in the history
* fix: uploading large files saving to disk instead of memory

Signed-off-by: Frost Ming <me@frostming.com>

* fix: context managers

Signed-off-by: Frost Ming <me@frostming.com>
  • Loading branch information
frostming authored Aug 23, 2024
1 parent 737d402 commit ed13a97
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 72 deletions.
68 changes: 44 additions & 24 deletions src/bentoml/_internal/cloud/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import io
import typing as t
from abc import ABC
from abc import abstractmethod
from contextlib import contextmanager

import attrs
from rich.console import Group
from rich.live import Live
from rich.panel import Panel
Expand Down Expand Up @@ -33,26 +33,40 @@
FILE_CHUNK_SIZE = 100 * 1024 * 1024 # 100Mb


class CallbackIOWrapper(io.BytesIO):
read_cb: t.Callable[[int], None] | None
write_cb: t.Callable[[int], None] | None
@attrs.define
class CallbackIOWrapper(t.IO[bytes]):
file: t.IO[bytes]
read_cb: t.Callable[[int], None] | None = None
write_cb: t.Callable[[int], None] | None = None
start: int | None = None
end: int | None = None

def __init__(
self,
buffer: t.Any = None,
*,
read_cb: t.Callable[[int], None] | None = None,
write_cb: t.Callable[[int], None] | None = None,
):
self.read_cb = read_cb
self.write_cb = write_cb
super().__init__(buffer)
def __attrs_post_init__(self) -> None:
self.file.seek(self.start or 0, 0)

def read(self, size: int | None = None) -> bytes:
if size is not None:
res = super().read(size)
def seek(self, offset: int, whence: int = 0) -> int:
if whence == 2 and self.end is not None:
length = self.file.seek(self.end, 0)
else:
res = super().read()
length = self.file.seek(offset, whence)
return length - (self.start or 0)

def tell(self) -> int:
return self.file.tell()

def fileno(self) -> int:
# Raise OSError to prevent access to the underlying file descriptor
raise OSError("fileno")

def __getattr__(self, name: str) -> t.Any:
return getattr(self.file, name)

def read(self, size: int = -1) -> bytes:
pos = self.tell()
if self.end is not None:
if size < 0 or size > self.end - pos:
size = self.end - pos
res = self.file.read(size)
if self.read_cb is not None:
self.read_cb(len(res))
return res
Expand All @@ -64,6 +78,9 @@ def write(self, data: bytes) -> t.Any: # type: ignore # python buffer types ar
self.write_cb(len(data))
return res

def __iter__(self) -> t.Iterator[bytes]:
return iter(self.file)


class Spinner:
"""A UI component that renders as follows:
Expand Down Expand Up @@ -109,20 +126,23 @@ def console(self) -> "Console":
def spin(self, text: str) -> t.Generator[TaskID, None, None]:
"""Create a spinner as a context manager."""
try:
task_id = self.update(text)
task_id = self.update(text, new=True)
yield task_id
finally:
self._spinner_task_id = None
self._spinner_progress.stop_task(task_id)
self._spinner_progress.update(task_id, visible=False)

def update(self, text: str) -> TaskID:
def update(self, text: str, new: bool = False) -> TaskID:
"""Update the spin text."""
if self._spinner_task_id is None:
self._spinner_task_id = self._spinner_progress.add_task(text)
if self._spinner_task_id is None or new:
task_id = self._spinner_progress.add_task(text)
if self._spinner_task_id is None:
self._spinner_task_id = task_id
else:
self._spinner_progress.update(self._spinner_task_id, description=text)
return self._spinner_task_id
task_id = self._spinner_task_id
self._spinner_progress.update(task_id, description=text)
return task_id

def __rich_console__(
self, console: "Console", options: "ConsoleOptions"
Expand Down
88 changes: 42 additions & 46 deletions src/bentoml/_internal/cloud/bentocloud.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import math
import tarfile
import tempfile
import threading
import typing as t
import warnings
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -84,6 +84,7 @@ def _do_push_bento(
threads: int = 10,
rest_client: RestApiClient = Provide[BentoMLContainer.rest_api_client],
model_store: ModelStore = Provide[BentoMLContainer.model_store],
bentoml_tmp_dir: str = Provide[BentoMLContainer.tmp_bento_store_dir],
):
name = bento.tag.name
version = bento.tag.version
Expand Down Expand Up @@ -213,10 +214,11 @@ def push_model(model: Model) -> None:
presigned_upload_url = remote_bento.presigned_upload_url

def io_cb(x: int):
with io_mutex:
self.spinner.transmission_progress.update(upload_task_id, advance=x)
self.spinner.transmission_progress.update(upload_task_id, advance=x)

with CallbackIOWrapper(read_cb=io_cb) as tar_io:
with NamedTemporaryFile(
prefix="bentoml-bento-", suffix=".tar", dir=bentoml_tmp_dir
) as tar_io:
with self.spinner.spin(
text=f'Creating tar archive for bento "{bento.tag}"..'
):
Expand All @@ -232,42 +234,38 @@ def filter_(
return tar_info

tar.add(bento.path, arcname="./", filter=filter_)
tar_io.seek(0, 0)

with self.spinner.spin(text=f'Start uploading bento "{bento.tag}"..'):
rest_client.v1.start_upload_bento(
bento_repository_name=bento_repository.name, version=version
)

file_size = tar_io.getbuffer().nbytes
file_size = tar_io.tell()
io_with_cb = CallbackIOWrapper(tar_io, read_cb=io_cb)

self.spinner.transmission_progress.update(
upload_task_id, completed=0, total=file_size, visible=True
)
self.spinner.transmission_progress.start_task(upload_task_id)

io_mutex = threading.Lock()

if transmission_strategy == "proxy":
try:
rest_client.v1.upload_bento(
bento_repository_name=bento_repository.name,
version=version,
data=tar_io,
data=io_with_cb,
)
except Exception as e: # pylint: disable=broad-except
self.spinner.log(f'[bold red]Failed to upload bento "{bento.tag}"')
raise e
self.spinner.log(f'[bold green]Successfully pushed bento "{bento.tag}"')
return
finish_req = FinishUploadBentoSchema(
status=BentoUploadStatus.SUCCESS.value,
reason="",
status=BentoUploadStatus.SUCCESS.value, reason=""
)
try:
if presigned_upload_url is not None:
resp = httpx.put(
presigned_upload_url, content=tar_io, timeout=36000
presigned_upload_url, content=io_with_cb, timeout=36000
)
if resp.status_code != 200:
finish_req = FinishUploadBentoSchema(
Expand All @@ -289,7 +287,8 @@ def filter_(

upload_id: str = remote_bento.upload_id

chunks_count = file_size // FILE_CHUNK_SIZE + 1
chunks_count = math.ceil(file_size / FILE_CHUNK_SIZE)
tar_io.file.close()

def chunk_upload(
upload_id: str, chunk_number: int
Expand All @@ -310,18 +309,16 @@ def chunk_upload(
with self.spinner.spin(
text=f'({chunk_number}/{chunks_count}) Uploading chunk of Bento "{bento.tag}"...'
):
chunk = (
tar_io.getbuffer()[
(chunk_number - 1) * FILE_CHUNK_SIZE : chunk_number
* FILE_CHUNK_SIZE
]
if chunk_number < chunks_count
else tar_io.getbuffer()[
(chunk_number - 1) * FILE_CHUNK_SIZE :
]
)
with open(tar_io.name, "rb") as f:
chunk_io = CallbackIOWrapper(
f,
read_cb=io_cb,
start=(chunk_number - 1) * FILE_CHUNK_SIZE,
end=chunk_number * FILE_CHUNK_SIZE
if chunk_number < chunks_count
else None,
)

with CallbackIOWrapper(chunk, read_cb=io_cb) as chunk_io:
resp = httpx.put(
remote_bento.presigned_upload_url,
content=chunk_io,
Expand Down Expand Up @@ -588,6 +585,7 @@ def _do_push_model(
force: bool = False,
threads: int = 10,
rest_client: RestApiClient = Provide[BentoMLContainer.rest_api_client],
bentoml_tmp_dir: str = Provide[BentoMLContainer.tmp_bento_store_dir],
):
name = model.tag.name
version = model.tag.version
Expand Down Expand Up @@ -663,38 +661,37 @@ def _do_push_model(
transmission_strategy = "presigned_url"
presigned_upload_url = remote_model.presigned_upload_url

io_mutex = threading.Lock()

def io_cb(x: int):
with io_mutex:
self.spinner.transmission_progress.update(upload_task_id, advance=x)
self.spinner.transmission_progress.update(upload_task_id, advance=x)

with CallbackIOWrapper(read_cb=io_cb) as tar_io:
with NamedTemporaryFile(
prefix="bentoml-model-", suffix=".tar", dir=bentoml_tmp_dir
) as tar_io:
with self.spinner.spin(
text=f'Creating tar archive for model "{model.tag}"..'
):
with tarfile.open(fileobj=tar_io, mode="w:") as tar:
tar.add(model.path, arcname="./")
tar_io.seek(0, 0)
with self.spinner.spin(text=f'Start uploading model "{model.tag}"..'):
rest_client.v1.start_upload_model(
model_repository_name=model_repository.name, version=version
)
file_size = tar_io.getbuffer().nbytes
file_size = tar_io.tell()
self.spinner.transmission_progress.update(
upload_task_id,
description=f'Uploading model "{model.tag}"',
total=file_size,
visible=True,
)
self.spinner.transmission_progress.start_task(upload_task_id)
io_with_cb = CallbackIOWrapper(tar_io, read_cb=io_cb)

if transmission_strategy == "proxy":
try:
rest_client.v1.upload_model(
model_repository_name=model_repository.name,
version=version,
data=tar_io,
data=io_with_cb,
)
except Exception as e: # pylint: disable=broad-except
self.spinner.log(f'[bold red]Failed to upload model "{model.tag}"')
Expand All @@ -708,7 +705,7 @@ def io_cb(x: int):
try:
if presigned_upload_url is not None:
resp = httpx.put(
presigned_upload_url, content=tar_io, timeout=36000
presigned_upload_url, content=io_with_cb, timeout=36000
)
if resp.status_code != 200:
finish_req = FinishUploadModelSchema(
Expand All @@ -730,7 +727,8 @@ def io_cb(x: int):

upload_id: str = remote_model.upload_id

chunks_count = file_size // FILE_CHUNK_SIZE + 1
chunks_count = math.ceil(file_size / FILE_CHUNK_SIZE)
tar_io.file.close()

def chunk_upload(
upload_id: str, chunk_number: int
Expand All @@ -752,18 +750,16 @@ def chunk_upload(
with self.spinner.spin(
text=f'({chunk_number}/{chunks_count}) Uploading chunk of model "{model.tag}"...'
):
chunk = (
tar_io.getbuffer()[
(chunk_number - 1) * FILE_CHUNK_SIZE : chunk_number
* FILE_CHUNK_SIZE
]
if chunk_number < chunks_count
else tar_io.getbuffer()[
(chunk_number - 1) * FILE_CHUNK_SIZE :
]
)
with open(tar_io.name, "rb") as f:
chunk_io = CallbackIOWrapper(
f,
read_cb=io_cb,
start=(chunk_number - 1) * FILE_CHUNK_SIZE,
end=chunk_number * FILE_CHUNK_SIZE
if chunk_number < chunks_count
else None,
)

with CallbackIOWrapper(chunk, read_cb=io_cb) as chunk_io:
resp = httpx.put(
remote_model.presigned_upload_url,
content=chunk_io,
Expand Down
4 changes: 2 additions & 2 deletions src/bentoml/_internal/cloud/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def finish_upload_bento(
return schema_from_json(resp.text, BentoSchema)

def upload_bento(
self, bento_repository_name: str, version: str, data: t.BinaryIO
self, bento_repository_name: str, version: str, data: t.IO[bytes]
) -> None:
url = urljoin(
self.endpoint,
Expand Down Expand Up @@ -416,7 +416,7 @@ def finish_upload_model(
return schema_from_json(resp.text, ModelSchema)

def upload_model(
self, model_repository_name: str, version: str, data: t.BinaryIO
self, model_repository_name: str, version: str, data: t.IO[bytes]
) -> None:
url = urljoin(
self.endpoint,
Expand Down

0 comments on commit ed13a97

Please sign in to comment.