From 26f9dbe5373c6aed74d4c18cddd81b94bf120bd0 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Wed, 3 Jul 2019 09:30:37 -0700 Subject: [PATCH] Pass to convert MLIR to Graph and running user specified Graph passes on it. This pass converts MLIR to Graph, runs user-specified Graph Optimization Passes in the order specified and converts the graph back to MLIR. Usage: mlir-tf-opt --run-tf-graph-optimization --graph-passes= -o PiperOrigin-RevId: 256382837 --- tensorflow/compiler/mlir/tensorflow/BUILD | 22 +++ .../tests/functionalize-if-fail.mlir | 18 ++ .../tensorflow/tests/functionalize-if.mlir | 34 ++++ .../transforms/tf_graph_optimization_pass.cc | 154 ++++++++++++++++++ tensorflow/compiler/tf2xla/BUILD | 1 + 5 files changed, 229 insertions(+) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/functionalize-if-fail.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/functionalize-if.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 5298e17e4e0..2e248dcfc65 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -373,6 +373,7 @@ cc_library( ":convert_tensor", ":eval_util", ":tensorflow", + ":tf_graph_optimization_pass", "//tensorflow/c:tf_status", "//tensorflow/c/eager:c_api", "//tensorflow/core:framework", @@ -396,6 +397,27 @@ cc_library( ], ) +cc_library( + name = "tf_graph_optimization_pass", + srcs = ["transforms/tf_graph_optimization_pass.cc"], + deps = [ + ":convert_graphdef", + ":mlir_roundtrip_flags", + "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_proto_cc", + "//tensorflow/stream_executor/lib", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + ], + alwayslink = 1, +) + cc_library( name = "eval_util", srcs = ["utils/eval_util.cc"], diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if-fail.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if-fail.mlir new file mode 100644 index 00000000000..779fe9011ff --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if-fail.mlir @@ -0,0 +1,18 @@ +// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowPass 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0 + +// CHECK: FunctionalizeControlFlowPass: Graph contains node with inputs predicated on incompatible predicates: {s(Cond:0,then)} and {s(Cond:0,else)} +// CHECK-NEXT: for node {{[{][{]node Add[}][}]}} +func @main() { + %0 = "_tf._TPUReplicate"() {computation = @foo, Tinputs = [], Tbroadcast_inputs = [], NumVariables = 0, Tguaranteed_constants = [], output_types = []} : () -> !_tf.control loc("_TPUReplicate") + return +} + +func @foo() { + %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", value = dense<17> : tensor} : () -> (tensor, !_tf.control) loc("x") + %1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_BOOL", value = dense : tensor} : () -> (tensor, !_tf.control) loc("Cond") + %2:3 = "_tf.Switch"(%0#0, %1#0) {T = "tfdtype$DT_INT32", device = ""} : (tensor, tensor) -> (tensor, tensor, !_tf.control) loc("switch") + %3:2 = "_tf.Add"(%2#0, %2#1) {T = "tfdtype$DT_INT32", device = ""} : (tensor, tensor) -> (tensor, !_tf.control) loc("Add") + %4:2 = "_tf.Mul"(%2#1, %2#0) {T = "tfdtype$DT_INT32", device = ""} : (tensor, tensor) -> (tensor, !_tf.control) loc("Square") + %5:3 = "_tf.Merge"(%3#0, %4#0) {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "_tf.Merge"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) loc("Merge") + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if.mlir new file mode 100644 index 00000000000..d3b2d835c27 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if.mlir @@ -0,0 +1,34 @@ +// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowPass | FileCheck %s + +func @main() { + %0 = "_tf._TPUReplicate"() {computation = @foo, Tinputs = [], Tbroadcast_inputs = [], NumVariables = 0, Tguaranteed_constants = [], output_types = []} : () -> !_tf.control loc("_TPUReplicate") + return +} + +func @foo() { + %0:2 = "_tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<17> : tensor} : () -> (tensor, !_tf.control) loc("x") + %1:2 = "_tf.Const"() {dtype = "tfdtype$DT_BOOL", value = dense : tensor} : () -> (tensor, !_tf.control) loc("predicate") + %2:3 = "_tf.Switch"(%0#0, %1#0) {T = "tfdtype$DT_INT32"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) loc("switch") + %3:2 = "_tf.Add"(%2#0, %2#0) {T = "tfdtype$DT_INT32"} : (tensor, tensor) -> (tensor, !_tf.control) loc("Addition") + %4:2 = "_tf.Mul"(%2#1, %2#1) {T = "tfdtype$DT_INT32"} : (tensor, tensor) -> (tensor, !_tf.control) loc("Multiplication") + %5:3 = "_tf.Merge"(%3#0, %4#0) {N = 2 : i64, T = "tfdtype$DT_INT32"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) loc("Merge") + return +} + +// Match the name of the cloned function with functionalized control-flow at call site +// CHECK: func @main() +// CHECK-NEXT: computation = @[[FUNCTIONALIZE_FUNC:[A-Za-z0-9_]*]] + + +// In the newly cloned function, check that we have a _tf.If operation and capture the then and else branch. +// CHECK: func @[[FUNCTIONALIZE_FUNC]] +// CHECK: "_tf.If" +// CHECK-SAME: else_branch = @[[ELSE_FUNC:[A-Za-z0-9_]*]] +// CHECK-SAME: then_branch = @[[THEN_FUNC:[A-Za-z0-9_]*]] + +// We expect the _tf.Add in the else func and the _tf.Mul in the then func + +// CHECK: func @[[ELSE_FUNC]] +// CHECK: "_tf.Add" +// CHECK: func @[[THEN_FUNC]] +// CHECK: "_tf.Mul" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc new file mode 100644 index 00000000000..1f75362117f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -0,0 +1,154 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/CommandLine.h" +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Identifier.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +#define DEBUG_TYPE "run-tf-graph-optimization" + +// TODO(prakalps): Move these flags and pass registration to a header file so +// that it is clear that this is a generic pass library and command line is used +// for testing only. + +// NOLINTNEXTLINE +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); + +// NOLINTNEXTLINE +static llvm::cl::list cl_pass_list( + "graph-passes", llvm::cl::value_desc("list"), + llvm::cl::desc("comma seprarated list of GraphOptimizationPass to run."), + llvm::cl::CommaSeparated, llvm::cl::cat(clOptionsCategory)); + +namespace tensorflow { + +// Creates a pass to convert MLIR to Graph, run user-specified Graph +// Optimization Passes and convert back to MLIR. +// Constraints: This pass expects that all operations in the MLIR module either +// belong to 'tf' or '_tf' dialect. The output is in '_tf' dialect. +class GraphOptPass : public mlir::ModulePass { + public: + explicit GraphOptPass() : pass_names_(cl_pass_list) {} + explicit GraphOptPass(std::vector pass_names) + : pass_names_(pass_names) {} + + private: + // Returns a vector of all the passes requested by the user. + std::vector FindPassIds(); + + void runOnModule() override; + + // The Graph passes are executed in the order present in the pass_names_ + // vector. + std::vector pass_names_; +}; + +std::vector GraphOptPass::FindPassIds() { + std::vector pass_ids(pass_names_.size(), nullptr); + + for (const auto& group : OptimizationPassRegistry::Global()->groups()) { + for (const auto& phase : group.second) { + for (const auto& pass : phase.second) { + // Iterate over the pass_names_ and insert the pass pointer at all the + // corresponding indices in the pass_ids vector. + auto iter = pass_names_.begin(); + while ((iter = std::find(iter, pass_names_.end(), pass->name())) != + pass_names_.end()) { + pass_ids[std::distance(pass_names_.begin(), iter)] = pass.get(); + iter++; + } + } + } + } + return pass_ids; +} + +void GraphOptPass::runOnModule() { + mlir::Module module_in = getModule(); + mlir::MLIRContext& ctx = getContext(); + + // Convert MLIR to Graph + FunctionLibraryDefinition flib_def(OpRegistry::Global(), + FunctionDefLibrary()); + ExporterConfigs confs; + auto graph = absl::make_unique(flib_def); + Status status = ConvertMlirToGraph(module_in, confs, &graph, &flib_def); + if (!status.ok()) { + mlir::emitError(mlir::UnknownLoc::get(&ctx)) << status.error_message(); + return signalPassFailure(); + } + + // Run each of the passes that were selected. + std::vector passes = FindPassIds(); + GraphConstructorOptions opts; + opts.allow_internal_ops = true; + opts.expect_device_spec = false; + + GraphOptimizationPassOptions options; + SessionOptions sess_options; + options.graph = &graph; + options.flib_def = &flib_def; + options.session_options = &sess_options; + + for (auto p = passes.begin(), e = passes.end(); p != e; ++p) { + auto pass = *p; + if (pass == nullptr) { + mlir::emitError(mlir::UnknownLoc::get(&ctx)) + << "Could not find pass " + << pass_names_[std::distance(passes.begin(), p)]; + return signalPassFailure(); + } + Status status = pass->Run(options); + if (!status.ok()) { + mlir::emitError(mlir::UnknownLoc::get(&ctx)) + << pass->name() << ": " << status.error_message(); + return signalPassFailure(); + } + } + + // Convert Graph to MLIR + GraphDebugInfo debug_info; + NodeSpecs specs; + auto module_or_status = + ConvertGraphToMlir(**options.graph, debug_info, flib_def, specs, &ctx); + if (!module_or_status.ok()) { + mlir::emitError(mlir::UnknownLoc::get(&ctx)) + << module_or_status.status().error_message(); + return signalPassFailure(); + } + auto module_out = std::move(module_or_status).ValueOrDie(); + + // We cannot replace the module in a ModulePass. So we simply copy the + // Function list from module_out to module_in. + module_in.clear(); + module_in.splice(module_in.getFunctions().end(), *module_out); +} +} // namespace tensorflow + +static mlir::PassRegistration pass( + DEBUG_TYPE, "runs passes registered as tensorflow::GraphOptimizationPass."); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index ec853c09cfb..3827eeab60b 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -561,6 +561,7 @@ cc_library( srcs = [ "functionalize_control_flow_pass_registration.cc", ], + visibility = [":friends"], deps = [ ":functionalize_control_flow", ],