Skip to content

Commit

Permalink
Merge pull request #189 from Xilinx/matthias.backport_avgpool
Browse files Browse the repository at this point in the history
Backport: [ONNX] Fixes Issue with Dynamic Dims in GlobalAveragePool -> Torch Conversion
  • Loading branch information
mgehre-amd authored Jul 4, 2024
2 parents 3cbdae8 + 7cdea15 commit e6c7403
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
14 changes: 11 additions & 3 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -791,9 +791,17 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
for (unsigned i = 2; i < inputRank; i++) {
int64_t kernelSize = inputShape[i] - resultShape[i] + 1;
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize)));
if (inputShape[i] == Torch::kUnknownSize) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i));
Value inputDimSize = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), operand, dim);
cstKernel.push_back(inputDimSize);
} else {
int64_t kernelSize = inputShape[i] - resultShape[i] + 1;
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize)));
}
cstPadding.push_back(cstZero);
cstStrides.push_back(cstOne);
}
Expand Down
3 changes: 0 additions & 3 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,8 +1661,6 @@
"PermuteNegativeIndexModule_basic",

# Failure - incorrect numerics
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
"ElementwiseAtan2TensorIntModule_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog2IntModule_basic",
Expand All @@ -1672,7 +1670,6 @@
"HardsigmoidModule_basic",
"HardsigmoidRandomModule_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"ResNet18Module_basic",
"SliceCopyEndGreaterThanDimSize_Module_basic",
"SliceCopyNegative_Module_basic",
"SliceCopyNonZeroDim_Module_basic",
Expand Down

0 comments on commit e6c7403

Please sign in to comment.