Skip to content

Commit

Permalink
Merge pull request #140 from Xilinx/tiagot.improve_broadcastto_torch_…
Browse files Browse the repository at this point in the history
…to_tosa

feat(TorchToTosa): improve support for AtenBroadcastTo ops on different rank scenarios.
  • Loading branch information
ttjost authored Sep 29, 2023
2 parents de998c0 + ec34970 commit 0953522
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 9 deletions.
7 changes: 7 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,9 @@
"BroadcastToDifferentRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic",
"BroadcastListConstructWithMinusOneModule_basic",
"BroadcastDifferentRankSameFinalShapeModule_basic",
"BroadcastDifferentRankWithMinusOneModule_basic",
"BroadcastToDifferentRankNotOneStaticModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"CumsumStaticModule_basic",
Expand Down Expand Up @@ -1133,9 +1136,12 @@
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"BroadcastToDifferentRankStaticModule_basic",
"BroadcastToDifferentRankNotOneStaticModule_basic",
"BroadcastToSameRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic",
"BroadcastListConstructWithMinusOneModule_basic",
"BroadcastDifferentRankWithMinusOneModule_basic",
"BroadcastDifferentRankSameFinalShapeModule_basic",
"SliceStaticModule_basic",
"SliceSizeTwoStepDivisibleStaticModule_basic",
"SliceOutOfLowerBoundStartIndexStaticModule_basic",
Expand Down Expand Up @@ -1257,6 +1263,7 @@
"IndexSelectStaticModule_basic",
"LinalgVectorNormModule_basic",
"LinalgVectorNormKeepDimModule_basic",
"MatmulStaticBroadcast_basic",
"NormScalarOptDimKeepDimModule_basic",
"NormScalarOptDimModule_basic",
"NormalizeModule_basic",
Expand Down
36 changes: 27 additions & 9 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3437,26 +3437,44 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
// Get the result type
auto resultType = getTypeConverter()->convertType(op.getType());

int64_t numBroadcastedDims = resultShape.size() - selfType.getRank();
assert(numBroadcastedDims >= 0 &&
"numBroadcastedDims must be positive or zero.");

// Result dimension -1 means not changing the size of that dimension.
// Adjust it by assigning its inputShape according to the rank difference
// between input and result.
SmallVector<int64_t> inputShape(
makeShapeTorchCompatible(selfType.getShape()));
// Result dimension -1 means not changing the size of that dimension.
// Adjust it by assigning its inputShape.
for (auto shape : llvm::enumerate(makeShapeTorchCompatible(inputShape))) {
auto index = shape.index();
for (auto shape : llvm::enumerate(inputShape)) {
auto index = shape.index() + numBroadcastedDims;
if (resultShape[index] == -1)
resultShape[index] = shape.value();
}

// If there are still unknown dimensions, nothing can be done.
if (llvm::any_of(resultShape, [&](auto dim) { return dim == -1; })) {
return rewriter.notifyMatchFailure(
op, "cannot propagate unknown (-1) dimension "
"as it is not presented in the input.");
}

// Add 1 to each broadcasted dimension in the input.
// Broadcasted dimensions are the outermost ones.
SmallVector<int64_t> broadcastedDims(numBroadcastedDims, 1);
inputShape.insert(inputShape.begin(), broadcastedDims.begin(),
broadcastedDims.end());

// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
// true then we can replace the op result with the input operand directly.
if (llvm::equal(inputShape, resultShape)) {
if (llvm::equal(inputShape, resultShape) && !numBroadcastedDims) {
// If we reach here, then it means that the broadcasting is not required
// since the input and result are of same shape.
op.replaceAllUsesWith(op.getSelf());
rewriter.eraseOp(op);
return success();
} else if (selfType.hasRank() &&
(selfType.getRank() == (int64_t)resultShape.size() ||
selfType.getRank() == 0)) {
} else if (selfType.hasRank() && (inputShape.size() == resultShape.size() ||
selfType.getRank() == 0)) {
// Right now to support limited cases where input and result shape are not
// equal, we can put a constraint that either the input should be of rank
// 0 or the rank of input tensor and result should be equal. And then we
Expand All @@ -3469,7 +3487,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
resultShape[i] != 1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: either the shape of input and result should "
"be equal at each dimenion or one of them should be 1.");
"be equal at each dimension or one of them should be 1.");
}
}
}
Expand Down
60 changes: 60 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,26 @@ def forward(self, x):
def BroadcastToDifferentRankStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 8))

# ==============================================================================

class BroadcastToDifferentRankNotOneStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([2, 8], torch.float32, True),
])
def forward(self, x):
return torch.broadcast_to(x, [10, 2, 8])


@register_test_case(module_factory=lambda: BroadcastToDifferentRankNotOneStaticModule())
def BroadcastToDifferentRankNotOneStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 8))


# ==============================================================================

Expand Down Expand Up @@ -1420,6 +1440,46 @@ def BroadcastListConstructWithMinusOneModule_basic(module, tu: TestUtils):

# ==============================================================================

class BroadcastDifferentRankWithMinusOneModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 1, 8], torch.float32, True),
])
def forward(self, x):
return torch.broadcast_to(x, [10, -1, -1, -1])


@register_test_case(module_factory=lambda: BroadcastDifferentRankWithMinusOneModule())
def BroadcastDifferentRankWithMinusOneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 8))

# ==============================================================================

class BroadcastDifferentRankSameFinalShapeModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 1, 8], torch.float32, True),
])
def forward(self, x):
return torch.broadcast_to(x, [1, -1, -1, -1])


@register_test_case(module_factory=lambda: BroadcastDifferentRankSameFinalShapeModule())
def BroadcastDifferentRankSameFinalShapeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 8))

# ==============================================================================

class RollModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit 0953522

Please sign in to comment.