Skip to content

Commit

Permalink
create _InputVarSplitter
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjaskula-aws committed Dec 19, 2023
1 parent 7d7b24f commit c21768e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
28 changes: 17 additions & 11 deletions src/braket/pulse/ast/qasm_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ def visit_ExpressionStatement(self, expression_statement: ast.ExpressionStatemen
else:
return expression_statement


class _InputVarSplitter(QASMTransformer):
"""
QASMTransformer which walks the AST and makes the necessary modifications needed
for IR generation. Currently, it performs the following operations:
* Bubbles up input variables to the top of the CalibrationStatement block.
"""

def visit_Program(self, program: ast.Program) -> ast.Program:
"""Visit a Program.
Args:
Expand All @@ -68,9 +76,8 @@ def visit_Program(self, program: ast.Program) -> ast.Program:
new_statement_list = []
for statement in program.statements:
if isinstance(statement, ast.CalibrationStatement):
input_vars, body = self.split_input_vars(statement.body)
new_statement_list.extend(input_vars)
new_statement_list.append(ast.CalibrationStatement(body))
reordered_cal_block_statements = self.split_input_vars(statement)
new_statement_list.extend(reordered_cal_block_statements)
else:
new_statement_list.append(statement)

Expand All @@ -79,21 +86,20 @@ def visit_Program(self, program: ast.Program) -> ast.Program:

def split_input_vars(
self,
body: list[ast.Statement],
) -> tuple[list[ast.IODeclaration], list[ast.Statement]]:
"""Split input vars out of the calibrationStatement block
node: ast.CalibrationStatement,
) -> list[ast.Statement]:
"""Split input variables out of the calibrationStatement block.
Args:
body (list[Statement]): The list of statement in the CalibrationStatement block
node (CalibrationStatement): The CalibrationStatement block.
Returns:
tuple[list[IODeclaration], list[Statement]]: A tuple of input vars and the list
of remaining statements.
list[Statement]: The list of statements with input variables outside and in front.
"""
input_vars = []
new_body = []
for child in body:
for child in node.body:
if isinstance(child, ast.IODeclaration) and child.io_identifier is ast.IOKeyword.input:
input_vars.append(child)
else:
new_body.append(child)
return input_vars, new_body
return input_vars + [ast.CalibrationStatement(new_body)]
3 changes: 2 additions & 1 deletion src/braket/pulse/pulse_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from braket.pulse.ast.approximation_parser import _ApproximationParser
from braket.pulse.ast.free_parameters import _FreeParameterTransformer
from braket.pulse.ast.qasm_parser import ast_to_qasm
from braket.pulse.ast.qasm_transformer import _IRQASMTransformer
from braket.pulse.ast.qasm_transformer import _InputVarSplitter, _IRQASMTransformer
from braket.pulse.frame import Frame
from braket.pulse.pulse_sequence_trace import PulseSequenceTrace
from braket.pulse.waveforms import Waveform
Expand Down Expand Up @@ -326,6 +326,7 @@ def to_ir(self) -> str:
tree = _IRQASMTransformer(register_identifier).visit(tree)
else:
tree = program.to_ast(encal=True, include_externs=False)
tree = _InputVarSplitter().visit(tree)
return ast_to_qasm(tree)

def _format_parameter_ast(
Expand Down
18 changes: 9 additions & 9 deletions test/unit_tests/braket/pulse/test_pulse_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,13 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined
expected_str_unbound = "\n".join(
[
"OPENQASM 3.0;",
"input float[64] b;",
"input float[64] a;",
"input float[64] length_g;",
"input float[64] sigma_g;",
"input float[64] length_dg;",
"input float[64] sigma_dg;",
"input float[64] length_c;",
"input float b;",
"input float a;",
"input float length_g;",
"input float sigma_g;",
"input float length_dg;",
"input float sigma_dg;",
"input float length_c;",
"cal {",
" waveform gauss_wf = gaussian((length_g) * 1s, (sigma_g) * 1s, 1, false);",
" waveform drag_gauss_wf = drag_gaussian((length_dg) * 1s,"
Expand Down Expand Up @@ -175,8 +175,8 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined
expected_str_b_bound = "\n".join(
[
"OPENQASM 3.0;",
"input float[64] a;",
"input float[64] sigma_g;",
"input float a;",
"input float sigma_g;",
"cal {",
" waveform gauss_wf = gaussian(1.0ms, (sigma_g) * 1s, 1, false);",
" waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);",
Expand Down

0 comments on commit c21768e

Please sign in to comment.