Skip to content

Commit

Permalink
Merge pull request #448 from Xilinx/bump_to_2665ed34
Browse files Browse the repository at this point in the history
[AutoBump] Merge with 2665ed3 (Oct 10) (78)
  • Loading branch information
mgehre-amd authored Jan 10, 2025
2 parents 88ae56b + 53c50dd commit 998f4c8
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 12 deletions.
158 changes: 146 additions & 12 deletions lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,139 @@ class FoldAtenUnsqueezePattern : public OpRewritePattern<AtenUnsqueezeOp> {
none, none, none, none);
return success();
}
auto squeezeOp = op.getSelf().getDefiningOp<AtenSqueezeDimOp>();
if (squeezeOp && resultTy.getSizes().size() == 1) {
rewriter.replaceOp(op, squeezeOp.getSelf());
return success();
}

return failure();
}
};
} // namespace

namespace {
// This is a specific pattern for converting views like [?,...,?,lastDim] ->
// [?,...,?,factor0,factor1] to unflatten, and views like
// [?,...,?,factor0,factor1] -> [?,...,?,lastDim] to flatten, whenever it is
// possible to infer that all but last shared dim match
// TODO: move this to an actual canonicalizer for view after deleting the
// conflicting decompositions for flatten/unflatten -> view.
class CanonicalizeAtenViewPattern : public OpRewritePattern<AtenViewOp> {
public:
using OpRewritePattern<AtenViewOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenViewOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> viewSizes;
if (failed(getListOperands(op.getSize(), viewSizes)))
return rewriter.notifyMatchFailure(
op, "view size must be from a list construct");
auto selfTy = dyn_cast<Torch::ValueTensorType>(op.getSelf().getType());
if (!selfTy || !selfTy.hasSizes())
return rewriter.notifyMatchFailure(op, "missing input type or sizes");
auto resultTy = dyn_cast<Torch::ValueTensorType>(op.getType());
if (!resultTy || !resultTy.hasSizes() ||
resultTy.getSizes().size() != viewSizes.size())
return rewriter.notifyMatchFailure(op, "missing result type or sizes");
int64_t inRank = selfTy.getSizes().size();
int64_t outRank = resultTy.getSizes().size();

SmallVector<int64_t> sizes(selfTy.getSizes());
int64_t endMatchingDim = -1;
// input sizes vs. provided view sizes comparison loop
for (int64_t i = 0; i < std::min(outRank, inRank); i++) {
int64_t providedSize;
bool providedStatic =
matchPattern(viewSizes[i], m_TorchConstantInt(&providedSize));
// if sizes[i] is static, it must match a constant in viewSizes[i]
if (sizes[i] != Torch::kUnknownSize) {
if (!providedStatic)
return rewriter.notifyMatchFailure(
op, "unsupported: found static input dim, but unable to match "
"provided view size on a constant. See position : " +
std::to_string(i));
if (providedSize != sizes[i]) {
endMatchingDim = i;
break;
}
continue;
}
// the remaining assumes sizes[i] is dynamic
// if provided dim is static, we can't verify it is a flatten/unflatten
// unless -1
if (i == outRank - 1 && providedStatic && providedSize == -1) {
endMatchingDim = i;
break;
}
if (providedStatic)
return rewriter.notifyMatchFailure(
op, "unexpected static view dim corresponding to dynamic input dim "
"at position : " +
std::to_string(i));
auto sizeIntOp = viewSizes[i].getDefiningOp<AtenSizeIntOp>();
// if we don't have a size int op on self, fail
if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf())
return rewriter.notifyMatchFailure(
op, "expected dynamic view dim to come from a corresponding "
"size.int op. See position : " +
std::to_string(i));
int64_t dim;
// if the dim of the size int op doesn't match, fail
if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) ||
dim != i)
return rewriter.notifyMatchFailure(
op,
"size int op dim cannot be matched to current dim at position : " +
std::to_string(i));
// passing the previous checks means viewSizes[i] = aten.size.int(self,
// i), so continue
}
// if all dims match and the ranks are equal, fold
if (endMatchingDim == -1 && inRank == outRank) {
rewriter.replaceOp(op, op.getSelf());
return success();
}
if (endMatchingDim > -1 && inRank > outRank) {
// only support flattening last dim
if (endMatchingDim != outRank - 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: output has more than back dim mismatching");
// flatten
Value start =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
Value end =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), inRank - 1);
rewriter.replaceOpWithNewOp<AtenFlattenUsingIntsOp>(
op, resultTy, op.getSelf(), start, end);
return success();
}
if (endMatchingDim > -1 && inRank < outRank) {
// only support unflattening last dim
if (endMatchingDim != inRank - 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: input has more than back dim mismatching");
// unflatten
Value dim =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
Value primList = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(), op.getSize().getType(),
ArrayRef<Value>(viewSizes.begin() + endMatchingDim, viewSizes.end()));
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
op, resultTy, op.getSelf(), dim, primList);
return success();
}
// examples that might reach this:
// input shape = [10, 5]; view sizes = [5, 10] (or dynamic variants)
// input shape = [dim0, dim1]; view sizes = [dim0, dim1, 1, 1] (unsqueezes)
// input shape = [dim0, dim1, 1, 1] view sizes = [dim0, dim1] (squeezes)
return rewriter.notifyMatchFailure(
op, "unhandled case: endMatchingDim=" + std::to_string(endMatchingDim) +
", inRank=" + std::to_string(inRank) +
", outRank=" + std::to_string(outRank));
}
};
} // namespace

namespace {
template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
public:
Expand All @@ -561,18 +689,24 @@ class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns
.insert<PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
FoldAtenWhereSelf, RemoveUnusedPattern<Torch::AtenSizeIntOp>,
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
RemoveUnusedPattern<Torch::AtenTensorOp>,
RemoveUnusedPattern<Torch::ConstantBoolOp>,
RemoveUnusedPattern<Torch::ConstantIntOp>,
RemoveUnusedPattern<Torch::ConstantNoneOp>,
RemoveUnusedPattern<Torch::PrimListConstructOp>>(context);
patterns.insert<PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
FoldAtenWhereSelf, CanonicalizeAtenViewPattern,
RemoveUnusedPattern<Torch::AtenIntBoolOp>,
RemoveUnusedPattern<Torch::AtenEqIntOp>,
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
RemoveUnusedPattern<Torch::AtenFullOp>,
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
RemoveUnusedPattern<Torch::AtenSqueezeDimOp>,
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
RemoveUnusedPattern<Torch::AtenTensorOp>,
RemoveUnusedPattern<Torch::ConstantBoolOp>,
RemoveUnusedPattern<Torch::ConstantIntOp>,
RemoveUnusedPattern<Torch::ConstantNoneOp>,
RemoveUnusedPattern<Torch::PrimListConstructOp>>(context);

context->getLoadedDialect<mlir::arith::ArithDialect>()
->getCanonicalizationPatterns(patterns);
Expand Down
88 changes: 88 additions & 0 deletions test/Dialect/Torch/scalarize-shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,91 @@ func.func @shape_as_tensor_slice(%arg0 : !torch.vtensor<[5,?,?,?],f32>) -> !torc
%slice = torch.aten.slice.Tensor %shape, %dim, %start, %end, %step : !torch.vtensor<[4], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32>
return %slice : !torch.vtensor<[2],si32>
}


// -----

// CHECK-LABEL: @view_as_flatten_static
func.func @view_as_flatten_static(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,1024],f32> {
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
// CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3
// CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32>
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,1024],f32>
%int1024 = torch.constant.int 1024
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
%2 = torch.prim.ListConstruct %0, %1, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,16,64],f32>, !torch.list<int> -> !torch.vtensor<[?,?,1024],f32>
return %3 : !torch.vtensor<[?,?,1024],f32>
}


// -----

// CHECK-LABEL: @view_as_unflatten_static
func.func @view_as_unflatten_static(%arg0: !torch.vtensor<[?,?,1024],f32>) -> !torch.vtensor<[?,?,16,64],f32> {
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
// CHECK-DAG: %[[CST16:.*]] = torch.constant.int 16
// CHECK-DAG: %[[CST64:.*]] = torch.constant.int 64
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[CST16]], %[[CST64]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[FLAT:.*]] = torch.aten.unflatten.int %arg0, %[[TWO]], %[[LIST]] : !torch.vtensor<[?,?,1024],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,?,16,64],f32>
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,16,64],f32>
%int16 = torch.constant.int 16
%int64 = torch.constant.int 64
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int
%2 = torch.prim.ListConstruct %0, %1, %int16, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,1024],f32>, !torch.list<int> -> !torch.vtensor<[?,?,16,64],f32>
return %3 : !torch.vtensor<[?,?,16,64],f32>
}


// -----

// CHECK-LABEL: @view_as_flatten_dynamic
func.func @view_as_flatten_dynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
// CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3
// CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,?],f32>
%int-1 = torch.constant.int -1
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
%2 = torch.prim.ListConstruct %0, %1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?],f32>
return %3 : !torch.vtensor<[?,?,?],f32>
}


// -----

// CHECK-LABEL: @unsqueeze_squeeze_combo
func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.int {
// CHECK: %int0 = torch.constant.int 0
// CHECK: %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
// CHECK: return %0 : !torch.int
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%1 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%2 = torch.vtensor.literal(dense<1024> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64>
%4 = torch.aten.index_select %3, %int0, %1 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%5 = torch.aten.squeeze.dim %4, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%6 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64>
%7 = torch.aten.index_select %6, %int0, %0 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%8 = torch.aten.squeeze.dim %7, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%9 = torch.aten.unsqueeze %5, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.unsqueeze %8, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%11 = torch.prim.ListConstruct %9, %10, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%12 = torch.aten.cat %11, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%13 = torch.aten.slice.Tensor %12, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int
return %14 : !torch.int
}

0 comments on commit 998f4c8

Please sign in to comment.