Skip to content

Commit

Permalink
Add modulo constant folding (#411)
Browse files Browse the repository at this point in the history
* Add modulo constant folding

* black

---------

Co-authored-by: Steffen Enders <steffen.enders@fkie.fraunhofer.de>
  • Loading branch information
rihi and steffenenders authored Jun 19, 2024
1 parent ba2a67a commit 800a3c1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ def constant_fold(operation: OperationType, constants: list[Constant], result_ty
)


def _constant_fold_arithmetic_binary(constants: list[Constant], fun: Callable[[int, int], int], norm_sign: Optional[bool] = None) -> int:
def _constant_fold_arithmetic_binary(
constants: list[Constant],
fun: Callable[[int, int], int],
norm_sign: Optional[bool] = None,
allow_mismatched_sizes: bool = False,
) -> int:
"""
Fold an arithmetic binary operation with constants as operands.
Expand All @@ -84,7 +89,7 @@ def _constant_fold_arithmetic_binary(constants: list[Constant], fun: Callable[[i

if len(constants) != 2:
raise IncompatibleOperandCount(f"Expected exactly 2 constants to fold, got {len(constants)}.")
if not all(constant.type.size == constants[0].type.size for constant in constants):
if not allow_mismatched_sizes and not all(constant.type.size == constants[0].type.size for constant in constants):
raise UnsupportedMismatchedSizes(f"Can not fold constants with different sizes: {[constant.type for constant in constants]}")

left, right = constants
Expand Down Expand Up @@ -137,13 +142,19 @@ def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], in
return fun(normalize_int(left.value, left.type.size, norm_signed), right.value)


def remainder(n, d):
return (-1 if n < 0 else 1) * (n % d)


_OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], int]] = {
OperationType.minus: partial(_constant_fold_arithmetic_binary, fun=operator.sub),
OperationType.plus: partial(_constant_fold_arithmetic_binary, fun=operator.add),
OperationType.multiply: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=True),
OperationType.multiply_us: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=False),
OperationType.divide: partial(_constant_fold_arithmetic_binary, fun=operator.floordiv, norm_sign=True),
OperationType.divide_us: partial(_constant_fold_arithmetic_binary, fun=operator.floordiv, norm_sign=False),
OperationType.modulo: partial(_constant_fold_arithmetic_binary, fun=remainder, norm_sign=True, allow_mismatched_sizes=True),
OperationType.modulo_us: partial(_constant_fold_arithmetic_binary, fun=operator.mod, norm_sign=False, allow_mismatched_sizes=True),
OperationType.negate: partial(_constant_fold_arithmetic_unary, fun=operator.neg),
OperationType.left_shift: partial(_constant_fold_shift, fun=operator.lshift, signed=True),
OperationType.right_shift: partial(_constant_fold_shift, fun=operator.rshift, signed=True),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ def test_constant_fold_invalid_value_type(
(OperationType.divide_us, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)),
(OperationType.divide_us, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.divide_us, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.modulo, [_c_i32(13), _c_i32(4)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.modulo, [_c_i32(-2147483647), _c_i32(2)], Integer.int32_t(), _c_i32(-1), nullcontext()),
(OperationType.modulo, [_c_u32(4), _c_i32(3)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.modulo, [_c_i32(4), _c_i16(3)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.modulo, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.modulo, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.modulo_us, [_c_i32(13), _c_i32(4)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.modulo_us, [_c_i32(-2147483647), _c_i32(2)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.modulo_us, [_c_u32(4), _c_i32(3)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.modulo_us, [_c_i32(4), _c_i16(3)], Integer.int32_t(), _c_i32(1), nullcontext()),
(OperationType.modulo_us, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.modulo_us, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
(OperationType.negate, [_c_i32(3)], Integer.int32_t(), _c_i32(-3), nullcontext()),
(OperationType.negate, [_c_i32(-2147483648)], Integer.int32_t(), _c_i32(-2147483648), nullcontext()),
(OperationType.negate, [], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)),
Expand Down

0 comments on commit 800a3c1

Please sign in to comment.