diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9327f3d363a1..412291292872 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7348,6 +7348,7 @@ def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ printDefaultTorchOp(printer, *this, 6, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenExpandOp : Torch_Op<"aten.expand", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index ef02d86629ea..35f1a753b46b 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1337,6 +1337,21 @@ static OpFoldResult intComparatorFoldHelper(OpTy op, OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) { return getSelf(); } +//===----------------------------------------------------------------------===// +// AtenEmptyMemoryFormatOp +//===----------------------------------------------------------------------===// + +void AtenEmptyMemoryFormatOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenEmptyMemoryFormatOp op, PatternRewriter &rewriter) { + if (!op->use_empty()) { + return failure(); + } + rewriter.eraseOp(op); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenNeIntOp //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 5a9670bc9844..007df85d11eb 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -494,7 +494,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::new_empty_strided : (Tensor, int[], int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") - emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)") + emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)", has_canonicalizer=True) emit("aten::expand : (Tensor, int[], bool) -> (Tensor)") emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)") emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_canonicalizer=True)