Skip to content

Commit

Permalink
MaxPool1d lowering to linalg (llvm#3295)
Browse files Browse the repository at this point in the history
Co-authored-by: root <root@i32b01216.sqa.eu95>
  • Loading branch information
NeverRaR and root authored May 10, 2024
1 parent 261074f commit 1d48596
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 1 deletion.
27 changes: 26 additions & 1 deletion lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ namespace {

template <typename T> struct DimensionTraits {};

template <> struct DimensionTraits<AtenMaxPool1dOp> {
static constexpr int64_t Dim = 1;
// unused const variable warning suppression:
static_assert(Dim == Dim);
};

template <> struct DimensionTraits<AtenMaxPool2dOp> {
static constexpr int64_t Dim = 2;
// unused const variable warning suppression:
Expand Down Expand Up @@ -328,7 +334,24 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {

Type elementType = cast<RankedTensorType>(self.getType()).getElementType();

if constexpr (Dim == 2) {
if constexpr (Dim == 1) {
SmallVector<Value, 4> outTensorShape;
Value maxPool1d, paddedInput;
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
if (failed(createPoolingOp<linalg::PoolingNcwMaxOp>(
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
/*dimensionality=*/1, kernelSizeIntValues, strideInts,
paddingInts, dilationInts, smallestFPValueAttr, outTensorShape,
paddedInput, maxPool1d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool1d");
Type newResultType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool1d);
return success();
} else if constexpr (Dim == 2) {
SmallVector<Value, 4> outTensorShape;
// `maxpool2d` contains the result of maxpool2d operation over the input.
Value maxPool2d, paddedInput;
Expand Down Expand Up @@ -1090,8 +1113,10 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenMaxPool1dOp>();
target.addIllegalOp<AtenMaxPool2dOp>();
target.addIllegalOp<AtenMaxPool3dOp>();
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool1dOp>>(typeConverter, context);
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool2dOp>>(typeConverter, context);
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dOp>>(typeConverter, context);

Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10481,6 +10481,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
10 changes: 10 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,8 @@
"Matmul_matvec",
"Matmul_vecmat",
"MatmulStaticBroadcast_basic",
"MaxPool1dStaticModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool2dStaticModule_basic",
"MaxPool2dEmptyStrideStaticModule_basic",
"MaxPool3dStaticModule_basic",
Expand Down Expand Up @@ -1905,6 +1907,9 @@
TOSA_PASS_SET
| {
### Tests additionally passing in make_fx_tosa
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
Expand Down Expand Up @@ -2361,6 +2366,11 @@
"LinalgNormKeepDimComplexModule_basic",
"LinalgVectorNormComplexModule_basic",
"LogSoftmaxBackwardModule_basic",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
"MaxPool2dCeilModeTrueModule_basic",
"MaxPool2dModule_basic",
"MaxPool2dWithIndicesAllOnesModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2612,6 +2612,11 @@ def aten〇masked_select〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dty
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5)], kernel_size=[2]))
def aten〇max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
120 changes: 120 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,126 @@ def AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils):
# ==============================================================================


class MaxPool1dModule(torch.nn.Module):

def __init__(self):
super().__init__()
self.mp1d = torch.nn.MaxPool1d(
kernel_size=[6], stride=[2], padding=[3], dilation=2
)

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return self.mp1d(x)


@register_test_case(module_factory=lambda: MaxPool1dModule())
def MaxPool1dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, low=-1))


class MaxPool1dEmptyStrideStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([1, 1, 20], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.max_pool1d(x, kernel_size=2, stride=[])


@register_test_case(module_factory=lambda: MaxPool1dEmptyStrideStaticModule())
def MaxPool1dEmptyStrideStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, low=-1))


class MaxPool1dStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()
self.mp1d = torch.nn.MaxPool1d(
kernel_size=[3], stride=[2], padding=[1], dilation=[1]
)

@export
@annotate_args(
[
None,
([1, 64, 112], torch.float32, True),
]
)
def forward(self, x):
return self.mp1d(x)


@register_test_case(module_factory=lambda: MaxPool1dStaticModule())
def MaxPool1dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 64, 112))


class MaxPool1dStaticCeilModeTrueModule(torch.nn.Module):

def __init__(self):
super().__init__()
self.mp1d = torch.nn.MaxPool1d(
kernel_size=[3], stride=[2], padding=[1], dilation=[1], ceil_mode=True
)

@export
@annotate_args(
[
None,
([1, 64, 112], torch.float32, True),
]
)
def forward(self, x):
return self.mp1d(x)


@register_test_case(module_factory=lambda: MaxPool1dStaticCeilModeTrueModule())
def MaxPool1dStaticCeilModeTrueModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 64, 112))


class MaxPool1dCeilModeTrueModule(torch.nn.Module):

def __init__(self):
super().__init__()
self.mp1d = torch.nn.MaxPool1d(
kernel_size=[6], stride=[2], padding=[3], dilation=2, ceil_mode=True
)

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return self.mp1d(x)


@register_test_case(module_factory=lambda: MaxPool1dCeilModeTrueModule())
def MaxPool1dCeilModeTrueModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, low=0.5, high=1.0))


# ==============================================================================


class MaxPool2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
22 changes: 22 additions & 0 deletions test/Conversion/TorchToLinalg/pooling.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,27 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func @forward_max_pool1d
func.func @forward_max_pool1d(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%false = torch.constant.bool false
// CHECK: %[[C1:.*]] = torch_c.to_i64 %int1
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 3] high[0, 0, 3]
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index
// CHECK: %[[INIT:.*]] = tensor.empty(%[[T1]]) : tensor<?xf32>
// CHECK: linalg.pooling_ncw_max {dilations = dense<4> : vector<1xi64>, strides = dense<2> : vector<1xi64>} ins(%[[PADDED]], %[[INIT]] : tensor<?x?x?xf32>, tensor<?xf32>) outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%kernel_size = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%stride = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%padding = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%dilation = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%4 = torch.aten.max_pool1d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?,?],f32>
return %4 : !torch.vtensor<[?,?,?],f32>
}

// CHECK-LABEL: func @forward_max_pool2d
func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int1 = torch.constant.int 1
Expand Down

0 comments on commit 1d48596

Please sign in to comment.