From ec349707c70db2c5005d6db364041987fdcf04de Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Thu, 28 Sep 2023 16:05:16 +0000 Subject: [PATCH] feat(TorchToTosa): improve support for AtenBroadcastTo ops on different rank scenarios. --- e2e_testing/xfail_sets.py | 7 +++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 36 ++++++++--- .../torch_mlir_e2e_test/test_suite/basic.py | 60 +++++++++++++++++++ 3 files changed, 94 insertions(+), 9 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 204632619deb..d6ddb62fbc2b 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -409,6 +409,9 @@ "BroadcastToDifferentRankStaticModule_basic", "BroadcastZeroRankInputStaticModule_basic", "BroadcastListConstructWithMinusOneModule_basic", + "BroadcastDifferentRankSameFinalShapeModule_basic", + "BroadcastDifferentRankWithMinusOneModule_basic", + "BroadcastToDifferentRankNotOneStaticModule_basic", "BucketizeTensorStaticFloatModule_basic", "BucketizeTensorStaticModule_basic", "CumsumStaticModule_basic", @@ -1133,9 +1136,12 @@ "ReduceSumDtypeFloatModule_basic", "ReduceSumDtypeIntModule_basic", "BroadcastToDifferentRankStaticModule_basic", + "BroadcastToDifferentRankNotOneStaticModule_basic", "BroadcastToSameRankStaticModule_basic", "BroadcastZeroRankInputStaticModule_basic", "BroadcastListConstructWithMinusOneModule_basic", + "BroadcastDifferentRankWithMinusOneModule_basic", + "BroadcastDifferentRankSameFinalShapeModule_basic", "SliceStaticModule_basic", "SliceSizeTwoStepDivisibleStaticModule_basic", "SliceOutOfLowerBoundStartIndexStaticModule_basic", @@ -1257,6 +1263,7 @@ "IndexSelectStaticModule_basic", "LinalgVectorNormModule_basic", "LinalgVectorNormKeepDimModule_basic", + "MatmulStaticBroadcast_basic", "NormScalarOptDimKeepDimModule_basic", "NormScalarOptDimModule_basic", "NormalizeModule_basic", diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 4e19c700482b..a8498a83bba2 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3437,26 +3437,44 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Get the result type auto resultType = getTypeConverter()->convertType(op.getType()); + int64_t numBroadcastedDims = resultShape.size() - selfType.getRank(); + assert(numBroadcastedDims >= 0 && + "numBroadcastedDims must be positive or zero."); + + // Result dimension -1 means not changing the size of that dimension. + // Adjust it by assigning its inputShape according to the rank difference + // between input and result. SmallVector inputShape( makeShapeTorchCompatible(selfType.getShape())); - // Result dimension -1 means not changing the size of that dimension. - // Adjust it by assigning its inputShape. - for (auto shape : llvm::enumerate(makeShapeTorchCompatible(inputShape))) { - auto index = shape.index(); + for (auto shape : llvm::enumerate(inputShape)) { + auto index = shape.index() + numBroadcastedDims; if (resultShape[index] == -1) resultShape[index] = shape.value(); } + + // If there are still unknown dimensions, nothing can be done. + if (llvm::any_of(resultShape, [&](auto dim) { return dim == -1; })) { + return rewriter.notifyMatchFailure( + op, "cannot propagate unknown (-1) dimension " + "as it is not presented in the input."); + } + + // Add 1 to each broadcasted dimension in the input. + // Broadcasted dimensions are the outermost ones. + SmallVector broadcastedDims(numBroadcastedDims, 1); + inputShape.insert(inputShape.begin(), broadcastedDims.begin(), + broadcastedDims.end()); + // Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is // true then we can replace the op result with the input operand directly. - if (llvm::equal(inputShape, resultShape)) { + if (llvm::equal(inputShape, resultShape) && !numBroadcastedDims) { // If we reach here, then it means that the broadcasting is not required // since the input and result are of same shape. op.replaceAllUsesWith(op.getSelf()); rewriter.eraseOp(op); return success(); - } else if (selfType.hasRank() && - (selfType.getRank() == (int64_t)resultShape.size() || - selfType.getRank() == 0)) { + } else if (selfType.hasRank() && (inputShape.size() == resultShape.size() || + selfType.getRank() == 0)) { // Right now to support limited cases where input and result shape are not // equal, we can put a constraint that either the input should be of rank // 0 or the rank of input tensor and result should be equal. And then we @@ -3469,7 +3487,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( resultShape[i] != 1) { return rewriter.notifyMatchFailure( op, "unimplemented: either the shape of input and result should " - "be equal at each dimenion or one of them should be 1."); + "be equal at each dimension or one of them should be 1."); } } } diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index c8ddc655932c..43992573e8fc 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1348,6 +1348,26 @@ def forward(self, x): def BroadcastToDifferentRankStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 8)) +# ============================================================================== + +class BroadcastToDifferentRankNotOneStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 8], torch.float32, True), + ]) + def forward(self, x): + return torch.broadcast_to(x, [10, 2, 8]) + + +@register_test_case(module_factory=lambda: BroadcastToDifferentRankNotOneStaticModule()) +def BroadcastToDifferentRankNotOneStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 8)) + # ============================================================================== @@ -1420,6 +1440,46 @@ def BroadcastListConstructWithMinusOneModule_basic(module, tu: TestUtils): # ============================================================================== +class BroadcastDifferentRankWithMinusOneModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 1, 8], torch.float32, True), + ]) + def forward(self, x): + return torch.broadcast_to(x, [10, -1, -1, -1]) + + +@register_test_case(module_factory=lambda: BroadcastDifferentRankWithMinusOneModule()) +def BroadcastDifferentRankWithMinusOneModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 8)) + +# ============================================================================== + +class BroadcastDifferentRankSameFinalShapeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 1, 8], torch.float32, True), + ]) + def forward(self, x): + return torch.broadcast_to(x, [1, -1, -1, -1]) + + +@register_test_case(module_factory=lambda: BroadcastDifferentRankSameFinalShapeModule()) +def BroadcastDifferentRankSameFinalShapeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 8)) + +# ============================================================================== + class RollModule(torch.nn.Module): def __init__(self):