Add a flag to the IRPrinter instrumentation to only print after a pass if there is a change to the IR.

This adds an additional filtering mode for printing after a pass that checks to see if the pass actually changed the IR before printing it. This "change" detection is implemented using a SHA1 hash of the current operation and its children.

PiperOrigin-RevId: 284291089
Change-Id: I65d1b05ccab6fc4e50be754d04215cac6f056edb
This commit is contained in:
River Riddle 2019-12-06 17:04:24 -08:00 committed by TensorFlower Gardener
parent e5ec18d997
commit 4a45a4f987
4 changed files with 144 additions and 17 deletions

View File

@ -624,7 +624,7 @@ pipeline. This display mode is available in mlir-opt via
`-pass-timing-display=list`. `-pass-timing-display=list`.
```shell ```shell
$ mlir-opt foo.mlir -disable-pass-threading -cse -canonicalize -convert-std-to-llvm -pass-timing -pass-timing-display=list $ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing -pass-timing-display=list
===-------------------------------------------------------------------------=== ===-------------------------------------------------------------------------===
... Pass execution timing report ... ... Pass execution timing report ...
@ -649,7 +649,7 @@ the most time, and can also be used to identify when analyses are being
invalidated and recomputed. This is the default display mode. invalidated and recomputed. This is the default display mode.
```shell ```shell
$ mlir-opt foo.mlir -disable-pass-threading -cse -canonicalize -convert-std-to-llvm -pass-timing $ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing
===-------------------------------------------------------------------------=== ===-------------------------------------------------------------------------===
... Pass execution timing report ... ... Pass execution timing report ...
@ -680,7 +680,7 @@ perceived time, or clock time, whereas the `User Time` will display the total
cpu time. cpu time.
```shell ```shell
$ mlir-opt foo.mlir -cse -canonicalize -convert-std-to-llvm -pass-timing $ mlir-opt foo.mlir -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing
===-------------------------------------------------------------------------=== ===-------------------------------------------------------------------------===
... Pass execution timing report ... ... Pass execution timing report ...
@ -716,7 +716,7 @@ this instrumentation:
* Print the IR before every pass in the pipeline. * Print the IR before every pass in the pipeline.
```shell ```shell
$ mlir-opt foo.mlir -cse -print-ir-before=cse $ mlir-opt foo.mlir -pass-pipeline='func(cse)' -print-ir-before=cse
*** IR Dump Before CSE *** *** IR Dump Before CSE ***
func @simple_constant() -> (i32, i32) { func @simple_constant() -> (i32, i32) {
@ -732,7 +732,28 @@ func @simple_constant() -> (i32, i32) {
* Print the IR after every pass in the pipeline. * Print the IR after every pass in the pipeline.
```shell ```shell
$ mlir-opt foo.mlir -cse -print-ir-after=cse $ mlir-opt foo.mlir -pass-pipeline='func(cse)' -print-ir-after=cse
*** IR Dump After CSE ***
func @simple_constant() -> (i32, i32) {
%c1_i32 = constant 1 : i32
return %c1_i32, %c1_i32 : i32, i32
}
```
* `print-ir-after-change`
* Only print the IR after a pass if the pass mutated the IR. This helps to
reduce the number of IR dumps for "uninteresting" passes.
* Note: Changes are detected by comparing a hash of the operation before
and after the pass. This adds additional run-time to compute the hash of
the IR, and in some rare cases may result in false-positives depending
on the collision rate of the hash algorithm used.
* Note: This option should be used in unison with one of the other
'print-ir-after' options above, as this option alone does not enable
printing.
```shell
$ mlir-opt foo.mlir -pass-pipeline='func(cse,cse)' -print-ir-after=cse -print-ir-after-change
*** IR Dump After CSE *** *** IR Dump After CSE ***
func @simple_constant() -> (i32, i32) { func @simple_constant() -> (i32, i32) {
@ -748,7 +769,7 @@ func @simple_constant() -> (i32, i32) {
is disabled(`-disable-pass-threading`) is disabled(`-disable-pass-threading`)
```shell ```shell
$ mlir-opt foo.mlir -disable-pass-threading -cse -print-ir-after=cse -print-ir-module-scope $ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse)' -print-ir-after=cse -print-ir-module-scope
*** IR Dump After CSE *** ('func' operation: @bar) *** IR Dump After CSE *** ('func' operation: @bar)
func @bar(%arg0: f32, %arg1: f32) -> f32 { func @bar(%arg0: f32, %arg1: f32) -> f32 {

View File

@ -172,7 +172,12 @@ public:
/// printed. This should only be set to true when multi-threading is /// printed. This should only be set to true when multi-threading is
/// disabled, otherwise we may try to print IR that is being modified /// disabled, otherwise we may try to print IR that is being modified
/// asynchronously. /// asynchronously.
explicit IRPrinterConfig(bool printModuleScope = false); /// * 'printAfterOnlyOnChange' signals that when printing the IR after a
/// pass, in the case of a non-failure, we should first check if any
/// potential mutations were made. This allows for reducing the number of
/// logs that don't contain meaningful changes.
explicit IRPrinterConfig(bool printModuleScope = false,
bool printAfterOnlyOnChange = false);
virtual ~IRPrinterConfig(); virtual ~IRPrinterConfig();
/// A hook that may be overridden by a derived config that checks if the IR /// A hook that may be overridden by a derived config that checks if the IR
@ -192,9 +197,17 @@ public:
/// Returns true if the IR should always be printed at the top-level scope. /// Returns true if the IR should always be printed at the top-level scope.
bool shouldPrintAtModuleScope() const { return printModuleScope; } bool shouldPrintAtModuleScope() const { return printModuleScope; }
/// Returns true if the IR should only printed after a pass if the IR
/// "changed".
bool shouldPrintAfterOnlyOnChange() const { return printAfterOnlyOnChange; }
private: private:
/// A flag that indicates if the IR should be printed at module scope. /// A flag that indicates if the IR should be printed at module scope.
bool printModuleScope; bool printModuleScope;
/// A flag that indicates that the IR after a pass should only be printed if
/// a change is detected.
bool printAfterOnlyOnChange;
}; };
/// Add an instrumentation to print the IR before and after pass execution, /// Add an instrumentation to print the IR before and after pass execution,
@ -208,11 +221,14 @@ public:
/// return true if the IR should be printed or not. /// return true if the IR should be printed or not.
/// * 'printModuleScope' signals if the module IR should be printed, even /// * 'printModuleScope' signals if the module IR should be printed, even
/// for non module passes. /// for non module passes.
/// * 'printAfterOnlyOnChange' signals that when printing the IR after a
/// pass, in the case of a non-failure, we should first check if any
/// potential mutations were made.
/// * 'out' corresponds to the stream to output the printed IR to. /// * 'out' corresponds to the stream to output the printed IR to.
void enableIRPrinting( void enableIRPrinting(
std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
bool printModuleScope, raw_ostream &out); bool printModuleScope, bool printAfterOnlyOnChange, raw_ostream &out);
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Pass Timing // Pass Timing

View File

@ -20,11 +20,70 @@
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "llvm/Support/Format.h" #include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/SHA1.h"
using namespace mlir; using namespace mlir;
using namespace mlir::detail; using namespace mlir::detail;
namespace { namespace {
//===----------------------------------------------------------------------===//
// OperationFingerPrint
//===----------------------------------------------------------------------===//
/// A unique fingerprint for a specific operation, and all of it's internal
/// operations.
class OperationFingerPrint {
public:
OperationFingerPrint(Operation *topOp) {
llvm::SHA1 hasher;
// Hash each of the operations based upon their mutable bits:
topOp->walk([&](Operation *op) {
// - Operation pointer
addDataToHash(hasher, op);
// - Attributes
addDataToHash(hasher,
op->getAttrList().getDictionary().getAsOpaquePointer());
// - Blocks in Regions
for (Region &region : op->getRegions()) {
for (Block &block : region) {
addDataToHash(hasher, &block);
for (BlockArgument *arg : block.getArguments())
addDataToHash(hasher, arg);
}
}
// - Location
addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
// - Operands
for (Value *operand : op->getOperands())
addDataToHash(hasher, operand);
// - Successors
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
addDataToHash(hasher, op->getSuccessor(i));
});
hash = hasher.result();
}
bool operator==(const OperationFingerPrint &other) const {
return hash == other.hash;
}
bool operator!=(const OperationFingerPrint &other) const {
return !(*this == other);
}
private:
template <typename T> void addDataToHash(llvm::SHA1 &hasher, const T &data) {
hasher.update(
ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
}
SmallString<20> hash;
};
//===----------------------------------------------------------------------===//
// IRPrinter
//===----------------------------------------------------------------------===//
class IRPrinterInstrumentation : public PassInstrumentation { class IRPrinterInstrumentation : public PassInstrumentation {
public: public:
IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config) IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)
@ -38,6 +97,11 @@ private:
/// Configuration to use. /// Configuration to use.
std::unique_ptr<PassManager::IRPrinterConfig> config; std::unique_ptr<PassManager::IRPrinterConfig> config;
/// The following is a set of fingerprints for operations that are currently
/// being operated on in a pass. This field is only used when the
/// configuration asked for change detection.
DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints;
}; };
} // end anonymous namespace } // end anonymous namespace
@ -81,6 +145,10 @@ static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) { void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
if (isHiddenPass(pass)) if (isHiddenPass(pass))
return; return;
// If the config asked to detect changes, record the current fingerprint.
if (config->shouldPrintAfterOnlyOnChange())
beforePassFingerPrints.try_emplace(pass, op);
config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) { config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) {
out << formatv("*** IR Dump Before {0} ***", pass->getName()); out << formatv("*** IR Dump Before {0} ***", pass->getName());
printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags()); printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags());
@ -91,6 +159,20 @@ void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
if (isHiddenPass(pass)) if (isHiddenPass(pass))
return; return;
// If the config asked to detect changes, compare the current fingerprint with
// the previous.
if (config->shouldPrintAfterOnlyOnChange()) {
auto fingerPrintIt = beforePassFingerPrints.find(pass);
assert(fingerPrintIt != beforePassFingerPrints.end() &&
"expected valid fingerprint");
// If the fingerprints are the same, we don't print the IR.
if (fingerPrintIt->second == OperationFingerPrint(op)) {
beforePassFingerPrints.erase(fingerPrintIt);
return;
}
beforePassFingerPrints.erase(fingerPrintIt);
}
config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
out << formatv("*** IR Dump After {0} ***", pass->getName()); out << formatv("*** IR Dump After {0} ***", pass->getName());
printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags()); printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags());
@ -101,6 +183,9 @@ void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
if (isAdaptorPass(pass)) if (isAdaptorPass(pass))
return; return;
if (config->shouldPrintAfterOnlyOnChange())
beforePassFingerPrints.erase(pass);
config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
out << formatv("*** IR Dump After {0} Failed ***", pass->getName()); out << formatv("*** IR Dump After {0} Failed ***", pass->getName());
printIR(op, config->shouldPrintAtModuleScope(), out, printIR(op, config->shouldPrintAtModuleScope(), out,
@ -114,10 +199,10 @@ void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Initialize the configuration. /// Initialize the configuration.
/// * 'printModuleScope' signals if the module IR should be printed, even PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope,
/// for non module passes. bool printAfterOnlyOnChange)
PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope) : printModuleScope(printModuleScope),
: printModuleScope(printModuleScope) {} printAfterOnlyOnChange(printAfterOnlyOnChange) {}
PassManager::IRPrinterConfig::~IRPrinterConfig() {} PassManager::IRPrinterConfig::~IRPrinterConfig() {}
/// A hook that may be overridden by a derived config that checks if the IR /// A hook that may be overridden by a derived config that checks if the IR
@ -148,8 +233,8 @@ struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
BasicIRPrinterConfig( BasicIRPrinterConfig(
std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
bool printModuleScope, raw_ostream &out) bool printModuleScope, bool printAfterOnlyOnChange, raw_ostream &out)
: IRPrinterConfig(printModuleScope), : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange),
shouldPrintBeforePass(shouldPrintBeforePass), shouldPrintBeforePass(shouldPrintBeforePass),
shouldPrintAfterPass(shouldPrintAfterPass), out(out) { shouldPrintAfterPass(shouldPrintAfterPass), out(out) {
assert((shouldPrintBeforePass || shouldPrintAfterPass) && assert((shouldPrintBeforePass || shouldPrintAfterPass) &&
@ -188,8 +273,8 @@ void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) {
void PassManager::enableIRPrinting( void PassManager::enableIRPrinting(
std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
bool printModuleScope, raw_ostream &out) { bool printModuleScope, bool printAfterOnlyOnChange, raw_ostream &out) {
enableIRPrinting(std::make_unique<BasicIRPrinterConfig>( enableIRPrinting(std::make_unique<BasicIRPrinterConfig>(
std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass), std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass),
printModuleScope, out)); printModuleScope, printAfterOnlyOnChange, out));
} }

View File

@ -54,6 +54,11 @@ struct PassManagerOptions {
llvm::cl::opt<bool> printAfterAll{"print-ir-after-all", llvm::cl::opt<bool> printAfterAll{"print-ir-after-all",
llvm::cl::desc("Print IR after each pass"), llvm::cl::desc("Print IR after each pass"),
llvm::cl::init(false)}; llvm::cl::init(false)};
llvm::cl::opt<bool> printAfterChange{
"print-ir-after-change",
llvm::cl::desc(
"When printing the IR after a pass, only print if the IR changed"),
llvm::cl::init(false)};
llvm::cl::opt<bool> printModuleScope{ llvm::cl::opt<bool> printModuleScope{
"print-ir-module-scope", "print-ir-module-scope",
llvm::cl::desc("When printing IR for print-ir-[before|after]{-all} " llvm::cl::desc("When printing IR for print-ir-[before|after]{-all} "
@ -139,7 +144,7 @@ void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) {
// Otherwise, add the IR printing instrumentation. // Otherwise, add the IR printing instrumentation.
pm.enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, pm.enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
printModuleScope, llvm::errs()); printModuleScope, printAfterChange, llvm::errs());
} }
/// Add a pass timing instrumentation if enabled by 'pass-timing' flags. /// Add a pass timing instrumentation if enabled by 'pass-timing' flags.