Introduce MlirOptimizationPassState::Fallback state to the MlirOptimizationPass.

This allows running the pass in the MlirFunctionOptimizationPass pipeline so that it only transforms the MLIR module in case of successful execution. Failure in the pass execution must not affect the MLIR module. Track failures in the execution by the /tensorflow/core/mlir_pass_fallback metric.

PiperOrigin-RevId: 356316824
Change-Id: Ida9b761c3dd39202982ad6a5d5804fb0e5433c4d
This commit is contained in:
Roman Dzhabarov 2021-02-08 11:48:34 -08:00 committed by TensorFlower Gardener
parent ed365f7817
commit f798728d37
5 changed files with 189 additions and 46 deletions

View File

@ -220,6 +220,7 @@ tf_cc_test(
":mlir_graph_optimization_pass",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm-project//mlir:IR",
],
)

View File

@ -39,13 +39,17 @@ limitations under the License.
namespace tensorflow {
auto* shadow_run_success =
auto* mlir_function_optimization_pass_success =
monitoring::Counter<0>::New("/tensorflow/core/mlir_shadow_run_success",
"Success count of MLIR shadow runs");
"Success count of MLIR pass runs");
auto* shadow_run_failure = monitoring::Counter<2>::New(
auto* mlir_function_optimization_pass_failure = monitoring::Counter<2>::New(
"/tensorflow/core/mlir_shadow_run_failure",
"Failure count of MLIR shadow runs", "kind", "name");
"Failure count of MLIR pass runs", "kind", "name");
auto* mlir_function_pass_failed_fallback = monitoring::Counter<0>::New(
"/tensorflow/core/mlir_pass_failed_fallback",
"Failure count of MLIR pass runs when fallback used");
static inline absl::string_view StringRefToView(llvm::StringRef ref) {
return {ref.data(), ref.size()};
@ -117,23 +121,46 @@ Status MlirFunctionOptimizationPass::Run(
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) {
// This tracks whether at least one pass is enabled, all passes are disabled,
// or there is a mix of disabled and shadow enabled passes.
// overall_state equals to:
// Enabled if at least one pass is Enabled.
// Disabled if all passes are Disabled.
// ShadowEnabled if all non Disabled passes are ShadowEnabled.
// FallbackEnabled if there are no Enabled passes and there is at least one
// FallbackEnabled pass.
MlirOptimizationPassState overall_state = MlirOptimizationPassState::Disabled;
// Cache per pass state and reuse it during pass execution.
std::vector<MlirOptimizationPassState> per_pass_state;
per_pass_state.reserve(registry_->passes().size());
int num_passes_enabled = 0, num_passes_disabled = 0,
num_passes_shadow_enabled = 0, num_passes_fallback_enabled = 0;
for (const auto& pass_registration : registry_->passes()) {
MlirOptimizationPassState pass_state = pass_registration.pass->GetPassState(
&device_set, config_proto, **graph);
per_pass_state.push_back(pass_state);
if (pass_state == MlirOptimizationPassState::ShadowEnabled &&
overall_state == MlirOptimizationPassState::Disabled) {
overall_state = MlirOptimizationPassState::ShadowEnabled;
} else if (pass_state == MlirOptimizationPassState::Enabled) {
overall_state = MlirOptimizationPassState::Enabled;
switch (pass_state) {
case MlirOptimizationPassState::ShadowEnabled: {
if (overall_state == MlirOptimizationPassState::Disabled)
overall_state = MlirOptimizationPassState::ShadowEnabled;
++num_passes_shadow_enabled;
break;
}
case MlirOptimizationPassState::FallbackEnabled: {
if (overall_state != MlirOptimizationPassState::Enabled)
overall_state = MlirOptimizationPassState::FallbackEnabled;
++num_passes_fallback_enabled;
break;
}
case MlirOptimizationPassState::Enabled: {
overall_state = MlirOptimizationPassState::Enabled;
++num_passes_enabled;
break;
}
case MlirOptimizationPassState::Disabled: {
++num_passes_disabled;
break;
}
}
}
@ -147,15 +174,12 @@ Status MlirFunctionOptimizationPass::Run(
return Status::OK();
}
if (overall_state == MlirOptimizationPassState::Enabled) {
LOG_FIRST_N(INFO, 1) << "At least one MLIR Graph Optimization Pass enabled"
<< "(registered " << registry_->passes().size()
<< " passes)";
} else if (overall_state == MlirOptimizationPassState::ShadowEnabled) {
LOG_FIRST_N(INFO, 1)
<< "All MLIR Graph Optimization Passes are shadow enabled"
<< "(registered " << registry_->passes().size() << " passes)";
}
LOG_FIRST_N(INFO, 1) << "MLIR Graph Optimization Passes."
<< " Enabled: " << num_passes_enabled
<< ", Disabled: " << num_passes_disabled
<< ", ShadowEnabled: " << num_passes_shadow_enabled
<< ", FallbackEnabled: " << num_passes_fallback_enabled
<< ", Total: " << registry_->passes().size();
GraphDebugInfo debug_info;
mlir::MLIRContext context;
@ -180,12 +204,13 @@ Status MlirFunctionOptimizationPass::Run(
return module_ref_status.status();
}
shadow_run_failure->GetCell("graph_to_mlir", "")->IncrementBy(1);
mlir_function_optimization_pass_failure->GetCell("graph_to_mlir", "")
->IncrementBy(1);
// Do not fail, just keep the original TF graph unchanged in shadow mode.
return Status::OK();
}
auto module_ref = std::move(module_ref_status.ValueOrDie());
mlir::OwningModuleRef module_ref = std::move(module_ref_status.ValueOrDie());
AddDevicesToOp(*module_ref, &device_set);
int per_pass_state_index = 0;
@ -207,22 +232,36 @@ Status MlirFunctionOptimizationPass::Run(
overall_state == MlirOptimizationPassState::ShadowEnabled)) {
pass_status =
pass_registration.pass->Run(config_proto, *module_ref, **graph);
} else if (pass_state == MlirOptimizationPassState::ShadowEnabled) {
// Make sure that the pass does not modify MLIR module if it's shadow
// enabled.
} else if (pass_state == MlirOptimizationPassState::ShadowEnabled ||
pass_state == MlirOptimizationPassState::FallbackEnabled) {
// Make sure when the pass is:
// ShadowEnabled, it does not modify the MLIR module.
// FallbackEnabled, it only modifies the MLIR module in case of
// no failures.
auto module_ref_clone = module_ref->clone();
pass_status =
pass_registration.pass->Run(config_proto, module_ref_clone, **graph);
module_ref_clone->destroy();
if (pass_state == MlirOptimizationPassState::FallbackEnabled &&
pass_status.ok()) {
module_ref = module_ref_clone;
} else {
module_ref_clone->destroy();
}
}
if (!pass_status.ok()) {
// If pass failed and pass is:
// ShadowEnabled - only collect metrics, do not propagate
// error to the caller.
// If pass failed and it is:
// (Shadow|Fallback)Enabled - only collect metrics, do not propagate
// error to the caller.
// Enabled - return error back to the caller.
if (pass_state == MlirOptimizationPassState::ShadowEnabled) {
shadow_run_failure->GetCell("pass", name.str())->IncrementBy(1);
mlir_function_optimization_pass_failure->GetCell("pass", name.str())
->IncrementBy(1);
} else if (pass_state == MlirOptimizationPassState::FallbackEnabled) {
LOG(WARNING) << StringRefToView(name)
<< " pass failed, continuing without the pass because the "
"pass has fallback enabled";
mlir_function_pass_failed_fallback->GetCell()->IncrementBy(1);
} else if (pass_state == MlirOptimizationPassState::Enabled) {
return pass_status;
}
@ -246,9 +285,10 @@ Status MlirFunctionOptimizationPass::Run(
ConvertMlirToGraph(*module_ref, export_config, &empty_graph,
&empty_flib, &control_ret_nodes);
if (mlir_to_graph_status.ok()) {
shadow_run_success->GetCell()->IncrementBy(1);
mlir_function_optimization_pass_success->GetCell()->IncrementBy(1);
} else {
shadow_run_failure->GetCell("mlir_to_graph", "")->IncrementBy(1);
mlir_function_optimization_pass_failure->GetCell("mlir_to_graph", "")
->IncrementBy(1);
}
return Status::OK();

View File

@ -29,7 +29,20 @@ namespace tensorflow {
// MLIR passes running on Tensorflow function graphs (Tensorflow V2).
// -------------------------------------------------------------------------- //
enum class MlirOptimizationPassState { Disabled, Enabled, ShadowEnabled };
// Disabled - skip execution of the pass.
// Enabled - execute the pass, propagate errors to the caller if any.
// ShadowEnabled - execute the pass in a shadow mode. The pass should not commit
// any changes to the MLIR module it's processing. Failures are not propagated
// to the caller.
// FallbackEnabled - execute the pass and commit all the changes to the MLIR
// module in case of success. Do not commit any changes in case of failures,
// let the rest of the pipeline run.
enum class MlirOptimizationPassState {
Disabled,
Enabled,
ShadowEnabled,
FallbackEnabled
};
// An API for registering MLIR ModulePass with the Tensorflow runtime. These
// passes are running only for function graphs built by Tensorflow V2 and

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include "mlir/IR/Builders.h" // from @llvm-project
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -39,6 +40,37 @@ class MockMlirOptimizationPass : public MlirOptimizationPass {
mlir::ModuleOp module, const Graph& graph));
};
class ModifyMlirModulePass : public MlirOptimizationPass {
public:
explicit ModifyMlirModulePass(Status run_status) : run_status_(run_status) {}
// MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX
// instead.
MOCK_CONST_METHOD0(name, llvm::StringRef());
MOCK_CONST_METHOD3(GetPassState,
MlirOptimizationPassState(const DeviceSet* device_set,
const ConfigProto& config_proto,
const Graph& graph));
// Just modify MLIR module so that we can check whether original TF graph
// has changed or not.
Status Run(const ConfigProto& config_proto, mlir::ModuleOp module,
const Graph& graph) override {
mlir::Builder b(module.getContext());
auto producer = b.getNamedAttr("producer", b.getI32IntegerAttr(0));
auto min_consumer = b.getNamedAttr("min_consumer", b.getI32IntegerAttr(0));
auto bad_consumers =
b.getNamedAttr("bad_consumers", b.getI32ArrayAttr({1, 2, 3, 4}));
module->setAttr("tf.versions",
b.getDictionaryAttr(llvm::ArrayRef<mlir::NamedAttribute>(
{producer, min_consumer, bad_consumers})));
return run_status_;
}
Status run_status_;
};
class MlirGraphOptimizationPassTest : public Test {
public:
void Init(Status pass_run_result,
@ -61,18 +93,35 @@ class MlirGraphOptimizationPassTest : public Test {
flib_.reset(new FunctionLibraryDefinition(graph_->flib_def()));
}
void AddModuleModificationPass(MlirOptimizationPassState pass_state,
Status run_status) {
// Add FallbackEnabled pass that modifies the graph.
auto optimization_pass =
std::make_unique<NiceMock<ModifyMlirModulePass>>(run_status);
ON_CALL(*optimization_pass, GetPassState(_, _, _))
.WillByDefault(Return(pass_state));
MlirOptimizationPassRegistry::Global().Add(10,
std::move(optimization_pass));
}
void TearDown() override {
MlirOptimizationPassRegistry::Global().ClearPasses();
}
void verifyGraphUnchanged(const GraphDef& original_graph_def) {
void verifyGraph(const GraphDef& original_graph_def, bool changed = false) {
// Proto matchers might be unavailable in the OSS.
#if defined(PLATFORM_GOOGLE)
GraphDef resulted_graph_def;
graph_->ToGraphDef(&resulted_graph_def);
EXPECT_THAT(resulted_graph_def,
::testing::proto::IgnoringRepeatedFieldOrdering(
::testing::EquivToProto(original_graph_def)));
if (changed)
EXPECT_THAT(resulted_graph_def,
Not(::testing::proto::IgnoringRepeatedFieldOrdering(
::testing::EquivToProto(original_graph_def))));
else
EXPECT_THAT(resulted_graph_def,
::testing::proto::IgnoringRepeatedFieldOrdering(
::testing::EquivToProto(original_graph_def)));
#endif
}
@ -96,7 +145,7 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoShadow) {
device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_),
Status(error::Code::ABORTED, "aborted"));
verifyGraphUnchanged(original_graph_def);
verifyGraph(original_graph_def);
}
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) {
@ -111,7 +160,7 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) {
device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_),
Status::OK());
verifyGraphUnchanged(original_graph_def);
verifyGraph(original_graph_def);
}
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailShadow) {
@ -125,7 +174,7 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailShadow) {
device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_),
Status::OK());
verifyGraphUnchanged(original_graph_def);
verifyGraph(original_graph_def);
}
TEST_F(MlirGraphOptimizationPassTest,
@ -141,6 +190,44 @@ TEST_F(MlirGraphOptimizationPassTest,
device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_),
Status(error::Code::ABORTED, "aborted"));
verifyGraph(original_graph_def);
}
TEST_F(MlirGraphOptimizationPassTest,
OptimizationPassFailsShadowDisabledFallback) {
Init(Status(error::Code::ABORTED, "aborted"),
{MlirOptimizationPassState::Disabled,
MlirOptimizationPassState::ShadowEnabled,
MlirOptimizationPassState::FallbackEnabled});
GraphDef original_graph_def;
graph_->ToGraphDef(&original_graph_def);
AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled,
Status(error::Code::ABORTED, "aborted"));
EXPECT_EQ(function_optimization_pass_.Run(
device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_),
Status::OK());
verifyGraph(original_graph_def);
}
TEST_F(MlirGraphOptimizationPassTest,
OptimizationPassDoesNotFailShadowFallback) {
Init(Status::OK(), {MlirOptimizationPassState::ShadowEnabled,
MlirOptimizationPassState::FallbackEnabled});
GraphDef original_graph_def;
graph_->ToGraphDef(&original_graph_def);
AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled,
Status::OK());
EXPECT_EQ(function_optimization_pass_.Run(
device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_),
Status::OK());
verifyGraph(original_graph_def, true);
}
TEST(MlirOptimizationPassRegistry, RegisterPassesWithTheSamePriorityFails) {

View File

@ -76,12 +76,14 @@ MlirOptimizationPassState MlirBridgePass::GetPassState(
MlirBridgeRolloutPolicy policy =
GetMlirBridgeRolloutPolicy(graph, config_proto);
if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
return MlirOptimizationPassState::Enabled;
} else if (policy == MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis) {
return MlirOptimizationPassState::ShadowEnabled;
} else {
return MlirOptimizationPassState::Disabled;
switch (policy) {
case MlirBridgeRolloutPolicy::kEnabledByUser:
return MlirOptimizationPassState::Enabled;
case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis:
return MlirOptimizationPassState::ShadowEnabled;
case MlirBridgeRolloutPolicy::kDisabledByUser:
case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis:
return MlirOptimizationPassState::Disabled;
}
}