Skip to content

Commit

Permalink
declare input parameters with pulse sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjaskula-aws committed Dec 19, 2023
1 parent c6b6f40 commit ec2a559
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
40 changes: 40 additions & 0 deletions src/braket/pulse/ast/qasm_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,43 @@ def visit_ExpressionStatement(self, expression_statement: ast.ExpressionStatemen
return new_val
else:
return expression_statement

def visit_Program(self, program: ast.Program) -> ast.Program:
"""Visit a Program.
Args:
program (Program): The program.
Returns:
Program: the modified program.
"""
new_statement_list = []
for statement in program.statements:
if isinstance(statement, ast.CalibrationStatement):
input_vars, body = self.split_input_vars(statement.body)
new_statement_list.extend(input_vars)
new_statement_list.append(ast.CalibrationStatement(body))
else:
new_statement_list.append(statement)

Check warning on line 75 in src/braket/pulse/ast/qasm_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/braket/pulse/ast/qasm_transformer.py#L75

Added line #L75 was not covered by tests

program.statements = new_statement_list
return self.generic_visit(program)

def split_input_vars(
self,
body: list[ast.Statement],
) -> tuple[list[ast.IODeclaration], list[ast.Statement]]:
"""Split input vars out of the calibrationStatement block
Args:
body (list[Statement]): The list of statement in the CalibrationStatement block
Returns:
tuple[list[IODeclaration], list[Statement]]: A tuple of input vars and the list
of remaining statements.
"""
input_vars = []
new_body = []
for child in body:
if isinstance(child, ast.IODeclaration) and child.io_identifier is ast.IOKeyword.input:
input_vars.append(child)
else:
new_body.append(child)
return input_vars, new_body
19 changes: 17 additions & 2 deletions src/braket/pulse/pulse_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Union

from openpulse import ast
from oqpy import BitVar, PhysicalQubits, Program
from oqpy import BitVar, FloatVar, PhysicalQubits, Program
from oqpy.timing import OQDurationLiteral

from braket.parametric.free_parameter import FreeParameter
Expand Down Expand Up @@ -228,12 +228,20 @@ def play(self, frame: Frame, waveform: Waveform) -> PulseSequence:
"""
_validate_uniqueness(self._frames, frame)
_validate_uniqueness(self._waveforms, waveform)
self._program.play(frame=frame, waveform=waveform)
if isinstance(waveform, Parameterizable):
for param in waveform.parameters:
if isinstance(param, FreeParameterExpression):
for p in param.expression.free_symbols:
self._program._add_var(
FloatVar(
name=p.name,
size=None,
init_expression="input",
needs_declaration=True,
)
)
self._free_parameters.add(FreeParameter(p.name))
self._program.play(frame=frame, waveform=waveform)
self._frames[frame.id] = frame
self._waveforms[waveform.id] = waveform
return self
Expand Down Expand Up @@ -280,6 +288,8 @@ def make_bound_pulse_sequence(self, param_values: dict[str, float]) -> PulseSequ

new_pulse_sequence = PulseSequence()
new_pulse_sequence._program = new_program
for param_name in param_values:
new_pulse_sequence._program.undeclared_vars.pop(param_name, None)
new_pulse_sequence._frames = deepcopy(self._frames)
new_pulse_sequence._waveforms = {
wf.id: wf.bind_values(**param_values) if isinstance(wf, Parameterizable) else wf
Expand Down Expand Up @@ -325,6 +335,11 @@ def _format_parameter_ast(
) -> Union[float, FreeParameterExpression]:
if isinstance(parameter, FreeParameterExpression):
for p in parameter.expression.free_symbols:
self._program._add_var(
FloatVar(
name=p.name, size=None, init_expression="input", needs_declaration=True
)
)
self._free_parameters.add(FreeParameter(p.name))
return (
FreeParameterExpression(parameter, _type)
Expand Down
9 changes: 9 additions & 0 deletions test/unit_tests/braket/pulse/test_pulse_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined
expected_str_unbound = "\n".join(
[
"OPENQASM 3.0;",
"input float[64] b;",
"input float[64] a;",
"input float[64] length_g;",
"input float[64] sigma_g;",
"input float[64] length_dg;",
"input float[64] sigma_dg;",
"input float[64] length_c;",
"cal {",
" waveform gauss_wf = gaussian((length_g) * 1s, (sigma_g) * 1s, 1, false);",
" waveform drag_gauss_wf = drag_gaussian((length_dg) * 1s,"
Expand Down Expand Up @@ -168,6 +175,8 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined
expected_str_b_bound = "\n".join(
[
"OPENQASM 3.0;",
"input float[64] a;",
"input float[64] sigma_g;",
"cal {",
" waveform gauss_wf = gaussian(1.0ms, (sigma_g) * 1s, 1, false);",
" waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);",
Expand Down

0 comments on commit ec2a559

Please sign in to comment.