Introduce MlirOptimizationPassState::Fallback state to the MlirOptimizationPass.

This allows running the pass in the MlirFunctionOptimizationPass pipeline so that it only transforms the MLIR module in case of successful execution. Failure in the pass execution must not affect the MLIR module. Track failures in the execution by the /tensorflow/core/mlir_pass_fallback metric.

PiperOrigin-RevId: 356316824
Change-Id: Ida9b761c3dd39202982ad6a5d5804fb0e5433c4d
This commit is contained in:
Roman Dzhabarov 2021-02-08 11:48:34 -08:00 committed by TensorFlower Gardener
parent ed365f7817
commit f798728d37
5 changed files with 189 additions and 46 deletions

View File

@ -220,6 +220,7 @@ tf_cc_test(
":mlir_graph_optimization_pass", ":mlir_graph_optimization_pass",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"@llvm-project//mlir:IR",
], ],
) )

View File

@ -39,13 +39,17 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
auto* shadow_run_success = auto* mlir_function_optimization_pass_success =
monitoring::Counter<0>::New("/tensorflow/core/mlir_shadow_run_success", monitoring::Counter<0>::New("/tensorflow/core/mlir_shadow_run_success",
"Success count of MLIR shadow runs"); "Success count of MLIR pass runs");
auto* shadow_run_failure = monitoring::Counter<2>::New( auto* mlir_function_optimization_pass_failure = monitoring::Counter<2>::New(
"/tensorflow/core/mlir_shadow_run_failure", "/tensorflow/core/mlir_shadow_run_failure",
"Failure count of MLIR shadow runs", "kind", "name"); "Failure count of MLIR pass runs", "kind", "name");
auto* mlir_function_pass_failed_fallback = monitoring::Counter<0>::New(
"/tensorflow/core/mlir_pass_failed_fallback",
"Failure count of MLIR pass runs when fallback used");
static inline absl::string_view StringRefToView(llvm::StringRef ref) { static inline absl::string_view StringRefToView(llvm::StringRef ref) {
return {ref.data(), ref.size()}; return {ref.data(), ref.size()};
@ -117,23 +121,46 @@ Status MlirFunctionOptimizationPass::Run(
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def, std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names, std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) { bool* control_rets_updated) {
// This tracks whether at least one pass is enabled, all passes are disabled, // overall_state equals to:
// or there is a mix of disabled and shadow enabled passes. // Enabled if at least one pass is Enabled.
// Disabled if all passes are Disabled.
// ShadowEnabled if all non Disabled passes are ShadowEnabled.
// FallbackEnabled if there are no Enabled passes and there is at least one
// FallbackEnabled pass.
MlirOptimizationPassState overall_state = MlirOptimizationPassState::Disabled; MlirOptimizationPassState overall_state = MlirOptimizationPassState::Disabled;
// Cache per pass state and reuse it during pass execution. // Cache per pass state and reuse it during pass execution.
std::vector<MlirOptimizationPassState> per_pass_state; std::vector<MlirOptimizationPassState> per_pass_state;
per_pass_state.reserve(registry_->passes().size()); per_pass_state.reserve(registry_->passes().size());
int num_passes_enabled = 0, num_passes_disabled = 0,
num_passes_shadow_enabled = 0, num_passes_fallback_enabled = 0;
for (const auto& pass_registration : registry_->passes()) { for (const auto& pass_registration : registry_->passes()) {
MlirOptimizationPassState pass_state = pass_registration.pass->GetPassState( MlirOptimizationPassState pass_state = pass_registration.pass->GetPassState(
&device_set, config_proto, **graph); &device_set, config_proto, **graph);
per_pass_state.push_back(pass_state); per_pass_state.push_back(pass_state);
if (pass_state == MlirOptimizationPassState::ShadowEnabled && switch (pass_state) {
overall_state == MlirOptimizationPassState::Disabled) { case MlirOptimizationPassState::ShadowEnabled: {
if (overall_state == MlirOptimizationPassState::Disabled)
overall_state = MlirOptimizationPassState::ShadowEnabled; overall_state = MlirOptimizationPassState::ShadowEnabled;
} else if (pass_state == MlirOptimizationPassState::Enabled) { ++num_passes_shadow_enabled;
break;
}
case MlirOptimizationPassState::FallbackEnabled: {
if (overall_state != MlirOptimizationPassState::Enabled)
overall_state = MlirOptimizationPassState::FallbackEnabled;
++num_passes_fallback_enabled;
break;
}
case MlirOptimizationPassState::Enabled: {
overall_state = MlirOptimizationPassState::Enabled; overall_state = MlirOptimizationPassState::Enabled;
++num_passes_enabled;
break;
}
case MlirOptimizationPassState::Disabled: {
++num_passes_disabled;
break;
}
} }
} }
@ -147,15 +174,12 @@ Status MlirFunctionOptimizationPass::Run(
return Status::OK(); return Status::OK();
} }
if (overall_state == MlirOptimizationPassState::Enabled) { LOG_FIRST_N(INFO, 1) << "MLIR Graph Optimization Passes."
LOG_FIRST_N(INFO, 1) << "At least one MLIR Graph Optimization Pass enabled" << " Enabled: " << num_passes_enabled
<< "(registered " << registry_->passes().size() << ", Disabled: " << num_passes_disabled
<< " passes)"; << ", ShadowEnabled: " << num_passes_shadow_enabled
} else if (overall_state == MlirOptimizationPassState::ShadowEnabled) { << ", FallbackEnabled: " << num_passes_fallback_enabled
LOG_FIRST_N(INFO, 1) << ", Total: " << registry_->passes().size();
<< "All MLIR Graph Optimization Passes are shadow enabled"
<< "(registered " << registry_->passes().size() << " passes)";
}
GraphDebugInfo debug_info; GraphDebugInfo debug_info;
mlir::MLIRContext context; mlir::MLIRContext context;
@ -180,12 +204,13 @@ Status MlirFunctionOptimizationPass::Run(
return module_ref_status.status(); return module_ref_status.status();
} }
shadow_run_failure->GetCell("graph_to_mlir", "")->IncrementBy(1); mlir_function_optimization_pass_failure->GetCell("graph_to_mlir", "")
->IncrementBy(1);
// Do not fail, just keep the original TF graph unchanged in shadow mode. // Do not fail, just keep the original TF graph unchanged in shadow mode.
return Status::OK(); return Status::OK();
} }
auto module_ref = std::move(module_ref_status.ValueOrDie()); mlir::OwningModuleRef module_ref = std::move(module_ref_status.ValueOrDie());
AddDevicesToOp(*module_ref, &device_set); AddDevicesToOp(*module_ref, &device_set);
int per_pass_state_index = 0; int per_pass_state_index = 0;
@ -207,22 +232,36 @@ Status MlirFunctionOptimizationPass::Run(
overall_state == MlirOptimizationPassState::ShadowEnabled)) { overall_state == MlirOptimizationPassState::ShadowEnabled)) {
pass_status = pass_status =
pass_registration.pass->Run(config_proto, *module_ref, **graph); pass_registration.pass->Run(config_proto, *module_ref, **graph);
} else if (pass_state == MlirOptimizationPassState::ShadowEnabled) { } else if (pass_state == MlirOptimizationPassState::ShadowEnabled ||
// Make sure that the pass does not modify MLIR module if it's shadow pass_state == MlirOptimizationPassState::FallbackEnabled) {
// enabled. // Make sure when the pass is:
// ShadowEnabled, it does not modify the MLIR module.
// FallbackEnabled, it only modifies the MLIR module in case of
// no failures.
auto module_ref_clone = module_ref->clone(); auto module_ref_clone = module_ref->clone();
pass_status = pass_status =
pass_registration.pass->Run(config_proto, module_ref_clone, **graph); pass_registration.pass->Run(config_proto, module_ref_clone, **graph);
if (pass_state == MlirOptimizationPassState::FallbackEnabled &&
pass_status.ok()) {
module_ref = module_ref_clone;
} else {
module_ref_clone->destroy(); module_ref_clone->destroy();
} }
}
if (!pass_status.ok()) { if (!pass_status.ok()) {
// If pass failed and pass is: // If pass failed and it is:
// ShadowEnabled - only collect metrics, do not propagate // (Shadow|Fallback)Enabled - only collect metrics, do not propagate
// error to the caller. // error to the caller.
// Enabled - return error back to the caller. // Enabled - return error back to the caller.
if (pass_state == MlirOptimizationPassState::ShadowEnabled) { if (pass_state == MlirOptimizationPassState::ShadowEnabled) {
shadow_run_failure->GetCell("pass", name.str())->IncrementBy(1); mlir_function_optimization_pass_failure->GetCell("pass", name.str())
->IncrementBy(1);
} else if (pass_state == MlirOptimizationPassState::FallbackEnabled) {
LOG(WARNING) << StringRefToView(name)
<< " pass failed, continuing without the pass because the "
"pass has fallback enabled";
mlir_function_pass_failed_fallback->GetCell()->IncrementBy(1);
} else if (pass_state == MlirOptimizationPassState::Enabled) { } else if (pass_state == MlirOptimizationPassState::Enabled) {
return pass_status; return pass_status;
} }
@ -246,9 +285,10 @@ Status MlirFunctionOptimizationPass::Run(
ConvertMlirToGraph(*module_ref, export_config, &empty_graph, ConvertMlirToGraph(*module_ref, export_config, &empty_graph,
&empty_flib, &control_ret_nodes); &empty_flib, &control_ret_nodes);
if (mlir_to_graph_status.ok()) { if (mlir_to_graph_status.ok()) {
shadow_run_success->GetCell()->IncrementBy(1); mlir_function_optimization_pass_success->GetCell()->IncrementBy(1);
} else { } else {
shadow_run_failure->GetCell("mlir_to_graph", "")->IncrementBy(1); mlir_function_optimization_pass_failure->GetCell("mlir_to_graph", "")
->IncrementBy(1);
} }
return Status::OK(); return Status::OK();

View File

@ -29,7 +29,20 @@ namespace tensorflow {
// MLIR passes running on Tensorflow function graphs (Tensorflow V2). // MLIR passes running on Tensorflow function graphs (Tensorflow V2).
// -------------------------------------------------------------------------- // // -------------------------------------------------------------------------- //
enum class MlirOptimizationPassState { Disabled, Enabled, ShadowEnabled }; // Disabled - skip execution of the pass.
// Enabled - execute the pass, propagate errors to the caller if any.
// ShadowEnabled - execute the pass in a shadow mode. The pass should not commit
// any changes to the MLIR module it's processing. Failures are not propagated
// to the caller.
// FallbackEnabled - execute the pass and commit all the changes to the MLIR
// module in case of success. Do not commit any changes in case of failures,
// let the rest of the pipeline run.
enum class MlirOptimizationPassState {
Disabled,
Enabled,
ShadowEnabled,
FallbackEnabled
};
// An API for registering MLIR ModulePass with the Tensorflow runtime. These // An API for registering MLIR ModulePass with the Tensorflow runtime. These
// passes are running only for function graphs built by Tensorflow V2 and // passes are running only for function graphs built by Tensorflow V2 and

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <memory> #include <memory>
#include "mlir/IR/Builders.h" // from @llvm-project
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {
@ -39,6 +40,37 @@ class MockMlirOptimizationPass : public MlirOptimizationPass {
mlir::ModuleOp module, const Graph& graph)); mlir::ModuleOp module, const Graph& graph));
}; };
class ModifyMlirModulePass : public MlirOptimizationPass {
public:
explicit ModifyMlirModulePass(Status run_status) : run_status_(run_status) {}
// MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX
// instead.
MOCK_CONST_METHOD0(name, llvm::StringRef());
MOCK_CONST_METHOD3(GetPassState,
MlirOptimizationPassState(const DeviceSet* device_set,
const ConfigProto& config_proto,
const Graph& graph));
// Just modify MLIR module so that we can check whether original TF graph
// has changed or not.
Status Run(const ConfigProto& config_proto, mlir::ModuleOp module,
const Graph& graph) override {
mlir::Builder b(module.getContext());
auto producer = b.getNamedAttr("producer", b.getI32IntegerAttr(0));
auto min_consumer = b.getNamedAttr("min_consumer", b.getI32IntegerAttr(0));
auto bad_consumers =
b.getNamedAttr("bad_consumers", b.getI32ArrayAttr({1, 2, 3, 4}));
module->setAttr("tf.versions",
b.getDictionaryAttr(llvm::ArrayRef<mlir::NamedAttribute>(
{producer, min_consumer, bad_consumers})));
return run_status_;
}
Status run_status_;
};
class MlirGraphOptimizationPassTest : public Test { class MlirGraphOptimizationPassTest : public Test {
public: public:
void Init(Status pass_run_result, void Init(Status pass_run_result,
@ -61,15 +93,32 @@ class MlirGraphOptimizationPassTest : public Test {
flib_.reset(new FunctionLibraryDefinition(graph_->flib_def())); flib_.reset(new FunctionLibraryDefinition(graph_->flib_def()));
} }
void AddModuleModificationPass(MlirOptimizationPassState pass_state,
Status run_status) {
// Add FallbackEnabled pass that modifies the graph.
auto optimization_pass =
std::make_unique<NiceMock<ModifyMlirModulePass>>(run_status);
ON_CALL(*optimization_pass, GetPassState(_, _, _))
.WillByDefault(Return(pass_state));
MlirOptimizationPassRegistry::Global().Add(10,
std::move(optimization_pass));
}
void TearDown() override { void TearDown() override {
MlirOptimizationPassRegistry::Global().ClearPasses(); MlirOptimizationPassRegistry::Global().ClearPasses();
} }
void verifyGraphUnchanged(const GraphDef& original_graph_def) { void verifyGraph(const GraphDef& original_graph_def, bool changed = false) {
// Proto matchers might be unavailable in the OSS. // Proto matchers might be unavailable in the OSS.
#if defined(PLATFORM_GOOGLE) #if defined(PLATFORM_GOOGLE)
GraphDef resulted_graph_def; GraphDef resulted_graph_def;
graph_->ToGraphDef(&resulted_graph_def); graph_->ToGraphDef(&resulted_graph_def);
if (changed)
EXPECT_THAT(resulted_graph_def,
Not(::testing::proto::IgnoringRepeatedFieldOrdering(
::testing::EquivToProto(original_graph_def))));
else
EXPECT_THAT(resulted_graph_def, EXPECT_THAT(resulted_graph_def,
::testing::proto::IgnoringRepeatedFieldOrdering( ::testing::proto::IgnoringRepeatedFieldOrdering(
::testing::EquivToProto(original_graph_def))); ::testing::EquivToProto(original_graph_def)));
@ -96,7 +145,7 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoShadow) {
device_set_, config_proto_, &graph_, flib_.get(), device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_), &control_ret_node_names_, &control_rets_updated_),
Status(error::Code::ABORTED, "aborted")); Status(error::Code::ABORTED, "aborted"));
verifyGraphUnchanged(original_graph_def); verifyGraph(original_graph_def);
} }
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) { TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) {
@ -111,7 +160,7 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) {
device_set_, config_proto_, &graph_, flib_.get(), device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_), &control_ret_node_names_, &control_rets_updated_),
Status::OK()); Status::OK());
verifyGraphUnchanged(original_graph_def); verifyGraph(original_graph_def);
} }
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailShadow) { TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailShadow) {
@ -125,7 +174,7 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailShadow) {
device_set_, config_proto_, &graph_, flib_.get(), device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_), &control_ret_node_names_, &control_rets_updated_),
Status::OK()); Status::OK());
verifyGraphUnchanged(original_graph_def); verifyGraph(original_graph_def);
} }
TEST_F(MlirGraphOptimizationPassTest, TEST_F(MlirGraphOptimizationPassTest,
@ -141,6 +190,44 @@ TEST_F(MlirGraphOptimizationPassTest,
device_set_, config_proto_, &graph_, flib_.get(), device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_), &control_ret_node_names_, &control_rets_updated_),
Status(error::Code::ABORTED, "aborted")); Status(error::Code::ABORTED, "aborted"));
verifyGraph(original_graph_def);
}
TEST_F(MlirGraphOptimizationPassTest,
OptimizationPassFailsShadowDisabledFallback) {
Init(Status(error::Code::ABORTED, "aborted"),
{MlirOptimizationPassState::Disabled,
MlirOptimizationPassState::ShadowEnabled,
MlirOptimizationPassState::FallbackEnabled});
GraphDef original_graph_def;
graph_->ToGraphDef(&original_graph_def);
AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled,
Status(error::Code::ABORTED, "aborted"));
EXPECT_EQ(function_optimization_pass_.Run(
device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_),
Status::OK());
verifyGraph(original_graph_def);
}
TEST_F(MlirGraphOptimizationPassTest,
OptimizationPassDoesNotFailShadowFallback) {
Init(Status::OK(), {MlirOptimizationPassState::ShadowEnabled,
MlirOptimizationPassState::FallbackEnabled});
GraphDef original_graph_def;
graph_->ToGraphDef(&original_graph_def);
AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled,
Status::OK());
EXPECT_EQ(function_optimization_pass_.Run(
device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_),
Status::OK());
verifyGraph(original_graph_def, true);
} }
TEST(MlirOptimizationPassRegistry, RegisterPassesWithTheSamePriorityFails) { TEST(MlirOptimizationPassRegistry, RegisterPassesWithTheSamePriorityFails) {

View File

@ -76,11 +76,13 @@ MlirOptimizationPassState MlirBridgePass::GetPassState(
MlirBridgeRolloutPolicy policy = MlirBridgeRolloutPolicy policy =
GetMlirBridgeRolloutPolicy(graph, config_proto); GetMlirBridgeRolloutPolicy(graph, config_proto);
if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) { switch (policy) {
case MlirBridgeRolloutPolicy::kEnabledByUser:
return MlirOptimizationPassState::Enabled; return MlirOptimizationPassState::Enabled;
} else if (policy == MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis) { case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis:
return MlirOptimizationPassState::ShadowEnabled; return MlirOptimizationPassState::ShadowEnabled;
} else { case MlirBridgeRolloutPolicy::kDisabledByUser:
case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis:
return MlirOptimizationPassState::Disabled; return MlirOptimizationPassState::Disabled;
} }
} }