From 891375f55f57e0cf67cdb8df930702c617a60599 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 10 Mar 2020 16:01:28 -0700 Subject: [PATCH] [TF:MLIR] Add generic MLIR pass registration mechanism. PiperOrigin-RevId: 300198093 Change-Id: I462e5c4096519f382271e4cc6734b28fa7fd2034 --- tensorflow/compiler/mlir/BUILD | 32 +++ .../mlir/mlir_graph_optimization_pass.cc | 211 ++++++++++++++++++ .../mlir/mlir_graph_optimization_pass.h | 179 +++++++++++++++ ...ir_graph_optimization_pass_registration.cc | 30 +++ tensorflow/compiler/mlir/tensorflow/BUILD | 22 ++ .../transforms/graph_optimization_pass.cc | 33 +++ .../transforms/graph_optimization_pass.h | 38 ++++ .../graph_optimization_pass_registration.cc | 30 +++ tensorflow/compiler/tf2xla/BUILD | 2 + .../compiler/tf2xla/mlir_bridge_pass.cc | 130 +---------- tensorflow/compiler/tf2xla/mlir_bridge_pass.h | 32 ++- .../tf2xla/mlir_bridge_pass_registration.cc | 15 +- tensorflow/core/protobuf/config.proto | 7 + ...nsorflow.-config-proto.-experimental.pbtxt | 6 + .../golden/v1/tensorflow.-config-proto.pbtxt | 6 + 15 files changed, 635 insertions(+), 138 deletions(-) create mode 100644 tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc create mode 100644 tensorflow/compiler/mlir/mlir_graph_optimization_pass.h create mode 100644 tensorflow/compiler/mlir/mlir_graph_optimization_pass_registration.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass_registration.cc diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index ef4ddb619a8..0bc34780cd1 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -102,6 +102,38 @@ cc_library( ], ) +cc_library( + name = "mlir_graph_optimization_pass", + srcs = ["mlir_graph_optimization_pass.cc"], + hdrs = ["mlir_graph_optimization_pass.h"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:device_util", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/core:core_cpu", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + ], + alwayslink = 1, +) + +cc_library( + name = "mlir_graph_optimization_pass_registration", + srcs = [ + "mlir_graph_optimization_pass_registration.cc", + ], + deps = [ + ":mlir_graph_optimization_pass", + "//tensorflow/core:core_cpu", + ], + alwayslink = 1, +) + tf_cc_binary( name = "tf-opt", deps = [ diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc new file mode 100644 index 00000000000..29c506f93fd --- /dev/null +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -0,0 +1,211 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_os_ostream.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +// Dumps the MLIR module to disk. +// This require the TF_DUMP_GRAPH_PREFIX to be set to a path that exist (or can +// be created). +static void DumpModule(mlir::ModuleOp module, std::string file_prefix) { + std::string prefix = GetDumpDirFromEnvVar(); + if (prefix.empty()) return; + + auto* env = tensorflow::Env::Default(); + auto status = env->RecursivelyCreateDir(prefix); + if (!status.ok()) { + LOG(WARNING) << "cannot create directory '" + prefix + + "': " + status.error_message(); + return; + } + + prefix += "/" + file_prefix; + if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir")) { + LOG(WARNING) << "cannot create unique filename, won't dump MLIR module."; + return; + } + + std::unique_ptr file_writer; + status = env->NewWritableFile(prefix, &file_writer); + if (!status.ok()) { + LOG(WARNING) << "cannot open file '" + prefix + + "': " + status.error_message(); + return; + } + + // Print the module to a string before writing to the file. + std::string txt_module; + { + llvm::raw_string_ostream os(txt_module); + module.print(os); + } + + status = file_writer->Append(txt_module); + if (!status.ok()) { + LOG(WARNING) << "error writing to file '" + prefix + + "': " + status.error_message(); + return; + } + (void)file_writer->Close(); + VLOG(1) << "Dumped MLIR module to " << prefix; +} + +MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() { + static auto* global = new MlirOptimizationPassRegistry(); + return *global; +} + +Status MlirFunctionOptimizationPass::Run( + const DeviceSet& device_set, const ConfigProto& config_proto, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) { + // Skip conversion from Graph to MLIR if none of the passes are enabled. + const bool is_enabled = + llvm::any_of(registry_->passes(), [&](auto& pass_registration) -> bool { + return pass_registration.pass->IsEnabled(config_proto); + }); + + if (!is_enabled) { + VLOG(1) << "None of the MLIR optimization passes are enabled " + << "(registered " << registry_->passes().size() << ")"; + return Status::OK(); + } + + VLOG(1) << "Running MLIR Graph Optimization Passes " + << "(registered " << registry_->passes().size() << " passes)"; + + GraphDebugInfo debug_info; + mlir::MLIRContext context; + GraphImportConfig import_config; + import_config.graph_as_function = true; + import_config.control_outputs = *control_ret_node_names; + TF_ASSIGN_OR_RETURN(auto module_ref, + ConvertGraphToMlir(**graph, debug_info, *flib_def, + import_config, &context)); + + AddDevicesToOp(*module_ref, &device_set); + + for (auto& pass_registration : registry_->passes()) { + llvm::StringRef name = pass_registration.pass->name(); + VLOG(2) << "Run MLIR graph optimization pass: " << absl::string_view(name); + + if (VLOG_IS_ON(1)) { + DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name)); + } + + TF_RETURN_IF_ERROR(pass_registration.pass->Run(config_proto, *module_ref)); + + if (VLOG_IS_ON(1)) { + DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name)); + } + } + + GraphExportConfig export_config; + export_config.graph_as_function = true; + absl::flat_hash_set control_ret_nodes; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ConvertMlirToGraph(*module_ref, export_config, graph, flib_def, + &control_ret_nodes), + "Error converting MLIR module back to graph"); + + control_ret_node_names->clear(); + control_ret_node_names->reserve(control_ret_nodes.size()); + for (const auto* node : control_ret_nodes) + control_ret_node_names->push_back(node->name()); + + *control_rets_updated = true; + + return Status::OK(); +} + +MlirV1CompatOptimizationPassRegistry& +MlirV1CompatOptimizationPassRegistry::Global() { + static auto* global = new MlirV1CompatOptimizationPassRegistry(); + return *global; +} + +Status MlirV1CompatGraphOptimizationPass::Run( + const GraphOptimizationPassOptions& options) { + // Skip function graphs as MlirOptimizationPassRegistry_ will be used instead. + if (options.is_function_graph) return Status::OK(); + + // Skip conversion from Graph to MLIR if none of the passes are enabled. + const bool is_enabled = + absl::c_any_of(registry_->passes(), [&](auto& pass_registration) -> bool { + return pass_registration.pass->IsEnabled( + options.session_options->config); + }); + + if (!is_enabled) { + VLOG(1) << "None of the MLIR optimization passes are enabled " + << "(registered" << registry_->passes().size() << " passes)"; + return Status::OK(); + } + + VLOG(1) << "Running MLIR Graph Optimization V1 Compat Passes " + << "(registered" << registry_->passes().size() << " passes)"; + + GraphDebugInfo debug_info; + mlir::MLIRContext context; + GraphImportConfig import_config; + import_config.upgrade_legacy = true; + TF_ASSIGN_OR_RETURN( + auto module_ref, + ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def, + import_config, &context)); + + AddDevicesToOp(*module_ref, options.device_set); + + for (auto& pass_registration : registry_->passes()) { + absl::string_view name = pass_registration.pass->name(); + VLOG(2) << "Run MLIR graph optimization pass: " << name; + + if (VLOG_IS_ON(1)) { + DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name)); + } + + TF_RETURN_IF_ERROR(pass_registration.pass->Run(options, *module_ref)); + + if (VLOG_IS_ON(1)) { + DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name)); + } + } + + GraphExportConfig export_config; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ConvertMlirToGraph(*module_ref, export_config, options.graph, + options.flib_def), + "Error converting MLIR module back to graph"); + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h new file mode 100644 index 00000000000..aed5307d39d --- /dev/null +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h @@ -0,0 +1,179 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ + +#include "mlir/IR/Module.h" // TF:llvm-project +#include "tensorflow/core/common_runtime/function_optimization_registry.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// -------------------------------------------------------------------------- // +// MLIR passes running on Tensorflow function graphs (Tensorflow V2). +// -------------------------------------------------------------------------- // + +// An API for registering MLIR ModulePass with the Tensorflow runtime. These +// passes are running only for function graphs built by Tensorflow V2 and +// instantiated by the process_function_library_runtime (see +// FunctionOptimizationPass for details). +class MlirOptimizationPass { + public: + virtual ~MlirOptimizationPass() = default; + virtual llvm::StringRef name() const = 0; + virtual bool IsEnabled(const ConfigProto& config_proto) const = 0; + + virtual Status Run(const ConfigProto& config_proto, + mlir::ModuleOp module) = 0; +}; + +class MlirOptimizationPassRegistry { + public: + struct PassRegistration { + int priority; + std::unique_ptr pass; + }; + + struct PriorityComparator { + bool operator()(const PassRegistration& x, + const PassRegistration& y) const { + return x.priority < y.priority; + } + }; + + using Passes = std::set; + + // Returns the global registry of MLIR optimization passes. + static MlirOptimizationPassRegistry& Global(); + + void Add(int priority, std::unique_ptr pass) { + passes_.insert({priority, std::move(pass)}); + } + + const Passes& passes() const { return passes_; } + + private: + Passes passes_; +}; + +// Function optimization pass that runs all MLIR passes registered in +// MlirOptimizationPassRegistry. +class MlirFunctionOptimizationPass : public FunctionOptimizationPass { + public: + explicit MlirFunctionOptimizationPass( + const MlirOptimizationPassRegistry* registry = + &MlirOptimizationPassRegistry::Global()) + : registry_(registry) {} + + Status Run(const DeviceSet& device_set, const ConfigProto& config_proto, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) override; + + private: + const MlirOptimizationPassRegistry* registry_; +}; + +// -------------------------------------------------------------------------- // +// MLIR passes running on Tensorflow V1 graphs. +// -------------------------------------------------------------------------- // + +// An API for registering MLIR ModulePass with the Tensorflow runtime. These +// passes are running only for V1 graphs (legacy graphs) executed via Session +// runtime. Graph importer updates legacy graph behavior to V2 constructs (e.g. +// it raises control flow from Switch/Merge nodes to functional control flow +// with If/While operations). +class MlirV1CompatOptimizationPass { + public: + virtual ~MlirV1CompatOptimizationPass() = default; + virtual llvm::StringRef name() const = 0; + virtual bool IsEnabled(const ConfigProto& config_proto) const = 0; + + virtual Status Run(const GraphOptimizationPassOptions& options, + mlir::ModuleOp module) = 0; +}; + +class MlirV1CompatOptimizationPassRegistry { + public: + struct PassRegistration { + int priority; + std::unique_ptr pass; + }; + + struct PriorityComparator { + bool operator()(const PassRegistration& x, + const PassRegistration& y) const { + return x.priority < y.priority; + } + }; + + using Passes = std::set; + + // Returns the global registry of MLIR optimization passes. + static MlirV1CompatOptimizationPassRegistry& Global(); + + void Add(int priority, std::unique_ptr pass) { + passes_.insert({priority, std::move(pass)}); + } + + const Passes& passes() const { return passes_; } + + private: + Passes passes_; +}; + +class MlirV1CompatGraphOptimizationPass : public GraphOptimizationPass { + public: + explicit MlirV1CompatGraphOptimizationPass( + const MlirV1CompatOptimizationPassRegistry* registry = + &MlirV1CompatOptimizationPassRegistry::Global()) + : registry_(registry) {} + + Status Run(const GraphOptimizationPassOptions& options) override; + + private: + const MlirV1CompatOptimizationPassRegistry* registry_; +}; + +// -------------------------------------------------------------------------- // +// Helper classes for static registration of MLIR (V1 Compat) passes in the +// corresponding registry. +// -------------------------------------------------------------------------- // + +namespace mlir_pass_registration { + +class MlirOptimizationPassRegistration { + public: + explicit MlirOptimizationPassRegistration( + int priority, std::unique_ptr pass) { + MlirOptimizationPassRegistry::Global().Add(priority, std::move(pass)); + } +}; + +class MlirV1CompatOptimizationPassRegistration { + public: + explicit MlirV1CompatOptimizationPassRegistration( + int priority, std::unique_ptr pass) { + MlirV1CompatOptimizationPassRegistry::Global().Add(priority, + std::move(pass)); + } +}; + +} // namespace mlir_pass_registration + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_registration.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_registration.cc new file mode 100644 index 00000000000..8155af6505e --- /dev/null +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_registration.cc @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" +#include "tensorflow/core/common_runtime/function_optimization_registry.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +static function_optimization_registration::FunctionOptimizationPassRegistration + register_mlir_passes(std::make_unique()); + +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0, + MlirV1CompatGraphOptimizationPass); + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index d9dbc66f77a..50c5d0b40b9 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -428,6 +428,28 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "graph_optimization_pass", + srcs = ["transforms/graph_optimization_pass.cc"], + hdrs = ["transforms/graph_optimization_pass.h"], + deps = [ + ":tensorflow_passes", + "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", + ], + alwayslink = 1, +) + +cc_library( + name = "graph_optimization_pass_registration", + srcs = ["transforms/graph_optimization_pass_registration.cc"], + deps = [ + ":graph_optimization_pass", + "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", + "//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration", + ], + alwayslink = 1, +) + # Library with TensorFlow dialect static initialization. cc_library( name = "tensorflow_dialect_registration", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc new file mode 100644 index 00000000000..281a6011af6 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h" + +namespace tensorflow { + +Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto, + mlir::ModuleOp module) { + if (!config_proto.experimental().enable_mlir_graph_optimization()) { + VLOG(1) << "Skipping MLIR Graph Optimization Pass" + << ", session flag not enabled"; + return Status::OK(); + } + + // TODO(ezhulenev): Add something here. + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h new file mode 100644 index 00000000000..955da470494 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_ + +#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" + +namespace tensorflow { + +// Bundle generic MLIR graph optimization passes (some derived from TF Grappler +// graph optimizers) into a single MLIR optimization pass. +class MlirGraphOptimizationPass : public MlirOptimizationPass { + public: + llvm::StringRef name() const override { return "graph_optimization"; } + + bool IsEnabled(const ConfigProto& config_proto) const override { + return config_proto.experimental().enable_mlir_graph_optimization(); + } + + Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass_registration.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass_registration.cc new file mode 100644 index 00000000000..4681f8a0f33 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass_registration.cc @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h" + +namespace tensorflow { +namespace { +constexpr int kMlirGraphOptimizationPriority = 0; +} + +static mlir_pass_registration::MlirOptimizationPassRegistration + register_mlir_graph_optimization_pass( + kMlirGraphOptimizationPriority, + std::make_unique()); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 8b382240ef1..d13c817d9f1 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -698,6 +698,7 @@ cc_library( srcs = ["mlir_bridge_pass.cc"], hdrs = ["mlir_bridge_pass.h"], deps = [ + "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:device_util", @@ -717,6 +718,7 @@ cc_library( ], deps = [ ":mlir_bridge_pass", + "//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration", "//tensorflow/core:core_cpu", ], alwayslink = 1, diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 7ac4cb8fb06..6d0d569724f 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -17,125 +17,32 @@ limitations under the License. #include -#include "absl/container/flat_hash_set.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/raw_os_ostream.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { -// Dumps the MLIR module to disk. -// This require the TF_DUMP_GRAPH_PREFIX to be set to a path that exist (or can -// be created). -static void DumpModule(mlir::ModuleOp module, llvm::StringRef file_prefix) { - std::string prefix = GetDumpDirFromEnvVar(); - if (prefix.empty()) { - return; - } - - auto* env = tensorflow::Env::Default(); - auto status = env->RecursivelyCreateDir(prefix); - if (!status.ok()) { - LOG(WARNING) << "cannot create directory '" + prefix + - "': " + status.error_message(); - return; - } - prefix += "/" + file_prefix.str(); - if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir")) { - LOG(WARNING) << "cannot create unique filename, won't dump MLIR module."; - return; - } - - std::unique_ptr file_writer; - status = env->NewWritableFile(prefix, &file_writer); - if (!status.ok()) { - LOG(WARNING) << "cannot open file '" + prefix + - "': " + status.error_message(); - return; - } - - // Print the module to a string before writing to the file. - std::string txt_module; - { - llvm::raw_string_ostream os(txt_module); - module.print(os); - } - - status = file_writer->Append(txt_module); - if (!status.ok()) { - LOG(WARNING) << "error writing to file '" + prefix + - "': " + status.error_message(); - return; - } - (void)file_writer->Close(); - VLOG(1) << "Dumped MLIR module to " << prefix; -} - // This runs the first phase of the "bridge", transforming the graph in a form // that can be executed with delegation of some computations to an accelerator. // This builds on the model of XLA where a subset of the graph is encapsulated // and attached to a "compile" operation, whose result is fed to an "execute" // operation. The kernel for these operations is responsible to lower the // encapsulated graph to a particular device. -Status MlirBridgePass::Run(const DeviceSet& device_set, - const ConfigProto& config_proto, - std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def, - std::vector* control_ret_node_names, - bool* control_rets_updated) { +Status MlirBridgePass::Run(const ConfigProto& config_proto, + mlir::ModuleOp module) { if (!config_proto.experimental().enable_mlir_bridge()) { VLOG(1) << "Skipping MLIR Bridge Pass, session flag not enabled"; return Status::OK(); } VLOG(1) << "Running MLIR Bridge Pass"; - - GraphDebugInfo debug_info; - mlir::MLIRContext context; - GraphImportConfig import_config; - import_config.graph_as_function = true; - import_config.control_outputs = *control_ret_node_names; - TF_ASSIGN_OR_RETURN(auto module_ref, - ConvertGraphToMlir(**graph, debug_info, *flib_def, - import_config, &context)); - - AddDevicesToOp(*module_ref, &device_set); - - if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_before_"); - - // Run the bridge now TF_RETURN_IF_ERROR( - mlir::TFTPU::TPUBridge(*module_ref, /*enable_logging=*/VLOG_IS_ON(1))); - - if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_after_"); - - GraphExportConfig export_config; - export_config.graph_as_function = true; - absl::flat_hash_set control_ret_nodes; - TF_RETURN_WITH_CONTEXT_IF_ERROR( - ConvertMlirToGraph(*module_ref, export_config, graph, flib_def, - &control_ret_nodes), - "Error converting MLIR module back to graph"); - - control_ret_node_names->clear(); - control_ret_node_names->reserve(control_ret_nodes.size()); - for (const auto* node : control_ret_nodes) - control_ret_node_names->push_back(node->name()); - - *control_rets_updated = true; + mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1))); return Status::OK(); } - -Status MlirBridgeV1CompatPass::Run( - const GraphOptimizationPassOptions& options) { +Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options, + mlir::ModuleOp module) { // Skip function graphs as MlirBridgePass will be used instead. if (options.is_function_graph) return Status::OK(); @@ -145,31 +52,8 @@ Status MlirBridgeV1CompatPass::Run( } VLOG(1) << "Running MLIR Bridge V1 Compat Pass"; - - GraphDebugInfo debug_info; - mlir::MLIRContext context; - GraphImportConfig import_config; - import_config.upgrade_legacy = true; - TF_ASSIGN_OR_RETURN( - auto module_ref, - ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def, - import_config, &context)); - - AddDevicesToOp(*module_ref, options.device_set); - - if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_before_"); - - // Run the bridge now - TF_RETURN_IF_ERROR(mlir::TFTPU::TPUBridgeV1Compat( - *module_ref, /*enable_logging=*/VLOG_IS_ON(1))); - - if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_after_"); - - GraphExportConfig export_config; - TF_RETURN_WITH_CONTEXT_IF_ERROR( - ConvertMlirToGraph(*module_ref, export_config, options.graph, - options.flib_def), - "Error converting MLIR module back to graph"); + TF_RETURN_IF_ERROR( + mlir::TFTPU::TPUBridgeV1Compat(module, /*enable_logging=*/VLOG_IS_ON(1))); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index e7f3fee79ca..b7f8ef203f7 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -16,28 +16,42 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_ #define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_ -#include "tensorflow/core/common_runtime/function_optimization_registry.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "llvm/ADT/StringRef.h" +#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" namespace tensorflow { // This pass uses MLIR to implement all the conversion steps to target XLA from // a TensorFlow Function Graph. It is meant to expose a very limited set of // functionalities during the bring-up of MLIR-based bridge. -class MlirBridgePass : public FunctionOptimizationPass { +class MlirBridgePass : public MlirOptimizationPass { public: - Status Run(const DeviceSet& device_set, const ConfigProto& config_proto, - std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, - std::vector* control_ret_node_names, - bool* control_rets_updated) override; + llvm::StringRef name() const override { return "bridge"; } + + bool IsEnabled(const ConfigProto& config_proto) const override { + return config_proto.experimental().enable_mlir_bridge(); + } + + // This should be used as a thin mapper around mlir::ModulePass::runOnModule + // API integrated with the Tensorflow runtime. + Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override; }; // This pass uses MLIR to implement all the conversion steps to target XLA from // a TensorFlow V1 Graph. It is meant to expose a very limited set of // functionalities during the bring-up of MLIR-based bridge. -class MlirBridgeV1CompatPass : public GraphOptimizationPass { +class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override; + llvm::StringRef name() const override { return "bridge"; } + + bool IsEnabled(const ConfigProto& config_proto) const override { + return config_proto.experimental().enable_mlir_bridge(); + } + + // This should be used as a thin mapper around mlir::ModulePass::runOnModule + // API integrated with the Tensorflow runtime. + Status Run(const GraphOptimizationPassOptions& options, + mlir::ModuleOp module) override; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass_registration.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass_registration.cc index ac6e54d4e76..21791ff4427 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass_registration.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass_registration.cc @@ -16,15 +16,18 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h" -#include "tensorflow/core/common_runtime/function_optimization_registry.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { +namespace { +constexpr int kMlirBridgePriority = 10; +} -static function_optimization_registration::FunctionOptimizationPassRegistration - register_mlir_bridge_pass(std::make_unique()); +static mlir_pass_registration::MlirOptimizationPassRegistration + register_mlir_bridge_pass(kMlirBridgePriority, + std::make_unique()); -REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0, - MlirBridgeV1CompatPass); +static mlir_pass_registration::MlirV1CompatOptimizationPassRegistration + register_v1_compat_mlir_bridge_pass( + kMlirBridgePriority, std::make_unique()); } // namespace tensorflow diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index df2eed45900..7973e002762 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -569,6 +569,13 @@ message ConfigProto { // to lower the encapsulated graph to a particular device. bool enable_mlir_bridge = 13; + // Whether to enable the MLIR-based Graph optimizations. + // + // This will become a part of standard Tensorflow graph optimization + // pipeline, currently this is only used for gradual migration and testing + // new passes that are replacing existing optimizations in Grappler. + bool enable_mlir_graph_optimization = 16; + // If true, the session will not store an additional copy of the graph for // each subgraph. // diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt index cde90e76f5d..3d8187ca752 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt @@ -75,6 +75,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_BOOL } + field { + name: "enable_mlir_graph_optimization" + number: 16 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } field { name: "disable_output_partition_graphs" number: 14 diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt index 1309eb79938..c32fdee5af0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt @@ -204,6 +204,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_BOOL } + field { + name: "enable_mlir_graph_optimization" + number: 16 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } field { name: "disable_output_partition_graphs" number: 14