forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Preprocessing] Add a one-off pattern to fuse attention with transpos…
…e. (iree-org#17901) The attention ops in SDXL models are usually followed by a `tensor.expand_shape` and a `transpose`. It is more natural to fold these in with the attention for codegeneration. This is done as a one-off pattern for now. Ideally the attention ops can be fused with any of its elementwise consumers when attention is handled natively by the backend pass-pipelines. More details are in iree-org#17673. Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
- Loading branch information
1 parent
4de493a
commit 7ce8c8e
Showing
7 changed files
with
321 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
204 changes: 204 additions & 0 deletions
204
compiler/src/iree/compiler/Preprocessing/Common/FoldAttentionWithTranspose.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
// Copyright 2020 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" | ||
#include "iree/compiler/Preprocessing/Common/Passes.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir::iree_compiler::Preprocessing { | ||
|
||
#define GEN_PASS_DEF_FOLDATTENTIONWITHTRANSPOSEPASS | ||
#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export | ||
|
||
namespace { | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Attention -> Transpose fusion | ||
//===----------------------------------------------------------------------===// | ||
|
||
/// Pattern to fold | ||
/// | ||
/// ```mlir | ||
/// %0 = iree_linalg_ext.attention { | ||
/// indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, | ||
/// affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, | ||
/// affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, | ||
/// affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} | ||
/// ins(%query, %key, %value) .... | ||
/// %1 = tensor.expand_shape %0 into [[0, 1], [2], [3]] .... | ||
/// %2 = linalg.generic { | ||
/// indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, | ||
/// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>]} | ||
/// ins(%1) | ||
/// ``` | ||
/// | ||
/// to | ||
/// | ||
/// ``` | ||
/// %0 = iree_linalg_ext.attention { | ||
/// indexing_maps = [affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d1, | ||
/// d2)>, | ||
/// affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d3, | ||
/// d2)>, affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, | ||
/// d3, d4)>, affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, | ||
/// d1, d00, d4)>]} | ||
/// ins(%expanded_query, %expanded_key, %expanded_value) .... | ||
/// ``` | ||
/// | ||
/// Do a very specific match for now. Eventually this can be generalized to a | ||
/// use similar analysis as to what the reshape propagation across Linalg op | ||
/// does. TODO(#17673) | ||
/// | ||
struct FoldAttentionAndTranspose | ||
: public OpRewritePattern<IREE::LinalgExt::AttentionOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(IREE::LinalgExt::AttentionOp attentionOp, | ||
PatternRewriter &rewriter) const override { | ||
// Check for single use attention op. | ||
if (!attentionOp->hasOneUse()) { | ||
return rewriter.notifyMatchFailure(attentionOp, | ||
"attention op has multiple uses"); | ||
} | ||
auto expandShapeOp = | ||
dyn_cast<tensor::ExpandShapeOp>(*attentionOp->user_begin()); | ||
if (!expandShapeOp) { | ||
return rewriter.notifyMatchFailure(attentionOp, | ||
"user is not an expand shape op."); | ||
} | ||
// Check for single use of expand shape op. | ||
if (!expandShapeOp->hasOneUse()) { | ||
return rewriter.notifyMatchFailure(attentionOp, | ||
"expand shape op has multiple uses"); | ||
} | ||
auto transposeLikeOp = | ||
dyn_cast<linalg::LinalgOp>(*expandShapeOp->user_begin()); | ||
if (!transposeLikeOp) { | ||
return failure(); | ||
} | ||
if (!(transposeLikeOp.getNumDpsInputs() == 1 && | ||
transposeLikeOp.getNumDpsInits() == 1 && | ||
transposeLikeOp.getBlock() | ||
->front() | ||
.hasTrait<OpTrait::IsTerminator>() && | ||
transposeLikeOp.getNumLoops() == | ||
transposeLikeOp.getNumParallelLoops())) { | ||
return rewriter.notifyMatchFailure( | ||
transposeLikeOp, "expand shape user is not a transpose"); | ||
} | ||
|
||
// Check attention op indexing maps. | ||
AffineExpr d0, d1, d2, d3, d4, d5; | ||
bindDims(rewriter.getContext(), d0, d1, d2, d3, d4, d5); | ||
auto getIndexingMap = [&](int n, ArrayRef<AffineExpr> results) { | ||
return AffineMap::get(n, 0, results, rewriter.getContext()); | ||
}; | ||
SmallVector<AffineMap> expectedMaps = { | ||
getIndexingMap(5, {d0, d1, d2}), getIndexingMap(5, {d0, d3, d2}), | ||
getIndexingMap(5, {d0, d3, d4}), getIndexingMap(5, {d0, d1, d4})}; | ||
if (attentionOp.getIndexingMapsArray() != expectedMaps) { | ||
return rewriter.notifyMatchFailure( | ||
attentionOp, "mismatch in expected maps, and maps on attention op"); | ||
} | ||
|
||
// Check reassociation indexing map. | ||
SmallVector<ReassociationIndices> reassociation = | ||
expandShapeOp.getReassociationIndices(); | ||
SmallVector<ReassociationIndices> expectedReassocation = {{0, 1}, {2}, {3}}; | ||
if (reassociation != expectedReassocation) { | ||
return rewriter.notifyMatchFailure(expandShapeOp, | ||
"unhandled reassocation"); | ||
} | ||
|
||
// Check the permutation maps for the transpose. | ||
SmallVector<AffineMap> expectedTransposeMaps = { | ||
getIndexingMap(4, {d0, d1, d2, d3}), | ||
getIndexingMap(4, {d0, d2, d1, d3})}; | ||
if (transposeLikeOp.getIndexingMapsArray() != expectedTransposeMaps) { | ||
return rewriter.notifyMatchFailure(transposeLikeOp, | ||
"unhandled transpose op"); | ||
} | ||
|
||
Location loc = attentionOp.getLoc(); | ||
OpBuilder::InsertionGuard g(rewriter); | ||
rewriter.setInsertionPoint(transposeLikeOp); | ||
|
||
SmallVector<OpFoldResult> expandedResultShape = | ||
tensor::getMixedSizes(rewriter, loc, expandShapeOp); | ||
OpFoldResult dim0_split0 = expandedResultShape[0]; | ||
OpFoldResult dim0_split1 = expandedResultShape[1]; | ||
OpFoldResult dim1 = expandedResultShape[2]; | ||
OpFoldResult dim2 = | ||
tensor::getMixedSize(rewriter, loc, attentionOp.getKey(), 2); | ||
OpFoldResult dim3 = | ||
tensor::getMixedSize(rewriter, loc, attentionOp.getKey(), 1); | ||
OpFoldResult dim4 = expandedResultShape[3]; | ||
|
||
SmallVector<OpFoldResult> newQuerySizes = {}; | ||
SmallVector<Value> tmp; | ||
SmallVector<int64_t> newQueryShape; | ||
dispatchIndexOpFoldResults(newQuerySizes, tmp, newQueryShape); | ||
|
||
auto getReshape = [&](Value v, ArrayRef<ReassociationIndices> reassociation, | ||
ArrayRef<OpFoldResult> outputShape) -> Value { | ||
SmallVector<int64_t> staticShape; | ||
SmallVector<Value> dynamicShape; | ||
dispatchIndexOpFoldResults(outputShape, dynamicShape, staticShape); | ||
Type resultType = RankedTensorType::get( | ||
staticShape, cast<RankedTensorType>(v.getType()).getElementType()); | ||
return rewriter | ||
.create<tensor::ExpandShapeOp>(loc, resultType, v, reassociation, | ||
outputShape) | ||
.getResult(); | ||
}; | ||
|
||
Value expandedQuery = getReshape(attentionOp.getQuery(), {{0, 1}, {2}, {3}}, | ||
{dim0_split0, dim0_split1, dim1, dim2}); | ||
Value expandedKey = getReshape(attentionOp.getKey(), {{0, 1}, {2}, {3}}, | ||
{dim0_split0, dim0_split1, dim3, dim2}); | ||
Value expandedValue = getReshape(attentionOp.getValue(), {{0, 1}, {2}, {3}}, | ||
{dim0_split0, dim0_split1, dim3, dim4}); | ||
Value expandedInit = transposeLikeOp.getDpsInitOperand(0)->get(); | ||
|
||
SmallVector<AffineMap> newIndexingMaps = { | ||
getIndexingMap(6, {d0, d1, d2, d3}), | ||
getIndexingMap(6, {d0, d1, d4, d3}), | ||
getIndexingMap(6, {d0, d1, d4, d5}), | ||
getIndexingMap(6, {d0, d2, d1, d5})}; | ||
ArrayAttr newIndexingMapsAttr = | ||
rewriter.getAffineMapArrayAttr(newIndexingMaps); | ||
auto newAttentionOp = rewriter.create<IREE::LinalgExt::AttentionOp>( | ||
attentionOp.getLoc(), expandedInit.getType(), expandedQuery, | ||
expandedKey, expandedValue, attentionOp.getScale(), expandedInit, | ||
newIndexingMapsAttr); | ||
rewriter.replaceOp(transposeLikeOp, newAttentionOp); | ||
return success(); | ||
} | ||
}; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Pass Implementation | ||
//===----------------------------------------------------------------------===// | ||
|
||
struct FoldAttentionWithTransposePass | ||
: public impl::FoldAttentionWithTransposePassBase< | ||
FoldAttentionWithTransposePass> { | ||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
RewritePatternSet patterns(context); | ||
patterns.insert<FoldAttentionAndTranspose>(context); | ||
if (failed(applyPatternsAndFoldGreedily(getOperation(), | ||
std::move(patterns)))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
} // namespace mlir::iree_compiler::Preprocessing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 106 additions & 0 deletions
106
compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-preprocessing-fold-attention-with-transpose, resolve-shaped-type-result-dims))" --split-input-file --mlir-print-local-scope %s | FileCheck %s | ||
|
||
util.func public @fuse_attention_expand_transpose( | ||
%arg0: tensor<?x?x?xf16>, %arg1 : tensor<?x?x?xf16>, %arg2 : tensor<?x?x?xf16>, %arg3 : f16) -> tensor<2x?x?x?xf16> { | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c2 = arith.constant 2 : index | ||
%d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf16> | ||
%d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf16> | ||
%d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf16> | ||
%d3 = tensor.dim %arg1, %c1 : tensor<?x?x?xf16> | ||
%d4 = tensor.dim %arg2, %c2 : tensor<?x?x?xf16> | ||
%empty = tensor.empty(%d0, %d1, %d4) : tensor<?x?x?xf16> | ||
%attention = iree_linalg_ext.attention { | ||
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, | ||
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, | ||
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, | ||
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} | ||
ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16) | ||
outs(%empty : tensor<?x?x?xf16>) -> tensor<?x?x?xf16> | ||
%split = arith.divsi %d0, %c2 : index | ||
%expanded = tensor.expand_shape %attention [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4] | ||
: tensor<?x?x?xf16> into tensor<2x?x?x?xf16> | ||
%empty2 = tensor.empty(%d1, %split, %d4) : tensor<2x?x?x?xf16> | ||
%transpose = linalg.generic { | ||
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], | ||
iterator_types = ["parallel", "parallel", "parallel", "parallel"]} | ||
ins(%expanded : tensor<2x?x?x?xf16>) outs(%empty2 : tensor<2x?x?x?xf16>) { | ||
^bb0(%b0 : f16, %b1 : f16): | ||
linalg.yield %b0 : f16 | ||
} -> tensor<2x?x?x?xf16> | ||
util.return %transpose : tensor<2x?x?x?xf16> | ||
} | ||
// CHECK-LABEL: func public @fuse_attention_expand_transpose( | ||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf16> | ||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf16> | ||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?xf16> | ||
// CHECK-SAME: %[[ARG3:.+]]: f16) | ||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index | ||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index | ||
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index | ||
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] | ||
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] | ||
// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]] | ||
// CHECK-DAG: %[[D_SPLIT:.+]] = arith.divsi %[[D0]], %[[C2]] | ||
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[D1]], %[[D_SPLIT]], %[[D4]]) : tensor<2x?x?x?xf16> | ||
// CHECK-DAG: %[[D_SPLIT2:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[D0]]] | ||
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]] | ||
// CHECK-DAG: %[[D3:.+]] = tensor.dim %[[ARG1]], %[[C1]] | ||
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D1]], %[[D2]]{{\]}} | ||
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D2]]{{\]}} | ||
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D4]]{{\]}} | ||
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention | ||
// CHECK-SAME: indexing_maps = | ||
// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, | ||
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, | ||
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, | ||
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>] | ||
// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] : | ||
// CHECK-SAME: outs(%[[EMPTY]] : | ||
// CHECK: util.return %[[ATTENTION]] | ||
|
||
// ----- | ||
|
||
util.func public @fuse_attention_expand_transpose_static( | ||
%arg0 : tensor<20x4096x16xf16>, %arg1 : tensor<20x1024x16xf16>, | ||
%arg2 : tensor<20x1024x64xf16>, %arg3 : f16) -> tensor<2x4096x10x64xf16> { | ||
%empty = tensor.empty() : tensor<20x4096x64xf16> | ||
%attention = iree_linalg_ext.attention { | ||
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, | ||
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, | ||
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, | ||
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} | ||
ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) | ||
outs(%empty: tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> | ||
%expanded = tensor.expand_shape %attention [[0, 1], [2], [3]] | ||
output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16> | ||
%empty2 = tensor.empty() : tensor<2x4096x10x64xf16> | ||
%transpose = linalg.generic { | ||
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, | ||
affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], | ||
iterator_types = ["parallel", "parallel", "parallel", "parallel"]} | ||
ins(%expanded : tensor<2x10x4096x64xf16>) outs(%empty2 : tensor<2x4096x10x64xf16>) { | ||
^bb0(%in: f16, %out: f16): | ||
linalg.yield %in : f16 | ||
} -> tensor<2x4096x10x64xf16> | ||
util.return %transpose : tensor<2x4096x10x64xf16> | ||
} | ||
// CHECK-LABEL: func public @fuse_attention_expand_transpose_static( | ||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16> | ||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16> | ||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16> | ||
// CHECK-SAME: %[[ARG3:.+]]: f16) | ||
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4096x10x64xf16> | ||
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16] | ||
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16] | ||
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64] | ||
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention | ||
// CHECK-SAME: indexing_maps = | ||
// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, | ||
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, | ||
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, | ||
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>] | ||
// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] : | ||
// CHECK-SAME: outs(%[[EMPTY]] : | ||
// CHECK: util.return %[[ATTENTION]] |