Skip to content

Commit

Permalink
fix: make filter more convenient (#718)
Browse files Browse the repository at this point in the history
* make filter more convenient

* changes according to feedback

* remove Union type hint

---------

Co-authored-by: Abe Coull <85974725+math411@users.noreply.github.com>
Co-authored-by: Cody Wang <speller26@gmail.com>
  • Loading branch information
3 people authored Dec 12, 2023
1 parent 7787d27 commit 12f1387
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
25 changes: 15 additions & 10 deletions src/braket/circuits/gate_calibrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

from copy import deepcopy
from typing import Any, Optional
from typing import Any

from braket.circuits.gate import Gate
from braket.circuits.serialization import (
Expand Down Expand Up @@ -91,35 +91,40 @@ def __len__(self):
return len(self._pulse_sequences)

def filter(
self, gates: Optional[list[Gate]] = None, qubits: Optional[QubitSet] = None
) -> Optional[GateCalibrations]:
self,
gates: list[Gate] | None = None,
qubits: QubitSet | list[QubitSet] | None = None,
) -> GateCalibrations:
"""
Filters the data based on optional lists of gates and QubitSets.
Args:
gates (Optional[list[Gate]]): An optional list of gates to filter on.
qubits (Optional[QubitSet]): An optional `QubitSet` to filter on.
gates (list[Gate] | None): An optional list of gates to filter on.
qubits (QubitSet | list[QubitSet] | None): An optional `QubitSet` or
list of `QubitSet` to filter on.
Returns:
Optional[GateCalibrations]: A filtered GateCalibrations object. Otherwise, returns
none if no matches are found.
GateCalibrations: A filtered GateCalibrations object.
""" # noqa: E501
keys = self.pulse_sequences.keys()
if isinstance(qubits, QubitSet):
qubits = [qubits]
filtered_calibration_keys = [
tup
for tup in keys
if (gates is None or tup[0] in gates) and (qubits is None or qubits.issubset(tup[1]))
if (gates is None or tup[0] in gates)
and (qubits is None or any(qset.issubset(tup[1]) for qset in qubits))
]
return GateCalibrations(
{k: v for (k, v) in self.pulse_sequences.items() if k in filtered_calibration_keys},
)

def to_ir(self, calibration_key: Optional[tuple[Gate, QubitSet]] = None) -> str:
def to_ir(self, calibration_key: tuple[Gate, QubitSet] | None = None) -> str:
"""
Returns the defcal representation for the `GateCalibrations` object.
Args:
calibration_key (Optional[tuple[Gate, QubitSet]]): An optional key to get a specific defcal.
calibration_key (tuple[Gate, QubitSet] | None): An optional key to get a specific defcal.
Default: None
Returns:
Expand Down
19 changes: 15 additions & 4 deletions test/unit_tests/braket/circuits/test_gate_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,30 @@ def test_gc_copy(pulse_sequence):


def test_filter(pulse_sequence):
calibration_key = (Gate.Z(), QubitSet([0, 1]))
calibration_key_2 = (Gate.H(), QubitSet([0, 1]))
calibration_key = (Gate.Z(), QubitSet([0]))
calibration_key_2 = (Gate.H(), QubitSet([1]))
calibration_key_3 = (Gate.CZ(), QubitSet([0, 1]))
calibration = GateCalibrations(
{calibration_key: pulse_sequence, calibration_key_2: pulse_sequence}
{
calibration_key: pulse_sequence,
calibration_key_2: pulse_sequence,
calibration_key_3: pulse_sequence,
}
)
expected_calibration_1 = GateCalibrations({calibration_key: pulse_sequence})
expected_calibration_2 = GateCalibrations(
{calibration_key: pulse_sequence, calibration_key_2: pulse_sequence}
{calibration_key: pulse_sequence, calibration_key_3: pulse_sequence}
)
expected_calibration_3 = GateCalibrations({calibration_key_2: pulse_sequence})
expected_calibration_4 = GateCalibrations({})
expected_calibration_5 = calibration
expected_calibration_6 = GateCalibrations({calibration_key_3: pulse_sequence})
assert expected_calibration_1 == calibration.filter(gates=[Gate.Z()])
assert expected_calibration_2 == calibration.filter(qubits=QubitSet(0))
assert expected_calibration_3 == calibration.filter(gates=[Gate.H()], qubits=QubitSet(1))
assert expected_calibration_4 == calibration.filter(gates=[Gate.Z()], qubits=QubitSet(1))
assert expected_calibration_5 == calibration.filter(qubits=[QubitSet(0), QubitSet(1)])
assert expected_calibration_6 == calibration.filter(qubits=QubitSet([0, 1]))


def test_to_ir(pulse_sequence):
Expand Down

0 comments on commit 12f1387

Please sign in to comment.