Skip to content

Commit

Permalink
[MLIR] [Python] align python ir printing with mlir-print-ir-after-all (
Browse files Browse the repository at this point in the history
…llvm#107522)

When using the `enable_ir_printing` API from Python, it invokes IR
printing with default args, printing the IR before each pass and
printing IR after pass only if there have been changes. This PR attempts
to align the `enable_ir_printing` API with the documentation
  • Loading branch information
xurui1995 authored Sep 18, 2024
1 parent 8280651 commit f8eceb4
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 11 deletions.
8 changes: 5 additions & 3 deletions mlir/include/mlir-c/Pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager);
MLIR_CAPI_EXPORTED MlirLogicalResult
mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);

/// Enable mlir-print-ir-after-all.
MLIR_CAPI_EXPORTED void
mlirPassManagerEnableIRPrinting(MlirPassManager passManager);
/// Enable IR printing.
MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
MlirPassManager passManager, bool printBeforeAll, bool printAfterAll,
bool printModuleScope, bool printAfterOnlyOnChange,
bool printAfterOnlyOnFailure);

/// Enable / disable verify-each.
MLIR_CAPI_EXPORTED void
Expand Down
13 changes: 10 additions & 3 deletions mlir/lib/Bindings/Python/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,17 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
"Releases (leaks) the backing pass manager (testing)")
.def(
"enable_ir_printing",
[](PyPassManager &passManager) {
mlirPassManagerEnableIRPrinting(passManager.get());
[](PyPassManager &passManager, bool printBeforeAll,
bool printAfterAll, bool printModuleScope, bool printAfterChange,
bool printAfterFailure) {
mlirPassManagerEnableIRPrinting(
passManager.get(), printBeforeAll, printAfterAll,
printModuleScope, printAfterChange, printAfterFailure);
},
"Enable mlir-print-ir-after-all.")
"print_before_all"_a = false, "print_after_all"_a = true,
"print_module_scope"_a = false, "print_after_change"_a = false,
"print_after_failure"_a = false,
"Enable IR printing, default as mlir-print-ir-after-all.")
.def(
"enable_verifier",
[](PyPassManager &passManager, bool enable) {
Expand Down
17 changes: 15 additions & 2 deletions mlir/lib/CAPI/IR/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,21 @@ MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
return wrap(unwrap(passManager)->run(unwrap(op)));
}

void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
return unwrap(passManager)->enableIRPrinting();
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
bool printBeforeAll, bool printAfterAll,
bool printModuleScope,
bool printAfterOnlyOnChange,
bool printAfterOnlyOnFailure) {
auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
return printBeforeAll;
};
auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
return printAfterAll;
};
return unwrap(passManager)
->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
printModuleScope, printAfterOnlyOnChange,
printAfterOnlyOnFailure);
}

void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
Expand Down
9 changes: 8 additions & 1 deletion mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@ class PassManager:
def __init__(self, context: Optional[_ir.Context] = None) -> None: ...
def _CAPICreate(self) -> object: ...
def _testing_release(self) -> None: ...
def enable_ir_printing(self) -> None: ...
def enable_ir_printing(
self,
print_before_all: bool = False,
print_after_all: bool = True,
print_module_scope: bool = False,
print_after_change: bool = False,
print_after_failure: bool = False,
) -> None: ...
def enable_verifier(self, enable: bool) -> None: ...
@staticmethod
def parse(pipeline: str, context: Optional[_ir.Context] = None) -> PassManager: ...
Expand Down
30 changes: 28 additions & 2 deletions mlir/test/python/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,40 @@ def testPrintIrAfterAll():
pm = PassManager.parse("builtin.module(canonicalize)")
ctx.enable_multithreading(False)
pm.enable_ir_printing()
# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) ('builtin.module' operation) //----- //
# CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- //
# CHECK: module {
# CHECK: func.func @main() {
# CHECK: return
# CHECK: }
# CHECK: }
pm.run(module)


# CHECK-LABEL: TEST: testPrintIrBeforeAndAfterAll
@run
def testPrintIrBeforeAndAfterAll():
with Context() as ctx:
module = ModuleOp.parse(
"""
module {
func.func @main() {
%0 = arith.constant 10
return
}
}
"""
)
pm = PassManager.parse("builtin.module(canonicalize)")
ctx.enable_multithreading(False)
pm.enable_ir_printing(print_before_all=True, print_after_all=True)
# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) //----- //
# CHECK: module {
# CHECK: func.func @main() {
# CHECK: %[[C10:.*]] = arith.constant 10 : i64
# CHECK: return
# CHECK: }
# CHECK: }
# CHECK: // -----// IR Dump After Canonicalizer (canonicalize) ('builtin.module' operation) //----- //
# CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- //
# CHECK: module {
# CHECK: func.func @main() {
# CHECK: return
Expand Down

0 comments on commit f8eceb4

Please sign in to comment.