[TF:MLIR] Add generic MLIR pass registration mechanism.

PiperOrigin-RevId: 300198093
Change-Id: I462e5c4096519f382271e4cc6734b28fa7fd2034
This commit is contained in:
Eugene Zhulenev 2020-03-10 16:01:28 -07:00 committed by TensorFlower Gardener
parent b6478bba25
commit 891375f55f
15 changed files with 635 additions and 138 deletions

View File

@ -102,6 +102,38 @@ cc_library(
],
)
cc_library(
name = "mlir_graph_optimization_pass",
srcs = ["mlir_graph_optimization_pass.cc"],
hdrs = ["mlir_graph_optimization_pass.h"],
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:device_util",
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
alwayslink = 1,
)
cc_library(
name = "mlir_graph_optimization_pass_registration",
srcs = [
"mlir_graph_optimization_pass_registration.cc",
],
deps = [
":mlir_graph_optimization_pass",
"//tensorflow/core:core_cpu",
],
alwayslink = 1,
)
tf_cc_binary(
name = "tf-opt",
deps = [

View File

@ -0,0 +1,211 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
#include <string>
#include "absl/container/flat_hash_set.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_os_ostream.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
// Dumps the MLIR module to disk.
// This require the TF_DUMP_GRAPH_PREFIX to be set to a path that exist (or can
// be created).
static void DumpModule(mlir::ModuleOp module, std::string file_prefix) {
std::string prefix = GetDumpDirFromEnvVar();
if (prefix.empty()) return;
auto* env = tensorflow::Env::Default();
auto status = env->RecursivelyCreateDir(prefix);
if (!status.ok()) {
LOG(WARNING) << "cannot create directory '" + prefix +
"': " + status.error_message();
return;
}
prefix += "/" + file_prefix;
if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir")) {
LOG(WARNING) << "cannot create unique filename, won't dump MLIR module.";
return;
}
std::unique_ptr<WritableFile> file_writer;
status = env->NewWritableFile(prefix, &file_writer);
if (!status.ok()) {
LOG(WARNING) << "cannot open file '" + prefix +
"': " + status.error_message();
return;
}
// Print the module to a string before writing to the file.
std::string txt_module;
{
llvm::raw_string_ostream os(txt_module);
module.print(os);
}
status = file_writer->Append(txt_module);
if (!status.ok()) {
LOG(WARNING) << "error writing to file '" + prefix +
"': " + status.error_message();
return;
}
(void)file_writer->Close();
VLOG(1) << "Dumped MLIR module to " << prefix;
}
MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() {
static auto* global = new MlirOptimizationPassRegistry();
return *global;
}
Status MlirFunctionOptimizationPass::Run(
const DeviceSet& device_set, const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) {
// Skip conversion from Graph to MLIR if none of the passes are enabled.
const bool is_enabled =
llvm::any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
return pass_registration.pass->IsEnabled(config_proto);
});
if (!is_enabled) {
VLOG(1) << "None of the MLIR optimization passes are enabled "
<< "(registered " << registry_->passes().size() << ")";
return Status::OK();
}
VLOG(1) << "Running MLIR Graph Optimization Passes "
<< "(registered " << registry_->passes().size() << " passes)";
GraphDebugInfo debug_info;
mlir::MLIRContext context;
GraphImportConfig import_config;
import_config.graph_as_function = true;
import_config.control_outputs = *control_ret_node_names;
TF_ASSIGN_OR_RETURN(auto module_ref,
ConvertGraphToMlir(**graph, debug_info, *flib_def,
import_config, &context));
AddDevicesToOp(*module_ref, &device_set);
for (auto& pass_registration : registry_->passes()) {
llvm::StringRef name = pass_registration.pass->name();
VLOG(2) << "Run MLIR graph optimization pass: " << absl::string_view(name);
if (VLOG_IS_ON(1)) {
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
}
TF_RETURN_IF_ERROR(pass_registration.pass->Run(config_proto, *module_ref));
if (VLOG_IS_ON(1)) {
DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name));
}
}
GraphExportConfig export_config;
export_config.graph_as_function = true;
absl::flat_hash_set<Node*> control_ret_nodes;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
&control_ret_nodes),
"Error converting MLIR module back to graph");
control_ret_node_names->clear();
control_ret_node_names->reserve(control_ret_nodes.size());
for (const auto* node : control_ret_nodes)
control_ret_node_names->push_back(node->name());
*control_rets_updated = true;
return Status::OK();
}
MlirV1CompatOptimizationPassRegistry&
MlirV1CompatOptimizationPassRegistry::Global() {
static auto* global = new MlirV1CompatOptimizationPassRegistry();
return *global;
}
Status MlirV1CompatGraphOptimizationPass::Run(
const GraphOptimizationPassOptions& options) {
// Skip function graphs as MlirOptimizationPassRegistry_ will be used instead.
if (options.is_function_graph) return Status::OK();
// Skip conversion from Graph to MLIR if none of the passes are enabled.
const bool is_enabled =
absl::c_any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
return pass_registration.pass->IsEnabled(
options.session_options->config);
});
if (!is_enabled) {
VLOG(1) << "None of the MLIR optimization passes are enabled "
<< "(registered" << registry_->passes().size() << " passes)";
return Status::OK();
}
VLOG(1) << "Running MLIR Graph Optimization V1 Compat Passes "
<< "(registered" << registry_->passes().size() << " passes)";
GraphDebugInfo debug_info;
mlir::MLIRContext context;
GraphImportConfig import_config;
import_config.upgrade_legacy = true;
TF_ASSIGN_OR_RETURN(
auto module_ref,
ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def,
import_config, &context));
AddDevicesToOp(*module_ref, options.device_set);
for (auto& pass_registration : registry_->passes()) {
absl::string_view name = pass_registration.pass->name();
VLOG(2) << "Run MLIR graph optimization pass: " << name;
if (VLOG_IS_ON(1)) {
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
}
TF_RETURN_IF_ERROR(pass_registration.pass->Run(options, *module_ref));
if (VLOG_IS_ON(1)) {
DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name));
}
}
GraphExportConfig export_config;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ConvertMlirToGraph(*module_ref, export_config, options.graph,
options.flib_def),
"Error converting MLIR module back to graph");
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,179 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
#define TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
#include "mlir/IR/Module.h" // TF:llvm-project
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
// -------------------------------------------------------------------------- //
// MLIR passes running on Tensorflow function graphs (Tensorflow V2).
// -------------------------------------------------------------------------- //
// An API for registering MLIR ModulePass with the Tensorflow runtime. These
// passes are running only for function graphs built by Tensorflow V2 and
// instantiated by the process_function_library_runtime (see
// FunctionOptimizationPass for details).
class MlirOptimizationPass {
public:
virtual ~MlirOptimizationPass() = default;
virtual llvm::StringRef name() const = 0;
virtual bool IsEnabled(const ConfigProto& config_proto) const = 0;
virtual Status Run(const ConfigProto& config_proto,
mlir::ModuleOp module) = 0;
};
class MlirOptimizationPassRegistry {
public:
struct PassRegistration {
int priority;
std::unique_ptr<MlirOptimizationPass> pass;
};
struct PriorityComparator {
bool operator()(const PassRegistration& x,
const PassRegistration& y) const {
return x.priority < y.priority;
}
};
using Passes = std::set<PassRegistration, PriorityComparator>;
// Returns the global registry of MLIR optimization passes.
static MlirOptimizationPassRegistry& Global();
void Add(int priority, std::unique_ptr<MlirOptimizationPass> pass) {
passes_.insert({priority, std::move(pass)});
}
const Passes& passes() const { return passes_; }
private:
Passes passes_;
};
// Function optimization pass that runs all MLIR passes registered in
// MlirOptimizationPassRegistry.
class MlirFunctionOptimizationPass : public FunctionOptimizationPass {
public:
explicit MlirFunctionOptimizationPass(
const MlirOptimizationPassRegistry* registry =
&MlirOptimizationPassRegistry::Global())
: registry_(registry) {}
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override;
private:
const MlirOptimizationPassRegistry* registry_;
};
// -------------------------------------------------------------------------- //
// MLIR passes running on Tensorflow V1 graphs.
// -------------------------------------------------------------------------- //
// An API for registering MLIR ModulePass with the Tensorflow runtime. These
// passes are running only for V1 graphs (legacy graphs) executed via Session
// runtime. Graph importer updates legacy graph behavior to V2 constructs (e.g.
// it raises control flow from Switch/Merge nodes to functional control flow
// with If/While operations).
class MlirV1CompatOptimizationPass {
public:
virtual ~MlirV1CompatOptimizationPass() = default;
virtual llvm::StringRef name() const = 0;
virtual bool IsEnabled(const ConfigProto& config_proto) const = 0;
virtual Status Run(const GraphOptimizationPassOptions& options,
mlir::ModuleOp module) = 0;
};
class MlirV1CompatOptimizationPassRegistry {
public:
struct PassRegistration {
int priority;
std::unique_ptr<MlirV1CompatOptimizationPass> pass;
};
struct PriorityComparator {
bool operator()(const PassRegistration& x,
const PassRegistration& y) const {
return x.priority < y.priority;
}
};
using Passes = std::set<PassRegistration, PriorityComparator>;
// Returns the global registry of MLIR optimization passes.
static MlirV1CompatOptimizationPassRegistry& Global();
void Add(int priority, std::unique_ptr<MlirV1CompatOptimizationPass> pass) {
passes_.insert({priority, std::move(pass)});
}
const Passes& passes() const { return passes_; }
private:
Passes passes_;
};
class MlirV1CompatGraphOptimizationPass : public GraphOptimizationPass {
public:
explicit MlirV1CompatGraphOptimizationPass(
const MlirV1CompatOptimizationPassRegistry* registry =
&MlirV1CompatOptimizationPassRegistry::Global())
: registry_(registry) {}
Status Run(const GraphOptimizationPassOptions& options) override;
private:
const MlirV1CompatOptimizationPassRegistry* registry_;
};
// -------------------------------------------------------------------------- //
// Helper classes for static registration of MLIR (V1 Compat) passes in the
// corresponding registry.
// -------------------------------------------------------------------------- //
namespace mlir_pass_registration {
class MlirOptimizationPassRegistration {
public:
explicit MlirOptimizationPassRegistration(
int priority, std::unique_ptr<MlirOptimizationPass> pass) {
MlirOptimizationPassRegistry::Global().Add(priority, std::move(pass));
}
};
class MlirV1CompatOptimizationPassRegistration {
public:
explicit MlirV1CompatOptimizationPassRegistration(
int priority, std::unique_ptr<MlirV1CompatOptimizationPass> pass) {
MlirV1CompatOptimizationPassRegistry::Global().Add(priority,
std::move(pass));
}
};
} // namespace mlir_pass_registration
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
static function_optimization_registration::FunctionOptimizationPassRegistration
register_mlir_passes(std::make_unique<MlirFunctionOptimizationPass>());
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
MlirV1CompatGraphOptimizationPass);
} // namespace tensorflow

View File

@ -428,6 +428,28 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "graph_optimization_pass",
srcs = ["transforms/graph_optimization_pass.cc"],
hdrs = ["transforms/graph_optimization_pass.h"],
deps = [
":tensorflow_passes",
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
],
alwayslink = 1,
)
cc_library(
name = "graph_optimization_pass_registration",
srcs = ["transforms/graph_optimization_pass_registration.cc"],
deps = [
":graph_optimization_pass",
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration",
],
alwayslink = 1,
)
# Library with TensorFlow dialect static initialization.
cc_library(
name = "tensorflow_dialect_registration",

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h"
namespace tensorflow {
Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto,
mlir::ModuleOp module) {
if (!config_proto.experimental().enable_mlir_graph_optimization()) {
VLOG(1) << "Skipping MLIR Graph Optimization Pass"
<< ", session flag not enabled";
return Status::OK();
}
// TODO(ezhulenev): Add something here.
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
namespace tensorflow {
// Bundle generic MLIR graph optimization passes (some derived from TF Grappler
// graph optimizers) into a single MLIR optimization pass.
class MlirGraphOptimizationPass : public MlirOptimizationPass {
public:
llvm::StringRef name() const override { return "graph_optimization"; }
bool IsEnabled(const ConfigProto& config_proto) const override {
return config_proto.experimental().enable_mlir_graph_optimization();
}
Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h"
namespace tensorflow {
namespace {
constexpr int kMlirGraphOptimizationPriority = 0;
}
static mlir_pass_registration::MlirOptimizationPassRegistration
register_mlir_graph_optimization_pass(
kMlirGraphOptimizationPriority,
std::make_unique<MlirGraphOptimizationPass>());
} // namespace tensorflow

View File

@ -698,6 +698,7 @@ cc_library(
srcs = ["mlir_bridge_pass.cc"],
hdrs = ["mlir_bridge_pass.h"],
deps = [
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:device_util",
@ -717,6 +718,7 @@ cc_library(
],
deps = [
":mlir_bridge_pass",
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration",
"//tensorflow/core:core_cpu",
],
alwayslink = 1,

View File

@ -17,125 +17,32 @@ limitations under the License.
#include <string>
#include "absl/container/flat_hash_set.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_os_ostream.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
// Dumps the MLIR module to disk.
// This require the TF_DUMP_GRAPH_PREFIX to be set to a path that exist (or can
// be created).
static void DumpModule(mlir::ModuleOp module, llvm::StringRef file_prefix) {
std::string prefix = GetDumpDirFromEnvVar();
if (prefix.empty()) {
return;
}
auto* env = tensorflow::Env::Default();
auto status = env->RecursivelyCreateDir(prefix);
if (!status.ok()) {
LOG(WARNING) << "cannot create directory '" + prefix +
"': " + status.error_message();
return;
}
prefix += "/" + file_prefix.str();
if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir")) {
LOG(WARNING) << "cannot create unique filename, won't dump MLIR module.";
return;
}
std::unique_ptr<WritableFile> file_writer;
status = env->NewWritableFile(prefix, &file_writer);
if (!status.ok()) {
LOG(WARNING) << "cannot open file '" + prefix +
"': " + status.error_message();
return;
}
// Print the module to a string before writing to the file.
std::string txt_module;
{
llvm::raw_string_ostream os(txt_module);
module.print(os);
}
status = file_writer->Append(txt_module);
if (!status.ok()) {
LOG(WARNING) << "error writing to file '" + prefix +
"': " + status.error_message();
return;
}
(void)file_writer->Close();
VLOG(1) << "Dumped MLIR module to " << prefix;
}
// This runs the first phase of the "bridge", transforming the graph in a form
// that can be executed with delegation of some computations to an accelerator.
// This builds on the model of XLA where a subset of the graph is encapsulated
// and attached to a "compile" operation, whose result is fed to an "execute"
// operation. The kernel for these operations is responsible to lower the
// encapsulated graph to a particular device.
Status MlirBridgePass::Run(const DeviceSet& device_set,
const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) {
Status MlirBridgePass::Run(const ConfigProto& config_proto,
mlir::ModuleOp module) {
if (!config_proto.experimental().enable_mlir_bridge()) {
VLOG(1) << "Skipping MLIR Bridge Pass, session flag not enabled";
return Status::OK();
}
VLOG(1) << "Running MLIR Bridge Pass";
GraphDebugInfo debug_info;
mlir::MLIRContext context;
GraphImportConfig import_config;
import_config.graph_as_function = true;
import_config.control_outputs = *control_ret_node_names;
TF_ASSIGN_OR_RETURN(auto module_ref,
ConvertGraphToMlir(**graph, debug_info, *flib_def,
import_config, &context));
AddDevicesToOp(*module_ref, &device_set);
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_before_");
// Run the bridge now
TF_RETURN_IF_ERROR(
mlir::TFTPU::TPUBridge(*module_ref, /*enable_logging=*/VLOG_IS_ON(1)));
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_after_");
GraphExportConfig export_config;
export_config.graph_as_function = true;
absl::flat_hash_set<Node*> control_ret_nodes;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
&control_ret_nodes),
"Error converting MLIR module back to graph");
control_ret_node_names->clear();
control_ret_node_names->reserve(control_ret_nodes.size());
for (const auto* node : control_ret_nodes)
control_ret_node_names->push_back(node->name());
*control_rets_updated = true;
mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1)));
return Status::OK();
}
Status MlirBridgeV1CompatPass::Run(
const GraphOptimizationPassOptions& options) {
Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
mlir::ModuleOp module) {
// Skip function graphs as MlirBridgePass will be used instead.
if (options.is_function_graph) return Status::OK();
@ -145,31 +52,8 @@ Status MlirBridgeV1CompatPass::Run(
}
VLOG(1) << "Running MLIR Bridge V1 Compat Pass";
GraphDebugInfo debug_info;
mlir::MLIRContext context;
GraphImportConfig import_config;
import_config.upgrade_legacy = true;
TF_ASSIGN_OR_RETURN(
auto module_ref,
ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def,
import_config, &context));
AddDevicesToOp(*module_ref, options.device_set);
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_before_");
// Run the bridge now
TF_RETURN_IF_ERROR(mlir::TFTPU::TPUBridgeV1Compat(
*module_ref, /*enable_logging=*/VLOG_IS_ON(1)));
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_after_");
GraphExportConfig export_config;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ConvertMlirToGraph(*module_ref, export_config, options.graph,
options.flib_def),
"Error converting MLIR module back to graph");
TF_RETURN_IF_ERROR(
mlir::TFTPU::TPUBridgeV1Compat(module, /*enable_logging=*/VLOG_IS_ON(1)));
return Status::OK();
}

View File

@ -16,28 +16,42 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "llvm/ADT/StringRef.h"
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
namespace tensorflow {
// This pass uses MLIR to implement all the conversion steps to target XLA from
// a TensorFlow Function Graph. It is meant to expose a very limited set of
// functionalities during the bring-up of MLIR-based bridge.
class MlirBridgePass : public FunctionOptimizationPass {
class MlirBridgePass : public MlirOptimizationPass {
public:
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override;
llvm::StringRef name() const override { return "bridge"; }
bool IsEnabled(const ConfigProto& config_proto) const override {
return config_proto.experimental().enable_mlir_bridge();
}
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
// API integrated with the Tensorflow runtime.
Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override;
};
// This pass uses MLIR to implement all the conversion steps to target XLA from
// a TensorFlow V1 Graph. It is meant to expose a very limited set of
// functionalities during the bring-up of MLIR-based bridge.
class MlirBridgeV1CompatPass : public GraphOptimizationPass {
class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
llvm::StringRef name() const override { return "bridge"; }
bool IsEnabled(const ConfigProto& config_proto) const override {
return config_proto.experimental().enable_mlir_bridge();
}
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
// API integrated with the Tensorflow runtime.
Status Run(const GraphOptimizationPassOptions& options,
mlir::ModuleOp module) override;
};
} // namespace tensorflow

View File

@ -16,15 +16,18 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
namespace {
constexpr int kMlirBridgePriority = 10;
}
static function_optimization_registration::FunctionOptimizationPassRegistration
register_mlir_bridge_pass(std::make_unique<MlirBridgePass>());
static mlir_pass_registration::MlirOptimizationPassRegistration
register_mlir_bridge_pass(kMlirBridgePriority,
std::make_unique<MlirBridgePass>());
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
MlirBridgeV1CompatPass);
static mlir_pass_registration::MlirV1CompatOptimizationPassRegistration
register_v1_compat_mlir_bridge_pass(
kMlirBridgePriority, std::make_unique<MlirBridgeV1CompatPass>());
} // namespace tensorflow

View File

@ -569,6 +569,13 @@ message ConfigProto {
// to lower the encapsulated graph to a particular device.
bool enable_mlir_bridge = 13;
// Whether to enable the MLIR-based Graph optimizations.
//
// This will become a part of standard Tensorflow graph optimization
// pipeline, currently this is only used for gradual migration and testing
// new passes that are replacing existing optimizations in Grappler.
bool enable_mlir_graph_optimization = 16;
// If true, the session will not store an additional copy of the graph for
// each subgraph.
//

View File

@ -75,6 +75,12 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_BOOL
}
field {
name: "enable_mlir_graph_optimization"
number: 16
label: LABEL_OPTIONAL
type: TYPE_BOOL
}
field {
name: "disable_output_partition_graphs"
number: 14

View File

@ -204,6 +204,12 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_BOOL
}
field {
name: "enable_mlir_graph_optimization"
number: 16
label: LABEL_OPTIONAL
type: TYPE_BOOL
}
field {
name: "disable_output_partition_graphs"
number: 14