From 7ba54cedaa300f55dd6031a47f66d3fce7eec17f Mon Sep 17 00:00:00 2001 From: Ryan Shaffer <3620100+rmshaffer@users.noreply.github.com> Date: Tue, 21 May 2024 12:54:59 -0400 Subject: [PATCH] feature: Add support for SerializableProgram abstraction to Device interface (#976) --- src/braket/aws/aws_quantum_task.py | 29 +++++++++ src/braket/circuits/serialization.py | 21 +++++++ src/braket/devices/local_simulator.py | 60 +++++++++++++++---- .../braket/aws/test_aws_quantum_task.py | 28 +++++++++ .../braket/devices/test_local_simulator.py | 38 ++++++++++++ 5 files changed, 164 insertions(+), 12 deletions(-) diff --git a/src/braket/aws/aws_quantum_task.py b/src/braket/aws/aws_quantum_task.py index 17ae29252..a21c7782c 100644 --- a/src/braket/aws/aws_quantum_task.py +++ b/src/braket/aws/aws_quantum_task.py @@ -34,6 +34,7 @@ IRType, OpenQASMSerializationProperties, QubitReferenceType, + SerializableProgram, ) from braket.device_schema import GateModelParameters from braket.device_schema.dwave import ( @@ -623,6 +624,34 @@ def _( return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) +@_create_internal.register +def _( + serializable_program: SerializableProgram, + aws_session: AwsSession, + create_task_kwargs: dict[str, Any], + device_arn: str, + device_parameters: Union[dict, BraketSchemaBase], + _disable_qubit_rewiring: bool, + inputs: dict[str, float], + gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + *args, + **kwargs, +) -> AwsQuantumTask: + openqasm_program = OpenQASMProgram(source=serializable_program.to_ir(ir_type=IRType.OPENQASM)) + return _create_internal( + openqasm_program, + aws_session, + create_task_kwargs, + device_arn, + device_parameters, + _disable_qubit_rewiring, + inputs, + gate_definitions, + *args, + **kwargs, + ) + + @_create_internal.register def _( blackbird_program: BlackbirdProgram, diff --git a/src/braket/circuits/serialization.py b/src/braket/circuits/serialization.py index afcb5d118..fdee7d144 100644 --- a/src/braket/circuits/serialization.py +++ b/src/braket/circuits/serialization.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum @@ -32,6 +33,26 @@ class QubitReferenceType(str, Enum): PHYSICAL = "PHYSICAL" +class SerializableProgram(ABC): + @abstractmethod + def to_ir( + self, + ir_type: IRType = IRType.OPENQASM, + ) -> str: + """Serializes the program into an intermediate representation. + + Args: + ir_type (IRType): The IRType to use for converting the program to its + IR representation. Defaults to IRType.OPENQASM. + + Raises: + ValueError: Raised if the supplied `ir_type` is not supported. + + Returns: + str: A representation of the program in the `ir_type` format. + """ + + @dataclass class OpenQASMSerializationProperties: """Properties for serializing a circuit to OpenQASM. diff --git a/src/braket/devices/local_simulator.py b/src/braket/devices/local_simulator.py index 1dec56d37..15ec904de 100644 --- a/src/braket/devices/local_simulator.py +++ b/src/braket/devices/local_simulator.py @@ -25,11 +25,11 @@ from braket.circuits import Circuit from braket.circuits.circuit_helpers import validate_circuit_and_shots from braket.circuits.noise_model import NoiseModel -from braket.circuits.serialization import IRType +from braket.circuits.serialization import IRType, SerializableProgram from braket.device_schema import DeviceActionType, DeviceCapabilities from braket.devices.device import Device from braket.ir.ahs import Program as AHSProgram -from braket.ir.openqasm import Program +from braket.ir.openqasm import Program as OpenQASMProgram from braket.simulator import BraketSimulator from braket.tasks import AnnealingQuantumTaskResult, GateModelQuantumTaskResult from braket.tasks.analog_hamiltonian_simulation_quantum_task_result import ( @@ -80,7 +80,9 @@ def __init__( def run( self, - task_specification: Union[Circuit, Problem, Program, AnalogHamiltonianSimulation], + task_specification: Union[ + Circuit, Problem, OpenQASMProgram, AnalogHamiltonianSimulation, SerializableProgram + ], shots: int = 0, inputs: Optional[dict[str, float]] = None, *args: Any, @@ -89,7 +91,7 @@ def run( """Runs the given task with the wrapped local simulator. Args: - task_specification (Union[Circuit, Problem, Program, AnalogHamiltonianSimulation]): + task_specification (Union[Circuit, Problem, OpenQASMProgram, AnalogHamiltonianSimulation, SerializableProgram]): # noqa E501 The quantum task specification. shots (int): The number of times to run the circuit or annealing problem. Default is 0, which means that the simulator will compute the exact @@ -122,8 +124,18 @@ def run( def run_batch( # noqa: C901 self, task_specifications: Union[ - Union[Circuit, Problem, Program, AnalogHamiltonianSimulation], - list[Union[Circuit, Problem, Program, AnalogHamiltonianSimulation]], + Union[ + Circuit, Problem, OpenQASMProgram, AnalogHamiltonianSimulation, SerializableProgram + ], + list[ + Union[ + Circuit, + Problem, + OpenQASMProgram, + AnalogHamiltonianSimulation, + SerializableProgram, + ] + ], ], shots: Optional[int] = 0, max_parallel: Optional[int] = None, @@ -134,7 +146,7 @@ def run_batch( # noqa: C901 """Executes a batch of quantum tasks in parallel Args: - task_specifications (Union[Union[Circuit, Problem, Program, AnalogHamiltonianSimulation], list[Union[Circuit, Problem, Program, AnalogHamiltonianSimulation]]]): + task_specifications (Union[Union[Circuit, Problem, OpenQASMProgram, AnalogHamiltonianSimulation, SerializableProgram], list[Union[Circuit, Problem, OpenQASMProgram, AnalogHamiltonianSimulation, SerializableProgram]]]): # noqa Single instance or list of quantum task specification. shots (Optional[int]): The number of times to run the quantum task. Default: 0. @@ -163,7 +175,7 @@ def run_batch( # noqa: C901 single_task = isinstance( task_specifications, - (Circuit, Program, Problem, AnalogHamiltonianSimulation), + (Circuit, OpenQASMProgram, Problem, AnalogHamiltonianSimulation), ) single_input = isinstance(inputs, dict) @@ -220,7 +232,9 @@ def registered_backends() -> set[str]: def _run_internal_wrap( self, - task_specification: Union[Circuit, Problem, Program, AnalogHamiltonianSimulation], + task_specification: Union[ + Circuit, Problem, OpenQASMProgram, AnalogHamiltonianSimulation, SerializableProgram + ], shots: Optional[int] = None, inputs: Optional[dict[str, float]] = None, *args, @@ -250,7 +264,12 @@ def _(self, backend_impl: BraketSimulator): def _run_internal( self, task_specification: Union[ - Circuit, Problem, Program, AnalogHamiltonianSimulation, AHSProgram + Circuit, + Problem, + OpenQASMProgram, + AnalogHamiltonianSimulation, + AHSProgram, + SerializableProgram, ], shots: Optional[int] = None, *args, @@ -296,7 +315,7 @@ def _(self, problem: Problem, shots: Optional[int] = None, *args, **kwargs): @_run_internal.register def _( self, - program: Program, + program: OpenQASMProgram, shots: Optional[int] = None, inputs: Optional[dict[str, float]] = None, *args, @@ -308,13 +327,30 @@ def _( if inputs: inputs_copy = program.inputs.copy() if program.inputs is not None else {} inputs_copy.update(inputs) - program = Program( + program = OpenQASMProgram( source=program.source, inputs=inputs_copy, ) + results = simulator.run(program, shots, *args, **kwargs) + + if isinstance(results, GateModelQuantumTaskResult): + return results + return GateModelQuantumTaskResult.from_object(results) + @_run_internal.register + def _( + self, + program: SerializableProgram, + shots: Optional[int] = None, + inputs: Optional[dict[str, float]] = None, + *args, + **kwargs, + ): + program = OpenQASMProgram(source=program.to_ir(ir_type=IRType.OPENQASM)) + return self._run_internal(program, shots, inputs, *args, **kwargs) + @_run_internal.register def _( self, diff --git a/test/unit_tests/braket/aws/test_aws_quantum_task.py b/test/unit_tests/braket/aws/test_aws_quantum_task.py index 16a72da7a..28032d943 100644 --- a/test/unit_tests/braket/aws/test_aws_quantum_task.py +++ b/test/unit_tests/braket/aws/test_aws_quantum_task.py @@ -33,6 +33,7 @@ IRType, OpenQASMSerializationProperties, QubitReferenceType, + SerializableProgram, ) from braket.device_schema import GateModelParameters, error_mitigation from braket.device_schema.dwave import ( @@ -123,6 +124,19 @@ def openqasm_program(): return OpenQASMProgram(source="OPENQASM 3.0; h $0;") +class DummySerializableProgram(SerializableProgram): + def __init__(self, source: str): + self.source = source + + def to_ir(self, ir_type: IRType = IRType.OPENQASM) -> str: + return self.source + + +@pytest.fixture +def serializable_program(): + return DummySerializableProgram(source="OPENQASM 3.0; h $0;") + + @pytest.fixture def blackbird_program(): return BlackbirdProgram(source="Vac | q[0]") @@ -614,6 +628,20 @@ def test_create_openqasm_program_em_serialized(aws_session, arn, openqasm_progra ) +def test_create_serializable_program(aws_session, arn, serializable_program): + aws_session.create_quantum_task.return_value = arn + shots = 21 + AwsQuantumTask.create(aws_session, SIMULATOR_ARN, serializable_program, S3_TARGET, shots) + + _assert_create_quantum_task_called_with( + aws_session, + SIMULATOR_ARN, + OpenQASMProgram(source=serializable_program.to_ir()).json(), + S3_TARGET, + shots, + ) + + def test_create_blackbird_program(aws_session, arn, blackbird_program): aws_session.create_quantum_task.return_value = arn shots = 21 diff --git a/test/unit_tests/braket/devices/test_local_simulator.py b/test/unit_tests/braket/devices/test_local_simulator.py index a7e8bfe17..216d161c7 100644 --- a/test/unit_tests/braket/devices/test_local_simulator.py +++ b/test/unit_tests/braket/devices/test_local_simulator.py @@ -27,6 +27,7 @@ from braket.annealing import Problem, ProblemType from braket.circuits import Circuit, FreeParameter, Gate, Noise from braket.circuits.noise_model import GateCriteria, NoiseModel, NoiseModelInstruction +from braket.circuits.serialization import IRType, SerializableProgram from braket.device_schema import DeviceActionType, DeviceCapabilities from braket.device_schema.openqasm_device_action_properties import OpenQASMDeviceActionProperties from braket.devices import LocalSimulator, local_simulator @@ -250,6 +251,24 @@ def properties(self) -> DeviceCapabilities: return device_properties +class DummySerializableProgram(SerializableProgram): + def __init__(self, source: str): + self.source = source + + def to_ir(self, ir_type: IRType = IRType.OPENQASM) -> str: + return self.source + + +class DummySerializableProgramSimulator(DummyProgramSimulator): + def run( + self, + program: SerializableProgram, + shots: int = 0, + batch_size: int = 1, + ) -> GateModelQuantumTaskResult: + return GateModelQuantumTaskResult.from_object(GATE_MODEL_RESULT) + + class DummyProgramDensityMatrixSimulator(BraketSimulator): def run( self, program: ir.openqasm.Program, shots: Optional[int], *args, **kwargs @@ -556,6 +575,25 @@ def test_run_program_model(): assert task.result() == GateModelQuantumTaskResult.from_object(GATE_MODEL_RESULT) +def test_run_serializable_program_model(): + dummy = DummySerializableProgramSimulator() + sim = LocalSimulator(dummy) + task = sim.run( + DummySerializableProgram( + source=""" +qubit[2] q; +bit[2] c; + +h q[0]; +cnot q[0], q[1]; + +c = measure q; +""" + ) + ) + assert task.result() == GateModelQuantumTaskResult.from_object(GATE_MODEL_RESULT) + + @pytest.mark.xfail(raises=ValueError) def test_run_gate_model_value_error(): dummy = DummyCircuitSimulator()