Skip to content

Commit

Permalink
feat: add dec- prefix to decorator jobs (#739)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajberdy authored Oct 12, 2023
1 parent 53839b8 commit 0344d62
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 10 deletions.
3 changes: 2 additions & 1 deletion src/braket/jobs/hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def job_wrapper(*args, **kwargs):
f"{temp_dir}.{entry_point_file_path.stem}:{entry_point.__name__}"
),
"wait_until_complete": wait_until_complete,
"job_name": job_name or _generate_default_job_name(func=entry_point),
"job_name": job_name
or _generate_default_job_name(func=entry_point, decorator=True),
"hyperparameters": _log_hyperparameters(entry_point, args, kwargs),
"logger": logger,
}
Expand Down
7 changes: 6 additions & 1 deletion src/braket/jobs/quantum_job_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,16 @@ def prepare_quantum_job(
return create_job_kwargs


def _generate_default_job_name(image_uri: str | None = None, func: Callable | None = None) -> str:
def _generate_default_job_name(
image_uri: str | None = None, func: Callable | None = None, decorator: bool = False
) -> str:
"""
Generate default job name using the image uri and entrypoint function.
Args:
image_uri (str | None): URI for the image container.
func (Callable | None): The entry point function.
decorator (bool): Whether the job is a decorator job. Default: False.
Returns:
str: Hybrid job name.
Expand All @@ -248,6 +251,8 @@ def _generate_default_job_name(image_uri: str | None = None, func: Callable | No

if func:
name = func.__name__.replace("_", "-")
if decorator:
name = f"dec-{name}"
if len(name) + len(timestamp) > max_length:
name = name[: max_length - len(timestamp) - 1]
warnings.warn(
Expand Down
16 changes: 8 additions & 8 deletions test/unit_tests/braket/jobs/test_hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def my_entry(c=0, d: float = 1.0, **extras):
source_module=source_module,
entry_point=entry_point,
wait_until_complete=wait_until_complete,
job_name="my-entry-123000",
job_name="dec-my-entry-123000",
hyperparameters={"c": 0, "d": 1.0},
logger=getLogger("braket.jobs.hybrid_job"),
)
Expand Down Expand Up @@ -139,7 +139,7 @@ def my_entry(a, b: int, c=0, d: float = 1.0, **extras) -> str:
image_uri=image_uri,
input_data=input_data,
wait_until_complete=wait_until_complete,
job_name="my-entry-123000",
job_name="dec-my-entry-123000",
instance_config=default_instance,
distribution=distribution,
hyperparameters={"a": "a", "b": 2, "c": 3, "d": 4, "extra_param": "value", "another": 6},
Expand Down Expand Up @@ -189,7 +189,7 @@ def my_entry():
source_module=source_module,
entry_point=entry_point,
wait_until_complete=wait_until_complete,
job_name="my-entry-123000",
job_name="dec-my-entry-123000",
hyperparameters={},
logger=getLogger("braket.jobs.hybrid_job"),
input_data=input_prefix,
Expand Down Expand Up @@ -219,7 +219,7 @@ def my_entry():
device=device,
source_module=source_module,
entry_point=entry_point,
job_name="my-entry-123000",
job_name="dec-my-entry-123000",
hyperparameters={},
)
assert mock_tempdir.return_value.__exit__.called
Expand Down Expand Up @@ -257,7 +257,7 @@ def my_entry():
device=device,
source_module=source_module,
entry_point=entry_point,
job_name="my-entry-123000",
job_name="dec-my-entry-123000",
hyperparameters={},
)
assert mock_tempdir.return_value.__exit__.called
Expand Down Expand Up @@ -285,7 +285,7 @@ def this_is_a_50_character_func_name_for_testing_names():
with pytest.warns(UserWarning):
this_is_a_50_character_func_name_for_testing_names()

expected_job_name = "this-is-a-50-character-func-name-for-testin-123000"
expected_job_name = "dec-this-is-a-50-character-func-name-for-te-123000"

mock_create.assert_called_with(
device=device,
Expand Down Expand Up @@ -326,7 +326,7 @@ def my_entry(pos_only, /):
source_module=source_module,
entry_point=entry_point,
wait_until_complete=wait_until_complete,
job_name="my-entry-123000",
job_name="dec-my-entry-123000",
hyperparameters={},
logger=getLogger("braket.jobs.hybrid_job"),
)
Expand Down Expand Up @@ -359,7 +359,7 @@ def my_entry(*args):
source_module=source_module,
entry_point=entry_point,
wait_until_complete=wait_until_complete,
job_name="my-entry-123000",
job_name="dec-my-entry-123000",
hyperparameters={},
logger=getLogger("braket.jobs.hybrid_job"),
)
Expand Down
14 changes: 14 additions & 0 deletions test/unit_tests/braket/jobs/test_quantum_job_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,20 @@ def test_generate_default_job_name(mock_time, image_uri):
assert _generate_default_job_name(image_uri) == f"braket-job{job_type}-{timestamp}"


@patch("time.time")
def test_generate_default_job_name_func(mock_time):
mock_time.return_value = 123.45678

def my_func():
pass

assert _generate_default_job_name("image_uri", func=my_func) == "my-func-123456"
assert (
_generate_default_job_name("image_uri", func=my_func, decorator=True)
== "dec-my-func-123456"
)


@pytest.mark.parametrize(
"source_module",
(
Expand Down

0 comments on commit 0344d62

Please sign in to comment.