diff --git a/src/braket/jobs/hybrid_job.py b/src/braket/jobs/hybrid_job.py index bca0b58cb..c02eb37ef 100644 --- a/src/braket/jobs/hybrid_job.py +++ b/src/braket/jobs/hybrid_job.py @@ -267,8 +267,6 @@ def _process_input_data(input_data): else (multiple matches, possibly including exact): cwd/prefix_match -> channel/prefix_match, for each match """ - from braket.aws import AwsSession - input_data = input_data or {} if not isinstance(input_data, dict): input_data = {"input": input_data} diff --git a/test/unit_tests/braket/jobs/test_hybrid_job.py b/test/unit_tests/braket/jobs/test_hybrid_job.py index 55a1373b3..458713ff0 100644 --- a/test/unit_tests/braket/jobs/test_hybrid_job.py +++ b/test/unit_tests/braket/jobs/test_hybrid_job.py @@ -1,4 +1,6 @@ +import ast import importlib +import re import tempfile from logging import getLogger from pathlib import Path @@ -7,11 +9,13 @@ import job_module import pytest +from cloudpickle import cloudpickle from braket.aws import AwsQuantumJob from braket.devices import Devices from braket.jobs import hybrid_job from braket.jobs.config import CheckpointConfig, InstanceConfig, OutputDataConfig, StoppingCondition +from braket.jobs.hybrid_job import _serialize_entry_point from braket.jobs.local import LocalQuantumJob @@ -158,6 +162,41 @@ def my_entry(a, b: int, c=0, d: float = 1.0, **extras) -> str: mock_stdout.write.assert_any_call(s3_not_linked) +@patch("time.time", return_value=123.0) +@patch("builtins.open") +@patch("tempfile.TemporaryDirectory") +@patch.object(AwsQuantumJob, "create") +def test_decorator_non_dict_input(mock_create, mock_tempdir, _mock_open, mock_time): + input_prefix = "my_input" + + @hybrid_job(device=None, input_data=input_prefix) + def my_entry(): + return "my entry return value" + + mock_tempdir_name = "job_temp_dir_00000" + mock_tempdir.return_value.__enter__.return_value = mock_tempdir_name + + source_module = mock_tempdir_name + entry_point = f"{mock_tempdir_name}.entry_point:my_entry" + wait_until_complete = False + + device = "local:none/none" + + my_entry() + + mock_create.assert_called_with( + device=device, + source_module=source_module, + entry_point=entry_point, + wait_until_complete=wait_until_complete, + job_name="my-entry-123000", + hyperparameters={}, + logger=getLogger("braket.jobs.hybrid_job"), + input_data=input_prefix, + ) + assert mock_tempdir.return_value.__exit__.called + + @patch("time.time", return_value=123.0) @patch("builtins.open") @patch("tempfile.TemporaryDirectory") @@ -340,3 +379,17 @@ def fails_serialization(): ) with pytest.raises(RuntimeError, match=serialization_failed): fails_serialization() + + +def test_serialization_wrapping(): + def my_entry(*args, **kwargs): + print("something with \" and ' and \n") + return args, kwargs + + args, kwargs = (1, "two"), {"three": 3} + template = _serialize_entry_point(my_entry, args, kwargs) + pickled_str = re.search(r"(?s)cloudpickle.loads\((.*?)\)\ndef my_entry", template).group(1) + byte_str = ast.literal_eval(pickled_str) + + recovered = cloudpickle.loads(byte_str) + assert recovered() == (args, kwargs)