Skip to content

Commit

Permalink
feat: support dependency list for hybrid jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
ajberdy committed Oct 26, 2023
1 parent 0e19e30 commit ef5e883
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 7 deletions.
24 changes: 17 additions & 7 deletions src/braket/jobs/hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from logging import Logger, getLogger
from pathlib import Path
from types import ModuleType
from typing import Any, Dict, List
from typing import Any

import cloudpickle

Expand All @@ -47,7 +47,7 @@ def hybrid_job(
*,
device: str,
include_modules: str | ModuleType | Iterable[str | ModuleType] = None,
dependencies: str | Path = None,
dependencies: str | Path | list[str] = None,
local: bool = False,
job_name: str = None,
image_uri: str = None,
Expand Down Expand Up @@ -85,7 +85,7 @@ def hybrid_job(
modules to be included. Any references to members of these modules in the hybrid job
algorithm code will be serialized as part of the algorithm code. Default value `[]`
dependencies (str | Path): Path (absolute or relative) to a requirements.txt
dependencies (str | Path | list[str]): Path (absolute or relative) to a requirements.txt
file to be used for the hybrid job.
local (bool): Whether to use local mode for the hybrid job. Default `False`
Expand Down Expand Up @@ -178,7 +178,7 @@ def job_wrapper(*args, **kwargs) -> Callable:
entry_point_file.write(template)

if dependencies:
shutil.copy(Path(dependencies).resolve(), temp_dir_path / "requirements.txt")
_process_dependencies(dependencies, temp_dir_path)

job_args = {
"device": device or "local:none/none",
Expand Down Expand Up @@ -241,6 +241,16 @@ def _validate_python_version(image_uri: str | None, aws_session: AwsSession | No
)


def _process_dependencies(dependencies: str | Path | list[str], temp_dir: Path) -> None:
if isinstance(dependencies, (str, Path)):
# requirements file
shutil.copy(Path(dependencies).resolve(), temp_dir / "requirements.txt")
else:
# list of packages
with open(temp_dir / "requirements.txt", "w") as f:
f.write("\n".join(dependencies))


class _IncludeModules:
def __init__(self, modules: str | ModuleType | Iterable[str | ModuleType] = None):
modules = modules or []
Expand Down Expand Up @@ -285,7 +295,7 @@ def wrapped_entry_point() -> Any:
)


def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict) -> Dict:
def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict) -> dict:
"""Capture function arguments as hyperparameters"""
signature = inspect.signature(entry_point)
bound_args = signature.bind(*args, **kwargs)
Expand Down Expand Up @@ -330,7 +340,7 @@ def _sanitize(hyperparameter: Any) -> str:
return sanitized


def _process_input_data(input_data: Dict) -> List[str]:
def _process_input_data(input_data: dict) -> list[str]:
"""
Create symlinks to data
Expand All @@ -344,7 +354,7 @@ def _process_input_data(input_data: Dict) -> List[str]:
if not isinstance(input_data, dict):
input_data = {"input": input_data}

def matches(prefix: str) -> List[str]:
def matches(prefix: str) -> list[str]:
return [
str(path) for path in Path(prefix).parent.iterdir() if str(path).startswith(str(prefix))
]
Expand Down
47 changes: 47 additions & 0 deletions test/unit_tests/braket/jobs/test_hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,53 @@ def my_entry():
assert mock_tempdir.return_value.__exit__.called


@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image")
@patch("time.time", return_value=123.0)
@patch("builtins.open")
@patch("tempfile.TemporaryDirectory")
@patch.object(AwsQuantumJob, "create")
def test_decorator_list_dependencies(
mock_create, mock_tempdir, _mock_open, mock_time, mock_retrieve, aws_session
):
mock_retrieve.return_value = "00000000.dkr.ecr.us-west-2.amazonaws.com/latest"
dependency_list = ["dep_1", "dep_2", "dep_3"]

@hybrid_job(
device=None,
aws_session=aws_session,
dependencies=dependency_list,
)
def my_entry(c=0, d: float = 1.0, **extras):
return "my entry return value"

mock_tempdir_name = "job_temp_dir_00000"
mock_tempdir.return_value.__enter__.return_value = mock_tempdir_name

source_module = mock_tempdir_name
entry_point = f"{mock_tempdir_name}.entry_point:my_entry"
wait_until_complete = False

device = "local:none/none"

my_entry()

mock_create.assert_called_with(
device=device,
source_module=source_module,
entry_point=entry_point,
wait_until_complete=wait_until_complete,
job_name="my-entry-123000",
hyperparameters={"c": "0", "d": "1.0"},
logger=getLogger("braket.jobs.hybrid_job"),
aws_session=aws_session,
)
assert mock_tempdir.return_value.__exit__.called
_mock_open.assert_called_with(Path(mock_tempdir_name) / "requirements.txt", "w")
_mock_open.return_value.__enter__.return_value.write.assert_called_with(
"\n".join(dependency_list)
)


@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image")
@patch("time.time", return_value=123.0)
@patch("builtins.open")
Expand Down

0 comments on commit ef5e883

Please sign in to comment.