Skip to content

Commit

Permalink
feature: AutoQASM types can accept a name when initialized
Browse files Browse the repository at this point in the history
  • Loading branch information
laurencap committed Nov 10, 2023
1 parent 9bfd34c commit 5410556
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 16 deletions.
69 changes: 54 additions & 15 deletions src/braket/experimental/autoqasm/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,40 +57,79 @@ def qasm_range(start: int, stop: Optional[int] = None, step: Optional[int] = 1)


class ArrayVar(oqpy.ArrayVar):
def __init__(self, *args, **kwargs):
def __init__(self, name: Optional[str] = None):
"""Initialize an array variable.
Args:
name (Optional[str]): Name of the variable. If None is provided,
a name will be autogenerated.
Raises:
InvalidArrayDeclaration: If the array is not declared at the root scope of
an AutoQASM main function.
"""
if (
program.get_program_conversion_context().subroutines_processing
or not program.get_program_conversion_context().at_function_root_scope
):
raise errors.InvalidArrayDeclaration(
"Arrays may only be declared at the root scope of an AutoQASM main function."
)
super(ArrayVar, self).__init__(*args, **kwargs)
self.name = program.get_program_conversion_context().next_var_name(oqpy.ArrayVar)
if not name:
name = program.get_program_conversion_context().next_var_name(oqpy.ArrayVar)
super(ArrayVar, self).__init__(name)


class BitVar(oqpy.BitVar):
def __init__(self, *args, **kwargs):
super(BitVar, self).__init__(*args, **kwargs)
self.name = program.get_program_conversion_context().next_var_name(oqpy.BitVar)
def __init__(self, name: Optional[str] = None):
"""Initialize a bit variable.
Args:
name (Optional[str]): Name of the variable. If None is provided,
a name will be autogenerated.
"""
if not name:
name = program.get_program_conversion_context().next_var_name(oqpy.BitVar)
super(BitVar, self).__init__(name)
if self.size:
value = self.init_expression or 0
self.init_expression = ast.BitstringLiteral(value, self.size)


class BoolVar(oqpy.BoolVar):
def __init__(self, *args, **kwargs):
super(BoolVar, self).__init__(*args, **kwargs)
self.name = program.get_program_conversion_context().next_var_name(oqpy.BoolVar)
def __init__(self, name: Optional[str] = None):
"""Initialize a boolean variable.
Args:
name (Optional[str]): Name of the variable. If None is provided,
a name will be autogenerated.
"""
if not name:
name = program.get_program_conversion_context().next_var_name(oqpy.BoolVar)
super(BoolVar, self).__init__(name)


class FloatVar(oqpy.FloatVar):
def __init__(self, *args, **kwargs):
super(FloatVar, self).__init__(*args, **kwargs)
self.name = program.get_program_conversion_context().next_var_name(oqpy.FloatVar)
def __init__(self, name: Optional[str] = None):
"""Initialize a float variable.
Args:
name (Optional[str]): Name of the variable. If None is provided,
a name will be autogenerated.
"""
if not name:
name = program.get_program_conversion_context().next_var_name(oqpy.FloatVar)
super(FloatVar, self).__init__(name)


class IntVar(oqpy.IntVar):
def __init__(self, *args, **kwargs):
super(IntVar, self).__init__(*args, **kwargs)
self.name = program.get_program_conversion_context().next_var_name(oqpy.IntVar)
def __init__(self, name: Optional[str] = None):
"""Initialize an integer variable.
Args:
name (Optional[str]): Name of the variable. If None is provided,
a name will be autogenerated.
"""
if not name:
name = program.get_program_conversion_context().next_var_name(oqpy.IntVar)
super(IntVar, self).__init__(name)
24 changes: 23 additions & 1 deletion test/unit_tests/braket/experimental/autoqasm/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
import pytest

import braket.experimental.autoqasm as aq
from braket.experimental.autoqasm.types.types import qasm_range
from braket.experimental.autoqasm.types.types import (
ArrayVar,
BitVar,
BoolVar,
FloatVar,
IntVar,
qasm_range,
)


@pytest.mark.parametrize(
Expand All @@ -42,6 +49,21 @@ def test_qasm_range(
assert (qrange.start, qrange.stop, qrange.step) == expected_range_params


@pytest.mark.parametrize("type_", [FloatVar, IntVar, BitVar, BoolVar])
def test_manual_name_for_aq_types(type_):
"""Test that types can accept a given name."""
with aq.build_program():
var = type_(name="test")
assert var.name == "test"


def test_manual_name_for_arrayvar():
"""Test that types can accept a given name."""
with aq.build_program():
var = ArrayVar(name="test", dimensions=[1])
assert var.name == "test"


def test_return_bit():
"""Test type discovery of bit return values."""

Expand Down

0 comments on commit 5410556

Please sign in to comment.