Skip to content

Commit

Permalink
Degeneralize transpose (#873)
Browse files Browse the repository at this point in the history
  • Loading branch information
chelini authored Jan 31, 2024
1 parent 07d5cb3 commit a8a1cab
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 2 deletions.
7 changes: 7 additions & 0 deletions include/TPP/IR/StructuredOpMatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ struct BroadcastableProjectedPermutation {
}
};

// Callable object to verify if `map` is a projected permutation.
struct ProjectedPermutation {
ProjectedPermutation() = default;

bool operator()(AffineMap map) const { return map.isProjectedPermutation(); }
};

// Callable object to verify if `map` is an identity map.
struct Identity {
Identity() = default;
Expand Down
66 changes: 65 additions & 1 deletion lib/TPP/Transforms/LinalgDeGeneralize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,69 @@ struct BatchReduceOpDeGeneralizationPattern
}
};

// From linalg.generic to linalg.transpose.
struct TransposeOpPattern : public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;

bool isIdentityPermutation(ArrayRef<int64_t> permutation) const {
for (auto i : llvm::seq<int64_t>(0, permutation.size()))
if (permutation[i] != i)
return false;
return true;
}

FailureOr<SmallVector<int64_t>>
getPermutationFromMap(AffineMap map, int64_t numLoops) const {
assert(map.isProjectedPermutation());
if (numLoops != map.getNumResults())
return failure();

SmallVector<int64_t> perm;
for (auto dim : llvm::seq<int64_t>(0, numLoops)) {
auto dimExpr = getAffineDimExpr(dim, map.getContext());
for (auto [idx, result] : llvm::enumerate(map.getResults())) {
if (result == dimExpr)
perm.push_back(idx);
}
}

if (isIdentityPermutation(perm))
return failure();
return perm;
}

FailureOr<SmallVector<int64_t>>
isTransposeOp(linalg::GenericOp linalgOp) const {
using namespace mlir::structured_match;
AffineMap inputMap;
auto transposeMatcher =
StructuredOpMatcher::make<linalg::GenericOp>()
.operation(NumDpsInits(EqualsTo(1)))
.operation(NumDpsInputs(EqualsTo(1)))
.operation(NumRegions(EqualsTo(1)))
.dim(MatchAll(), mlir::utils::IteratorType::parallel)
.input(MatchOne(0), HasMap(ProjectedPermutation(), &inputMap))
.output(MatchOne(0), HasMap(Identity()))
.region(MatchOne(0),
WithSingleOp<linalg::YieldOp>(/*captures=*/nullptr));
if (!transposeMatcher.match(linalgOp))
return failure();
return getPermutationFromMap(inputMap, linalgOp.getNumLoops());
}

LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
PatternRewriter &rewriter) const override {
auto maybePerm = isTransposeOp(linalgOp);
if (failed(maybePerm))
return failure();
Value inputOperand = linalgOp.getDpsInputs()[0];
Value outputOperand = linalgOp.getDpsInits()[0];
rewriter.replaceOpWithNewOp<linalg::TransposeOp>(linalgOp, inputOperand,
outputOperand, *maybePerm);
return success();
}
};

// From linalg.generic to linalg.fillOp.
struct FillOpDeGeneralizationPattern
: public OpRewritePattern<linalg::GenericOp> {
Expand Down Expand Up @@ -156,5 +219,6 @@ struct FillOpDeGeneralizationPattern
void mlir::linalg::populateLinalgDeGeneralizationPatterns(
RewritePatternSet &patterns) {
patterns.add<FillOpDeGeneralizationPattern, MatmulOpDeGeneralizationPattern,
BatchReduceOpDeGeneralizationPattern>(patterns.getContext());
BatchReduceOpDeGeneralizationPattern, TransposeOpPattern>(
patterns.getContext());
}
89 changes: 88 additions & 1 deletion test/Passes/pass-degeneralize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,97 @@ func.func @degeneralize(%arg0: tensor<3x3x3xf32>, %arg1: tensor<3x3x3xf32>) -> t
return %2 : tensor<3x3xf32>
}

// CHECK: degeneralize
// CHECK-LABEL: degeneralize
// CHECK-SAME: %[[ARG0:.+]]: tensor<3x3x3xf32>, %[[ARG1:.+]]: tensor<3x3x3xf32>
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x3xf32>
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY]] : tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK: %{{.+}} = linalg.batch_reduce_matmul ins(%[[ARG0]], %[[ARG1]] : tensor<3x3x3xf32>, tensor<3x3x3xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<3x3xf32>) -> tensor<3x3xf32>

// -----

#map = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>

func.func @transpose_degeneralize(%arg0 : tensor<128x262144xf32>, %arg1: tensor<262144x128xf32>)
-> tensor<262144x128xf32> {
%0 = linalg.generic {
indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<128x262144xf32>) outs(%arg1 : tensor<262144x128xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<262144x128xf32>
return %0 : tensor<262144x128xf32>
}

// CHECK-LABEL: transpose_degeneralize
// CHECK-SAME: %[[ARG0:.+]]: tensor<128x262144xf32>, %[[ARG1:.+]]: tensor<262144x128xf32>
// CHECK: %[[T:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<128x262144xf32>) outs(%[[ARG1]] : tensor<262144x128xf32>)
// CHECK-SAME: permutation = [1, 0]
// CHECK: return %[[T]] : tensor<262144x128xf32>

// -----

#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

func.func @transpose_degeneralize_1(%arg0 : tensor<1x2x3x4xf32>, %arg1 : tensor<1x3x2x4xf32>)
-> tensor<1x3x2x4xf32> {
%0 = linalg.generic {
indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<1x2x3x4xf32>) outs(%arg1 : tensor<1x3x2x4xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x3x2x4xf32>
return %0 : tensor<1x3x2x4xf32>
}

// CHECK-LABEL: transpose_degeneralize_1
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x3x4xf32>, %[[ARG1:.+]]: tensor<1x3x2x4xf32>
// CHECK: %[[T:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<1x2x3x4xf32>) outs(%[[ARG1]] : tensor<1x3x2x4xf32>)
// CHECK-SAME: permutation = [0, 2, 1, 3]
// CHECK: return %[[T]] : tensor<1x3x2x4xf32>

// -----

#map = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>

func.func @transpose_degeneralize_memref(%arg0 : memref<128x262144xf32>, %arg1: memref<262144x128xf32>) {
linalg.generic {
indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : memref<128x262144xf32>) outs(%arg1 : memref<262144x128xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
}
return
}

// CHECK-LABEL: transpose_degeneralize_memref
// CHECK-SAME: %[[ARG0:.+]]: memref<128x262144xf32>, %[[ARG1:.+]]: memref<262144x128xf32>
// CHECK: linalg.transpose ins(%[[ARG0]] : memref<128x262144xf32>) outs(%[[ARG1]] : memref<262144x128xf32>)
// CHECK-SAME: permutation = [1, 0]

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>

func.func @transpose_degeneralize_copy(%arg0 : tensor<128x262144xf32>, %arg1: tensor<128x262144xf32>)
-> tensor<128x262144xf32> {
%0 = linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<128x262144xf32>) outs(%arg1 : tensor<128x262144xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<128x262144xf32>
return %0 : tensor<128x262144xf32>
}

// CHECK-LABEL: transpose_degeneralize_copy
// CHECK-NOT: linalg.transpose
// CHECK: linalg.generic

0 comments on commit a8a1cab

Please sign in to comment.