Skip to content

Commit

Permalink
feat: add support for the ARN region (#977)
Browse files Browse the repository at this point in the history
Co-authored-by: Coull <accoull@amazon.com>
Co-authored-by: Tim (Yi-Ting) <yitchen@amazon.com>
  • Loading branch information
3 people authored May 22, 2024
1 parent 7ba54ce commit 035409a
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/braket/aws/aws_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class AwsDevice(Device):
device.
"""

REGIONS = ("us-east-1", "us-west-1", "us-west-2", "eu-west-2")
REGIONS = ("us-east-1", "us-west-1", "us-west-2", "eu-west-2", "eu-north-1")

DEFAULT_SHOTS_QPU = 1000
DEFAULT_SHOTS_SIMULATOR = 0
Expand Down
4 changes: 4 additions & 0 deletions src/braket/devices/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class _DWave(str, Enum):
_Advantage6 = "arn:aws:braket:us-west-2::device/qpu/d-wave/Advantage_system6"
_DW2000Q6 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6"

class _IQM(str, Enum):
Garnet = "arn:aws:braket:eu-north-1::device/qpu/iqm/Garnet"

class _IonQ(str, Enum):
Harmony = "arn:aws:braket:us-east-1::device/qpu/ionq/Harmony"
Aria1 = "arn:aws:braket:us-east-1::device/qpu/ionq/Aria-1"
Expand Down Expand Up @@ -54,6 +57,7 @@ class _Xanadu(str, Enum):
Amazon = _Amazon
# DWave = _DWave
IonQ = _IonQ
IQM = _IQM
OQC = _OQC
QuEra = _QuEra
Rigetti = _Rigetti
Expand Down
3 changes: 2 additions & 1 deletion src/braket/jobs/image_uri_config/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"us-east-1",
"us-west-1",
"us-west-2",
"eu-west-2"
"eu-west-2",
"eu-north-1"
]
}
3 changes: 2 additions & 1 deletion src/braket/jobs/image_uri_config/pl_pytorch.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"us-east-1",
"us-west-1",
"us-west-2",
"eu-west-2"
"eu-west-2",
"eu-north-1"
]
}
3 changes: 2 additions & 1 deletion src/braket/jobs/image_uri_config/pl_tensorflow.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"us-east-1",
"us-west-1",
"us-west-2",
"eu-west-2"
"eu-west-2",
"eu-north-1"
]
}
55 changes: 35 additions & 20 deletions test/integ_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def pytest_configure_node(node):
node.workerinput["JOB_FAILED_NAME"] = job_fail_name
if endpoint := os.getenv("BRAKET_ENDPOINT"):
node.workerinput["BRAKET_ENDPOINT"] = endpoint
node.workerinput["AWS_REGION"] = os.getenv("AWS_REGION")


def pytest_xdist_node_collection_finished(ids):
Expand All @@ -48,8 +49,11 @@ def pytest_xdist_node_collection_finished(ids):
"""
run_jobs = any("job" in test for test in ids)
profile_name = os.environ["AWS_PROFILE"]
aws_session = AwsSession(boto3.session.Session(profile_name=profile_name))
if run_jobs and os.getenv("JOBS_STARTED") is None:
region_name = os.getenv("AWS_REGION")
aws_session = AwsSession(
boto3.session.Session(profile_name=profile_name, region_name=region_name)
)
if run_jobs and os.getenv("JOBS_STARTED") is None and region_name != "eu-north-1":
AwsQuantumJob.create(
"arn:aws:braket:::device/quantum-simulator/amazon/sv1",
job_name=job_fail_name,
Expand All @@ -72,9 +76,10 @@ def pytest_xdist_node_collection_finished(ids):


@pytest.fixture(scope="session")
def boto_session():
def boto_session(request):
profile_name = os.environ["AWS_PROFILE"]
return boto3.session.Session(profile_name=profile_name)
region_name = request.config.workerinput["AWS_REGION"]
return boto3.session.Session(profile_name=profile_name, region_name=region_name)


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -137,9 +142,11 @@ def s3_destination_folder(s3_bucket, s3_prefix):

@pytest.fixture(scope="session")
def braket_simulators(aws_session):
return {
simulator_arn: AwsDevice(simulator_arn, aws_session) for simulator_arn in SIMULATOR_ARNS
}
return (
{simulator_arn: AwsDevice(simulator_arn, aws_session) for simulator_arn in SIMULATOR_ARNS}
if aws_session.region != "eu-north-1"
else None
)


@pytest.fixture(scope="session")
Expand All @@ -164,21 +171,29 @@ def job_failed_name(request):

@pytest.fixture(scope="session", autouse=True)
def completed_quantum_job(job_completed_name):
job_arn = [
job["jobArn"]
for job in boto3.client("braket").search_jobs(filters=[])["jobs"]
if job["jobName"] == job_completed_name
][0]
job_arn = (
[
job["jobArn"]
for job in boto3.client("braket").search_jobs(filters=[])["jobs"]
if job["jobName"] == job_completed_name
][0]
if os.getenv("JOBS_STARTED")
else None
)

return AwsQuantumJob(arn=job_arn)
return AwsQuantumJob(arn=job_arn) if os.getenv("JOBS_STARTED") else None


@pytest.fixture(scope="session", autouse=True)
def failed_quantum_job(job_failed_name):
job_arn = [
job["jobArn"]
for job in boto3.client("braket").search_jobs(filters=[])["jobs"]
if job["jobName"] == job_failed_name
][0]

return AwsQuantumJob(arn=job_arn)
job_arn = (
[
job["jobArn"]
for job in boto3.client("braket").search_jobs(filters=[])["jobs"]
if job["jobName"] == job_failed_name
][0]
if os.getenv("JOBS_STARTED")
else None
)

return AwsQuantumJob(arn=job_arn) if os.getenv("JOBS_STARTED") else None
37 changes: 20 additions & 17 deletions test/integ_tests/test_cost_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
from botocore.exceptions import ClientError

from braket.aws import AwsDevice, AwsSession
from braket.aws import AwsDevice, AwsDeviceType, AwsSession
from braket.circuits import Circuit
from braket.tracking import Tracker
from braket.tracking.tracker import MIN_SIMULATOR_DURATION
Expand Down Expand Up @@ -93,23 +93,26 @@ def test_all_devices_price_search():
s = AwsSession(boto3.Session(region_name=region))
# Skip devices with empty execution windows
for device in [device for device in devices if device.properties.service.executionWindows]:
try:
s.get_device(device.arn)

# If we are here, device can create tasks in region
details = {
"shots": 100,
"device": device.arn,
"billed_duration": MIN_SIMULATOR_DURATION,
"job_task": False,
"status": "COMPLETED",
}
tasks[f"task:for:{device.name}:{region}"] = details.copy()
details["job_task"] = True
tasks[f"jobtask:for:{device.name}:{region}"] = details
except s.braket_client.exceptions.ResourceNotFoundException:
# device does not exist in region, so nothing to test
if region == "eu-north-1" and device.type == AwsDeviceType.SIMULATOR:
pass
else:
try:
s.get_device(device.arn)

# If we are here, device can create tasks in region
details = {
"shots": 100,
"device": device.arn,
"billed_duration": MIN_SIMULATOR_DURATION,
"job_task": False,
"status": "COMPLETED",
}
tasks[f"task:for:{device.name}:{region}"] = details.copy()
details["job_task"] = True
tasks[f"jobtask:for:{device.name}:{region}"] = details
except s.braket_client.exceptions.ResourceNotFoundException:
# device does not exist in region, so nothing to test
pass

t = Tracker()
t._resources = tasks
Expand Down
24 changes: 22 additions & 2 deletions test/unit_tests/braket/aws/test_aws_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,6 +1753,16 @@ def test_get_devices(mock_copy_session, aws_session):
"providerName": "OQC",
}
],
# eu-north-1
[
{
"deviceArn": SV1_ARN,
"deviceName": "SV1",
"deviceType": "SIMULATOR",
"deviceStatus": "ONLINE",
"providerName": "Amazon Braket",
},
],
# Only two regions to search outside of current
ValueError("should not be reachable"),
]
Expand All @@ -1763,7 +1773,7 @@ def test_get_devices(mock_copy_session, aws_session):
ValueError("should not be reachable"),
]
mock_copy_session.return_value = session_for_region
# Search order: us-east-1, us-west-1, us-west-2, eu-west-2
# Search order: us-east-1, us-west-1, us-west-2, eu-west-2, eu-north-1
results = AwsDevice.get_devices(
arns=[SV1_ARN, DWAVE_ARN, IONQ_ARN, OQC_ARN],
provider_names=["Amazon Braket", "D-Wave", "IonQ", "OQC"],
Expand Down Expand Up @@ -1858,6 +1868,16 @@ def test_get_devices_with_error_in_region(mock_copy_session, aws_session):
"providerName": "OQC",
}
],
# eu-north-1
[
{
"deviceArn": SV1_ARN,
"deviceName": "SV1",
"deviceType": "SIMULATOR",
"deviceStatus": "ONLINE",
"providerName": "Amazon Braket",
},
],
# Only two regions to search outside of current
ValueError("should not be reachable"),
]
Expand All @@ -1867,7 +1887,7 @@ def test_get_devices_with_error_in_region(mock_copy_session, aws_session):
ValueError("should not be reachable"),
]
mock_copy_session.return_value = session_for_region
# Search order: us-east-1, us-west-1, us-west-2, eu-west-2
# Search order: us-east-1, us-west-1, us-west-2, eu-west-2, eu-north-1
results = AwsDevice.get_devices(
statuses=["ONLINE"],
aws_session=aws_session,
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ deps =
{[test-deps]deps}
passenv =
AWS_PROFILE
AWS_REGION
BRAKET_ENDPOINT
commands =
pytest test/integ_tests {posargs}
Expand Down

0 comments on commit 035409a

Please sign in to comment.