From dd23bc9d1852e3cc56963c83f5052a628cd97d8d Mon Sep 17 00:00:00 2001 From: Tim Date: Mon, 13 Nov 2023 01:02:20 -0500 Subject: [PATCH] AwsDevice.run supports AutoQASM program --- src/braket/aws/aws_quantum_task.py | 41 +++++++++++++++++++ .../experimental/autoqasm/test_devices.py | 39 +++++++++++++++++- 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/src/braket/aws/aws_quantum_task.py b/src/braket/aws/aws_quantum_task.py index fa58d911d..c2cc42b03 100644 --- a/src/braket/aws/aws_quantum_task.py +++ b/src/braket/aws/aws_quantum_task.py @@ -34,6 +34,7 @@ IRType, OpenQASMSerializationProperties, QubitReferenceType, + SerializableProgram, ) from braket.device_schema import GateModelParameters from braket.device_schema.dwave import ( @@ -583,6 +584,46 @@ def _( return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) +@_create_internal.register +def _( + serializable_program: SerializableProgram, + aws_session: AwsSession, + create_task_kwargs: dict[str, Any], + device_arn: str, + device_parameters: Union[dict, BraketSchemaBase], + _disable_qubit_rewiring: bool, + inputs: dict[str, float], + gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + *args, + **kwargs, +) -> AwsQuantumTask: + openqasm_program = OpenQASMProgram(source=serializable_program.to_ir(ir_type=IRType.OPENQASM)) + if inputs: + inputs_copy = openqasm_program.inputs.copy() if openqasm_program.inputs is not None else {} + inputs_copy.update(inputs) + openqasm_program = OpenQASMProgram( + source=openqasm_program.source, + inputs=inputs_copy, + ) + create_task_kwargs.update({"action": openqasm_program.json()}) + if device_parameters: + final_device_parameters = ( + _circuit_device_params_from_dict( + device_parameters, + device_arn, + GateModelParameters(qubitCount=0), # qubitCount unused + ) + if type(device_parameters) is dict + else device_parameters + ) + create_task_kwargs.update( + {"deviceParameters": final_device_parameters.json(exclude_none=True)} + ) + + task_arn = aws_session.create_quantum_task(**create_task_kwargs) + return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) + + @_create_internal.register def _( blackbird_program: BlackbirdProgram, diff --git a/test/unit_tests/braket/experimental/autoqasm/test_devices.py b/test/unit_tests/braket/experimental/autoqasm/test_devices.py index 4313ffb83..ec7d869b7 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_devices.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_devices.py @@ -14,6 +14,7 @@ """AutoQASM tests exercising device-specific targeting functionality. """ +import json from unittest.mock import Mock, patch import pytest @@ -24,7 +25,8 @@ from braket.device_schema.simulators import GateModelSimulatorDeviceCapabilities from braket.devices import Devices from braket.experimental.autoqasm import errors -from braket.experimental.autoqasm.instructions import cnot, cphaseshift00, h, x +from braket.experimental.autoqasm.instructions import cnot, cphaseshift00, h, rx, x +from braket.parametric import FreeParameter RIGETTI_REGION = "us-west-1" @@ -253,3 +255,38 @@ def my_program(): cnot("$5", "$2") assert my_program().to_ir() + + +@patch("braket.aws.aws_device.AwsSession.copy_session") +@patch("braket.aws.aws_device.AwsSession") +def test_aws_device_run( + aws_session_init: Mock, + mock_copy_session: Mock, + aws_session: Mock, +) -> None: + """Tests AwsDevice.run with AutoQASM program.""" + aws_session_init.return_value = aws_session + mock_copy_session.return_value = aws_session + + @aq.main + def my_program(): + h(0) + rx(0, FreeParameter("angle")) + + program = my_program() + aws_device = AwsDevice(Devices.Amazon.SV1.value) + _ = aws_device.run(program, shots=10, inputs={"angle": 0.123}) + + run_call_args = aws_session.create_quantum_task.mock_calls[0].kwargs + run_call_args_action = json.loads(run_call_args["action"]) + + expected_run_call_args = { + "deviceArn": "arn:aws:braket:::device/quantum-simulator/amazon/sv1", + "outputS3Bucket": "amazon-braket-us-test-1-00000000", + "outputS3KeyPrefix": "tasks", + "shots": 10, + } + aws_session.create_quantum_task.assert_called_once() + assert expected_run_call_args.items() <= run_call_args.items() + assert run_call_args_action["source"] == program.to_ir() + assert run_call_args_action["inputs"] == {"angle": 0.123}