diff --git a/src/braket/experimental/autoqasm/operators/utils.py b/src/braket/experimental/autoqasm/operators/utils.py index a24b502ed..2960585d2 100644 --- a/src/braket/experimental/autoqasm/operators/utils.py +++ b/src/braket/experimental/autoqasm/operators/utils.py @@ -16,7 +16,7 @@ from typing import Any, Union -from braket.circuits import FreeParameter +from braket.circuits import FreeParameterExpression from braket.experimental.autoqasm import program from braket.experimental.autoqasm import types as aq_types @@ -39,8 +39,8 @@ def _register_and_convert_parameters( program_conversion_context.register_args(args) result = [] for arg in args: - if isinstance(arg, FreeParameter): - var = program.get_program_conversion_context().get_parameter(arg.name) + if isinstance(arg, FreeParameterExpression): + var = program.get_program_conversion_context().get_expression_var(arg) result.append(var) else: result.append(arg) diff --git a/src/braket/experimental/autoqasm/program/program.py b/src/braket/experimental/autoqasm/program/program.py index 9f2a241b8..1a79a202c 100644 --- a/src/braket/experimental/autoqasm/program/program.py +++ b/src/braket/experimental/autoqasm/program/program.py @@ -23,7 +23,9 @@ from typing import Any, Optional, Union import oqpy.base +from sympy import Symbol +import braket.experimental.autoqasm.types as aq_types from braket.circuits.free_parameter import FreeParameter from braket.circuits.free_parameter_expression import FreeParameterExpression from braket.circuits.serialization import IRType, SerializableProgram @@ -321,40 +323,60 @@ def register_args(self, args: list[Any]) -> None: args (list[Any]): Arguments passed to the main program or a subroutine. """ for arg in args: - if isinstance(arg, FreeParameter): - self.register_parameter(arg) - elif isinstance(arg, FreeParameterExpression): - # TODO laurecap: Support for expressions - raise NotImplementedError( - "Expressions of FreeParameters will be supported shortly!" - ) + if isinstance(arg, FreeParameterExpression): + for free_symbol_name in self._free_symbol_names(arg): + self.register_parameter(free_symbol_name) + + @staticmethod + def _free_symbol_names(expr: FreeParameterExpression) -> Iterable[str]: + """Return the names of any free symbols found in the provided expression + which are Symbol objects. + + Args: + expr (FreeParameterExpression): The expression in which to look for free symbols. + + Returns: + Iterable[str]: The list of free symbol names in sorted order (sorted to ensure + that the order is deterministic). + """ + return sorted([str(s) for s in expr._expression.free_symbols if isinstance(s, Symbol)]) - def register_parameter(self, parameter: FreeParameter) -> None: + def register_parameter(self, parameter_name: str) -> None: """Register an input parameter if it has not already been registered. Only floats are currently supported. Args: - parameter (FreeParameter): The parameter to register with the program. + parameter_name (str): The name of the parameter to register with the program. """ - if parameter.name not in self._free_parameters: - self._free_parameters[parameter.name] = oqpy.FloatVar("input", name=parameter.name) + if parameter_name not in self._free_parameters: + self._free_parameters[parameter_name] = oqpy.FloatVar("input", name=parameter_name) - def get_parameter(self, name: str) -> oqpy.FloatVar: - """Return a named oqpy.FloatVar that is used as a free parameter in the program. + def get_expression_var(self, expression: FreeParameterExpression) -> oqpy.FloatVar: + """Return an oqpy.FloatVar that represents the provided expression. Args: - name (str): The name of the parameter. + expression (FreeParameterExpression): The expression to represent. Raises: - ParameterNotFoundError: If there is no parameter with the given name registered - with the program. + ParameterNotFoundError: If the expression contains any free parameter which has + not already been registered with the program. Returns: - FloatVar: The associated variable. + FloatVar: The variable representing the expression. """ - if name not in self._free_parameters: - raise errors.ParameterNotFoundError(f"Free parameter '{name}' was not found.") - return self._free_parameters[name] + # Validate that all of the free symbols are registered as free parameters. + for name in self._free_symbol_names(expression): + if name not in self._free_parameters: + raise errors.ParameterNotFoundError(f"Free parameter '{name}' was not found.") + + # If the expression is just a standalone parameter, return the registered variable. + if isinstance(expression, FreeParameter): + return self._free_parameters[expression.name] + + # Otherwise, create a new variable and declare it here + var = aq_types.FloatVar(init_expression=expression) + self.get_oqpy_program().declare(var) + return var def get_free_parameters(self) -> list[oqpy.FloatVar]: """Return a list of named oqpy.Vars that are used as free parameters in the program.""" diff --git a/test/unit_tests/braket/experimental/autoqasm/test_parameters.py b/test/unit_tests/braket/experimental/autoqasm/test_parameters.py index 6633a36ad..5a9aff162 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_parameters.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_parameters.py @@ -21,7 +21,17 @@ from braket.default_simulator import StateVectorSimulator from braket.devices.local_simulator import LocalSimulator from braket.experimental.autoqasm import pulse -from braket.experimental.autoqasm.instructions import cnot, cphaseshift, h, measure, ms, rx, rz, x +from braket.experimental.autoqasm.instructions import ( + cnot, + cphaseshift, + gpi, + h, + measure, + ms, + rx, + rz, + x, +) from braket.tasks.local_quantum_task import LocalQuantumTask @@ -775,3 +785,139 @@ def parametric(val: float): assert bound_prog.to_ir() == template.format(0) bound_prog = parametric(FreeParameter("val")).make_bound_program({"val": 1}) assert bound_prog.to_ir() == template.format(1) + + +def test_parameter_expressions(): + """Test expressions of free parameters with numeric literals.""" + + @aq.main + def parametric(): + expr = 2 * FreeParameter("theta") + gpi(0, expr) + + expected = """OPENQASM 3.0; +input float[64] theta; +qubit[1] __qubits__; +gpi(2*theta) __qubits__[0];""" + assert parametric().to_ir() == expected + + +def test_sim_expressions(): + @aq.main + def parametric(): + rx(0, 2 * FreeParameter("phi")) + measure(0) + + measurements = _test_parametric_on_local_sim(parametric(), {"phi": np.pi / 2}) + assert 0 not in measurements["__bit_0__"] + + +def test_multi_parameter_expressions(): + """Test expressions of multiple free parameters.""" + + @aq.main + def parametric(): + expr = FreeParameter("alpha") * FreeParameter("theta") + gpi(0, expr) + + expected = """OPENQASM 3.0; +input float[64] alpha; +input float[64] theta; +qubit[1] __qubits__; +gpi(alpha*theta) __qubits__[0];""" + assert parametric().to_ir() == expected + + +def test_bound_parameter_expressions(): + """Test expressions of free parameters bound to specific values.""" + + @aq.main + def parametric(): + rx(0, 2 * FreeParameter("phi")) + + expected = """OPENQASM 3.0; +float[64] phi = 1.5707963267948966; +qubit[1] __qubits__; +rx(2*phi) __qubits__[0];""" + assert parametric().make_bound_program({"phi": np.pi / 2}).to_ir() == expected + + +def test_partially_bound_parameter_expressions(): + """Test expressions of free parameters partially bound to specific values.""" + + @aq.main + def parametric(): + expr = FreeParameter("prefactor") * FreeParameter("theta") + gpi(0, expr) + + expected = """OPENQASM 3.0; +float[64] prefactor = 3; +input float[64] theta; +qubit[1] __qubits__; +gpi(prefactor*theta) __qubits__[0];""" + assert parametric().make_bound_program({"prefactor": 3}).to_ir() == expected + + +def test_subroutine_parameter_expressions(): + """Test expressions of free parameters passed to subroutines.""" + + @aq.subroutine + def rotate(theta: float): + rx(0, 3 * theta) + + @aq.main + def parametric(): + rotate(2 * FreeParameter("alpha")) + + expected = """OPENQASM 3.0; +def rotate(float[64] theta) { + rx(3 * theta) __qubits__[0]; +} +input float[64] alpha; +qubit[1] __qubits__; +rotate(2*alpha);""" + assert parametric().to_ir() == expected + + +def test_gate_parameter_expressions(): + """Test expressions of free parameters passed to custom gates.""" + + @aq.gate + def rotate(q: aq.Qubit, theta: float): + rx(q, 3 * theta) + + @aq.main + def parametric(): + rotate(0, 2 * FreeParameter("alpha")) + + expected = """OPENQASM 3.0; +gate rotate(theta) q { + rx(3 * theta) q; +} +input float[64] alpha; +qubit[1] __qubits__; +rotate(2*alpha) __qubits__[0];""" + assert parametric().to_ir() == expected + + +def test_conditional_parameter_expressions(): + """Test expressions of free parameters contained in conditional statements.""" + + @aq.main + def parametric(): + if 2 * FreeParameter("phi") > np.pi: + h(0) + measure(0) + + expected = """OPENQASM 3.0; +input float[64] phi; +qubit[1] __qubits__; +float[64] __float_0__ = 2*phi; +bool __bool_1__; +__bool_1__ = __float_0__ > 3.141592653589793; +if (__bool_1__) { + h __qubits__[0]; +} +bit __bit_2__; +__bit_2__ = measure __qubits__[0];""" + assert parametric().to_ir() == expected diff --git a/test/unit_tests/braket/experimental/autoqasm/test_program.py b/test/unit_tests/braket/experimental/autoqasm/test_program.py index 04dde2601..7d809ed7b 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_program.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_program.py @@ -40,12 +40,14 @@ def test_program_conversion_context() -> None: assert len(prog._oqpy_program_stack) == 1 -def test_get_parameter_invalid_name(): - """Tests the get_parameter function.""" +def test_get_expression_var_invalid_name(): + """Tests the get_expression_var function.""" prog = aq.program.ProgramConversionContext() - prog.register_parameter(FreeParameter("alpha")) + prog.register_parameter("alpha") with pytest.raises(aq.errors.ParameterNotFoundError): - prog.get_parameter("not_a_parameter") + prog.get_expression_var(FreeParameter("not_a_parameter")) + with pytest.raises(aq.errors.ParameterNotFoundError): + prog.get_expression_var(3 * FreeParameter("also_not_a_parameter")) def test_build_program() -> None: