diff --git a/third_party/mlir/g3doc/WritingAPass.md b/third_party/mlir/g3doc/WritingAPass.md index f72d41bea40..784757139d3 100644 --- a/third_party/mlir/g3doc/WritingAPass.md +++ b/third_party/mlir/g3doc/WritingAPass.md @@ -624,7 +624,7 @@ pipeline. This display mode is available in mlir-opt via `-pass-timing-display=list`. ```shell -$ mlir-opt foo.mlir -disable-pass-threading -cse -canonicalize -convert-std-to-llvm -pass-timing -pass-timing-display=list +$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing -pass-timing-display=list ===-------------------------------------------------------------------------=== ... Pass execution timing report ... @@ -649,7 +649,7 @@ the most time, and can also be used to identify when analyses are being invalidated and recomputed. This is the default display mode. ```shell -$ mlir-opt foo.mlir -disable-pass-threading -cse -canonicalize -convert-std-to-llvm -pass-timing +$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing ===-------------------------------------------------------------------------=== ... Pass execution timing report ... @@ -680,7 +680,7 @@ perceived time, or clock time, whereas the `User Time` will display the total cpu time. ```shell -$ mlir-opt foo.mlir -cse -canonicalize -convert-std-to-llvm -pass-timing +$ mlir-opt foo.mlir -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing ===-------------------------------------------------------------------------=== ... Pass execution timing report ... @@ -716,7 +716,7 @@ this instrumentation: * Print the IR before every pass in the pipeline. ```shell -$ mlir-opt foo.mlir -cse -print-ir-before=cse +$ mlir-opt foo.mlir -pass-pipeline='func(cse)' -print-ir-before=cse *** IR Dump Before CSE *** func @simple_constant() -> (i32, i32) { @@ -732,7 +732,28 @@ func @simple_constant() -> (i32, i32) { * Print the IR after every pass in the pipeline. ```shell -$ mlir-opt foo.mlir -cse -print-ir-after=cse +$ mlir-opt foo.mlir -pass-pipeline='func(cse)' -print-ir-after=cse + +*** IR Dump After CSE *** +func @simple_constant() -> (i32, i32) { + %c1_i32 = constant 1 : i32 + return %c1_i32, %c1_i32 : i32, i32 +} +``` + +* `print-ir-after-change` + * Only print the IR after a pass if the pass mutated the IR. This helps to + reduce the number of IR dumps for "uninteresting" passes. + * Note: Changes are detected by comparing a hash of the operation before + and after the pass. This adds additional run-time to compute the hash of + the IR, and in some rare cases may result in false-positives depending + on the collision rate of the hash algorithm used. + * Note: This option should be used in unison with one of the other + 'print-ir-after' options above, as this option alone does not enable + printing. + +```shell +$ mlir-opt foo.mlir -pass-pipeline='func(cse,cse)' -print-ir-after=cse -print-ir-after-change *** IR Dump After CSE *** func @simple_constant() -> (i32, i32) { @@ -748,7 +769,7 @@ func @simple_constant() -> (i32, i32) { is disabled(`-disable-pass-threading`) ```shell -$ mlir-opt foo.mlir -disable-pass-threading -cse -print-ir-after=cse -print-ir-module-scope +$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse)' -print-ir-after=cse -print-ir-module-scope *** IR Dump After CSE *** ('func' operation: @bar) func @bar(%arg0: f32, %arg1: f32) -> f32 { diff --git a/third_party/mlir/include/mlir/Pass/PassManager.h b/third_party/mlir/include/mlir/Pass/PassManager.h index 724ee0a31cd..5762d684b06 100644 --- a/third_party/mlir/include/mlir/Pass/PassManager.h +++ b/third_party/mlir/include/mlir/Pass/PassManager.h @@ -172,7 +172,12 @@ public: /// printed. This should only be set to true when multi-threading is /// disabled, otherwise we may try to print IR that is being modified /// asynchronously. - explicit IRPrinterConfig(bool printModuleScope = false); + /// * 'printAfterOnlyOnChange' signals that when printing the IR after a + /// pass, in the case of a non-failure, we should first check if any + /// potential mutations were made. This allows for reducing the number of + /// logs that don't contain meaningful changes. + explicit IRPrinterConfig(bool printModuleScope = false, + bool printAfterOnlyOnChange = false); virtual ~IRPrinterConfig(); /// A hook that may be overridden by a derived config that checks if the IR @@ -192,9 +197,17 @@ public: /// Returns true if the IR should always be printed at the top-level scope. bool shouldPrintAtModuleScope() const { return printModuleScope; } + /// Returns true if the IR should only printed after a pass if the IR + /// "changed". + bool shouldPrintAfterOnlyOnChange() const { return printAfterOnlyOnChange; } + private: /// A flag that indicates if the IR should be printed at module scope. bool printModuleScope; + + /// A flag that indicates that the IR after a pass should only be printed if + /// a change is detected. + bool printAfterOnlyOnChange; }; /// Add an instrumentation to print the IR before and after pass execution, @@ -208,11 +221,14 @@ public: /// return true if the IR should be printed or not. /// * 'printModuleScope' signals if the module IR should be printed, even /// for non module passes. + /// * 'printAfterOnlyOnChange' signals that when printing the IR after a + /// pass, in the case of a non-failure, we should first check if any + /// potential mutations were made. /// * 'out' corresponds to the stream to output the printed IR to. void enableIRPrinting( std::function shouldPrintBeforePass, std::function shouldPrintAfterPass, - bool printModuleScope, raw_ostream &out); + bool printModuleScope, bool printAfterOnlyOnChange, raw_ostream &out); //===--------------------------------------------------------------------===// // Pass Timing diff --git a/third_party/mlir/lib/Pass/IRPrinting.cpp b/third_party/mlir/lib/Pass/IRPrinting.cpp index 19e69feb5d8..8e172156f05 100644 --- a/third_party/mlir/lib/Pass/IRPrinting.cpp +++ b/third_party/mlir/lib/Pass/IRPrinting.cpp @@ -20,11 +20,70 @@ #include "mlir/Pass/PassManager.h" #include "llvm/Support/Format.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/SHA1.h" using namespace mlir; using namespace mlir::detail; namespace { +//===----------------------------------------------------------------------===// +// OperationFingerPrint +//===----------------------------------------------------------------------===// + +/// A unique fingerprint for a specific operation, and all of it's internal +/// operations. +class OperationFingerPrint { +public: + OperationFingerPrint(Operation *topOp) { + llvm::SHA1 hasher; + + // Hash each of the operations based upon their mutable bits: + topOp->walk([&](Operation *op) { + // - Operation pointer + addDataToHash(hasher, op); + // - Attributes + addDataToHash(hasher, + op->getAttrList().getDictionary().getAsOpaquePointer()); + // - Blocks in Regions + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + addDataToHash(hasher, &block); + for (BlockArgument *arg : block.getArguments()) + addDataToHash(hasher, arg); + } + } + // - Location + addDataToHash(hasher, op->getLoc().getAsOpaquePointer()); + // - Operands + for (Value *operand : op->getOperands()) + addDataToHash(hasher, operand); + // - Successors + for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) + addDataToHash(hasher, op->getSuccessor(i)); + }); + hash = hasher.result(); + } + + bool operator==(const OperationFingerPrint &other) const { + return hash == other.hash; + } + bool operator!=(const OperationFingerPrint &other) const { + return !(*this == other); + } + +private: + template void addDataToHash(llvm::SHA1 &hasher, const T &data) { + hasher.update( + ArrayRef(reinterpret_cast(&data), sizeof(T))); + } + + SmallString<20> hash; +}; + +//===----------------------------------------------------------------------===// +// IRPrinter +//===----------------------------------------------------------------------===// + class IRPrinterInstrumentation : public PassInstrumentation { public: IRPrinterInstrumentation(std::unique_ptr config) @@ -38,6 +97,11 @@ private: /// Configuration to use. std::unique_ptr config; + + /// The following is a set of fingerprints for operations that are currently + /// being operated on in a pass. This field is only used when the + /// configuration asked for change detection. + DenseMap beforePassFingerPrints; }; } // end anonymous namespace @@ -81,6 +145,10 @@ static void printIR(Operation *op, bool printModuleScope, raw_ostream &out, void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) { if (isHiddenPass(pass)) return; + // If the config asked to detect changes, record the current fingerprint. + if (config->shouldPrintAfterOnlyOnChange()) + beforePassFingerPrints.try_emplace(pass, op); + config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) { out << formatv("*** IR Dump Before {0} ***", pass->getName()); printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags()); @@ -91,6 +159,20 @@ void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) { void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { if (isHiddenPass(pass)) return; + // If the config asked to detect changes, compare the current fingerprint with + // the previous. + if (config->shouldPrintAfterOnlyOnChange()) { + auto fingerPrintIt = beforePassFingerPrints.find(pass); + assert(fingerPrintIt != beforePassFingerPrints.end() && + "expected valid fingerprint"); + // If the fingerprints are the same, we don't print the IR. + if (fingerPrintIt->second == OperationFingerPrint(op)) { + beforePassFingerPrints.erase(fingerPrintIt); + return; + } + beforePassFingerPrints.erase(fingerPrintIt); + } + config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { out << formatv("*** IR Dump After {0} ***", pass->getName()); printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags()); @@ -101,6 +183,9 @@ void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { if (isAdaptorPass(pass)) return; + if (config->shouldPrintAfterOnlyOnChange()) + beforePassFingerPrints.erase(pass); + config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { out << formatv("*** IR Dump After {0} Failed ***", pass->getName()); printIR(op, config->shouldPrintAtModuleScope(), out, @@ -114,10 +199,10 @@ void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { //===----------------------------------------------------------------------===// /// Initialize the configuration. -/// * 'printModuleScope' signals if the module IR should be printed, even -/// for non module passes. -PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope) - : printModuleScope(printModuleScope) {} +PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope, + bool printAfterOnlyOnChange) + : printModuleScope(printModuleScope), + printAfterOnlyOnChange(printAfterOnlyOnChange) {} PassManager::IRPrinterConfig::~IRPrinterConfig() {} /// A hook that may be overridden by a derived config that checks if the IR @@ -148,8 +233,8 @@ struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig { BasicIRPrinterConfig( std::function shouldPrintBeforePass, std::function shouldPrintAfterPass, - bool printModuleScope, raw_ostream &out) - : IRPrinterConfig(printModuleScope), + bool printModuleScope, bool printAfterOnlyOnChange, raw_ostream &out) + : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange), shouldPrintBeforePass(shouldPrintBeforePass), shouldPrintAfterPass(shouldPrintAfterPass), out(out) { assert((shouldPrintBeforePass || shouldPrintAfterPass) && @@ -188,8 +273,8 @@ void PassManager::enableIRPrinting(std::unique_ptr config) { void PassManager::enableIRPrinting( std::function shouldPrintBeforePass, std::function shouldPrintAfterPass, - bool printModuleScope, raw_ostream &out) { + bool printModuleScope, bool printAfterOnlyOnChange, raw_ostream &out) { enableIRPrinting(std::make_unique( std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass), - printModuleScope, out)); + printModuleScope, printAfterOnlyOnChange, out)); } diff --git a/third_party/mlir/lib/Pass/PassManagerOptions.cpp b/third_party/mlir/lib/Pass/PassManagerOptions.cpp index 1416dfe3e8c..932bf98f61e 100644 --- a/third_party/mlir/lib/Pass/PassManagerOptions.cpp +++ b/third_party/mlir/lib/Pass/PassManagerOptions.cpp @@ -54,6 +54,11 @@ struct PassManagerOptions { llvm::cl::opt printAfterAll{"print-ir-after-all", llvm::cl::desc("Print IR after each pass"), llvm::cl::init(false)}; + llvm::cl::opt printAfterChange{ + "print-ir-after-change", + llvm::cl::desc( + "When printing the IR after a pass, only print if the IR changed"), + llvm::cl::init(false)}; llvm::cl::opt printModuleScope{ "print-ir-module-scope", llvm::cl::desc("When printing IR for print-ir-[before|after]{-all} " @@ -139,7 +144,7 @@ void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) { // Otherwise, add the IR printing instrumentation. pm.enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, - printModuleScope, llvm::errs()); + printModuleScope, printAfterChange, llvm::errs()); } /// Add a pass timing instrumentation if enabled by 'pass-timing' flags.