diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index 9ce7df7f8..02d419cff 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -319,8 +319,8 @@ def add_result_type( observable = Circuit._extract_observable(result_type_to_add) # We can skip this for now for AdjointGradient (the only subtype of this # type) because AdjointGradient can only be used when `shots=0`, and the - # qubit_observable_mapping is used to generate basis rotation instrunctions - # and make sure the observables are simultaneously commuting for `shots>0` mode. + # qubit_observable_mapping is used to generate basis rotation instructions + # and make sure the observables mutually commute for `shots>0` mode. supports_basis_rotation_instructions = not isinstance( result_type_to_add, ObservableParameterResultType ) diff --git a/src/braket/circuits/observable.py b/src/braket/circuits/observable.py index 11ed02eb6..e0572e3d9 100644 --- a/src/braket/circuits/observable.py +++ b/src/braket/circuits/observable.py @@ -40,20 +40,18 @@ def __init__( self, qubit_count: int, ascii_symbols: Sequence[str], targets: QubitSetInput | None = None ): super().__init__(qubit_count=qubit_count, ascii_symbols=ascii_symbols) - if targets is not None: - targets = QubitSet(targets) + targets = QubitSet(targets) + if targets: if (num_targets := len(targets)) != qubit_count: raise ValueError( f"Length of target {num_targets} does not match qubit count {qubit_count}" ) - self._targets = targets - else: - self._targets = None + self._targets = targets self._coef = 1 def _unscaled(self) -> Observable: return Observable( - qubit_count=self.qubit_count, ascii_symbols=self.ascii_symbols, targets=self.targets + qubit_count=self.qubit_count, ascii_symbols=self.ascii_symbols, targets=self._targets ) def to_ir( @@ -207,7 +205,7 @@ def __sub__(self, other: Observable): def __repr__(self) -> str: return ( f"{self.name}('qubit_count': {self._qubit_count})" - if self._targets is None + if not self._targets else f"{self.name}('qubit_count': {self._qubit_count}, 'target': {self._targets})" ) diff --git a/src/braket/circuits/observables.py b/src/braket/circuits/observables.py index 9346435a2..a9039ee4c 100644 --- a/src/braket/circuits/observables.py +++ b/src/braket/circuits/observables.py @@ -508,9 +508,9 @@ def __init__(self, observables: list[Observable], display_name: str = "Hamiltoni self._summands = tuple(flattened_observables) qubit_count = max(flattened_observables, key=lambda obs: obs.qubit_count).qubit_count all_targets = [observable.targets for observable in flattened_observables] - if all(targets is None for targets in all_targets): - targets = None - elif all(targets is not None for targets in all_targets): + if not any(all_targets): + targets = QubitSet() + elif all(all_targets): targets = all_targets else: raise ValueError("Cannot mix terms with and without targets") diff --git a/src/braket/circuits/result_type.py b/src/braket/circuits/result_type.py index a7ce2436c..8343429e0 100644 --- a/src/braket/circuits/result_type.py +++ b/src/braket/circuits/result_type.py @@ -243,7 +243,7 @@ def observable(self) -> Observable: @property def target(self) -> QubitSet: - return self._target + return self._target or self._observable.targets @target.setter def target(self, target: QubitSetInput) -> None: diff --git a/src/braket/circuits/result_types.py b/src/braket/circuits/result_types.py index 8af36ab4a..6c1232b06 100644 --- a/src/braket/circuits/result_types.py +++ b/src/braket/circuits/result_types.py @@ -218,7 +218,7 @@ def __init__( def _to_openqasm(self, serialization_properties: OpenQASMSerializationProperties) -> str: observable_ir = self.observable.to_ir( - target=self.target, + target=self._target, ir_type=IRType.OPENQASM, serialization_properties=serialization_properties, ) @@ -477,7 +477,7 @@ def _to_jaqcd(self) -> ir.Expectation: def _to_openqasm(self, serialization_properties: OpenQASMSerializationProperties) -> str: observable_ir = self.observable.to_ir( - target=self.target, + target=self._target, ir_type=IRType.OPENQASM, serialization_properties=serialization_properties, ) @@ -552,7 +552,7 @@ def _to_jaqcd(self) -> ir.Sample: def _to_openqasm(self, serialization_properties: OpenQASMSerializationProperties) -> str: observable_ir = self.observable.to_ir( - target=self.target, + target=self._target, ir_type=IRType.OPENQASM, serialization_properties=serialization_properties, ) @@ -632,7 +632,7 @@ def _to_jaqcd(self) -> ir.Variance: def _to_openqasm(self, serialization_properties: OpenQASMSerializationProperties) -> str: observable_ir = self.observable.to_ir( - target=self.target, + target=self._target, ir_type=IRType.OPENQASM, serialization_properties=serialization_properties, ) diff --git a/test/unit_tests/braket/circuits/test_result_type.py b/test/unit_tests/braket/circuits/test_result_type.py index bc7cc0909..bb543b7c0 100644 --- a/test/unit_tests/braket/circuits/test_result_type.py +++ b/test/unit_tests/braket/circuits/test_result_type.py @@ -18,6 +18,7 @@ from braket.circuits.free_parameter import FreeParameter from braket.circuits.result_type import ObservableParameterResultType from braket.circuits.serialization import IRType +from braket.registers import QubitSet @pytest.fixture @@ -168,6 +169,15 @@ def test_obs_rt_repr(): ) +def test_obs_rt_target(): + assert ObservableResultType( + ascii_symbols=["Obs"], observable=Observable.X(), target=1 + ).target == QubitSet(1) + assert ObservableResultType( + ascii_symbols=["Obs"], observable=Observable.X(1) + ).target == QubitSet(1) + + @pytest.mark.parametrize( "ir_type, serialization_properties, expected_exception, expected_message", [