diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py index 31656b667..88212abde 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py @@ -65,7 +65,12 @@ def constant_fold(operation: OperationType, constants: list[Constant], result_ty ) -def _constant_fold_arithmetic_binary(constants: list[Constant], fun: Callable[[int, int], int], norm_sign: Optional[bool] = None) -> int: +def _constant_fold_arithmetic_binary( + constants: list[Constant], + fun: Callable[[int, int], int], + norm_sign: Optional[bool] = None, + allow_mismatched_sizes: bool = False, +) -> int: """ Fold an arithmetic binary operation with constants as operands. @@ -84,7 +89,7 @@ def _constant_fold_arithmetic_binary(constants: list[Constant], fun: Callable[[i if len(constants) != 2: raise IncompatibleOperandCount(f"Expected exactly 2 constants to fold, got {len(constants)}.") - if not all(constant.type.size == constants[0].type.size for constant in constants): + if not allow_mismatched_sizes and not all(constant.type.size == constants[0].type.size for constant in constants): raise UnsupportedMismatchedSizes(f"Can not fold constants with different sizes: {[constant.type for constant in constants]}") left, right = constants @@ -137,6 +142,10 @@ def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], in return fun(normalize_int(left.value, left.type.size, norm_signed), right.value) +def remainder(n, d): + return (-1 if n < 0 else 1) * (n % d) + + _OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], int]] = { OperationType.minus: partial(_constant_fold_arithmetic_binary, fun=operator.sub), OperationType.plus: partial(_constant_fold_arithmetic_binary, fun=operator.add), @@ -144,6 +153,8 @@ def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], in OperationType.multiply_us: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=False), OperationType.divide: partial(_constant_fold_arithmetic_binary, fun=operator.floordiv, norm_sign=True), OperationType.divide_us: partial(_constant_fold_arithmetic_binary, fun=operator.floordiv, norm_sign=False), + OperationType.modulo: partial(_constant_fold_arithmetic_binary, fun=remainder, norm_sign=True, allow_mismatched_sizes=True), + OperationType.modulo_us: partial(_constant_fold_arithmetic_binary, fun=operator.mod, norm_sign=False, allow_mismatched_sizes=True), OperationType.negate: partial(_constant_fold_arithmetic_unary, fun=operator.neg), OperationType.left_shift: partial(_constant_fold_shift, fun=operator.lshift, signed=True), OperationType.right_shift: partial(_constant_fold_shift, fun=operator.rshift, signed=True), diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py b/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py index 44ff99ea7..03ac11327 100644 --- a/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py +++ b/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py @@ -93,6 +93,18 @@ def test_constant_fold_invalid_value_type( (OperationType.divide_us, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)), (OperationType.divide_us, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), (OperationType.divide_us, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.modulo, [_c_i32(13), _c_i32(4)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.modulo, [_c_i32(-2147483647), _c_i32(2)], Integer.int32_t(), _c_i32(-1), nullcontext()), + (OperationType.modulo, [_c_u32(4), _c_i32(3)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.modulo, [_c_i32(4), _c_i16(3)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.modulo, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.modulo, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.modulo_us, [_c_i32(13), _c_i32(4)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.modulo_us, [_c_i32(-2147483647), _c_i32(2)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.modulo_us, [_c_u32(4), _c_i32(3)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.modulo_us, [_c_i32(4), _c_i16(3)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.modulo_us, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.modulo_us, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), (OperationType.negate, [_c_i32(3)], Integer.int32_t(), _c_i32(-3), nullcontext()), (OperationType.negate, [_c_i32(-2147483648)], Integer.int32_t(), _c_i32(-2147483648), nullcontext()), (OperationType.negate, [], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),