Skip to content

Commit

Permalink
fix: new env var names
Browse files Browse the repository at this point in the history
  • Loading branch information
mbeach-aws committed Apr 26, 2024
1 parent 5bc7a1c commit e8720f5
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 36 deletions.
6 changes: 2 additions & 4 deletions src/braket/aws/aws_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,22 +236,20 @@ def create_quantum_task(self, **boto3_kwargs) -> str:
str: The ARN of the quantum task.
"""
# Add reservation arn if available and device is correct.
device_arn = os.getenv("AMZN_BRAKET_DEVICE_ARN_TEMP")
reservation_arn = os.getenv("AMZN_BRAKET_RESERVATION_ARN")
device_arn = os.getenv("AMZN_BRAKET_RESERVATION_DEVICE_ARN")
reservation_arn = os.getenv("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN")
if device_arn == boto3_kwargs["deviceArn"] and reservation_arn:
boto3_kwargs["associations"] = [
{
"arn": reservation_arn,
"type": "RESERVATION_TIME_WINDOW_ARN",
}
]
print(boto3_kwargs["associations"])

# Add job token to request, if available.
job_token = os.getenv("AMZN_BRAKET_JOB_TOKEN")
if job_token:
boto3_kwargs["jobToken"] = job_token
print("ARGS are", boto3_kwargs)
response = self.braket_client.create_quantum_task(**boto3_kwargs)
broadcast_event(
_TaskCreationEvent(
Expand Down
26 changes: 15 additions & 11 deletions src/braket/reservations/reservations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,20 @@
from contextlib import AbstractContextManager

from braket.aws import AwsDevice
from braket.devices import Device


class DirectReservation(AbstractContextManager):
"""
Modify AwsQuantumTasks created within this context to run on a device with a reservation
ARN.This is useful for ensuring that all quantum task
Context manager that modifies AwsQuantumTasks created within the context to use a reservation
ARN for all tasks targetting the specified device.
Reservations are AWS account and device specific. Only the AWS account that created the
reservation can use your reservation ARN. Additionally, the reservation ARN is only valid on the
reserved device at the chosen start and end times.
Args:
device (AwsDevice | str): The Braket device for which you have a reservation ARN, or
device (Device | str): The Braket device for which you have a reservation ARN, or
optionally the device ARN.
reservation_arn (str | None): The Braket Direct reservation ARN to be applied to all
quantum tasks run within the context.
Expand All @@ -50,11 +51,13 @@ class DirectReservation(AbstractContextManager):
[1] https://docs.aws.amazon.com/braket/latest/developerguide/braket-reservations.html
"""

def __init__(self, device: AwsDevice | str, reservation_arn: str | None):
def __init__(self, device: Device | str, reservation_arn: str | None):
if isinstance(device, AwsDevice):
self.device_arn = device.arn
elif isinstance(device, str):
self.device_arn = device
elif isinstance(device, Device): # LocalSimulator
self.device_arn = "" # instead of None, use empty string
else:
raise ValueError("device must be an AwsDevice or its ARN.")

Expand All @@ -71,16 +74,17 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
def start(self) -> None:
"""Start the reservation context."""
if self.context_active:
raise RuntimeError("Context is already active")
os.environ["AMZN_BRAKET_DEVICE_ARN_TEMP"] = self.device_arn
if self.reservation_arn is not None:
os.environ["AMZN_BRAKET_RESERVATION_ARN"] = self.reservation_arn
raise RuntimeError("Reservation context is already active.")

os.environ["AMZN_BRAKET_RESERVATION_DEVICE_ARN"] = self.device_arn
if self.reservation_arn:
os.environ["AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN"] = self.reservation_arn
self.context_active = True

def stop(self) -> None:
"""Stop the reservation context."""
if not self.context_active:
raise RuntimeError("Context is not active")
os.environ.pop("AMZN_BRAKET_DEVICE_ARN_TEMP", None)
os.environ.pop("AMZN_BRAKET_RESERVATION_ARN", None)
raise RuntimeError("Reservation context is not active.")
os.environ.pop("AMZN_BRAKET_RESERVATION_DEVICE_ARN", None)
os.environ.pop("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN", None)
self.context_active = False
21 changes: 11 additions & 10 deletions test/integ_tests/test_reservation_arn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from braket.circuits import Circuit
from braket.devices import Devices
from braket.jobs import get_job_device_arn, hybrid_job
from braket.reservations import DirectReservation


@pytest.fixture
Expand All @@ -36,23 +37,23 @@ def test_create_task_via_invalid_reservation_arn_on_qpu(reservation_arn):
device = AwsDevice(Devices.IonQ.Harmony)

with pytest.raises(ClientError, match="Reservation arn is invalid"):
device.run(
circuit,
shots=10,
reservation_arn=reservation_arn,
)
device.run(circuit, shots=10, reservation_arn=reservation_arn)

with pytest.raises(ClientError, match="Reservation arn is invalid"):
with DirectReservation(device, reservation_arn=reservation_arn):
device.run(circuit, shots=10)


def test_create_task_via_reservation_arn_on_simulator(reservation_arn):
circuit = Circuit().h(0)
device = AwsDevice(Devices.Amazon.SV1)

with pytest.raises(ClientError, match="Braket Direct is not supported for"):
device.run(
circuit,
shots=10,
reservation_arn=reservation_arn,
)
device.run(circuit, shots=10, reservation_arn=reservation_arn)

with pytest.raises(ClientError, match="Braket Direct is not supported for"):
with DirectReservation(device, reservation_arn=reservation_arn):
device.run(circuit, shots=10)


@pytest.mark.xfail(
Expand Down
29 changes: 18 additions & 11 deletions test/unit_tests/braket/reservations/test_reservations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import pytest

from braket.aws import AwsDevice, AwsSession
from braket.reservations.reservations import DirectReservation
from braket.devices import LocalSimulator
from braket.reservations import DirectReservation


@pytest.fixture
Expand All @@ -39,34 +40,40 @@ def test_direct_reservation_with_string():
assert reservation.reservation_arn == "reservation_arn_example"


def test_reservation_local_device():
mock_device = MagicMock(spec=LocalSimulator)
with DirectReservation(mock_device, "reservation_arn_example") as reservation:
os.environ["AMZN_BRAKET_RESERVATION_DEVICE_ARN"] = ""


def test_direct_reservation_with_invalid_type():
"""Test initialization with an invalid type should raise ValueError."""
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="device must be an AwsDevice or its ARN."):
DirectReservation(123, "reservation_arn_example")


def test_context_management(aws_device):
"""Test the context manager functionality."""
with DirectReservation(aws_device, "reservation_arn_example"):
assert os.getenv("AMZN_BRAKET_DEVICE_ARN_TEMP") == "device_arn_example"
assert os.getenv("AMZN_BRAKET_RESERVATION_ARN") == "reservation_arn_example"
assert os.getenv("AMZN_BRAKET_DEVICE_ARN_TEMP") is None
assert os.getenv("AMZN_BRAKET_RESERVATION_ARN") is None
assert os.getenv("AMZN_BRAKET_RESERVATION_DEVICE_ARN") == "device_arn_example"
assert os.getenv("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN") == "reservation_arn_example"
assert os.getenv("AMZN_BRAKET_RESERVATION_DEVICE_ARN") is None
assert os.getenv("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN") is None


def test_start_reservation_already_active(aws_device):
"""Test starting an already active context raises RuntimeError."""
reservation = DirectReservation(aws_device, "reservation_arn_example")
reservation.start()
with pytest.raises(RuntimeError):
with pytest.raises(RuntimeError, match="Reservation context is already active."):
reservation.start()
reservation.stop()


def test_stop_reservation_not_active(aws_device):
"""Test stopping a non-active context raises RuntimeError."""
reservation = DirectReservation(aws_device, "reservation_arn_example")
with pytest.raises(RuntimeError):
with pytest.raises(RuntimeError, match="Reservation context is not active."):
reservation.stop()


Expand All @@ -81,8 +88,8 @@ def test_multiple_start_stop_cycles(aws_device):
reservation.stop()
reservation.start()
reservation.stop()
assert os.getenv("AMZN_BRAKET_DEVICE_ARN_TEMP") is None
assert os.getenv("AMZN_BRAKET_RESERVATION_ARN") is None
assert os.getenv("AMZN_BRAKET_RESERVATION_DEVICE_ARN") is None
assert os.getenv("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN") is None


def test_stop(aws_device):
Expand Down Expand Up @@ -117,7 +124,7 @@ def test_create_quantum_task_with_correct_device_and_reservation():
kwargs["associations"] = [
{
"arn": reservation_arn,
"type": "AMZN_BRAKET_RESERVATION_ARN",
"type": "RESERVATION_TIME_WINDOW_ARN",
}
]
mock_client.create_quantum_task.assert_called_once_with(**kwargs)

0 comments on commit e8720f5

Please sign in to comment.