Skip to content

Commit

Permalink
Improvements from PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rmshaffer committed Oct 30, 2023
1 parent 828f399 commit d0b848a
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 80 deletions.
23 changes: 19 additions & 4 deletions src/braket/experimental/autoqasm/operators/assignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,37 @@ def assign_stmt(target_name: str, value: Any) -> Any:

if is_target_name_used:
target = _get_oqpy_program_variable(target_name)
_validate_variables_type_size(target, value)
_validate_assignment_types(target, value)
else:
target = copy.copy(value)
target.init_expression = None
target.name = target_name

oqpy_program = program_conversion_context.get_oqpy_program()
if is_value_name_used or value.init_expression is None:
# Directly assign the value to the target.
# For example:
# a = b;
# where `b` is previously declared.
oqpy_program.set(target, value)
elif target.name not in oqpy_program.declared_vars and program_conversion_context.at_root_scope:
elif (
target.name not in oqpy_program.declared_vars
and program_conversion_context.at_function_root_scope
):
# Explicitly declare and initialize the variable at the root scope.
# For example:
# int[32] a = 10;
# where `a` is at the root scope of the function (not inside any if/for/while block).
target.init_expression = value.init_expression
oqpy_program.declare(target)
else:
# Set to `value.init_expression` to avoid declaring an unnecessary variable.
# The variable will be set in the current scope and auto-declared at the root scope.
# For example, the `a = 1` and `a = 0` statements in the following:
# int[32] a;
# if (b == True) { a = 1; }
# else { a = 0; }
# where `b` is previously declared.
oqpy_program.set(target, value.init_expression)

return target
Expand All @@ -111,8 +126,8 @@ def _get_oqpy_program_variable(var_name: str) -> oqpy.base.Var:
return variables[var_name]


def _validate_variables_type_size(var1: oqpy.base.Var, var2: oqpy.base.Var) -> None:
"""Raise error when the size or type of the two variables do not match.
def _validate_assignment_types(var1: oqpy.base.Var, var2: oqpy.base.Var) -> None:
"""Validates that the size and type of the variables are compatible for assignment.
Args:
var1 (oqpy.base.Var): Variable to validate.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""Operators for other data structures (e.g. list)."""

import collections
from typing import Any, Iterable, Optional
from collections.abc import Iterable
from typing import Any, Optional


class ListPopOpts(collections.namedtuple("ListPopOpts", ("element_dtype", "element_shape"))):
Expand Down
10 changes: 9 additions & 1 deletion src/braket/experimental/autoqasm/operators/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

"""Operators for exception handling."""

from typing import Callable
from collections.abc import Callable

from braket.experimental.autoqasm.types import is_qasm_type


def assert_stmt(test: bool, message: Callable) -> None:
Expand All @@ -25,4 +27,10 @@ def assert_stmt(test: bool, message: Callable) -> None:
message (Callable): A function which returns the message to be used
if the assertion fails.
"""
if is_qasm_type(test):
raise NotImplementedError(
"Assertions are not supported for values that depend on "
"measurement results or AutoQASM variables."
)

assert test, message()
63 changes: 28 additions & 35 deletions src/braket/experimental/autoqasm/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def __init__(self, user_config: Optional[UserConfig] = None):
self.user_config = user_config or UserConfig()
self.return_variable = None
self.in_verbatim_block = False
self.at_root_scope = True
self.at_function_root_scope = True # whether we are at the root scope of main or subroutine
self._oqpy_program_stack = [oqpy.Program(simplify_constants=False)]
self._gate_definitions_processing = []
self._calibration_definitions_processing = []
Expand Down Expand Up @@ -449,36 +449,39 @@ def push_oqpy_program(self, oqpy_program: oqpy.Program) -> None:
self._oqpy_program_stack.pop()

@contextlib.contextmanager
def if_block(self, condition: Any) -> None:
def _control_flow_block(
self, _context_manager: contextlib._GeneratorContextManager
) -> contextlib._GeneratorContextManager:
original = self.at_function_root_scope
try:
self.at_function_root_scope = False
with _context_manager as _cm:
yield _cm
finally:
self.at_function_root_scope = original

def if_block(self, condition: Any) -> contextlib._GeneratorContextManager:
"""Sets the program conversion context into an if block context.
Args:
condition (Any): The condition of the if block.
Yields:
_GeneratorContextManager: The context manager of the oqpy.If block.
"""
oqpy_program = self.get_oqpy_program()
current_in_global_scope = self.at_root_scope
try:
self.at_root_scope = False
with oqpy.If(oqpy_program, condition):
yield
finally:
self.at_root_scope = current_in_global_scope
return self._control_flow_block(oqpy.If(oqpy_program, condition))

@contextlib.contextmanager
def else_block(self) -> None:
def else_block(self) -> contextlib._GeneratorContextManager:
"""Sets the program conversion context into an else block context.
Must be immediately preceded by an if block.
Yields:
_GeneratorContextManager: The context manager of the oqpy.Else block.
"""
oqpy_program = self.get_oqpy_program()
current_in_global_scope = self.at_root_scope
try:
self.at_root_scope = False
with oqpy.Else(oqpy_program):
yield
finally:
self.at_root_scope = current_in_global_scope
return self._control_flow_block(oqpy.Else(oqpy_program))

@contextlib.contextmanager
def for_in(
self, iterator: oqpy.Range, iterator_name: Optional[str]
) -> contextlib._GeneratorContextManager:
Expand All @@ -492,29 +495,19 @@ def for_in(
_GeneratorContextManager: The context manager of the oqpy.ForIn block.
"""
oqpy_program = self.get_oqpy_program()
current_in_global_scope = self.at_root_scope
try:
self.at_root_scope = False
with oqpy.ForIn(oqpy_program, iterator, iterator_name) as f:
yield f
finally:
self.at_root_scope = current_in_global_scope
return self._control_flow_block(oqpy.ForIn(oqpy_program, iterator, iterator_name))

@contextlib.contextmanager
def while_loop(self, condition: Any) -> None:
def while_loop(self, condition: Any) -> contextlib._GeneratorContextManager:
"""Sets the program conversion context into a while loop context.
Args:
condition (Any): The condition of the while loop.
Yields:
_GeneratorContextManager: The context manager of the oqpy.While block.
"""
oqpy_program = self.get_oqpy_program()
current_in_global_scope = self.at_root_scope
try:
self.at_root_scope = False
with oqpy.While(oqpy_program, condition):
yield
finally:
self.at_root_scope = current_in_global_scope
return self._control_flow_block(oqpy.While(oqpy_program, condition))

@contextlib.contextmanager
def gate_definition(self, gate_name: str, gate_args: GateArgs) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/braket/experimental/autoqasm/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ArrayVar(oqpy.ArrayVar):
def __init__(self, *args, **kwargs):
if (
program.get_program_conversion_context().subroutines_processing
or not program.get_program_conversion_context().at_root_scope
or not program.get_program_conversion_context().at_function_root_scope
):
raise errors.InvalidArrayDeclaration(
"Arrays may only be declared at the root scope of an AutoQASM main function."
Expand Down
121 changes: 83 additions & 38 deletions test/unit_tests/braket/experimental/autoqasm/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,73 +363,107 @@ def test_logical_eq_qasm_cond() -> None:
assert "==" in qasm


def test_logical_ops_qasm() -> None:
"""Tests the logical aq.operators for QASM expressions."""

def test_logical_op_and() -> None:
@aq.subroutine
def do_and(a: bool, b: bool):
return a and b

@aq.main
def prog():
do_and(True, False)

expected = """OPENQASM 3.0;
def do_and(bool a, bool b) -> bool {
bool __bool_0__;
__bool_0__ = a && b;
return __bool_0__;
}
bool __bool_1__;
__bool_1__ = do_and(true, false);"""

assert prog().to_ir() == expected


def test_logical_op_or() -> None:
@aq.subroutine
def do_or(a: bool, b: bool):
return a or b

@aq.main
def prog():
do_or(True, False)

expected = """OPENQASM 3.0;
def do_or(bool a, bool b) -> bool {
bool __bool_0__;
__bool_0__ = a || b;
return __bool_0__;
}
bool __bool_1__;
__bool_1__ = do_or(true, false);"""

assert prog().to_ir() == expected


def test_logical_op_not() -> None:
@aq.subroutine
def do_not(a: bool):
return not a

@aq.main
def prog():
do_not(True)

expected = """OPENQASM 3.0;
def do_not(bool a) -> bool {
bool __bool_0__;
__bool_0__ = !a;
return __bool_0__;
}
bool __bool_1__;
__bool_1__ = do_not(true);"""

assert prog().to_ir() == expected


def test_logical_op_eq() -> None:
@aq.subroutine
def do_eq(a: int, b: int):
return a == b

@aq.main
def prog():
do_eq(1, 2)

expected = """OPENQASM 3.0;
def do_eq(int[32] a, int[32] b) -> bool {
bool __bool_0__;
__bool_0__ = a == b;
return __bool_0__;
}
bool __bool_1__;
__bool_1__ = do_eq(1, 2);"""

assert prog().to_ir() == expected


def test_logical_op_not_eq() -> None:
@aq.subroutine
def do_not_eq(a: int, b: int):
return a != b

@aq.main
def prog():
do_and(True, False)
do_or(True, False)
do_not(True)
do_eq(1, 2)
do_not_eq(1, 2)

expected = """OPENQASM 3.0;
def do_and(bool a, bool b) -> bool {
def do_not_eq(int[32] a, int[32] b) -> bool {
bool __bool_0__;
__bool_0__ = a && b;
__bool_0__ = a != b;
return __bool_0__;
}
def do_or(bool a, bool b) -> bool {
bool __bool_2__;
__bool_2__ = a || b;
return __bool_2__;
}
def do_not(bool a) -> bool {
bool __bool_4__;
__bool_4__ = !a;
return __bool_4__;
}
def do_eq(int[32] a, int[32] b) -> bool {
bool __bool_6__;
__bool_6__ = a == b;
return __bool_6__;
}
def do_not_eq(int[32] a, int[32] b) -> bool {
bool __bool_8__;
__bool_8__ = a != b;
return __bool_8__;
}
bool __bool_1__;
__bool_1__ = do_and(true, false);
bool __bool_3__;
__bool_3__ = do_or(true, false);
bool __bool_5__;
__bool_5__ = do_not(true);
bool __bool_7__;
__bool_7__ = do_eq(1, 2);
bool __bool_9__;
__bool_9__ = do_not_eq(1, 2);"""
__bool_1__ = do_not_eq(1, 2);"""

assert prog().to_ir() == expected

Expand Down Expand Up @@ -741,6 +775,17 @@ def test_assert(value: bool):
test_assert(False)


def test_measurement_assert() -> None:
"""Test assertions on measurement results inside an AutoQASM program."""

@aq.main
def test_assert():
assert measure(0)

with pytest.raises(NotImplementedError):
test_assert()


def test_py_list_ops() -> None:
"""Test Python list operations inside an AutoQASM program."""

Expand Down

0 comments on commit d0b848a

Please sign in to comment.