diff --git a/src/braket/aws/aws_session.py b/src/braket/aws/aws_session.py index 225140d8f..eabe06dc8 100644 --- a/src/braket/aws/aws_session.py +++ b/src/braket/aws/aws_session.py @@ -17,6 +17,7 @@ import os import os.path import re +from functools import cache from pathlib import Path from typing import Any, NamedTuple, Optional @@ -825,3 +826,38 @@ def copy_session( # Preserve user_agent information copied_session._braket_user_agents = self._braket_user_agents return copied_session + + @cache + def get_full_image_tag(self, image_uri: str) -> str: + """ + Get verbose image tag from image uri. + + Args: + image_uri (str): Image uri to get tag for. + + Returns: + str: Verbose image tag for given image. + """ + registry = image_uri.split(".")[0] + repository, tag = image_uri.split("/")[-1].split(":") + + # get image digest of latest image + digest = self.ecr_client.batch_get_image( + registryId=registry, + repositoryName=repository, + imageIds=[{"imageTag": tag}], + )["images"][0]["imageId"]["imageDigest"] + + # get all images matching digest (same image, different tags) + images = self.ecr_client.batch_get_image( + registryId=registry, + repositoryName=repository, + imageIds=[{"imageDigest": digest}], + )["images"] + + # find the tag with the python version info + for image in images: + if re.search(r"py\d\d+", tag := image["imageId"]["imageTag"]): + return tag + + raise ValueError("Full image tag missing.") diff --git a/src/braket/jobs/hybrid_job.py b/src/braket/jobs/hybrid_job.py index 4717201fb..3e5b9078d 100644 --- a/src/braket/jobs/hybrid_job.py +++ b/src/braket/jobs/hybrid_job.py @@ -16,7 +16,9 @@ import functools import importlib.util import inspect +import re import shutil +import sys import tempfile import warnings from collections.abc import Callable, Iterable @@ -36,6 +38,7 @@ S3DataSourceConfig, StoppingCondition, ) +from braket.jobs.image_uris import Framework, built_in_images, retrieve_image from braket.jobs.quantum_job import QuantumJob from braket.jobs.quantum_job_creation import _generate_default_job_name @@ -146,6 +149,8 @@ def hybrid_job( logger (Logger): Logger object with which to write logs, such as task statuses while waiting for task to be in a terminal state. Default is `getLogger(__name__)` """ + aws_session = aws_session or AwsSession() + _validate_python_version(aws_session, image_uri) def _hybrid_job(entry_point): @functools.wraps(entry_point) @@ -203,6 +208,30 @@ def job_wrapper(*args, **kwargs): return _hybrid_job +def _validate_python_version(aws_session: AwsSession, image_uri: str | None): + """Validate python version at job definition time""" + # user provides a custom image_uri + if image_uri and image_uri not in built_in_images(aws_session.region): + print( + "Skipping python version validation, make sure versions match " + "between local environment and container." + ) + else: + # set default image_uri to base + image_uri = image_uri or retrieve_image(Framework.BASE, aws_session.region) + tag = aws_session.get_full_image_tag(image_uri) + major_version, minor_version = re.search(r"-py(\d)(\d+)-", tag).groups() + if not (sys.version_info.major, sys.version_info.minor) == ( + int(major_version), + int(minor_version), + ): + raise RuntimeError( + "Python version must match between local environment and container. " + f"Client is running Python {sys.version_info.major}.{sys.version_info.minor} " + f"locally, but container uses Python {major_version}.{minor_version}." + ) + + class _IncludeModules: def __init__(self, modules: str | ModuleType | Iterable[str | ModuleType] = None): modules = modules or [] @@ -224,7 +253,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): cloudpickle.unregister_pickle_by_value(module) -def _serialize_entry_point(entry_point: Callable, args: list, kwargs: dict) -> str: +def _serialize_entry_point(entry_point: Callable, args: tuple, kwargs: dict) -> str: """Create an entry point from a function""" def wrapped_entry_point(): @@ -249,7 +278,7 @@ def wrapped_entry_point(): ) -def _log_hyperparameters(entry_point: Callable, args: list, kwargs: dict): +def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict): """Capture function arguments as hyperparameters""" signature = inspect.signature(entry_point) bound_args = signature.bind(*args, **kwargs) diff --git a/src/braket/jobs/image_uris.py b/src/braket/jobs/image_uris.py index 0a29f2ce2..eedc5e795 100644 --- a/src/braket/jobs/image_uris.py +++ b/src/braket/jobs/image_uris.py @@ -14,6 +14,7 @@ import json import os from enum import Enum +from functools import cache from typing import Dict @@ -25,6 +26,11 @@ class Framework(str, Enum): PL_PYTORCH = "PL_PYTORCH" +def built_in_images(region): + return {retrieve_image(framework, region) for framework in Framework} + + +@cache def retrieve_image(framework: Framework, region: str) -> str: """Retrieves the ECR URI for the Docker image matching the specified arguments. diff --git a/test/unit_tests/braket/aws/test_aws_session.py b/test/unit_tests/braket/aws/test_aws_session.py index bfc65d54b..d2ee45a3a 100644 --- a/test/unit_tests/braket/aws/test_aws_session.py +++ b/test/unit_tests/braket/aws/test_aws_session.py @@ -1346,3 +1346,38 @@ def test_add_braket_user_agent(aws_session): aws_session.add_braket_user_agent(user_agent) aws_session.add_braket_user_agent(user_agent) aws_session._braket_user_agents.count(user_agent) == 1 + + +def test_get_full_image_tag(aws_session): + aws_session.ecr_client.batch_get_image.side_effect = ( + {"images": [{"imageId": {"imageDigest": "my-digest"}}]}, + { + "images": [ + {"imageId": {"imageTag": "my-tag"}}, + {"imageId": {"imageTag": "my-tag-py3"}}, + {"imageId": {"imageTag": "my-tag-py310"}}, + {"imageId": {"imageTag": "latest"}}, + ] + }, + AssertionError("Image tag not cached"), + ) + image_uri = "123456.image_uri/repo-name:my-tag" + assert aws_session.get_full_image_tag(image_uri) == "my-tag-py310" + assert aws_session.get_full_image_tag(image_uri) == "my-tag-py310" + + +def test_get_full_image_tag_no_py_info(aws_session): + aws_session.ecr_client.batch_get_image.side_effect = ( + {"images": [{"imageId": {"imageDigest": "my-digest"}}]}, + { + "images": [ + {"imageId": {"imageTag": "my-tag"}}, + {"imageId": {"imageTag": "latest"}}, + ] + }, + ) + image_uri = "123456.image_uri/repo-name:my-tag" + + no_py_info = "Full image tag missing." + with pytest.raises(ValueError, match=no_py_info): + aws_session.get_full_image_tag(image_uri) diff --git a/test/unit_tests/braket/jobs/test_hybrid_job.py b/test/unit_tests/braket/jobs/test_hybrid_job.py index 458713ff0..063623383 100644 --- a/test/unit_tests/braket/jobs/test_hybrid_job.py +++ b/test/unit_tests/braket/jobs/test_hybrid_job.py @@ -1,6 +1,7 @@ import ast import importlib import re +import sys import tempfile from logging import getLogger from pathlib import Path @@ -19,12 +20,26 @@ from braket.jobs.local import LocalQuantumJob +@pytest.fixture +def aws_session(): + aws_session = MagicMock() + python_version_str = f"py{sys.version_info.major}{sys.version_info.minor}" + aws_session.get_full_image_tag.return_value = f"1.0-cpu-{python_version_str}-ubuntu22.04" + aws_session.region = "us-west-2" + return aws_session + + +@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image") @patch("time.time", return_value=123.0) @patch("builtins.open") @patch("tempfile.TemporaryDirectory") @patch.object(AwsQuantumJob, "create") -def test_decorator_defaults(mock_create, mock_tempdir, _mock_open, mock_time): - @hybrid_job(device=None) +def test_decorator_defaults( + mock_create, mock_tempdir, _mock_open, mock_time, mock_retrieve, aws_session +): + mock_retrieve.return_value = "00000000.dkr.ecr.us-west-2.amazonaws.com/latest" + + @hybrid_job(device=None, aws_session=aws_session) def my_entry(c=0, d: float = 1.0, **extras): return "my entry return value" @@ -47,11 +62,13 @@ def my_entry(c=0, d: float = 1.0, **extras): job_name="my-entry-123000", hyperparameters={"c": 0, "d": 1.0}, logger=getLogger("braket.jobs.hybrid_job"), + aws_session=aws_session, ) assert mock_tempdir.return_value.__exit__.called @pytest.mark.parametrize("include_modules", (job_module, ["job_module"])) +@patch("braket.jobs.image_uris.retrieve_image") @patch("sys.stdout") @patch("time.time", return_value=123.0) @patch("cloudpickle.register_pickle_by_value") @@ -67,8 +84,10 @@ def test_decorator_non_defaults( mock_unregister, mock_time, mock_stdout, + mock_retrieve, include_modules, ): + mock_retrieve.return_value = "should-not-be-used" dependencies = "my_requirements.txt" image_uri = "my_image.uri" default_instance = InstanceConfig() @@ -162,14 +181,18 @@ def my_entry(a, b: int, c=0, d: float = 1.0, **extras) -> str: mock_stdout.write.assert_any_call(s3_not_linked) +@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image") @patch("time.time", return_value=123.0) @patch("builtins.open") @patch("tempfile.TemporaryDirectory") @patch.object(AwsQuantumJob, "create") -def test_decorator_non_dict_input(mock_create, mock_tempdir, _mock_open, mock_time): +def test_decorator_non_dict_input( + mock_create, mock_tempdir, _mock_open, mock_time, mock_retrieve, aws_session +): + mock_retrieve.return_value = "00000000.dkr.ecr.us-west-2.amazonaws.com/latest" input_prefix = "my_input" - @hybrid_job(device=None, input_data=input_prefix) + @hybrid_job(device=None, input_data=input_prefix, aws_session=aws_session) def my_entry(): return "my entry return value" @@ -193,16 +216,22 @@ def my_entry(): hyperparameters={}, logger=getLogger("braket.jobs.hybrid_job"), input_data=input_prefix, + aws_session=aws_session, ) assert mock_tempdir.return_value.__exit__.called +@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image") @patch("time.time", return_value=123.0) @patch("builtins.open") @patch("tempfile.TemporaryDirectory") @patch.object(LocalQuantumJob, "create") -def test_decorator_local(mock_create, mock_tempdir, _mock_open, mock_time): - @hybrid_job(device=Devices.Amazon.SV1, local=True) +def test_decorator_local( + mock_create, mock_tempdir, _mock_open, mock_time, mock_retrieve, aws_session +): + mock_retrieve.return_value = "00000000.dkr.ecr.us-west-2.amazonaws.com/latest" + + @hybrid_job(device=Devices.Amazon.SV1, local=True, aws_session=aws_session) def my_entry(): return "my entry return value" @@ -221,15 +250,21 @@ def my_entry(): entry_point=entry_point, job_name="my-entry-123000", hyperparameters={}, + aws_session=aws_session, ) assert mock_tempdir.return_value.__exit__.called +@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image") @patch("time.time", return_value=123.0) @patch("builtins.open") @patch("tempfile.TemporaryDirectory") @patch.object(LocalQuantumJob, "create") -def test_decorator_local_unsupported_args(mock_create, mock_tempdir, _mock_open, mock_time): +def test_decorator_local_unsupported_args( + mock_create, mock_tempdir, _mock_open, mock_time, mock_retrieve, aws_session +): + mock_retrieve.return_value = "00000000.dkr.ecr.us-west-2.amazonaws.com/latest" + @hybrid_job( device=Devices.Amazon.SV1, local=True, @@ -240,6 +275,7 @@ def test_decorator_local_unsupported_args(mock_create, mock_tempdir, _mock_open, stopping_condition=StoppingCondition(), tags={"my_tag": "my_value"}, logger=getLogger(__name__), + aws_session=aws_session, ) def my_entry(): return "my entry return value" @@ -259,16 +295,22 @@ def my_entry(): entry_point=entry_point, job_name="my-entry-123000", hyperparameters={}, + aws_session=aws_session, ) assert mock_tempdir.return_value.__exit__.called +@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image") @patch("time.time", return_value=123.0) @patch("builtins.open") @patch("tempfile.TemporaryDirectory") @patch.object(AwsQuantumJob, "create") -def test_job_name_too_long(mock_create, mock_tempdir, _mock_open, mock_time): - @hybrid_job(device="local:braket/default") +def test_job_name_too_long( + mock_create, mock_tempdir, _mock_open, mock_time, mock_retrieve, aws_session +): + mock_retrieve.return_value = "00000000.dkr.ecr.us-west-2.amazonaws.com/latest" + + @hybrid_job(device="local:braket/default", aws_session=aws_session) def this_is_a_50_character_func_name_for_testing_names(): return "my entry return value" @@ -295,17 +337,23 @@ def this_is_a_50_character_func_name_for_testing_names(): job_name=expected_job_name, hyperparameters={}, logger=getLogger("braket.jobs.hybrid_job"), + aws_session=aws_session, ) assert len(expected_job_name) == 50 assert mock_tempdir.return_value.__exit__.called +@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image") @patch("time.time", return_value=123.0) @patch("builtins.open") @patch("tempfile.TemporaryDirectory") @patch.object(AwsQuantumJob, "create") -def test_decorator_pos_only_slash(mock_create, mock_tempdir, _mock_open, mock_time): - @hybrid_job(device="local:braket/default") +def test_decorator_pos_only_slash( + mock_create, mock_tempdir, _mock_open, mock_time, mock_retrieve, aws_session +): + mock_retrieve.return_value = "00000000.dkr.ecr.us-west-2.amazonaws.com/latest" + + @hybrid_job(device="local:braket/default", aws_session=aws_session) def my_entry(pos_only, /): return "my entry return value" @@ -329,16 +377,22 @@ def my_entry(pos_only, /): job_name="my-entry-123000", hyperparameters={}, logger=getLogger("braket.jobs.hybrid_job"), + aws_session=aws_session, ) assert mock_tempdir.return_value.__exit__.called +@patch.object(sys.modules["braket.jobs.hybrid_job"], "retrieve_image") @patch("time.time", return_value=123.0) @patch("builtins.open") @patch("tempfile.TemporaryDirectory") @patch.object(AwsQuantumJob, "create") -def test_decorator_pos_only_args(mock_create, mock_tempdir, _mock_open, mock_time): - @hybrid_job(device="local:braket/default") +def test_decorator_pos_only_args( + mock_create, mock_tempdir, _mock_open, mock_time, mock_retrieve, aws_session +): + mock_retrieve.return_value = "00000000.dkr.ecr.us-west-2.amazonaws.com/latest" + + @hybrid_job(device="local:braket/default", aws_session=aws_session) def my_entry(*args): return "my entry return value" @@ -362,14 +416,15 @@ def my_entry(*args): job_name="my-entry-123000", hyperparameters={}, logger=getLogger("braket.jobs.hybrid_job"), + aws_session=aws_session, ) assert mock_tempdir.return_value.__exit__.called -def test_serialization_error(): +def test_serialization_error(aws_session): ssl_context = SSLContext() - @hybrid_job(device=None) + @hybrid_job(device=None, aws_session=aws_session) def fails_serialization(): print(ssl_context) @@ -393,3 +448,18 @@ def my_entry(*args, **kwargs): recovered = cloudpickle.loads(byte_str) assert recovered() == (args, kwargs) + + +def test_python_validation(aws_session): + aws_session.get_full_image_tag.return_value = "1.0-cpu-py38-ubuntu22.04" + + bad_version = ( + "Python version must match between local environment and container. " + f"Client is running Python {sys.version_info.major}.{sys.version_info.minor} " + "locally, but container uses Python 3.8." + ) + with pytest.raises(RuntimeError, match=bad_version): + + @hybrid_job(device=None, aws_session=aws_session) + def my_job(): + pass