Skip to content

Commit

Permalink
feat: Support FreeParameter expressions in AutoQASM programs (#798)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmshaffer authored Nov 15, 2023
1 parent 45c051a commit 87d48ff
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 28 deletions.
6 changes: 3 additions & 3 deletions src/braket/experimental/autoqasm/operators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import Any, Union

from braket.circuits import FreeParameter
from braket.circuits import FreeParameterExpression
from braket.experimental.autoqasm import program
from braket.experimental.autoqasm import types as aq_types

Expand All @@ -39,8 +39,8 @@ def _register_and_convert_parameters(
program_conversion_context.register_args(args)
result = []
for arg in args:
if isinstance(arg, FreeParameter):
var = program.get_program_conversion_context().get_parameter(arg.name)
if isinstance(arg, FreeParameterExpression):
var = program.get_program_conversion_context().get_expression_var(arg)
result.append(var)
else:
result.append(arg)
Expand Down
62 changes: 42 additions & 20 deletions src/braket/experimental/autoqasm/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from typing import Any, Optional, Union

import oqpy.base
from sympy import Symbol

import braket.experimental.autoqasm.types as aq_types
from braket.circuits.free_parameter import FreeParameter
from braket.circuits.free_parameter_expression import FreeParameterExpression
from braket.circuits.serialization import IRType, SerializableProgram
Expand Down Expand Up @@ -321,40 +323,60 @@ def register_args(self, args: list[Any]) -> None:
args (list[Any]): Arguments passed to the main program or a subroutine.
"""
for arg in args:
if isinstance(arg, FreeParameter):
self.register_parameter(arg)
elif isinstance(arg, FreeParameterExpression):
# TODO laurecap: Support for expressions
raise NotImplementedError(
"Expressions of FreeParameters will be supported shortly!"
)
if isinstance(arg, FreeParameterExpression):
for free_symbol_name in self._free_symbol_names(arg):
self.register_parameter(free_symbol_name)

@staticmethod
def _free_symbol_names(expr: FreeParameterExpression) -> Iterable[str]:
"""Return the names of any free symbols found in the provided expression
which are Symbol objects.
Args:
expr (FreeParameterExpression): The expression in which to look for free symbols.
Returns:
Iterable[str]: The list of free symbol names in sorted order (sorted to ensure
that the order is deterministic).
"""
return sorted([str(s) for s in expr._expression.free_symbols if isinstance(s, Symbol)])

def register_parameter(self, parameter: FreeParameter) -> None:
def register_parameter(self, parameter_name: str) -> None:
"""Register an input parameter if it has not already been registered.
Only floats are currently supported.
Args:
parameter (FreeParameter): The parameter to register with the program.
parameter_name (str): The name of the parameter to register with the program.
"""
if parameter.name not in self._free_parameters:
self._free_parameters[parameter.name] = oqpy.FloatVar("input", name=parameter.name)
if parameter_name not in self._free_parameters:
self._free_parameters[parameter_name] = oqpy.FloatVar("input", name=parameter_name)

def get_parameter(self, name: str) -> oqpy.FloatVar:
"""Return a named oqpy.FloatVar that is used as a free parameter in the program.
def get_expression_var(self, expression: FreeParameterExpression) -> oqpy.FloatVar:
"""Return an oqpy.FloatVar that represents the provided expression.
Args:
name (str): The name of the parameter.
expression (FreeParameterExpression): The expression to represent.
Raises:
ParameterNotFoundError: If there is no parameter with the given name registered
with the program.
ParameterNotFoundError: If the expression contains any free parameter which has
not already been registered with the program.
Returns:
FloatVar: The associated variable.
FloatVar: The variable representing the expression.
"""
if name not in self._free_parameters:
raise errors.ParameterNotFoundError(f"Free parameter '{name}' was not found.")
return self._free_parameters[name]
# Validate that all of the free symbols are registered as free parameters.
for name in self._free_symbol_names(expression):
if name not in self._free_parameters:
raise errors.ParameterNotFoundError(f"Free parameter '{name}' was not found.")

# If the expression is just a standalone parameter, return the registered variable.
if isinstance(expression, FreeParameter):
return self._free_parameters[expression.name]

# Otherwise, create a new variable and declare it here
var = aq_types.FloatVar(init_expression=expression)
self.get_oqpy_program().declare(var)
return var

def get_free_parameters(self) -> list[oqpy.FloatVar]:
"""Return a list of named oqpy.Vars that are used as free parameters in the program."""
Expand Down
148 changes: 147 additions & 1 deletion test/unit_tests/braket/experimental/autoqasm/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,17 @@
from braket.default_simulator import StateVectorSimulator
from braket.devices.local_simulator import LocalSimulator
from braket.experimental.autoqasm import pulse
from braket.experimental.autoqasm.instructions import cnot, cphaseshift, h, measure, ms, rx, rz, x
from braket.experimental.autoqasm.instructions import (
cnot,
cphaseshift,
gpi,
h,
measure,
ms,
rx,
rz,
x,
)
from braket.tasks.local_quantum_task import LocalQuantumTask


Expand Down Expand Up @@ -775,3 +785,139 @@ def parametric(val: float):
assert bound_prog.to_ir() == template.format(0)
bound_prog = parametric(FreeParameter("val")).make_bound_program({"val": 1})
assert bound_prog.to_ir() == template.format(1)


def test_parameter_expressions():
"""Test expressions of free parameters with numeric literals."""

@aq.main
def parametric():
expr = 2 * FreeParameter("theta")
gpi(0, expr)

expected = """OPENQASM 3.0;
input float[64] theta;
qubit[1] __qubits__;
gpi(2*theta) __qubits__[0];"""
assert parametric().to_ir() == expected


def test_sim_expressions():
@aq.main
def parametric():
rx(0, 2 * FreeParameter("phi"))
measure(0)

measurements = _test_parametric_on_local_sim(parametric(), {"phi": np.pi / 2})
assert 0 not in measurements["__bit_0__"]


def test_multi_parameter_expressions():
"""Test expressions of multiple free parameters."""

@aq.main
def parametric():
expr = FreeParameter("alpha") * FreeParameter("theta")
gpi(0, expr)

expected = """OPENQASM 3.0;
input float[64] alpha;
input float[64] theta;
qubit[1] __qubits__;
gpi(alpha*theta) __qubits__[0];"""
assert parametric().to_ir() == expected


def test_bound_parameter_expressions():
"""Test expressions of free parameters bound to specific values."""

@aq.main
def parametric():
rx(0, 2 * FreeParameter("phi"))

expected = """OPENQASM 3.0;
float[64] phi = 1.5707963267948966;
qubit[1] __qubits__;
rx(2*phi) __qubits__[0];"""
assert parametric().make_bound_program({"phi": np.pi / 2}).to_ir() == expected


def test_partially_bound_parameter_expressions():
"""Test expressions of free parameters partially bound to specific values."""

@aq.main
def parametric():
expr = FreeParameter("prefactor") * FreeParameter("theta")
gpi(0, expr)

expected = """OPENQASM 3.0;
float[64] prefactor = 3;
input float[64] theta;
qubit[1] __qubits__;
gpi(prefactor*theta) __qubits__[0];"""
assert parametric().make_bound_program({"prefactor": 3}).to_ir() == expected


def test_subroutine_parameter_expressions():
"""Test expressions of free parameters passed to subroutines."""

@aq.subroutine
def rotate(theta: float):
rx(0, 3 * theta)

@aq.main
def parametric():
rotate(2 * FreeParameter("alpha"))

expected = """OPENQASM 3.0;
def rotate(float[64] theta) {
rx(3 * theta) __qubits__[0];
}
input float[64] alpha;
qubit[1] __qubits__;
rotate(2*alpha);"""
assert parametric().to_ir() == expected


def test_gate_parameter_expressions():
"""Test expressions of free parameters passed to custom gates."""

@aq.gate
def rotate(q: aq.Qubit, theta: float):
rx(q, 3 * theta)

@aq.main
def parametric():
rotate(0, 2 * FreeParameter("alpha"))

expected = """OPENQASM 3.0;
gate rotate(theta) q {
rx(3 * theta) q;
}
input float[64] alpha;
qubit[1] __qubits__;
rotate(2*alpha) __qubits__[0];"""
assert parametric().to_ir() == expected


def test_conditional_parameter_expressions():
"""Test expressions of free parameters contained in conditional statements."""

@aq.main
def parametric():
if 2 * FreeParameter("phi") > np.pi:
h(0)
measure(0)

expected = """OPENQASM 3.0;
input float[64] phi;
qubit[1] __qubits__;
float[64] __float_0__ = 2*phi;
bool __bool_1__;
__bool_1__ = __float_0__ > 3.141592653589793;
if (__bool_1__) {
h __qubits__[0];
}
bit __bit_2__;
__bit_2__ = measure __qubits__[0];"""
assert parametric().to_ir() == expected
10 changes: 6 additions & 4 deletions test/unit_tests/braket/experimental/autoqasm/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ def test_program_conversion_context() -> None:
assert len(prog._oqpy_program_stack) == 1


def test_get_parameter_invalid_name():
"""Tests the get_parameter function."""
def test_get_expression_var_invalid_name():
"""Tests the get_expression_var function."""
prog = aq.program.ProgramConversionContext()
prog.register_parameter(FreeParameter("alpha"))
prog.register_parameter("alpha")
with pytest.raises(aq.errors.ParameterNotFoundError):
prog.get_parameter("not_a_parameter")
prog.get_expression_var(FreeParameter("not_a_parameter"))
with pytest.raises(aq.errors.ParameterNotFoundError):
prog.get_expression_var(3 * FreeParameter("also_not_a_parameter"))


def test_build_program() -> None:
Expand Down

0 comments on commit 87d48ff

Please sign in to comment.