From 7cdea15db0db6009ef587fdaabb9d605a099c56e Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Thu, 28 Mar 2024 11:43:09 -0500 Subject: [PATCH] [ONNX] Fixes Issue with Dynamic Dims in GlobalAveragePool -> Torch Conversion (#3053) Two e2e tests (AdaptiveAveragePool1/2dUnitOutputSizeDynamic) were failing due to numerics. This was as a result of passing -1 as the kernel size in the lowering for the corresponding onnx op GlobalAveragePool. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 14 +++++++++++--- projects/pt1/e2e_testing/xfail_sets.py | 3 --- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index d7367a926de8..60738a579687 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -791,9 +791,17 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value cstOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); for (unsigned i = 2; i < inputRank; i++) { - int64_t kernelSize = inputShape[i] - resultShape[i] + 1; - cstKernel.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize))); + if (inputShape[i] == Torch::kUnknownSize) { + Value dim = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value inputDimSize = rewriter.create( + binder.getLoc(), operand, dim); + cstKernel.push_back(inputDimSize); + } else { + int64_t kernelSize = inputShape[i] - resultShape[i] + 1; + cstKernel.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize))); + } cstPadding.push_back(cstZero); cstStrides.push_back(cstOne); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a8e4649a96b8..cc88728fa642 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1661,8 +1661,6 @@ "PermuteNegativeIndexModule_basic", # Failure - incorrect numerics - "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", "ElementwiseAtan2TensorIntModule_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog2IntModule_basic", @@ -1672,7 +1670,6 @@ "HardsigmoidModule_basic", "HardsigmoidRandomModule_basic", "PixelShuffleModuleStaticRank4Float32_basic", - "ResNet18Module_basic", "SliceCopyEndGreaterThanDimSize_Module_basic", "SliceCopyNegative_Module_basic", "SliceCopyNonZeroDim_Module_basic",