diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 2e3f3e8b8053..f998240b3472 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -306,7 +306,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( - "AveragePool", 19, + "AveragePool", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; SmallVector dilation; @@ -357,7 +357,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "padding list size does not match twice the number of axes"); } - if (binder.s64IntegerArrayAttr(strides, "strides", {1})) { + if (binder.s64IntegerArrayAttr( + strides, "strides", llvm::SmallVector(rank - 2, 1))) { return failure(); } if (strides.size() != 1 && strides.size() != rank - 2) { diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index e795d2ea9fb8..283ac42ca6c5 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -114,8 +114,22 @@ static Value padInputTensor(Operation *op, ConversionPatternRewriter &rewriter, SmallVectorImpl &paddingInts, Value initValue) { SmallVector lowPaddingIncludingNC = {0, 0}; - lowPaddingIncludingNC.append(paddingInts); - SmallVector highPaddingIncludingNC = lowPaddingIncludingNC; + SmallVector highPaddingIncludingNC = {0, 0}; + + unsigned selfRank = self.getType().cast().getRank(); + unsigned paddingIntsSize = paddingInts.size(); + + if (paddingIntsSize == 2 * (selfRank - 2)) { + // This condition being true means that the `paddingInts` contain seperate + // values for low padding and high padding. + for (unsigned i = 0; i < paddingIntsSize / 2; i++) + lowPaddingIncludingNC.push_back(paddingInts[i]); + for (unsigned i = paddingIntsSize / 2; i < paddingIntsSize; i++) + highPaddingIncludingNC.push_back(paddingInts[i]); + } else { + lowPaddingIncludingNC.append(paddingInts); + highPaddingIncludingNC = lowPaddingIncludingNC; + } if (ceilMode) { for (int64_t i = 0; i < dimensionality; ++i) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 671df14b3d34..5d9db7558d80 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2044,17 +2044,7 @@ "LinalgNormModule_basic", # Failure - onnx_lowering: onnx.AveragePool - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool1dStaticEvenMultiple_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AvgPool1dFloatModule_basic", - "AvgPool1dIntModule_basic", - "AvgPool1dStaticModule_basic", - "AvgPool2dCeilModeTrueModule_basic", "AvgPool2dDivisorOverrideModule_basic", - "AvgPool2dFloatModule_basic", - "AvgPool2dIntModule_basic", - "AvgPool2dStaticModule_basic", # Failure - onnx_lowering: onnx.Cast "BucketizeTensorOutInt32RightModule_basic",