From 77690e63284cadbf84efbb2bb9fdf0bd6b6746ae Mon Sep 17 00:00:00 2001 From: Lauren Capelluto Date: Fri, 10 Nov 2023 15:55:02 -0500 Subject: [PATCH] feature: Add support for FreeParameters in AutoQASM conditional statements (#789) * feature: FreeParameters in conditional statements * Update type check to include FreeParameter as a qasm type * Add support for free parameters in logical operations * Add support for comparison statements --------- Co-authored-by: Ryan Shaffer <3620100+rmshaffer@users.noreply.github.com> --- .../autoqasm/converters/comparisons.py | 70 +++++ src/braket/experimental/autoqasm/errors.py | 4 + .../autoqasm/operators/__init__.py | 1 + .../autoqasm/operators/comparisons.py | 126 ++++++++ .../autoqasm/operators/logical.py | 12 + .../experimental/autoqasm/operators/utils.py | 47 +++ .../experimental/autoqasm/program/program.py | 36 ++- .../autoqasm/transpiler/transpiler.py | 8 +- .../experimental/autoqasm/types/types.py | 6 +- .../braket/experimental/autoqasm/test_api.py | 8 +- .../experimental/autoqasm/test_operators.py | 67 ++++ .../experimental/autoqasm/test_parameters.py | 297 +++++++++++++++++- .../experimental/autoqasm/test_program.py | 9 + 13 files changed, 673 insertions(+), 18 deletions(-) create mode 100644 src/braket/experimental/autoqasm/converters/comparisons.py create mode 100644 src/braket/experimental/autoqasm/operators/comparisons.py create mode 100644 src/braket/experimental/autoqasm/operators/utils.py diff --git a/src/braket/experimental/autoqasm/converters/comparisons.py b/src/braket/experimental/autoqasm/converters/comparisons.py new file mode 100644 index 000000000..7b8d53a5e --- /dev/null +++ b/src/braket/experimental/autoqasm/converters/comparisons.py @@ -0,0 +1,70 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +"""Converters for comparison nodes.""" + +import ast + +import gast + +from braket.experimental.autoqasm.autograph.core import ag_ctx, converter +from braket.experimental.autoqasm.autograph.pyct import templates + +COMPARISON_OPERATORS = { + gast.Lt: "ag__.lt_", + gast.LtE: "ag__.lteq_", + gast.Gt: "ag__.gt_", + gast.GtE: "ag__.gteq_", +} + + +class ComparisonTransformer(converter.Base): + """Transformer for comparison nodes.""" + + def visit_Compare(self, node: ast.stmt) -> ast.stmt: + """Transforms a comparison node. + + Args: + node (ast.stmt): AST node to transform. + + Returns: + ast.stmt: Transformed node. + """ + node = self.generic_visit(node) + + op_type = type(node.ops[0]) + if op_type not in COMPARISON_OPERATORS: + return node + + template = f"{COMPARISON_OPERATORS[op_type]}(lhs_, rhs_)" + + return templates.replace( + template, + lhs_=node.left, + rhs_=node.comparators[0], + original=node, + )[0].value + + +def transform(node: ast.stmt, ctx: ag_ctx.ControlStatusCtx) -> ast.stmt: + """Transform comparison nodes. + + Args: + node (ast.stmt): AST node to transform. + ctx (ag_ctx.ControlStatusCtx): Transformer context. + + Returns: + ast.stmt: Transformed node. + """ + + return ComparisonTransformer(ctx).visit(node) diff --git a/src/braket/experimental/autoqasm/errors.py b/src/braket/experimental/autoqasm/errors.py index 5a6b1edef..33bd23c4a 100644 --- a/src/braket/experimental/autoqasm/errors.py +++ b/src/braket/experimental/autoqasm/errors.py @@ -37,6 +37,10 @@ class MissingParameterTypeError(AutoQasmError): """AutoQASM requires type hints for subroutine parameters.""" +class ParameterNotFoundError(AutoQasmError): + """A FreeParameter could not be found in the program.""" + + class InvalidGateDefinition(AutoQasmError): """Gate definition does not meet the necessary requirements.""" diff --git a/src/braket/experimental/autoqasm/operators/__init__.py b/src/braket/experimental/autoqasm/operators/__init__.py index 853b8ce07..75b44c879 100644 --- a/src/braket/experimental/autoqasm/operators/__init__.py +++ b/src/braket/experimental/autoqasm/operators/__init__.py @@ -27,6 +27,7 @@ ) from .assignments import assign_stmt # noqa: F401 +from .comparisons import gt_, gteq_, lt_, lteq_ # noqa: F401 from .conditional_expressions import if_exp # noqa: F401 from .control_flow import for_stmt, if_stmt, while_stmt # noqa: F401 from .data_structures import ListPopOpts # noqa: F401 diff --git a/src/braket/experimental/autoqasm/operators/comparisons.py b/src/braket/experimental/autoqasm/operators/comparisons.py new file mode 100644 index 000000000..6f08f29e9 --- /dev/null +++ b/src/braket/experimental/autoqasm/operators/comparisons.py @@ -0,0 +1,126 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + + +"""Operators for comparison operators: <, <=, >, and >=.""" + +from typing import Any, Union + +from braket.experimental.autoqasm import program +from braket.experimental.autoqasm import types as aq_types + +from .utils import _register_and_convert_parameters + + +def lt_(a: Any, b: Any) -> Union[bool, aq_types.BoolVar]: + """Functional form of "<". + + Args: + a (Any): The first expression. + b (Any): The second expression. + + Returns: + Union[bool, BoolVar]: Whether the first expression is less than the second. + """ + if aq_types.is_qasm_type(a) or aq_types.is_qasm_type(b): + return _aq_lt(a, b) + else: + return a < b + + +def _aq_lt(a: Any, b: Any) -> aq_types.BoolVar: + a, b = _register_and_convert_parameters(a, b) + + oqpy_program = program.get_program_conversion_context().get_oqpy_program() + result = aq_types.BoolVar() + oqpy_program.declare(result) + oqpy_program.set(result, a < b) + return result + + +def lteq_(a: Any, b: Any) -> Union[bool, aq_types.BoolVar]: + """Functional form of "<=". + + Args: + a (Any): The first expression. + b (Any): The second expression. + + Returns: + Union[bool, BoolVar]: Whether the first expression is less than or equal to the second. + """ + if aq_types.is_qasm_type(a) or aq_types.is_qasm_type(b): + return _aq_lteq(a, b) + else: + return a <= b + + +def _aq_lteq(a: Any, b: Any) -> aq_types.BoolVar: + a, b = _register_and_convert_parameters(a, b) + + oqpy_program = program.get_program_conversion_context().get_oqpy_program() + result = aq_types.BoolVar() + oqpy_program.declare(result) + oqpy_program.set(result, a <= b) + return result + + +def gt_(a: Any, b: Any) -> Union[bool, aq_types.BoolVar]: + """Functional form of ">". + + Args: + a (Any): The first expression. + b (Any): The second expression. + + Returns: + Union[bool, BoolVar]: Whether the first expression is greater than the second. + """ + if aq_types.is_qasm_type(a) or aq_types.is_qasm_type(b): + return _aq_gt(a, b) + else: + return a > b + + +def _aq_gt(a: Any, b: Any) -> aq_types.BoolVar: + a, b = _register_and_convert_parameters(a, b) + + oqpy_program = program.get_program_conversion_context().get_oqpy_program() + result = aq_types.BoolVar() + oqpy_program.declare(result) + oqpy_program.set(result, a > b) + return result + + +def gteq_(a: Any, b: Any) -> Union[bool, aq_types.BoolVar]: + """Functional form of ">=". + + Args: + a (Any): The first expression. + b (Any): The second expression. + + Returns: + Union[bool, BoolVar]: Whether the first expression is greater than or equal to the second. + """ + if aq_types.is_qasm_type(a) or aq_types.is_qasm_type(b): + return _aq_gteq(a, b) + else: + return a >= b + + +def _aq_gteq(a: Any, b: Any) -> aq_types.BoolVar: + a, b = _register_and_convert_parameters(a, b) + + oqpy_program = program.get_program_conversion_context().get_oqpy_program() + result = aq_types.BoolVar() + oqpy_program.declare(result) + oqpy_program.set(result, a >= b) + return result diff --git a/src/braket/experimental/autoqasm/operators/logical.py b/src/braket/experimental/autoqasm/operators/logical.py index 6813e090a..c7bcf7c57 100644 --- a/src/braket/experimental/autoqasm/operators/logical.py +++ b/src/braket/experimental/autoqasm/operators/logical.py @@ -22,6 +22,8 @@ from braket.experimental.autoqasm import program from braket.experimental.autoqasm import types as aq_types +from .utils import _register_and_convert_parameters + def and_(a: Callable[[], Any], b: Callable[[], Any]) -> Union[bool, aq_types.BoolVar]: """Functional form of "and". @@ -42,6 +44,8 @@ def and_(a: Callable[[], Any], b: Callable[[], Any]) -> Union[bool, aq_types.Boo def _oqpy_and(a: Any, b: Any) -> aq_types.BoolVar: + a, b = _register_and_convert_parameters(a, b) + oqpy_program = program.get_program_conversion_context().get_oqpy_program() result = aq_types.BoolVar() oqpy_program.declare(result) @@ -72,6 +76,8 @@ def or_(a: Callable[[], Any], b: Callable[[], Any]) -> Union[bool, aq_types.Bool def _oqpy_or(a: Any, b: Any) -> aq_types.BoolVar: + a, b = _register_and_convert_parameters(a, b) + oqpy_program = program.get_program_conversion_context().get_oqpy_program() result = aq_types.BoolVar() oqpy_program.declare(result) @@ -99,6 +105,8 @@ def not_(a: Any) -> Union[bool, aq_types.BoolVar]: def _oqpy_not(a: Any) -> aq_types.BoolVar: + a = _register_and_convert_parameters(a) + oqpy_program = program.get_program_conversion_context().get_oqpy_program() result = aq_types.BoolVar() oqpy_program.declare(result) @@ -127,6 +135,8 @@ def eq(a: Any, b: Any) -> Union[bool, aq_types.BoolVar]: def _oqpy_eq(a: Any, b: Any) -> aq_types.BoolVar: + a, b = _register_and_convert_parameters(a, b) + oqpy_program = program.get_program_conversion_context().get_oqpy_program() is_equal = aq_types.BoolVar() oqpy_program.declare(is_equal) @@ -155,6 +165,8 @@ def not_eq(a: Any, b: Any) -> Union[bool, aq_types.BoolVar]: def _oqpy_not_eq(a: Any, b: Any) -> aq_types.BoolVar: + a, b = _register_and_convert_parameters(a, b) + oqpy_program = program.get_program_conversion_context().get_oqpy_program() is_not_equal = aq_types.BoolVar() oqpy_program.declare(is_not_equal) diff --git a/src/braket/experimental/autoqasm/operators/utils.py b/src/braket/experimental/autoqasm/operators/utils.py new file mode 100644 index 000000000..a24b502ed --- /dev/null +++ b/src/braket/experimental/autoqasm/operators/utils.py @@ -0,0 +1,47 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + + +"Utility methods for operators." + +from typing import Any, Union + +from braket.circuits import FreeParameter +from braket.experimental.autoqasm import program +from braket.experimental.autoqasm import types as aq_types + + +def _register_and_convert_parameters( + *args: tuple[Any], +) -> Union[list[aq_types.FloatVar], aq_types.FloatVar]: + """Adds FreeParameters to the program conversion context parameter registry, and + returns the associated FloatVar objects. + + Notes: Adding a parameter to the registry twice is safe. Conversion is a pass through + for non-FreeParameter inputs. Input and output arity is the same. + + FloatVars are more compatible with the program conversion operations. + + Returns: + Union[list[FloatVar], FloatVar]: FloatVars for program conversion. + """ + program_conversion_context = program.get_program_conversion_context() + program_conversion_context.register_args(args) + result = [] + for arg in args: + if isinstance(arg, FreeParameter): + var = program.get_program_conversion_context().get_parameter(arg.name) + result.append(var) + else: + result.append(arg) + return result[0] if len(result) == 1 else result diff --git a/src/braket/experimental/autoqasm/program/program.py b/src/braket/experimental/autoqasm/program/program.py index bd7416bc2..9f2a241b8 100644 --- a/src/braket/experimental/autoqasm/program/program.py +++ b/src/braket/experimental/autoqasm/program/program.py @@ -130,11 +130,12 @@ def make_bound_program(self, param_values: dict[str, float], strict: bool = Fals Args: param_values (dict[str, float]): A mapping of FreeParameter names to a value to assign to them. - strict (bool): If True, raises a ValueError if any of the FreeParameters + strict (bool): If True, raises a ParameterNotFoundError if any of the FreeParameters in param_values do not appear in the program. False by default. Raises: - ValueError: If a parameter name is given which does not appear in the program. + ParameterNotFoundError: If a parameter name is given which does not appear in + the program. Returns: Program: Returns a program with all present parameters fixed to their respective @@ -148,7 +149,7 @@ def make_bound_program(self, param_values: dict[str, float], strict: bool = Fals assert target.init_expression == "input", "Only free parameters can be bound." target.init_expression = value elif strict: - raise ValueError(f"No parameter in the program named: {name}") + raise errors.ParameterNotFoundError(f"No parameter in the program named: {name}") return Program(bound_oqpy_program, self._has_pulse_control) @@ -321,22 +322,39 @@ def register_args(self, args: list[Any]) -> None: """ for arg in args: if isinstance(arg, FreeParameter): - self.register_parameter(arg.name) + self.register_parameter(arg) elif isinstance(arg, FreeParameterExpression): # TODO laurecap: Support for expressions raise NotImplementedError( "Expressions of FreeParameters will be supported shortly!" ) - def register_parameter(self, name: str) -> None: - """Register an input parameter with the given name, if it has not already been - registered. Only floats are currently supported. + def register_parameter(self, parameter: FreeParameter) -> None: + """Register an input parameter if it has not already been registered. + Only floats are currently supported. Args: - name (str): The identifier for the parameter. + parameter (FreeParameter): The parameter to register with the program. + """ + if parameter.name not in self._free_parameters: + self._free_parameters[parameter.name] = oqpy.FloatVar("input", name=parameter.name) + + def get_parameter(self, name: str) -> oqpy.FloatVar: + """Return a named oqpy.FloatVar that is used as a free parameter in the program. + + Args: + name (str): The name of the parameter. + + Raises: + ParameterNotFoundError: If there is no parameter with the given name registered + with the program. + + Returns: + FloatVar: The associated variable. """ if name not in self._free_parameters: - self._free_parameters[name] = oqpy.FloatVar("input", name=name) + raise errors.ParameterNotFoundError(f"Free parameter '{name}' was not found.") + return self._free_parameters[name] def get_free_parameters(self) -> list[oqpy.FloatVar]: """Return a list of named oqpy.Vars that are used as free parameters in the program.""" diff --git a/src/braket/experimental/autoqasm/transpiler/transpiler.py b/src/braket/experimental/autoqasm/transpiler/transpiler.py index 2dc0dbfc8..46cea08de 100644 --- a/src/braket/experimental/autoqasm/transpiler/transpiler.py +++ b/src/braket/experimental/autoqasm/transpiler/transpiler.py @@ -59,7 +59,12 @@ reaching_definitions, ) from braket.experimental.autoqasm.autograph.tf_utils import tf_stack -from braket.experimental.autoqasm.converters import assignments, break_statements, return_statements +from braket.experimental.autoqasm.converters import ( + assignments, + break_statements, + comparisons, + return_statements, +) class PyToOqpy(transpiler.PyToPy): @@ -145,6 +150,7 @@ def transform_ast( node = call_trees.transform(node, ctx) node = control_flow.transform(node, ctx) node = conditional_expressions.transform(node, ctx) + node = comparisons.transform(node, ctx) node = logical_expressions.transform(node, ctx) node = variables.transform(node, ctx) diff --git a/src/braket/experimental/autoqasm/types/types.py b/src/braket/experimental/autoqasm/types/types.py index fcfaea88f..740d76b25 100644 --- a/src/braket/experimental/autoqasm/types/types.py +++ b/src/braket/experimental/autoqasm/types/types.py @@ -19,6 +19,7 @@ import oqpy.base from openpulse import ast +from braket.circuits import FreeParameterExpression from braket.experimental.autoqasm import errors, program @@ -32,11 +33,12 @@ def is_qasm_type(val: Any) -> bool: Returns: bool: Whether the object is a QASM type. """ + qasm_types = (oqpy.Range, oqpy._ClassicalVar, oqpy.base.OQPyExpression, FreeParameterExpression) # The input can either be a class, like oqpy.Range ... if type(val) is type: - return issubclass(val, (oqpy.Range, oqpy._ClassicalVar, oqpy.base.OQPyExpression)) + return issubclass(val, qasm_types) # ... or an instance of a class, like oqpy.Range(10) - return isinstance(val, (oqpy.Range, oqpy._ClassicalVar, oqpy.base.OQPyExpression)) + return isinstance(val, qasm_types) def qasm_range(start: int, stop: Optional[int] = None, step: Optional[int] = 1) -> oqpy.Range: diff --git a/test/unit_tests/braket/experimental/autoqasm/test_api.py b/test/unit_tests/braket/experimental/autoqasm/test_api.py index f5f2d19c5..b757033b5 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_api.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_api.py @@ -131,7 +131,9 @@ def do_h(int[32] q) { } def recursive_h(int[32] q) { do_h(q); - if (q > 0) { + bool __bool_0__; + __bool_0__ = q > 0; + if (__bool_0__) { recursive_h(q - 1); } } @@ -155,7 +157,9 @@ def do_h(int[32] q) { } def recursive_h(int[32] q) { do_h(q); - if (q > 0) { + bool __bool_0__; + __bool_0__ = q > 0; + if (__bool_0__) { recursive_h(q - 1); } } diff --git a/test/unit_tests/braket/experimental/autoqasm/test_operators.py b/test/unit_tests/braket/experimental/autoqasm/test_operators.py index dba42099c..ca85fec34 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_operators.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_operators.py @@ -487,6 +487,73 @@ def prog(): assert prog().to_ir() == expected +def test_comparison_lt() -> None: + """Tests less than operator handling.""" + + @aq.main + def prog(): + a = measure(0) + if a < 1: + h(0) + + expected = """OPENQASM 3.0; +bit a; +qubit[1] __qubits__; +bit __bit_0__; +__bit_0__ = measure __qubits__[0]; +a = __bit_0__; +bool __bool_1__; +__bool_1__ = a < 1; +if (__bool_1__) { + h __qubits__[0]; +}""" + qasm = prog().to_ir() + assert qasm == expected + + +def test_comparison_gt() -> None: + """Tests greater than operator handling.""" + + @aq.main + def prog(): + a = measure(0) + if a > 1: + h(0) + + expected = """OPENQASM 3.0; +bit a; +qubit[1] __qubits__; +bit __bit_0__; +__bit_0__ = measure __qubits__[0]; +a = __bit_0__; +bool __bool_1__; +__bool_1__ = a > 1; +if (__bool_1__) { + h __qubits__[0]; +}""" + qasm = prog().to_ir() + assert qasm == expected + + +def test_comparison_ops_py() -> None: + """Tests the comparison aq.operators for Python expressions.""" + + @aq.main + def prog(): + a = 1.2 + b = 12 + c = a < b + d = a <= b + e = a > b + f = a >= b + g = 1.2 + h = a <= g + assert all([c, d, not e, not f, h]) + + expected = """OPENQASM 3.0;""" + assert prog().to_ir() == expected + + @pytest.mark.parametrize( "target", [oqpy.ArrayVar(dimensions=[3], name="arr"), oqpy.BitVar(size=3, name="arr")] ) diff --git a/test/unit_tests/braket/experimental/autoqasm/test_parameters.py b/test/unit_tests/braket/experimental/autoqasm/test_parameters.py index 0a43b32e6..6633a36ad 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_parameters.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_parameters.py @@ -21,13 +21,13 @@ from braket.default_simulator import StateVectorSimulator from braket.devices.local_simulator import LocalSimulator from braket.experimental.autoqasm import pulse -from braket.experimental.autoqasm.instructions import cnot, cphaseshift, measure, ms, rx, rz +from braket.experimental.autoqasm.instructions import cnot, cphaseshift, h, measure, ms, rx, rz, x from braket.tasks.local_quantum_task import LocalQuantumTask def _test_parametric_on_local_sim(program: aq.Program, inputs: dict[str, float]) -> None: device = LocalSimulator(backend=StateVectorSimulator()) - task = device.run(program, shots=10, inputs=inputs) + task = device.run(program, shots=100, inputs=inputs) assert isinstance(task, LocalQuantumTask) assert isinstance(task.result().measurements, dict) return task.result().measurements @@ -461,7 +461,9 @@ def parametric(theta: float): measure(0) prog = parametric(FreeParameter("alpha")) - with pytest.raises(ValueError, match="No parameter in the program named: beta"): + with pytest.raises( + aq.errors.ParameterNotFoundError, match="No parameter in the program named: beta" + ): prog.make_bound_program({"beta": 0.5}, strict=True) @@ -484,5 +486,292 @@ def test_binding_variable_fails(): def parametric(): alpha = aq.FloatVar(1.2) # noqa: F841 - with pytest.raises(ValueError, match="No parameter in the program named: beta"): + with pytest.raises( + aq.errors.ParameterNotFoundError, match="No parameter in the program named: beta" + ): parametric().make_bound_program({"beta": 0.5}, strict=True) + + +def test_compound_condition(): + """Test parameters used in greater than conditional statements.""" + + @aq.main + def parametric(val: float): + threshold = 0.9 + if val > threshold or val >= 1.2: + x(0) + measure(0) + + expected = """OPENQASM 3.0; +input float[64] val; +qubit[1] __qubits__; +bool __bool_0__; +__bool_0__ = val > 0.9; +bool __bool_1__; +__bool_1__ = val >= 1.2; +bool __bool_2__; +__bool_2__ = __bool_0__ || __bool_1__; +if (__bool_2__) { + x __qubits__[0]; +} +bit __bit_3__; +__bit_3__ = measure __qubits__[0];""" + assert parametric(FreeParameter("val")).to_ir() == expected + + +def test_lt_condition(): + """Test parameters used in less than conditional statements.""" + + @aq.main + def parametric(val: float): + if val < 0.9: + x(0) + if val <= 0.9: + h(0) + measure(0) + + expected = """OPENQASM 3.0; +input float[64] val; +qubit[1] __qubits__; +bool __bool_0__; +__bool_0__ = val < 0.9; +if (__bool_0__) { + x __qubits__[0]; +} +bool __bool_1__; +__bool_1__ = val <= 0.9; +if (__bool_1__) { + h __qubits__[0]; +} +bit __bit_2__; +__bit_2__ = measure __qubits__[0];""" + assert parametric(FreeParameter("val")).to_ir() == expected + + +def test_parameter_in_predicate_in_subroutine(): + """Test parameters used in conditional statements.""" + + @aq.subroutine + def sub(val: float): + threshold = 0.9 + if val > threshold: + x(0) + + @aq.main + def parametric(val: float): + sub(val) + measure(0) + + expected = """OPENQASM 3.0; +def sub(float[64] val) { + bool __bool_0__; + __bool_0__ = val > 0.9; + if (__bool_0__) { + x __qubits__[0]; + } +} +input float[64] val; +qubit[1] __qubits__; +sub(val); +bit __bit_1__; +__bit_1__ = measure __qubits__[0];""" + assert parametric(FreeParameter("val")).to_ir() == expected + + +def test_eq_condition(): + """Test parameters used in conditional equals statements.""" + + @aq.main + def parametric(basis: int): + if basis == 1: + h(0) + elif basis == 2: + x(0) + else: + pass + measure(0) + + expected = """OPENQASM 3.0; +input float[64] basis; +qubit[1] __qubits__; +bool __bool_0__; +__bool_0__ = basis == 1; +if (__bool_0__) { + h __qubits__[0]; +} else { + bool __bool_1__; + __bool_1__ = basis == 2; + if (__bool_1__) { + x __qubits__[0]; + } +} +bit __bit_2__; +__bit_2__ = measure __qubits__[0];""" + assert parametric(FreeParameter("basis")).to_ir() == expected + + +def test_sim_conditional_stmts(): + @aq.main + def main(basis: int): + if basis == 1: + h(0) + else: + x(0) + c = measure(0) # noqa: F841 + + measurements = _test_parametric_on_local_sim(main(FreeParameter("basis")), {"basis": 0}) + assert all(val == 1 for val in measurements["c"]) + measurements = _test_parametric_on_local_sim(main(FreeParameter("basis")), {"basis": 1}) + assert 1 in measurements["c"] and 0 in measurements["c"] + + +def test_sim_comparison_stmts(): + @aq.main + def main(basis: int): + if basis > 0.5: + x(0) + c = measure(0) # noqa: F841 + + measurements = _test_parametric_on_local_sim(main(FreeParameter("basis")), {"basis": 0.5}) + assert all(val == 0 for val in measurements["c"]) + measurements = _test_parametric_on_local_sim(main(FreeParameter("basis")), {"basis": 0.55}) + assert all(val == 1 for val in measurements["c"]) + + +def test_param_neq(): + """Test parameters used in conditional not equals statements.""" + + @aq.main + def parametric(val: int): + if val != 1: + h(0) + measure(0) + + expected = """OPENQASM 3.0; +input float[64] val; +qubit[1] __qubits__; +bool __bool_0__; +__bool_0__ = val != 1; +if (__bool_0__) { + h __qubits__[0]; +} +bit __bit_1__; +__bit_1__ = measure __qubits__[0];""" + assert parametric(FreeParameter("val")).to_ir() == expected + + +def test_param_or(): + """Test parameters used in conditional `or` statements.""" + + @aq.main + def parametric(alpha: float, beta: float): + if alpha or beta: + rx(0, alpha) + rx(0, beta) + measure(0) + + expected = """OPENQASM 3.0; +input float[64] alpha; +input float[64] beta; +qubit[1] __qubits__; +bool __bool_0__; +__bool_0__ = alpha || beta; +if (__bool_0__) { + rx(alpha) __qubits__[0]; + rx(beta) __qubits__[0]; +} +bit __bit_1__; +__bit_1__ = measure __qubits__[0];""" + assert parametric(FreeParameter("alpha"), FreeParameter("beta")).to_ir() == expected + + +def test_param_and(): + """Test parameters used in conditional `and` statements.""" + + @aq.main + def parametric(alpha: float, beta: float): + if alpha and beta: + rx(0, alpha) + measure(0) + + expected = """OPENQASM 3.0; +input float[64] alpha; +input float[64] beta; +qubit[1] __qubits__; +bool __bool_0__; +__bool_0__ = alpha && beta; +if (__bool_0__) { + rx(alpha) __qubits__[0]; +} +bit __bit_1__; +__bit_1__ = measure __qubits__[0];""" + assert parametric(FreeParameter("alpha"), FreeParameter("beta")).to_ir() == expected + + +def test_param_and_float(): + """Test parameters used in conditional `and` statements.""" + + @aq.main + def parametric(alpha: float, beta: float): + if alpha and beta: + rx(0, alpha) + measure(0) + + expected = """OPENQASM 3.0; +input float[64] alpha; +qubit[1] __qubits__; +bool __bool_0__; +__bool_0__ = alpha && 1.5; +if (__bool_0__) { + rx(alpha) __qubits__[0]; +} +bit __bit_1__; +__bit_1__ = measure __qubits__[0];""" + assert parametric(FreeParameter("alpha"), 1.5).to_ir() == expected + + +def test_param_not(): + """Test parameters used in conditional `not` statements.""" + + @aq.main + def parametric(val: int): + if not val: + h(0) + measure(0) + + expected = """OPENQASM 3.0; +input float[64] val; +qubit[1] __qubits__; +bool __bool_0__; +__bool_0__ = !val; +if (__bool_0__) { + h __qubits__[0]; +} +bit __bit_1__; +__bit_1__ = measure __qubits__[0];""" + assert parametric(FreeParameter("val")).to_ir() == expected + + +def test_parameter_binding_conditions(): + """Test that parameters can be used in conditions and then bound.""" + + @aq.main + def parametric(val: float): + if val == 1: + x(0) + measure(0) + + template = """OPENQASM 3.0; +float[64] val = {}; +qubit[1] __qubits__; +bool __bool_0__; +__bool_0__ = val == 1; +if (__bool_0__) {{ + x __qubits__[0]; +}} +bit __bit_1__; +__bit_1__ = measure __qubits__[0];""" + bound_prog = parametric(FreeParameter("val")).make_bound_program({"val": 0}) + assert bound_prog.to_ir() == template.format(0) + bound_prog = parametric(FreeParameter("val")).make_bound_program({"val": 1}) + assert bound_prog.to_ir() == template.format(1) diff --git a/test/unit_tests/braket/experimental/autoqasm/test_program.py b/test/unit_tests/braket/experimental/autoqasm/test_program.py index 274e70d6c..04dde2601 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_program.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_program.py @@ -20,6 +20,7 @@ import pytest import braket.experimental.autoqasm as aq +from braket.circuits import FreeParameter from braket.circuits.serialization import IRType from braket.experimental.autoqasm.instructions import cnot, measure, rx @@ -39,6 +40,14 @@ def test_program_conversion_context() -> None: assert len(prog._oqpy_program_stack) == 1 +def test_get_parameter_invalid_name(): + """Tests the get_parameter function.""" + prog = aq.program.ProgramConversionContext() + prog.register_parameter(FreeParameter("alpha")) + with pytest.raises(aq.errors.ParameterNotFoundError): + prog.get_parameter("not_a_parameter") + + def test_build_program() -> None: """Tests the aq.build_program function.""" with pytest.raises(AssertionError):