diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/base_class_car.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/base_class_car.py index b7b9fcae1..477377360 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/base_class_car.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/base_class_car.py @@ -3,8 +3,7 @@ from decompiler.pipeline.controlflowanalysis.restructuring_options import LoopBreakOptions, RestructuringOptions from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, FalseNode, SwitchNode, TrueNode -from decompiler.structures.ast.condition_symbol import ConditionHandler -from decompiler.structures.ast.switch_node_handler import ExpressionUsages +from decompiler.structures.ast.condition_symbol import ConditionHandler, ExpressionUsages from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest from decompiler.structures.logic.logic_condition import LogicCondition, PseudoLogicCondition from decompiler.structures.pseudo import Condition, Constant, Expression, OperationType @@ -110,14 +109,14 @@ def _get_expression_compared_with_constant(self, reaching_condition: LogicCondit Check whether the given reaching condition, which is a literal, i.e., a z3-symbol or its negation is of the form `expr == const`. If this is the case, then we return the expression `expr`. """ - return self.asforest.switch_node_handler.get_potential_switch_expression(reaching_condition) + return self.asforest.condition_handler.get_potential_switch_expression_of(reaching_condition) def _get_constant_compared_with_expression(self, reaching_condition: LogicCondition) -> Optional[Constant]: """ Check whether the given reaching condition, which is a literal, i.e., a z3-symbol or its negation is of the form `expr == const`. If this is the case, then we return the constant `const`. """ - return self.asforest.switch_node_handler.get_potential_switch_constant(reaching_condition) + return self.asforest.condition_handler.get_potential_switch_constant_of(reaching_condition) def _convert_to_z3_condition(self, condition: LogicCondition) -> PseudoLogicCondition: return PseudoLogicCondition.initialize_from_formula(condition, self.condition_handler.get_z3_condition_map()) diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py index 6a45b8a36..fbda16638 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py @@ -11,8 +11,8 @@ ) from decompiler.pipeline.controlflowanalysis.restructuring_options import RestructuringOptions from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, CodeNode, ConditionNode, SeqNode, SwitchNode, TrueNode +from decompiler.structures.ast.condition_symbol import ExpressionUsages from decompiler.structures.ast.reachability_graph import CaseDependencyGraph, LinearOrderDependency, SiblingReachability -from decompiler.structures.ast.switch_node_handler import ExpressionUsages from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest from decompiler.structures.logic.logic_condition import LogicCondition from decompiler.structures.pseudo import Constant, Expression @@ -90,8 +90,8 @@ def _clean_up_reachability(self): """ for candidate_1, candidate_2 in permutations(self.switch_candidate.cases, 2): if self.sibling_reachability.reaches(candidate_1.node, candidate_2.node) and not ( - set(self.asforest.switch_node_handler.get_constants_for(candidate_1.condition)) - & set(self.asforest.switch_node_handler.get_constants_for(candidate_2.condition)) + set(self.asforest.condition_handler.get_constants_of(candidate_1.condition)) + & set(self.asforest.condition_handler.get_constants_of(candidate_2.condition)) ): self.asforest._code_node_reachability_graph.remove_reachability_between([candidate_1.node, candidate_2.node]) self.sibling_reachability.remove_reachability_between([candidate_1.node, candidate_2.node]) @@ -521,7 +521,7 @@ def _add_constants_to_cases_for( case_node.constant = Constant("add_to_previous_case") else: considered_conditions.update( - (c, l) for l, c in self.asforest.switch_node_handler.get_literal_and_constant_for(case_node.reaching_condition) + (c, l) for l, c in self.asforest.condition_handler.get_literal_and_constant_of(case_node.reaching_condition) ) def _update_reaching_condition_of(self, case_node: CaseNode, considered_conditions: Dict[Constant, LogicCondition]) -> None: @@ -537,8 +537,7 @@ def _update_reaching_condition_of(self, case_node: CaseNode, considered_conditio :param considered_conditions: The conditions (literals) that are already fulfilled when we reach the given case node. """ constant_of_case_node_literal = { - const: literal - for literal, const in self.asforest.switch_node_handler.get_literal_and_constant_for(case_node.reaching_condition) + const: literal for literal, const in self.asforest.condition_handler.get_literal_and_constant_of(case_node.reaching_condition) } exception_condition: LogicCondition = self.condition_handler.get_true_value() @@ -578,7 +577,7 @@ def prepend_empty_cases_to_case_with_or_condition(self, case: CaseNode) -> List[ the list of new case nodes. """ condition_for_constant: Dict[Constant, LogicCondition] = dict() - for l, c in self.asforest.switch_node_handler.get_literal_and_constant_for(case.reaching_condition): + for l, c in self.asforest.condition_handler.get_literal_and_constant_of(case.reaching_condition): if c is None: raise ValueError( f"The case node should have a reaching-condition that is a disjunction of literals, but it has the clause {l}." diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_sequence.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_sequence.py index 598fe8dc7..d4b35ceda 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_sequence.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder_sequence.py @@ -12,8 +12,8 @@ ) from decompiler.pipeline.controlflowanalysis.restructuring_options import RestructuringOptions from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, ConditionNode, FalseNode, SeqNode, SwitchNode, TrueNode +from decompiler.structures.ast.condition_symbol import ExpressionUsages from decompiler.structures.ast.reachability_graph import SiblingReachabilityGraph -from decompiler.structures.ast.switch_node_handler import ExpressionUsages from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest from decompiler.structures.logic.logic_condition import LogicCondition, PseudoLogicCondition from decompiler.structures.pseudo import Condition, Constant, OperationType diff --git a/decompiler/structures/ast/condition_symbol.py b/decompiler/structures/ast/condition_symbol.py index 8c9c2c374..c6e45b6e4 100644 --- a/decompiler/structures/ast/condition_symbol.py +++ b/decompiler/structures/ast/condition_symbol.py @@ -1,29 +1,180 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Dict, Iterable, Optional, Set +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union from decompiler.structures.logic.logic_condition import LogicCondition, PseudoLogicCondition -from decompiler.structures.pseudo import Condition +from decompiler.structures.logic.z3_implementations import Z3Implementation +from decompiler.structures.pseudo import Condition, Constant, Expression, OperationType, Variable, Z3Converter +from z3 import BoolRef + + +def _is_equivalent(cond1: BoolRef, cond2: BoolRef): + """Check whether the given conditions are equivalent.""" + z3_implementation = Z3Implementation(True) + if z3_implementation.is_equal(cond1, cond2): + return True + return z3_implementation.does_imply(cond1, cond2) and z3_implementation.does_imply(cond2, cond1) + + +def _get_ssa_expression(expression_usage: ExpressionUsages) -> Expression: + """Construct SSA-expression of the given expression.""" + if isinstance(expression_usage.expression, Variable): + return expression_usage.expression.ssa_name if expression_usage.expression.ssa_name else expression_usage.expression + ssa_expression = expression_usage.expression.copy() + for variable in [var for var in ssa_expression.requirements if var.ssa_name]: + ssa_expression.substitute(variable, variable.ssa_name) + return ssa_expression + + +@dataclass(frozen=True) +class ExpressionUsages: + """Dataclass maintaining for a condition the used SSA-variables.""" + + expression: Expression + ssa_usages: Tuple[Optional[Variable], ...] + + @classmethod + def from_expression(cls, expression: Expression) -> ExpressionUsages: + return ExpressionUsages(expression, tuple(var.ssa_name for var in expression.requirements)) @dataclass(frozen=True) +class ZeroCaseCondition: + """Possible switch expression together with its zero-case condition.""" + + expression: Expression + ssa_usages: Set[Optional[Variable]] + z3_condition: BoolRef + + def are_equivalent(self, other: Union[ZeroCaseCondition, PotentialZeroCaseCondition]) -> bool: + return self.ssa_usages == other.ssa_usages and _is_equivalent(self.z3_condition, other.z3_condition) + + +@dataclass(frozen=True) +class PotentialZeroCaseCondition: + """Possible zero-case condition with its z3-condition and ssa-usages.""" + + expression: Condition + ssa_usages: Set[Optional[Variable]] + z3_condition: BoolRef + + def are_equivalent(self, other: Union[ZeroCaseCondition, PotentialZeroCaseCondition]) -> bool: + return self.ssa_usages == other.ssa_usages and _is_equivalent(self.z3_condition, other.z3_condition) + + +@dataclass(frozen=True) +class CaseNodeProperties: + """ + Class for mapping possible expression and constant of a symbol for a switch-case. + + -> symbol: symbol that belongs to the expression and constant + -> constant: the compared constant + -> negation: whether the symbol or its negation belongs to a switch-case + -> The condition that the new case node should get. + """ + + symbol: LogicCondition + expression: ExpressionUsages + constant: Constant + negation: bool + + def __eq__(self, other) -> bool: + """ + We want to be able to compare CaseNodeCandidates with AST-nodes, more precisely, + we want that an CaseNodeCandidate 'case_node' is equal to the AST node 'case_node.node'. + """ + if isinstance(other, CaseNodeProperties): + return self.symbol == other.symbol + return False + + def copy(self) -> CaseNodeProperties: + return CaseNodeProperties(self.symbol, self.expression, self.constant, self.negation) + + +@dataclass class ConditionSymbol: """Dataclass that maintains for each symbol the according condition and its transition in a z3-condition.""" - condition: Condition - symbol: LogicCondition + _condition: Condition + _symbol: LogicCondition z3_condition: PseudoLogicCondition + case_node_property: Optional[CaseNodeProperties] = None + + @property + def condition(self) -> Condition: + return self._condition + + @property + def symbol(self) -> LogicCondition: + return self._symbol + + def __hash__(self) -> int: + return hash((self.condition, self.symbol)) def __eq__(self, other): """Check whether two condition-symbols are equal.""" - return ( - isinstance(other, ConditionSymbol) - and self.condition == other.condition - and self.symbol == other.symbol - and self.z3_condition.is_equivalent_to(other.z3_condition) + return isinstance(other, ConditionSymbol) and self.condition == other.condition and self.symbol == other.symbol + + +@dataclass +class SwitchHandler: + z3_converter: Z3Converter + zero_case_of_switch_expression: Dict[ExpressionUsages, ZeroCaseCondition] + potential_zero_cases: Dict[ConditionSymbol, PotentialZeroCaseCondition] + + @classmethod + def initialize(cls, condition_map: Optional[Dict[LogicCondition, ConditionSymbol]]) -> SwitchHandler: + handler = cls(Z3Converter(), {}, {}) + if condition_map is None: + return handler + for cond_symbol in condition_map.values(): + if cond_symbol.case_node_property is not None: + handler.have_new_zero_case_for(cond_symbol.case_node_property.expression) + elif cond_symbol.condition.operation in {OperationType.equal, OperationType.not_equal} and not any( + isinstance(operand, Constant) for operand in cond_symbol.condition.operands + ): + handler.have_new_potential_zero_case_for(cond_symbol) + return handler + + def have_new_zero_case_for(self, expression_usage: ExpressionUsages) -> bool: + """Returns whether we added a new zero-case condition for the given expression.""" + return expression_usage not in self.zero_case_of_switch_expression and self._successfully_compute_zero_case_condition_for( + expression_usage ) + def have_new_potential_zero_case_for(self, condition_symbol: ConditionSymbol) -> bool: + """Returns whether we added a new zero-case condition for the given expression.""" + return self._successfully_compute_potential_zero_case_condition_for(condition_symbol) + + def _successfully_compute_zero_case_condition_for(self, expression_usage: ExpressionUsages) -> bool: + """Return whether the construction of the zero-case condition was successful and add it to the dictionary.""" + ssa_expression = _get_ssa_expression(expression_usage) + try: + z3_condition = self.z3_converter.convert(Condition(OperationType.equal, [ssa_expression, Constant(0, ssa_expression.type)])) + self.zero_case_of_switch_expression[expression_usage] = ZeroCaseCondition( + expression_usage.expression, set(expression_usage.ssa_usages), z3_condition + ) + return True + except ValueError: + return False + + def _successfully_compute_potential_zero_case_condition_for(self, condition_symbol: ConditionSymbol) -> bool: + """Construct the potential zero-case condition.""" + condition = condition_symbol.condition + expression_usage = ExpressionUsages.from_expression(condition) + ssa_condition = _get_ssa_expression(expression_usage) + assert isinstance(ssa_condition, Condition), f"{ssa_condition} must be of type Condition!" + ssa_condition = ssa_condition.negate() if ssa_condition.operation == OperationType.not_equal else ssa_condition + try: + z3_condition = self.z3_converter.convert(ssa_condition) + self.potential_zero_cases[condition_symbol] = PotentialZeroCaseCondition( + condition, set(expression_usage.ssa_usages), z3_condition + ) + return True + except ValueError: + return False + class ConditionHandler: """Class that handles all the conditions of a transition graph and syntax-forest.""" @@ -33,6 +184,7 @@ def __init__(self, condition_map: Optional[Dict[LogicCondition, ConditionSymbol] self._condition_map: Dict[LogicCondition, ConditionSymbol] = dict() if condition_map is None else condition_map self._symbol_counter = 0 self._logic_context = next(iter(self._condition_map)).context if self._condition_map else LogicCondition.generate_new_context() + self._switch_handler: SwitchHandler = SwitchHandler.initialize(condition_map) def __eq__(self, other) -> bool: """Checks whether two condition handlers are equal.""" @@ -58,7 +210,12 @@ def logic_context(self): def copy(self) -> ConditionHandler: """Return a copy of the condition handler""" condition_map = { - symbol: ConditionSymbol(condition_symbol.condition.copy(), condition_symbol.symbol, condition_symbol.z3_condition) + symbol: ConditionSymbol( + condition_symbol.condition.copy(), + condition_symbol.symbol, + condition_symbol.z3_condition, + condition_symbol.case_node_property.copy(), + ) for symbol, condition_symbol in self._condition_map.items() } return ConditionHandler(condition_map) @@ -71,6 +228,10 @@ def get_z3_condition_of(self, symbol: LogicCondition) -> PseudoLogicCondition: """Return the z3-condition to the given symbol""" return self._condition_map[symbol].z3_condition + def get_case_node_property_of(self, symbol: LogicCondition) -> CaseNodeProperties: + """Return the z3-condition to the given symbol""" + return self._condition_map[symbol].case_node_property + def get_all_symbols(self) -> Set[LogicCondition]: """Return all existing symbols""" return set(self._condition_map.keys()) @@ -87,12 +248,33 @@ def get_reverse_z3_condition_map(self) -> Dict[PseudoLogicCondition, LogicCondit """Return the reverse z3-condition map that maps z3-conditions to symbols.""" return dict((condition_symbol.z3_condition, symbol) for symbol, condition_symbol in self._condition_map.items()) - def update_z3_condition_of(self, symbol: LogicCondition, condition: Condition): - """Change the z3-condition of the given symbol according to the given condition.""" - assert symbol.is_symbol, "Input must be a symbol!" - z3_condition = PseudoLogicCondition.initialize_from_condition(condition, self._logic_context) - pseudo_condition = self.get_condition_of(symbol) - self._condition_map[symbol] = ConditionSymbol(pseudo_condition, symbol, z3_condition) + def get_true_value(self) -> LogicCondition: + """Return a true value.""" + return LogicCondition.initialize_true(self._logic_context) + + def get_false_value(self) -> LogicCondition: + """Return a false value.""" + return LogicCondition.initialize_false(self._logic_context) + + def get_literal_and_constant_of(self, condition: LogicCondition) -> Iterable[LogicCondition, Constant]: + """Get the constant for each literal of the given condition.""" + for literal in condition.get_literals(): + yield literal, self.get_potential_switch_constant_of(literal) + + def get_constants_of(self, condition: LogicCondition) -> Iterable[Constant]: + """Get the constant for each literal of the given condition.""" + for literal in condition.get_literals(): + yield self.get_potential_switch_constant_of(literal) + + def get_potential_switch_constant_of(self, condition: LogicCondition) -> Optional[Constant]: + """Check whether the given condition is a potential switch case, and if return the corresponding constant.""" + if (case_node_property := self._get_case_node_property_of(condition)) is not None: + return case_node_property.constant + + def get_potential_switch_expression_of(self, condition: LogicCondition) -> Optional[ExpressionUsages]: + """Check whether the given condition is a potential switch case, and if return the corresponding expression.""" + if (case_node_property := self._get_case_node_property_of(condition)) is not None: + return case_node_property.expression def add_condition(self, condition: Condition) -> LogicCondition: """Adds a new condition to the condition map and returns the corresponding condition_symbol""" @@ -102,6 +284,7 @@ def add_condition(self, condition: Condition) -> LogicCondition: symbol = self._get_next_symbol() condition_symbol = ConditionSymbol(condition, symbol, z3_condition) + self._set_switch_case_property_for_condition(condition_symbol) self._condition_map[symbol] = condition_symbol return symbol @@ -118,10 +301,87 @@ def _get_next_symbol(self) -> LogicCondition: self._symbol_counter += 1 return LogicCondition.initialize_symbol(f"x{self._symbol_counter}", self._logic_context) - def get_true_value(self) -> LogicCondition: - """Return a true value.""" - return LogicCondition.initialize_true(self._logic_context) + def _set_switch_case_property_for_condition(self, condition_symbol: ConditionSymbol) -> None: + """Compute the switch-case property.""" + condition: Condition = condition_symbol.condition + if condition.operation not in {OperationType.equal, OperationType.not_equal}: + return None + constants: List[Constant] = [operand for operand in condition.operands if isinstance(operand, Constant)] + expressions: List[Expression] = [operand for operand in condition.operands if not isinstance(operand, Constant)] - def get_false_value(self) -> LogicCondition: - """Return a false value.""" - return LogicCondition.initialize_false(self._logic_context) + if len(constants) == 1 and len(expressions) == 1: + expression_usage = ExpressionUsages.from_expression(expressions[0]) + condition_symbol.case_node_property = CaseNodeProperties( + condition_symbol.symbol, expression_usage, constants[0], condition.operation == OperationType.not_equal + ) + self._update_potential_zero_cases_for(expression_usage) + elif len(constants) == 0: + if self._switch_handler.have_new_potential_zero_case_for(condition_symbol): + self._add_zero_case_condition_for(condition_symbol) + + def _update_potential_zero_cases_for(self, expression_usage: ExpressionUsages) -> None: + """ + Update the Zero-cases for the given expression. + + If the switch handler adds a new zero-case condition, we check whether one of the potential zero-cases matches this zero-case. + """ + if self._switch_handler.have_new_zero_case_for(expression_usage): + self._add_missing_zero_cases_for(self._switch_handler.zero_case_of_switch_expression[expression_usage]) + + def _add_missing_zero_cases_for(self, zero_case: ZeroCaseCondition) -> None: + """We check for each potential zero-case whether it matches the given zero-case.""" + found_zero_cases = set() + for condition_symbol, potential_zero_case in self._switch_handler.potential_zero_cases.items(): + if zero_case.are_equivalent(potential_zero_case): + self._update_case_property_for( + condition_symbol, potential_zero_case, ExpressionUsages.from_expression(zero_case.expression) + ) + found_zero_cases.add(condition_symbol) + for zero_case_condition_symbol in found_zero_cases: + del self._switch_handler.potential_zero_cases[zero_case_condition_symbol] + + def _add_zero_case_condition_for(self, potential_zero_case_condition_symbol: ConditionSymbol) -> None: + """ + Check whether the condition belongs to a zero-case of a switch expression. + + If this is the case, we return the switch expression and the zero-constant + """ + potential_zero_case: PotentialZeroCaseCondition = self._switch_handler.potential_zero_cases[potential_zero_case_condition_symbol] + for expression_usage, zero_case in self._switch_handler.zero_case_of_switch_expression.items(): + if potential_zero_case.are_equivalent(zero_case): + self._update_case_property_for(potential_zero_case_condition_symbol, potential_zero_case, expression_usage) + del self._switch_handler.potential_zero_cases[potential_zero_case_condition_symbol] + return None + return None + + def _update_case_property_for( + self, condition_symbol: ConditionSymbol, zero_case: PotentialZeroCaseCondition, expression_usage: ExpressionUsages + ): + """ + Update the case_node_property of the given condition-symbol which belongs to the potential zero-case with the given expression. + """ + condition_symbol.z3_condition = PseudoLogicCondition.initialize_from_condition( + Condition( + zero_case.expression.operation, + [expression_usage.expression, (Constant(0, expression_usage.expression.type))], + ), + self._logic_context, + ) + condition_symbol.case_node_property = CaseNodeProperties( + condition_symbol.symbol, + expression_usage, + Constant(0, expression_usage.expression.type), + zero_case.expression.operation == OperationType.not_equal, + ) + + def _get_case_node_property_of(self, condition: LogicCondition) -> Optional[CaseNodeProperties]: + """Return the case-property of a given literal.""" + negation = False + if condition.is_negation: + condition = condition.operands[0] + negation = True + if condition.is_symbol: + case_node_property = self.get_case_node_property_of(condition) + if case_node_property is not None and case_node_property.negation == negation: + return case_node_property + return None diff --git a/decompiler/structures/ast/switch_node_handler.py b/decompiler/structures/ast/switch_node_handler.py deleted file mode 100644 index 96a2d7ecf..000000000 --- a/decompiler/structures/ast/switch_node_handler.py +++ /dev/null @@ -1,205 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple - -from decompiler.structures.ast.condition_symbol import ConditionHandler -from decompiler.structures.logic.logic_condition import LogicCondition -from decompiler.structures.logic.z3_implementations import Z3Implementation -from decompiler.structures.pseudo import Condition, Constant, Expression, OperationType, Variable, Z3Converter -from z3 import BoolRef - - -@dataclass(frozen=True) -class ExpressionUsages: - """Dataclass maintaining for a condition the used SSA-variables.""" - - expression: Expression - ssa_usages: Tuple[Optional[Variable]] - - -@dataclass -class ZeroCaseCondition: - """Possible switch expression together with its zero-case condition.""" - - expression: Expression - ssa_usages: Set[Optional[Variable]] - z3_condition: BoolRef - - -@dataclass -class CaseNodeProperties: - """ - Class for mapping possible expression and constant of a symbol for a switch-case. - - -> symbol: symbol that belongs to the expression and constant - -> constant: the compared constant - -> The condition that the new case node should get. - """ - - symbol: LogicCondition - expression: ExpressionUsages - constant: Constant - negation: bool - - def __eq__(self, other) -> bool: - """ - We want to be able to compare CaseNodeCandidates with AST-nodes, more precisely, - we want that an CaseNodeCandidate 'case_node' is equal to the AST node 'case_node.node'. - """ - if isinstance(other, CaseNodeProperties): - return self.symbol == other.symbol - return False - - -class SwitchNodeHandler: - """Handler for switch node reconstruction knowing possible constants and expressions for switch-nodes for each symbol.""" - - def __init__(self, condition_handler: ConditionHandler): - """ - Initialize the switch-node constructor. - - self._zero_case_of_switch_expression: maps to each possible switch-expression the possible zero-case condition. - self._case_node_property_of_symbol: maps to each symbol the possible expression and constant for a switch it can belong to. - """ - self._condition_handler: ConditionHandler = condition_handler - self._z3_converter: Z3Converter = Z3Converter() - self._zero_case_of_switch_expression: Dict[ExpressionUsages, ZeroCaseCondition] = dict() - self._get_zero_cases_for_possible_switch_expressions() - self._case_node_properties_of_symbol: Dict[LogicCondition, Optional[CaseNodeProperties]] = dict() - self._initialize_case_node_properties_for_symbols() - - def is_potential_switch_case(self, condition: LogicCondition) -> bool: - """Check whether the given condition is a potential switch case.""" - return self._get_case_node_property_of(condition) is not None - - def get_potential_switch_expression(self, condition: LogicCondition) -> Optional[ExpressionUsages]: - """Check whether the given condition is a potential switch case, and if return the corresponding expression.""" - if (case_node_property := self._get_case_node_property_of(condition)) is not None: - return case_node_property.expression - - def get_potential_switch_constant(self, condition: LogicCondition) -> Optional[Constant]: - """Check whether the given condition is a potential switch case, and if return the corresponding constant.""" - if (case_node_property := self._get_case_node_property_of(condition)) is not None: - return case_node_property.constant - - def get_literal_and_constant_for(self, condition: LogicCondition) -> Iterable[LogicCondition, Constant]: - """Get the constant for each literal of the given condition.""" - for literal in condition.get_literals(): - yield literal, self.get_potential_switch_constant(literal) - - def get_constants_for(self, condition: LogicCondition) -> Iterable[Constant]: - """Get the constant for each literal of the given condition.""" - for literal in condition.get_literals(): - yield self.get_potential_switch_constant(literal) - - def _get_case_node_property_of(self, condition: LogicCondition) -> Optional[CaseNodeProperties]: - """Return the case-property of a given literal.""" - negation = False - if condition.is_negation: - condition = condition.operands[0] - negation = True - if condition.is_symbol: - if condition not in self._case_node_properties_of_symbol: - self._case_node_properties_of_symbol[condition] = self.__get_case_node_property_of_symbol(condition) - if (case_property := self._case_node_properties_of_symbol[condition]) is not None and case_property.negation == negation: - return case_property - return None - - def _get_zero_cases_for_possible_switch_expressions(self) -> None: - """Get all possible switch expressions, i.e., all expression compared with a constant, together with the potential zero case.""" - for symbol in self._condition_handler.get_all_symbols(): - self.__add_switch_expression_and_zero_case_for_symbol(symbol) - - def __add_switch_expression_and_zero_case_for_symbol(self, symbol: LogicCondition) -> None: - """Add possible switch condition for symbol if comparison of expression with constant.""" - assert symbol.is_symbol, f"Each symbol should be a single Literal, but we have {symbol}" - non_constants = [op for op in self._condition_handler.get_condition_of(symbol).operands if not isinstance(op, Constant)] - if len(non_constants) != 1: - return None - expression_usage = ExpressionUsages(non_constants[0], tuple(var.ssa_name for var in non_constants[0].requirements)) - if expression_usage not in self._zero_case_of_switch_expression: - self.__add_switch_expression(expression_usage) - - def __add_switch_expression(self, expression_usage: ExpressionUsages) -> None: - """Construct the zero case condition and add it to the dictionary.""" - ssa_expression = self.__get_ssa_expression(expression_usage) - try: - z3_condition = self._z3_converter.convert(Condition(OperationType.equal, [ssa_expression, Constant(0, ssa_expression.type)])) - except ValueError: - return - self._zero_case_of_switch_expression[expression_usage] = ZeroCaseCondition( - expression_usage.expression, set(expression_usage.ssa_usages), z3_condition - ) - - @staticmethod - def __get_ssa_expression(expression_usage: ExpressionUsages) -> Expression: - """Construct SSA-expression of the given expression.""" - if isinstance(expression_usage.expression, Variable): - return expression_usage.expression.ssa_name if expression_usage.expression.ssa_name else expression_usage.expression - ssa_expression = expression_usage.expression.copy() - for variable in [var for var in ssa_expression.requirements if var.ssa_name]: - ssa_expression.substitute(variable, variable.ssa_name) - return ssa_expression - - def _initialize_case_node_properties_for_symbols(self) -> None: - """Initialize for each symbol the possible switch case properties""" - for symbol in self._condition_handler.get_all_symbols(): - self._case_node_properties_of_symbol[symbol] = self.__get_case_node_property_of_symbol(symbol) - - def __get_case_node_property_of_symbol(self, symbol: LogicCondition) -> Optional[CaseNodeProperties]: - """Return CaseNodeProperty of the given symbol, if it exists.""" - condition = self._condition_handler.get_condition_of(symbol) - if condition.operation not in {OperationType.equal, OperationType.not_equal}: - return None - constants: List[Constant] = [operand for operand in condition.operands if isinstance(operand, Constant)] - expressions: List[Expression] = [operand for operand in condition.operands if not isinstance(operand, Constant)] - - if len(constants) == 1 or len(expressions) == 1: - expression_usage = ExpressionUsages(expressions[0], tuple(var.ssa_name for var in expressions[0].requirements)) - const: Constant = constants[0] - elif len(constants) == 0 and (zero_case_condition := self.__check_for_zero_case_condition(condition)): - expression_usage, const = zero_case_condition - self._condition_handler.update_z3_condition_of(symbol, Condition(condition.operation, [expression_usage.expression, const])) - else: - return None - if expression_usage not in self._zero_case_of_switch_expression: - self.__add_switch_expression(expression_usage) - return CaseNodeProperties(symbol, expression_usage, const, condition.operation == OperationType.not_equal) - - def __check_for_zero_case_condition(self, condition: Condition) -> Optional[Tuple[ExpressionUsages, Constant]]: - """ - Check whether the condition belongs to a zero-case of a switch expression. - - If this is the case, we return the switch expression and the zero-constant - """ - tuple_ssa_usages = tuple(var.ssa_name for var in condition.requirements) - ssa_usages = set(tuple_ssa_usages) - ssa_condition = None - for expression_usage, zero_case_condition in self._zero_case_of_switch_expression.items(): - if zero_case_condition.ssa_usages != ssa_usages: - continue - if ssa_condition is None: - if (ssa_condition := self.__get_z3_condition(ExpressionUsages(condition, tuple_ssa_usages))) is None: - return None - zero_case_z3_condition = zero_case_condition.z3_condition - if self.__is_equivalent(ssa_condition, zero_case_z3_condition): - return expression_usage, Constant(0, expression_usage.expression.type) - - def __get_z3_condition(self, expression_usage: ExpressionUsages) -> Optional[BoolRef]: - """Get z3-condition of the expression usage in SSA-form if there is one""" - ssa_condition = self.__get_ssa_expression(expression_usage) - assert isinstance(ssa_condition, Condition), f"{ssa_condition} must be of type Condition!" - ssa_condition = ssa_condition.negate() if ssa_condition.operation == OperationType.not_equal else ssa_condition - try: - return self._z3_converter.convert(ssa_condition) - except ValueError: - return None - - @staticmethod - def __is_equivalent(cond1: BoolRef, cond2: BoolRef): - """Check whether the given conditions are equivalent.""" - z3_implementation = Z3Implementation(True) - if z3_implementation.is_equal(cond1, cond2): - return True - return z3_implementation.does_imply(cond1, cond2) and z3_implementation.does_imply(cond2, cond1) diff --git a/decompiler/structures/ast/syntaxforest.py b/decompiler/structures/ast/syntaxforest.py index 0fafbde22..4f8fe856f 100644 --- a/decompiler/structures/ast/syntaxforest.py +++ b/decompiler/structures/ast/syntaxforest.py @@ -16,7 +16,6 @@ VirtualRootNode, ) from decompiler.structures.ast.condition_symbol import ConditionHandler -from decompiler.structures.ast.switch_node_handler import SwitchNodeHandler from decompiler.structures.ast.syntaxgraph import AbstractSyntaxInterface from decompiler.structures.graphs.restructuring_graph.transition_cfg import TransitionBlock from decompiler.structures.logic.logic_condition import LogicCondition @@ -37,7 +36,6 @@ def __init__(self, condition_handler: ConditionHandler): self.condition_handler: ConditionHandler = condition_handler self._current_root: VirtualRootNode = self.factory.create_virtual_node() self._add_node(self._current_root) - self.switch_node_handler: SwitchNodeHandler = SwitchNodeHandler(condition_handler) @property def current_root(self) -> Optional[AbstractSyntaxTreeNode]: