Skip to content

Commit

Permalink
[Preprocessing] Add a one-off pattern to fuse attention with transpos…
Browse files Browse the repository at this point in the history
…e. (iree-org#17901)

The attention ops in SDXL models are usually followed by a
`tensor.expand_shape` and a `transpose`. It is more natural to fold
these in with the attention for codegeneration. This is done as a
one-off pattern for now. Ideally the attention ops can be fused with any
of its elementwise consumers when attention is handled natively by the
backend pass-pipelines.
More details are in iree-org#17673.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar authored Jul 17, 2024
1 parent 4de493a commit 7ce8c8e
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 0 deletions.
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ iree_compiler_cc_library(
"ApplyPDLPatterns.cpp",
"ConvertConv2DToImg2Col.cpp",
"ConvertConvToChannelsLast.cpp",
"FoldAttentionWithTranspose.cpp",
"GeneralizeLinalgMatMul.cpp",
"InterpreterPass.cpp",
"MakeSingleDispatchForFunction.cpp",
Expand All @@ -54,6 +55,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Flow/Transforms",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ iree_cc_library(
"ApplyPDLPatterns.cpp"
"ConvertConv2DToImg2Col.cpp"
"ConvertConvToChannelsLast.cpp"
"FoldAttentionWithTranspose.cpp"
"GeneralizeLinalgMatMul.cpp"
"InterpreterPass.cpp"
"MakeSingleDispatchForFunction.cpp"
Expand Down Expand Up @@ -65,6 +66,7 @@ iree_cc_library(
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Flow::Transforms
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::LinalgExt::IR
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::IR
PUBLIC
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Preprocessing/Common/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::Preprocessing {

#define GEN_PASS_DEF_FOLDATTENTIONWITHTRANSPOSEPASS
#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export

namespace {

//===----------------------------------------------------------------------===//
// Attention -> Transpose fusion
//===----------------------------------------------------------------------===//

/// Pattern to fold
///
/// ```mlir
/// %0 = iree_linalg_ext.attention {
/// indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
/// affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
/// affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
/// ins(%query, %key, %value) ....
/// %1 = tensor.expand_shape %0 into [[0, 1], [2], [3]] ....
/// %2 = linalg.generic {
/// indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
/// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>]}
/// ins(%1)
/// ```
///
/// to
///
/// ```
/// %0 = iree_linalg_ext.attention {
/// indexing_maps = [affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d1,
/// d2)>,
/// affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d3,
/// d2)>, affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00,
/// d3, d4)>, affine_map<(d0, d00, d1, d2, d3, d4) -> (d0,
/// d1, d00, d4)>]}
/// ins(%expanded_query, %expanded_key, %expanded_value) ....
/// ```
///
/// Do a very specific match for now. Eventually this can be generalized to a
/// use similar analysis as to what the reshape propagation across Linalg op
/// does. TODO(#17673)
///
struct FoldAttentionAndTranspose
: public OpRewritePattern<IREE::LinalgExt::AttentionOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(IREE::LinalgExt::AttentionOp attentionOp,
PatternRewriter &rewriter) const override {
// Check for single use attention op.
if (!attentionOp->hasOneUse()) {
return rewriter.notifyMatchFailure(attentionOp,
"attention op has multiple uses");
}
auto expandShapeOp =
dyn_cast<tensor::ExpandShapeOp>(*attentionOp->user_begin());
if (!expandShapeOp) {
return rewriter.notifyMatchFailure(attentionOp,
"user is not an expand shape op.");
}
// Check for single use of expand shape op.
if (!expandShapeOp->hasOneUse()) {
return rewriter.notifyMatchFailure(attentionOp,
"expand shape op has multiple uses");
}
auto transposeLikeOp =
dyn_cast<linalg::LinalgOp>(*expandShapeOp->user_begin());
if (!transposeLikeOp) {
return failure();
}
if (!(transposeLikeOp.getNumDpsInputs() == 1 &&
transposeLikeOp.getNumDpsInits() == 1 &&
transposeLikeOp.getBlock()
->front()
.hasTrait<OpTrait::IsTerminator>() &&
transposeLikeOp.getNumLoops() ==
transposeLikeOp.getNumParallelLoops())) {
return rewriter.notifyMatchFailure(
transposeLikeOp, "expand shape user is not a transpose");
}

// Check attention op indexing maps.
AffineExpr d0, d1, d2, d3, d4, d5;
bindDims(rewriter.getContext(), d0, d1, d2, d3, d4, d5);
auto getIndexingMap = [&](int n, ArrayRef<AffineExpr> results) {
return AffineMap::get(n, 0, results, rewriter.getContext());
};
SmallVector<AffineMap> expectedMaps = {
getIndexingMap(5, {d0, d1, d2}), getIndexingMap(5, {d0, d3, d2}),
getIndexingMap(5, {d0, d3, d4}), getIndexingMap(5, {d0, d1, d4})};
if (attentionOp.getIndexingMapsArray() != expectedMaps) {
return rewriter.notifyMatchFailure(
attentionOp, "mismatch in expected maps, and maps on attention op");
}

// Check reassociation indexing map.
SmallVector<ReassociationIndices> reassociation =
expandShapeOp.getReassociationIndices();
SmallVector<ReassociationIndices> expectedReassocation = {{0, 1}, {2}, {3}};
if (reassociation != expectedReassocation) {
return rewriter.notifyMatchFailure(expandShapeOp,
"unhandled reassocation");
}

// Check the permutation maps for the transpose.
SmallVector<AffineMap> expectedTransposeMaps = {
getIndexingMap(4, {d0, d1, d2, d3}),
getIndexingMap(4, {d0, d2, d1, d3})};
if (transposeLikeOp.getIndexingMapsArray() != expectedTransposeMaps) {
return rewriter.notifyMatchFailure(transposeLikeOp,
"unhandled transpose op");
}

Location loc = attentionOp.getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(transposeLikeOp);

SmallVector<OpFoldResult> expandedResultShape =
tensor::getMixedSizes(rewriter, loc, expandShapeOp);
OpFoldResult dim0_split0 = expandedResultShape[0];
OpFoldResult dim0_split1 = expandedResultShape[1];
OpFoldResult dim1 = expandedResultShape[2];
OpFoldResult dim2 =
tensor::getMixedSize(rewriter, loc, attentionOp.getKey(), 2);
OpFoldResult dim3 =
tensor::getMixedSize(rewriter, loc, attentionOp.getKey(), 1);
OpFoldResult dim4 = expandedResultShape[3];

SmallVector<OpFoldResult> newQuerySizes = {};
SmallVector<Value> tmp;
SmallVector<int64_t> newQueryShape;
dispatchIndexOpFoldResults(newQuerySizes, tmp, newQueryShape);

auto getReshape = [&](Value v, ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> outputShape) -> Value {
SmallVector<int64_t> staticShape;
SmallVector<Value> dynamicShape;
dispatchIndexOpFoldResults(outputShape, dynamicShape, staticShape);
Type resultType = RankedTensorType::get(
staticShape, cast<RankedTensorType>(v.getType()).getElementType());
return rewriter
.create<tensor::ExpandShapeOp>(loc, resultType, v, reassociation,
outputShape)
.getResult();
};

Value expandedQuery = getReshape(attentionOp.getQuery(), {{0, 1}, {2}, {3}},
{dim0_split0, dim0_split1, dim1, dim2});
Value expandedKey = getReshape(attentionOp.getKey(), {{0, 1}, {2}, {3}},
{dim0_split0, dim0_split1, dim3, dim2});
Value expandedValue = getReshape(attentionOp.getValue(), {{0, 1}, {2}, {3}},
{dim0_split0, dim0_split1, dim3, dim4});
Value expandedInit = transposeLikeOp.getDpsInitOperand(0)->get();

SmallVector<AffineMap> newIndexingMaps = {
getIndexingMap(6, {d0, d1, d2, d3}),
getIndexingMap(6, {d0, d1, d4, d3}),
getIndexingMap(6, {d0, d1, d4, d5}),
getIndexingMap(6, {d0, d2, d1, d5})};
ArrayAttr newIndexingMapsAttr =
rewriter.getAffineMapArrayAttr(newIndexingMaps);
auto newAttentionOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
attentionOp.getLoc(), expandedInit.getType(), expandedQuery,
expandedKey, expandedValue, attentionOp.getScale(), expandedInit,
newIndexingMapsAttr);
rewriter.replaceOp(transposeLikeOp, newAttentionOp);
return success();
}
};

//===----------------------------------------------------------------------===//
// Pass Implementation
//===----------------------------------------------------------------------===//

struct FoldAttentionWithTransposePass
: public impl::FoldAttentionWithTransposePassBase<
FoldAttentionWithTransposePass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.insert<FoldAttentionAndTranspose>(context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
};

} // namespace

} // namespace mlir::iree_compiler::Preprocessing
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Preprocessing/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def ConvertConvToChannelsLastPass :
];
}

def FoldAttentionWithTransposePass :
Pass<"iree-preprocessing-fold-attention-with-transpose", ""> {
let summary = "Fold attention operation with transpose";
}

def InterpreterPass : Pass<"iree-preprocessing-transform-interpreter"> {
let summary = "transform dialect interpreter";
let description = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ iree_lit_test_suite(
[
"conv2d_to_img2col.mlir",
"conv_to_channels_last.mlir",
"fold_attention_with_transpose.mlir",
"generalize_linalg_matmul.mlir",
"make_single_dispatch_for_function.mlir",
"pad_linalg_ops.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ iree_lit_test_suite(
SRCS
"conv2d_to_img2col.mlir"
"conv_to_channels_last.mlir"
"fold_attention_with_transpose.mlir"
"generalize_linalg_matmul.mlir"
"make_single_dispatch_for_function.mlir"
"pad_linalg_ops.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-preprocessing-fold-attention-with-transpose, resolve-shaped-type-result-dims))" --split-input-file --mlir-print-local-scope %s | FileCheck %s

util.func public @fuse_attention_expand_transpose(
%arg0: tensor<?x?x?xf16>, %arg1 : tensor<?x?x?xf16>, %arg2 : tensor<?x?x?xf16>, %arg3 : f16) -> tensor<2x?x?x?xf16> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf16>
%d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf16>
%d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf16>
%d3 = tensor.dim %arg1, %c1 : tensor<?x?x?xf16>
%d4 = tensor.dim %arg2, %c2 : tensor<?x?x?xf16>
%empty = tensor.empty(%d0, %d1, %d4) : tensor<?x?x?xf16>
%attention = iree_linalg_ext.attention {
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16)
outs(%empty : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
%split = arith.divsi %d0, %c2 : index
%expanded = tensor.expand_shape %attention [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4]
: tensor<?x?x?xf16> into tensor<2x?x?x?xf16>
%empty2 = tensor.empty(%d1, %split, %d4) : tensor<2x?x?x?xf16>
%transpose = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%expanded : tensor<2x?x?x?xf16>) outs(%empty2 : tensor<2x?x?x?xf16>) {
^bb0(%b0 : f16, %b1 : f16):
linalg.yield %b0 : f16
} -> tensor<2x?x?x?xf16>
util.return %transpose : tensor<2x?x?x?xf16>
}
// CHECK-LABEL: func public @fuse_attention_expand_transpose(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
// CHECK-SAME: %[[ARG3:.+]]: f16)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]]
// CHECK-DAG: %[[D_SPLIT:.+]] = arith.divsi %[[D0]], %[[C2]]
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[D1]], %[[D_SPLIT]], %[[D4]]) : tensor<2x?x?x?xf16>
// CHECK-DAG: %[[D_SPLIT2:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[D0]]]
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
// CHECK-DAG: %[[D3:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D1]], %[[D2]]{{\]}}
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D2]]{{\]}}
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D4]]{{\]}}
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK-SAME: indexing_maps =
// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>]
// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] :
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK: util.return %[[ATTENTION]]

// -----

util.func public @fuse_attention_expand_transpose_static(
%arg0 : tensor<20x4096x16xf16>, %arg1 : tensor<20x1024x16xf16>,
%arg2 : tensor<20x1024x64xf16>, %arg3 : f16) -> tensor<2x4096x10x64xf16> {
%empty = tensor.empty() : tensor<20x4096x64xf16>
%attention = iree_linalg_ext.attention {
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16)
outs(%empty: tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
%expanded = tensor.expand_shape %attention [[0, 1], [2], [3]]
output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16>
%empty2 = tensor.empty() : tensor<2x4096x10x64xf16>
%transpose = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%expanded : tensor<2x10x4096x64xf16>) outs(%empty2 : tensor<2x4096x10x64xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<2x4096x10x64xf16>
util.return %transpose : tensor<2x4096x10x64xf16>
}
// CHECK-LABEL: func public @fuse_attention_expand_transpose_static(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16>
// CHECK-SAME: %[[ARG3:.+]]: f16)
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4096x10x64xf16>
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16]
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16]
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64]
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK-SAME: indexing_maps =
// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>]
// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] :
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK: util.return %[[ATTENTION]]

0 comments on commit 7ce8c8e

Please sign in to comment.