diff --git a/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py b/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py index 733365090..0184bf74f 100644 --- a/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py +++ b/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py @@ -2,7 +2,6 @@ from __future__ import annotations -import dataclasses from collections import Counter, defaultdict, deque from dataclasses import dataclass from itertools import chain @@ -19,20 +18,28 @@ from networkx import dfs_postorder_nodes -@dataclass(frozen=True) -class CfgInstructionLocation: +@dataclass(frozen=True, eq=False) +class CfgInstruction: """ dataclass in charge of tracking the location of Instruction objects in the cfg -> The considered instruction, where block is the basic block where it is contained and index the position in the basic block. + + Note: Two instances with the same data will not be equal (because of eq=False). + This way, eq and hash are way more performant, because at the time of writing this, eq and hash are very + expensive on big instructions. + + eq=True would probably be nicer to use, but we don't actually create instances with the same data + multiple times. (Rationale: initially just one instance is created per (block, index) pair. + All further instances with the same (block, index) will have a less complex instruction than before) """ + instruction: Instruction block: BasicBlock - index: int @property - def instruction(self): - return self.block.instructions[self.index] + def index(self): + return next(index for index, instruction in enumerate(self.block.instructions) if id(instruction) == id(self.instruction)) @dataclass() @@ -204,7 +211,7 @@ class DefinitionGenerator: def __init__( self, - expression_usages: DefaultDict[Expression, Counter[CfgInstructionLocation]], + expression_usages: DefaultDict[Expression, Counter[CfgInstruction]], dominator_tree: NetworkXGraph, ): """Generate a new instance based on data parsed from a cfg.""" @@ -214,16 +221,16 @@ def __init__( @classmethod def from_cfg(cls, cfg: ControlFlowGraph) -> DefinitionGenerator: """Initialize a DefinitionGenerator utilizing the data of the given cfg.""" - usages: DefaultDict[Expression, Counter[CfgInstructionLocation]] = defaultdict(Counter) + usages: DefaultDict[Expression, Counter[CfgInstruction]] = defaultdict(Counter) for basic_block in cfg: for index, instruction in enumerate(basic_block.instructions): - instruction_with_position = CfgInstructionLocation(basic_block, index) + instruction_with_position = CfgInstruction(instruction, basic_block) for subexpression in _subexpression_dfs(instruction): usages[subexpression][instruction_with_position] += 1 return cls(usages, cfg.dominator_tree) @property - def usages(self) -> DefaultDict[Expression, Counter[CfgInstructionLocation]]: + def usages(self) -> DefaultDict[Expression, Counter[CfgInstruction]]: """Return a mapping from expressions to a set of instructions using them.""" return self._usages @@ -264,22 +271,16 @@ def _is_invalid_dominator(self, basic_block: BasicBlock, expression: Expression) def _insert_definition(self, instruction: Instruction, block: BasicBlock, index: int): """Insert a new intermediate definition for the given expression at the given location.""" block.instructions.insert(index, instruction) - - # update positions of expression usages - for occurrences in self._usages.values(): - for location in list(occurrences): - if location.block == block and location.index >= index: - occurrences[dataclasses.replace(location, index=location.index + 1)] = occurrences.pop(location) - + cfg_instruction = CfgInstruction(instruction, block) for subexpression in _subexpression_dfs(instruction): - self._usages[subexpression][CfgInstructionLocation(block, index)] += 1 + self._usages[subexpression][cfg_instruction] += 1 @staticmethod - def _find_insertion_index(basic_block: BasicBlock, usages: Iterable[CfgInstructionLocation]) -> int: + def _find_insertion_index(basic_block: BasicBlock, usages: Iterable[CfgInstruction]) -> int: """Find the first index in the given basic block where a definition could be inserted.""" - usage = min((usage for usage in usages if usage.block == basic_block), default=None, key=lambda x: x.index) - if usage: - return usage.index + first_usage_index = min((usage.index for usage in usages if usage.block == basic_block), default=None) + if first_usage_index is not None: + return first_usage_index if not basic_block.instructions: return 0 if isinstance(basic_block.instructions[-1], GenericBranch): @@ -327,7 +328,7 @@ def eliminate_common_subexpressions(self, definition_generator: DefinitionGenera except StopIteration: warning(f"[{self.name}] No dominating basic block could be found for {replacee}") - def _find_elimination_candidates(self, usages: DefaultDict[Expression, Counter[CfgInstructionLocation]]) -> Iterator[Expression]: + def _find_elimination_candidates(self, usages: DefaultDict[Expression, Counter[CfgInstruction]]) -> Iterator[Expression]: """ Iterate all expressions, yielding the expressions which should be eliminated. @@ -341,7 +342,7 @@ def _find_elimination_candidates(self, usages: DefaultDict[Expression, Counter[C usages[subexpression].subtract(expression_usage) yield expression - def _is_cse_candidate(self, expression: Expression, usages: DefaultDict[Expression, Counter[CfgInstructionLocation]]): + def _is_cse_candidate(self, expression: Expression, usages: DefaultDict[Expression, Counter[CfgInstruction]]): """Checks that we can add a common subexpression for the given expression.""" return ( self._is_elimination_candidate(expression, usages[expression]) @@ -359,14 +360,14 @@ def _is_complex_string(self, expression: Expression) -> bool: return isinstance(expression.value, str) and len(expression.value) >= self._min_string_length return False - def _check_inter_instruction(self, expression: Expression, instructions: Counter[CfgInstructionLocation]) -> bool: + def _check_inter_instruction(self, expression: Expression, instructions: Counter[CfgInstruction]) -> bool: """Check if the given expressions should be eliminated based on its global occurrences.""" referencing_instructions_count = sum(1 for _, count in instructions.items() if count > 0) return (expression.complexity >= 2 and referencing_instructions_count >= self._threshold) or ( self._is_complex_string(expression) and referencing_instructions_count >= self._string_threshold ) - def _check_intra_instruction(self, expression: Expression, instructions: Counter[CfgInstructionLocation]) -> bool: + def _check_intra_instruction(self, expression: Expression, instructions: Counter[CfgInstruction]) -> bool: """Check if this expression should be eliminated based on the amount of unique instructions utilizing it.""" referencing_count = instructions.total() return (expression.complexity >= 2 and referencing_count >= self._threshold) or (