Skip to content

Commit

Permalink
feat: sanitize hp strings (#742)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajberdy authored Oct 13, 2023
1 parent b0df320 commit 0efbd82
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 5 deletions.
25 changes: 24 additions & 1 deletion src/braket/jobs/hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,30 @@ def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict):
warnings.warn(
"Positional only arguments will not be logged to the hyperparameters file."
)
return hyperparameters
return {name: _sanitize(value) for name, value in hyperparameters.items()}


def _sanitize(hyperparameter: Any) -> str:
"""Sanitize forbidden characters from hp strings"""
string_hp = str(hyperparameter)

sanitized = (
string_hp
# replace forbidden characters with close matches
.replace("\n", " ")
.replace("$", "?")
.replace("(", "{")
.replace("&", "+")
.replace("`", "'")
# not technically forbidden, but to avoid mismatched parens
.replace(")", "}")
)

# max allowed length for a hyperparameter is 2500
if len(sanitized) > 2500:
# show as much as possible, including the final 20 characters
return f"{sanitized[:2500-23]}...{sanitized[-20:]}"
return sanitized


def _process_input_data(input_data):
Expand Down
2 changes: 1 addition & 1 deletion test/integ_tests/test_create_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def decorator_job(a, b: int, c=0, d: float = 1.0, **extras):
with open(hp_file, "r") as f:
hyperparameters = json.load(f)
assert hyperparameters == {
"a": "MyClass(value)",
"a": "MyClass{value}",
"b": "2",
"c": "0",
"d": "5",
Expand Down
34 changes: 31 additions & 3 deletions test/unit_tests/braket/jobs/test_hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from braket.devices import Devices
from braket.jobs import hybrid_job
from braket.jobs.config import CheckpointConfig, InstanceConfig, OutputDataConfig, StoppingCondition
from braket.jobs.hybrid_job import _serialize_entry_point
from braket.jobs.hybrid_job import _sanitize, _serialize_entry_point
from braket.jobs.local import LocalQuantumJob


Expand Down Expand Up @@ -60,7 +60,7 @@ def my_entry(c=0, d: float = 1.0, **extras):
entry_point=entry_point,
wait_until_complete=wait_until_complete,
job_name="my-entry-123000",
hyperparameters={"c": 0, "d": 1.0},
hyperparameters={"c": "0", "d": "1.0"},
logger=getLogger("braket.jobs.hybrid_job"),
aws_session=aws_session,
)
Expand Down Expand Up @@ -161,7 +161,14 @@ def my_entry(a, b: int, c=0, d: float = 1.0, **extras) -> str:
job_name="my-entry-123000",
instance_config=default_instance,
distribution=distribution,
hyperparameters={"a": "a", "b": 2, "c": 3, "d": 4, "extra_param": "value", "another": 6},
hyperparameters={
"a": "a",
"b": "2",
"c": "3",
"d": "4",
"extra_param": "value",
"another": "6",
},
checkpoint_config=checkpoint_config,
copy_checkpoints_from_job=copy_checkpoints_from_job,
role_arn=role_arn,
Expand Down Expand Up @@ -463,3 +470,24 @@ def test_python_validation(aws_session):
@hybrid_job(device=None, aws_session=aws_session)
def my_job():
pass


@pytest.mark.parametrize(
"hyperparameter, expected",
(
(
"with\nnewline",
"with newline",
),
(
"with weird chars: (&$`)",
"with weird chars: {+?'}",
),
(
"?" * 2600,
f"{'?'*2477}...{'?'*20}",
),
),
)
def test_sanitize_hyperparameters(hyperparameter, expected):
assert _sanitize(hyperparameter) == expected

0 comments on commit 0efbd82

Please sign in to comment.