From a05f8bdf0e964d27cab4dfbe36576b7692f055da Mon Sep 17 00:00:00 2001
From: Feng Liu <fengliuai@google.com>
Date: Mon, 12 Oct 2020 16:36:09 -0700
Subject: [PATCH] Add the op expansion graph optimization pass to tensorflow

PiperOrigin-RevId: 336770266
Change-Id: Iadc1917c9be25bc9010129af551814ae72160347
---
 tensorflow/compiler/mlir/tfr/BUILD            |  22 +++-
 .../tfr/integration/graph_decompose_pass.cc   |  19 ++-
 .../tfr/integration/graph_decompose_pass.h    |  11 +-
 .../tfr/integration/graph_decompose_test.py   |  90 +++++++++++++
 .../mlir/tfr/integration/tfr_decompose_ctx.cc |   9 --
 .../mlir/tfr/integration/tfr_decompose_ctx.h  |   5 -
 .../tfr/integration/tfr_decompose_ctx_test.cc |   1 -
 tensorflow/compiler/mlir/tfr/resources/BUILD  |  21 ++-
 tensorflow/compiler/mlir/tfr/tfr.bzl          | 120 ++++++++++++++++++
 9 files changed, 267 insertions(+), 31 deletions(-)
 create mode 100644 tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py
 create mode 100644 tensorflow/compiler/mlir/tfr/tfr.bzl

diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD
index 420648dff54..63ce46b853c 100644
--- a/tensorflow/compiler/mlir/tfr/BUILD
+++ b/tensorflow/compiler/mlir/tfr/BUILD
@@ -18,7 +18,8 @@ package_group(
     includes = ["//third_party/mlir:subpackages"],
     packages = [
         "//learning/brain/experimental/mlir/tfr/...",
-        "//tensorflow/compiler/mlir/...",
+        "//tensorflow/c/...",
+        "//tensorflow/compiler/...",
     ],
 )
 
@@ -179,7 +180,6 @@ cc_library(
         "//tensorflow/compiler/mlir/tfr:passes",
         "//tensorflow/core:framework",
         "//tensorflow/core:graph",
-        "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/common_runtime:optimization_registry",
         "//tensorflow/stream_executor/lib",
@@ -201,7 +201,6 @@ tf_cc_test(
         ":tfr_decompose_ctx",
         "//tensorflow/compiler/xla:test",
         "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
@@ -216,7 +215,6 @@ cc_library(
     name = "graph_decompose_pass",
     srcs = ["integration/graph_decompose_pass.cc"],
     hdrs = ["integration/graph_decompose_pass.h"],
-    data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"],
     deps = [
         ":tfr_decompose_ctx",
         "//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
@@ -228,6 +226,22 @@ cc_library(
     alwayslink = 1,
 )
 
+tf_py_test(
+    name = "graph_decompose_test",
+    size = "small",
+    srcs = ["integration/graph_decompose_test.py"],
+    data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"],
+    python_version = "PY3",
+    tags = [
+        "no_oss",
+        "notap",
+    ],
+    deps = [
+        "//tensorflow/compiler/mlir/tfr/resources:composite_ops",
+        "//tensorflow/python/eager:def_function",
+    ],
+)
+
 tf_python_pybind_extension(
     name = "tfr_wrapper",
     srcs = ["python/tfr_wrapper.cc"],
diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc
index 9fd7ee03cb9..be0fab13021 100644
--- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc
+++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc
@@ -15,7 +15,6 @@ limitations under the License.
 #include "tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h"
 
 #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
-#include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/util/env_var.h"
@@ -23,8 +22,21 @@ limitations under the License.
 
 namespace tensorflow {
 
+constexpr const char* const kTFRLibEnv = "TF_MLIR_TFR_LIB_DIR";
+
+bool GraphDecomposePass::IsEnabled(const ConfigProto& config_proto) const {
+  const char* tfr_lib_env_val = getenv(string(kTFRLibEnv).c_str());
+  return tfr_lib_env_val != nullptr;
+}
+
 Status GraphDecomposePass::Run(const ConfigProto& config_proto,
                                mlir::ModuleOp module) {
+  if (!IsEnabled(config_proto)) {
+    VLOG(1) << "Skipping Graph Decomposition Pass, decompositin library was "
+               "not found";
+    return Status::OK();
+  }
+  VLOG(1) << "Run Graph Decomposition Passes";
   TF_ASSIGN_OR_RETURN(ctx_, LoadDecompositionLib(module.getContext()));
   TF_RETURN_IF_ERROR(ctx_->Decompose(module));
   return ctx_->Destroy();
@@ -35,8 +47,7 @@ GraphDecomposePass::LoadDecompositionLib(mlir::MLIRContext* mlir_ctx) {
   Env* env = Env::Default();
   std::string tfr_lib_dir;
   TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
-      "TF_MLIR_TFR_LIB_DIR", "tensorflow/compiler/mlir/tfr/resources",
-      &tfr_lib_dir));
+      kTFRLibEnv, "tensorflow/compiler/mlir/tfr/resources", &tfr_lib_dir));
   string composite_mlir_dir = io::JoinPath(env->GetRunfilesDir(), tfr_lib_dir);
   std::vector<string> files;
   TF_RETURN_IF_ERROR(env->GetChildren(composite_mlir_dir, &files));
@@ -59,7 +70,7 @@ GraphDecomposePass::LoadDecompositionLib(mlir::MLIRContext* mlir_ctx) {
 }
 
 namespace {
-constexpr int kMlirGraphDecomposePassPriority = 1;
+constexpr int kMlirGraphDecomposePassPriority = -1;
 
 static mlir_pass_registration::MlirOptimizationPassRegistration
     register_mlir_graph_decompose_pass(kMlirGraphDecomposePassPriority,
diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h
index f0963379928..89db74c72ed 100644
--- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h
+++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h
@@ -16,9 +16,9 @@ limitations under the License.
 #define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_
 
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
 #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
-#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
 
 namespace tensorflow {
 
@@ -30,10 +30,9 @@ class GraphDecomposePass : public MlirOptimizationPass {
  public:
   llvm::StringRef name() const override { return "tfr"; }
 
-  bool IsEnabled(const ConfigProto& config_proto) const override {
-    // TODO(fengliuai): make a new flag in config_proto.experimental()
-    return true;
-  }
+  // Whether to run this pass. If this is enabled, the GraphDef will be imported
+  // to MLIR even no tf composition file is found.
+  bool IsEnabled(const ConfigProto& config_proto) const override;
 
   // This should be used as a thin mapper around mlir::ModulePass::runOnModule
   // API integrated with the Tensorflow runtime.
diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py
new file mode 100644
index 00000000000..a03de698cb1
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py
@@ -0,0 +1,90 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for third_party.tensorflow.compiler.mlir.tfr.integrattion.graph_decompose."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.compiler.mlir.tfr.resources import gen_composite_ops
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import load_library
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+
+_lib_dir = os.path.dirname(gen_composite_ops.__file__)
+_lib_name = os.path.basename(gen_composite_ops.__file__)[4:].replace(
+    '.py', '.so')
+load_library.load_op_library(os.path.join(_lib_dir, _lib_name))
+
+
+class GraphDecomposeTest(test.TestCase):
+
+  def setUp(self):
+    os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources'
+    super(GraphDecomposeTest, self).setUp()
+
+  def tearDown(self):
+    del os.environ['TF_MLIR_TFR_LIB_DIR']
+    super(GraphDecomposeTest, self).tearDown()
+
+  def testAddN(self):
+    add = def_function.function(gen_composite_ops.my_add_n)
+    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+    t3 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+    sq1 = add([t1])
+    sq2 = add([t1, t2])
+    sq3 = add([t1, t2, t3])
+    self.assertAllEqual(sq1.numpy().reshape(-1), [1, 2, 3, 4])
+    self.assertAllEqual(sq2.numpy().reshape(-1), [2, 4, 6, 8])
+    self.assertAllEqual(sq3.numpy().reshape(-1), [3, 6, 9, 12])
+
+  def testBiasedDense(self):
+    biased_dense = def_function.function(gen_composite_ops.my_biased_dense)
+    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+    t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
+    sq = biased_dense(t1, t2, t3)
+    self.assertAllEqual(sq.numpy().reshape(-1), [-3, 0, 5, 12])
+
+  def testBiasedDenseRelu(self):
+    biased_dense = def_function.function(gen_composite_ops.my_biased_dense)
+    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+    t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
+    sq = biased_dense(t1, t2, t3, act='relu')
+    self.assertAllEqual(sq.numpy().reshape(-1), [0, 0, 5, 12])
+
+  def testWithKnownKernel(self):
+
+    @def_function.function
+    def biasd_dense_elu(x, y, z):
+      dot = gen_composite_ops.my_biased_dense(x, y, z)
+      return nn_ops.elu(dot)  # with known kernel, should not expand.
+
+    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+    t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
+    sq = biasd_dense_elu(t1, t2, t3)
+    self.assertAllClose(sq.numpy().reshape(-1), [-0.950213, 0, 5, 12])
+
+
+if __name__ == '__main__':
+  ops.enable_eager_execution()
+  test.main()
diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc
index 7a2962c7b67..4cc7d90f17b 100644
--- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc
+++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc
@@ -36,18 +36,9 @@ limitations under the License.
 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
 #include "tensorflow/compiler/mlir/tfr/passes/passes.h"
-#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/framework/versions.pb.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/node_builder.h"
 #include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/protobuf.h"
-#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
 #include "tensorflow/core/protobuf/struct.pb.h"
 #include "tensorflow/stream_executor/lib/statusor.h"
diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h
index b51d6158eb2..9ff314ebc92 100644
--- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h
+++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h
@@ -20,13 +20,8 @@ limitations under the License.
 #include "mlir/IR/Module.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
-#include "tensorflow/core/common_runtime/optimization_registry.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
 #include "tensorflow/stream_executor/lib/statusor.h"
 
diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc
index 8b1b0453cff..5d10936e092 100644
--- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc
+++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc
@@ -29,7 +29,6 @@ limitations under the License.
 #include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/stream_executor/lib/statusor.h"
diff --git a/tensorflow/compiler/mlir/tfr/resources/BUILD b/tensorflow/compiler/mlir/tfr/resources/BUILD
index 0d0705b1da0..62ca65c5b57 100644
--- a/tensorflow/compiler/mlir/tfr/resources/BUILD
+++ b/tensorflow/compiler/mlir/tfr/resources/BUILD
@@ -33,14 +33,31 @@ cc_library(
     alwayslink = 1,
 )
 
+tf_custom_op_library(
+    name = "composite_ops.so",
+    srcs = [
+        "composite_ops.cc",
+    ],
+)
+
 tf_gen_op_wrapper_py(
-    name = "composite_ops",
-    out = "composite_ops.py",
+    name = "gen_composite_ops",
+    out = "gen_composite_ops.py",
     deps = [
         ":composite_ops_cc",
     ],
 )
 
+tf_custom_op_py_library(
+    name = "composite_ops",
+    dso = [":composite_ops.so"],
+    kernels = [":composite_ops_cc"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":gen_composite_ops",
+    ],
+)
+
 cc_library(
     name = "test_ops_cc",
     srcs = ["test_ops.cc"],
diff --git a/tensorflow/compiler/mlir/tfr/tfr.bzl b/tensorflow/compiler/mlir/tfr/tfr.bzl
new file mode 100644
index 00000000000..cc1b617f932
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfr/tfr.bzl
@@ -0,0 +1,120 @@
+"""BUILD extension for TF composition project."""
+
+load("//tensorflow:tensorflow.bzl", "py_binary", "tf_gen_op_wrapper_py")
+load("//tensorflow:tensorflow.google.bzl", "pytype_library")
+
+def gen_op_libraries(
+        name,
+        src,
+        deps,
+        tags = [],
+        test = False):
+    """gen_op_libraries() generates all cc and py libraries for composite op source.
+
+    Args:
+        name: used as the name component of all the generated libraries.
+        src: File contains the composite ops.
+        deps: Libraries the 'src' depends on.
+        tags:
+        test:
+    """
+    if not src.endswith(".py") or name == src[:-3]:
+        fail("'src' %s conflicts with op Python wrapper. Rename it to be different from 'name'." % src)
+
+    gen_op_lib_exec = src[:-3]
+    py_binary(
+        name = gen_op_lib_exec,
+        srcs = [src],
+        srcs_version = "PY2AND3",
+        python_version = "PY3",
+        deps = [
+            "//tensorflow/python:platform",
+        ] + deps,
+    )
+
+    register_op = "register_" + name
+    native.genrule(
+        name = register_op,
+        srcs = [],
+        outs = [name + ".inc.cc"],
+        cmd = "$(location %s) --output=$@ --gen_register_op=true" % gen_op_lib_exec,
+        exec_tools = [":" + gen_op_lib_exec],
+        local = 1,
+        tags = tags,
+    )
+
+    native.cc_library(
+        name = name + "_cc",
+        testonly = test,
+        srcs = [":" + register_op],
+        copts = [
+            "-Wno-unused-result",
+            "-Wno-unused-variable",
+        ],
+        deps = [
+            "//tensorflow/core:framework",
+            "//tensorflow/core:lib",
+            "//tensorflow/core:protos_all_cc",
+        ],
+        alwayslink = 1,
+    )
+
+    tf_gen_op_wrapper_py(
+        name = name,
+        out = name + ".py",
+        deps = [
+            ":%s_cc" % name,
+        ],
+    )
+
+    pytype_library(
+        name = name + "_grads",
+        srcs = [
+            src,
+        ],
+        srcs_version = "PY2AND3",
+        deps = [
+            "//third_party/py/numpy",
+            "//third_party/py/tensorflow",
+        ] + deps,
+    )
+
+    pytype_library(
+        name = name + "_lib",
+        srcs = [
+            name + ".py",
+        ],
+        srcs_version = "PY2AND3",
+        deps = [
+            ":%s" % name,
+            ":%s_cc" % name,
+            ":%s_grads" % name,
+            "//third_party/py/numpy",
+            "//third_party/py/tensorflow",
+        ] + deps,
+    )
+
+    # Link the register op and rebuild the binary
+    gen_tfr_lib_exec = gen_op_lib_exec + "_registered"
+    py_binary(
+        name = gen_tfr_lib_exec,
+        main = src,
+        srcs = [src],
+        srcs_version = "PY2AND3",
+        python_version = "PY3",
+        deps = [
+            "//tensorflow/python:platform",
+            ":%s" % name + "_cc",
+        ] + deps,
+    )
+
+    op_tfr = "composite_" + name
+    native.genrule(
+        name = op_tfr,
+        srcs = [],
+        outs = [name + ".mlir"],
+        cmd = "$(location %s) --output=$@ --gen_register_op=false" % gen_tfr_lib_exec,
+        exec_tools = [":" + gen_tfr_lib_exec],
+        local = 1,
+        tags = tags,
+    )