From 2c9c763191a5b040309e7cc5ab059c72ac2e7253 Mon Sep 17 00:00:00 2001 From: Angel Zhang <68571948+angelz913@users.noreply.github.com> Date: Fri, 10 May 2024 10:39:13 -0400 Subject: [PATCH 1/6] Update development.md (#3314) Add a command for installing the `python-dev` package --------- Co-authored-by: Jakub Kuderski --- docs/development.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/development.md b/docs/development.md index e1b575b54b1f..fe997447c319 100644 --- a/docs/development.md +++ b/docs/development.md @@ -13,9 +13,6 @@ While this is running, you can already setup the Python venv and dependencies in ## Setup your Python VirtualEnvironment and Dependencies -Also, ensure that you have the appropriate `python-dev` package installed -to access the Python development libraries / headers. - ```shell python -m venv mlir_venv source mlir_venv/bin/activate @@ -26,6 +23,15 @@ python -m pip install -r requirements.txt python -m pip install -r torchvision-requirements.txt ``` +Also, ensure that you have the appropriate `python-dev` package installed +to access the Python development libraries / headers. For example, you can install +it with the following `apt` command on Ubuntu/Debian. + +```shell +sudo apt install python3-dev +``` + + ## (Optional) Set up pre-commit This project uses [pre-commit](https://pre-commit.com/) in its CI. You can From 7c289d95222f0297b7f1b36dc123663b95341136 Mon Sep 17 00:00:00 2001 From: Angel Zhang <68571948+angelz913@users.noreply.github.com> Date: Fri, 10 May 2024 11:58:46 -0400 Subject: [PATCH 2/6] [ONNX] Handle one-input case for `onnx.Max` operator (#3325) This commit handles the one-input case for the "Max" ONNX operator. A new unit test has also been added. --- lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp | 2 +- test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 1f1e2e5d7f0c..d419b0b5b74b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -651,7 +651,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); } - rewriter.replaceOp(binder.op, result.getDefiningOp()); + rewriter.replaceOp(binder.op, result); return success(); }); patterns.onOp( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index d280d5f6b495..6041bae1cd09 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -740,6 +740,15 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 // ----- +// CHECK-LABEL: func.func @test_max_one_input_example + func.func @test_max_one_input_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: return %arg0 : !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Max"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> + } + +// ----- + // CHECK-LABEL: func.func @test_min_example func.func @test_min_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> From 10db31046028e26142e93e53b0eba8f38ffd91d1 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 10 May 2024 21:45:06 +0530 Subject: [PATCH 3/6] build: manually update PyTorch version (#3291) Set PyTorch and TorchVision version to nightly release 2024-05-05. Signed-Off By: Vivek Khandelwal --- projects/pt1/e2e_testing/xfail_sets.py | 6 ------ pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 4 files changed, 3 insertions(+), 9 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2c7d392d3300..3eee6b2d3727 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2660,12 +2660,6 @@ "ReduceMinAlongDimUnsignedInt_basic", } -if torch_version_for_comparison() >= version.parse("2.4.0.dev"): - ONNX_XFAIL_SET = ONNX_XFAIL_SET | { - # ERROR: Found dtype (torch.float64) but expected (torch.float32) - "ReduceL1NormWithDTypeModule_basic", - } - if torch_version_for_comparison() < version.parse("2.3.0.dev"): ONNX_XFAIL_SET = ONNX_XFAIL_SET | { # ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120])) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 400586976392..3424cb46aad1 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -34ade3521ca41f20af3469bba276c2b0499c3892 +1b7523fbe9d0a0c81930673f4374c6e69fa293b6 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 7cd8d44e5425..7b73c61f4e13 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.4.0.dev20240428 +torch==2.4.0.dev20240505 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 148f66152b88..a7da638bc2bf 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.19.0.dev20240428 +torchvision==0.19.0.dev20240505 From be20db0a0eaec40bdd1d9b197e7a0e139bb08e66 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Sat, 11 May 2024 00:28:58 +0800 Subject: [PATCH 4/6] [NFC] Delete the deprecated example cases (#3323) --- .../pt1/examples/example-requirements.txt | 4 + projects/pt1/examples/torchdynamo_resnet18.py | 106 ------------------ .../torchscript_stablehlo_backend_resnet.py | 2 +- .../torchscript_stablehlo_backend_tinybert.py | 2 +- 4 files changed, 6 insertions(+), 108 deletions(-) create mode 100644 projects/pt1/examples/example-requirements.txt delete mode 100644 projects/pt1/examples/torchdynamo_resnet18.py diff --git a/projects/pt1/examples/example-requirements.txt b/projects/pt1/examples/example-requirements.txt new file mode 100644 index 000000000000..c443f7a3bcd0 --- /dev/null +++ b/projects/pt1/examples/example-requirements.txt @@ -0,0 +1,4 @@ +datasets +transformers +requests +pillow diff --git a/projects/pt1/examples/torchdynamo_resnet18.py b/projects/pt1/examples/torchdynamo_resnet18.py deleted file mode 100644 index 76602d4bae59..000000000000 --- a/projects/pt1/examples/torchdynamo_resnet18.py +++ /dev/null @@ -1,106 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -import sys -from typing import List - -from PIL import Image -import requests - -import torch -import torch._dynamo as dynamo -import torchvision.models as models -from torchvision import transforms - -from torch_mlir import torchscript -from torch_mlir.dynamo import make_simple_dynamo_backend -from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend - - -def load_and_preprocess_image(url: str): - headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36" - } - img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB") - # preprocessing pipeline - preprocess = transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ] - ) - img_preprocessed = preprocess(img) - return torch.unsqueeze(img_preprocessed, 0) - - -def load_labels(): - classes_text = requests.get( - "https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt", - stream=True, - ).text - labels = [line.strip() for line in classes_text.splitlines()] - return labels - - -def top3_possibilities(res): - _, indexes = torch.sort(res, descending=True) - percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100 - top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]] - return top3 - - -def predictions(torch_func, jit_func, img, labels): - golden_prediction = top3_possibilities(torch_func(img)) - print("PyTorch prediction") - print(golden_prediction) - prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy()))) - print("torch-mlir prediction") - print(prediction) - - -image_url = ( - "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" -) - -print("load image from " + image_url, file=sys.stderr) -img = load_and_preprocess_image(image_url) -labels = load_labels() - - -@make_simple_dynamo_backend -def refbackend_torchdynamo_backend( - fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor] -): - mlir_module = torchscript.compile( - fx_graph, example_inputs, output_type="linalg-on-tensors" - ) - backend = refbackend.RefBackendLinalgOnTensorsBackend() - compiled = backend.compile(mlir_module) - loaded = backend.load(compiled) - - def compiled_callable(*inputs): - inputs = [x.numpy() for x in inputs] - result = loaded.forward(*inputs) - if not isinstance(result, tuple): - result = torch.from_numpy(result) - else: - result = tuple(torch.from_numpy(x) for x in result) - return result - - return compiled_callable - - -resnet18 = models.resnet18(pretrained=True) -resnet18.train(False) -dynamo_callable = dynamo.optimize(refbackend_torchdynamo_backend)(resnet18) - -predictions( - resnet18.forward, - lambda x: dynamo_callable(torch.from_numpy(x)).detach().numpy(), - img, - labels, -) diff --git a/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py b/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py index db281fc8e748..526c4a72c24e 100644 --- a/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py +++ b/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py @@ -11,6 +11,6 @@ model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False ) with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: - outf.write(str(module)) + outf.write(module.operation.get_asm()) print(f"StableHLO IR of resent18 successfully written into {out_stablehlo_mlir_path}") diff --git a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py index 840ec519d5c8..f68d0cdbcf1d 100644 --- a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py +++ b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py @@ -25,6 +25,6 @@ def forward(self, data): model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True ) with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: - outf.write(str(module)) + outf.write(module.operation.get_asm()) print(f"StableHLO IR of tiny bert successfully written into {out_stablehlo_mlir_path}") From 261074f5948fb30b68981f736aec8f10871bb98c Mon Sep 17 00:00:00 2001 From: Angel Zhang <68571948+angelz913@users.noreply.github.com> Date: Fri, 10 May 2024 12:34:03 -0400 Subject: [PATCH 5/6] [ONNX] Handle one-input case for Min ONNX operator (#3326) This commit handles the one-input case for the "Min" ONNX operator. A new unit test has also been added. --- lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp | 2 +- test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index d419b0b5b74b..cd6f92d8094d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -667,7 +667,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); } - rewriter.replaceOp(binder.op, result.getDefiningOp()); + rewriter.replaceOp(binder.op, result); return success(); }); patterns.onOp("Neg", 1, diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 6041bae1cd09..991a7075c863 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -758,6 +758,15 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 // ----- +// CHECK-LABEL: func.func @test_min_one_input_example + func.func @test_min_one_input_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: return %arg0 : !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Min"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> + } + +// ----- + // CHECK-LABEL: func.func @test_mod_int64_fmod func.func @test_mod_int64_fmod(%arg0: !torch.vtensor<[6],si64>, %arg1: !torch.vtensor<[6],si64>) -> !torch.vtensor<[6],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64> -> !torch.vtensor<[6],si64> From 1d4859699b576b438d362b857bbb16922c04ee07 Mon Sep 17 00:00:00 2001 From: NeverRaR <44917563+NeverRaR@users.noreply.github.com> Date: Sat, 11 May 2024 00:35:26 +0800 Subject: [PATCH 6/6] MaxPool1d lowering to linalg (#3295) Co-authored-by: root --- lib/Conversion/TorchToLinalg/Pooling.cpp | 27 +++- .../Transforms/AbstractInterpLibrary.cpp | 4 + projects/pt1/e2e_testing/xfail_sets.py | 10 ++ .../build_tools/abstract_interp_lib_gen.py | 5 + .../torch_mlir_e2e_test/test_suite/pooling.py | 120 ++++++++++++++++++ test/Conversion/TorchToLinalg/pooling.mlir | 22 ++++ 6 files changed, 187 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 4157ef285888..70b27fd84f24 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -185,6 +185,12 @@ namespace { template struct DimensionTraits {}; +template <> struct DimensionTraits { + static constexpr int64_t Dim = 1; + // unused const variable warning suppression: + static_assert(Dim == Dim); +}; + template <> struct DimensionTraits { static constexpr int64_t Dim = 2; // unused const variable warning suppression: @@ -328,7 +334,24 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { Type elementType = cast(self.getType()).getElementType(); - if constexpr (Dim == 2) { + if constexpr (Dim == 1) { + SmallVector outTensorShape; + Value maxPool1d, paddedInput; + TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( + elementType, + APFloat::getInf( + cast(elementType).getFloatSemantics(), + /*Negative=*/true)); + if (failed(createPoolingOp( + op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, + /*dimensionality=*/1, kernelSizeIntValues, strideInts, + paddingInts, dilationInts, smallestFPValueAttr, outTensorShape, + paddedInput, maxPool1d))) + return rewriter.notifyMatchFailure(op, "unable to compute maxpool1d"); + Type newResultType = this->getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, maxPool1d); + return success(); + } else if constexpr (Dim == 2) { SmallVector outTensorShape; // `maxpool2d` contains the result of maxpool2d operation over the input. Value maxPool2d, paddedInput; @@ -1090,8 +1113,10 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index d1ecc6ed797c..43bcc3acc0eb 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10481,6 +10481,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool1d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3eee6b2d3727..3fcb272f423e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1078,6 +1078,8 @@ "Matmul_matvec", "Matmul_vecmat", "MatmulStaticBroadcast_basic", + "MaxPool1dStaticModule_basic", + "MaxPool1dEmptyStrideStaticModule_basic", "MaxPool2dStaticModule_basic", "MaxPool2dEmptyStrideStaticModule_basic", "MaxPool3dStaticModule_basic", @@ -1905,6 +1907,9 @@ TOSA_PASS_SET | { ### Tests additionally passing in make_fx_tosa + "MaxPool1dEmptyStrideStaticModule_basic", + "MaxPool1dStaticCeilModeTrueModule_basic", + "MaxPool1dStaticModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dStaticEvenMultiple_basic", @@ -2361,6 +2366,11 @@ "LinalgNormKeepDimComplexModule_basic", "LinalgVectorNormComplexModule_basic", "LogSoftmaxBackwardModule_basic", + "MaxPool1dCeilModeTrueModule_basic", + "MaxPool1dEmptyStrideStaticModule_basic", + "MaxPool1dModule_basic", + "MaxPool1dStaticCeilModeTrueModule_basic", + "MaxPool1dStaticModule_basic", "MaxPool2dCeilModeTrueModule_basic", "MaxPool2dModule_basic", "MaxPool2dWithIndicesAllOnesModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 60ffaa439421..1cf0c2c7696a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2612,6 +2612,11 @@ def aten〇masked_select〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dty self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5)], kernel_size=[2])) +def aten〇max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 50711afed3ef..69d813c917f0 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -157,6 +157,126 @@ def AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils): # ============================================================================== +class MaxPool1dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.mp1d = torch.nn.MaxPool1d( + kernel_size=[6], stride=[2], padding=[3], dilation=2 + ) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.mp1d(x) + + +@register_test_case(module_factory=lambda: MaxPool1dModule()) +def MaxPool1dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, low=-1)) + + +class MaxPool1dEmptyStrideStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 1, 20], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.max_pool1d(x, kernel_size=2, stride=[]) + + +@register_test_case(module_factory=lambda: MaxPool1dEmptyStrideStaticModule()) +def MaxPool1dEmptyStrideStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, low=-1)) + + +class MaxPool1dStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.mp1d = torch.nn.MaxPool1d( + kernel_size=[3], stride=[2], padding=[1], dilation=[1] + ) + + @export + @annotate_args( + [ + None, + ([1, 64, 112], torch.float32, True), + ] + ) + def forward(self, x): + return self.mp1d(x) + + +@register_test_case(module_factory=lambda: MaxPool1dStaticModule()) +def MaxPool1dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 64, 112)) + + +class MaxPool1dStaticCeilModeTrueModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.mp1d = torch.nn.MaxPool1d( + kernel_size=[3], stride=[2], padding=[1], dilation=[1], ceil_mode=True + ) + + @export + @annotate_args( + [ + None, + ([1, 64, 112], torch.float32, True), + ] + ) + def forward(self, x): + return self.mp1d(x) + + +@register_test_case(module_factory=lambda: MaxPool1dStaticCeilModeTrueModule()) +def MaxPool1dStaticCeilModeTrueModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 64, 112)) + + +class MaxPool1dCeilModeTrueModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.mp1d = torch.nn.MaxPool1d( + kernel_size=[6], stride=[2], padding=[3], dilation=2, ceil_mode=True + ) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.mp1d(x) + + +@register_test_case(module_factory=lambda: MaxPool1dCeilModeTrueModule()) +def MaxPool1dCeilModeTrueModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, low=0.5, high=1.0)) + + +# ============================================================================== + + class MaxPool2dModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 8a359ed5627d..494f603c296e 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -1,5 +1,27 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s +// CHECK-LABEL: func @forward_max_pool1d +func.func @forward_max_pool1d(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %false = torch.constant.bool false + // CHECK: %[[C1:.*]] = torch_c.to_i64 %int1 + // CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32 + // CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 3] high[0, 0, 3] + // CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor) -> tensor + // CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index + // CHECK: %[[INIT:.*]] = tensor.empty(%[[T1]]) : tensor + // CHECK: linalg.pooling_ncw_max {dilations = dense<4> : vector<1xi64>, strides = dense<2> : vector<1xi64>} ins(%[[PADDED]], %[[INIT]] : tensor, tensor) outs(%[[OUT]] : tensor) -> tensor + %kernel_size = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %stride = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list + %padding = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list + %dilation = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list + %4 = torch.aten.max_pool1d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?],f32> +} + // CHECK-LABEL: func @forward_max_pool2d func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %int1 = torch.constant.int 1