diff --git a/src/braket/experimental/autoqasm/errors.py b/src/braket/experimental/autoqasm/errors.py index 68c7e809f..5a6b1edef 100644 --- a/src/braket/experimental/autoqasm/errors.py +++ b/src/braket/experimental/autoqasm/errors.py @@ -100,6 +100,10 @@ def __str__(self): return self.message +class InvalidAssignmentStatement(AutoQasmError): + """Invalid assignment statement for an AutoQASM variable.""" + + class InvalidArrayDeclaration(AutoQasmError): """Invalid declaration of an AutoQASM array variable.""" diff --git a/src/braket/experimental/autoqasm/operators/assignments.py b/src/braket/experimental/autoqasm/operators/assignments.py index 8cc78858c..e78c47844 100644 --- a/src/braket/experimental/autoqasm/operators/assignments.py +++ b/src/braket/experimental/autoqasm/operators/assignments.py @@ -85,8 +85,13 @@ def assign_stmt(target_name: str, value: Any) -> Any: oqpy_program = program_conversion_context.get_oqpy_program() if is_value_name_used or value.init_expression is None: oqpy_program.set(target, value) + elif target.name not in oqpy_program.declared_vars and program_conversion_context.at_root_scope: + # Explicitly declare and initialize the variable at the root scope. + target.init_expression = value.init_expression + oqpy_program.declare(target) else: # Set to `value.init_expression` to avoid declaring an unnecessary variable. + # The variable will be set in the current scope and auto-declared at the root scope. oqpy_program.set(target, value.init_expression) return target @@ -113,10 +118,20 @@ def _validate_variables_type_size(var1: oqpy.base.Var, var2: oqpy.base.Var) -> N var1 (oqpy.base.Var): Variable to validate. var2 (oqpy.base.Var): Variable to validate. """ - var1_size = var1.size or 1 - var2_size = var2.size or 1 - if var_type_from_oqpy(var1) != var_type_from_oqpy(var2): - raise ValueError("Variables in assignment statements must have the same type") - if var1_size != var2_size: - raise ValueError("Variables in assignment statements must have the same size") + raise errors.InvalidAssignmentStatement( + "Variables in assignment statements must have the same type" + ) + + if isinstance(var1, oqpy.ArrayVar): + if var1.dimensions != var2.dimensions: + raise errors.InvalidAssignmentStatement( + "Arrays in assignment statements must have the same dimensions" + ) + else: + var1_size = var1.size or 1 + var2_size = var2.size or 1 + if var1_size != var2_size: + raise errors.InvalidAssignmentStatement( + "Variables in assignment statements must have the same size" + ) diff --git a/src/braket/experimental/autoqasm/operators/conditional_expressions.py b/src/braket/experimental/autoqasm/operators/conditional_expressions.py index 94a9e0f2b..27bfb7d21 100644 --- a/src/braket/experimental/autoqasm/operators/conditional_expressions.py +++ b/src/braket/experimental/autoqasm/operators/conditional_expressions.py @@ -52,20 +52,20 @@ def _oqpy_if_exp( ) -> Optional[oqpy.base.Var]: """Overload of if_exp that stages an oqpy conditional.""" result_var = None - oqpy_program = aq_program.get_program_conversion_context().get_oqpy_program() - with oqpy.If(oqpy_program, cond): + program_conversion_context = aq_program.get_program_conversion_context() + with program_conversion_context.if_block(cond): true_result = aq_types.wrap_value(if_true()) true_result_type = aq_types.var_type_from_oqpy(true_result) if true_result is not None: result_var = true_result_type() - oqpy_program.set(result_var, true_result) - with oqpy.Else(oqpy_program): + program_conversion_context.get_oqpy_program().set(result_var, true_result) + with program_conversion_context.else_block(): false_result = aq_types.wrap_value(if_false()) false_result_type = aq_types.var_type_from_oqpy(false_result) if false_result_type != true_result_type: raise UnsupportedConditionalExpressionError(true_result_type, false_result_type) if false_result is not None: - oqpy_program.set(result_var, false_result) + program_conversion_context.get_oqpy_program().set(result_var, false_result) return result_var diff --git a/src/braket/experimental/autoqasm/operators/control_flow.py b/src/braket/experimental/autoqasm/operators/control_flow.py index 33fb10965..8f454fd0b 100644 --- a/src/braket/experimental/autoqasm/operators/control_flow.py +++ b/src/braket/experimental/autoqasm/operators/control_flow.py @@ -44,37 +44,33 @@ def for_stmt( opts (dict): Options of the for loop. """ del get_state, set_state, symbol_names + if extra_test is not None: + raise NotImplementedError("break and return statements are not supported in for loops.") + if is_qasm_type(iter): - _oqpy_for_stmt(iter, extra_test, body, opts) + _oqpy_for_stmt(iter, body, opts) else: - _py_for_stmt(iter, extra_test, body) + _py_for_stmt(iter, body) def _oqpy_for_stmt( iter: oqpy.Range, - extra_test: Callable[[], Any], body: Callable[[Any], None], opts: dict, ) -> None: """Overload of for_stmt that produces an oqpy for loop.""" - oqpy_program = program.get_program_conversion_context().get_oqpy_program() - # TODO: Should check extra_test() on each iteration and break if False, - # but oqpy doesn't currently support break statements at the moment. - with oqpy.ForIn(oqpy_program, iter, opts["iterate_names"]) as f: + program_conversion_context = program.get_program_conversion_context() + with program_conversion_context.for_in(iter, opts["iterate_names"]) as f: body(f) def _py_for_stmt( iter: Iterable, - extra_test: Callable[[], Any], body: Callable[[Any], None], ) -> None: """Overload of for_stmt that executes a Python for loop.""" - if extra_test is not None: - raise NotImplementedError("break and return statements are not supported in for loops.") - else: - for target in iter: - body(target) + for target in iter: + body(target) def while_stmt( @@ -107,8 +103,8 @@ def _oqpy_while_stmt( body: Callable[[], None], ) -> None: """Overload of while_stmt that produces an oqpy while loop.""" - oqpy_program = program.get_program_conversion_context().get_oqpy_program() - with oqpy.While(oqpy_program, test()): + program_conversion_context = program.get_program_conversion_context() + with program_conversion_context.while_loop(test()): body() @@ -154,10 +150,10 @@ def _oqpy_if_stmt( orelse: Callable[[], Any], ) -> None: """Overload of if_stmt that stages an oqpy cond.""" - oqpy_program = program.get_program_conversion_context().get_oqpy_program() - with oqpy.If(oqpy_program, cond): + program_conversion_context = program.get_program_conversion_context() + with program_conversion_context.if_block(cond): body() - with oqpy.Else(oqpy_program): + with program_conversion_context.else_block(): orelse() diff --git a/src/braket/experimental/autoqasm/operators/slices.py b/src/braket/experimental/autoqasm/operators/slices.py index 0ed2ea47f..726abea84 100644 --- a/src/braket/experimental/autoqasm/operators/slices.py +++ b/src/braket/experimental/autoqasm/operators/slices.py @@ -20,7 +20,7 @@ import oqpy.base from braket.experimental.autoqasm import program -from braket.experimental.autoqasm.types import is_qasm_type +from braket.experimental.autoqasm.types import is_qasm_type, wrap_value class GetItemOpts(collections.namedtuple("GetItemOpts", ("element_dtype",))): @@ -86,9 +86,10 @@ def set_item(target: Any, i: Any, x: Any) -> Any: def _oqpy_set_item(target: Any, i: Any, x: Any) -> Any: """Overload of set_item that produces an oqpy list modification.""" - if not isinstance(target, oqpy.BitVar): - raise NotImplementedError("Slice assignment is not supported.") + if not isinstance(target, (oqpy.BitVar, oqpy.ArrayVar)): + raise NotImplementedError("Only BitVar and ArrayVar types support slice assignment.") + x = wrap_value(x) is_var_name_used = program.get_program_conversion_context().is_var_name_used(x.name) oqpy_program = program.get_program_conversion_context().get_oqpy_program() if is_var_name_used or x.init_expression is None: diff --git a/src/braket/experimental/autoqasm/program/program.py b/src/braket/experimental/autoqasm/program/program.py index b6eaed22e..32e1aee4f 100644 --- a/src/braket/experimental/autoqasm/program/program.py +++ b/src/braket/experimental/autoqasm/program/program.py @@ -199,6 +199,7 @@ def __init__(self, user_config: Optional[UserConfig] = None): self.user_config = user_config or UserConfig() self.return_variable = None self.in_verbatim_block = False + self.at_root_scope = True self._oqpy_program_stack = [oqpy.Program(simplify_constants=False)] self._gate_definitions_processing = [] self._calibration_definitions_processing = [] @@ -447,6 +448,74 @@ def push_oqpy_program(self, oqpy_program: oqpy.Program) -> None: finally: self._oqpy_program_stack.pop() + @contextlib.contextmanager + def if_block(self, condition: Any) -> None: + """Sets the program conversion context into an if block context. + + Args: + condition (Any): The condition of the if block. + """ + oqpy_program = self.get_oqpy_program() + current_in_global_scope = self.at_root_scope + try: + self.at_root_scope = False + with oqpy.If(oqpy_program, condition): + yield + finally: + self.at_root_scope = current_in_global_scope + + @contextlib.contextmanager + def else_block(self) -> None: + """Sets the program conversion context into an else block context. + Must be immediately preceded by an if block. + """ + oqpy_program = self.get_oqpy_program() + current_in_global_scope = self.at_root_scope + try: + self.at_root_scope = False + with oqpy.Else(oqpy_program): + yield + finally: + self.at_root_scope = current_in_global_scope + + @contextlib.contextmanager + def for_in( + self, iterator: oqpy.Range, iterator_name: Optional[str] + ) -> contextlib._GeneratorContextManager: + """Sets the program conversion context into a for loop context. + + Args: + iterator (oqpy.Range): The iterator of the for loop. + iterator_name (Optional[str]): The symbol to use as the name of the iterator. + + Yields: + _GeneratorContextManager: The context manager of the oqpy.ForIn block. + """ + oqpy_program = self.get_oqpy_program() + current_in_global_scope = self.at_root_scope + try: + self.at_root_scope = False + with oqpy.ForIn(oqpy_program, iterator, iterator_name) as f: + yield f + finally: + self.at_root_scope = current_in_global_scope + + @contextlib.contextmanager + def while_loop(self, condition: Any) -> None: + """Sets the program conversion context into a while loop context. + + Args: + condition (Any): The condition of the while loop. + """ + oqpy_program = self.get_oqpy_program() + current_in_global_scope = self.at_root_scope + try: + self.at_root_scope = False + with oqpy.While(oqpy_program, condition): + yield + finally: + self.at_root_scope = current_in_global_scope + @contextlib.contextmanager def gate_definition(self, gate_name: str, gate_args: GateArgs) -> None: """Sets the program conversion context into a gate definition context. diff --git a/src/braket/experimental/autoqasm/types/types.py b/src/braket/experimental/autoqasm/types/types.py index 68097df64..37f5c86fd 100644 --- a/src/braket/experimental/autoqasm/types/types.py +++ b/src/braket/experimental/autoqasm/types/types.py @@ -58,8 +58,13 @@ def qasm_range(start: int, stop: Optional[int] = None, step: Optional[int] = 1) class ArrayVar(oqpy.ArrayVar): def __init__(self, *args, **kwargs): - if program.get_program_conversion_context().subroutines_processing: - raise errors.InvalidArrayDeclaration("ArrayVar cannot be declared inside a subroutine.") + if ( + program.get_program_conversion_context().subroutines_processing + or not program.get_program_conversion_context().at_root_scope + ): + raise errors.InvalidArrayDeclaration( + "Arrays may only be declared at the root scope of an AutoQASM main function." + ) super(ArrayVar, self).__init__(*args, **kwargs) self.name = program.get_program_conversion_context().next_var_name(oqpy.ArrayVar) diff --git a/test/unit_tests/braket/experimental/autoqasm/test_api.py b/test/unit_tests/braket/experimental/autoqasm/test_api.py index 5d0973d40..057cc0cec 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_api.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_api.py @@ -229,9 +229,8 @@ def bell_measurement_declared() -> None: def test_bell_measurement_declared() -> None: expected = """OPENQASM 3.0; -bit[2] c; qubit[2] __qubits__; -c = "00"; +bit[2] c = "00"; h __qubits__[0]; cnot __qubits__[0], __qubits__[1]; bit[2] __bit_1__ = "00"; @@ -279,7 +278,7 @@ def bell_measurement_invalid_declared_type() -> None: def test_bell_measurement_invalid_declared_type() -> None: """Test measurement with reuslt stored in an variable with invalid type.""" expected_error_message = "Variables in assignment statements must have the same type" - with pytest.raises(ValueError) as exc_info: + with pytest.raises(errors.InvalidAssignmentStatement) as exc_info: bell_measurement_invalid_declared_type() assert expected_error_message in str(exc_info.value) @@ -298,7 +297,7 @@ def bell_measurement_invalid_declared_size() -> None: def test_bell_measurement_invalid_declared_size() -> None: """Test measurement with reuslt stored in an variable with invalid size.""" expected_error_message = "Variables in assignment statements must have the same size" - with pytest.raises(ValueError) as exc_info: + with pytest.raises(errors.InvalidAssignmentStatement) as exc_info: bell_measurement_invalid_declared_size() assert expected_error_message in str(exc_info.value) @@ -755,20 +754,15 @@ def classical_variables_types() -> None: def test_classical_variables_types(): expected = """OPENQASM 3.0; -bit a; -int[32] i; -bit[2] a_array; -int[32] b; -float[64] c; -a = 0; +bit a = 0; a = 1; -i = 1; -a_array = "00"; +int[32] i = 1; +bit[2] a_array = "00"; a_array[0] = 0; a_array[i] = 1; -b = 10; +int[32] b = 10; b = 15; -c = 1.2; +float[64] c = 1.2; c = 3.4;""" assert classical_variables_types().to_ir() == expected @@ -786,9 +780,8 @@ def prog() -> None: a = b # declared target, declared value # noqa: F841 expected = """OPENQASM 3.0; -int[32] a; int[32] b; -a = 1; +int[32] a = 1; a = 2; b = a; a = b;""" diff --git a/test/unit_tests/braket/experimental/autoqasm/test_converters.py b/test/unit_tests/braket/experimental/autoqasm/test_converters.py index 6e169e532..c22af68b5 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_converters.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_converters.py @@ -51,11 +51,9 @@ def fn() -> None: qasm = program_conversion_context.make_program().to_ir() expected_qasm = """OPENQASM 3.0; -int[32] a; -float[64] b; int[32] e; -a = 5; -b = 1.2; +int[32] a = 5; +float[64] b = 1.2; a = 1; e = a;""" assert qasm == expected_qasm diff --git a/test/unit_tests/braket/experimental/autoqasm/test_operators.py b/test/unit_tests/braket/experimental/autoqasm/test_operators.py index 66e066c5d..4f70a0ba9 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_operators.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_operators.py @@ -21,6 +21,7 @@ import pytest import braket.experimental.autoqasm as aq +from braket.experimental.autoqasm import errors from braket.experimental.autoqasm.errors import UnsupportedConditionalExpressionError from braket.experimental.autoqasm.instructions import cnot, h, measure, x @@ -114,7 +115,6 @@ def cond_exp_assignment(): lambda: aq.FloatVar(2), lambda: aq.BoolVar(False), lambda: aq.BitVar(0), - lambda: aq.ArrayVar(dimensions=[3]), ], ) def test_unsupported_inline_conditional_assignment(else_value) -> None: @@ -162,9 +162,8 @@ def branch_assignment_declared(): a = aq.IntVar(7) # noqa: F841 expected = """OPENQASM 3.0; -int[32] a; bool __bool_1__ = true; -a = 5; +int[32] a = 5; if (__bool_1__) { a = 6; } else { @@ -540,10 +539,8 @@ def slice(): a[3] = b expected = """OPENQASM 3.0; -bit[6] a; -bit b; -a = "000000"; -b = 1; +bit[6] a = "000000"; +bit b = 1; a[3] = b;""" assert slice().to_ir() == expected @@ -559,10 +556,9 @@ def measure_to_slice(): b0[3] = c expected = """OPENQASM 3.0; -bit[10] b0; bit c; qubit[1] __qubits__; -b0 = "0000000000"; +bit[10] b0 = "0000000000"; bit __bit_1__; __bit_1__ = measure __qubits__[0]; c = __bit_1__; @@ -574,9 +570,9 @@ def measure_to_slice(): @pytest.mark.parametrize( "target_name,value,expected_qasm", [ - ("foo", oqpy.IntVar(5), "\nint[32] foo;\nfoo = 5;"), - ("bar", oqpy.FloatVar(1.2), "\nfloat[64] bar;\nbar = 1.2;"), - ("baz", oqpy.BitVar(0), "\nbit baz;\nbaz = 0;"), + ("foo", oqpy.IntVar(5), "\nint[32] foo = 5;"), + ("bar", oqpy.FloatVar(1.2), "\nfloat[64] bar = 1.2;"), + ("baz", oqpy.BitVar(0), "\nbit baz = 0;"), ], ) def test_assignment_qasm_undeclared_target( @@ -658,7 +654,7 @@ def test_assignment_qasm_invalid_size_type( oqpy_program = program_conversion_context.get_oqpy_program() oqpy_program.declare(declared_var) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(errors.InvalidAssignmentStatement) as exc_info: _ = aq.operators.assign_stmt( target_name=target_name, value=value, diff --git a/test/unit_tests/braket/experimental/autoqasm/test_types.py b/test/unit_tests/braket/experimental/autoqasm/test_types.py index ecb632a4b..cc6f1f92a 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_types.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_types.py @@ -56,8 +56,7 @@ def main() -> aq.BitVar: expected = """OPENQASM 3.0; def ret_test() -> bit { - bit res; - res = 1; + bit res = 1; return res; } bit __bit_1__; @@ -80,8 +79,7 @@ def main() -> int: expected = """OPENQASM 3.0; def ret_test() -> int[32] { - int[32] res; - res = 1; + int[32] res = 1; return res; } int[32] __int_1__; @@ -104,8 +102,7 @@ def main() -> float: expected = """OPENQASM 3.0; def ret_test() -> float[64] { - float[64] res; - res = 1.0; + float[64] res = 1.0; return res; } float[64] __float_1__; @@ -128,8 +125,7 @@ def main() -> bool: expected = """OPENQASM 3.0; def ret_test() -> bool { - bool res; - res = true; + bool res = true; return res; } bool __bool_1__; @@ -155,10 +151,8 @@ def ret_test() -> int: def add(int[32] a, int[32] b) -> int[32] { return a + b; } -int[32] a; -int[32] b; -a = 5; -b = 6; +int[32] a = 5; +int[32] b = 6; int[32] __int_2__; __int_2__ = add(a, b);""" @@ -177,6 +171,52 @@ def ret_test() -> None: assert ret_test().to_ir() == expected +def test_declare_array(): + """Test declaring and assigning an array.""" + + @aq.main + def declare_array(): + a = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar, dimensions=[3]) + a[0] = 11 + b = aq.ArrayVar([4, 5, 6], base_type=aq.IntVar, dimensions=[3]) + b[2] = 14 + b = a + + expected = """OPENQASM 3.0; +array[int[32], 3] a = {1, 2, 3}; +a[0] = 11; +array[int[32], 3] b = {4, 5, 6}; +b[2] = 14; +b = a;""" + + assert declare_array().to_ir() == expected + + +def test_invalid_array_assignment(): + """Test invalid array assignment.""" + + @aq.main + def invalid(): + a = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar, dimensions=[3]) + b = aq.ArrayVar([4, 5], base_type=aq.IntVar, dimensions=[2]) + a = b # noqa: F841 + + with pytest.raises(aq.errors.InvalidAssignmentStatement): + invalid() + + +def test_declare_array_in_local_scope(): + """Test declaring an array inside a local scope.""" + + @aq.main + def declare_array(): + if aq.BoolVar(True): + _ = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar, dimensions=[3]) + + with pytest.raises(aq.errors.InvalidArrayDeclaration): + declare_array() + + def test_declare_array_in_subroutine(): """Test declaring an array inside a subroutine.""" @@ -236,8 +276,7 @@ def ret_test() -> int: expected = """OPENQASM 3.0; def helper() -> int[32] { - int[32] res; - res = 1; + int[32] res = 1; return res; } int[32] __int_1__; @@ -334,8 +373,7 @@ def main(): expected = """OPENQASM 3.0; def annotation_test(bit input) { } -bit a; -a = 1; +bit a = 1; annotation_test(a);""" assert main().to_ir() == expected @@ -376,12 +414,10 @@ def main(): expected = """OPENQASM 3.0; def assign_param(int[32] c) { - int[32] d; - d = 4; + int[32] d = 4; c = d; } -int[32] c; -c = 0; +int[32] c = 0; assign_param(c);""" assert main().to_ir() == expected @@ -400,8 +436,7 @@ def caller() -> int: expected_qasm = """OPENQASM 3.0; def retval_test() -> int[32] { - int[32] retval_; - retval_ = 1; + int[32] retval_ = 1; return retval_; } int[32] __int_1__; @@ -423,8 +458,7 @@ def caller() -> aq.BitVar: expected_qasm = """OPENQASM 3.0; def retval_test() -> bit { - bit retval_; - retval_ = 1; + bit retval_ = 1; return retval_; } bit __bit_1__; @@ -447,10 +481,9 @@ def main(): expected_qasm = """OPENQASM 3.0; def retval_recursive() -> int[32] { - int[32] retval_; int[32] __int_1__; __int_1__ = retval_recursive(); - retval_ = 1; + int[32] retval_ = 1; return retval_; } int[32] __int_3__; @@ -474,11 +507,10 @@ def main(): expected_qasm = """OPENQASM 3.0; def retval_recursive() -> int[32] { int[32] a; - int[32] retval_; int[32] __int_1__; __int_1__ = retval_recursive(); a = __int_1__; - retval_ = 1; + int[32] retval_ = 1; return retval_; } int[32] __int_3__; @@ -512,8 +544,7 @@ def retval_recursive() -> float[64] { return 2 * __float_1__ + (__int_3__ + 2) / 3; } def retval_constant() -> int[32] { - int[32] retval_; - retval_ = 3; + int[32] retval_ = 3; return retval_; } float[64] __float_4__; @@ -580,8 +611,7 @@ def main() -> bool: expected = """OPENQASM 3.0; def ret_test() -> bool { - bool retval_; - retval_ = true; + bool retval_ = true; return retval_; } bool __bool_1__; @@ -633,8 +663,7 @@ def main(): expected = """OPENQASM 3.0; def ret_test() -> float[64] { - float[64] retval_; - retval_ = 1.2; + float[64] retval_ = 1.2; return retval_; } qubit[4] __qubits__;