diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 77db4eb43be..340b5ba1efd 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -220,6 +220,7 @@ tf_cc_test( ":mlir_graph_optimization_pass", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 3b713fc0140..9cb44283f6f 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -39,13 +39,17 @@ limitations under the License. namespace tensorflow { -auto* shadow_run_success = +auto* mlir_function_optimization_pass_success = monitoring::Counter<0>::New("/tensorflow/core/mlir_shadow_run_success", - "Success count of MLIR shadow runs"); + "Success count of MLIR pass runs"); -auto* shadow_run_failure = monitoring::Counter<2>::New( +auto* mlir_function_optimization_pass_failure = monitoring::Counter<2>::New( "/tensorflow/core/mlir_shadow_run_failure", - "Failure count of MLIR shadow runs", "kind", "name"); + "Failure count of MLIR pass runs", "kind", "name"); + +auto* mlir_function_pass_failed_fallback = monitoring::Counter<0>::New( + "/tensorflow/core/mlir_pass_failed_fallback", + "Failure count of MLIR pass runs when fallback used"); static inline absl::string_view StringRefToView(llvm::StringRef ref) { return {ref.data(), ref.size()}; @@ -117,23 +121,46 @@ Status MlirFunctionOptimizationPass::Run( std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, std::vector* control_ret_node_names, bool* control_rets_updated) { - // This tracks whether at least one pass is enabled, all passes are disabled, - // or there is a mix of disabled and shadow enabled passes. + // overall_state equals to: + // Enabled if at least one pass is Enabled. + // Disabled if all passes are Disabled. + // ShadowEnabled if all non Disabled passes are ShadowEnabled. + // FallbackEnabled if there are no Enabled passes and there is at least one + // FallbackEnabled pass. MlirOptimizationPassState overall_state = MlirOptimizationPassState::Disabled; // Cache per pass state and reuse it during pass execution. std::vector per_pass_state; per_pass_state.reserve(registry_->passes().size()); + int num_passes_enabled = 0, num_passes_disabled = 0, + num_passes_shadow_enabled = 0, num_passes_fallback_enabled = 0; for (const auto& pass_registration : registry_->passes()) { MlirOptimizationPassState pass_state = pass_registration.pass->GetPassState( &device_set, config_proto, **graph); per_pass_state.push_back(pass_state); - if (pass_state == MlirOptimizationPassState::ShadowEnabled && - overall_state == MlirOptimizationPassState::Disabled) { - overall_state = MlirOptimizationPassState::ShadowEnabled; - } else if (pass_state == MlirOptimizationPassState::Enabled) { - overall_state = MlirOptimizationPassState::Enabled; + switch (pass_state) { + case MlirOptimizationPassState::ShadowEnabled: { + if (overall_state == MlirOptimizationPassState::Disabled) + overall_state = MlirOptimizationPassState::ShadowEnabled; + ++num_passes_shadow_enabled; + break; + } + case MlirOptimizationPassState::FallbackEnabled: { + if (overall_state != MlirOptimizationPassState::Enabled) + overall_state = MlirOptimizationPassState::FallbackEnabled; + ++num_passes_fallback_enabled; + break; + } + case MlirOptimizationPassState::Enabled: { + overall_state = MlirOptimizationPassState::Enabled; + ++num_passes_enabled; + break; + } + case MlirOptimizationPassState::Disabled: { + ++num_passes_disabled; + break; + } } } @@ -147,15 +174,12 @@ Status MlirFunctionOptimizationPass::Run( return Status::OK(); } - if (overall_state == MlirOptimizationPassState::Enabled) { - LOG_FIRST_N(INFO, 1) << "At least one MLIR Graph Optimization Pass enabled" - << "(registered " << registry_->passes().size() - << " passes)"; - } else if (overall_state == MlirOptimizationPassState::ShadowEnabled) { - LOG_FIRST_N(INFO, 1) - << "All MLIR Graph Optimization Passes are shadow enabled" - << "(registered " << registry_->passes().size() << " passes)"; - } + LOG_FIRST_N(INFO, 1) << "MLIR Graph Optimization Passes." + << " Enabled: " << num_passes_enabled + << ", Disabled: " << num_passes_disabled + << ", ShadowEnabled: " << num_passes_shadow_enabled + << ", FallbackEnabled: " << num_passes_fallback_enabled + << ", Total: " << registry_->passes().size(); GraphDebugInfo debug_info; mlir::MLIRContext context; @@ -180,12 +204,13 @@ Status MlirFunctionOptimizationPass::Run( return module_ref_status.status(); } - shadow_run_failure->GetCell("graph_to_mlir", "")->IncrementBy(1); + mlir_function_optimization_pass_failure->GetCell("graph_to_mlir", "") + ->IncrementBy(1); // Do not fail, just keep the original TF graph unchanged in shadow mode. return Status::OK(); } - auto module_ref = std::move(module_ref_status.ValueOrDie()); + mlir::OwningModuleRef module_ref = std::move(module_ref_status.ValueOrDie()); AddDevicesToOp(*module_ref, &device_set); int per_pass_state_index = 0; @@ -207,22 +232,36 @@ Status MlirFunctionOptimizationPass::Run( overall_state == MlirOptimizationPassState::ShadowEnabled)) { pass_status = pass_registration.pass->Run(config_proto, *module_ref, **graph); - } else if (pass_state == MlirOptimizationPassState::ShadowEnabled) { - // Make sure that the pass does not modify MLIR module if it's shadow - // enabled. + } else if (pass_state == MlirOptimizationPassState::ShadowEnabled || + pass_state == MlirOptimizationPassState::FallbackEnabled) { + // Make sure when the pass is: + // ShadowEnabled, it does not modify the MLIR module. + // FallbackEnabled, it only modifies the MLIR module in case of + // no failures. auto module_ref_clone = module_ref->clone(); pass_status = pass_registration.pass->Run(config_proto, module_ref_clone, **graph); - module_ref_clone->destroy(); + if (pass_state == MlirOptimizationPassState::FallbackEnabled && + pass_status.ok()) { + module_ref = module_ref_clone; + } else { + module_ref_clone->destroy(); + } } if (!pass_status.ok()) { - // If pass failed and pass is: - // ShadowEnabled - only collect metrics, do not propagate - // error to the caller. + // If pass failed and it is: + // (Shadow|Fallback)Enabled - only collect metrics, do not propagate + // error to the caller. // Enabled - return error back to the caller. if (pass_state == MlirOptimizationPassState::ShadowEnabled) { - shadow_run_failure->GetCell("pass", name.str())->IncrementBy(1); + mlir_function_optimization_pass_failure->GetCell("pass", name.str()) + ->IncrementBy(1); + } else if (pass_state == MlirOptimizationPassState::FallbackEnabled) { + LOG(WARNING) << StringRefToView(name) + << " pass failed, continuing without the pass because the " + "pass has fallback enabled"; + mlir_function_pass_failed_fallback->GetCell()->IncrementBy(1); } else if (pass_state == MlirOptimizationPassState::Enabled) { return pass_status; } @@ -246,9 +285,10 @@ Status MlirFunctionOptimizationPass::Run( ConvertMlirToGraph(*module_ref, export_config, &empty_graph, &empty_flib, &control_ret_nodes); if (mlir_to_graph_status.ok()) { - shadow_run_success->GetCell()->IncrementBy(1); + mlir_function_optimization_pass_success->GetCell()->IncrementBy(1); } else { - shadow_run_failure->GetCell("mlir_to_graph", "")->IncrementBy(1); + mlir_function_optimization_pass_failure->GetCell("mlir_to_graph", "") + ->IncrementBy(1); } return Status::OK(); diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h index f9eb6857a5c..32cba472507 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h @@ -29,7 +29,20 @@ namespace tensorflow { // MLIR passes running on Tensorflow function graphs (Tensorflow V2). // -------------------------------------------------------------------------- // -enum class MlirOptimizationPassState { Disabled, Enabled, ShadowEnabled }; +// Disabled - skip execution of the pass. +// Enabled - execute the pass, propagate errors to the caller if any. +// ShadowEnabled - execute the pass in a shadow mode. The pass should not commit +// any changes to the MLIR module it's processing. Failures are not propagated +// to the caller. +// FallbackEnabled - execute the pass and commit all the changes to the MLIR +// module in case of success. Do not commit any changes in case of failures, +// let the rest of the pipeline run. +enum class MlirOptimizationPassState { + Disabled, + Enabled, + ShadowEnabled, + FallbackEnabled +}; // An API for registering MLIR ModulePass with the Tensorflow runtime. These // passes are running only for function graphs built by Tensorflow V2 and diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc index af78874b5c9..123f3a30e67 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "mlir/IR/Builders.h" // from @llvm-project #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -39,6 +40,37 @@ class MockMlirOptimizationPass : public MlirOptimizationPass { mlir::ModuleOp module, const Graph& graph)); }; +class ModifyMlirModulePass : public MlirOptimizationPass { + public: + explicit ModifyMlirModulePass(Status run_status) : run_status_(run_status) {} + // MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX + // instead. + MOCK_CONST_METHOD0(name, llvm::StringRef()); + MOCK_CONST_METHOD3(GetPassState, + MlirOptimizationPassState(const DeviceSet* device_set, + const ConfigProto& config_proto, + const Graph& graph)); + + // Just modify MLIR module so that we can check whether original TF graph + // has changed or not. + Status Run(const ConfigProto& config_proto, mlir::ModuleOp module, + const Graph& graph) override { + mlir::Builder b(module.getContext()); + auto producer = b.getNamedAttr("producer", b.getI32IntegerAttr(0)); + auto min_consumer = b.getNamedAttr("min_consumer", b.getI32IntegerAttr(0)); + auto bad_consumers = + b.getNamedAttr("bad_consumers", b.getI32ArrayAttr({1, 2, 3, 4})); + + module->setAttr("tf.versions", + b.getDictionaryAttr(llvm::ArrayRef( + {producer, min_consumer, bad_consumers}))); + + return run_status_; + } + + Status run_status_; +}; + class MlirGraphOptimizationPassTest : public Test { public: void Init(Status pass_run_result, @@ -61,18 +93,35 @@ class MlirGraphOptimizationPassTest : public Test { flib_.reset(new FunctionLibraryDefinition(graph_->flib_def())); } + void AddModuleModificationPass(MlirOptimizationPassState pass_state, + Status run_status) { + // Add FallbackEnabled pass that modifies the graph. + auto optimization_pass = + std::make_unique>(run_status); + ON_CALL(*optimization_pass, GetPassState(_, _, _)) + .WillByDefault(Return(pass_state)); + MlirOptimizationPassRegistry::Global().Add(10, + std::move(optimization_pass)); + } + void TearDown() override { MlirOptimizationPassRegistry::Global().ClearPasses(); } - void verifyGraphUnchanged(const GraphDef& original_graph_def) { + void verifyGraph(const GraphDef& original_graph_def, bool changed = false) { // Proto matchers might be unavailable in the OSS. #if defined(PLATFORM_GOOGLE) GraphDef resulted_graph_def; graph_->ToGraphDef(&resulted_graph_def); - EXPECT_THAT(resulted_graph_def, - ::testing::proto::IgnoringRepeatedFieldOrdering( - ::testing::EquivToProto(original_graph_def))); + + if (changed) + EXPECT_THAT(resulted_graph_def, + Not(::testing::proto::IgnoringRepeatedFieldOrdering( + ::testing::EquivToProto(original_graph_def)))); + else + EXPECT_THAT(resulted_graph_def, + ::testing::proto::IgnoringRepeatedFieldOrdering( + ::testing::EquivToProto(original_graph_def))); #endif } @@ -96,7 +145,7 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoShadow) { device_set_, config_proto_, &graph_, flib_.get(), &control_ret_node_names_, &control_rets_updated_), Status(error::Code::ABORTED, "aborted")); - verifyGraphUnchanged(original_graph_def); + verifyGraph(original_graph_def); } TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) { @@ -111,7 +160,7 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) { device_set_, config_proto_, &graph_, flib_.get(), &control_ret_node_names_, &control_rets_updated_), Status::OK()); - verifyGraphUnchanged(original_graph_def); + verifyGraph(original_graph_def); } TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailShadow) { @@ -125,7 +174,7 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailShadow) { device_set_, config_proto_, &graph_, flib_.get(), &control_ret_node_names_, &control_rets_updated_), Status::OK()); - verifyGraphUnchanged(original_graph_def); + verifyGraph(original_graph_def); } TEST_F(MlirGraphOptimizationPassTest, @@ -141,6 +190,44 @@ TEST_F(MlirGraphOptimizationPassTest, device_set_, config_proto_, &graph_, flib_.get(), &control_ret_node_names_, &control_rets_updated_), Status(error::Code::ABORTED, "aborted")); + verifyGraph(original_graph_def); +} + +TEST_F(MlirGraphOptimizationPassTest, + OptimizationPassFailsShadowDisabledFallback) { + Init(Status(error::Code::ABORTED, "aborted"), + {MlirOptimizationPassState::Disabled, + MlirOptimizationPassState::ShadowEnabled, + MlirOptimizationPassState::FallbackEnabled}); + + GraphDef original_graph_def; + graph_->ToGraphDef(&original_graph_def); + AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, + Status(error::Code::ABORTED, "aborted")); + + EXPECT_EQ(function_optimization_pass_.Run( + device_set_, config_proto_, &graph_, flib_.get(), + &control_ret_node_names_, &control_rets_updated_), + Status::OK()); + verifyGraph(original_graph_def); +} + +TEST_F(MlirGraphOptimizationPassTest, + OptimizationPassDoesNotFailShadowFallback) { + Init(Status::OK(), {MlirOptimizationPassState::ShadowEnabled, + MlirOptimizationPassState::FallbackEnabled}); + + GraphDef original_graph_def; + graph_->ToGraphDef(&original_graph_def); + + AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, + Status::OK()); + EXPECT_EQ(function_optimization_pass_.Run( + device_set_, config_proto_, &graph_, flib_.get(), + &control_ret_node_names_, &control_rets_updated_), + Status::OK()); + + verifyGraph(original_graph_def, true); } TEST(MlirOptimizationPassRegistry, RegisterPassesWithTheSamePriorityFails) { diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 7b47250fc51..273718ee810 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -76,12 +76,14 @@ MlirOptimizationPassState MlirBridgePass::GetPassState( MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(graph, config_proto); - if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) { - return MlirOptimizationPassState::Enabled; - } else if (policy == MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis) { - return MlirOptimizationPassState::ShadowEnabled; - } else { - return MlirOptimizationPassState::Disabled; + switch (policy) { + case MlirBridgeRolloutPolicy::kEnabledByUser: + return MlirOptimizationPassState::Enabled; + case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis: + return MlirOptimizationPassState::ShadowEnabled; + case MlirBridgeRolloutPolicy::kDisabledByUser: + case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis: + return MlirOptimizationPassState::Disabled; } }