From de998c01cd256d81a0050b0b56306d09b88e1716 Mon Sep 17 00:00:00 2001 From: Tina Jung <126699487+TinaAMD@users.noreply.github.com> Date: Fri, 8 Sep 2023 10:16:23 +0200 Subject: [PATCH] Implement folder for unused empty.memory_formats (#139) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 15 +++++++++++++++ .../importer/jit_ir/build_tools/torch_ods_gen.py | 2 +- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9327f3d363a1..412291292872 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7348,6 +7348,7 @@ def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ printDefaultTorchOp(printer, *this, 6, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenExpandOp : Torch_Op<"aten.expand", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index ef02d86629ea..35f1a753b46b 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1337,6 +1337,21 @@ static OpFoldResult intComparatorFoldHelper(OpTy op, OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) { return getSelf(); } +//===----------------------------------------------------------------------===// +// AtenEmptyMemoryFormatOp +//===----------------------------------------------------------------------===// + +void AtenEmptyMemoryFormatOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenEmptyMemoryFormatOp op, PatternRewriter &rewriter) { + if (!op->use_empty()) { + return failure(); + } + rewriter.eraseOp(op); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenNeIntOp //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 5a9670bc9844..007df85d11eb 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -494,7 +494,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::new_empty_strided : (Tensor, int[], int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") - emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)") + emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)", has_canonicalizer=True) emit("aten::expand : (Tensor, int[], bool) -> (Tensor)") emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)") emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_canonicalizer=True)