Skip to content

Commit

Permalink
Merge pull request #304 from Xilinx/bump_to_ad1083dc
Browse files Browse the repository at this point in the history
[AutoBump] Merge with fixes of ad1083d (May 14) (45)
  • Loading branch information
cferry-AMD authored Sep 3, 2024
2 parents 5520c5c + aae518b commit 4283845
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ void populateSparseAssembler(RewritePatternSet &patterns, bool directOut);
std::unique_ptr<Pass> createSparseAssembler();
std::unique_ptr<Pass> createSparseAssembler(bool directOut);

//===----------------------------------------------------------------------===//
// The SparseEncodingPropagation pass.
//===----------------------------------------------------------------------===//

std::unique_ptr<Pass> createSparseEncodingPropagationPass();

//===----------------------------------------------------------------------===//
// The SparseReinterpretMap pass.
//===----------------------------------------------------------------------===//
Expand Down
36 changes: 36 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,42 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
];
}

def SparseEncodingPropagation : Pass<"sparse-encoding-propagation", "func::FuncOp"> {
let summary = "Propagate sparse tensor encodings";
let description = [{
A pass that propagates sparse tensor encodings.

Background: To avoid introducing repetitive operations, sparse tensors
in MLIR try to reuse tensor operations whenever available. However, most
tensor operations are canonicalized/transformed without the knowledge
of sparsity. The pass tries to propagate missing sparse encodings.

For example:
```mlir
%s = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
: tensor<2x3xf32, #sparse> to tensor<2x1xf32, #sparse>

// After rank reducing (by tensor dialect transformation)
%t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
: tensor<2x3xf32, #sparse> to tensor<2xf32>
%s = tensor.expand_shape [[0, 1]] %t
: tensor<2xf32> to tensor<2x1xf32, #sparse>

// After sparsity propagation
%t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
: tensor<2x3xf32, #sparse> to tensor<2xf32, #sparse1>
%s = tensor.expand_shape [[0, 1]] %t
: tensor<2xf32, #sparse1> to tensor<2x1xf32, #sparse>
```
}];

let constructor = "mlir::createSparseEncodingPropagationPass()";
let dependentDialects = [
"sparse_tensor::SparseTensorDialect",
"tensor::TensorDialect",
];
}

def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
let summary = "Reinterprets sparse tensor type mappings";
let description = [{
Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

namespace mlir {
#define GEN_PASS_DEF_SPARSEASSEMBLER
#define GEN_PASS_DEF_SPARSEENCODINGPROPAGATION
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
#define GEN_PASS_DEF_SPARSIFICATIONPASS
Expand Down Expand Up @@ -60,6 +61,14 @@ struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
}
};

struct SparseEncodingPropagation
: public impl::SparseEncodingPropagationBase<SparseEncodingPropagation> {
SparseEncodingPropagation() = default;
SparseEncodingPropagation(const SparseEncodingPropagation &pass) = default;

void runOnOperation() override {}
};

struct SparseReinterpretMap
: public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
SparseReinterpretMap() = default;
Expand Down Expand Up @@ -398,6 +407,10 @@ std::unique_ptr<Pass> mlir::createSparseAssembler() {
return std::make_unique<SparseAssembler>();
}

std::unique_ptr<Pass> mlir::createSparseEncodingPropagationPass() {
return std::make_unique<SparseEncodingPropagation>();
}

std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
return std::make_unique<SparseReinterpretMap>();
}
Expand Down

0 comments on commit 4283845

Please sign in to comment.