[TF:MLIR] Add generic MLIR pass registration mechanism.
PiperOrigin-RevId: 300198093 Change-Id: I462e5c4096519f382271e4cc6734b28fa7fd2034
This commit is contained in:
parent
b6478bba25
commit
891375f55f
@ -102,6 +102,38 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_graph_optimization_pass",
|
||||
srcs = ["mlir_graph_optimization_pass.cc"],
|
||||
hdrs = ["mlir_graph_optimization_pass.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_graph_optimization_pass_registration",
|
||||
srcs = [
|
||||
"mlir_graph_optimization_pass_registration.cc",
|
||||
],
|
||||
deps = [
|
||||
":mlir_graph_optimization_pass",
|
||||
"//tensorflow/core:core_cpu",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf-opt",
|
||||
deps = [
|
||||
|
211
tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc
Normal file
211
tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc
Normal file
@ -0,0 +1,211 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Dumps the MLIR module to disk.
|
||||
// This require the TF_DUMP_GRAPH_PREFIX to be set to a path that exist (or can
|
||||
// be created).
|
||||
static void DumpModule(mlir::ModuleOp module, std::string file_prefix) {
|
||||
std::string prefix = GetDumpDirFromEnvVar();
|
||||
if (prefix.empty()) return;
|
||||
|
||||
auto* env = tensorflow::Env::Default();
|
||||
auto status = env->RecursivelyCreateDir(prefix);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "cannot create directory '" + prefix +
|
||||
"': " + status.error_message();
|
||||
return;
|
||||
}
|
||||
|
||||
prefix += "/" + file_prefix;
|
||||
if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir")) {
|
||||
LOG(WARNING) << "cannot create unique filename, won't dump MLIR module.";
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_ptr<WritableFile> file_writer;
|
||||
status = env->NewWritableFile(prefix, &file_writer);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "cannot open file '" + prefix +
|
||||
"': " + status.error_message();
|
||||
return;
|
||||
}
|
||||
|
||||
// Print the module to a string before writing to the file.
|
||||
std::string txt_module;
|
||||
{
|
||||
llvm::raw_string_ostream os(txt_module);
|
||||
module.print(os);
|
||||
}
|
||||
|
||||
status = file_writer->Append(txt_module);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "error writing to file '" + prefix +
|
||||
"': " + status.error_message();
|
||||
return;
|
||||
}
|
||||
(void)file_writer->Close();
|
||||
VLOG(1) << "Dumped MLIR module to " << prefix;
|
||||
}
|
||||
|
||||
MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() {
|
||||
static auto* global = new MlirOptimizationPassRegistry();
|
||||
return *global;
|
||||
}
|
||||
|
||||
Status MlirFunctionOptimizationPass::Run(
|
||||
const DeviceSet& device_set, const ConfigProto& config_proto,
|
||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||
std::vector<std::string>* control_ret_node_names,
|
||||
bool* control_rets_updated) {
|
||||
// Skip conversion from Graph to MLIR if none of the passes are enabled.
|
||||
const bool is_enabled =
|
||||
llvm::any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
|
||||
return pass_registration.pass->IsEnabled(config_proto);
|
||||
});
|
||||
|
||||
if (!is_enabled) {
|
||||
VLOG(1) << "None of the MLIR optimization passes are enabled "
|
||||
<< "(registered " << registry_->passes().size() << ")";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VLOG(1) << "Running MLIR Graph Optimization Passes "
|
||||
<< "(registered " << registry_->passes().size() << " passes)";
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
mlir::MLIRContext context;
|
||||
GraphImportConfig import_config;
|
||||
import_config.graph_as_function = true;
|
||||
import_config.control_outputs = *control_ret_node_names;
|
||||
TF_ASSIGN_OR_RETURN(auto module_ref,
|
||||
ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||
import_config, &context));
|
||||
|
||||
AddDevicesToOp(*module_ref, &device_set);
|
||||
|
||||
for (auto& pass_registration : registry_->passes()) {
|
||||
llvm::StringRef name = pass_registration.pass->name();
|
||||
VLOG(2) << "Run MLIR graph optimization pass: " << absl::string_view(name);
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(pass_registration.pass->Run(config_proto, *module_ref));
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name));
|
||||
}
|
||||
}
|
||||
|
||||
GraphExportConfig export_config;
|
||||
export_config.graph_as_function = true;
|
||||
absl::flat_hash_set<Node*> control_ret_nodes;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
|
||||
&control_ret_nodes),
|
||||
"Error converting MLIR module back to graph");
|
||||
|
||||
control_ret_node_names->clear();
|
||||
control_ret_node_names->reserve(control_ret_nodes.size());
|
||||
for (const auto* node : control_ret_nodes)
|
||||
control_ret_node_names->push_back(node->name());
|
||||
|
||||
*control_rets_updated = true;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
MlirV1CompatOptimizationPassRegistry&
|
||||
MlirV1CompatOptimizationPassRegistry::Global() {
|
||||
static auto* global = new MlirV1CompatOptimizationPassRegistry();
|
||||
return *global;
|
||||
}
|
||||
|
||||
Status MlirV1CompatGraphOptimizationPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
// Skip function graphs as MlirOptimizationPassRegistry_ will be used instead.
|
||||
if (options.is_function_graph) return Status::OK();
|
||||
|
||||
// Skip conversion from Graph to MLIR if none of the passes are enabled.
|
||||
const bool is_enabled =
|
||||
absl::c_any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
|
||||
return pass_registration.pass->IsEnabled(
|
||||
options.session_options->config);
|
||||
});
|
||||
|
||||
if (!is_enabled) {
|
||||
VLOG(1) << "None of the MLIR optimization passes are enabled "
|
||||
<< "(registered" << registry_->passes().size() << " passes)";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VLOG(1) << "Running MLIR Graph Optimization V1 Compat Passes "
|
||||
<< "(registered" << registry_->passes().size() << " passes)";
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
mlir::MLIRContext context;
|
||||
GraphImportConfig import_config;
|
||||
import_config.upgrade_legacy = true;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module_ref,
|
||||
ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def,
|
||||
import_config, &context));
|
||||
|
||||
AddDevicesToOp(*module_ref, options.device_set);
|
||||
|
||||
for (auto& pass_registration : registry_->passes()) {
|
||||
absl::string_view name = pass_registration.pass->name();
|
||||
VLOG(2) << "Run MLIR graph optimization pass: " << name;
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(pass_registration.pass->Run(options, *module_ref));
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name));
|
||||
}
|
||||
}
|
||||
|
||||
GraphExportConfig export_config;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
ConvertMlirToGraph(*module_ref, export_config, options.graph,
|
||||
options.flib_def),
|
||||
"Error converting MLIR module back to graph");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
179
tensorflow/compiler/mlir/mlir_graph_optimization_pass.h
Normal file
179
tensorflow/compiler/mlir/mlir_graph_optimization_pass.h
Normal file
@ -0,0 +1,179 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// MLIR passes running on Tensorflow function graphs (Tensorflow V2).
|
||||
// -------------------------------------------------------------------------- //
|
||||
|
||||
// An API for registering MLIR ModulePass with the Tensorflow runtime. These
|
||||
// passes are running only for function graphs built by Tensorflow V2 and
|
||||
// instantiated by the process_function_library_runtime (see
|
||||
// FunctionOptimizationPass for details).
|
||||
class MlirOptimizationPass {
|
||||
public:
|
||||
virtual ~MlirOptimizationPass() = default;
|
||||
virtual llvm::StringRef name() const = 0;
|
||||
virtual bool IsEnabled(const ConfigProto& config_proto) const = 0;
|
||||
|
||||
virtual Status Run(const ConfigProto& config_proto,
|
||||
mlir::ModuleOp module) = 0;
|
||||
};
|
||||
|
||||
class MlirOptimizationPassRegistry {
|
||||
public:
|
||||
struct PassRegistration {
|
||||
int priority;
|
||||
std::unique_ptr<MlirOptimizationPass> pass;
|
||||
};
|
||||
|
||||
struct PriorityComparator {
|
||||
bool operator()(const PassRegistration& x,
|
||||
const PassRegistration& y) const {
|
||||
return x.priority < y.priority;
|
||||
}
|
||||
};
|
||||
|
||||
using Passes = std::set<PassRegistration, PriorityComparator>;
|
||||
|
||||
// Returns the global registry of MLIR optimization passes.
|
||||
static MlirOptimizationPassRegistry& Global();
|
||||
|
||||
void Add(int priority, std::unique_ptr<MlirOptimizationPass> pass) {
|
||||
passes_.insert({priority, std::move(pass)});
|
||||
}
|
||||
|
||||
const Passes& passes() const { return passes_; }
|
||||
|
||||
private:
|
||||
Passes passes_;
|
||||
};
|
||||
|
||||
// Function optimization pass that runs all MLIR passes registered in
|
||||
// MlirOptimizationPassRegistry.
|
||||
class MlirFunctionOptimizationPass : public FunctionOptimizationPass {
|
||||
public:
|
||||
explicit MlirFunctionOptimizationPass(
|
||||
const MlirOptimizationPassRegistry* registry =
|
||||
&MlirOptimizationPassRegistry::Global())
|
||||
: registry_(registry) {}
|
||||
|
||||
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
|
||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||
std::vector<std::string>* control_ret_node_names,
|
||||
bool* control_rets_updated) override;
|
||||
|
||||
private:
|
||||
const MlirOptimizationPassRegistry* registry_;
|
||||
};
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// MLIR passes running on Tensorflow V1 graphs.
|
||||
// -------------------------------------------------------------------------- //
|
||||
|
||||
// An API for registering MLIR ModulePass with the Tensorflow runtime. These
|
||||
// passes are running only for V1 graphs (legacy graphs) executed via Session
|
||||
// runtime. Graph importer updates legacy graph behavior to V2 constructs (e.g.
|
||||
// it raises control flow from Switch/Merge nodes to functional control flow
|
||||
// with If/While operations).
|
||||
class MlirV1CompatOptimizationPass {
|
||||
public:
|
||||
virtual ~MlirV1CompatOptimizationPass() = default;
|
||||
virtual llvm::StringRef name() const = 0;
|
||||
virtual bool IsEnabled(const ConfigProto& config_proto) const = 0;
|
||||
|
||||
virtual Status Run(const GraphOptimizationPassOptions& options,
|
||||
mlir::ModuleOp module) = 0;
|
||||
};
|
||||
|
||||
class MlirV1CompatOptimizationPassRegistry {
|
||||
public:
|
||||
struct PassRegistration {
|
||||
int priority;
|
||||
std::unique_ptr<MlirV1CompatOptimizationPass> pass;
|
||||
};
|
||||
|
||||
struct PriorityComparator {
|
||||
bool operator()(const PassRegistration& x,
|
||||
const PassRegistration& y) const {
|
||||
return x.priority < y.priority;
|
||||
}
|
||||
};
|
||||
|
||||
using Passes = std::set<PassRegistration, PriorityComparator>;
|
||||
|
||||
// Returns the global registry of MLIR optimization passes.
|
||||
static MlirV1CompatOptimizationPassRegistry& Global();
|
||||
|
||||
void Add(int priority, std::unique_ptr<MlirV1CompatOptimizationPass> pass) {
|
||||
passes_.insert({priority, std::move(pass)});
|
||||
}
|
||||
|
||||
const Passes& passes() const { return passes_; }
|
||||
|
||||
private:
|
||||
Passes passes_;
|
||||
};
|
||||
|
||||
class MlirV1CompatGraphOptimizationPass : public GraphOptimizationPass {
|
||||
public:
|
||||
explicit MlirV1CompatGraphOptimizationPass(
|
||||
const MlirV1CompatOptimizationPassRegistry* registry =
|
||||
&MlirV1CompatOptimizationPassRegistry::Global())
|
||||
: registry_(registry) {}
|
||||
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
|
||||
private:
|
||||
const MlirV1CompatOptimizationPassRegistry* registry_;
|
||||
};
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// Helper classes for static registration of MLIR (V1 Compat) passes in the
|
||||
// corresponding registry.
|
||||
// -------------------------------------------------------------------------- //
|
||||
|
||||
namespace mlir_pass_registration {
|
||||
|
||||
class MlirOptimizationPassRegistration {
|
||||
public:
|
||||
explicit MlirOptimizationPassRegistration(
|
||||
int priority, std::unique_ptr<MlirOptimizationPass> pass) {
|
||||
MlirOptimizationPassRegistry::Global().Add(priority, std::move(pass));
|
||||
}
|
||||
};
|
||||
|
||||
class MlirV1CompatOptimizationPassRegistration {
|
||||
public:
|
||||
explicit MlirV1CompatOptimizationPassRegistration(
|
||||
int priority, std::unique_ptr<MlirV1CompatOptimizationPass> pass) {
|
||||
MlirV1CompatOptimizationPassRegistry::Global().Add(priority,
|
||||
std::move(pass));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlir_pass_registration
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static function_optimization_registration::FunctionOptimizationPassRegistration
|
||||
register_mlir_passes(std::make_unique<MlirFunctionOptimizationPass>());
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
|
||||
MlirV1CompatGraphOptimizationPass);
|
||||
|
||||
} // namespace tensorflow
|
@ -428,6 +428,28 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_optimization_pass",
|
||||
srcs = ["transforms/graph_optimization_pass.cc"],
|
||||
hdrs = ["transforms/graph_optimization_pass.h"],
|
||||
deps = [
|
||||
":tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_optimization_pass_registration",
|
||||
srcs = ["transforms/graph_optimization_pass_registration.cc"],
|
||||
deps = [
|
||||
":graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Library with TensorFlow dialect static initialization.
|
||||
cc_library(
|
||||
name = "tensorflow_dialect_registration",
|
||||
|
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto,
|
||||
mlir::ModuleOp module) {
|
||||
if (!config_proto.experimental().enable_mlir_graph_optimization()) {
|
||||
VLOG(1) << "Skipping MLIR Graph Optimization Pass"
|
||||
<< ", session flag not enabled";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(ezhulenev): Add something here.
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_
|
||||
|
||||
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Bundle generic MLIR graph optimization passes (some derived from TF Grappler
|
||||
// graph optimizers) into a single MLIR optimization pass.
|
||||
class MlirGraphOptimizationPass : public MlirOptimizationPass {
|
||||
public:
|
||||
llvm::StringRef name() const override { return "graph_optimization"; }
|
||||
|
||||
bool IsEnabled(const ConfigProto& config_proto) const override {
|
||||
return config_proto.experimental().enable_mlir_graph_optimization();
|
||||
}
|
||||
|
||||
Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
constexpr int kMlirGraphOptimizationPriority = 0;
|
||||
}
|
||||
|
||||
static mlir_pass_registration::MlirOptimizationPassRegistration
|
||||
register_mlir_graph_optimization_pass(
|
||||
kMlirGraphOptimizationPriority,
|
||||
std::make_unique<MlirGraphOptimizationPass>());
|
||||
|
||||
} // namespace tensorflow
|
@ -698,6 +698,7 @@ cc_library(
|
||||
srcs = ["mlir_bridge_pass.cc"],
|
||||
hdrs = ["mlir_bridge_pass.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
||||
@ -717,6 +718,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":mlir_bridge_pass",
|
||||
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration",
|
||||
"//tensorflow/core:core_cpu",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -17,125 +17,32 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Dumps the MLIR module to disk.
|
||||
// This require the TF_DUMP_GRAPH_PREFIX to be set to a path that exist (or can
|
||||
// be created).
|
||||
static void DumpModule(mlir::ModuleOp module, llvm::StringRef file_prefix) {
|
||||
std::string prefix = GetDumpDirFromEnvVar();
|
||||
if (prefix.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto* env = tensorflow::Env::Default();
|
||||
auto status = env->RecursivelyCreateDir(prefix);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "cannot create directory '" + prefix +
|
||||
"': " + status.error_message();
|
||||
return;
|
||||
}
|
||||
prefix += "/" + file_prefix.str();
|
||||
if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir")) {
|
||||
LOG(WARNING) << "cannot create unique filename, won't dump MLIR module.";
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_ptr<WritableFile> file_writer;
|
||||
status = env->NewWritableFile(prefix, &file_writer);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "cannot open file '" + prefix +
|
||||
"': " + status.error_message();
|
||||
return;
|
||||
}
|
||||
|
||||
// Print the module to a string before writing to the file.
|
||||
std::string txt_module;
|
||||
{
|
||||
llvm::raw_string_ostream os(txt_module);
|
||||
module.print(os);
|
||||
}
|
||||
|
||||
status = file_writer->Append(txt_module);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "error writing to file '" + prefix +
|
||||
"': " + status.error_message();
|
||||
return;
|
||||
}
|
||||
(void)file_writer->Close();
|
||||
VLOG(1) << "Dumped MLIR module to " << prefix;
|
||||
}
|
||||
|
||||
// This runs the first phase of the "bridge", transforming the graph in a form
|
||||
// that can be executed with delegation of some computations to an accelerator.
|
||||
// This builds on the model of XLA where a subset of the graph is encapsulated
|
||||
// and attached to a "compile" operation, whose result is fed to an "execute"
|
||||
// operation. The kernel for these operations is responsible to lower the
|
||||
// encapsulated graph to a particular device.
|
||||
Status MlirBridgePass::Run(const DeviceSet& device_set,
|
||||
const ConfigProto& config_proto,
|
||||
std::unique_ptr<Graph>* graph,
|
||||
FunctionLibraryDefinition* flib_def,
|
||||
std::vector<std::string>* control_ret_node_names,
|
||||
bool* control_rets_updated) {
|
||||
Status MlirBridgePass::Run(const ConfigProto& config_proto,
|
||||
mlir::ModuleOp module) {
|
||||
if (!config_proto.experimental().enable_mlir_bridge()) {
|
||||
VLOG(1) << "Skipping MLIR Bridge Pass, session flag not enabled";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VLOG(1) << "Running MLIR Bridge Pass";
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
mlir::MLIRContext context;
|
||||
GraphImportConfig import_config;
|
||||
import_config.graph_as_function = true;
|
||||
import_config.control_outputs = *control_ret_node_names;
|
||||
TF_ASSIGN_OR_RETURN(auto module_ref,
|
||||
ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||
import_config, &context));
|
||||
|
||||
AddDevicesToOp(*module_ref, &device_set);
|
||||
|
||||
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_before_");
|
||||
|
||||
// Run the bridge now
|
||||
TF_RETURN_IF_ERROR(
|
||||
mlir::TFTPU::TPUBridge(*module_ref, /*enable_logging=*/VLOG_IS_ON(1)));
|
||||
|
||||
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_after_");
|
||||
|
||||
GraphExportConfig export_config;
|
||||
export_config.graph_as_function = true;
|
||||
absl::flat_hash_set<Node*> control_ret_nodes;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
|
||||
&control_ret_nodes),
|
||||
"Error converting MLIR module back to graph");
|
||||
|
||||
control_ret_node_names->clear();
|
||||
control_ret_node_names->reserve(control_ret_nodes.size());
|
||||
for (const auto* node : control_ret_nodes)
|
||||
control_ret_node_names->push_back(node->name());
|
||||
|
||||
*control_rets_updated = true;
|
||||
mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MlirBridgeV1CompatPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
|
||||
mlir::ModuleOp module) {
|
||||
// Skip function graphs as MlirBridgePass will be used instead.
|
||||
if (options.is_function_graph) return Status::OK();
|
||||
|
||||
@ -145,31 +52,8 @@ Status MlirBridgeV1CompatPass::Run(
|
||||
}
|
||||
|
||||
VLOG(1) << "Running MLIR Bridge V1 Compat Pass";
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
mlir::MLIRContext context;
|
||||
GraphImportConfig import_config;
|
||||
import_config.upgrade_legacy = true;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module_ref,
|
||||
ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def,
|
||||
import_config, &context));
|
||||
|
||||
AddDevicesToOp(*module_ref, options.device_set);
|
||||
|
||||
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_before_");
|
||||
|
||||
// Run the bridge now
|
||||
TF_RETURN_IF_ERROR(mlir::TFTPU::TPUBridgeV1Compat(
|
||||
*module_ref, /*enable_logging=*/VLOG_IS_ON(1)));
|
||||
|
||||
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_after_");
|
||||
|
||||
GraphExportConfig export_config;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
ConvertMlirToGraph(*module_ref, export_config, options.graph,
|
||||
options.flib_def),
|
||||
"Error converting MLIR module back to graph");
|
||||
TF_RETURN_IF_ERROR(
|
||||
mlir::TFTPU::TPUBridgeV1Compat(module, /*enable_logging=*/VLOG_IS_ON(1)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -16,28 +16,42 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This pass uses MLIR to implement all the conversion steps to target XLA from
|
||||
// a TensorFlow Function Graph. It is meant to expose a very limited set of
|
||||
// functionalities during the bring-up of MLIR-based bridge.
|
||||
class MlirBridgePass : public FunctionOptimizationPass {
|
||||
class MlirBridgePass : public MlirOptimizationPass {
|
||||
public:
|
||||
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
|
||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||
std::vector<std::string>* control_ret_node_names,
|
||||
bool* control_rets_updated) override;
|
||||
llvm::StringRef name() const override { return "bridge"; }
|
||||
|
||||
bool IsEnabled(const ConfigProto& config_proto) const override {
|
||||
return config_proto.experimental().enable_mlir_bridge();
|
||||
}
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
// API integrated with the Tensorflow runtime.
|
||||
Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override;
|
||||
};
|
||||
|
||||
// This pass uses MLIR to implement all the conversion steps to target XLA from
|
||||
// a TensorFlow V1 Graph. It is meant to expose a very limited set of
|
||||
// functionalities during the bring-up of MLIR-based bridge.
|
||||
class MlirBridgeV1CompatPass : public GraphOptimizationPass {
|
||||
class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
llvm::StringRef name() const override { return "bridge"; }
|
||||
|
||||
bool IsEnabled(const ConfigProto& config_proto) const override {
|
||||
return config_proto.experimental().enable_mlir_bridge();
|
||||
}
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
// API integrated with the Tensorflow runtime.
|
||||
Status Run(const GraphOptimizationPassOptions& options,
|
||||
mlir::ModuleOp module) override;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,15 +16,18 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
constexpr int kMlirBridgePriority = 10;
|
||||
}
|
||||
|
||||
static function_optimization_registration::FunctionOptimizationPassRegistration
|
||||
register_mlir_bridge_pass(std::make_unique<MlirBridgePass>());
|
||||
static mlir_pass_registration::MlirOptimizationPassRegistration
|
||||
register_mlir_bridge_pass(kMlirBridgePriority,
|
||||
std::make_unique<MlirBridgePass>());
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
|
||||
MlirBridgeV1CompatPass);
|
||||
static mlir_pass_registration::MlirV1CompatOptimizationPassRegistration
|
||||
register_v1_compat_mlir_bridge_pass(
|
||||
kMlirBridgePriority, std::make_unique<MlirBridgeV1CompatPass>());
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -569,6 +569,13 @@ message ConfigProto {
|
||||
// to lower the encapsulated graph to a particular device.
|
||||
bool enable_mlir_bridge = 13;
|
||||
|
||||
// Whether to enable the MLIR-based Graph optimizations.
|
||||
//
|
||||
// This will become a part of standard Tensorflow graph optimization
|
||||
// pipeline, currently this is only used for gradual migration and testing
|
||||
// new passes that are replacing existing optimizations in Grappler.
|
||||
bool enable_mlir_graph_optimization = 16;
|
||||
|
||||
// If true, the session will not store an additional copy of the graph for
|
||||
// each subgraph.
|
||||
//
|
||||
|
@ -75,6 +75,12 @@ tf_proto {
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_BOOL
|
||||
}
|
||||
field {
|
||||
name: "enable_mlir_graph_optimization"
|
||||
number: 16
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_BOOL
|
||||
}
|
||||
field {
|
||||
name: "disable_output_partition_graphs"
|
||||
number: 14
|
||||
|
@ -204,6 +204,12 @@ tf_proto {
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_BOOL
|
||||
}
|
||||
field {
|
||||
name: "enable_mlir_graph_optimization"
|
||||
number: 16
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_BOOL
|
||||
}
|
||||
field {
|
||||
name: "disable_output_partition_graphs"
|
||||
number: 14
|
||||
|
Loading…
Reference in New Issue
Block a user