diff --git a/src/dynapyt/instrument/CodeInstrumenter.py b/src/dynapyt/instrument/CodeInstrumenter.py index e7847ae..61407fb 100644 --- a/src/dynapyt/instrument/CodeInstrumenter.py +++ b/src/dynapyt/instrument/CodeInstrumenter.py @@ -1029,6 +1029,43 @@ def leave_Assign(self, original_node, updated_node): else: return updated_node.with_changes(value=call) + def leave_AnnAssign(self, original_node, updated_node): + # Keep track of nodes of blacklisted attributes + if self.file_path.endswith("__init__.py") and m.matches( + original_node, m.AnnAssign(target=m.OneOf(*self.blacklist_name_objs)) + ): + self.blacklist_nodes.append(cst.SimpleStatementLine(body=[original_node])) + self.blacklist_nodes.append(cst.Newline(value="\n")) + if "write" not in self.selected_hooks or original_node.value is None: + return updated_node + callee_name = cst.Attribute( + value=cst.Name(value="_rt"), attr=cst.Name(value="_write_") + ) + self.to_import.add("_write_") + iid = self.__create_iid(original_node) + ast_arg = cst.Arg(value=cst.Name("_dynapyt_ast_")) + iid_arg = cst.Arg(value=cst.Integer(value=str(iid))) + if m.matches(updated_node.value, m.Yield()): + val_arg = cst.Arg(value=updated_node.value.value) + else: + val_arg = cst.Arg(value=updated_node.value) + left_arg = cst.Arg( + value=cst.List( + elements=[ + cst.Element( + self.__wrap_in_lambda(original_node.target, updated_node.target) + ) + ] + ) + ) + call = cst.Call(func=callee_name, args=[ast_arg, iid_arg, val_arg, left_arg]) + if m.matches(updated_node.value, m.Yield()): + return updated_node.with_changes( + value=updated_node.value.with_changes(value=call) + ) + else: + return updated_node.with_changes(value=call) + def leave_AugAssign(self, original_node, updated_node): if ("write" not in self.selected_hooks) and ( snake(type(original_node.operator).__name__) not in self.selected_hooks