Skip to content

Commit

Permalink
feat: Replace 'EmulationPass' with more general BasePass class
Browse files Browse the repository at this point in the history
  • Loading branch information
ltnln committed Aug 2, 2024
1 parent 9bddfbc commit 4626332
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 38 deletions.
3 changes: 1 addition & 2 deletions src/braket/aws/aws_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@
from braket.devices import Devices
from braket.devices.device import Device
from braket.emulation import Emulator
from braket.emulation.emulation_passes import ProgramType
from braket.ir.blackbird import Program as BlackbirdProgram
from braket.ir.openqasm import Program as OpenQasmProgram
from braket.parametric.free_parameter import FreeParameter
from braket.parametric.free_parameter_expression import _is_float
from braket.passes import ProgramType
from braket.pulse import ArbitraryWaveform, Frame, Port, PulseSequence
from braket.pulse.waveforms import _parse_waveform_from_calibration_schema
from braket.schema_common import BraketSchemaBase
Expand Down Expand Up @@ -930,7 +930,6 @@ def validate(
"""
self.emulator.validate(task_specification)
return

def run_passes(
self, task_specification: ProgramType, apply_noise_model: bool = True
Expand Down
27 changes: 12 additions & 15 deletions src/braket/emulation/base_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@

from typing import Iterable, Union

from braket.emulation.emulation_passes import EmulationPass, ProgramType, ValidationPass
from braket.emulation.emulation_passes import ValidationPass
from braket.passes import BasePass, ProgramType


class BaseEmulator:
def __init__(self, emulator_passes: Iterable[EmulationPass] = None):
def __init__(self, emulator_passes: Iterable[BasePass] = None):
self._emulator_passes = emulator_passes if emulator_passes is not None else []

def run_passes(self, task_specification: ProgramType) -> ProgramType:
"""
This method passes the input program through the EmulationPasses contained
This method passes the input program through the Passes contained
within this emulator. An emulator pass may simply validate a program or may
modify or entirely transform the program (to an equivalent quantum program).
Expand All @@ -28,7 +29,7 @@ def run_passes(self, task_specification: ProgramType) -> ProgramType:

def validate(self, task_specification: ProgramType) -> None:
"""
This method passes the input program through EmulationPasses that perform
This method passes the input program through Passes that perform
only validation, without modifying the input program.
Args:
Expand All @@ -39,31 +40,27 @@ def validate(self, task_specification: ProgramType) -> None:
if isinstance(emulator_pass, ValidationPass):
emulator_pass(task_specification)

def add_pass(
self, emulator_pass: Union[Iterable[EmulationPass], EmulationPass]
) -> BaseEmulator:
def add_pass(self, emulator_pass: Union[Iterable[BasePass], BasePass]) -> BaseEmulator:
"""
Append a new EmulationPass or a list of EmulationPass objects.
Append a new BasePass or a list of BasePass objects.
Args:
emulator_pass (Union[Iterable[EmulationPass], EmulationPass]): Either a
single EmulationPass object or a list of EmulationPass objects that
emulator_pass (Union[Iterable[BasePass], BasePass]): Either a
single Pass object or a list of Pass objects that
will be used in validation and program compilation passes by this
emulator.
Returns:
BaseEmulator: Returns an updated self.
Raises:
TypeError: If the input is not an iterable or an EmulationPass.
TypeError: If the input is not an iterable or an Pass.
"""
if isinstance(emulator_pass, Iterable):
self._emulator_passes.extend(emulator_pass)
elif isinstance(emulator_pass, EmulationPass):
elif isinstance(emulator_pass, BasePass):
self._emulator_passes.append(emulator_pass)
else:
raise TypeError(
"emulator_pass must be an EmulationPass or an iterable of EmulationPass"
)
raise TypeError("emulator_pass must be an Pass or an iterable of Pass")
return self
4 changes: 0 additions & 4 deletions src/braket/emulation/emulation_passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
from braket.emulation.emulation_passes.emulation_pass import ( # noqa: F401
EmulationPass,
ProgramType,
)
from braket.emulation.emulation_passes.validation_pass import ValidationPass # noqa: F401
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Iterator
from collections.abc import Iterable
from typing import Optional

from braket.circuits import Circuit
Expand All @@ -11,14 +11,14 @@
class GateValidator(ValidationPass[Circuit]):
def __init__(
self,
supported_gates: Optional[Iterator[str]] = None,
native_gates: Optional[Iterator[str]] = None,
supported_gates: Optional[Iterable[str]] = None,
native_gates: Optional[Iterable[str]] = None,
):
"""
Args:
supported_gates (Optional[Iterator[str]]): A list of gates supported outside of
supported_gates (Optional[Iterable[str]]): A list of gates supported outside of
verbatim modeby the emulator. A gate is a Braket gate name.
native_gates (Optional[Iterator[str]]): A list of gates supported inside of
native_gates (Optional[Iterable[str]]): A list of gates supported inside of
verbatim mode by the emulator.
Raises:
Expand All @@ -30,12 +30,14 @@ def __init__(
raise ValueError("Supported gate set or native gate set must be provided.")

try:
self._supported_gates = set(BRAKET_GATES[gate.lower()] for gate in supported_gates)
self._supported_gates = frozenset(
BRAKET_GATES[gate.lower()] for gate in supported_gates
)
except KeyError as e:
raise ValueError(f"Input {str(e)} in supported_gates is not a valid Braket gate name.")

try:
self._native_gates = set(BRAKET_GATES[gate.lower()] for gate in native_gates)
self._native_gates = frozenset(BRAKET_GATES[gate.lower()] for gate in native_gates)
except KeyError as e:
raise ValueError(f"Input {str(e)} in native_gates is not a valid Braket gate name.")

Expand Down
4 changes: 2 additions & 2 deletions src/braket/emulation/emulation_passes/validation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from abc import abstractmethod

from braket.emulation.emulation_passes.emulation_pass import EmulationPass, ProgramType
from braket.passes.base_pass import BasePass, ProgramType


class ValidationPass(EmulationPass[ProgramType]):
class ValidationPass(BasePass[ProgramType]):
@abstractmethod
def validate(self, program: ProgramType) -> None:
"""
Expand Down
10 changes: 5 additions & 5 deletions src/braket/emulation/emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from braket.devices import Device
from braket.devices.local_simulator import LocalSimulator
from braket.emulation.base_emulator import BaseEmulator
from braket.emulation.emulation_passes import EmulationPass, ProgramType
from braket.ir.openqasm import Program as OpenQasmProgram
from braket.passes import BasePass, ProgramType
from braket.tasks import QuantumTask
from braket.tasks.quantum_task_batch import QuantumTaskBatch

Expand All @@ -26,7 +26,7 @@ def __init__(
self,
backend: str = "default",
noise_model: Optional[NoiseModel] = None,
emulator_passes: Iterable[EmulationPass] = None,
emulator_passes: Iterable[BasePass] = None,
**kwargs,
):
Device.__init__(self, name=kwargs.get("name", "DeviceEmulator"), status="AVAILABLE")
Expand Down Expand Up @@ -143,13 +143,13 @@ def run_passes(
self, task_specification: ProgramType, apply_noise_model: bool = True
) -> ProgramType:
"""
Passes the input program through all EmulationPass objects contained in this
Passes the input program through all Pass objects contained in this
emulator and applies the emulator's noise model, if it exists, before
returning the compiled program.
Args:
task_specification (ProgramType): The input program to validate and
compile based on this emulator's EmulationPasses
compile based on this emulator's Passes
apply_noise_model (bool): If true, apply this emulator's noise model
to the compiled program before returning the final program.
Expand All @@ -167,7 +167,7 @@ def run_passes(

def validate(self, task_specification: ProgramType) -> None:
"""
Runs only EmulationPasses that are ValidationPass, i.e. all non-modifying
Runs only Passes that are ValidationPass, i.e. all non-modifying
validation passes on the input program.
Args:
Expand Down
1 change: 1 addition & 0 deletions src/braket/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from braket.passes.base_pass import BasePass, ProgramType # noqa: F40
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
ProgramType = TypeVar("ProgramType")


class EmulationPass(ABC, Generic[ProgramType]):
class BasePass(ABC, Generic[ProgramType]):
@abstractmethod
def run(self, program: ProgramType) -> ProgramType:
"""
Expand Down
4 changes: 2 additions & 2 deletions test/unit_tests/braket/emulation/test_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from braket.default_simulator import DensityMatrixSimulator, StateVectorSimulator
from braket.devices import local_simulator
from braket.emulation import Emulator
from braket.emulation.emulation_passes import EmulationPass, ProgramType
from braket.emulation.emulation_passes.gate_device_passes import GateValidator, QubitCountValidator
from braket.passes import BasePass, ProgramType


class AlwaysFailPass(EmulationPass[ProgramType]):
class AlwaysFailPass(BasePass[ProgramType]):
def run(self, program: ProgramType):
raise ValueError("This pass always raises an error.")

Expand Down

0 comments on commit 4626332

Please sign in to comment.