diff --git a/src/braket/jobs/hybrid_job.py b/src/braket/jobs/hybrid_job.py index 3e5b9078d..a45b094d4 100644 --- a/src/braket/jobs/hybrid_job.py +++ b/src/braket/jobs/hybrid_job.py @@ -297,7 +297,31 @@ def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict): warnings.warn( "Positional only arguments will not be logged to the hyperparameters file." ) - return hyperparameters + return {name: _sanitize(value) for name, value in hyperparameters.items()} + + +def _sanitize(hyperparameter: Any) -> str: + """Sanitize forbidden characters from hp strings""" + string_hp = str(hyperparameter) + + sanitized = ( + string_hp + # replace newlines with spaces + .replace("\n", " ") + # replace forbidden characters with "?" + .replace("$", "?") + .replace("(", "?") + .replace("&", "?") + .replace("`", "?") + # not technically forbidden, but to avoid mismatched parens + .replace(")", "?") + ) + + # max allowed length for a hyperparameter is 2500 + if len(sanitized) > 2500: + # show as much as possible, including the final 20 characters + return f"{sanitized[:2500-23]}...{sanitized[-20:]}" + return sanitized def _process_input_data(input_data): diff --git a/test/unit_tests/braket/jobs/test_hybrid_job.py b/test/unit_tests/braket/jobs/test_hybrid_job.py index 063623383..f1159f9ce 100644 --- a/test/unit_tests/braket/jobs/test_hybrid_job.py +++ b/test/unit_tests/braket/jobs/test_hybrid_job.py @@ -16,7 +16,7 @@ from braket.devices import Devices from braket.jobs import hybrid_job from braket.jobs.config import CheckpointConfig, InstanceConfig, OutputDataConfig, StoppingCondition -from braket.jobs.hybrid_job import _serialize_entry_point +from braket.jobs.hybrid_job import _sanitize, _serialize_entry_point from braket.jobs.local import LocalQuantumJob @@ -60,7 +60,7 @@ def my_entry(c=0, d: float = 1.0, **extras): entry_point=entry_point, wait_until_complete=wait_until_complete, job_name="my-entry-123000", - hyperparameters={"c": 0, "d": 1.0}, + hyperparameters={"c": "0", "d": "1.0"}, logger=getLogger("braket.jobs.hybrid_job"), aws_session=aws_session, ) @@ -161,7 +161,14 @@ def my_entry(a, b: int, c=0, d: float = 1.0, **extras) -> str: job_name="my-entry-123000", instance_config=default_instance, distribution=distribution, - hyperparameters={"a": "a", "b": 2, "c": 3, "d": 4, "extra_param": "value", "another": 6}, + hyperparameters={ + "a": "a", + "b": "2", + "c": "3", + "d": "4", + "extra_param": "value", + "another": "6", + }, checkpoint_config=checkpoint_config, copy_checkpoints_from_job=copy_checkpoints_from_job, role_arn=role_arn, @@ -463,3 +470,24 @@ def test_python_validation(aws_session): @hybrid_job(device=None, aws_session=aws_session) def my_job(): pass + + +@pytest.mark.parametrize( + "hyperparameter, expected", + ( + ( + "with\nnewline", + "with newline", + ), + ( + "with weird chars: (&$`)", + "with weird chars: ?????", + ), + ( + "?" * 2600, + f"{'?'*2477}...{'?'*20}", + ), + ), +) +def test_sanitize_hyperparameters(hyperparameter, expected): + assert _sanitize(hyperparameter) == expected