From daabcf5f04bbd13ac53f76ca3cc43b0d1ef64f5a Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Sat, 16 Apr 2022 09:24:20 +0530 Subject: [PATCH] [MLIR] Provide a way to print ops in custom form on pass failure The generic form of the op is too verbose and in some cases not readable. On pass failure, ops have been so far printed in generic form to provide a (stronger) guarantee that the IR print succeeds. However, in a large number of pass failure cases, the IR is still valid and the custom printers for the ops will succeed. In fact, readability is highly desirable post pass failure. This revision provides an option to print ops in their custom/pretty-printed form on IR failure -- this option is unsafe and there is no guarantee it will succeed. It's disabled by default and can be turned on only if needed. Differential Revision: https://reviews.llvm.org/D123893 --- mlir/include/mlir/Pass/PassManager.h | 21 ++++++++++++++++++++- mlir/lib/Pass/IRPrinting.cpp | 19 ++++++++++++------- mlir/lib/Pass/PassManagerOptions.cpp | 9 ++++++++- mlir/test/Pass/ir-printing.mlir | 3 +++ 4 files changed, 43 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index 13b127c360f1..a6df1bb69aa1 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -253,11 +253,15 @@ public: /// pass, we only print in the case of a failure. /// - This option should *not* be used with the other `printAfter` flags /// above. + /// * 'printCustomFormOnFailure' signals that when printing the IR after a + /// pass failure, the custom form should be used (unsafe) instead of the + /// generic form. /// * 'opPrintingFlags' sets up the printing flags to use when printing the /// IR. explicit IRPrinterConfig( bool printModuleScope = false, bool printAfterOnlyOnChange = false, bool printAfterOnlyOnFailure = false, + bool printCustomFormOnFailure = false, OpPrintingFlags opPrintingFlags = OpPrintingFlags()); virtual ~IRPrinterConfig(); @@ -288,6 +292,13 @@ public: return printAfterOnlyOnFailure; } + /// Returns true if the IR should be printed in custom form even on failure. + /// This is unsafe and there is no guaranee that the custom form printer + /// will not crash or print valid IR. + bool shouldPrintCustomFormOnFailure() const { + return printCustomFormOnFailure; + } + /// Returns the printing flags to be used to print the IR. OpPrintingFlags getOpPrintingFlags() const { return opPrintingFlags; } @@ -303,6 +314,10 @@ public: /// the pass failed. bool printAfterOnlyOnFailure; + /// A flag that indicates that the IR should be printed (or attempted to be + /// printed) in custom form even after a pass failure. + bool printCustomFormOnFailure; + /// Flags to control printing behavior. OpPrintingFlags opPrintingFlags; }; @@ -325,6 +340,9 @@ public: /// pass, we only print in the case of a failure. /// - This option should *not* be used with the other `printAfter` flags /// above. + /// * 'printCustomFormOnFailure' signals that when printing the IR after a + /// pass failure, the custom form should be used (unsafe) instead of the + /// generic form. /// * 'out' corresponds to the stream to output the printed IR to. /// * 'opPrintingFlags' sets up the printing flags to use when printing the /// IR. @@ -334,7 +352,8 @@ public: std::function shouldPrintAfterPass = [](Pass *, Operation *) { return true; }, bool printModuleScope = true, bool printAfterOnlyOnChange = true, - bool printAfterOnlyOnFailure = false, raw_ostream &out = llvm::errs(), + bool printAfterOnlyOnFailure = false, + bool printCustomFormOnFailure = false, raw_ostream &out = llvm::errs(), OpPrintingFlags opPrintingFlags = OpPrintingFlags()); //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index 9e1e56ba4cb6..437a564732a1 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -171,7 +171,9 @@ void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { out << formatv("// -----// IR Dump After {0} Failed", pass->getName()); printIR(op, config->shouldPrintAtModuleScope(), out, - OpPrintingFlags().printGenericOpForm()); + config->shouldPrintCustomFormOnFailure() + ? OpPrintingFlags() + : OpPrintingFlags().printGenericOpForm()); out << "\n\n"; }); } @@ -184,10 +186,12 @@ void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope, bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure, + bool printCustomFormOnFailure, OpPrintingFlags opPrintingFlags) : printModuleScope(printModuleScope), printAfterOnlyOnChange(printAfterOnlyOnChange), printAfterOnlyOnFailure(printAfterOnlyOnFailure), + printCustomFormOnFailure(printCustomFormOnFailure), opPrintingFlags(opPrintingFlags) {} PassManager::IRPrinterConfig::~IRPrinterConfig() = default; @@ -220,10 +224,11 @@ struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig { std::function shouldPrintBeforePass, std::function shouldPrintAfterPass, bool printModuleScope, bool printAfterOnlyOnChange, - bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags, - raw_ostream &out) + bool printAfterOnlyOnFailure, bool printCustomFormOnFailure, + OpPrintingFlags opPrintingFlags, raw_ostream &out) : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, - printAfterOnlyOnFailure, opPrintingFlags), + printAfterOnlyOnFailure, printCustomFormOnFailure, + opPrintingFlags), shouldPrintBeforePass(std::move(shouldPrintBeforePass)), shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) { assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) && @@ -267,10 +272,10 @@ void PassManager::enableIRPrinting( std::function shouldPrintBeforePass, std::function shouldPrintAfterPass, bool printModuleScope, bool printAfterOnlyOnChange, - bool printAfterOnlyOnFailure, raw_ostream &out, - OpPrintingFlags opPrintingFlags) { + bool printAfterOnlyOnFailure, bool printCustomFormOnFailure, + raw_ostream &out, OpPrintingFlags opPrintingFlags) { enableIRPrinting(std::make_unique( std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass), printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, - opPrintingFlags, out)); + printCustomFormOnFailure, opPrintingFlags, out)); } diff --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp index 7b725b2904b1..566fc3079b69 100644 --- a/mlir/lib/Pass/PassManagerOptions.cpp +++ b/mlir/lib/Pass/PassManagerOptions.cpp @@ -53,6 +53,13 @@ struct PassManagerOptions { llvm::cl::desc( "When printing the IR after a pass, only print if the pass failed"), llvm::cl::init(false)}; + llvm::cl::opt printCustomAssemblyAfterFailure{ + "mlir-print-custom-assembly-after-failure", + llvm::cl::desc( + "When printing the IR after a pass failure, print in custom form " + "instead of generic (WARNING: this is unsafe and there is no " + "guarantee of a crash-free or valid print"), + llvm::cl::init(false)}; llvm::cl::opt printModuleScope{ "mlir-print-ir-module-scope", llvm::cl::desc("When printing IR for print-ir-[before|after]{-all} " @@ -122,7 +129,7 @@ void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) { // Otherwise, add the IR printing instrumentation. pm.enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, printModuleScope, printAfterChange, printAfterFailure, - llvm::errs()); + printCustomAssemblyAfterFailure, llvm::errs()); } void mlir::registerPassManagerCLOptions() { diff --git a/mlir/test/Pass/ir-printing.mlir b/mlir/test/Pass/ir-printing.mlir index e60b0265f622..ac2b31f75d0d 100644 --- a/mlir/test/Pass/ir-printing.mlir +++ b/mlir/test/Pass/ir-printing.mlir @@ -5,6 +5,7 @@ // RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func.func(cse,canonicalize)' -mlir-print-ir-before=cse -mlir-print-ir-module-scope -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_MODULE %s // RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func.func(cse,cse)' -mlir-print-ir-after-all -mlir-print-ir-after-change -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL_CHANGE %s // RUN: not mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func.func(cse,test-pass-failure)' -mlir-print-ir-after-failure -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_FAILURE %s +// RUN: not mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func.func(cse,test-pass-failure)' -mlir-print-ir-after-failure -mlir-print-custom-assembly-after-failure -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_FAILURE_CUSTOM %s func @foo() { %0 = arith.constant 0 : i32 @@ -64,3 +65,5 @@ func @bar() { // AFTER_FAILURE-NOT: // -----// IR Dump After{{.*}}CSE // AFTER_FAILURE: // -----// IR Dump After{{.*}}TestFailurePass Failed //----- // +// AFTER_FAILURE_CUSTOM: // -----// IR Dump After{{.*}}TestFailurePass Failed //----- // +// AFTER_FAILURE_CUSTOM: func @foo()