Skip to content

Commit

Permalink
feat: add U and GPhase gates (#799)
Browse files Browse the repository at this point in the history
* Add U gate

* modification according to feedback

* fix linters

* clean commented code

* first version of gphase

* handle drawing and control global phase

* Adding a global phase attribute

* add global phase to circuit unitary

* first draft tests to check coverage

* add more tests

* add test with parametric global phase

* add test for neg control qubit printing

* clean up

* simplify ctrl-gphase transform

* feat: add str, repr and getitem to BasisState

* add repr coverage

* add index

* add pop

* fix phase target qubit

* fix typing

* add index and pop tests

* fix code coverage

* move unitary matrices

* use a subindex in MomentKey

* print global phase integration

* fix docstrings

* fix circuits zero total global phase

* fix edge cases

* fix to_unitary

* temporary fix that checks classname

* clean up test conditions

* change logic according to feedback

* update docstring

* clean tests

* update tests

* replace control symbols

* use box drawing characters

* Revert "use box drawing characters"

This reverts commit ccb81fa.

* Revert "replace control symbols"

This reverts commit 4efb8bc.

* simplify all gphase case

* change preprare_y_axis function name

* create an helper function to compute the global phase

* make control_basis_state more explicit

* add comment and clean grouping

* add test_only_neg_control_qubits

* parametrize test_one_gate_with_global_phase

* reformat

* change to printing with fixed precision

* fix docstring

---------

Co-authored-by: Aaron Berdy <berdy@amazon.com>
  • Loading branch information
jcjaskula-aws and ajberdy authored Dec 20, 2023
1 parent d45c9c8 commit 5d8d04c
Show file tree
Hide file tree
Showing 11 changed files with 639 additions and 30 deletions.
148 changes: 131 additions & 17 deletions src/braket/circuits/ascii_circuit_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
from braket.circuits.compiler_directive import CompilerDirective
from braket.circuits.gate import Gate
from braket.circuits.instruction import Instruction
from braket.circuits.moments import MomentType
from braket.circuits.noise import Noise
from braket.circuits.result_type import ResultType
from braket.registers.qubit import Qubit
from braket.registers.qubit_set import QubitSet


Expand All @@ -44,23 +46,26 @@ def build_diagram(circuit: cir.Circuit) -> str:
if not circuit.instructions:
return ""

if all(m.moment_type == MomentType.GLOBAL_PHASE for m in circuit._moments):
return f"Global phase: {circuit.global_phase}"

circuit_qubits = circuit.qubits
circuit_qubits.sort()

# Y Axis Column
y_axis_width = len(str(int(max(circuit_qubits))))
y_axis_str = "{0:{width}} : |\n".format("T", width=y_axis_width + 1)
for qubit in circuit_qubits:
y_axis_str += "{0:{width}}\n".format(" ", width=y_axis_width + 5)
y_axis_str += "q{0:{width}} : -\n".format(str(int(qubit)), width=y_axis_width)
y_axis_str, global_phase = AsciiCircuitDiagram._prepare_diagram_vars(
circuit, circuit_qubits
)

time_slices = circuit.moments.time_slices()
column_strs = []

# Moment columns
for time, instructions in time_slices.items():
global_phase = AsciiCircuitDiagram._compute_moment_global_phase(
global_phase, instructions
)
moment_str = AsciiCircuitDiagram._ascii_diagram_column_set(
str(time), circuit_qubits, instructions
str(time), circuit_qubits, instructions, global_phase
)
column_strs.append(moment_str)

Expand All @@ -71,7 +76,7 @@ def build_diagram(circuit: cir.Circuit) -> str:
if target_result_types:
column_strs.append(
AsciiCircuitDiagram._ascii_diagram_column_set(
"Result Types", circuit_qubits, target_result_types
"Result Types", circuit_qubits, target_result_types, global_phase
)
)

Expand All @@ -84,6 +89,9 @@ def build_diagram(circuit: cir.Circuit) -> str:
# Time on top and bottom
lines.append(lines[0])

if global_phase:
lines.append(f"\nGlobal phase: {global_phase}")

# Additional result types line on bottom
if additional_result_types:
lines.append(f"\nAdditional result types: {', '.join(additional_result_types)}")
Expand All @@ -97,6 +105,49 @@ def build_diagram(circuit: cir.Circuit) -> str:

return "\n".join(lines)

@staticmethod
def _prepare_diagram_vars(
circuit: cir.Circuit, circuit_qubits: QubitSet
) -> tuple[str, float | None]:
# Y Axis Column
y_axis_width = len(str(int(max(circuit_qubits))))
y_axis_str = "{0:{width}} : |\n".format("T", width=y_axis_width + 1)

global_phase = None
if any(m.moment_type == MomentType.GLOBAL_PHASE for m in circuit._moments):
y_axis_str += "{0:{width}} : |\n".format("GP", width=y_axis_width)
global_phase = 0

for qubit in circuit_qubits:
y_axis_str += "{0:{width}}\n".format(" ", width=y_axis_width + 5)
y_axis_str += "q{0:{width}} : -\n".format(str(int(qubit)), width=y_axis_width)

return y_axis_str, global_phase

@staticmethod
def _compute_moment_global_phase(
global_phase: float | None, items: list[Instruction]
) -> float | None:
"""
Compute the integrated phase at a certain moment.
Args:
global_phase (float | None): The integrated phase up to the computed moment
items (list[Instruction]): list of instructions
Returns:
float | None: The updated integrated phase.
"""
moment_phase = 0
for item in items:
if (
isinstance(item, Instruction)
and isinstance(item.operator, Gate)
and item.operator.name == "GPhase"
):
moment_phase += item.operator.angle
return global_phase + moment_phase if global_phase is not None else None

@staticmethod
def _ascii_group_items(
circuit_qubits: QubitSet,
Expand All @@ -120,7 +171,15 @@ def _ascii_group_items(
):
continue

if (isinstance(item, ResultType) and not item.target) or (
# As a zero-qubit gate, GPhase can be grouped with anything. We set qubit_range
# to an empty list and we just add it to the first group below.
if (
isinstance(item, Instruction)
and isinstance(item.operator, Gate)
and item.operator.name == "GPhase"
):
qubit_range = QubitSet()
elif (isinstance(item, ResultType) and not item.target) or (
isinstance(item, Instruction) and isinstance(item.operator, CompilerDirective)
):
qubit_range = circuit_qubits
Expand Down Expand Up @@ -175,7 +234,10 @@ def _categorize_result_types(

@staticmethod
def _ascii_diagram_column_set(
col_title: str, circuit_qubits: QubitSet, items: list[Union[Instruction, ResultType]]
col_title: str,
circuit_qubits: QubitSet,
items: list[Union[Instruction, ResultType]],
global_phase: float | None,
) -> str:
"""
Return a set of columns in the ASCII string diagram of the circuit for a list of items.
Expand All @@ -184,6 +246,7 @@ def _ascii_diagram_column_set(
col_title (str): title of column set
circuit_qubits (QubitSet): qubits in circuit
items (list[Union[Instruction, ResultType]]): list of instructions or result types
global_phase (float | None): the integrated global phase up to this set
Returns:
str: An ASCII string diagram for the column set.
Expand All @@ -193,7 +256,7 @@ def _ascii_diagram_column_set(
groupings = AsciiCircuitDiagram._ascii_group_items(circuit_qubits, items)

column_strs = [
AsciiCircuitDiagram._ascii_diagram_column(circuit_qubits, grouping[1])
AsciiCircuitDiagram._ascii_diagram_column(circuit_qubits, grouping[1], global_phase)
for grouping in groupings
]

Expand All @@ -220,17 +283,20 @@ def _ascii_diagram_column_set(

@staticmethod
def _ascii_diagram_column(
circuit_qubits: QubitSet, items: list[Union[Instruction, ResultType]]
circuit_qubits: QubitSet,
items: list[Union[Instruction, ResultType]],
global_phase: float | None = None,
) -> str:
"""
Return a column in the ASCII string diagram of the circuit for a given list of items.
Args:
circuit_qubits (QubitSet): qubits in circuit
items (list[Union[Instruction, ResultType]]): list of instructions or result types
global_phase (float | None): the integrated global phase up to this column
Returns:
str: An ASCII string diagram for the specified moment in time for a column.
str: an ASCII string diagram for the specified moment in time for a column.
"""
symbols = {qubit: "-" for qubit in circuit_qubits}
margins = {qubit: " " for qubit in circuit_qubits}
Expand All @@ -252,12 +318,26 @@ def _ascii_diagram_column(
num_after = len(circuit_qubits) - 1
after = ["|"] * (num_after - 1) + ([marker] if num_after else [])
ascii_symbols = [ascii_symbol] + after
elif (
isinstance(item, Instruction)
and isinstance(item.operator, Gate)
and item.operator.name == "GPhase"
):
target_qubits = circuit_qubits
control_qubits = QubitSet()
target_and_control = QubitSet()
qubits = circuit_qubits
ascii_symbols = "-" * len(circuit_qubits)
else:
if isinstance(item.target, list):
target_qubits = reduce(QubitSet.union, map(QubitSet, item.target), QubitSet())
else:
target_qubits = item.target
control_qubits = getattr(item, "control", QubitSet())
map_control_qubit_states = AsciiCircuitDiagram._build_map_control_qubits(
item, control_qubits
)

target_and_control = target_qubits.union(control_qubits)
qubits = QubitSet(range(min(target_and_control), max(target_and_control) + 1))

Expand Down Expand Up @@ -288,20 +368,54 @@ def _ascii_diagram_column(
else ascii_symbols[item_qubit_index]
)
elif qubit in control_qubits:
symbols[qubit] = "C"
symbols[qubit] = "C" if map_control_qubit_states[qubit] else "N"
else:
symbols[qubit] = "|"

# Set the margin to be a connector if not on the first qubit
if qubit != min(target_and_control):
if target_and_control and qubit != min(target_and_control):
margins[qubit] = "|"

symbols_width = max([len(symbol) for symbol in symbols.values()])
output = AsciiCircuitDiagram._create_output(symbols, margins, circuit_qubits, global_phase)
return output

@staticmethod
def _create_output(
symbols: dict[Qubit, str],
margins: dict[Qubit, str],
qubits: QubitSet,
global_phase: float | None,
) -> str:
symbols_width = max([len(symbol) for symbol in symbols.values()])
output = ""
for qubit in circuit_qubits:

if global_phase is not None:
global_phase_str = (
f"{global_phase:.2f}" if isinstance(global_phase, float) else str(global_phase)
)
symbols_width = max([symbols_width, len(global_phase_str)])
output += "{0:{fill}{align}{width}}|\n".format(
global_phase_str,
fill=" ",
align="^",
width=symbols_width,
)

for qubit in qubits:
output += "{0:{width}}\n".format(margins[qubit], width=symbols_width + 1)
output += "{0:{fill}{align}{width}}\n".format(
symbols[qubit], fill="-", align="<", width=symbols_width + 1
)
return output

@staticmethod
def _build_map_control_qubits(item: Instruction, control_qubits: QubitSet) -> dict(Qubit, int):
control_state = getattr(item, "control_state", None)
if control_state is not None:
map_control_qubit_states = {
qubit: state for qubit, state in zip(control_qubits, control_state)
}
else:
map_control_qubit_states = {qubit: 1 for qubit in control_qubits}

return map_control_qubit_states
11 changes: 9 additions & 2 deletions src/braket/circuits/braket_program_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,15 @@ def is_builtin_gate(self, name: str) -> bool:
user_defined_gate = self.is_user_defined_gate(name)
return name in BRAKET_GATES and not user_defined_gate

def add_phase_instruction(self, target: tuple[int], phase_value: int) -> None:
raise NotImplementedError
def add_phase_instruction(self, target: tuple[int], phase_value: float) -> None:
"""Add a global phase to the circuit.
Args:
target (tuple[int]): Unused
phase_value (float): The phase value to be applied
"""
instruction = Instruction(BRAKET_GATES["gphase"](phase_value))
self._circuit.add_instruction(instruction)

def add_gate_instruction(
self, gate_name: str, target: tuple[int], *params, ctrl_modifiers: list[int], power: float
Expand Down
13 changes: 12 additions & 1 deletion src/braket/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from braket.circuits.free_parameter_expression import FreeParameterExpression
from braket.circuits.gate import Gate
from braket.circuits.instruction import Instruction
from braket.circuits.moments import Moments
from braket.circuits.moments import Moments, MomentType
from braket.circuits.noise import Noise
from braket.circuits.noise_helpers import (
apply_noise_to_gates,
Expand Down Expand Up @@ -156,6 +156,17 @@ def depth(self) -> int:
"""int: Get the circuit depth."""
return self._moments.depth

@property
def global_phase(self) -> float:
"""float: Get the global phase of the circuit."""
return sum(
[
instr.operator.angle
for moment, instr in self._moments.items()
if moment.moment_type == MomentType.GLOBAL_PHASE
]
)

@property
def instructions(self) -> list[Instruction]:
"""Iterable[Instruction]: Get an `iterable` of instructions in the circuit."""
Expand Down
2 changes: 1 addition & 1 deletion src/braket/circuits/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _to_openqasm(

return (
f"{inv_prefix}{power_prefix}{control_prefix}"
f"{self._qasm_name}{param_string} {', '.join(qubits)};"
f"{self._qasm_name}{param_string}{','.join([f' {qubit}' for qubit in qubits])};"
)

@property
Expand Down
Loading

0 comments on commit 5d8d04c

Please sign in to comment.