From 12f1387e0c715b398ed102a996fd436b77f297a5 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula <99367153+jcjaskula-aws@users.noreply.github.com> Date: Mon, 11 Dec 2023 19:00:57 -0500 Subject: [PATCH] fix: make filter more convenient (#718) * make filter more convenient * changes according to feedback * remove Union type hint --------- Co-authored-by: Abe Coull <85974725+math411@users.noreply.github.com> Co-authored-by: Cody Wang --- src/braket/circuits/gate_calibrations.py | 25 +++++++++++-------- .../braket/circuits/test_gate_calibration.py | 19 +++++++++++--- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/braket/circuits/gate_calibrations.py b/src/braket/circuits/gate_calibrations.py index 6cbdd97d1..57013df4a 100644 --- a/src/braket/circuits/gate_calibrations.py +++ b/src/braket/circuits/gate_calibrations.py @@ -14,7 +14,7 @@ from __future__ import annotations from copy import deepcopy -from typing import Any, Optional +from typing import Any from braket.circuits.gate import Gate from braket.circuits.serialization import ( @@ -91,35 +91,40 @@ def __len__(self): return len(self._pulse_sequences) def filter( - self, gates: Optional[list[Gate]] = None, qubits: Optional[QubitSet] = None - ) -> Optional[GateCalibrations]: + self, + gates: list[Gate] | None = None, + qubits: QubitSet | list[QubitSet] | None = None, + ) -> GateCalibrations: """ Filters the data based on optional lists of gates and QubitSets. Args: - gates (Optional[list[Gate]]): An optional list of gates to filter on. - qubits (Optional[QubitSet]): An optional `QubitSet` to filter on. + gates (list[Gate] | None): An optional list of gates to filter on. + qubits (QubitSet | list[QubitSet] | None): An optional `QubitSet` or + list of `QubitSet` to filter on. Returns: - Optional[GateCalibrations]: A filtered GateCalibrations object. Otherwise, returns - none if no matches are found. + GateCalibrations: A filtered GateCalibrations object. """ # noqa: E501 keys = self.pulse_sequences.keys() + if isinstance(qubits, QubitSet): + qubits = [qubits] filtered_calibration_keys = [ tup for tup in keys - if (gates is None or tup[0] in gates) and (qubits is None or qubits.issubset(tup[1])) + if (gates is None or tup[0] in gates) + and (qubits is None or any(qset.issubset(tup[1]) for qset in qubits)) ] return GateCalibrations( {k: v for (k, v) in self.pulse_sequences.items() if k in filtered_calibration_keys}, ) - def to_ir(self, calibration_key: Optional[tuple[Gate, QubitSet]] = None) -> str: + def to_ir(self, calibration_key: tuple[Gate, QubitSet] | None = None) -> str: """ Returns the defcal representation for the `GateCalibrations` object. Args: - calibration_key (Optional[tuple[Gate, QubitSet]]): An optional key to get a specific defcal. + calibration_key (tuple[Gate, QubitSet] | None): An optional key to get a specific defcal. Default: None Returns: diff --git a/test/unit_tests/braket/circuits/test_gate_calibration.py b/test/unit_tests/braket/circuits/test_gate_calibration.py index 31c2384db..c95ce74a3 100644 --- a/test/unit_tests/braket/circuits/test_gate_calibration.py +++ b/test/unit_tests/braket/circuits/test_gate_calibration.py @@ -57,19 +57,30 @@ def test_gc_copy(pulse_sequence): def test_filter(pulse_sequence): - calibration_key = (Gate.Z(), QubitSet([0, 1])) - calibration_key_2 = (Gate.H(), QubitSet([0, 1])) + calibration_key = (Gate.Z(), QubitSet([0])) + calibration_key_2 = (Gate.H(), QubitSet([1])) + calibration_key_3 = (Gate.CZ(), QubitSet([0, 1])) calibration = GateCalibrations( - {calibration_key: pulse_sequence, calibration_key_2: pulse_sequence} + { + calibration_key: pulse_sequence, + calibration_key_2: pulse_sequence, + calibration_key_3: pulse_sequence, + } ) expected_calibration_1 = GateCalibrations({calibration_key: pulse_sequence}) expected_calibration_2 = GateCalibrations( - {calibration_key: pulse_sequence, calibration_key_2: pulse_sequence} + {calibration_key: pulse_sequence, calibration_key_3: pulse_sequence} ) expected_calibration_3 = GateCalibrations({calibration_key_2: pulse_sequence}) + expected_calibration_4 = GateCalibrations({}) + expected_calibration_5 = calibration + expected_calibration_6 = GateCalibrations({calibration_key_3: pulse_sequence}) assert expected_calibration_1 == calibration.filter(gates=[Gate.Z()]) assert expected_calibration_2 == calibration.filter(qubits=QubitSet(0)) assert expected_calibration_3 == calibration.filter(gates=[Gate.H()], qubits=QubitSet(1)) + assert expected_calibration_4 == calibration.filter(gates=[Gate.Z()], qubits=QubitSet(1)) + assert expected_calibration_5 == calibration.filter(qubits=[QubitSet(0), QubitSet(1)]) + assert expected_calibration_6 == calibration.filter(qubits=QubitSet([0, 1])) def test_to_ir(pulse_sequence):