Skip to content

Commit

Permalink
feat: python version validation (#737)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajberdy authored Oct 12, 2023
1 parent 32a6d7a commit 6fa0005
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 17 deletions.
36 changes: 36 additions & 0 deletions src/braket/aws/aws_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import os.path
import re
from functools import cache
from pathlib import Path
from typing import Any, NamedTuple, Optional

Expand Down Expand Up @@ -825,3 +826,38 @@ def copy_session(
# Preserve user_agent information
copied_session._braket_user_agents = self._braket_user_agents
return copied_session

@cache
def get_full_image_tag(self, image_uri: str) -> str:
"""
Get verbose image tag from image uri.
Args:
image_uri (str): Image uri to get tag for.
Returns:
str: Verbose image tag for given image.
"""
registry = image_uri.split(".")[0]
repository, tag = image_uri.split("/")[-1].split(":")

# get image digest of latest image
digest = self.ecr_client.batch_get_image(
registryId=registry,
repositoryName=repository,
imageIds=[{"imageTag": tag}],
)["images"][0]["imageId"]["imageDigest"]

# get all images matching digest (same image, different tags)
images = self.ecr_client.batch_get_image(
registryId=registry,
repositoryName=repository,
imageIds=[{"imageDigest": digest}],
)["images"]

# find the tag with the python version info
for image in images:
if re.search(r"py\d\d+", tag := image["imageId"]["imageTag"]):
return tag

raise ValueError("Full image tag missing.")
33 changes: 31 additions & 2 deletions src/braket/jobs/hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import functools
import importlib.util
import inspect
import re
import shutil
import sys
import tempfile
import warnings
from collections.abc import Callable, Iterable
Expand All @@ -36,6 +38,7 @@
S3DataSourceConfig,
StoppingCondition,
)
from braket.jobs.image_uris import Framework, built_in_images, retrieve_image
from braket.jobs.quantum_job import QuantumJob
from braket.jobs.quantum_job_creation import _generate_default_job_name

Expand Down Expand Up @@ -146,6 +149,8 @@ def hybrid_job(
logger (Logger): Logger object with which to write logs, such as task statuses
while waiting for task to be in a terminal state. Default is `getLogger(__name__)`
"""
aws_session = aws_session or AwsSession()
_validate_python_version(aws_session, image_uri)

def _hybrid_job(entry_point):
@functools.wraps(entry_point)
Expand Down Expand Up @@ -203,6 +208,30 @@ def job_wrapper(*args, **kwargs):
return _hybrid_job


def _validate_python_version(aws_session: AwsSession, image_uri: str | None):
"""Validate python version at job definition time"""
# user provides a custom image_uri
if image_uri and image_uri not in built_in_images(aws_session.region):
print(
"Skipping python version validation, make sure versions match "
"between local environment and container."
)
else:
# set default image_uri to base
image_uri = image_uri or retrieve_image(Framework.BASE, aws_session.region)
tag = aws_session.get_full_image_tag(image_uri)
major_version, minor_version = re.search(r"-py(\d)(\d+)-", tag).groups()
if not (sys.version_info.major, sys.version_info.minor) == (
int(major_version),
int(minor_version),
):
raise RuntimeError(
"Python version must match between local environment and container. "
f"Client is running Python {sys.version_info.major}.{sys.version_info.minor} "
f"locally, but container uses Python {major_version}.{minor_version}."
)


class _IncludeModules:
def __init__(self, modules: str | ModuleType | Iterable[str | ModuleType] = None):
modules = modules or []
Expand All @@ -224,7 +253,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
cloudpickle.unregister_pickle_by_value(module)


def _serialize_entry_point(entry_point: Callable, args: list, kwargs: dict) -> str:
def _serialize_entry_point(entry_point: Callable, args: tuple, kwargs: dict) -> str:
"""Create an entry point from a function"""

def wrapped_entry_point():
Expand All @@ -249,7 +278,7 @@ def wrapped_entry_point():
)


def _log_hyperparameters(entry_point: Callable, args: list, kwargs: dict):
def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict):
"""Capture function arguments as hyperparameters"""
signature = inspect.signature(entry_point)
bound_args = signature.bind(*args, **kwargs)
Expand Down
6 changes: 6 additions & 0 deletions src/braket/jobs/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
import os
from enum import Enum
from functools import cache
from typing import Dict


Expand All @@ -25,6 +26,11 @@ class Framework(str, Enum):
PL_PYTORCH = "PL_PYTORCH"


def built_in_images(region):
return {retrieve_image(framework, region) for framework in Framework}


@cache
def retrieve_image(framework: Framework, region: str) -> str:
"""Retrieves the ECR URI for the Docker image matching the specified arguments.
Expand Down
35 changes: 35 additions & 0 deletions test/unit_tests/braket/aws/test_aws_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,3 +1346,38 @@ def test_add_braket_user_agent(aws_session):
aws_session.add_braket_user_agent(user_agent)
aws_session.add_braket_user_agent(user_agent)
aws_session._braket_user_agents.count(user_agent) == 1


def test_get_full_image_tag(aws_session):
aws_session.ecr_client.batch_get_image.side_effect = (
{"images": [{"imageId": {"imageDigest": "my-digest"}}]},
{
"images": [
{"imageId": {"imageTag": "my-tag"}},
{"imageId": {"imageTag": "my-tag-py3"}},
{"imageId": {"imageTag": "my-tag-py310"}},
{"imageId": {"imageTag": "latest"}},
]
},
AssertionError("Image tag not cached"),
)
image_uri = "123456.image_uri/repo-name:my-tag"
assert aws_session.get_full_image_tag(image_uri) == "my-tag-py310"
assert aws_session.get_full_image_tag(image_uri) == "my-tag-py310"


def test_get_full_image_tag_no_py_info(aws_session):
aws_session.ecr_client.batch_get_image.side_effect = (
{"images": [{"imageId": {"imageDigest": "my-digest"}}]},
{
"images": [
{"imageId": {"imageTag": "my-tag"}},
{"imageId": {"imageTag": "latest"}},
]
},
)
image_uri = "123456.image_uri/repo-name:my-tag"

no_py_info = "Full image tag missing."
with pytest.raises(ValueError, match=no_py_info):
aws_session.get_full_image_tag(image_uri)
Loading

0 comments on commit 6fa0005

Please sign in to comment.