Skip to content

Commit

Permalink
make filter more convenient
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjaskula-aws committed Oct 29, 2023
1 parent 10f4995 commit 4acc53f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
14 changes: 10 additions & 4 deletions src/braket/circuits/gate_calibrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

from copy import deepcopy
from typing import Any, Optional
from typing import Any, Optional, Union

from braket.circuits.gate import Gate
from braket.circuits.serialization import (
Expand Down Expand Up @@ -91,24 +91,30 @@ def __len__(self):
return len(self._pulse_sequences)

def filter(
self, gates: Optional[list[Gate]] = None, qubits: Optional[QubitSet] = None
self,
gates: Optional[list[Gate]] = None,
qubits: Optional[Union[QubitSet, list[QubitSet]]] = None,
) -> Optional[GateCalibrations]:
"""
Filters the data based on optional lists of gates and QubitSets.
Args:
gates (Optional[list[Gate]]): An optional list of gates to filter on.
qubits (Optional[QubitSet]): An optional `QubitSet` to filter on.
qubits (Optional[Union[QubitSet, list[QubitSet]]]): An optional `QubitSet` or
list of `QubitSet` to filter on.
Returns:
Optional[GateCalibrations]: A filtered GateCalibrations object. Otherwise, returns
none if no matches are found.
""" # noqa: E501
keys = self.pulse_sequences.keys()
if isinstance(qubits, QubitSet):
qubits = [qubits]
filtered_calibration_keys = [
tup
for tup in keys
if (gates is None or tup[0] in gates) and (qubits is None or qubits.issubset(tup[1]))
if (gates is None or tup[0] in gates)
and (qubits is None or any(qset.issubset(tup[1]) for qset in qubits))
]
return GateCalibrations(
{k: v for (k, v) in self.pulse_sequences.items() if k in filtered_calibration_keys},
Expand Down
19 changes: 15 additions & 4 deletions test/unit_tests/braket/circuits/test_gate_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,30 @@ def test_gc_copy(pulse_sequence):


def test_filter(pulse_sequence):
calibration_key = (Gate.Z(), QubitSet([0, 1]))
calibration_key_2 = (Gate.H(), QubitSet([0, 1]))
calibration_key = (Gate.Z(), QubitSet([0]))
calibration_key_2 = (Gate.H(), QubitSet([1]))
calibration_key_3 = (Gate.CZ(), QubitSet([0, 1]))
calibration = GateCalibrations(
{calibration_key: pulse_sequence, calibration_key_2: pulse_sequence}
{
calibration_key: pulse_sequence,
calibration_key_2: pulse_sequence,
calibration_key_3: pulse_sequence,
}
)
expected_calibration_1 = GateCalibrations({calibration_key: pulse_sequence})
expected_calibration_2 = GateCalibrations(
{calibration_key: pulse_sequence, calibration_key_2: pulse_sequence}
{calibration_key: pulse_sequence, calibration_key_3: pulse_sequence}
)
expected_calibration_3 = GateCalibrations({calibration_key_2: pulse_sequence})
expected_calibration_4 = GateCalibrations({})
expected_calibration_5 = calibration
expected_calibration_6 = GateCalibrations({calibration_key_3: pulse_sequence})
assert expected_calibration_1 == calibration.filter(gates=[Gate.Z()])
assert expected_calibration_2 == calibration.filter(qubits=QubitSet(0))
assert expected_calibration_3 == calibration.filter(gates=[Gate.H()], qubits=QubitSet(1))
assert expected_calibration_4 == calibration.filter(gates=[Gate.Z()], qubits=QubitSet(1))
assert expected_calibration_5 == calibration.filter(qubits=[QubitSet(0), QubitSet(1)])
assert expected_calibration_6 == calibration.filter(qubits=QubitSet([0, 1]))


def test_to_ir(pulse_sequence):
Expand Down

0 comments on commit 4acc53f

Please sign in to comment.