Pass to convert MLIR to Graph and running user specified Graph passes on it.

This pass converts MLIR to Graph, runs user-specified Graph Optimization Passes in the order specified and converts the graph back to MLIR.
Usage: mlir-tf-opt <input.mlir> --run-tf-graph-optimization --graph-passes=<comma separated list of passes> -o <output.mlir>
PiperOrigin-RevId: 256382837
This commit is contained in:
Prakalp Srivastava 2019-07-03 09:30:37 -07:00 committed by TensorFlower Gardener
parent 7a8fd724ea
commit 26f9dbe537
5 changed files with 229 additions and 0 deletions

View File

@ -373,6 +373,7 @@ cc_library(
":convert_tensor", ":convert_tensor",
":eval_util", ":eval_util",
":tensorflow", ":tensorflow",
":tf_graph_optimization_pass",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -396,6 +397,27 @@ cc_library(
], ],
) )
cc_library(
name = "tf_graph_optimization_pass",
srcs = ["transforms/tf_graph_optimization_pass.cc"],
deps = [
":convert_graphdef",
":mlir_roundtrip_flags",
"//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "eval_util", name = "eval_util",
srcs = ["utils/eval_util.cc"], srcs = ["utils/eval_util.cc"],

View File

@ -0,0 +1,18 @@
// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowPass 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0
// CHECK: FunctionalizeControlFlowPass: Graph contains node with inputs predicated on incompatible predicates: {s(Cond:0,then)} and {s(Cond:0,else)}
// CHECK-NEXT: for node {{[{][{]node Add[}][}]}}
func @main() {
%0 = "_tf._TPUReplicate"() {computation = @foo, Tinputs = [], Tbroadcast_inputs = [], NumVariables = 0, Tguaranteed_constants = [], output_types = []} : () -> !_tf.control loc("_TPUReplicate")
return
}
func @foo() {
%0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", value = dense<17> : tensor<i32>} : () -> (tensor<i32>, !_tf.control) loc("x")
%1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_BOOL", value = dense<true> : tensor<i1>} : () -> (tensor<i1>, !_tf.control) loc("Cond")
%2:3 = "_tf.Switch"(%0#0, %1#0) {T = "tfdtype$DT_INT32", device = ""} : (tensor<i32>, tensor<i1>) -> (tensor<i32>, tensor<i32>, !_tf.control) loc("switch")
%3:2 = "_tf.Add"(%2#0, %2#1) {T = "tfdtype$DT_INT32", device = ""} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, !_tf.control) loc("Add")
%4:2 = "_tf.Mul"(%2#1, %2#0) {T = "tfdtype$DT_INT32", device = ""} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, !_tf.control) loc("Square")
%5:3 = "_tf.Merge"(%3#0, %4#0) {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "_tf.Merge"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, !_tf.control) loc("Merge")
return
}

View File

@ -0,0 +1,34 @@
// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowPass | FileCheck %s
func @main() {
%0 = "_tf._TPUReplicate"() {computation = @foo, Tinputs = [], Tbroadcast_inputs = [], NumVariables = 0, Tguaranteed_constants = [], output_types = []} : () -> !_tf.control loc("_TPUReplicate")
return
}
func @foo() {
%0:2 = "_tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<17> : tensor<i32>} : () -> (tensor<i32>, !_tf.control) loc("x")
%1:2 = "_tf.Const"() {dtype = "tfdtype$DT_BOOL", value = dense<true> : tensor<i1>} : () -> (tensor<i1>, !_tf.control) loc("predicate")
%2:3 = "_tf.Switch"(%0#0, %1#0) {T = "tfdtype$DT_INT32"} : (tensor<i32>, tensor<i1>) -> (tensor<i32>, tensor<i32>, !_tf.control) loc("switch")
%3:2 = "_tf.Add"(%2#0, %2#0) {T = "tfdtype$DT_INT32"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, !_tf.control) loc("Addition")
%4:2 = "_tf.Mul"(%2#1, %2#1) {T = "tfdtype$DT_INT32"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, !_tf.control) loc("Multiplication")
%5:3 = "_tf.Merge"(%3#0, %4#0) {N = 2 : i64, T = "tfdtype$DT_INT32"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, !_tf.control) loc("Merge")
return
}
// Match the name of the cloned function with functionalized control-flow at call site
// CHECK: func @main()
// CHECK-NEXT: computation = @[[FUNCTIONALIZE_FUNC:[A-Za-z0-9_]*]]
// In the newly cloned function, check that we have a _tf.If operation and capture the then and else branch.
// CHECK: func @[[FUNCTIONALIZE_FUNC]]
// CHECK: "_tf.If"
// CHECK-SAME: else_branch = @[[ELSE_FUNC:[A-Za-z0-9_]*]]
// CHECK-SAME: then_branch = @[[THEN_FUNC:[A-Za-z0-9_]*]]
// We expect the _tf.Add in the else func and the _tf.Mul in the then func
// CHECK: func @[[ELSE_FUNC]]
// CHECK: "_tf.Add"
// CHECK: func @[[THEN_FUNC]]
// CHECK: "_tf.Mul"

View File

@ -0,0 +1,154 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/CommandLine.h"
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#define DEBUG_TYPE "run-tf-graph-optimization"
// TODO(prakalps): Move these flags and pass registration to a header file so
// that it is clear that this is a generic pass library and command line is used
// for testing only.
// NOLINTNEXTLINE
static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
// NOLINTNEXTLINE
static llvm::cl::list<std::string> cl_pass_list(
"graph-passes", llvm::cl::value_desc("list"),
llvm::cl::desc("comma seprarated list of GraphOptimizationPass to run."),
llvm::cl::CommaSeparated, llvm::cl::cat(clOptionsCategory));
namespace tensorflow {
// Creates a pass to convert MLIR to Graph, run user-specified Graph
// Optimization Passes and convert back to MLIR.
// Constraints: This pass expects that all operations in the MLIR module either
// belong to 'tf' or '_tf' dialect. The output is in '_tf' dialect.
class GraphOptPass : public mlir::ModulePass<GraphOptPass> {
public:
explicit GraphOptPass() : pass_names_(cl_pass_list) {}
explicit GraphOptPass(std::vector<string> pass_names)
: pass_names_(pass_names) {}
private:
// Returns a vector of all the passes requested by the user.
std::vector<GraphOptimizationPass*> FindPassIds();
void runOnModule() override;
// The Graph passes are executed in the order present in the pass_names_
// vector.
std::vector<string> pass_names_;
};
std::vector<GraphOptimizationPass*> GraphOptPass::FindPassIds() {
std::vector<GraphOptimizationPass*> pass_ids(pass_names_.size(), nullptr);
for (const auto& group : OptimizationPassRegistry::Global()->groups()) {
for (const auto& phase : group.second) {
for (const auto& pass : phase.second) {
// Iterate over the pass_names_ and insert the pass pointer at all the
// corresponding indices in the pass_ids vector.
auto iter = pass_names_.begin();
while ((iter = std::find(iter, pass_names_.end(), pass->name())) !=
pass_names_.end()) {
pass_ids[std::distance(pass_names_.begin(), iter)] = pass.get();
iter++;
}
}
}
}
return pass_ids;
}
void GraphOptPass::runOnModule() {
mlir::Module module_in = getModule();
mlir::MLIRContext& ctx = getContext();
// Convert MLIR to Graph
FunctionLibraryDefinition flib_def(OpRegistry::Global(),
FunctionDefLibrary());
ExporterConfigs confs;
auto graph = absl::make_unique<Graph>(flib_def);
Status status = ConvertMlirToGraph(module_in, confs, &graph, &flib_def);
if (!status.ok()) {
mlir::emitError(mlir::UnknownLoc::get(&ctx)) << status.error_message();
return signalPassFailure();
}
// Run each of the passes that were selected.
std::vector<GraphOptimizationPass*> passes = FindPassIds();
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = false;
GraphOptimizationPassOptions options;
SessionOptions sess_options;
options.graph = &graph;
options.flib_def = &flib_def;
options.session_options = &sess_options;
for (auto p = passes.begin(), e = passes.end(); p != e; ++p) {
auto pass = *p;
if (pass == nullptr) {
mlir::emitError(mlir::UnknownLoc::get(&ctx))
<< "Could not find pass "
<< pass_names_[std::distance(passes.begin(), p)];
return signalPassFailure();
}
Status status = pass->Run(options);
if (!status.ok()) {
mlir::emitError(mlir::UnknownLoc::get(&ctx))
<< pass->name() << ": " << status.error_message();
return signalPassFailure();
}
}
// Convert Graph to MLIR
GraphDebugInfo debug_info;
NodeSpecs specs;
auto module_or_status =
ConvertGraphToMlir(**options.graph, debug_info, flib_def, specs, &ctx);
if (!module_or_status.ok()) {
mlir::emitError(mlir::UnknownLoc::get(&ctx))
<< module_or_status.status().error_message();
return signalPassFailure();
}
auto module_out = std::move(module_or_status).ValueOrDie();
// We cannot replace the module in a ModulePass. So we simply copy the
// Function list from module_out to module_in.
module_in.clear();
module_in.splice(module_in.getFunctions().end(), *module_out);
}
} // namespace tensorflow
static mlir::PassRegistration<tensorflow::GraphOptPass> pass(
DEBUG_TYPE, "runs passes registered as tensorflow::GraphOptimizationPass.");

View File

@ -561,6 +561,7 @@ cc_library(
srcs = [ srcs = [
"functionalize_control_flow_pass_registration.cc", "functionalize_control_flow_pass_registration.cc",
], ],
visibility = [":friends"],
deps = [ deps = [
":functionalize_control_flow", ":functionalize_control_flow",
], ],