Skip to content

Commit

Permalink
Support declaration and assignment of array variables
Browse files Browse the repository at this point in the history
  • Loading branch information
rmshaffer committed Oct 26, 2023
1 parent 09ef788 commit 498dc56
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 101 deletions.
4 changes: 4 additions & 0 deletions src/braket/experimental/autoqasm/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def __str__(self):
return self.message


class InvalidAssignmentStatement(AutoQasmError):
"""Invalid assignment statement for an AutoQASM variable."""


class InvalidArrayDeclaration(AutoQasmError):
"""Invalid declaration of an AutoQASM array variable."""

Expand Down
27 changes: 21 additions & 6 deletions src/braket/experimental/autoqasm/operators/assignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ def assign_stmt(target_name: str, value: Any) -> Any:
oqpy_program = program_conversion_context.get_oqpy_program()
if is_value_name_used or value.init_expression is None:
oqpy_program.set(target, value)
elif target.name not in oqpy_program.declared_vars and program_conversion_context.at_root_scope:
# Explicitly declare and initialize the variable at the root scope.
target.init_expression = value.init_expression
oqpy_program.declare(target)
else:
# Set to `value.init_expression` to avoid declaring an unnecessary variable.
# The variable will be set in the current scope and auto-declared at the root scope.
oqpy_program.set(target, value.init_expression)

return target
Expand All @@ -113,10 +118,20 @@ def _validate_variables_type_size(var1: oqpy.base.Var, var2: oqpy.base.Var) -> N
var1 (oqpy.base.Var): Variable to validate.
var2 (oqpy.base.Var): Variable to validate.
"""
var1_size = var1.size or 1
var2_size = var2.size or 1

if var_type_from_oqpy(var1) != var_type_from_oqpy(var2):
raise ValueError("Variables in assignment statements must have the same type")
if var1_size != var2_size:
raise ValueError("Variables in assignment statements must have the same size")
raise errors.InvalidAssignmentStatement(
"Variables in assignment statements must have the same type"
)

if isinstance(var1, oqpy.ArrayVar):
if var1.dimensions != var2.dimensions:
raise errors.InvalidAssignmentStatement(
"Arrays in assignment statements must have the same dimensions"
)
else:
var1_size = var1.size or 1
var2_size = var2.size or 1
if var1_size != var2_size:
raise errors.InvalidAssignmentStatement(
"Variables in assignment statements must have the same size"
)
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,20 @@ def _oqpy_if_exp(
) -> Optional[oqpy.base.Var]:
"""Overload of if_exp that stages an oqpy conditional."""
result_var = None
oqpy_program = aq_program.get_program_conversion_context().get_oqpy_program()
with oqpy.If(oqpy_program, cond):
program_conversion_context = aq_program.get_program_conversion_context()
with program_conversion_context.if_block(cond):
true_result = aq_types.wrap_value(if_true())
true_result_type = aq_types.var_type_from_oqpy(true_result)
if true_result is not None:
result_var = true_result_type()
oqpy_program.set(result_var, true_result)
with oqpy.Else(oqpy_program):
program_conversion_context.get_oqpy_program().set(result_var, true_result)
with program_conversion_context.else_block():
false_result = aq_types.wrap_value(if_false())
false_result_type = aq_types.var_type_from_oqpy(false_result)
if false_result_type != true_result_type:
raise UnsupportedConditionalExpressionError(true_result_type, false_result_type)
if false_result is not None:
oqpy_program.set(result_var, false_result)
program_conversion_context.get_oqpy_program().set(result_var, false_result)

return result_var

Expand Down
32 changes: 14 additions & 18 deletions src/braket/experimental/autoqasm/operators/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,37 +44,33 @@ def for_stmt(
opts (dict): Options of the for loop.
"""
del get_state, set_state, symbol_names
if extra_test is not None:
raise NotImplementedError("break and return statements are not supported in for loops.")

if is_qasm_type(iter):
_oqpy_for_stmt(iter, extra_test, body, opts)
_oqpy_for_stmt(iter, body, opts)
else:
_py_for_stmt(iter, extra_test, body)
_py_for_stmt(iter, body)


def _oqpy_for_stmt(
iter: oqpy.Range,
extra_test: Callable[[], Any],
body: Callable[[Any], None],
opts: dict,
) -> None:
"""Overload of for_stmt that produces an oqpy for loop."""
oqpy_program = program.get_program_conversion_context().get_oqpy_program()
# TODO: Should check extra_test() on each iteration and break if False,
# but oqpy doesn't currently support break statements at the moment.
with oqpy.ForIn(oqpy_program, iter, opts["iterate_names"]) as f:
program_conversion_context = program.get_program_conversion_context()
with program_conversion_context.for_in(iter, opts["iterate_names"]) as f:
body(f)


def _py_for_stmt(
iter: Iterable,
extra_test: Callable[[], Any],
body: Callable[[Any], None],
) -> None:
"""Overload of for_stmt that executes a Python for loop."""
if extra_test is not None:
raise NotImplementedError("break and return statements are not supported in for loops.")
else:
for target in iter:
body(target)
for target in iter:
body(target)


def while_stmt(
Expand Down Expand Up @@ -107,8 +103,8 @@ def _oqpy_while_stmt(
body: Callable[[], None],
) -> None:
"""Overload of while_stmt that produces an oqpy while loop."""
oqpy_program = program.get_program_conversion_context().get_oqpy_program()
with oqpy.While(oqpy_program, test()):
program_conversion_context = program.get_program_conversion_context()
with program_conversion_context.while_loop(test()):
body()


Expand Down Expand Up @@ -154,10 +150,10 @@ def _oqpy_if_stmt(
orelse: Callable[[], Any],
) -> None:
"""Overload of if_stmt that stages an oqpy cond."""
oqpy_program = program.get_program_conversion_context().get_oqpy_program()
with oqpy.If(oqpy_program, cond):
program_conversion_context = program.get_program_conversion_context()
with program_conversion_context.if_block(cond):
body()
with oqpy.Else(oqpy_program):
with program_conversion_context.else_block():
orelse()


Expand Down
7 changes: 4 additions & 3 deletions src/braket/experimental/autoqasm/operators/slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import oqpy.base

from braket.experimental.autoqasm import program
from braket.experimental.autoqasm.types import is_qasm_type
from braket.experimental.autoqasm.types import is_qasm_type, wrap_value


class GetItemOpts(collections.namedtuple("GetItemOpts", ("element_dtype",))):
Expand Down Expand Up @@ -86,9 +86,10 @@ def set_item(target: Any, i: Any, x: Any) -> Any:

def _oqpy_set_item(target: Any, i: Any, x: Any) -> Any:
"""Overload of set_item that produces an oqpy list modification."""
if not isinstance(target, oqpy.BitVar):
raise NotImplementedError("Slice assignment is not supported.")
if not isinstance(target, (oqpy.BitVar, oqpy.ArrayVar)):
raise NotImplementedError("Only BitVar and ArrayVar types support slice assignment.")

x = wrap_value(x)
is_var_name_used = program.get_program_conversion_context().is_var_name_used(x.name)
oqpy_program = program.get_program_conversion_context().get_oqpy_program()
if is_var_name_used or x.init_expression is None:
Expand Down
69 changes: 69 additions & 0 deletions src/braket/experimental/autoqasm/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def __init__(self, user_config: Optional[UserConfig] = None):
self.user_config = user_config or UserConfig()
self.return_variable = None
self.in_verbatim_block = False
self.at_root_scope = True
self._oqpy_program_stack = [oqpy.Program(simplify_constants=False)]
self._gate_definitions_processing = []
self._calibration_definitions_processing = []
Expand Down Expand Up @@ -447,6 +448,74 @@ def push_oqpy_program(self, oqpy_program: oqpy.Program) -> None:
finally:
self._oqpy_program_stack.pop()

@contextlib.contextmanager
def if_block(self, condition: Any) -> None:
"""Sets the program conversion context into an if block context.
Args:
condition (Any): The condition of the if block.
"""
oqpy_program = self.get_oqpy_program()
current_in_global_scope = self.at_root_scope
try:
self.at_root_scope = False
with oqpy.If(oqpy_program, condition):
yield
finally:
self.at_root_scope = current_in_global_scope

@contextlib.contextmanager
def else_block(self) -> None:
"""Sets the program conversion context into an else block context.
Must be immediately preceded by an if block.
"""
oqpy_program = self.get_oqpy_program()
current_in_global_scope = self.at_root_scope
try:
self.at_root_scope = False
with oqpy.Else(oqpy_program):
yield
finally:
self.at_root_scope = current_in_global_scope

@contextlib.contextmanager
def for_in(
self, iterator: oqpy.Range, iterator_name: Optional[str]
) -> contextlib._GeneratorContextManager:
"""Sets the program conversion context into a for loop context.
Args:
iterator (oqpy.Range): The iterator of the for loop.
iterator_name (Optional[str]): The symbol to use as the name of the iterator.
Yields:
_GeneratorContextManager: The context manager of the oqpy.ForIn block.
"""
oqpy_program = self.get_oqpy_program()
current_in_global_scope = self.at_root_scope
try:
self.at_root_scope = False
with oqpy.ForIn(oqpy_program, iterator, iterator_name) as f:
yield f
finally:
self.at_root_scope = current_in_global_scope

@contextlib.contextmanager
def while_loop(self, condition: Any) -> None:
"""Sets the program conversion context into a while loop context.
Args:
condition (Any): The condition of the while loop.
"""
oqpy_program = self.get_oqpy_program()
current_in_global_scope = self.at_root_scope
try:
self.at_root_scope = False
with oqpy.While(oqpy_program, condition):
yield
finally:
self.at_root_scope = current_in_global_scope

@contextlib.contextmanager
def gate_definition(self, gate_name: str, gate_args: GateArgs) -> None:
"""Sets the program conversion context into a gate definition context.
Expand Down
9 changes: 7 additions & 2 deletions src/braket/experimental/autoqasm/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,13 @@ def qasm_range(start: int, stop: Optional[int] = None, step: Optional[int] = 1)

class ArrayVar(oqpy.ArrayVar):
def __init__(self, *args, **kwargs):
if program.get_program_conversion_context().subroutines_processing:
raise errors.InvalidArrayDeclaration("ArrayVar cannot be declared inside a subroutine.")
if (
program.get_program_conversion_context().subroutines_processing
or not program.get_program_conversion_context().at_root_scope
):
raise errors.InvalidArrayDeclaration(
"Arrays may only be declared at the root scope of an AutoQASM main function."
)
super(ArrayVar, self).__init__(*args, **kwargs)
self.name = program.get_program_conversion_context().next_var_name(oqpy.ArrayVar)

Expand Down
25 changes: 9 additions & 16 deletions test/unit_tests/braket/experimental/autoqasm/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,8 @@ def bell_measurement_declared() -> None:

def test_bell_measurement_declared() -> None:
expected = """OPENQASM 3.0;
bit[2] c;
qubit[2] __qubits__;
c = "00";
bit[2] c = "00";
h __qubits__[0];
cnot __qubits__[0], __qubits__[1];
bit[2] __bit_1__ = "00";
Expand Down Expand Up @@ -279,7 +278,7 @@ def bell_measurement_invalid_declared_type() -> None:
def test_bell_measurement_invalid_declared_type() -> None:
"""Test measurement with reuslt stored in an variable with invalid type."""
expected_error_message = "Variables in assignment statements must have the same type"
with pytest.raises(ValueError) as exc_info:
with pytest.raises(errors.InvalidAssignmentStatement) as exc_info:
bell_measurement_invalid_declared_type()
assert expected_error_message in str(exc_info.value)

Expand All @@ -298,7 +297,7 @@ def bell_measurement_invalid_declared_size() -> None:
def test_bell_measurement_invalid_declared_size() -> None:
"""Test measurement with reuslt stored in an variable with invalid size."""
expected_error_message = "Variables in assignment statements must have the same size"
with pytest.raises(ValueError) as exc_info:
with pytest.raises(errors.InvalidAssignmentStatement) as exc_info:
bell_measurement_invalid_declared_size()
assert expected_error_message in str(exc_info.value)

Expand Down Expand Up @@ -755,20 +754,15 @@ def classical_variables_types() -> None:

def test_classical_variables_types():
expected = """OPENQASM 3.0;
bit a;
int[32] i;
bit[2] a_array;
int[32] b;
float[64] c;
a = 0;
bit a = 0;
a = 1;
i = 1;
a_array = "00";
int[32] i = 1;
bit[2] a_array = "00";
a_array[0] = 0;
a_array[i] = 1;
b = 10;
int[32] b = 10;
b = 15;
c = 1.2;
float[64] c = 1.2;
c = 3.4;"""
assert classical_variables_types().to_ir() == expected

Expand All @@ -786,9 +780,8 @@ def prog() -> None:
a = b # declared target, declared value # noqa: F841

expected = """OPENQASM 3.0;
int[32] a;
int[32] b;
a = 1;
int[32] a = 1;
a = 2;
b = a;
a = b;"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,9 @@ def fn() -> None:

qasm = program_conversion_context.make_program().to_ir()
expected_qasm = """OPENQASM 3.0;
int[32] a;
float[64] b;
int[32] e;
a = 5;
b = 1.2;
int[32] a = 5;
float[64] b = 1.2;
a = 1;
e = a;"""
assert qasm == expected_qasm
Expand Down
Loading

0 comments on commit 498dc56

Please sign in to comment.