From df9f2957877da022e543bcd2c08840c30f49d660 Mon Sep 17 00:00:00 2001 From: Matt Beach Date: Thu, 25 Apr 2024 15:19:41 -0400 Subject: [PATCH] fix: address comments --- examples/reservation.py | 10 +++- src/braket/aws/aws_session.py | 11 +++- src/braket/reservations/reservations.py | 25 ++++---- test/integ_tests/test_reservation_arn.py | 21 +++---- .../braket/reservations/test_reservations.py | 59 +++++++++++++------ 5 files changed, 81 insertions(+), 45 deletions(-) diff --git a/examples/reservation.py b/examples/reservation.py index 26449fcb8..31e34cf6a 100644 --- a/examples/reservation.py +++ b/examples/reservation.py @@ -21,11 +21,15 @@ # To run a task in a device reservation, change the device to the one you reserved # and fill in your reservation ARN. -with DirectReservation(device, reservation_arn="reservation ARN"): +with DirectReservation(device, reservation_arn=""): task = device.run(bell, shots=100) print(task.result().measurement_counts) # Alternatively, you may start the reservation globally -DirectReservation(device, reservation_arn="reservation ARN").start() -task = device.run(bell, shots=100, reservation_arn="reservation ARN") +DirectReservation(device, reservation_arn="").start() +task = device.run(bell, shots=100) +print(task.result().measurement_counts) + +# Lastly, you may pass the reservation ARN directly to a quantum task +task = device.run(bell, shots=100, reservation_arn="") print(task.result().measurement_counts) diff --git a/src/braket/aws/aws_session.py b/src/braket/aws/aws_session.py index 732abef7d..205e1b0af 100644 --- a/src/braket/aws/aws_session.py +++ b/src/braket/aws/aws_session.py @@ -236,10 +236,15 @@ def create_quantum_task(self, **boto3_kwargs) -> str: str: The ARN of the quantum task. """ # Add reservation arn if available and device is correct. - device_arn = os.getenv("AMZN_BRAKET_DEVICE_ARN_TEMP") - reservation_arn = os.getenv("AMZN_BRAKET_RESERVATION_ARN_TEMP") + device_arn = os.getenv("AMZN_BRAKET_RESERVATION_DEVICE_ARN") + reservation_arn = os.getenv("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN") if device_arn == boto3_kwargs["deviceArn"] and reservation_arn: - boto3_kwargs["reservation_arn"] = reservation_arn + boto3_kwargs["associations"] = [ + { + "arn": reservation_arn, + "type": "RESERVATION_TIME_WINDOW_ARN", + } + ] # Add job token to request, if available. job_token = os.getenv("AMZN_BRAKET_JOB_TOKEN") diff --git a/src/braket/reservations/reservations.py b/src/braket/reservations/reservations.py index 74d9d00fd..9ce1e2806 100644 --- a/src/braket/reservations/reservations.py +++ b/src/braket/reservations/reservations.py @@ -17,19 +17,20 @@ from contextlib import AbstractContextManager from braket.aws import AwsDevice +from braket.devices import Device class DirectReservation(AbstractContextManager): """ - Modify AwsQuantumTasks created within this context to run on a device with a reservation - ARN.This is useful for ensuring that all quantum task + Context manager that modifies AwsQuantumTasks created within the context to use a reservation + ARN for all tasks targeting the specified device. Reservations are AWS account and device specific. Only the AWS account that created the reservation can use your reservation ARN. Additionally, the reservation ARN is only valid on the reserved device at the chosen start and end times. Args: - device (AwsDevice | str): The Braket device for which you have a reservation ARN, or + device (Device | str): The Braket device for which you have a reservation ARN, or optionally the device ARN. reservation_arn (str | None): The Braket Direct reservation ARN to be applied to all quantum tasks run within the context. @@ -50,11 +51,13 @@ class DirectReservation(AbstractContextManager): [1] https://docs.aws.amazon.com/braket/latest/developerguide/braket-reservations.html """ - def __init__(self, device: AwsDevice | str, reservation_arn: str | None): + def __init__(self, device: Device | str, reservation_arn: str | None): if isinstance(device, AwsDevice): self.device_arn = device.arn elif isinstance(device, str): self.device_arn = device + elif isinstance(device, Device): # LocalSimulator + self.device_arn = "" # instead of None, use empty string else: raise ValueError("device must be an AwsDevice or its ARN.") @@ -71,15 +74,17 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: def start(self) -> None: """Start the reservation context.""" if self.context_active: - raise RuntimeError("Context is already active") - os.environ["AMZN_BRAKET_DEVICE_ARN_TEMP"] = self.device_arn - os.environ["AMZN_BRAKET_RESERVATION_ARN_TEMP"] = self.reservation_arn + raise RuntimeError("Reservation context is already active.") + + os.environ["AMZN_BRAKET_RESERVATION_DEVICE_ARN"] = self.device_arn + if self.reservation_arn: + os.environ["AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN"] = self.reservation_arn self.context_active = True def stop(self) -> None: """Stop the reservation context.""" if not self.context_active: - raise RuntimeError("Context is not active") - os.environ.pop("AMZN_BRAKET_DEVICE_ARN_TEMP", None) - os.environ.pop("AMZN_BRAKET_RESERVATION_ARN_TEMP", None) + raise RuntimeError("Reservation context is not active.") + os.environ.pop("AMZN_BRAKET_RESERVATION_DEVICE_ARN", None) + os.environ.pop("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN", None) self.context_active = False diff --git a/test/integ_tests/test_reservation_arn.py b/test/integ_tests/test_reservation_arn.py index 98b87f075..7d6d5b9ce 100644 --- a/test/integ_tests/test_reservation_arn.py +++ b/test/integ_tests/test_reservation_arn.py @@ -21,6 +21,7 @@ from braket.circuits import Circuit from braket.devices import Devices from braket.jobs import get_job_device_arn, hybrid_job +from braket.reservations import DirectReservation @pytest.fixture @@ -36,11 +37,11 @@ def test_create_task_via_invalid_reservation_arn_on_qpu(reservation_arn): device = AwsDevice(Devices.IonQ.Harmony) with pytest.raises(ClientError, match="Reservation arn is invalid"): - device.run( - circuit, - shots=10, - reservation_arn=reservation_arn, - ) + device.run(circuit, shots=10, reservation_arn=reservation_arn) + + with pytest.raises(ClientError, match="Reservation arn is invalid"): + with DirectReservation(device, reservation_arn=reservation_arn): + device.run(circuit, shots=10) def test_create_task_via_reservation_arn_on_simulator(reservation_arn): @@ -48,11 +49,11 @@ def test_create_task_via_reservation_arn_on_simulator(reservation_arn): device = AwsDevice(Devices.Amazon.SV1) with pytest.raises(ClientError, match="Braket Direct is not supported for"): - device.run( - circuit, - shots=10, - reservation_arn=reservation_arn, - ) + device.run(circuit, shots=10, reservation_arn=reservation_arn) + + with pytest.raises(ClientError, match="Braket Direct is not supported for"): + with DirectReservation(device, reservation_arn=reservation_arn): + device.run(circuit, shots=10) @pytest.mark.xfail( diff --git a/test/unit_tests/braket/reservations/test_reservations.py b/test/unit_tests/braket/reservations/test_reservations.py index 5b6ecf56d..7abbf3dce 100644 --- a/test/unit_tests/braket/reservations/test_reservations.py +++ b/test/unit_tests/braket/reservations/test_reservations.py @@ -17,48 +17,55 @@ import pytest from braket.aws import AwsDevice, AwsSession -from braket.reservations.reservations import DirectReservation +from braket.devices import LocalSimulator +from braket.reservations import DirectReservation @pytest.fixture def aws_device(): mock_device = MagicMock(spec=AwsDevice) - mock_device._arn = "device_arn_example" + mock_device.arn = "device_arn_example" return mock_device def test_direct_reservation_with_device_object(aws_device): - reservation = DirectReservation(aws_device, "reservation_arn_example") - assert reservation.device_arn == "device_arn_example" - assert reservation.reservation_arn == "reservation_arn_example" + with DirectReservation(aws_device, "reservation_arn_example") as reservation: + assert reservation.device_arn == "device_arn_example" + assert reservation.reservation_arn == "reservation_arn_example" def test_direct_reservation_with_string(): - reservation = DirectReservation("my:string:arn", "reservation_arn_example") - assert reservation.device_arn == "my:string:arn" - assert reservation.reservation_arn == "reservation_arn_example" + with DirectReservation("my:string:arn", "reservation_arn_example") as reservation: + assert reservation.device_arn == "my:string:arn" + assert reservation.reservation_arn == "reservation_arn_example" + + +def test_reservation_local_device(): + mock_device = MagicMock(spec=LocalSimulator) + with DirectReservation(mock_device, "reservation_arn_example"): + os.environ["AMZN_BRAKET_RESERVATION_DEVICE_ARN"] = "" def test_direct_reservation_with_invalid_type(): """Test initialization with an invalid type should raise ValueError.""" - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="device must be an AwsDevice or its ARN."): DirectReservation(123, "reservation_arn_example") def test_context_management(aws_device): """Test the context manager functionality.""" with DirectReservation(aws_device, "reservation_arn_example"): - assert os.getenv("AMZN_BRAKET_DEVICE_ARN_TEMP") == "device_arn_example" - assert os.getenv("AMZN_BRAKET_RESERVATION_ARN_TEMP") == "reservation_arn_example" - assert os.getenv("AMZN_BRAKET_DEVICE_ARN_TEMP") is None - assert os.getenv("AMZN_BRAKET_RESERVATION_ARN_TEMP") is None + assert os.getenv("AMZN_BRAKET_RESERVATION_DEVICE_ARN") == "device_arn_example" + assert os.getenv("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN") == "reservation_arn_example" + assert os.getenv("AMZN_BRAKET_RESERVATION_DEVICE_ARN") is None + assert os.getenv("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN") is None def test_start_reservation_already_active(aws_device): """Test starting an already active context raises RuntimeError.""" reservation = DirectReservation(aws_device, "reservation_arn_example") reservation.start() - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError, match="Reservation context is already active."): reservation.start() reservation.stop() @@ -66,12 +73,12 @@ def test_start_reservation_already_active(aws_device): def test_stop_reservation_not_active(aws_device): """Test stopping a non-active context raises RuntimeError.""" reservation = DirectReservation(aws_device, "reservation_arn_example") - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError, match="Reservation context is not active."): reservation.stop() def test_start_without_device_arn(): - with pytest.raises(ValueError, match="Device ARN must be an AwsDevice or string."): + with pytest.raises(ValueError, match="device must be an AwsDevice or its ARN."): DirectReservation(None, "reservation_arn_example") @@ -81,8 +88,17 @@ def test_multiple_start_stop_cycles(aws_device): reservation.stop() reservation.start() reservation.stop() - assert os.getenv("AMZN_BRAKET_DEVICE_ARN_TEMP") is None - assert os.getenv("AMZN_BRAKET_RESERVATION_ARN_TEMP") is None + assert os.getenv("AMZN_BRAKET_RESERVATION_DEVICE_ARN") is None + assert os.getenv("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN") is None + + +def test_stop(aws_device): + with pytest.raises(RuntimeError, match="Reservation context is not active."): + DirectReservation(aws_device, "reservation_arn_example").stop() + + +def test_reservation_none(aws_device): + DirectReservation(aws_device, reservation_arn=None).start() def test_create_quantum_task_with_correct_device_and_reservation(): @@ -105,5 +121,10 @@ def test_create_quantum_task_with_correct_device_and_reservation(): with DirectReservation(device_arn, reservation_arn): aws_session.create_quantum_task(**kwargs) - kwargs["reservation_arn"] = reservation_arn + kwargs["associations"] = [ + { + "arn": reservation_arn, + "type": "RESERVATION_TIME_WINDOW_ARN", + } + ] mock_client.create_quantum_task.assert_called_once_with(**kwargs)