Skip to content

Commit

Permalink
cache and assert all builtin images
Browse files Browse the repository at this point in the history
  • Loading branch information
ajberdy committed Oct 12, 2023
1 parent 415ca95 commit 88f832d
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/braket/aws/aws_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import os.path
import re
from functools import cache
from pathlib import Path
from typing import Any, NamedTuple, Optional

Expand Down Expand Up @@ -826,6 +827,7 @@ def copy_session(
copied_session._braket_user_agents = self._braket_user_agents
return copied_session

@cache
def get_full_image_tag(self, image_uri: str) -> str:
"""
Get verbose image tag from image uri.
Expand Down
17 changes: 10 additions & 7 deletions src/braket/jobs/hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
S3DataSourceConfig,
StoppingCondition,
)
from braket.jobs.image_uris import Framework, retrieve_image
from braket.jobs.image_uris import Framework, built_in_images, retrieve_image
from braket.jobs.quantum_job import QuantumJob
from braket.jobs.quantum_job_creation import _generate_default_job_name

Expand Down Expand Up @@ -179,7 +179,8 @@ def job_wrapper(*args, **kwargs):
f"{temp_dir}.{entry_point_file_path.stem}:{entry_point.__name__}"
),
"wait_until_complete": wait_until_complete,
"job_name": job_name or _generate_default_job_name(func=entry_point),
"job_name": job_name
or _generate_default_job_name(func=entry_point, decorator=True),
"hyperparameters": _log_hyperparameters(entry_point, args, kwargs),
"logger": logger,
}
Expand Down Expand Up @@ -209,18 +210,20 @@ def job_wrapper(*args, **kwargs):


def _validate_python_version(aws_session: AwsSession, image_uri: str | None):
if image_uri:
# user provides a custom image_uri
if image_uri and image_uri not in built_in_images(aws_session.region):
print(
"Skipping python version validation, make sure versions match "
"between local environment and container."
)
else:
image_uri = retrieve_image(Framework.BASE, aws_session.region)
# set default image_uri to base
image_uri = image_uri or retrieve_image(Framework.BASE, aws_session.region)
tag = aws_session.get_full_image_tag(image_uri)
major_version, minor_version = re.search(r"-py(\d)(\d+)-", tag).groups()
if not (
sys.version_info.major == int(major_version)
and sys.version_info.minor == int(minor_version)
if not (sys.version_info.major, sys.version_info.minor) == (
int(major_version),
int(minor_version),
):
raise RuntimeError(
"Python version must match between local environment and container. "
Expand Down
6 changes: 6 additions & 0 deletions src/braket/jobs/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
import os
from enum import Enum
from functools import cache
from typing import Dict


Expand All @@ -25,6 +26,11 @@ class Framework(str, Enum):
PL_PYTORCH = "PL_PYTORCH"


@cache
def built_in_images(region):
return {retrieve_image(framework, region) for framework in Framework}


def retrieve_image(framework: Framework, region: str) -> str:
"""Retrieves the ECR URI for the Docker image matching the specified arguments.
Expand Down
7 changes: 6 additions & 1 deletion src/braket/jobs/quantum_job_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,16 @@ def prepare_quantum_job(
return create_job_kwargs


def _generate_default_job_name(image_uri: str | None = None, func: Callable | None = None) -> str:
def _generate_default_job_name(
image_uri: str | None = None, func: Callable | None = None, decorator: bool = False
) -> str:
"""
Generate default job name using the image uri and entrypoint function.
Args:
image_uri (str | None): URI for the image container.
func (Callable | None): The entry point function.
decorator (bool): Whether the job is a decorator job. Default: False.
Returns:
str: Hybrid job name.
Expand All @@ -248,6 +251,8 @@ def _generate_default_job_name(image_uri: str | None = None, func: Callable | No

if func:
name = func.__name__.replace("_", "-")
if decorator:
name = "dec-" + name
if len(name) + len(timestamp) > max_length:
name = name[: max_length - len(timestamp) - 1]
warnings.warn(
Expand Down
11 changes: 11 additions & 0 deletions test/unit_tests/braket/aws/test_aws_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,3 +1346,14 @@ def test_add_braket_user_agent(aws_session):
aws_session.add_braket_user_agent(user_agent)
aws_session.add_braket_user_agent(user_agent)
aws_session._braket_user_agents.count(user_agent) == 1


def test_get_full_image_tag(aws_session):
aws_session.ecr_client.batch_get_image.side_effect = (
{"images": [{"imageId": {"imageDigest": "my-digest"}}]},
{"images": [{"imageId": {"imageTag": "my-tag"}}]},
AssertionError("Image tag not cached"),
)
image_uri = "123456.image_uri/repo-name:my-tag"
assert aws_session.get_full_image_tag(image_uri) == "my-tag"
assert aws_session.get_full_image_tag(image_uri) == "my-tag"
18 changes: 16 additions & 2 deletions test/unit_tests/braket/jobs/test_hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,25 @@
from braket.jobs.local import LocalQuantumJob


@pytest.fixture
def aws_session():
aws_session = MagicMock()
aws_session.get_full_image_tag.return_value = "1.0-cpu-py310-ubuntu22.04"
aws_session.region = "us-west-2"
return aws_session


@patch("braket.jobs.image_uris.retrieve_image")
@patch("time.time", return_value=123.0)
@patch("builtins.open")
@patch("tempfile.TemporaryDirectory")
@patch.object(AwsQuantumJob, "create")
def test_decorator_defaults(mock_create, mock_tempdir, _mock_open, mock_time):
@hybrid_job(device=None)
def test_decorator_defaults(
mock_create, mock_tempdir, _mock_open, mock_time, mock_retrieve, aws_session
):
mock_retrieve.return_value = "00000000.dkr.ecr.us-west-2.amazonaws.com/latest"

@hybrid_job(device=None, aws_session=aws_session)
def my_entry(c=0, d: float = 1.0, **extras):
return "my entry return value"

Expand All @@ -47,6 +60,7 @@ def my_entry(c=0, d: float = 1.0, **extras):
job_name="my-entry-123000",
hyperparameters={"c": 0, "d": 1.0},
logger=getLogger("braket.jobs.hybrid_job"),
aws_session=aws_session,
)
assert mock_tempdir.return_value.__exit__.called

Expand Down

0 comments on commit 88f832d

Please sign in to comment.