From 11d3a2d7f2eaf50bed0f2b399353a73e7c09ac5a Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Tue, 11 Feb 2020 14:27:56 -0800 Subject: [PATCH] Add and register MLIR bridge v1 compat pass. This pass is similar to the MLIR bridge pass but is specific to TensorFlow V1 Graphs. In addition, this pass will only trigger on non function graphs as they are already handled by the MLIR bridge pass. PiperOrigin-RevId: 294521035 Change-Id: If80bb613b51ff9fd8b2dad61b4357a55452a4093 --- .../compiler/tf2xla/mlir_bridge_pass.cc | 41 ++++++++++++++++++- tensorflow/compiler/tf2xla/mlir_bridge_pass.h | 11 ++++- .../tf2xla/mlir_bridge_pass_registration.cc | 4 ++ 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 51b737e5144..b09f4d4eb8d 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -105,7 +105,6 @@ Status MlirBridgePass::Run(const DeviceSet& device_set, GraphImportConfig import_config; import_config.graph_as_function = true; import_config.control_outputs = *control_ret_node_names; - TF_ASSIGN_OR_RETURN(auto module_ref, ConvertGraphToMlir(**graph, debug_info, *flib_def, import_config, &context)); @@ -138,4 +137,44 @@ Status MlirBridgePass::Run(const DeviceSet& device_set, return Status::OK(); } +Status MlirBridgeV1CompatPass::Run( + const GraphOptimizationPassOptions& options) { + // Skip function graphs as MlirBridgePass will be used instead. + if (options.is_function_graph) return Status::OK(); + + if (!options.session_options->config.experimental().enable_mlir_bridge()) { + VLOG(1) << "Skipping MLIR Bridge V1 Compat Pass, session flag not enabled"; + return Status::OK(); + } + + VLOG(1) << "Running MLIR Bridge V1 Compat Pass"; + + GraphDebugInfo debug_info; + mlir::MLIRContext context; + GraphImportConfig import_config; + import_config.upgrade_legacy = true; + TF_ASSIGN_OR_RETURN( + auto module_ref, + ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def, + import_config, &context)); + + AddDevicesToOp(*module_ref, options.device_set); + + if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_before_"); + + // Run the bridge now + TF_RETURN_IF_ERROR(mlir::TFTPU::TPUBridgeV1Compat( + *module_ref, /*enable_logging=*/VLOG_IS_ON(1))); + + if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_after_"); + + GraphExportConfig export_config; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ConvertMlirToGraph(*module_ref, export_config, options.graph, + options.flib_def), + "Error converting MLIR module back to graph"); + + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index 58d6e54f367..e7f3fee79ca 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -17,11 +17,12 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_ #include "tensorflow/core/common_runtime/function_optimization_registry.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { // This pass uses MLIR to implement all the conversion steps to target XLA from -// a TensorFlow Graph. It is meant to expose a very limited set of +// a TensorFlow Function Graph. It is meant to expose a very limited set of // functionalities during the bring-up of MLIR-based bridge. class MlirBridgePass : public FunctionOptimizationPass { public: @@ -31,6 +32,14 @@ class MlirBridgePass : public FunctionOptimizationPass { bool* control_rets_updated) override; }; +// This pass uses MLIR to implement all the conversion steps to target XLA from +// a TensorFlow V1 Graph. It is meant to expose a very limited set of +// functionalities during the bring-up of MLIR-based bridge. +class MlirBridgeV1CompatPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_ diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass_registration.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass_registration.cc index 32820790677..ac6e54d4e76 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass_registration.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass_registration.cc @@ -17,10 +17,14 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h" #include "tensorflow/core/common_runtime/function_optimization_registry.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { static function_optimization_registration::FunctionOptimizationPassRegistration register_mlir_bridge_pass(std::make_unique()); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0, + MlirBridgeV1CompatPass); + } // namespace tensorflow