From f798728d37810d4cba4dda9dfb92a7b0fe708643 Mon Sep 17 00:00:00 2001 From: Roman Dzhabarov <rdzhabarov@google.com> Date: Mon, 8 Feb 2021 11:48:34 -0800 Subject: [PATCH] Introduce MlirOptimizationPassState::Fallback state to the MlirOptimizationPass. This allows running the pass in the MlirFunctionOptimizationPass pipeline so that it only transforms the MLIR module in case of successful execution. Failure in the pass execution must not affect the MLIR module. Track failures in the execution by the /tensorflow/core/mlir_pass_fallback metric. PiperOrigin-RevId: 356316824 Change-Id: Ida9b761c3dd39202982ad6a5d5804fb0e5433c4d --- tensorflow/compiler/mlir/BUILD | 1 + .../mlir/mlir_graph_optimization_pass.cc | 104 ++++++++++++------ .../mlir/mlir_graph_optimization_pass.h | 15 ++- .../mlir/mlir_graph_optimization_pass_test.cc | 101 +++++++++++++++-- .../compiler/tf2xla/mlir_bridge_pass.cc | 14 ++- 5 files changed, 189 insertions(+), 46 deletions(-) diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 77db4eb43be..340b5ba1efd 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -220,6 +220,7 @@ tf_cc_test( ":mlir_graph_optimization_pass", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 3b713fc0140..9cb44283f6f 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -39,13 +39,17 @@ limitations under the License. namespace tensorflow { -auto* shadow_run_success = +auto* mlir_function_optimization_pass_success = monitoring::Counter<0>::New("/tensorflow/core/mlir_shadow_run_success", - "Success count of MLIR shadow runs"); + "Success count of MLIR pass runs"); -auto* shadow_run_failure = monitoring::Counter<2>::New( +auto* mlir_function_optimization_pass_failure = monitoring::Counter<2>::New( "/tensorflow/core/mlir_shadow_run_failure", - "Failure count of MLIR shadow runs", "kind", "name"); + "Failure count of MLIR pass runs", "kind", "name"); + +auto* mlir_function_pass_failed_fallback = monitoring::Counter<0>::New( + "/tensorflow/core/mlir_pass_failed_fallback", + "Failure count of MLIR pass runs when fallback used"); static inline absl::string_view StringRefToView(llvm::StringRef ref) { return {ref.data(), ref.size()}; @@ -117,23 +121,46 @@ Status MlirFunctionOptimizationPass::Run( std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def, std::vector<std::string>* control_ret_node_names, bool* control_rets_updated) { - // This tracks whether at least one pass is enabled, all passes are disabled, - // or there is a mix of disabled and shadow enabled passes. + // overall_state equals to: + // Enabled if at least one pass is Enabled. + // Disabled if all passes are Disabled. + // ShadowEnabled if all non Disabled passes are ShadowEnabled. + // FallbackEnabled if there are no Enabled passes and there is at least one + // FallbackEnabled pass. MlirOptimizationPassState overall_state = MlirOptimizationPassState::Disabled; // Cache per pass state and reuse it during pass execution. std::vector<MlirOptimizationPassState> per_pass_state; per_pass_state.reserve(registry_->passes().size()); + int num_passes_enabled = 0, num_passes_disabled = 0, + num_passes_shadow_enabled = 0, num_passes_fallback_enabled = 0; for (const auto& pass_registration : registry_->passes()) { MlirOptimizationPassState pass_state = pass_registration.pass->GetPassState( &device_set, config_proto, **graph); per_pass_state.push_back(pass_state); - if (pass_state == MlirOptimizationPassState::ShadowEnabled && - overall_state == MlirOptimizationPassState::Disabled) { - overall_state = MlirOptimizationPassState::ShadowEnabled; - } else if (pass_state == MlirOptimizationPassState::Enabled) { - overall_state = MlirOptimizationPassState::Enabled; + switch (pass_state) { + case MlirOptimizationPassState::ShadowEnabled: { + if (overall_state == MlirOptimizationPassState::Disabled) + overall_state = MlirOptimizationPassState::ShadowEnabled; + ++num_passes_shadow_enabled; + break; + } + case MlirOptimizationPassState::FallbackEnabled: { + if (overall_state != MlirOptimizationPassState::Enabled) + overall_state = MlirOptimizationPassState::FallbackEnabled; + ++num_passes_fallback_enabled; + break; + } + case MlirOptimizationPassState::Enabled: { + overall_state = MlirOptimizationPassState::Enabled; + ++num_passes_enabled; + break; + } + case MlirOptimizationPassState::Disabled: { + ++num_passes_disabled; + break; + } } } @@ -147,15 +174,12 @@ Status MlirFunctionOptimizationPass::Run( return Status::OK(); } - if (overall_state == MlirOptimizationPassState::Enabled) { - LOG_FIRST_N(INFO, 1) << "At least one MLIR Graph Optimization Pass enabled" - << "(registered " << registry_->passes().size() - << " passes)"; - } else if (overall_state == MlirOptimizationPassState::ShadowEnabled) { - LOG_FIRST_N(INFO, 1) - << "All MLIR Graph Optimization Passes are shadow enabled" - << "(registered " << registry_->passes().size() << " passes)"; - } + LOG_FIRST_N(INFO, 1) << "MLIR Graph Optimization Passes." + << " Enabled: " << num_passes_enabled + << ", Disabled: " << num_passes_disabled + << ", ShadowEnabled: " << num_passes_shadow_enabled + << ", FallbackEnabled: " << num_passes_fallback_enabled + << ", Total: " << registry_->passes().size(); GraphDebugInfo debug_info; mlir::MLIRContext context; @@ -180,12 +204,13 @@ Status MlirFunctionOptimizationPass::Run( return module_ref_status.status(); } - shadow_run_failure->GetCell("graph_to_mlir", "")->IncrementBy(1); + mlir_function_optimization_pass_failure->GetCell("graph_to_mlir", "") + ->IncrementBy(1); // Do not fail, just keep the original TF graph unchanged in shadow mode. return Status::OK(); } - auto module_ref = std::move(module_ref_status.ValueOrDie()); + mlir::OwningModuleRef module_ref = std::move(module_ref_status.ValueOrDie()); AddDevicesToOp(*module_ref, &device_set); int per_pass_state_index = 0; @@ -207,22 +232,36 @@ Status MlirFunctionOptimizationPass::Run( overall_state == MlirOptimizationPassState::ShadowEnabled)) { pass_status = pass_registration.pass->Run(config_proto, *module_ref, **graph); - } else if (pass_state == MlirOptimizationPassState::ShadowEnabled) { - // Make sure that the pass does not modify MLIR module if it's shadow - // enabled. + } else if (pass_state == MlirOptimizationPassState::ShadowEnabled || + pass_state == MlirOptimizationPassState::FallbackEnabled) { + // Make sure when the pass is: + // ShadowEnabled, it does not modify the MLIR module. + // FallbackEnabled, it only modifies the MLIR module in case of + // no failures. auto module_ref_clone = module_ref->clone(); pass_status = pass_registration.pass->Run(config_proto, module_ref_clone, **graph); - module_ref_clone->destroy(); + if (pass_state == MlirOptimizationPassState::FallbackEnabled && + pass_status.ok()) { + module_ref = module_ref_clone; + } else { + module_ref_clone->destroy(); + } } if (!pass_status.ok()) { - // If pass failed and pass is: - // ShadowEnabled - only collect metrics, do not propagate - // error to the caller. + // If pass failed and it is: + // (Shadow|Fallback)Enabled - only collect metrics, do not propagate + // error to the caller. // Enabled - return error back to the caller. if (pass_state == MlirOptimizationPassState::ShadowEnabled) { - shadow_run_failure->GetCell("pass", name.str())->IncrementBy(1); + mlir_function_optimization_pass_failure->GetCell("pass", name.str()) + ->IncrementBy(1); + } else if (pass_state == MlirOptimizationPassState::FallbackEnabled) { + LOG(WARNING) << StringRefToView(name) + << " pass failed, continuing without the pass because the " + "pass has fallback enabled"; + mlir_function_pass_failed_fallback->GetCell()->IncrementBy(1); } else if (pass_state == MlirOptimizationPassState::Enabled) { return pass_status; } @@ -246,9 +285,10 @@ Status MlirFunctionOptimizationPass::Run( ConvertMlirToGraph(*module_ref, export_config, &empty_graph, &empty_flib, &control_ret_nodes); if (mlir_to_graph_status.ok()) { - shadow_run_success->GetCell()->IncrementBy(1); + mlir_function_optimization_pass_success->GetCell()->IncrementBy(1); } else { - shadow_run_failure->GetCell("mlir_to_graph", "")->IncrementBy(1); + mlir_function_optimization_pass_failure->GetCell("mlir_to_graph", "") + ->IncrementBy(1); } return Status::OK(); diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h index f9eb6857a5c..32cba472507 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h @@ -29,7 +29,20 @@ namespace tensorflow { // MLIR passes running on Tensorflow function graphs (Tensorflow V2). // -------------------------------------------------------------------------- // -enum class MlirOptimizationPassState { Disabled, Enabled, ShadowEnabled }; +// Disabled - skip execution of the pass. +// Enabled - execute the pass, propagate errors to the caller if any. +// ShadowEnabled - execute the pass in a shadow mode. The pass should not commit +// any changes to the MLIR module it's processing. Failures are not propagated +// to the caller. +// FallbackEnabled - execute the pass and commit all the changes to the MLIR +// module in case of success. Do not commit any changes in case of failures, +// let the rest of the pipeline run. +enum class MlirOptimizationPassState { + Disabled, + Enabled, + ShadowEnabled, + FallbackEnabled +}; // An API for registering MLIR ModulePass with the Tensorflow runtime. These // passes are running only for function graphs built by Tensorflow V2 and diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc index af78874b5c9..123f3a30e67 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include <memory> +#include "mlir/IR/Builders.h" // from @llvm-project #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -39,6 +40,37 @@ class MockMlirOptimizationPass : public MlirOptimizationPass { mlir::ModuleOp module, const Graph& graph)); }; +class ModifyMlirModulePass : public MlirOptimizationPass { + public: + explicit ModifyMlirModulePass(Status run_status) : run_status_(run_status) {} + // MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX + // instead. + MOCK_CONST_METHOD0(name, llvm::StringRef()); + MOCK_CONST_METHOD3(GetPassState, + MlirOptimizationPassState(const DeviceSet* device_set, + const ConfigProto& config_proto, + const Graph& graph)); + + // Just modify MLIR module so that we can check whether original TF graph + // has changed or not. + Status Run(const ConfigProto& config_proto, mlir::ModuleOp module, + const Graph& graph) override { + mlir::Builder b(module.getContext()); + auto producer = b.getNamedAttr("producer", b.getI32IntegerAttr(0)); + auto min_consumer = b.getNamedAttr("min_consumer", b.getI32IntegerAttr(0)); + auto bad_consumers = + b.getNamedAttr("bad_consumers", b.getI32ArrayAttr({1, 2, 3, 4})); + + module->setAttr("tf.versions", + b.getDictionaryAttr(llvm::ArrayRef<mlir::NamedAttribute>( + {producer, min_consumer, bad_consumers}))); + + return run_status_; + } + + Status run_status_; +}; + class MlirGraphOptimizationPassTest : public Test { public: void Init(Status pass_run_result, @@ -61,18 +93,35 @@ class MlirGraphOptimizationPassTest : public Test { flib_.reset(new FunctionLibraryDefinition(graph_->flib_def())); } + void AddModuleModificationPass(MlirOptimizationPassState pass_state, + Status run_status) { + // Add FallbackEnabled pass that modifies the graph. + auto optimization_pass = + std::make_unique<NiceMock<ModifyMlirModulePass>>(run_status); + ON_CALL(*optimization_pass, GetPassState(_, _, _)) + .WillByDefault(Return(pass_state)); + MlirOptimizationPassRegistry::Global().Add(10, + std::move(optimization_pass)); + } + void TearDown() override { MlirOptimizationPassRegistry::Global().ClearPasses(); } - void verifyGraphUnchanged(const GraphDef& original_graph_def) { + void verifyGraph(const GraphDef& original_graph_def, bool changed = false) { // Proto matchers might be unavailable in the OSS. #if defined(PLATFORM_GOOGLE) GraphDef resulted_graph_def; graph_->ToGraphDef(&resulted_graph_def); - EXPECT_THAT(resulted_graph_def, - ::testing::proto::IgnoringRepeatedFieldOrdering( - ::testing::EquivToProto(original_graph_def))); + + if (changed) + EXPECT_THAT(resulted_graph_def, + Not(::testing::proto::IgnoringRepeatedFieldOrdering( + ::testing::EquivToProto(original_graph_def)))); + else + EXPECT_THAT(resulted_graph_def, + ::testing::proto::IgnoringRepeatedFieldOrdering( + ::testing::EquivToProto(original_graph_def))); #endif } @@ -96,7 +145,7 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoShadow) { device_set_, config_proto_, &graph_, flib_.get(), &control_ret_node_names_, &control_rets_updated_), Status(error::Code::ABORTED, "aborted")); - verifyGraphUnchanged(original_graph_def); + verifyGraph(original_graph_def); } TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) { @@ -111,7 +160,7 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) { device_set_, config_proto_, &graph_, flib_.get(), &control_ret_node_names_, &control_rets_updated_), Status::OK()); - verifyGraphUnchanged(original_graph_def); + verifyGraph(original_graph_def); } TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailShadow) { @@ -125,7 +174,7 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailShadow) { device_set_, config_proto_, &graph_, flib_.get(), &control_ret_node_names_, &control_rets_updated_), Status::OK()); - verifyGraphUnchanged(original_graph_def); + verifyGraph(original_graph_def); } TEST_F(MlirGraphOptimizationPassTest, @@ -141,6 +190,44 @@ TEST_F(MlirGraphOptimizationPassTest, device_set_, config_proto_, &graph_, flib_.get(), &control_ret_node_names_, &control_rets_updated_), Status(error::Code::ABORTED, "aborted")); + verifyGraph(original_graph_def); +} + +TEST_F(MlirGraphOptimizationPassTest, + OptimizationPassFailsShadowDisabledFallback) { + Init(Status(error::Code::ABORTED, "aborted"), + {MlirOptimizationPassState::Disabled, + MlirOptimizationPassState::ShadowEnabled, + MlirOptimizationPassState::FallbackEnabled}); + + GraphDef original_graph_def; + graph_->ToGraphDef(&original_graph_def); + AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, + Status(error::Code::ABORTED, "aborted")); + + EXPECT_EQ(function_optimization_pass_.Run( + device_set_, config_proto_, &graph_, flib_.get(), + &control_ret_node_names_, &control_rets_updated_), + Status::OK()); + verifyGraph(original_graph_def); +} + +TEST_F(MlirGraphOptimizationPassTest, + OptimizationPassDoesNotFailShadowFallback) { + Init(Status::OK(), {MlirOptimizationPassState::ShadowEnabled, + MlirOptimizationPassState::FallbackEnabled}); + + GraphDef original_graph_def; + graph_->ToGraphDef(&original_graph_def); + + AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, + Status::OK()); + EXPECT_EQ(function_optimization_pass_.Run( + device_set_, config_proto_, &graph_, flib_.get(), + &control_ret_node_names_, &control_rets_updated_), + Status::OK()); + + verifyGraph(original_graph_def, true); } TEST(MlirOptimizationPassRegistry, RegisterPassesWithTheSamePriorityFails) { diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 7b47250fc51..273718ee810 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -76,12 +76,14 @@ MlirOptimizationPassState MlirBridgePass::GetPassState( MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(graph, config_proto); - if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) { - return MlirOptimizationPassState::Enabled; - } else if (policy == MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis) { - return MlirOptimizationPassState::ShadowEnabled; - } else { - return MlirOptimizationPassState::Disabled; + switch (policy) { + case MlirBridgeRolloutPolicy::kEnabledByUser: + return MlirOptimizationPassState::Enabled; + case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis: + return MlirOptimizationPassState::ShadowEnabled; + case MlirBridgeRolloutPolicy::kDisabledByUser: + case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis: + return MlirOptimizationPassState::Disabled; } }