[TF:MLIR] Add generic MLIR pass registration mechanism.
PiperOrigin-RevId: 300198093 Change-Id: I462e5c4096519f382271e4cc6734b28fa7fd2034
This commit is contained in:
parent
b6478bba25
commit
891375f55f
@ -102,6 +102,38 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "mlir_graph_optimization_pass",
|
||||||
|
srcs = ["mlir_graph_optimization_pass.cc"],
|
||||||
|
hdrs = ["mlir_graph_optimization_pass.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@llvm-project//llvm:support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "mlir_graph_optimization_pass_registration",
|
||||||
|
srcs = [
|
||||||
|
"mlir_graph_optimization_pass_registration.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":mlir_graph_optimization_pass",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_binary(
|
tf_cc_binary(
|
||||||
name = "tf-opt",
|
name = "tf-opt",
|
||||||
deps = [
|
deps = [
|
||||||
|
211
tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc
Normal file
211
tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Dumps the MLIR module to disk.
|
||||||
|
// This require the TF_DUMP_GRAPH_PREFIX to be set to a path that exist (or can
|
||||||
|
// be created).
|
||||||
|
static void DumpModule(mlir::ModuleOp module, std::string file_prefix) {
|
||||||
|
std::string prefix = GetDumpDirFromEnvVar();
|
||||||
|
if (prefix.empty()) return;
|
||||||
|
|
||||||
|
auto* env = tensorflow::Env::Default();
|
||||||
|
auto status = env->RecursivelyCreateDir(prefix);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(WARNING) << "cannot create directory '" + prefix +
|
||||||
|
"': " + status.error_message();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix += "/" + file_prefix;
|
||||||
|
if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir")) {
|
||||||
|
LOG(WARNING) << "cannot create unique filename, won't dump MLIR module.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<WritableFile> file_writer;
|
||||||
|
status = env->NewWritableFile(prefix, &file_writer);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(WARNING) << "cannot open file '" + prefix +
|
||||||
|
"': " + status.error_message();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the module to a string before writing to the file.
|
||||||
|
std::string txt_module;
|
||||||
|
{
|
||||||
|
llvm::raw_string_ostream os(txt_module);
|
||||||
|
module.print(os);
|
||||||
|
}
|
||||||
|
|
||||||
|
status = file_writer->Append(txt_module);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(WARNING) << "error writing to file '" + prefix +
|
||||||
|
"': " + status.error_message();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
(void)file_writer->Close();
|
||||||
|
VLOG(1) << "Dumped MLIR module to " << prefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() {
|
||||||
|
static auto* global = new MlirOptimizationPassRegistry();
|
||||||
|
return *global;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MlirFunctionOptimizationPass::Run(
|
||||||
|
const DeviceSet& device_set, const ConfigProto& config_proto,
|
||||||
|
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||||
|
std::vector<std::string>* control_ret_node_names,
|
||||||
|
bool* control_rets_updated) {
|
||||||
|
// Skip conversion from Graph to MLIR if none of the passes are enabled.
|
||||||
|
const bool is_enabled =
|
||||||
|
llvm::any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
|
||||||
|
return pass_registration.pass->IsEnabled(config_proto);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!is_enabled) {
|
||||||
|
VLOG(1) << "None of the MLIR optimization passes are enabled "
|
||||||
|
<< "(registered " << registry_->passes().size() << ")";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
VLOG(1) << "Running MLIR Graph Optimization Passes "
|
||||||
|
<< "(registered " << registry_->passes().size() << " passes)";
|
||||||
|
|
||||||
|
GraphDebugInfo debug_info;
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
GraphImportConfig import_config;
|
||||||
|
import_config.graph_as_function = true;
|
||||||
|
import_config.control_outputs = *control_ret_node_names;
|
||||||
|
TF_ASSIGN_OR_RETURN(auto module_ref,
|
||||||
|
ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||||
|
import_config, &context));
|
||||||
|
|
||||||
|
AddDevicesToOp(*module_ref, &device_set);
|
||||||
|
|
||||||
|
for (auto& pass_registration : registry_->passes()) {
|
||||||
|
llvm::StringRef name = pass_registration.pass->name();
|
||||||
|
VLOG(2) << "Run MLIR graph optimization pass: " << absl::string_view(name);
|
||||||
|
|
||||||
|
if (VLOG_IS_ON(1)) {
|
||||||
|
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(pass_registration.pass->Run(config_proto, *module_ref));
|
||||||
|
|
||||||
|
if (VLOG_IS_ON(1)) {
|
||||||
|
DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphExportConfig export_config;
|
||||||
|
export_config.graph_as_function = true;
|
||||||
|
absl::flat_hash_set<Node*> control_ret_nodes;
|
||||||
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||||
|
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
|
||||||
|
&control_ret_nodes),
|
||||||
|
"Error converting MLIR module back to graph");
|
||||||
|
|
||||||
|
control_ret_node_names->clear();
|
||||||
|
control_ret_node_names->reserve(control_ret_nodes.size());
|
||||||
|
for (const auto* node : control_ret_nodes)
|
||||||
|
control_ret_node_names->push_back(node->name());
|
||||||
|
|
||||||
|
*control_rets_updated = true;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirV1CompatOptimizationPassRegistry&
|
||||||
|
MlirV1CompatOptimizationPassRegistry::Global() {
|
||||||
|
static auto* global = new MlirV1CompatOptimizationPassRegistry();
|
||||||
|
return *global;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MlirV1CompatGraphOptimizationPass::Run(
|
||||||
|
const GraphOptimizationPassOptions& options) {
|
||||||
|
// Skip function graphs as MlirOptimizationPassRegistry_ will be used instead.
|
||||||
|
if (options.is_function_graph) return Status::OK();
|
||||||
|
|
||||||
|
// Skip conversion from Graph to MLIR if none of the passes are enabled.
|
||||||
|
const bool is_enabled =
|
||||||
|
absl::c_any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
|
||||||
|
return pass_registration.pass->IsEnabled(
|
||||||
|
options.session_options->config);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!is_enabled) {
|
||||||
|
VLOG(1) << "None of the MLIR optimization passes are enabled "
|
||||||
|
<< "(registered" << registry_->passes().size() << " passes)";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
VLOG(1) << "Running MLIR Graph Optimization V1 Compat Passes "
|
||||||
|
<< "(registered" << registry_->passes().size() << " passes)";
|
||||||
|
|
||||||
|
GraphDebugInfo debug_info;
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
GraphImportConfig import_config;
|
||||||
|
import_config.upgrade_legacy = true;
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto module_ref,
|
||||||
|
ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def,
|
||||||
|
import_config, &context));
|
||||||
|
|
||||||
|
AddDevicesToOp(*module_ref, options.device_set);
|
||||||
|
|
||||||
|
for (auto& pass_registration : registry_->passes()) {
|
||||||
|
absl::string_view name = pass_registration.pass->name();
|
||||||
|
VLOG(2) << "Run MLIR graph optimization pass: " << name;
|
||||||
|
|
||||||
|
if (VLOG_IS_ON(1)) {
|
||||||
|
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(pass_registration.pass->Run(options, *module_ref));
|
||||||
|
|
||||||
|
if (VLOG_IS_ON(1)) {
|
||||||
|
DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphExportConfig export_config;
|
||||||
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||||
|
ConvertMlirToGraph(*module_ref, export_config, options.graph,
|
||||||
|
options.flib_def),
|
||||||
|
"Error converting MLIR module back to graph");
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
179
tensorflow/compiler/mlir/mlir_graph_optimization_pass.h
Normal file
179
tensorflow/compiler/mlir/mlir_graph_optimization_pass.h
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||||
|
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||||
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------- //
|
||||||
|
// MLIR passes running on Tensorflow function graphs (Tensorflow V2).
|
||||||
|
// -------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
// An API for registering MLIR ModulePass with the Tensorflow runtime. These
|
||||||
|
// passes are running only for function graphs built by Tensorflow V2 and
|
||||||
|
// instantiated by the process_function_library_runtime (see
|
||||||
|
// FunctionOptimizationPass for details).
|
||||||
|
class MlirOptimizationPass {
|
||||||
|
public:
|
||||||
|
virtual ~MlirOptimizationPass() = default;
|
||||||
|
virtual llvm::StringRef name() const = 0;
|
||||||
|
virtual bool IsEnabled(const ConfigProto& config_proto) const = 0;
|
||||||
|
|
||||||
|
virtual Status Run(const ConfigProto& config_proto,
|
||||||
|
mlir::ModuleOp module) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MlirOptimizationPassRegistry {
|
||||||
|
public:
|
||||||
|
struct PassRegistration {
|
||||||
|
int priority;
|
||||||
|
std::unique_ptr<MlirOptimizationPass> pass;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PriorityComparator {
|
||||||
|
bool operator()(const PassRegistration& x,
|
||||||
|
const PassRegistration& y) const {
|
||||||
|
return x.priority < y.priority;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using Passes = std::set<PassRegistration, PriorityComparator>;
|
||||||
|
|
||||||
|
// Returns the global registry of MLIR optimization passes.
|
||||||
|
static MlirOptimizationPassRegistry& Global();
|
||||||
|
|
||||||
|
void Add(int priority, std::unique_ptr<MlirOptimizationPass> pass) {
|
||||||
|
passes_.insert({priority, std::move(pass)});
|
||||||
|
}
|
||||||
|
|
||||||
|
const Passes& passes() const { return passes_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Passes passes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Function optimization pass that runs all MLIR passes registered in
|
||||||
|
// MlirOptimizationPassRegistry.
|
||||||
|
class MlirFunctionOptimizationPass : public FunctionOptimizationPass {
|
||||||
|
public:
|
||||||
|
explicit MlirFunctionOptimizationPass(
|
||||||
|
const MlirOptimizationPassRegistry* registry =
|
||||||
|
&MlirOptimizationPassRegistry::Global())
|
||||||
|
: registry_(registry) {}
|
||||||
|
|
||||||
|
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
|
||||||
|
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||||
|
std::vector<std::string>* control_ret_node_names,
|
||||||
|
bool* control_rets_updated) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const MlirOptimizationPassRegistry* registry_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------- //
|
||||||
|
// MLIR passes running on Tensorflow V1 graphs.
|
||||||
|
// -------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
// An API for registering MLIR ModulePass with the Tensorflow runtime. These
|
||||||
|
// passes are running only for V1 graphs (legacy graphs) executed via Session
|
||||||
|
// runtime. Graph importer updates legacy graph behavior to V2 constructs (e.g.
|
||||||
|
// it raises control flow from Switch/Merge nodes to functional control flow
|
||||||
|
// with If/While operations).
|
||||||
|
class MlirV1CompatOptimizationPass {
|
||||||
|
public:
|
||||||
|
virtual ~MlirV1CompatOptimizationPass() = default;
|
||||||
|
virtual llvm::StringRef name() const = 0;
|
||||||
|
virtual bool IsEnabled(const ConfigProto& config_proto) const = 0;
|
||||||
|
|
||||||
|
virtual Status Run(const GraphOptimizationPassOptions& options,
|
||||||
|
mlir::ModuleOp module) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MlirV1CompatOptimizationPassRegistry {
|
||||||
|
public:
|
||||||
|
struct PassRegistration {
|
||||||
|
int priority;
|
||||||
|
std::unique_ptr<MlirV1CompatOptimizationPass> pass;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PriorityComparator {
|
||||||
|
bool operator()(const PassRegistration& x,
|
||||||
|
const PassRegistration& y) const {
|
||||||
|
return x.priority < y.priority;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using Passes = std::set<PassRegistration, PriorityComparator>;
|
||||||
|
|
||||||
|
// Returns the global registry of MLIR optimization passes.
|
||||||
|
static MlirV1CompatOptimizationPassRegistry& Global();
|
||||||
|
|
||||||
|
void Add(int priority, std::unique_ptr<MlirV1CompatOptimizationPass> pass) {
|
||||||
|
passes_.insert({priority, std::move(pass)});
|
||||||
|
}
|
||||||
|
|
||||||
|
const Passes& passes() const { return passes_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Passes passes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MlirV1CompatGraphOptimizationPass : public GraphOptimizationPass {
|
||||||
|
public:
|
||||||
|
explicit MlirV1CompatGraphOptimizationPass(
|
||||||
|
const MlirV1CompatOptimizationPassRegistry* registry =
|
||||||
|
&MlirV1CompatOptimizationPassRegistry::Global())
|
||||||
|
: registry_(registry) {}
|
||||||
|
|
||||||
|
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const MlirV1CompatOptimizationPassRegistry* registry_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------- //
|
||||||
|
// Helper classes for static registration of MLIR (V1 Compat) passes in the
|
||||||
|
// corresponding registry.
|
||||||
|
// -------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
namespace mlir_pass_registration {
|
||||||
|
|
||||||
|
class MlirOptimizationPassRegistration {
|
||||||
|
public:
|
||||||
|
explicit MlirOptimizationPassRegistration(
|
||||||
|
int priority, std::unique_ptr<MlirOptimizationPass> pass) {
|
||||||
|
MlirOptimizationPassRegistry::Global().Add(priority, std::move(pass));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MlirV1CompatOptimizationPassRegistration {
|
||||||
|
public:
|
||||||
|
explicit MlirV1CompatOptimizationPassRegistration(
|
||||||
|
int priority, std::unique_ptr<MlirV1CompatOptimizationPass> pass) {
|
||||||
|
MlirV1CompatOptimizationPassRegistry::Global().Add(priority,
|
||||||
|
std::move(pass));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlir_pass_registration
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
|
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||||
|
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||||
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
static function_optimization_registration::FunctionOptimizationPassRegistration
|
||||||
|
register_mlir_passes(std::make_unique<MlirFunctionOptimizationPass>());
|
||||||
|
|
||||||
|
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
|
||||||
|
MlirV1CompatGraphOptimizationPass);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -428,6 +428,28 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "graph_optimization_pass",
|
||||||
|
srcs = ["transforms/graph_optimization_pass.cc"],
|
||||||
|
hdrs = ["transforms/graph_optimization_pass.h"],
|
||||||
|
deps = [
|
||||||
|
":tensorflow_passes",
|
||||||
|
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "graph_optimization_pass_registration",
|
||||||
|
srcs = ["transforms/graph_optimization_pass_registration.cc"],
|
||||||
|
deps = [
|
||||||
|
":graph_optimization_pass",
|
||||||
|
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
|
||||||
|
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
# Library with TensorFlow dialect static initialization.
|
# Library with TensorFlow dialect static initialization.
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tensorflow_dialect_registration",
|
name = "tensorflow_dialect_registration",
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto,
|
||||||
|
mlir::ModuleOp module) {
|
||||||
|
if (!config_proto.experimental().enable_mlir_graph_optimization()) {
|
||||||
|
VLOG(1) << "Skipping MLIR Graph Optimization Pass"
|
||||||
|
<< ", session flag not enabled";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(ezhulenev): Add something here.
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,38 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Bundle generic MLIR graph optimization passes (some derived from TF Grappler
|
||||||
|
// graph optimizers) into a single MLIR optimization pass.
|
||||||
|
class MlirGraphOptimizationPass : public MlirOptimizationPass {
|
||||||
|
public:
|
||||||
|
llvm::StringRef name() const override { return "graph_optimization"; }
|
||||||
|
|
||||||
|
bool IsEnabled(const ConfigProto& config_proto) const override {
|
||||||
|
return config_proto.experimental().enable_mlir_graph_optimization();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_
|
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
constexpr int kMlirGraphOptimizationPriority = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static mlir_pass_registration::MlirOptimizationPassRegistration
|
||||||
|
register_mlir_graph_optimization_pass(
|
||||||
|
kMlirGraphOptimizationPriority,
|
||||||
|
std::make_unique<MlirGraphOptimizationPass>());
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -698,6 +698,7 @@ cc_library(
|
|||||||
srcs = ["mlir_bridge_pass.cc"],
|
srcs = ["mlir_bridge_pass.cc"],
|
||||||
hdrs = ["mlir_bridge_pass.h"],
|
hdrs = ["mlir_bridge_pass.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
|
||||||
"//tensorflow/compiler/mlir/tensorflow",
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
||||||
@ -717,6 +718,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":mlir_bridge_pass",
|
":mlir_bridge_pass",
|
||||||
|
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
@ -17,125 +17,32 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
|
||||||
#include "tensorflow/core/public/session_options.h"
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// Dumps the MLIR module to disk.
|
|
||||||
// This require the TF_DUMP_GRAPH_PREFIX to be set to a path that exist (or can
|
|
||||||
// be created).
|
|
||||||
static void DumpModule(mlir::ModuleOp module, llvm::StringRef file_prefix) {
|
|
||||||
std::string prefix = GetDumpDirFromEnvVar();
|
|
||||||
if (prefix.empty()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto* env = tensorflow::Env::Default();
|
|
||||||
auto status = env->RecursivelyCreateDir(prefix);
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(WARNING) << "cannot create directory '" + prefix +
|
|
||||||
"': " + status.error_message();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
prefix += "/" + file_prefix.str();
|
|
||||||
if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir")) {
|
|
||||||
LOG(WARNING) << "cannot create unique filename, won't dump MLIR module.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<WritableFile> file_writer;
|
|
||||||
status = env->NewWritableFile(prefix, &file_writer);
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(WARNING) << "cannot open file '" + prefix +
|
|
||||||
"': " + status.error_message();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print the module to a string before writing to the file.
|
|
||||||
std::string txt_module;
|
|
||||||
{
|
|
||||||
llvm::raw_string_ostream os(txt_module);
|
|
||||||
module.print(os);
|
|
||||||
}
|
|
||||||
|
|
||||||
status = file_writer->Append(txt_module);
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(WARNING) << "error writing to file '" + prefix +
|
|
||||||
"': " + status.error_message();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
(void)file_writer->Close();
|
|
||||||
VLOG(1) << "Dumped MLIR module to " << prefix;
|
|
||||||
}
|
|
||||||
|
|
||||||
// This runs the first phase of the "bridge", transforming the graph in a form
|
// This runs the first phase of the "bridge", transforming the graph in a form
|
||||||
// that can be executed with delegation of some computations to an accelerator.
|
// that can be executed with delegation of some computations to an accelerator.
|
||||||
// This builds on the model of XLA where a subset of the graph is encapsulated
|
// This builds on the model of XLA where a subset of the graph is encapsulated
|
||||||
// and attached to a "compile" operation, whose result is fed to an "execute"
|
// and attached to a "compile" operation, whose result is fed to an "execute"
|
||||||
// operation. The kernel for these operations is responsible to lower the
|
// operation. The kernel for these operations is responsible to lower the
|
||||||
// encapsulated graph to a particular device.
|
// encapsulated graph to a particular device.
|
||||||
Status MlirBridgePass::Run(const DeviceSet& device_set,
|
Status MlirBridgePass::Run(const ConfigProto& config_proto,
|
||||||
const ConfigProto& config_proto,
|
mlir::ModuleOp module) {
|
||||||
std::unique_ptr<Graph>* graph,
|
|
||||||
FunctionLibraryDefinition* flib_def,
|
|
||||||
std::vector<std::string>* control_ret_node_names,
|
|
||||||
bool* control_rets_updated) {
|
|
||||||
if (!config_proto.experimental().enable_mlir_bridge()) {
|
if (!config_proto.experimental().enable_mlir_bridge()) {
|
||||||
VLOG(1) << "Skipping MLIR Bridge Pass, session flag not enabled";
|
VLOG(1) << "Skipping MLIR Bridge Pass, session flag not enabled";
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
VLOG(1) << "Running MLIR Bridge Pass";
|
VLOG(1) << "Running MLIR Bridge Pass";
|
||||||
|
|
||||||
GraphDebugInfo debug_info;
|
|
||||||
mlir::MLIRContext context;
|
|
||||||
GraphImportConfig import_config;
|
|
||||||
import_config.graph_as_function = true;
|
|
||||||
import_config.control_outputs = *control_ret_node_names;
|
|
||||||
TF_ASSIGN_OR_RETURN(auto module_ref,
|
|
||||||
ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
|
||||||
import_config, &context));
|
|
||||||
|
|
||||||
AddDevicesToOp(*module_ref, &device_set);
|
|
||||||
|
|
||||||
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_before_");
|
|
||||||
|
|
||||||
// Run the bridge now
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
mlir::TFTPU::TPUBridge(*module_ref, /*enable_logging=*/VLOG_IS_ON(1)));
|
mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1)));
|
||||||
|
|
||||||
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_after_");
|
|
||||||
|
|
||||||
GraphExportConfig export_config;
|
|
||||||
export_config.graph_as_function = true;
|
|
||||||
absl::flat_hash_set<Node*> control_ret_nodes;
|
|
||||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
|
||||||
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
|
|
||||||
&control_ret_nodes),
|
|
||||||
"Error converting MLIR module back to graph");
|
|
||||||
|
|
||||||
control_ret_node_names->clear();
|
|
||||||
control_ret_node_names->reserve(control_ret_nodes.size());
|
|
||||||
for (const auto* node : control_ret_nodes)
|
|
||||||
control_ret_node_names->push_back(node->name());
|
|
||||||
|
|
||||||
*control_rets_updated = true;
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
|
||||||
Status MlirBridgeV1CompatPass::Run(
|
mlir::ModuleOp module) {
|
||||||
const GraphOptimizationPassOptions& options) {
|
|
||||||
// Skip function graphs as MlirBridgePass will be used instead.
|
// Skip function graphs as MlirBridgePass will be used instead.
|
||||||
if (options.is_function_graph) return Status::OK();
|
if (options.is_function_graph) return Status::OK();
|
||||||
|
|
||||||
@ -145,31 +52,8 @@ Status MlirBridgeV1CompatPass::Run(
|
|||||||
}
|
}
|
||||||
|
|
||||||
VLOG(1) << "Running MLIR Bridge V1 Compat Pass";
|
VLOG(1) << "Running MLIR Bridge V1 Compat Pass";
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
GraphDebugInfo debug_info;
|
mlir::TFTPU::TPUBridgeV1Compat(module, /*enable_logging=*/VLOG_IS_ON(1)));
|
||||||
mlir::MLIRContext context;
|
|
||||||
GraphImportConfig import_config;
|
|
||||||
import_config.upgrade_legacy = true;
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
auto module_ref,
|
|
||||||
ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def,
|
|
||||||
import_config, &context));
|
|
||||||
|
|
||||||
AddDevicesToOp(*module_ref, options.device_set);
|
|
||||||
|
|
||||||
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_before_");
|
|
||||||
|
|
||||||
// Run the bridge now
|
|
||||||
TF_RETURN_IF_ERROR(mlir::TFTPU::TPUBridgeV1Compat(
|
|
||||||
*module_ref, /*enable_logging=*/VLOG_IS_ON(1)));
|
|
||||||
|
|
||||||
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_after_");
|
|
||||||
|
|
||||||
GraphExportConfig export_config;
|
|
||||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
|
||||||
ConvertMlirToGraph(*module_ref, export_config, options.graph,
|
|
||||||
options.flib_def),
|
|
||||||
"Error converting MLIR module back to graph");
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -16,28 +16,42 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
||||||
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// This pass uses MLIR to implement all the conversion steps to target XLA from
|
// This pass uses MLIR to implement all the conversion steps to target XLA from
|
||||||
// a TensorFlow Function Graph. It is meant to expose a very limited set of
|
// a TensorFlow Function Graph. It is meant to expose a very limited set of
|
||||||
// functionalities during the bring-up of MLIR-based bridge.
|
// functionalities during the bring-up of MLIR-based bridge.
|
||||||
class MlirBridgePass : public FunctionOptimizationPass {
|
class MlirBridgePass : public MlirOptimizationPass {
|
||||||
public:
|
public:
|
||||||
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
|
llvm::StringRef name() const override { return "bridge"; }
|
||||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
|
||||||
std::vector<std::string>* control_ret_node_names,
|
bool IsEnabled(const ConfigProto& config_proto) const override {
|
||||||
bool* control_rets_updated) override;
|
return config_proto.experimental().enable_mlir_bridge();
|
||||||
|
}
|
||||||
|
|
||||||
|
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||||
|
// API integrated with the Tensorflow runtime.
|
||||||
|
Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
// This pass uses MLIR to implement all the conversion steps to target XLA from
|
// This pass uses MLIR to implement all the conversion steps to target XLA from
|
||||||
// a TensorFlow V1 Graph. It is meant to expose a very limited set of
|
// a TensorFlow V1 Graph. It is meant to expose a very limited set of
|
||||||
// functionalities during the bring-up of MLIR-based bridge.
|
// functionalities during the bring-up of MLIR-based bridge.
|
||||||
class MlirBridgeV1CompatPass : public GraphOptimizationPass {
|
class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
|
||||||
public:
|
public:
|
||||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
llvm::StringRef name() const override { return "bridge"; }
|
||||||
|
|
||||||
|
bool IsEnabled(const ConfigProto& config_proto) const override {
|
||||||
|
return config_proto.experimental().enable_mlir_bridge();
|
||||||
|
}
|
||||||
|
|
||||||
|
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||||
|
// API integrated with the Tensorflow runtime.
|
||||||
|
Status Run(const GraphOptimizationPassOptions& options,
|
||||||
|
mlir::ModuleOp module) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,15 +16,18 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
|
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
|
||||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
constexpr int kMlirBridgePriority = 10;
|
||||||
|
}
|
||||||
|
|
||||||
static function_optimization_registration::FunctionOptimizationPassRegistration
|
static mlir_pass_registration::MlirOptimizationPassRegistration
|
||||||
register_mlir_bridge_pass(std::make_unique<MlirBridgePass>());
|
register_mlir_bridge_pass(kMlirBridgePriority,
|
||||||
|
std::make_unique<MlirBridgePass>());
|
||||||
|
|
||||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
|
static mlir_pass_registration::MlirV1CompatOptimizationPassRegistration
|
||||||
MlirBridgeV1CompatPass);
|
register_v1_compat_mlir_bridge_pass(
|
||||||
|
kMlirBridgePriority, std::make_unique<MlirBridgeV1CompatPass>());
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -569,6 +569,13 @@ message ConfigProto {
|
|||||||
// to lower the encapsulated graph to a particular device.
|
// to lower the encapsulated graph to a particular device.
|
||||||
bool enable_mlir_bridge = 13;
|
bool enable_mlir_bridge = 13;
|
||||||
|
|
||||||
|
// Whether to enable the MLIR-based Graph optimizations.
|
||||||
|
//
|
||||||
|
// This will become a part of standard Tensorflow graph optimization
|
||||||
|
// pipeline, currently this is only used for gradual migration and testing
|
||||||
|
// new passes that are replacing existing optimizations in Grappler.
|
||||||
|
bool enable_mlir_graph_optimization = 16;
|
||||||
|
|
||||||
// If true, the session will not store an additional copy of the graph for
|
// If true, the session will not store an additional copy of the graph for
|
||||||
// each subgraph.
|
// each subgraph.
|
||||||
//
|
//
|
||||||
|
@ -75,6 +75,12 @@ tf_proto {
|
|||||||
label: LABEL_OPTIONAL
|
label: LABEL_OPTIONAL
|
||||||
type: TYPE_BOOL
|
type: TYPE_BOOL
|
||||||
}
|
}
|
||||||
|
field {
|
||||||
|
name: "enable_mlir_graph_optimization"
|
||||||
|
number: 16
|
||||||
|
label: LABEL_OPTIONAL
|
||||||
|
type: TYPE_BOOL
|
||||||
|
}
|
||||||
field {
|
field {
|
||||||
name: "disable_output_partition_graphs"
|
name: "disable_output_partition_graphs"
|
||||||
number: 14
|
number: 14
|
||||||
|
@ -204,6 +204,12 @@ tf_proto {
|
|||||||
label: LABEL_OPTIONAL
|
label: LABEL_OPTIONAL
|
||||||
type: TYPE_BOOL
|
type: TYPE_BOOL
|
||||||
}
|
}
|
||||||
|
field {
|
||||||
|
name: "enable_mlir_graph_optimization"
|
||||||
|
number: 16
|
||||||
|
label: LABEL_OPTIONAL
|
||||||
|
type: TYPE_BOOL
|
||||||
|
}
|
||||||
field {
|
field {
|
||||||
name: "disable_output_partition_graphs"
|
name: "disable_output_partition_graphs"
|
||||||
number: 14
|
number: 14
|
||||||
|
Loading…
Reference in New Issue
Block a user