diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc index 5c2ba56af7c..96bde65003f 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_util.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_util.cc @@ -23,9 +23,7 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/ptr_util.h" @@ -72,42 +70,38 @@ class SinglePassSearch { bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr, const NodeDef& node_def) { - VLOG(2) << "Called CanCreateXlaKernel, input: " << SummarizeNodeDef(node_def); - NameAttrList attr_list; - if (!NameAndAttrsFromFunctionCall(node_def, &attr_list).ok()) { - return false; - } - std::string func_name = attr_list.name(); const FunctionDef* function_def = - flr.GetFunctionLibraryDefinition()->Find(func_name); + flr.GetFunctionLibraryDefinition()->Find(node_def.name()); if (function_def == nullptr) { // The node def is not calling a function. Individual ops can be // run directly using on-demand mode, no need to create XlaLaunch // kernel for them. - VLOG(2) << "Not creating XlaLaunch kernel for " << func_name - << " because it does not seem to be a function"; return false; } // If kXlaCompileAttr is set on the node_def, use its value. const auto& it = node_def.attr().find(kXlaCompileAttr); if (it != node_def.attr().end()) { - bool value = it->second.b(); - VLOG(2) << "Found " << kXlaCompileAttr - << " attribute with value = " << value - << " on node: " << SummarizeNodeDef(node_def); - return value; + return it->second.b(); } - // Otherwise, look for it on the custom defition. - const auto& fit = function_def->attr().find(kXlaCompileAttr); - if (fit != function_def->attr().end()) { - bool value = fit->second.b(); - VLOG(2) << "Found " << kXlaCompileAttr << " attribute on function " - << func_name << " with value = " << value; - return value; + // kXlaCompileAttr is not set on node_def, check if it is set on + // FunctionDef. + bool xla_compile = false; + Status status = flr.GetFunctionLibraryDefinition()->GetAttr( + node_def, kXlaCompileAttr, &xla_compile); + if (!status.ok() || !xla_compile) { + if (VLOG_IS_ON(3)) { + if (!status.ok()) { + VLOG(3) << "No " << kXlaCompileAttr << " attr defined for " + << node_def.op() << ". status=" << status.ToString(); + } else { + VLOG(3) << node_def.op() << " is explicitly marked not to be compiled"; + } + } + return false; } - return false; + return true; } // Given a FunctionLibraryRuntime and a NodeDef calling a function in the @@ -124,11 +118,8 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle; // If node_def is not instantiable, e.g., the function does not exist, // simply bail out. - NameAttrList function; - TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); - TF_RETURN_IF_ERROR( - flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle)); + flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle)); *fbody = flr->GetFunctionBody(handle); CHECK(*fbody); // Can't be nullptr since we just instantiated it. const DataTypeVector& arg_types = (*fbody)->arg_types; @@ -250,7 +241,9 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, // Create the kernel. NameAttrList function; - TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); + function.set_name(node_def.op()); + *(function.mutable_attr()) = node_def.attr(); + Device* dev = flr->device(); Status s; OpKernelConstruction construction( diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 1f20e201b7e..c836cb23898 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1302,12 +1302,10 @@ Status DirectSession::CreateExecutors( options_.config.experimental().has_session_metadata() ? &options_.config.experimental().session_metadata() : nullptr; - const CustomKernelCreator* custom_kernel_creator = - GetDefaultCustomKernelCreator(); func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime( device_mgr_.get(), options_.env, &options_.config, graph_def_version, func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first, - nullptr, custom_kernel_creator, session_metadata)); + nullptr, nullptr, session_metadata)); GraphOptimizer optimizer(optimizer_opts); for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) { diff --git a/tensorflow/python/compiler/xla/BUILD b/tensorflow/python/compiler/xla/BUILD index 9f7afae4052..2061f0cca2f 100644 --- a/tensorflow/python/compiler/xla/BUILD +++ b/tensorflow/python/compiler/xla/BUILD @@ -91,20 +91,3 @@ cuda_py_test( "@absl_py//absl/testing:parameterized", ], ) - -cuda_py_test( - name = "experimental_compile_test", - srcs = ["experimental_compile_test.py"], - additional_deps = [ - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_ops", - "//tensorflow/python:resource_variable_ops", - ], - python_version = "PY3", - tags = [ - "no_mac", - "no_windows", - ], - xla_enabled = True, -) diff --git a/tensorflow/python/compiler/xla/experimental_compile_test.py b/tensorflow/python/compiler/xla/experimental_compile_test.py deleted file mode 100644 index c8ea3d05001..00000000000 --- a/tensorflow/python/compiler/xla/experimental_compile_test.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.client import session -from tensorflow.python.eager import def_function -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test - - -class ExperimentalCompileTest(test.TestCase): - - def testBasic(self): - with ops.Graph().as_default() as g: - - def fn(x, a): - return x + a - - xla_func = def_function.function(fn, experimental_compile=True) - inputs = array_ops.placeholder(dtypes.float32, [5]) - # XLA support is not yet enabled for TF ROCm - if not test.is_built_with_rocm(): - x = xla_func(inputs, 1) - with session.Session(graph=g) as sess: - y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]}) - self.assertTrue( - x.graph.as_graph_def().library.function[0].attr["_XlaCompile"].b) - self.assertAllClose([2, 3, 3, 4, 4], y) - - # Checking that we crash on an unsupported operation lets us test that the XLA - # compiler was actually invoked. - def testUnsupportedOps(self): - with ops.Graph().as_default() as g: - - def fn(x): - return array_ops.unique(x).y # Unique is not supported by XLA - - xla_func = def_function.function(fn, experimental_compile=True) - inputs = array_ops.placeholder(dtypes.float32, [5]) - x = xla_func(inputs) - # XLA support is not yet enabled for TF ROCm - if not test.is_built_with_rocm(): - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "not compilable"): - with session.Session(graph=g) as sess: - sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]}) - - -if __name__ == "__main__": - test.main()