From c21768ed4fe6fe626dd9feabcffe3c1fba615054 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Tue, 19 Dec 2023 15:50:57 -0500 Subject: [PATCH] create _InputVarSplitter --- src/braket/pulse/ast/qasm_transformer.py | 28 +++++++++++-------- src/braket/pulse/pulse_sequence.py | 3 +- .../braket/pulse/test_pulse_sequence.py | 18 ++++++------ 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/src/braket/pulse/ast/qasm_transformer.py b/src/braket/pulse/ast/qasm_transformer.py index 6626bf6b5..904bebc67 100644 --- a/src/braket/pulse/ast/qasm_transformer.py +++ b/src/braket/pulse/ast/qasm_transformer.py @@ -58,6 +58,14 @@ def visit_ExpressionStatement(self, expression_statement: ast.ExpressionStatemen else: return expression_statement + +class _InputVarSplitter(QASMTransformer): + """ + QASMTransformer which walks the AST and makes the necessary modifications needed + for IR generation. Currently, it performs the following operations: + * Bubbles up input variables to the top of the CalibrationStatement block. + """ + def visit_Program(self, program: ast.Program) -> ast.Program: """Visit a Program. Args: @@ -68,9 +76,8 @@ def visit_Program(self, program: ast.Program) -> ast.Program: new_statement_list = [] for statement in program.statements: if isinstance(statement, ast.CalibrationStatement): - input_vars, body = self.split_input_vars(statement.body) - new_statement_list.extend(input_vars) - new_statement_list.append(ast.CalibrationStatement(body)) + reordered_cal_block_statements = self.split_input_vars(statement) + new_statement_list.extend(reordered_cal_block_statements) else: new_statement_list.append(statement) @@ -79,21 +86,20 @@ def visit_Program(self, program: ast.Program) -> ast.Program: def split_input_vars( self, - body: list[ast.Statement], - ) -> tuple[list[ast.IODeclaration], list[ast.Statement]]: - """Split input vars out of the calibrationStatement block + node: ast.CalibrationStatement, + ) -> list[ast.Statement]: + """Split input variables out of the calibrationStatement block. Args: - body (list[Statement]): The list of statement in the CalibrationStatement block + node (CalibrationStatement): The CalibrationStatement block. Returns: - tuple[list[IODeclaration], list[Statement]]: A tuple of input vars and the list - of remaining statements. + list[Statement]: The list of statements with input variables outside and in front. """ input_vars = [] new_body = [] - for child in body: + for child in node.body: if isinstance(child, ast.IODeclaration) and child.io_identifier is ast.IOKeyword.input: input_vars.append(child) else: new_body.append(child) - return input_vars, new_body + return input_vars + [ast.CalibrationStatement(new_body)] diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index 059af35b5..7d820a014 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -28,7 +28,7 @@ from braket.pulse.ast.approximation_parser import _ApproximationParser from braket.pulse.ast.free_parameters import _FreeParameterTransformer from braket.pulse.ast.qasm_parser import ast_to_qasm -from braket.pulse.ast.qasm_transformer import _IRQASMTransformer +from braket.pulse.ast.qasm_transformer import _InputVarSplitter, _IRQASMTransformer from braket.pulse.frame import Frame from braket.pulse.pulse_sequence_trace import PulseSequenceTrace from braket.pulse.waveforms import Waveform @@ -326,6 +326,7 @@ def to_ir(self) -> str: tree = _IRQASMTransformer(register_identifier).visit(tree) else: tree = program.to_ast(encal=True, include_externs=False) + tree = _InputVarSplitter().visit(tree) return ast_to_qasm(tree) def _format_parameter_ast( diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 555cac1f9..13d835965 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -124,13 +124,13 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined expected_str_unbound = "\n".join( [ "OPENQASM 3.0;", - "input float[64] b;", - "input float[64] a;", - "input float[64] length_g;", - "input float[64] sigma_g;", - "input float[64] length_dg;", - "input float[64] sigma_dg;", - "input float[64] length_c;", + "input float b;", + "input float a;", + "input float length_g;", + "input float sigma_g;", + "input float length_dg;", + "input float sigma_dg;", + "input float length_c;", "cal {", " waveform gauss_wf = gaussian((length_g) * 1s, (sigma_g) * 1s, 1, false);", " waveform drag_gauss_wf = drag_gaussian((length_dg) * 1s," @@ -175,8 +175,8 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined expected_str_b_bound = "\n".join( [ "OPENQASM 3.0;", - "input float[64] a;", - "input float[64] sigma_g;", + "input float a;", + "input float sigma_g;", "cal {", " waveform gauss_wf = gaussian(1.0ms, (sigma_g) * 1s, 1, false);", " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);",