From 7ce8c8e9f8f8a0234f97bbd576072b1cdf756f92 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Wed, 17 Jul 2024 11:27:14 -0700 Subject: [PATCH] [Preprocessing] Add a one-off pattern to fuse attention with transpose. (#17901) The attention ops in SDXL models are usually followed by a `tensor.expand_shape` and a `transpose`. It is more natural to fold these in with the attention for codegeneration. This is done as a one-off pattern for now. Ideally the attention ops can be fused with any of its elementwise consumers when attention is handled natively by the backend pass-pipelines. More details are in https://github.com/iree-org/iree/issues/17673. Signed-off-by: MaheshRavishankar --- .../compiler/Preprocessing/Common/BUILD.bazel | 2 + .../Preprocessing/Common/CMakeLists.txt | 2 + .../Common/FoldAttentionWithTranspose.cpp | 204 ++++++++++++++++++ .../compiler/Preprocessing/Common/Passes.td | 5 + .../Preprocessing/Common/test/BUILD.bazel | 1 + .../Preprocessing/Common/test/CMakeLists.txt | 1 + .../test/fold_attention_with_transpose.mlir | 106 +++++++++ 7 files changed, 321 insertions(+) create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/FoldAttentionWithTranspose.cpp create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index 1692c78bf800..a48188885592 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -33,6 +33,7 @@ iree_compiler_cc_library( "ApplyPDLPatterns.cpp", "ConvertConv2DToImg2Col.cpp", "ConvertConvToChannelsLast.cpp", + "FoldAttentionWithTranspose.cpp", "GeneralizeLinalgMatMul.cpp", "InterpreterPass.cpp", "MakeSingleDispatchForFunction.cpp", @@ -54,6 +55,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/Flow/Transforms", "//compiler/src/iree/compiler/Dialect/HAL/IR", + "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "//compiler/src/iree/compiler/Dialect/Stream/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", "@llvm-project//llvm:Support", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index 4613d4bb404b..1bc4c1e85972 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -29,6 +29,7 @@ iree_cc_library( "ApplyPDLPatterns.cpp" "ConvertConv2DToImg2Col.cpp" "ConvertConvToChannelsLast.cpp" + "FoldAttentionWithTranspose.cpp" "GeneralizeLinalgMatMul.cpp" "InterpreterPass.cpp" "MakeSingleDispatchForFunction.cpp" @@ -65,6 +66,7 @@ iree_cc_library( iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::Flow::Transforms iree::compiler::Dialect::HAL::IR + iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::Util::IR PUBLIC diff --git a/compiler/src/iree/compiler/Preprocessing/Common/FoldAttentionWithTranspose.cpp b/compiler/src/iree/compiler/Preprocessing/Common/FoldAttentionWithTranspose.cpp new file mode 100644 index 000000000000..2c84abbba06f --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/FoldAttentionWithTranspose.cpp @@ -0,0 +1,204 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler::Preprocessing { + +#define GEN_PASS_DEF_FOLDATTENTIONWITHTRANSPOSEPASS +#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export + +namespace { + +//===----------------------------------------------------------------------===// +// Attention -> Transpose fusion +//===----------------------------------------------------------------------===// + +/// Pattern to fold +/// +/// ```mlir +/// %0 = iree_linalg_ext.attention { +/// indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, +/// affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, +/// affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} +/// ins(%query, %key, %value) .... +/// %1 = tensor.expand_shape %0 into [[0, 1], [2], [3]] .... +/// %2 = linalg.generic { +/// indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, +/// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>]} +/// ins(%1) +/// ``` +/// +/// to +/// +/// ``` +/// %0 = iree_linalg_ext.attention { +/// indexing_maps = [affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d1, +/// d2)>, +/// affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d3, +/// d2)>, affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, +/// d3, d4)>, affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, +/// d1, d00, d4)>]} +/// ins(%expanded_query, %expanded_key, %expanded_value) .... +/// ``` +/// +/// Do a very specific match for now. Eventually this can be generalized to a +/// use similar analysis as to what the reshape propagation across Linalg op +/// does. TODO(#17673) +/// +struct FoldAttentionAndTranspose + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IREE::LinalgExt::AttentionOp attentionOp, + PatternRewriter &rewriter) const override { + // Check for single use attention op. + if (!attentionOp->hasOneUse()) { + return rewriter.notifyMatchFailure(attentionOp, + "attention op has multiple uses"); + } + auto expandShapeOp = + dyn_cast(*attentionOp->user_begin()); + if (!expandShapeOp) { + return rewriter.notifyMatchFailure(attentionOp, + "user is not an expand shape op."); + } + // Check for single use of expand shape op. + if (!expandShapeOp->hasOneUse()) { + return rewriter.notifyMatchFailure(attentionOp, + "expand shape op has multiple uses"); + } + auto transposeLikeOp = + dyn_cast(*expandShapeOp->user_begin()); + if (!transposeLikeOp) { + return failure(); + } + if (!(transposeLikeOp.getNumDpsInputs() == 1 && + transposeLikeOp.getNumDpsInits() == 1 && + transposeLikeOp.getBlock() + ->front() + .hasTrait() && + transposeLikeOp.getNumLoops() == + transposeLikeOp.getNumParallelLoops())) { + return rewriter.notifyMatchFailure( + transposeLikeOp, "expand shape user is not a transpose"); + } + + // Check attention op indexing maps. + AffineExpr d0, d1, d2, d3, d4, d5; + bindDims(rewriter.getContext(), d0, d1, d2, d3, d4, d5); + auto getIndexingMap = [&](int n, ArrayRef results) { + return AffineMap::get(n, 0, results, rewriter.getContext()); + }; + SmallVector expectedMaps = { + getIndexingMap(5, {d0, d1, d2}), getIndexingMap(5, {d0, d3, d2}), + getIndexingMap(5, {d0, d3, d4}), getIndexingMap(5, {d0, d1, d4})}; + if (attentionOp.getIndexingMapsArray() != expectedMaps) { + return rewriter.notifyMatchFailure( + attentionOp, "mismatch in expected maps, and maps on attention op"); + } + + // Check reassociation indexing map. + SmallVector reassociation = + expandShapeOp.getReassociationIndices(); + SmallVector expectedReassocation = {{0, 1}, {2}, {3}}; + if (reassociation != expectedReassocation) { + return rewriter.notifyMatchFailure(expandShapeOp, + "unhandled reassocation"); + } + + // Check the permutation maps for the transpose. + SmallVector expectedTransposeMaps = { + getIndexingMap(4, {d0, d1, d2, d3}), + getIndexingMap(4, {d0, d2, d1, d3})}; + if (transposeLikeOp.getIndexingMapsArray() != expectedTransposeMaps) { + return rewriter.notifyMatchFailure(transposeLikeOp, + "unhandled transpose op"); + } + + Location loc = attentionOp.getLoc(); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(transposeLikeOp); + + SmallVector expandedResultShape = + tensor::getMixedSizes(rewriter, loc, expandShapeOp); + OpFoldResult dim0_split0 = expandedResultShape[0]; + OpFoldResult dim0_split1 = expandedResultShape[1]; + OpFoldResult dim1 = expandedResultShape[2]; + OpFoldResult dim2 = + tensor::getMixedSize(rewriter, loc, attentionOp.getKey(), 2); + OpFoldResult dim3 = + tensor::getMixedSize(rewriter, loc, attentionOp.getKey(), 1); + OpFoldResult dim4 = expandedResultShape[3]; + + SmallVector newQuerySizes = {}; + SmallVector tmp; + SmallVector newQueryShape; + dispatchIndexOpFoldResults(newQuerySizes, tmp, newQueryShape); + + auto getReshape = [&](Value v, ArrayRef reassociation, + ArrayRef outputShape) -> Value { + SmallVector staticShape; + SmallVector dynamicShape; + dispatchIndexOpFoldResults(outputShape, dynamicShape, staticShape); + Type resultType = RankedTensorType::get( + staticShape, cast(v.getType()).getElementType()); + return rewriter + .create(loc, resultType, v, reassociation, + outputShape) + .getResult(); + }; + + Value expandedQuery = getReshape(attentionOp.getQuery(), {{0, 1}, {2}, {3}}, + {dim0_split0, dim0_split1, dim1, dim2}); + Value expandedKey = getReshape(attentionOp.getKey(), {{0, 1}, {2}, {3}}, + {dim0_split0, dim0_split1, dim3, dim2}); + Value expandedValue = getReshape(attentionOp.getValue(), {{0, 1}, {2}, {3}}, + {dim0_split0, dim0_split1, dim3, dim4}); + Value expandedInit = transposeLikeOp.getDpsInitOperand(0)->get(); + + SmallVector newIndexingMaps = { + getIndexingMap(6, {d0, d1, d2, d3}), + getIndexingMap(6, {d0, d1, d4, d3}), + getIndexingMap(6, {d0, d1, d4, d5}), + getIndexingMap(6, {d0, d2, d1, d5})}; + ArrayAttr newIndexingMapsAttr = + rewriter.getAffineMapArrayAttr(newIndexingMaps); + auto newAttentionOp = rewriter.create( + attentionOp.getLoc(), expandedInit.getType(), expandedQuery, + expandedKey, expandedValue, attentionOp.getScale(), expandedInit, + newIndexingMapsAttr); + rewriter.replaceOp(transposeLikeOp, newAttentionOp); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass Implementation +//===----------------------------------------------------------------------===// + +struct FoldAttentionWithTransposePass + : public impl::FoldAttentionWithTransposePassBase< + FoldAttentionWithTransposePass> { + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::Preprocessing diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index e4921b81fe88..edc17057b55b 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -51,6 +51,11 @@ def ConvertConvToChannelsLastPass : ]; } +def FoldAttentionWithTransposePass : + Pass<"iree-preprocessing-fold-attention-with-transpose", ""> { + let summary = "Fold attention operation with transpose"; +} + def InterpreterPass : Pass<"iree-preprocessing-transform-interpreter"> { let summary = "transform dialect interpreter"; let description = [{ diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel index 54ebb1176caa..e484b956de10 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel @@ -18,6 +18,7 @@ iree_lit_test_suite( [ "conv2d_to_img2col.mlir", "conv_to_channels_last.mlir", + "fold_attention_with_transpose.mlir", "generalize_linalg_matmul.mlir", "make_single_dispatch_for_function.mlir", "pad_linalg_ops.mlir", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt index 03c92b7423bc..2c105a2b0636 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "conv2d_to_img2col.mlir" "conv_to_channels_last.mlir" + "fold_attention_with_transpose.mlir" "generalize_linalg_matmul.mlir" "make_single_dispatch_for_function.mlir" "pad_linalg_ops.mlir" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir new file mode 100644 index 000000000000..cbb6ebd7835c --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir @@ -0,0 +1,106 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-preprocessing-fold-attention-with-transpose, resolve-shaped-type-result-dims))" --split-input-file --mlir-print-local-scope %s | FileCheck %s + +util.func public @fuse_attention_expand_transpose( + %arg0: tensor, %arg1 : tensor, %arg2 : tensor, %arg3 : f16) -> tensor<2x?x?x?xf16> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %d2 = tensor.dim %arg0, %c2 : tensor + %d3 = tensor.dim %arg1, %c1 : tensor + %d4 = tensor.dim %arg2, %c2 : tensor + %empty = tensor.empty(%d0, %d1, %d4) : tensor + %attention = iree_linalg_ext.attention { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%arg0, %arg1, %arg2, %arg3 : tensor, tensor, tensor, f16) + outs(%empty : tensor) -> tensor + %split = arith.divsi %d0, %c2 : index + %expanded = tensor.expand_shape %attention [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4] + : tensor into tensor<2x?x?x?xf16> + %empty2 = tensor.empty(%d1, %split, %d4) : tensor<2x?x?x?xf16> + %transpose = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%expanded : tensor<2x?x?x?xf16>) outs(%empty2 : tensor<2x?x?x?xf16>) { + ^bb0(%b0 : f16, %b1 : f16): + linalg.yield %b0 : f16 + } -> tensor<2x?x?x?xf16> + util.return %transpose : tensor<2x?x?x?xf16> +} +// CHECK-LABEL: func public @fuse_attention_expand_transpose( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG3:.+]]: f16) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]] +// CHECK-DAG: %[[D_SPLIT:.+]] = arith.divsi %[[D0]], %[[C2]] +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[D1]], %[[D_SPLIT]], %[[D4]]) : tensor<2x?x?x?xf16> +// CHECK-DAG: %[[D_SPLIT2:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[D0]]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]] +// CHECK-DAG: %[[D3:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D1]], %[[D2]]{{\]}} +// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D2]]{{\]}} +// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D4]]{{\]}} +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK-SAME: indexing_maps = +// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>] +// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] : +// CHECK-SAME: outs(%[[EMPTY]] : +// CHECK: util.return %[[ATTENTION]] + +// ----- + +util.func public @fuse_attention_expand_transpose_static( + %arg0 : tensor<20x4096x16xf16>, %arg1 : tensor<20x1024x16xf16>, + %arg2 : tensor<20x1024x64xf16>, %arg3 : f16) -> tensor<2x4096x10x64xf16> { + %empty = tensor.empty() : tensor<20x4096x64xf16> + %attention = iree_linalg_ext.attention { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) + outs(%empty: tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> + %expanded = tensor.expand_shape %attention [[0, 1], [2], [3]] + output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16> + %empty2 = tensor.empty() : tensor<2x4096x10x64xf16> + %transpose = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%expanded : tensor<2x10x4096x64xf16>) outs(%empty2 : tensor<2x4096x10x64xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<2x4096x10x64xf16> + util.return %transpose : tensor<2x4096x10x64xf16> +} +// CHECK-LABEL: func public @fuse_attention_expand_transpose_static( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16> +// CHECK-SAME: %[[ARG3:.+]]: f16) +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4096x10x64xf16> +// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16] +// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16] +// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64] +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK-SAME: indexing_maps = +// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>] +// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] : +// CHECK-SAME: outs(%[[EMPTY]] : +// CHECK: util.return %[[ATTENTION]]