Skip to content

Commit

Permalink
fix: Allow identities in PauliString observable (#867)
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 authored Feb 2, 2024
1 parent 458249d commit e07672e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
20 changes: 18 additions & 2 deletions src/braket/quantum_information/pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Optional, Union

from braket.circuits.circuit import Circuit
from braket.circuits.observables import TensorProduct, X, Y, Z
from braket.circuits.observables import I, TensorProduct, X, Y, Z

_IDENTITY = "I"
_PAULI_X = "X"
Expand All @@ -29,6 +29,7 @@
"Y": {"X": ["Z", -1j], "Z": ["X", 1j]},
"Z": {"X": ["Y", 1j], "Y": ["X", -1j]},
}
_ID_OBS = I()
_PAULI_OBSERVABLES = {_PAULI_X: X(), _PAULI_Y: Y(), _PAULI_Z: Z()}
_SIGN_MAP = {"+": 1, "-": -1}

Expand Down Expand Up @@ -74,14 +75,29 @@ def qubit_count(self) -> int:
"""int: The number of qubits this Pauli string acts on."""
return self._qubit_count

def to_unsigned_observable(self) -> TensorProduct:
def to_unsigned_observable(self, include_trivial: bool = False) -> TensorProduct:
"""Returns the observable corresponding to the unsigned part of the Pauli string.
For example, for a Pauli string -XYZ, the corresponding observable is X ⊗ Y ⊗ Z.
Args:
include_trivial (bool): Whether to include explicit identity factors in the observable.
Default: False.
Returns:
TensorProduct: The tensor product of the unsigned factors in the Pauli string.
"""
if include_trivial:
return TensorProduct(
[
(
_PAULI_OBSERVABLES[self._nontrivial[qubit]]
if qubit in self._nontrivial
else _ID_OBS
)
for qubit in range(self._qubit_count)
]
)
return TensorProduct(
[_PAULI_OBSERVABLES[self._nontrivial[qubit]] for qubit in sorted(self._nontrivial)]
)
Expand Down
16 changes: 9 additions & 7 deletions test/unit_tests/braket/quantum_information/test_pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from braket.circuits import gates
from braket.circuits.circuit import Circuit
from braket.circuits.observables import X, Y, Z
from braket.circuits.observables import I, X, Y, Z
from braket.quantum_information import PauliString

ORDER = ["I", "X", "Y", "Z"]
Expand All @@ -34,15 +34,16 @@


@pytest.mark.parametrize(
"pauli_string, string, phase, observable",
"pauli_string, string, phase, observable, obs_with_id",
[
("+XZ", "+XZ", 1, X() @ Z()),
("-ZXY", "-ZXY", -1, Z() @ X() @ Y()),
("YIX", "+YIX", 1, Y() @ X()),
(PauliString("-ZYXI"), "-ZYXI", -1, Z() @ Y() @ X()),
("+XZ", "+XZ", 1, X() @ Z(), X() @ Z()),
("-ZXY", "-ZXY", -1, Z() @ X() @ Y(), Z() @ X() @ Y()),
("YIX", "+YIX", 1, Y() @ X(), Y() @ I() @ X()),
(PauliString("-ZYXI"), "-ZYXI", -1, Z() @ Y() @ X(), Z() @ Y() @ X() @ I()),
("IIXIIIYI", "+IIXIIIYI", 1, X() @ Y(), I() @ I() @ X() @ I() @ I() @ I() @ Y() @ I()),
],
)
def test_happy_case(pauli_string, string, phase, observable):
def test_happy_case(pauli_string, string, phase, observable, obs_with_id):
instance = PauliString(pauli_string)
assert str(instance) == string
assert instance.phase == phase
Expand All @@ -57,6 +58,7 @@ def test_happy_case(pauli_string, string, phase, observable):
assert instance == PauliString(pauli_string)
assert instance == PauliString(instance)
assert instance.to_unsigned_observable() == observable
assert instance.to_unsigned_observable(include_trivial=True) == obs_with_id


@pytest.mark.parametrize(
Expand Down

0 comments on commit e07672e

Please sign in to comment.