Skip to content

Commit

Permalink
[mlir] Add apply_patterns.linalg.generalize_pack_unpack TD Op (llvm#1…
Browse files Browse the repository at this point in the history
…16373)

This PR introduces populateGeneralizePatterns, which collects the
following patterns:

  * `GeneralizeOuterUnitDimsPackOpPattern`,
  * `GeneralizeOuterUnitDimsUnPackOpPattern` (currently a TODO).

These patterns are wrapped in a new Transform Dialect Op:
`apply_patterns.linalg.generalize_pack_unpack`. This Op facilitates
creating more involved end-to-end compilation pipelines for
`tensor.pack` and `tensor.unpack` operations. It will be required in an
upcoming PR building on top of llvm#115698.

No new tests are added in this PR. Instead, existing tests from:

  * "generalize-tensor-pack.mlir"

are reused. To achieve this:

  * I've updated the test to use
    `transform.apply_patterns.linalg.generalize_pack_unpack` instead of
    the flag
    `--test-linalg-transform-patterns="test-generalize-tensor-pack"`,
    avoiding artificial tests solely for the TD Op.
  * The TD sequence is saved to a new file, "generalize_pack.mlir", and
    pre-loaded using the option:

`--transform-preload-library='transform-library-paths=%p/td/generalize_pack.mlir'`
    This avoids duplicating the sequence for every "split" in the input
    file.
  * Added "lit.local.cfg" to exclude the "test/Dialect/Linalg/td"
    directory from test discovery, ensuring "generalize_pack.mlir" is
    not treated as a test file.
  • Loading branch information
banach-space authored Nov 18, 2024
1 parent 1dcb3db commit 63b926a
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ def ApplyEraseUnnecessaryInputsPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplyGeneralizeTensorPackUnpackPatternsOp
: Op<Transform_Dialect, "apply_patterns.linalg.generalize_pack_unpack",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect patterns to generalize tensor.pack and tensor.unpack (i.e. to
decompose it into e.g. tensor::PadOp, linalg::transposeOp etc). Requires
all outer dims to be unit.
}];

let assemblyFormat = "attr-dict";
}

def ApplyFoldUnitExtentDimsViaReshapesPatternsOp : Op<Transform_Dialect,
"apply_patterns.linalg.fold_unit_extent_dims_via_reshapes",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
Expand Down
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1516,8 +1516,8 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
};

/// Rewrites a tensor::PackOp into a sequence of:
/// * tensor::PadOp + linalg::TransposeOp +
/// tensor::EmptyOp + tensor::InsertSliceOp ops.
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
/// tensor::InsertSliceOp ops.
///
/// Required that all the outer dims of the input tensor::PackOp are 1.
///
Expand Down Expand Up @@ -1683,6 +1683,11 @@ void populateLinalgGenericOpsSpecializationPatterns(
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g.
/// tensor.pad, linalg.transpose, tensor.{insert|extract}_slice. Require all
/// outer dims to be unit.
void populateGeneralizePatterns(RewritePatternSet &patterns);

/// Populates patterns to transform linalg.conv_2d_xxx operations into
/// linalg.generic (for img2col packing) and linalg.matmul.
/// \see rewriteInIm2Col for more details.
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
linalg::populateEraseUnnecessaryInputsPatterns(patterns);
}

void transform::ApplyGeneralizeTensorPackUnpackPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::populateGeneralizePatterns(patterns);
}

void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::ControlDropUnitDims options;
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1618,3 +1618,8 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
patterns.getContext(), benefit);
}

void linalg::populateGeneralizePatterns(RewritePatternSet &patterns) {
// TODO: Add and test patterns for tensor.unpack
patterns.add<GeneralizeOuterUnitDimsPackOpPattern>(patterns.getContext());
}
3 changes: 1 addition & 2 deletions mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-pack" %s | FileCheck %s

// RUN: mlir-opt --transform-preload-library='transform-library-paths=%p/td/generalize-pack.mlir' -split-input-file --transform-interpreter %s | FileCheck %s

func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32> {
%c8 = arith.constant 8 : index
Expand Down
2 changes: 2 additions & 0 deletions mlir/test/Dialect/Linalg/lit.local.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Skip the directory with input TD sequences
config.excludes = ["td"]
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Linalg/td/generalize-pack.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module @transforms attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
%pack = transform.structured.match ops{["tensor.pack"]} in %module : (!transform.any_op) -> !transform.any_op

%1 = transform.get_parent_op %pack {isolated_from_above} : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %1 {
transform.apply_patterns.linalg.generalize_pack_unpack
} : !transform.any_op

transform.yield
}
}

0 comments on commit 63b926a

Please sign in to comment.