Pass to convert MLIR to Graph and running user specified Graph passes on it.
This pass converts MLIR to Graph, runs user-specified Graph Optimization Passes in the order specified and converts the graph back to MLIR. Usage: mlir-tf-opt <input.mlir> --run-tf-graph-optimization --graph-passes=<comma separated list of passes> -o <output.mlir> PiperOrigin-RevId: 256382837
This commit is contained in:
parent
7a8fd724ea
commit
26f9dbe537
@ -373,6 +373,7 @@ cc_library(
|
||||
":convert_tensor",
|
||||
":eval_util",
|
||||
":tensorflow",
|
||||
":tf_graph_optimization_pass",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/core:framework",
|
||||
@ -396,6 +397,27 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_graph_optimization_pass",
|
||||
srcs = ["transforms/tf_graph_optimization_pass.cc"],
|
||||
deps = [
|
||||
":convert_graphdef",
|
||||
":mlir_roundtrip_flags",
|
||||
"//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "eval_util",
|
||||
srcs = ["utils/eval_util.cc"],
|
||||
|
@ -0,0 +1,18 @@
|
||||
// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowPass 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0
|
||||
|
||||
// CHECK: FunctionalizeControlFlowPass: Graph contains node with inputs predicated on incompatible predicates: {s(Cond:0,then)} and {s(Cond:0,else)}
|
||||
// CHECK-NEXT: for node {{[{][{]node Add[}][}]}}
|
||||
func @main() {
|
||||
%0 = "_tf._TPUReplicate"() {computation = @foo, Tinputs = [], Tbroadcast_inputs = [], NumVariables = 0, Tguaranteed_constants = [], output_types = []} : () -> !_tf.control loc("_TPUReplicate")
|
||||
return
|
||||
}
|
||||
|
||||
func @foo() {
|
||||
%0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", value = dense<17> : tensor<i32>} : () -> (tensor<i32>, !_tf.control) loc("x")
|
||||
%1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_BOOL", value = dense<true> : tensor<i1>} : () -> (tensor<i1>, !_tf.control) loc("Cond")
|
||||
%2:3 = "_tf.Switch"(%0#0, %1#0) {T = "tfdtype$DT_INT32", device = ""} : (tensor<i32>, tensor<i1>) -> (tensor<i32>, tensor<i32>, !_tf.control) loc("switch")
|
||||
%3:2 = "_tf.Add"(%2#0, %2#1) {T = "tfdtype$DT_INT32", device = ""} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, !_tf.control) loc("Add")
|
||||
%4:2 = "_tf.Mul"(%2#1, %2#0) {T = "tfdtype$DT_INT32", device = ""} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, !_tf.control) loc("Square")
|
||||
%5:3 = "_tf.Merge"(%3#0, %4#0) {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "_tf.Merge"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, !_tf.control) loc("Merge")
|
||||
return
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowPass | FileCheck %s
|
||||
|
||||
func @main() {
|
||||
%0 = "_tf._TPUReplicate"() {computation = @foo, Tinputs = [], Tbroadcast_inputs = [], NumVariables = 0, Tguaranteed_constants = [], output_types = []} : () -> !_tf.control loc("_TPUReplicate")
|
||||
return
|
||||
}
|
||||
|
||||
func @foo() {
|
||||
%0:2 = "_tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<17> : tensor<i32>} : () -> (tensor<i32>, !_tf.control) loc("x")
|
||||
%1:2 = "_tf.Const"() {dtype = "tfdtype$DT_BOOL", value = dense<true> : tensor<i1>} : () -> (tensor<i1>, !_tf.control) loc("predicate")
|
||||
%2:3 = "_tf.Switch"(%0#0, %1#0) {T = "tfdtype$DT_INT32"} : (tensor<i32>, tensor<i1>) -> (tensor<i32>, tensor<i32>, !_tf.control) loc("switch")
|
||||
%3:2 = "_tf.Add"(%2#0, %2#0) {T = "tfdtype$DT_INT32"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, !_tf.control) loc("Addition")
|
||||
%4:2 = "_tf.Mul"(%2#1, %2#1) {T = "tfdtype$DT_INT32"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, !_tf.control) loc("Multiplication")
|
||||
%5:3 = "_tf.Merge"(%3#0, %4#0) {N = 2 : i64, T = "tfdtype$DT_INT32"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, !_tf.control) loc("Merge")
|
||||
return
|
||||
}
|
||||
|
||||
// Match the name of the cloned function with functionalized control-flow at call site
|
||||
// CHECK: func @main()
|
||||
// CHECK-NEXT: computation = @[[FUNCTIONALIZE_FUNC:[A-Za-z0-9_]*]]
|
||||
|
||||
|
||||
// In the newly cloned function, check that we have a _tf.If operation and capture the then and else branch.
|
||||
// CHECK: func @[[FUNCTIONALIZE_FUNC]]
|
||||
// CHECK: "_tf.If"
|
||||
// CHECK-SAME: else_branch = @[[ELSE_FUNC:[A-Za-z0-9_]*]]
|
||||
// CHECK-SAME: then_branch = @[[THEN_FUNC:[A-Za-z0-9_]*]]
|
||||
|
||||
// We expect the _tf.Add in the else func and the _tf.Mul in the then func
|
||||
|
||||
// CHECK: func @[[ELSE_FUNC]]
|
||||
// CHECK: "_tf.Add"
|
||||
// CHECK: func @[[THEN_FUNC]]
|
||||
// CHECK: "_tf.Mul"
|
@ -0,0 +1,154 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
#define DEBUG_TYPE "run-tf-graph-optimization"
|
||||
|
||||
// TODO(prakalps): Move these flags and pass registration to a header file so
|
||||
// that it is clear that this is a generic pass library and command line is used
|
||||
// for testing only.
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::list<std::string> cl_pass_list(
|
||||
"graph-passes", llvm::cl::value_desc("list"),
|
||||
llvm::cl::desc("comma seprarated list of GraphOptimizationPass to run."),
|
||||
llvm::cl::CommaSeparated, llvm::cl::cat(clOptionsCategory));
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Creates a pass to convert MLIR to Graph, run user-specified Graph
|
||||
// Optimization Passes and convert back to MLIR.
|
||||
// Constraints: This pass expects that all operations in the MLIR module either
|
||||
// belong to 'tf' or '_tf' dialect. The output is in '_tf' dialect.
|
||||
class GraphOptPass : public mlir::ModulePass<GraphOptPass> {
|
||||
public:
|
||||
explicit GraphOptPass() : pass_names_(cl_pass_list) {}
|
||||
explicit GraphOptPass(std::vector<string> pass_names)
|
||||
: pass_names_(pass_names) {}
|
||||
|
||||
private:
|
||||
// Returns a vector of all the passes requested by the user.
|
||||
std::vector<GraphOptimizationPass*> FindPassIds();
|
||||
|
||||
void runOnModule() override;
|
||||
|
||||
// The Graph passes are executed in the order present in the pass_names_
|
||||
// vector.
|
||||
std::vector<string> pass_names_;
|
||||
};
|
||||
|
||||
std::vector<GraphOptimizationPass*> GraphOptPass::FindPassIds() {
|
||||
std::vector<GraphOptimizationPass*> pass_ids(pass_names_.size(), nullptr);
|
||||
|
||||
for (const auto& group : OptimizationPassRegistry::Global()->groups()) {
|
||||
for (const auto& phase : group.second) {
|
||||
for (const auto& pass : phase.second) {
|
||||
// Iterate over the pass_names_ and insert the pass pointer at all the
|
||||
// corresponding indices in the pass_ids vector.
|
||||
auto iter = pass_names_.begin();
|
||||
while ((iter = std::find(iter, pass_names_.end(), pass->name())) !=
|
||||
pass_names_.end()) {
|
||||
pass_ids[std::distance(pass_names_.begin(), iter)] = pass.get();
|
||||
iter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return pass_ids;
|
||||
}
|
||||
|
||||
void GraphOptPass::runOnModule() {
|
||||
mlir::Module module_in = getModule();
|
||||
mlir::MLIRContext& ctx = getContext();
|
||||
|
||||
// Convert MLIR to Graph
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(),
|
||||
FunctionDefLibrary());
|
||||
ExporterConfigs confs;
|
||||
auto graph = absl::make_unique<Graph>(flib_def);
|
||||
Status status = ConvertMlirToGraph(module_in, confs, &graph, &flib_def);
|
||||
if (!status.ok()) {
|
||||
mlir::emitError(mlir::UnknownLoc::get(&ctx)) << status.error_message();
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// Run each of the passes that were selected.
|
||||
std::vector<GraphOptimizationPass*> passes = FindPassIds();
|
||||
GraphConstructorOptions opts;
|
||||
opts.allow_internal_ops = true;
|
||||
opts.expect_device_spec = false;
|
||||
|
||||
GraphOptimizationPassOptions options;
|
||||
SessionOptions sess_options;
|
||||
options.graph = &graph;
|
||||
options.flib_def = &flib_def;
|
||||
options.session_options = &sess_options;
|
||||
|
||||
for (auto p = passes.begin(), e = passes.end(); p != e; ++p) {
|
||||
auto pass = *p;
|
||||
if (pass == nullptr) {
|
||||
mlir::emitError(mlir::UnknownLoc::get(&ctx))
|
||||
<< "Could not find pass "
|
||||
<< pass_names_[std::distance(passes.begin(), p)];
|
||||
return signalPassFailure();
|
||||
}
|
||||
Status status = pass->Run(options);
|
||||
if (!status.ok()) {
|
||||
mlir::emitError(mlir::UnknownLoc::get(&ctx))
|
||||
<< pass->name() << ": " << status.error_message();
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
// Convert Graph to MLIR
|
||||
GraphDebugInfo debug_info;
|
||||
NodeSpecs specs;
|
||||
auto module_or_status =
|
||||
ConvertGraphToMlir(**options.graph, debug_info, flib_def, specs, &ctx);
|
||||
if (!module_or_status.ok()) {
|
||||
mlir::emitError(mlir::UnknownLoc::get(&ctx))
|
||||
<< module_or_status.status().error_message();
|
||||
return signalPassFailure();
|
||||
}
|
||||
auto module_out = std::move(module_or_status).ValueOrDie();
|
||||
|
||||
// We cannot replace the module in a ModulePass. So we simply copy the
|
||||
// Function list from module_out to module_in.
|
||||
module_in.clear();
|
||||
module_in.splice(module_in.getFunctions().end(), *module_out);
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
static mlir::PassRegistration<tensorflow::GraphOptPass> pass(
|
||||
DEBUG_TYPE, "runs passes registered as tensorflow::GraphOptimizationPass.");
|
@ -561,6 +561,7 @@ cc_library(
|
||||
srcs = [
|
||||
"functionalize_control_flow_pass_registration.cc",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":functionalize_control_flow",
|
||||
],
|
||||
|
Loading…
Reference in New Issue
Block a user