From f8eceb45d0bbca092164efffc92f2e9d66b304a5 Mon Sep 17 00:00:00 2001 From: Bimo Date: Wed, 18 Sep 2024 11:54:16 +0800 Subject: [PATCH] [MLIR] [Python] align python ir printing with mlir-print-ir-after-all (#107522) When using the `enable_ir_printing` API from Python, it invokes IR printing with default args, printing the IR before each pass and printing IR after pass only if there have been changes. This PR attempts to align the `enable_ir_printing` API with the documentation --- mlir/include/mlir-c/Pass.h | 8 +++-- mlir/lib/Bindings/Python/Pass.cpp | 13 ++++++-- mlir/lib/CAPI/IR/Pass.cpp | 17 +++++++++-- .../mlir/_mlir_libs/_mlir/passmanager.pyi | 9 +++++- mlir/test/python/pass_manager.py | 30 +++++++++++++++++-- 5 files changed, 66 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 35db138305d1e2..2218ec0f47d199 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -74,9 +74,11 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager); MLIR_CAPI_EXPORTED MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op); -/// Enable mlir-print-ir-after-all. -MLIR_CAPI_EXPORTED void -mlirPassManagerEnableIRPrinting(MlirPassManager passManager); +/// Enable IR printing. +MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting( + MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, + bool printModuleScope, bool printAfterOnlyOnChange, + bool printAfterOnlyOnFailure); /// Enable / disable verify-each. MLIR_CAPI_EXPORTED void diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index a68421b61641f6..1d0e5ce2115a0a 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -74,10 +74,17 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { "Releases (leaks) the backing pass manager (testing)") .def( "enable_ir_printing", - [](PyPassManager &passManager) { - mlirPassManagerEnableIRPrinting(passManager.get()); + [](PyPassManager &passManager, bool printBeforeAll, + bool printAfterAll, bool printModuleScope, bool printAfterChange, + bool printAfterFailure) { + mlirPassManagerEnableIRPrinting( + passManager.get(), printBeforeAll, printAfterAll, + printModuleScope, printAfterChange, printAfterFailure); }, - "Enable mlir-print-ir-after-all.") + "print_before_all"_a = false, "print_after_all"_a = true, + "print_module_scope"_a = false, "print_after_change"_a = false, + "print_after_failure"_a = false, + "Enable IR printing, default as mlir-print-ir-after-all.") .def( "enable_verifier", [](PyPassManager &passManager, bool enable) { diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index d242baae99c086..a6c9fbd08d45a6 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -44,8 +44,21 @@ MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, return wrap(unwrap(passManager)->run(unwrap(op))); } -void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) { - return unwrap(passManager)->enableIRPrinting(); +void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, + bool printBeforeAll, bool printAfterAll, + bool printModuleScope, + bool printAfterOnlyOnChange, + bool printAfterOnlyOnFailure) { + auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) { + return printBeforeAll; + }; + auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) { + return printAfterAll; + }; + return unwrap(passManager) + ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, + printModuleScope, printAfterOnlyOnChange, + printAfterOnlyOnFailure); } void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi index c072d5e0fb86f3..5d115e8222d730 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -16,7 +16,14 @@ class PassManager: def __init__(self, context: Optional[_ir.Context] = None) -> None: ... def _CAPICreate(self) -> object: ... def _testing_release(self) -> None: ... - def enable_ir_printing(self) -> None: ... + def enable_ir_printing( + self, + print_before_all: bool = False, + print_after_all: bool = True, + print_module_scope: bool = False, + print_after_change: bool = False, + print_after_failure: bool = False, + ) -> None: ... def enable_verifier(self, enable: bool) -> None: ... @staticmethod def parse(pipeline: str, context: Optional[_ir.Context] = None) -> PassManager: ... diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index 43af80b53166cc..74967032562351 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -300,14 +300,40 @@ def testPrintIrAfterAll(): pm = PassManager.parse("builtin.module(canonicalize)") ctx.enable_multithreading(False) pm.enable_ir_printing() - # CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) ('builtin.module' operation) //----- // + # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- // + # CHECK: module { + # CHECK: func.func @main() { + # CHECK: return + # CHECK: } + # CHECK: } + pm.run(module) + + +# CHECK-LABEL: TEST: testPrintIrBeforeAndAfterAll +@run +def testPrintIrBeforeAndAfterAll(): + with Context() as ctx: + module = ModuleOp.parse( + """ + module { + func.func @main() { + %0 = arith.constant 10 + return + } + } + """ + ) + pm = PassManager.parse("builtin.module(canonicalize)") + ctx.enable_multithreading(False) + pm.enable_ir_printing(print_before_all=True, print_after_all=True) + # CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) //----- // # CHECK: module { # CHECK: func.func @main() { # CHECK: %[[C10:.*]] = arith.constant 10 : i64 # CHECK: return # CHECK: } # CHECK: } - # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) ('builtin.module' operation) //----- // + # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- // # CHECK: module { # CHECK: func.func @main() { # CHECK: return