Skip to content

Commit

Permalink
Fixed issue with inline breaks and added loop iid
Browse files Browse the repository at this point in the history
  • Loading branch information
AryazE committed Oct 4, 2024
1 parent 1b249b7 commit b78f8bb
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 96 deletions.
53 changes: 45 additions & 8 deletions src/dynapyt/analyses/TraceAll.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ def exit_control_flow(self, dyn_ast: str, iid: int) -> None:
The path to the original code. Can be used to extract the syntax tree.
iid : int
Unique ID of the syntax tree node.
Unique ID of the syntax tree node for the control flow statement.
"""
self.log(iid, "Control-flow exit")
Expand Down Expand Up @@ -1217,6 +1217,19 @@ def exit_for(self, dyn_ast, iid):
"""
self.log(iid, "For exit")

def normal_exit_for(self, dyn_ast: str, iid: int) -> None:
"""Hook for exiting a for loop without a break or continue statement.
Parameters
----------
dyn_ast : str
The path to the original code. Can be used to extract the syntax tree.
iid : int
Unique ID of the syntax tree node.
"""
self.log(iid, "For exit normally")

def enter_while(self, dyn_ast: str, iid: int, cond_value: bool) -> Optional[bool]:
"""Hook for entering the next iteration of a while loop.
Expand Down Expand Up @@ -1257,8 +1270,8 @@ def exit_while(self, dyn_ast, iid):
"""
self.log(iid, "While exit")

def _break(self, dyn_ast: str, iid: int) -> Optional[bool]:
"""Hook for break statement.
def normal_exit_while(self, dyn_ast: str, iid: int) -> None:
"""Hook for exiting a while loop without a break or continue statement.
Parameters
Expand All @@ -1270,6 +1283,25 @@ def _break(self, dyn_ast: str, iid: int) -> Optional[bool]:
Unique ID of the syntax tree node.
"""
self.log(iid, "While exit normally")

def _break(self, dyn_ast: str, iid: int, loop_iid: int) -> Optional[bool]:
"""Hook for break statement.
Parameters
----------
dyn_ast : str
The path to the original code. Can be used to extract the syntax tree.
iid : int
Unique ID of the syntax tree node at break.
loop_iid : int
Unique ID of the syntax tree node for the loop statement.
Returns
-------
bool
Expand All @@ -1278,7 +1310,7 @@ def _break(self, dyn_ast: str, iid: int) -> Optional[bool]:
"""
self.log(iid, "Break")

def _continue(self, dyn_ast: str, iid: int) -> Optional[bool]:
def _continue(self, dyn_ast: str, iid: int, loop_iid: int) -> Optional[bool]:
"""Hook for continue statement.
Expand All @@ -1288,7 +1320,10 @@ def _continue(self, dyn_ast: str, iid: int) -> Optional[bool]:
The path to the original code. Can be used to extract the syntax tree.
iid : int
Unique ID of the syntax tree node.
Unique ID of the syntax tree node of continue.
loop_iid : int
Unique ID of the syntax tree node for the loop statement.
Returns
Expand Down Expand Up @@ -1420,7 +1455,9 @@ def enter_decorator(self, dyn_ast: str, iid: int, decorator_name, args, kwargs):
"""
self.log(iid, "Entered decorator", decorator_name)

def exit_decorator(self, dyn_ast: str, iid: int, decorator_name, result, args, kwargs) -> Any:
def exit_decorator(
self, dyn_ast: str, iid: int, decorator_name, result, args, kwargs
) -> Any:
"""Hook for exiting a decorator.
Expand All @@ -1444,11 +1481,11 @@ def exit_decorator(self, dyn_ast: str, iid: int, decorator_name, result, args, k
kwargs : Dict
The keyword arguments passed to the decorator.
Returns
-------
Any
If provided, overwrites the result returned by the function
If provided, overwrites the result returned by the function
"""
self.log(iid, "Exited decorator", decorator_name)
Expand Down
145 changes: 83 additions & 62 deletions src/dynapyt/instrument/CodeInstrumenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, src, file_path, iids: IIDs, selected_hooks):
self.file_path = str(Path(file_path).resolve())
self.iids = iids
self.name_stack = []
self.current_loop = []
self.current_try = []
self.current_class = []
self.current_function = []
Expand Down Expand Up @@ -1674,53 +1675,6 @@ def leave_ExceptHandler(self, original_node, updated_node):
return updated_node.with_changes(body=new_body)

# Control flow
def leave_IndentedBlock(self, original_node, updated_node):
if ("_break" not in self.selected_hooks) and (
"_continue" not in self.selected_hooks
):
return updated_node
new_body = []
for i in updated_node.body:
if ("_break" in self.selected_hooks) and (
m.matches(i, m.SimpleStatementLine(body=[m.Break()]))
):
callee_name = cst.Attribute(
value=cst.Name(value="_rt"), attr=cst.Name(value="_break_")
)
self.to_import.add("_break_")
iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name("_dynapyt_ast_"))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
call = cst.Call(func=callee_name, args=[ast_arg, iid_arg])
condition = cst.If(
test=call,
body=cst.IndentedBlock(
body=[cst.SimpleStatementLine(body=[cst.Break()])]
),
)
new_body.append(condition)
elif ("_continue" in self.selected_hooks) and (
m.matches(i, m.SimpleStatementLine(body=[m.Continue()]))
):
callee_name = cst.Attribute(
value=cst.Name(value="_rt"), attr=cst.Name(value="_continue_")
)
self.to_import.add("_continue_")
iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name("_dynapyt_ast_"))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
call = cst.Call(func=callee_name, args=[ast_arg, iid_arg])
condition = cst.If(
test=call,
body=cst.IndentedBlock(
body=[cst.SimpleStatementLine(body=[cst.Continue()])]
),
)
new_body.append(condition)
else:
new_body.append(i)

return updated_node.with_changes(body=new_body)

def leave_If(self, original_node, updated_node):
if ("enter_if" not in self.selected_hooks) and (
Expand Down Expand Up @@ -1809,12 +1763,79 @@ def leave_IfExp(self, original_node, updated_node):
)
return call

def leave_SimpleStatementLine(self, original_node, updated_node):
if "_break" in self.selected_hooks and m.matches(
updated_node, m.SimpleStatementLine(body=[m.Break()])
):
return self.instrument_Break(original_node.body[0], updated_node.body[0])
elif "_continue" in self.selected_hooks and m.matches(
updated_node, m.SimpleStatementLine(body=[m.Continue()])
):
return self.instrument_Continue(original_node.body[0], updated_node.body[0])
return updated_node

def instrument_Break(self, original_node, updated_node):
if "_break" not in self.selected_hooks:
return updated_node
self.to_import.add("_break_")
callee_name = cst.Attribute(
value=cst.Name(value="_rt"), attr=cst.Name(value="_break_")
)
iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name("_dynapyt_ast_"))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
ctrl_flow_iid_arg = cst.Arg(
value=cst.Integer(value=str(self.current_loop[-1][0]))
)
ctrl_flow_type_arg = cst.Arg(
value=cst.Integer(value=str(self.current_loop[-1][1]))
)
call = cst.Call(
func=callee_name,
args=[ast_arg, iid_arg, ctrl_flow_iid_arg, ctrl_flow_type_arg],
)
return cst.If(
test=call,
body=cst.IndentedBlock(body=[cst.SimpleStatementLine(body=[cst.Break()])]),
)

def instrument_Continue(self, original_node, updated_node):
if "_continue" not in self.selected_hooks:
return updated_node
self.to_import.add("_continue_")
callee_name = cst.Attribute(
value=cst.Name(value="_rt"), attr=cst.Name(value="_continue_")
)
iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name("_dynapyt_ast_"))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
ctrl_flow_iid_arg = cst.Arg(
value=cst.Integer(value=str(self.current_loop[-1][0]))
)
ctrl_flow_type_arg = cst.Arg(
value=cst.Integer(value=str(self.current_loop[-1][1]))
)
call = cst.Call(
func=callee_name,
args=[ast_arg, iid_arg, ctrl_flow_iid_arg, ctrl_flow_type_arg],
)
return cst.If(
test=call,
body=cst.IndentedBlock(
body=[cst.SimpleStatementLine(body=[cst.Continue()])]
),
)

def visit_While(self, node):
iid = self.__create_iid(node)
self.current_loop.append((iid, 0)) # 0 for while loop, 1 for for loop

def leave_While(self, original_node, updated_node):
iid = self.current_loop.pop()[0]
if ("enter_while" not in self.selected_hooks) and (
"exit_while" not in self.selected_hooks
"normal_exit_while" not in self.selected_hooks
):
return updated_node
iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name("_dynapyt_ast_"))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
if "enter_while" in self.selected_hooks:
Expand All @@ -1826,7 +1847,7 @@ def leave_While(self, original_node, updated_node):
enter_call = cst.Call(func=enter_name, args=[ast_arg, iid_arg, enter_arg])
else:
enter_call = updated_node.test
if "exit_while" in self.selected_hooks:
if "normal_exit_while" in self.selected_hooks:
end_name = cst.Attribute(
value=cst.Name(value="_rt"), attr=cst.Name(value="_exit_while_")
)
Expand All @@ -1850,13 +1871,17 @@ def leave_While(self, original_node, updated_node):
orelse=else_part,
)

def visit_For(self, node):
iid = self.__create_iid(node)
self.current_loop.append((iid, 1)) # 0 for while loop, 1 for for loop

def leave_For(self, original_node, updated_node):
iid = self.current_loop.pop()[0]
if (
("enter_for" not in self.selected_hooks)
and ("exit_for" not in self.selected_hooks)
and ("normal_exit_for" not in self.selected_hooks)
) or original_node.asynchronous is not None: # TODO: Handle async for loops
return updated_node
iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name("_dynapyt_ast_"))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
if "enter_for" in self.selected_hooks:
Expand All @@ -1869,7 +1894,7 @@ def leave_For(self, original_node, updated_node):
func=generator_name, args=[ast_arg, iid_arg, iter_arg]
)
else_part = updated_node.orelse
elif "exit_for" in self.selected_hooks:
elif "normal_exit_for" in self.selected_hooks:
end_name = cst.Attribute(
value=cst.Name(value="_rt"), attr=cst.Name(value="_exit_for_")
)
Expand All @@ -1893,7 +1918,7 @@ def leave_For(self, original_node, updated_node):
def leave_CompFor(self, original_node, updated_node):
if (
"enter_for" not in self.selected_hooks
and "exit_for" not in self.selected_hooks
and "normal_exit_for" not in self.selected_hooks
) or original_node.asynchronous is not None: # TODO: Handle async for loops
return updated_node
generator_name = cst.Attribute(
Expand Down Expand Up @@ -1931,30 +1956,26 @@ def leave_WithItem(self, original_node, updated_node):
call = cst.Call(func=callee_name, args=[ast_arg, iid_arg, ctx_manager_arg])
return updated_node.with_changes(item=call)


def leave_Decorator(self, original_node, updated_node):
print("decorator node: ", original_node)
if ("enter_decorator" not in self.selected_hooks) and (
"exit_decorator" not in self.selected_hooks
):
return updated_node

iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name("_dynapyt_ast_"))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
dynapyt_decorator_attr = cst.Attribute(
value=cst.Name("_rt"), attr=cst.Name("dynapyt_decorator"),
value=cst.Name("_rt"),
attr=cst.Name("dynapyt_decorator"),
)
dynapyt_decorator_call = cst.Call(
func=dynapyt_decorator_attr,
args=[ast_arg, iid_arg],
)
dynapyt_decorator = cst.Decorator(
decorator=dynapyt_decorator_call
)
dynapyt_decorator = cst.Decorator(decorator=dynapyt_decorator_call)

self.to_import.add("dynapyt_decorator")

return cst.FlattenSentinel([dynapyt_decorator, updated_node])


32 changes: 31 additions & 1 deletion src/dynapyt/instrument/instrument.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
from multiprocessing import Pool
import traceback
import libcst as cst
from libcst._exceptions import ParserSyntaxError
from .CodeInstrumenter import CodeInstrumenter
Expand All @@ -23,9 +24,37 @@ def gather_files(files_arg):
return files


def canonical_ifs(node, child_dict):
new_body = (
cst.Break()
if cst.matchers.matches(node.body.body[0], cst.matchers.Break())
else cst.Continue()
)
return cst.If(
test=node.test,
body=cst.IndentedBlock(body=[cst.SimpleStatementLine(body=[new_body])]),
orelse=node.orelse,
leading_lines=node.leading_lines,
whitespace_before_test=node.whitespace_before_test,
whitespace_after_test=node.whitespace_after_test,
)


def instrument_code(src, file_path, iids, selected_hooks):
try:
ast = cst.parse_module(src)
print("Before:")
print(src)
ast = cst.matchers.replace(
cst.parse_module(src),
cst.matchers.If(
body=cst.matchers.SimpleStatementSuite(
body=[(cst.matchers.Break() | cst.matchers.Continue())]
)
),
canonical_ifs,
)
print("After:")
print(ast.code)
ast_wrapper = cst.metadata.MetadataWrapper(ast)

instrumented_code = CodeInstrumenter(src, file_path, iids, selected_hooks)
Expand All @@ -38,6 +67,7 @@ def instrument_code(src, file_path, iids, selected_hooks):
except Exception as e:
print(f"Error in {file_path} -- skipping it")
print(e)
print(traceback.format_exc())
return None


Expand Down
Loading

0 comments on commit b78f8bb

Please sign in to comment.