From ef5e883c01962a422a1db77b0ccbafa93ff26ef8 Mon Sep 17 00:00:00 2001 From: Aaron Berdy Date: Thu, 26 Oct 2023 13:51:19 -0700 Subject: [PATCH] feat: support dependency list for hybrid jobs --- src/braket/jobs/hybrid_job.py | 24 +++++++--- .../unit_tests/braket/jobs/test_hybrid_job.py | 47 +++++++++++++++++++ 2 files changed, 64 insertions(+), 7 deletions(-) diff --git a/src/braket/jobs/hybrid_job.py b/src/braket/jobs/hybrid_job.py index 8de496f5e..394bfad59 100644 --- a/src/braket/jobs/hybrid_job.py +++ b/src/braket/jobs/hybrid_job.py @@ -25,7 +25,7 @@ from logging import Logger, getLogger from pathlib import Path from types import ModuleType -from typing import Any, Dict, List +from typing import Any import cloudpickle @@ -47,7 +47,7 @@ def hybrid_job( *, device: str, include_modules: str | ModuleType | Iterable[str | ModuleType] = None, - dependencies: str | Path = None, + dependencies: str | Path | list[str] = None, local: bool = False, job_name: str = None, image_uri: str = None, @@ -85,7 +85,7 @@ def hybrid_job( modules to be included. Any references to members of these modules in the hybrid job algorithm code will be serialized as part of the algorithm code. Default value `[]` - dependencies (str | Path): Path (absolute or relative) to a requirements.txt + dependencies (str | Path | list[str]): Path (absolute or relative) to a requirements.txt file to be used for the hybrid job. local (bool): Whether to use local mode for the hybrid job. Default `False` @@ -178,7 +178,7 @@ def job_wrapper(*args, **kwargs) -> Callable: entry_point_file.write(template) if dependencies: - shutil.copy(Path(dependencies).resolve(), temp_dir_path / "requirements.txt") + _process_dependencies(dependencies, temp_dir_path) job_args = { "device": device or "local:none/none", @@ -241,6 +241,16 @@ def _validate_python_version(image_uri: str | None, aws_session: AwsSession | No ) +def _process_dependencies(dependencies: str | Path | list[str], temp_dir: Path) -> None: + if isinstance(dependencies, (str, Path)): + # requirements file + shutil.copy(Path(dependencies).resolve(), temp_dir / "requirements.txt") + else: + # list of packages + with open(temp_dir / "requirements.txt", "w") as f: + f.write("\n".join(dependencies)) + + class _IncludeModules: def __init__(self, modules: str | ModuleType | Iterable[str | ModuleType] = None): modules = modules or [] @@ -285,7 +295,7 @@ def wrapped_entry_point() -> Any: ) -def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict) -> Dict: +def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict) -> dict: """Capture function arguments as hyperparameters""" signature = inspect.signature(entry_point) bound_args = signature.bind(*args, **kwargs) @@ -330,7 +340,7 @@ def _sanitize(hyperparameter: Any) -> str: return sanitized -def _process_input_data(input_data: Dict) -> List[str]: +def _process_input_data(input_data: dict) -> list[str]: """ Create symlinks to data @@ -344,7 +354,7 @@ def _process_input_data(input_data: Dict) -> List[str]: if not isinstance(input_data, dict): input_data = {"input": input_data} - def matches(prefix: str) -> List[str]: + def matches(prefix: str) -> list[str]: return [ str(path) for path in Path(prefix).parent.iterdir() if str(path).startswith(str(prefix)) ] diff --git a/test/unit_tests/braket/jobs/test_hybrid_job.py b/test/unit_tests/braket/jobs/test_hybrid_job.py index 2877063d9..f47980083 100644 --- a/test/unit_tests/braket/jobs/test_hybrid_job.py +++ b/test/unit_tests/braket/jobs/test_hybrid_job.py @@ -228,6 +228,53 @@ def my_entry(): assert mock_tempdir.return_value.__exit__.called +@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image") +@patch("time.time", return_value=123.0) +@patch("builtins.open") +@patch("tempfile.TemporaryDirectory") +@patch.object(AwsQuantumJob, "create") +def test_decorator_list_dependencies( + mock_create, mock_tempdir, _mock_open, mock_time, mock_retrieve, aws_session +): + mock_retrieve.return_value = "00000000.dkr.ecr.us-west-2.amazonaws.com/latest" + dependency_list = ["dep_1", "dep_2", "dep_3"] + + @hybrid_job( + device=None, + aws_session=aws_session, + dependencies=dependency_list, + ) + def my_entry(c=0, d: float = 1.0, **extras): + return "my entry return value" + + mock_tempdir_name = "job_temp_dir_00000" + mock_tempdir.return_value.__enter__.return_value = mock_tempdir_name + + source_module = mock_tempdir_name + entry_point = f"{mock_tempdir_name}.entry_point:my_entry" + wait_until_complete = False + + device = "local:none/none" + + my_entry() + + mock_create.assert_called_with( + device=device, + source_module=source_module, + entry_point=entry_point, + wait_until_complete=wait_until_complete, + job_name="my-entry-123000", + hyperparameters={"c": "0", "d": "1.0"}, + logger=getLogger("braket.jobs.hybrid_job"), + aws_session=aws_session, + ) + assert mock_tempdir.return_value.__exit__.called + _mock_open.assert_called_with(Path(mock_tempdir_name) / "requirements.txt", "w") + _mock_open.return_value.__enter__.return_value.write.assert_called_with( + "\n".join(dependency_list) + ) + + @patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image") @patch("time.time", return_value=123.0) @patch("builtins.open")