From 0505a9e2cbd43f11f99686141f388638b5fd5ab1 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 13 Dec 2019 18:26:14 -0800
Subject: [PATCH] Fix experimental_compile=True for graph mode

Previously the attribute only worked in eager mode, and was a no-op otherwise.
Note that this also resolves the problem of a function with
experimental_compile=True not being compiled when called from
experimental_compile=False context.
PiperOrigin-RevId: 285509461
Change-Id: I3f8d5611fea5b7430feba1c58f937e121d71b75c
---
 .../compiler/jit/xla_kernel_creator_util.cc   | 51 ++++++--------
 .../core/common_runtime/direct_session.cc     |  4 +-
 tensorflow/python/compiler/xla/BUILD          | 17 -----
 .../compiler/xla/experimental_compile_test.py | 68 -------------------
 4 files changed, 23 insertions(+), 117 deletions(-)
 delete mode 100644 tensorflow/python/compiler/xla/experimental_compile_test.py

diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc
index 5c2ba56af7c..96bde65003f 100644
--- a/tensorflow/compiler/jit/xla_kernel_creator_util.cc
+++ b/tensorflow/compiler/jit/xla_kernel_creator_util.cc
@@ -23,9 +23,7 @@ limitations under the License.
 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
 #include "tensorflow/compiler/tf2xla/const_analysis.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/util/ptr_util.h"
 
@@ -72,42 +70,38 @@ class SinglePassSearch {
 
 bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
                         const NodeDef& node_def) {
-  VLOG(2) << "Called CanCreateXlaKernel, input: " << SummarizeNodeDef(node_def);
-  NameAttrList attr_list;
-  if (!NameAndAttrsFromFunctionCall(node_def, &attr_list).ok()) {
-    return false;
-  }
-  std::string func_name = attr_list.name();
   const FunctionDef* function_def =
-      flr.GetFunctionLibraryDefinition()->Find(func_name);
+      flr.GetFunctionLibraryDefinition()->Find(node_def.name());
   if (function_def == nullptr) {
     // The node def is not calling a function. Individual ops can be
     // run directly using on-demand mode, no need to create XlaLaunch
     // kernel for them.
-    VLOG(2) << "Not creating XlaLaunch kernel for " << func_name
-            << " because it does not seem to be a function";
     return false;
   }
 
   // If kXlaCompileAttr is set on the node_def, use its value.
   const auto& it = node_def.attr().find(kXlaCompileAttr);
   if (it != node_def.attr().end()) {
-    bool value = it->second.b();
-    VLOG(2) << "Found " << kXlaCompileAttr
-            << " attribute  with value = " << value
-            << " on node: " << SummarizeNodeDef(node_def);
-    return value;
+    return it->second.b();
   }
 
-  // Otherwise, look for it on the custom defition.
-  const auto& fit = function_def->attr().find(kXlaCompileAttr);
-  if (fit != function_def->attr().end()) {
-    bool value = fit->second.b();
-    VLOG(2) << "Found " << kXlaCompileAttr << " attribute on function "
-            << func_name << " with value = " << value;
-    return value;
+  // kXlaCompileAttr is not set on node_def, check if it is set on
+  // FunctionDef.
+  bool xla_compile = false;
+  Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
+      node_def, kXlaCompileAttr, &xla_compile);
+  if (!status.ok() || !xla_compile) {
+    if (VLOG_IS_ON(3)) {
+      if (!status.ok()) {
+        VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
+                << node_def.op() << ". status=" << status.ToString();
+      } else {
+        VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
+      }
+    }
+    return false;
   }
-  return false;
+  return true;
 }
 
 // Given a FunctionLibraryRuntime and a NodeDef calling a function in the
@@ -124,11 +118,8 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
   FunctionLibraryRuntime::Handle handle;
   // If node_def is not instantiable, e.g., the function does not exist,
   // simply bail out.
-  NameAttrList function;
-  TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
-
   TF_RETURN_IF_ERROR(
-      flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
+      flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
   *fbody = flr->GetFunctionBody(handle);
   CHECK(*fbody);  // Can't be nullptr since we just instantiated it.
   const DataTypeVector& arg_types = (*fbody)->arg_types;
@@ -250,7 +241,9 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
 
   // Create the kernel.
   NameAttrList function;
-  TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
+  function.set_name(node_def.op());
+  *(function.mutable_attr()) = node_def.attr();
+
   Device* dev = flr->device();
   Status s;
   OpKernelConstruction construction(
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 1f20e201b7e..c836cb23898 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1302,12 +1302,10 @@ Status DirectSession::CreateExecutors(
       options_.config.experimental().has_session_metadata()
           ? &options_.config.experimental().session_metadata()
           : nullptr;
-  const CustomKernelCreator* custom_kernel_creator =
-      GetDefaultCustomKernelCreator();
   func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
       device_mgr_.get(), options_.env, &options_.config, graph_def_version,
       func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
-      nullptr, custom_kernel_creator, session_metadata));
+      nullptr, nullptr, session_metadata));
 
   GraphOptimizer optimizer(optimizer_opts);
   for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
diff --git a/tensorflow/python/compiler/xla/BUILD b/tensorflow/python/compiler/xla/BUILD
index 9f7afae4052..2061f0cca2f 100644
--- a/tensorflow/python/compiler/xla/BUILD
+++ b/tensorflow/python/compiler/xla/BUILD
@@ -91,20 +91,3 @@ cuda_py_test(
         "@absl_py//absl/testing:parameterized",
     ],
 )
-
-cuda_py_test(
-    name = "experimental_compile_test",
-    srcs = ["experimental_compile_test.py"],
-    additional_deps = [
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python:constant_op",
-        "//tensorflow/python:framework_ops",
-        "//tensorflow/python:resource_variable_ops",
-    ],
-    python_version = "PY3",
-    tags = [
-        "no_mac",
-        "no_windows",
-    ],
-    xla_enabled = True,
-)
diff --git a/tensorflow/python/compiler/xla/experimental_compile_test.py b/tensorflow/python/compiler/xla/experimental_compile_test.py
deleted file mode 100644
index c8ea3d05001..00000000000
--- a/tensorflow/python/compiler/xla/experimental_compile_test.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.client import session
-from tensorflow.python.eager import def_function
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class ExperimentalCompileTest(test.TestCase):
-
-  def testBasic(self):
-    with ops.Graph().as_default() as g:
-
-      def fn(x, a):
-        return x + a
-
-      xla_func = def_function.function(fn, experimental_compile=True)
-      inputs = array_ops.placeholder(dtypes.float32, [5])
-      # XLA support is not yet enabled for TF ROCm
-      if not test.is_built_with_rocm():
-        x = xla_func(inputs, 1)
-        with session.Session(graph=g) as sess:
-          y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
-          self.assertTrue(
-              x.graph.as_graph_def().library.function[0].attr["_XlaCompile"].b)
-          self.assertAllClose([2, 3, 3, 4, 4], y)
-
-  # Checking that we crash on an unsupported operation lets us test that the XLA
-  # compiler was actually invoked.
-  def testUnsupportedOps(self):
-    with ops.Graph().as_default() as g:
-
-      def fn(x):
-        return array_ops.unique(x).y  # Unique is not supported by XLA
-
-      xla_func = def_function.function(fn, experimental_compile=True)
-      inputs = array_ops.placeholder(dtypes.float32, [5])
-      x = xla_func(inputs)
-      # XLA support is not yet enabled for TF ROCm
-      if not test.is_built_with_rocm():
-        with self.assertRaisesRegexp(errors.InvalidArgumentError,
-                                     "not compilable"):
-          with session.Session(graph=g) as sess:
-            sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
-
-
-if __name__ == "__main__":
-  test.main()