From ad1083dce4f664265c5489ecd2e46649cd978683 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Mon, 13 May 2024 17:29:01 -0700 Subject: [PATCH] [mlir][sparse] introduce new pass to propagate sparse encodings. (#92052) --- .../Dialect/SparseTensor/Transforms/Passes.h | 6 ++++ .../Dialect/SparseTensor/Transforms/Passes.td | 36 +++++++++++++++++++ .../Transforms/SparseTensorPasses.cpp | 13 +++++++ 3 files changed, 55 insertions(+) diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index d6d038ef65bdf4..bb49d6c256f21b 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -65,6 +65,12 @@ void populateSparseAssembler(RewritePatternSet &patterns, bool directOut); std::unique_ptr createSparseAssembler(); std::unique_ptr createSparseAssembler(bool directOut); +//===----------------------------------------------------------------------===// +// The SparseEncodingPropagation pass. +//===----------------------------------------------------------------------===// + +std::unique_ptr createSparseEncodingPropagationPass(); + //===----------------------------------------------------------------------===// // The SparseReinterpretMap pass. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index 2f844cee5ff528..94c3ca60030eeb 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -40,6 +40,42 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> { ]; } +def SparseEncodingPropagation : Pass<"sparse-encoding-propagation", "func::FuncOp"> { + let summary = "Propagate sparse tensor encodings"; + let description = [{ + A pass that propagates sparse tensor encodings. + + Background: To avoid introducing repetitive operations, sparse tensors + in MLIR try to reuse tensor operations whenever available. However, most + tensor operations are canonicalized/transformed without the knowledge + of sparsity. The pass tries to propagate missing sparse encodings. + + For example: + ```mlir + %s = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1] + : tensor<2x3xf32, #sparse> to tensor<2x1xf32, #sparse> + + // After rank reducing (by tensor dialect transformation) + %t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1] + : tensor<2x3xf32, #sparse> to tensor<2xf32> + %s = tensor.expand_shape [[0, 1]] %t + : tensor<2xf32> to tensor<2x1xf32, #sparse> + + // After sparsity propagation + %t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1] + : tensor<2x3xf32, #sparse> to tensor<2xf32, #sparse1> + %s = tensor.expand_shape [[0, 1]] %t + : tensor<2xf32, #sparse1> to tensor<2x1xf32, #sparse> + ``` + }]; + + let constructor = "mlir::createSparseEncodingPropagationPass()"; + let dependentDialects = [ + "sparse_tensor::SparseTensorDialect", + "tensor::TensorDialect", + ]; +} + def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> { let summary = "Reinterprets sparse tensor type mappings"; let description = [{ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index b42d58634a36c4..f57353b5892b5a 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -23,6 +23,7 @@ namespace mlir { #define GEN_PASS_DEF_SPARSEASSEMBLER +#define GEN_PASS_DEF_SPARSEENCODINGPROPAGATION #define GEN_PASS_DEF_SPARSEREINTERPRETMAP #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE #define GEN_PASS_DEF_SPARSIFICATIONPASS @@ -60,6 +61,14 @@ struct SparseAssembler : public impl::SparseAssemblerBase { } }; +struct SparseEncodingPropagation + : public impl::SparseEncodingPropagationBase { + SparseEncodingPropagation() = default; + SparseEncodingPropagation(const SparseEncodingPropagation &pass) = default; + + void runOnOperation() override {} +}; + struct SparseReinterpretMap : public impl::SparseReinterpretMapBase { SparseReinterpretMap() = default; @@ -398,6 +407,10 @@ std::unique_ptr mlir::createSparseAssembler() { return std::make_unique(); } +std::unique_ptr mlir::createSparseEncodingPropagationPass() { + return std::make_unique(); +} + std::unique_ptr mlir::createSparseReinterpretMapPass() { return std::make_unique(); }