Skip to content

Commit

Permalink
Add converter for comparison statements. Add new tests
Browse files Browse the repository at this point in the history
  • Loading branch information
laurencap committed Nov 7, 2023
1 parent 2a84dad commit 2245cbb
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 33 deletions.
70 changes: 70 additions & 0 deletions src/braket/experimental/autoqasm/converters/comparisons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""Converters for comparison nodes."""

import ast

import gast

from braket.experimental.autoqasm.autograph.core import ag_ctx, converter
from braket.experimental.autoqasm.autograph.pyct import templates

COMPARISON_OPERATORS = {
gast.Lt: "ag__.lt_",
gast.LtE: "ag__.lteq_",
gast.Gt: "ag__.gt_",
gast.GtE: "ag__.gteq_",
}


class ComparisonTransformer(converter.Base):
"""Transformer for comparison nodes."""

def visit_Compare(self, node: ast.stmt) -> ast.stmt:
"""Transforms a comparison node.
Args:
node (ast.stmt): AST node to transform.
Returns:
ast.stmt: Transformed node.
"""
node = self.generic_visit(node)

op_type = type(node.ops[0])
if op_type not in COMPARISON_OPERATORS:
return node

template = f"{COMPARISON_OPERATORS[op_type]}(lhs_, rhs_)"

return templates.replace(
template,
lhs_=node.left,
rhs_=node.comparators[0],
original=node,
)[0].value


def transform(node: ast.stmt, ctx: ag_ctx.ControlStatusCtx) -> ast.stmt:
"""Transform comparison nodes.
Args:
node (ast.stmt): AST node to transform.
ctx (ag_ctx.ControlStatusCtx): Transformer context.
Returns:
ast.stmt: Transformed node.
"""

return ComparisonTransformer(ctx).visit(node)
1 change: 1 addition & 0 deletions src/braket/experimental/autoqasm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

from .assignments import assign_stmt # noqa: F401
from .comparisons import gt_, gteq_, lt_, lteq_ # noqa: F401
from .conditional_expressions import if_exp # noqa: F401
from .control_flow import for_stmt, if_stmt, while_stmt # noqa: F401
from .data_structures import ListPopOpts # noqa: F401
Expand Down
137 changes: 137 additions & 0 deletions src/braket/experimental/autoqasm/operators/comparisons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.


"""Operators for comparison operators: <, <=, >, and >=."""

from collections.abc import Callable
from typing import Any, Union

import oqpy.base

from braket.experimental.autoqasm import program
from braket.experimental.autoqasm import types as aq_types

from .utils import _convert_parameters


def lt_(a: Any, b: Any) -> Union[bool, aq_types.BoolVar]:
"""Functional form of "<".
Args:
a (Any): Callable that returns the first expression.
b (Any): Callable that returns the second expression.
Returns:
Union[bool, BoolVar]: Whether the first expression is less than the second.
"""
if aq_types.is_qasm_type(a) or aq_types.is_qasm_type(b):
return _aq_lt(a, b)
else:
return a < b


def _aq_lt(a: Any, b: Any) -> aq_types.BoolVar:
program_conversion_context = program.get_program_conversion_context()
program_conversion_context.register_args([a, b])
a, b = _convert_parameters(a, b)

oqpy_program = program_conversion_context.get_oqpy_program()
result = aq_types.BoolVar()
oqpy_program.declare(result)
oqpy_program.set(result, a < b)
return result


def lteq_(a: Any, b: Any) -> Union[bool, aq_types.BoolVar]:
"""Functional form of "<=".
Args:
a (Any): Callable that returns the first expression.
b (Any): Callable that returns the second expression.
Returns:
Union[bool, BoolVar]: Whether the first expression is less than or equal to the second.
"""
if aq_types.is_qasm_type(a) or aq_types.is_qasm_type(b):
return _aq_lteq(a, b)
else:
return a <= b


def _aq_lteq(a: Any, b: Any) -> aq_types.BoolVar:
program_conversion_context = program.get_program_conversion_context()
program_conversion_context.register_args([a, b])
a, b = _convert_parameters(a, b)

oqpy_program = program_conversion_context.get_oqpy_program()
result = aq_types.BoolVar()
oqpy_program.declare(result)
oqpy_program.set(result, a <= b)
return result


def gt_(a: Any, b: Any) -> Union[bool, aq_types.BoolVar]:
"""Functional form of ">".
Args:
a (Any): Callable that returns the first expression.
b (Any): Callable that returns the second expression.
Returns:
Union[bool, BoolVar]: Whether the first expression greater than the second.
"""
if aq_types.is_qasm_type(a) or aq_types.is_qasm_type(b):
return _aq_gt(a, b)
else:
return a > b


def _aq_gt(a: Any, b: Any) -> aq_types.BoolVar:
program_conversion_context = program.get_program_conversion_context()
program_conversion_context.register_args([a, b])
a, b = _convert_parameters(a, b)

oqpy_program = program_conversion_context.get_oqpy_program()
result = aq_types.BoolVar()
oqpy_program.declare(result)
oqpy_program.set(result, a > b)
return result


def gteq_(a: Any, b: Any) -> Union[bool, aq_types.BoolVar]:
"""Functional form of ">=".
Args:
a (Any): Callable that returns the first expression.
b (Any): Callable that returns the second expression.
Returns:
Union[bool, BoolVar]: Whether the first expression greater than or equal to the second.
"""
if aq_types.is_qasm_type(a) or aq_types.is_qasm_type(b):
return _aq_gteq(a, b)
else:
return a >= b


def _aq_gteq(a: Any, b: Any) -> aq_types.BoolVar:
program_conversion_context = program.get_program_conversion_context()
program_conversion_context.register_args([a, b])
a, b = _convert_parameters(a, b)

oqpy_program = program_conversion_context.get_oqpy_program()
result = aq_types.BoolVar()
oqpy_program.declare(result)
oqpy_program.set(result, a >= b)
return result
23 changes: 1 addition & 22 deletions src/braket/experimental/autoqasm/operators/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,10 @@
import oqpy.base
from openpulse import ast

from braket.circuits import FreeParameter
from braket.experimental.autoqasm import program
from braket.experimental.autoqasm import types as aq_types


def _convert_parameters(*args: list[FreeParameter]) -> list[aq_types.FloatVar]:
"""Converts FreeParameter objects to FloatVars through the program conversion context
parameter registry.
FloatVars are more compatible with the program conversion operations.
Args:
args (list[FreeParameter]): FreeParameter objects.
Returns:
list[FloatVar]: FloatVars for program conversion.
"""
result = []
for arg in args:
if isinstance(arg, FreeParameter):
var = program.get_program_conversion_context()._free_parameters[arg.name]
result.append(var)
else:
result.append(arg)
return result[0] if len(result) == 1 else result
from .utils import _convert_parameters


def and_(a: Callable[[], Any], b: Callable[[], Any]) -> Union[bool, aq_types.BoolVar]:
Expand Down
41 changes: 41 additions & 0 deletions src/braket/experimental/autoqasm/operators/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.


"Utility methods for operators."

from braket.circuits import FreeParameter
from braket.experimental.autoqasm import program
from braket.experimental.autoqasm import types as aq_types


def _convert_parameters(*args: list[FreeParameter]) -> list[aq_types.FloatVar]:
"""Converts FreeParameter objects to FloatVars through the program conversion context
parameter registry.
FloatVars are more compatible with the program conversion operations.
Args:
args (list[FreeParameter]): FreeParameter objects.
Returns:
list[FloatVar]: FloatVars for program conversion.
"""
result = []
for arg in args:
if isinstance(arg, FreeParameter):
var = program.get_program_conversion_context().get_parameter(arg.name)
result.append(var)
else:
result.append(arg)
return result[0] if len(result) == 1 else result
16 changes: 16 additions & 0 deletions src/braket/experimental/autoqasm/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,22 @@ def register_parameter(self, parameter: FreeParameter) -> None:
if parameter.name not in self._free_parameters:
self._free_parameters[parameter.name] = oqpy.FloatVar("input", name=parameter.name)

def get_parameter(self, name: str) -> oqpy.FloatVar:
"""Return a named oqpy.FloatVar that is used as a free parameter in the program.
Args:
name (str): The name of the parameter.
Raises:
ValueError: If there is no parameter with the given name registered with the program.
Returns:
FloatVar: The associated variable.
"""
if name not in self._free_parameters:
raise ValueError(f"Free parameter '{name}' was not found.")
return self._free_parameters[name]

def get_free_parameters(self) -> list[oqpy.FloatVar]:
"""Return a list of named oqpy.Vars that are used as free parameters in the program."""
return list(self._free_parameters.values())
Expand Down
8 changes: 7 additions & 1 deletion src/braket/experimental/autoqasm/transpiler/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@
reaching_definitions,
)
from braket.experimental.autoqasm.autograph.tf_utils import tf_stack
from braket.experimental.autoqasm.converters import assignments, break_statements, return_statements
from braket.experimental.autoqasm.converters import (
assignments,
break_statements,
comparisons,
return_statements,
)


class PyToOqpy(transpiler.PyToPy):
Expand Down Expand Up @@ -139,6 +144,7 @@ def transform_ast(
# canonicalization creates.
node = continue_statements.transform(node, ctx)
node = return_statements.transform(node, ctx)
node = comparisons.transform(node, ctx)
node = assignments.transform(node, ctx)
node = lists.transform(node, ctx)
node = slices.transform(node, ctx)
Expand Down
22 changes: 12 additions & 10 deletions test/unit_tests/braket/experimental/autoqasm/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,8 @@ def parametric():
parametric().make_bound_program({"beta": 0.5}, strict=True)


def test_parameter_as_condition():
"""Test parameters used in conditional statements."""
def test_parameter_as_condition_gt():
"""Test parameters used in greater than conditional statements."""

@aq.main
def parametric(val: float):
Expand All @@ -500,14 +500,14 @@ def parametric(val: float):

expected = """OPENQASM 3.0;
input float[64] val;
float[64] threshold;
qubit[1] __qubits__;
threshold = 0.9;
if (val > threshold) {
bool __bool_0__;
__bool_0__ = val > 0.9;
if (__bool_0__) {
x __qubits__[0];
}
bit __bit_0__;
__bit_0__ = measure __qubits__[0];"""
bit __bit_1__;
__bit_1__ = measure __qubits__[0];"""
assert parametric(FreeParameter("val")).to_ir() == expected


Expand All @@ -527,15 +527,17 @@ def parametric(val: float):

expected = """OPENQASM 3.0;
def sub(float[64] val) {
if (val > 0.9) {
bool __bool_0__;
__bool_0__ = val > 0.9;
if (__bool_0__) {
x __qubits__[0];
}
}
input float[64] val;
qubit[1] __qubits__;
sub(val);
bit __bit_0__;
__bit_0__ = measure __qubits__[0];"""
bit __bit_1__;
__bit_1__ = measure __qubits__[0];"""
assert parametric(FreeParameter("val")).to_ir() == expected


Expand Down

0 comments on commit 2245cbb

Please sign in to comment.