Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: upgrade AutoQASM to use oqpy 0.3.3 #715

Merged
merged 19 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
# simulation of mid-circuit measurement, which AutoQASM requires.
# NOTE: This change should remain in the feature/autoqasm branch; do not merge to main.
"amazon-braket-default-simulator @ git+https://github.com/aws/amazon-braket-default-simulator-python.git@46aea776976ad7f958d847c06f29f3a7976f5cf5#egg=amazon-braket-default-simulator", # noqa E501
# Pin the latest commit of the qubit-array branch of ajberdy/oqpy.git to get the version of
# oqpy which contains changes that AutoQASM relies on, including the QubitArray type.
# NOTE: This change should remain in the feature/autoqasm branch; do not merge to main.
"oqpy @ git+https://github.com/ajberdy/oqpy.git@26cf4f9089c3b381370917734d2d964c45c4458d#egg=oqpy", # noqa E501
"oqpy~=0.3.3",
"setuptools",
"backoff",
"boltons",
Expand Down
2 changes: 1 addition & 1 deletion src/braket/experimental/autoqasm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _add_qubit_declaration(program_conversion_context: aq_program.ProgramConvers
scope=aq_program.ProgramScope.MAIN
)
root_oqpy_program.declare(
[oqpy.QubitArray(aq_constants.QUBIT_REGISTER, num_qubits)],
[oqpy.Qubit(aq_constants.QUBIT_REGISTER, num_qubits)],
to_beginning=True,
)

Expand Down
6 changes: 3 additions & 3 deletions src/braket/experimental/autoqasm/instructions/qubits.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def _get_physical_qubit_indices(qids: List[str]) -> List[int]:


def _global_qubit_register(qubit_idx_expr: Union[int, str]) -> str:
# TODO: We should index into a oqpy.QubitArray rather
# TODO: We should index into a oqpy.Qubit register rather
# than manually generating the string to index into
# a hard-coded global qubit array.
# a hard-coded global qubit register.
return f"{constants.QUBIT_REGISTER}[{qubit_idx_expr}]"


Expand Down Expand Up @@ -120,7 +120,7 @@ def _(qid: str) -> oqpy.Qubit:
if qid.startswith("$"):
qubit_idx = qid[1:]
try:
int(qubit_idx)
qubit_idx = int(qubit_idx)
except ValueError:
raise ValueError(f"invalid physical qubit label: '{qid}'")
return oqpy.PhysicalQubits[qubit_idx]
Expand Down
4 changes: 3 additions & 1 deletion src/braket/experimental/autoqasm/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
OpenQASMSerializationProperties,
SerializationProperties,
)
from braket.pulse.ast.qasm_parser import ast_to_qasm

# Create the thread-local object for the program conversion context.
_local = threading.local()
Expand Down Expand Up @@ -139,10 +140,11 @@ def to_ir(
str: A representation of the program in the `ir_type` format.
"""
if ir_type == IRType.OPENQASM:
openqasm_ir = self._oqpy_program.to_qasm(
openqasm_ast = self._oqpy_program.to_ast(
encal_declarations=self._has_pulse_control,
include_externs=serialization_properties.include_externs,
)
openqasm_ir = ast_to_qasm(openqasm_ast)
if self._has_pulse_control and not serialization_properties.auto_defcalgrammar:
openqasm_ir = openqasm_ir.replace('defcalgrammar "openpulse";\n', "")
return openqasm_ir
Expand Down
7 changes: 5 additions & 2 deletions src/braket/experimental/autoqasm/pulse/pulse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
is_qubit_identifier_type,
)
from braket.experimental.autoqasm.types import BitVar
from braket.parametric.free_parameter import FreeParameter
from braket.pulse import PulseSequence
from braket.pulse.frame import Frame
from braket.pulse.pulse_sequence import _validate_uniqueness
Expand Down Expand Up @@ -127,19 +128,21 @@ def capture_v0(frame: Frame) -> None:

def delay(
qubits_or_frames: Union[Frame, List[Frame], QubitIdentifierType, List[QubitIdentifierType]],
duration: float,
duration: Union[float, oqpy.FloatVar],
jcjaskula-aws marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Adds an instruction to advance the frame clock by the specified `duration` value.

Args:
qubits_or_frames (Union[Frame, List[Frame], QubitIdentifierType, List[QubitIdentifierType]]):
Qubits or frame(s) on which the delay needs to be introduced.
duration (float): Value (in seconds) defining the duration of the delay.
duration (Union[float, FloatVar]): Value (in seconds) defining the duration of the delay.
""" # noqa: E501
if not isinstance(qubits_or_frames, List):
qubits_or_frames = [qubits_or_frames]
if all(is_qubit_identifier_type(q) for q in qubits_or_frames):
qubits_or_frames = QubitSet(_get_physical_qubit_indices(qubits_or_frames))
if isinstance(duration, oqpy.FloatVar):
duration = FreeParameter(duration.name)
jcjaskula-aws marked this conversation as resolved.
Show resolved Hide resolved
_pulse_instruction("delay", qubits_or_frames, duration)


Expand Down
1 change: 1 addition & 0 deletions src/braket/experimental/autoqasm/types/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def map_type(python_type: type) -> type:
raise errors.ParameterTypeError(
f"Unsupported array type: {item_type}. AutoQASM arrays only support ints."
)

# TODO: Update array length to match the input rather than hardcoding
# OQPY and QASM require arrays have a set length. python doesn't require this,
# so the length of the array is indeterminate.
Expand Down
52 changes: 49 additions & 3 deletions src/braket/parametric/free_parameter_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,18 @@

import ast
from numbers import Number
from typing import Any, Union

from typing import Any, Optional, Union

from openpulse.ast import (
ClassicalType,
DurationLiteral,
DurationType,
Expression,
FloatType,
Identifier,
TimeUnit,
)
from oqpy import Program
from sympy import Expr, Float, Symbol, sympify


Expand All @@ -30,7 +40,11 @@ class FreeParameterExpression:
present will NOT run. Values must be substituted prior to execution.
"""

def __init__(self, expression: Union[FreeParameterExpression, Number, Expr, str]):
def __init__(
self,
expression: Union[FreeParameterExpression, Number, Expr, str],
_type: Optional[ClassicalType] = None,
):
"""
Initializes a FreeParameterExpression. Best practice is to initialize using
FreeParameters and Numbers. Not meant to be initialized directly.
Expand All @@ -39,6 +53,10 @@ def __init__(self, expression: Union[FreeParameterExpression, Number, Expr, str]

Args:
expression (Union[FreeParameterExpression, Number, Expr, str]): The expression to use.
_type (Optional[ClassicalType]): The OpenQASM3 type associated with the expression.
Subtypes of openqasm3.ast.ClassicalType are used to specify how to express the
expression in the OpenQASM3 IR. Any type other than DurationType is considered
as FloatType.

Examples:
>>> expression_1 = FreeParameter("theta") * FreeParameter("alpha")
Expand All @@ -51,8 +69,11 @@ def __init__(self, expression: Union[FreeParameterExpression, Number, Expr, str]
ast.Pow: self.__pow__,
ast.USub: self.__neg__,
}
self._type = _type if _type is not None else FloatType()
if isinstance(expression, FreeParameterExpression):
self._expression = expression.expression
if _type is None:
self._type = expression._type
elif isinstance(expression, (Number, Expr)):
self._expression = expression
elif isinstance(expression, str):
Expand Down Expand Up @@ -170,6 +191,31 @@ def __repr__(self) -> str:
"""
return repr(self.expression)

def to_ast(self, program: Program) -> Expression:
"""Creates an AST node for the :class:'FreeParameterExpression'.

Args:
program (Program): Unused.

Returns:
Expression: The AST node.
"""
if isinstance(self._type, DurationType):
return DurationLiteral(_FreeParameterExpressionIdentifier(self), TimeUnit.s)
return _FreeParameterExpressionIdentifier(self)


class _FreeParameterExpressionIdentifier(Identifier):
"""Dummy AST node with FreeParameterExpression instance attached"""

def __init__(self, expression: FreeParameterExpression):
super().__init__(name=f"FreeParameterExpression({expression})")
self._expression = expression

@property
def expression(self) -> FreeParameterExpression:
return self._expression


def subs_if_free_parameter(parameter: Any, **kwargs) -> Any:
"""Substitute a free parameter with the given kwargs, if any.
Expand Down
2 changes: 1 addition & 1 deletion src/braket/pulse/ast/approximation_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, program: Program, frames: dict[str, Frame]):
self.amplitudes = defaultdict(TimeSeries)
self.frequencies = defaultdict(TimeSeries)
self.phases = defaultdict(TimeSeries)
context = _ParseState(variables=dict(), frame_data=_init_frame_data(frames))
context = _ParseState(variables={"pi": np.pi}, frame_data=_init_frame_data(frames))
self._qubit_frames_mapping: dict[str, list[str]] = _init_qubit_frame_mapping(frames)
self.visit(program.to_ast(include_externs=False), context)

Expand Down
38 changes: 18 additions & 20 deletions src/braket/pulse/ast/free_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,30 @@
from typing import Union

from openpulse import ast
from openqasm3.ast import DurationLiteral
from openqasm3.visitor import QASMTransformer
from oqpy.program import Program
from oqpy.timing import OQDurationLiteral

from braket.parametric.free_parameter_expression import FreeParameterExpression


class _FreeParameterExpressionIdentifier(ast.Identifier):
"""Dummy AST node with FreeParameterExpression instance attached"""

def __init__(self, expression: FreeParameterExpression):
super().__init__(name=f"FreeParameterExpression({expression})")
self._expression = expression

@property
def expression(self) -> FreeParameterExpression:
return self._expression
from braket.parametric.free_parameter_expression import (
FreeParameterExpression,
_FreeParameterExpressionIdentifier,
)


class _FreeParameterTransformer(QASMTransformer):
"""Walk the AST and evaluate FreeParameterExpressions."""

def __init__(self, param_values: dict[str, float]):
def __init__(self, param_values: dict[str, float], program: Program):
self.param_values = param_values
self.program = program
super().__init__()

def visit__FreeParameterExpressionIdentifier(
self, identifier: ast.Identifier
self, identifier: _FreeParameterExpressionIdentifier
) -> Union[_FreeParameterExpressionIdentifier, ast.FloatLiteral]:
"""Visit a FreeParameterExpressionIdentifier.
Args:
identifier (Identifier): The identifier.
identifier (_FreeParameterExpressionIdentifier): The identifier.

Returns:
Union[_FreeParameterExpressionIdentifier, FloatLiteral]: The transformed expression.
Expand All @@ -55,7 +48,7 @@ def visit__FreeParameterExpressionIdentifier(
else:
return ast.FloatLiteral(new_value)

def visit_DurationLiteral(self, duration_literal: DurationLiteral) -> DurationLiteral:
def visit_DurationLiteral(self, duration_literal: ast.DurationLiteral) -> ast.DurationLiteral:
"""Visit Duration Literal.
node.value, node.unit (node.unit.name, node.unit.value)
1
Expand All @@ -65,6 +58,11 @@ def visit_DurationLiteral(self, duration_literal: DurationLiteral) -> DurationLi
DurationLiteral: The transformed duration literal.
"""
duration = duration_literal.value
if not isinstance(duration, FreeParameterExpression):
if not isinstance(duration, _FreeParameterExpressionIdentifier):
return duration_literal
return DurationLiteral(duration.subs(self.param_values), duration_literal.unit)
new_duration = duration.expression.subs(self.param_values)
if isinstance(new_duration, FreeParameterExpression):
return ast.DurationLiteral(
_FreeParameterExpressionIdentifier(new_duration), duration_literal.unit
)
return OQDurationLiteral(new_duration).to_ast(self.program)
6 changes: 3 additions & 3 deletions src/braket/pulse/ast/qasm_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from openqasm3.ast import DurationLiteral
from openqasm3.printer import PrinterState

from braket.parametric.free_parameter_expression import FreeParameterExpression
from braket.pulse.ast.free_parameters import _FreeParameterExpressionIdentifier


class _PulsePrinter(Printer):
Expand Down Expand Up @@ -46,8 +46,8 @@ def visit_DurationLiteral(self, node: DurationLiteral, context: PrinterState) ->
context (PrinterState): The printer state context.
"""
duration = node.value
if isinstance(duration, FreeParameterExpression):
self.stream.write(f"({duration.expression}){node.unit.name}")
if isinstance(duration, _FreeParameterExpressionIdentifier):
self.stream.write(f"({duration.expression}) * 1{node.unit.name}")
else:
super().visit_DurationLiteral(node, context)

Expand Down
19 changes: 11 additions & 8 deletions src/braket/pulse/pulse_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,7 @@ def delay(
Returns:
PulseSequence: self, with the instruction added.
"""
if isinstance(duration, FreeParameterExpression):
for p in duration.expression.free_symbols:
self._free_parameters.add(FreeParameter(p.name))
duration = OQDurationLiteral(duration)
duration = self._format_parameter_ast(duration, _type=ast.DurationType())
if not isinstance(qubits_or_frames, QubitSet):
if not isinstance(qubits_or_frames, list):
qubits_or_frames = [qubits_or_frames]
Expand Down Expand Up @@ -276,7 +273,7 @@ def make_bound_pulse_sequence(self, param_values: dict[str, float]) -> PulseSequ
"""
program = deepcopy(self._program)
tree: ast.Program = program.to_ast(include_externs=False, ignore_needs_declaration=True)
new_tree: ast.Program = _FreeParameterTransformer(param_values).visit(tree)
new_tree: ast.Program = _FreeParameterTransformer(param_values, program).visit(tree)

new_program = Program(simplify_constants=False)
new_program.declared_vars = program.declared_vars
Expand Down Expand Up @@ -325,13 +322,19 @@ def to_ir(self) -> str:
return ast_to_qasm(tree)

def _format_parameter_ast(
self, parameter: Union[float, FreeParameterExpression]
self,
parameter: Union[float, FreeParameterExpression],
_type: ast.ClassicalType = ast.FloatType(),
) -> Union[float, _FreeParameterExpressionIdentifier]:
if isinstance(parameter, FreeParameterExpression):
for p in parameter.expression.free_symbols:
self._free_parameters.add(FreeParameter(p.name))
return _FreeParameterExpressionIdentifier(parameter)
return parameter
return (
FreeParameterExpression(parameter, _type)
if isinstance(_type, ast.DurationType)
else parameter
)
return OQDurationLiteral(parameter) if isinstance(_type, ast.DurationType) else parameter

def _parse_arg_from_calibration_schema(
self, argument: dict, waveforms: dict[Waveform], frames: dict[Frame]
Expand Down
16 changes: 6 additions & 10 deletions src/braket/pulse/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,13 @@
import numpy as np
from oqpy import WaveformVar, bool_, complex128, declare_waveform_generator, duration, float64
from oqpy.base import OQPyExpression
from oqpy.timing import OQDurationLiteral

from braket.parametric.free_parameter import FreeParameter
from braket.parametric.free_parameter_expression import (
FreeParameterExpression,
subs_if_free_parameter,
)
from braket.parametric.parameterizable import Parameterizable
from braket.pulse.ast.free_parameters import _FreeParameterExpressionIdentifier


class Waveform(ABC):
Expand Down Expand Up @@ -454,14 +452,12 @@ def _make_identifier_name() -> str:

def _map_to_oqpy_type(
parameter: Union[FreeParameterExpression, float], is_duration_type: bool = False
) -> Union[_FreeParameterExpressionIdentifier, OQPyExpression]:
if isinstance(parameter, FreeParameterExpression):
return (
OQDurationLiteral(parameter)
if is_duration_type
else _FreeParameterExpressionIdentifier(parameter)
)
return parameter
) -> Union[FreeParameterExpression, OQPyExpression]:
return (
FreeParameterExpression(parameter, duration)
if isinstance(parameter, FreeParameterExpression) and is_duration_type
else parameter
)


def _parse_waveform_from_calibration_schema(waveform: dict) -> Waveform:
Expand Down
2 changes: 1 addition & 1 deletion test/unit_tests/braket/circuits/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def to_ir(pulse_gate):
[
"cal {",
" set_frequency(user_frame, b + 3);",
" delay[(1000000000.0*c)ns] user_frame;",
" delay[(c) * 1s] user_frame;",
"}",
]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def my_program():
"""
OPENQASM 3.0;
defcal rx(angle[32] angle) $1 {
delay[angle * 1s] $1;
delay[(angle) * 1s] $1;
}
rx(1.0) $1;
"""
Expand Down Expand Up @@ -222,7 +222,7 @@ def my_program():
}
defcal my_gate(angle[32] a) $0 {
barrier $0;
delay[a * 1s] $0;
delay[(a) * 1s] $0;
}
qubit[3] __qubits__;
my_gate(0.123) __qubits__[2];
Expand Down
Loading