diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 35db138305d1e2..2218ec0f47d199 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -74,9 +74,11 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager); MLIR_CAPI_EXPORTED MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op); -/// Enable mlir-print-ir-after-all. -MLIR_CAPI_EXPORTED void -mlirPassManagerEnableIRPrinting(MlirPassManager passManager); +/// Enable IR printing. +MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting( + MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, + bool printModuleScope, bool printAfterOnlyOnChange, + bool printAfterOnlyOnFailure); /// Enable / disable verify-each. MLIR_CAPI_EXPORTED void diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index a68421b61641f6..1d0e5ce2115a0a 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -74,10 +74,17 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { "Releases (leaks) the backing pass manager (testing)") .def( "enable_ir_printing", - [](PyPassManager &passManager) { - mlirPassManagerEnableIRPrinting(passManager.get()); + [](PyPassManager &passManager, bool printBeforeAll, + bool printAfterAll, bool printModuleScope, bool printAfterChange, + bool printAfterFailure) { + mlirPassManagerEnableIRPrinting( + passManager.get(), printBeforeAll, printAfterAll, + printModuleScope, printAfterChange, printAfterFailure); }, - "Enable mlir-print-ir-after-all.") + "print_before_all"_a = false, "print_after_all"_a = true, + "print_module_scope"_a = false, "print_after_change"_a = false, + "print_after_failure"_a = false, + "Enable IR printing, default as mlir-print-ir-after-all.") .def( "enable_verifier", [](PyPassManager &passManager, bool enable) { diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index d242baae99c086..a6c9fbd08d45a6 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -44,8 +44,21 @@ MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, return wrap(unwrap(passManager)->run(unwrap(op))); } -void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) { - return unwrap(passManager)->enableIRPrinting(); +void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, + bool printBeforeAll, bool printAfterAll, + bool printModuleScope, + bool printAfterOnlyOnChange, + bool printAfterOnlyOnFailure) { + auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) { + return printBeforeAll; + }; + auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) { + return printAfterAll; + }; + return unwrap(passManager) + ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, + printModuleScope, printAfterOnlyOnChange, + printAfterOnlyOnFailure); } void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi index c072d5e0fb86f3..5d115e8222d730 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -16,7 +16,14 @@ class PassManager: def __init__(self, context: Optional[_ir.Context] = None) -> None: ... def _CAPICreate(self) -> object: ... def _testing_release(self) -> None: ... - def enable_ir_printing(self) -> None: ... + def enable_ir_printing( + self, + print_before_all: bool = False, + print_after_all: bool = True, + print_module_scope: bool = False, + print_after_change: bool = False, + print_after_failure: bool = False, + ) -> None: ... def enable_verifier(self, enable: bool) -> None: ... @staticmethod def parse(pipeline: str, context: Optional[_ir.Context] = None) -> PassManager: ... diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index 43af80b53166cc..74967032562351 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -300,14 +300,40 @@ def testPrintIrAfterAll(): pm = PassManager.parse("builtin.module(canonicalize)") ctx.enable_multithreading(False) pm.enable_ir_printing() - # CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) ('builtin.module' operation) //----- // + # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- // + # CHECK: module { + # CHECK: func.func @main() { + # CHECK: return + # CHECK: } + # CHECK: } + pm.run(module) + + +# CHECK-LABEL: TEST: testPrintIrBeforeAndAfterAll +@run +def testPrintIrBeforeAndAfterAll(): + with Context() as ctx: + module = ModuleOp.parse( + """ + module { + func.func @main() { + %0 = arith.constant 10 + return + } + } + """ + ) + pm = PassManager.parse("builtin.module(canonicalize)") + ctx.enable_multithreading(False) + pm.enable_ir_printing(print_before_all=True, print_after_all=True) + # CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) //----- // # CHECK: module { # CHECK: func.func @main() { # CHECK: %[[C10:.*]] = arith.constant 10 : i64 # CHECK: return # CHECK: } # CHECK: } - # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) ('builtin.module' operation) //----- // + # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- // # CHECK: module { # CHECK: func.func @main() { # CHECK: return