diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index e3706a09278..23bd7425dbd 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -21,7 +21,7 @@ namespace tensorflow { bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr, const NodeDef& node_def) const { - return CanCreateXlaKernel(flr, node_def); + return CanCreateXlaKernel(node_def); } Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr, diff --git a/tensorflow/compiler/jit/xla_kernel_creator_test.cc b/tensorflow/compiler/jit/xla_kernel_creator_test.cc index f1d9689268c..7ec37332906 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_test.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_test.cc @@ -98,12 +98,14 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) { (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true); Init({fdef}); XlaKernelCreator xla_kernel_creator; - - Status status = xla_kernel_creator.CreateKernel( - flr_, ToNodeDef(R"pb( + NodeDef callsite = + ToNodeDef(R"pb( name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' - )pb"), - &kernel_); + )pb"); + (*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true); + + // Note: need to set attribute on the created node. + Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_); ASSERT_TRUE(status.ok()) << status.ToString(); EXPECT_EQ("XTimesY", kernel_->name()); diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc index 43d5f0b924e..94727fdf35a 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_util.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_util.cc @@ -23,7 +23,9 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/ptr_util.h" @@ -68,40 +70,10 @@ class SinglePassSearch { }; } // namespace -bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr, - const NodeDef& node_def) { - const FunctionDef* function_def = - flr.GetFunctionLibraryDefinition()->Find(node_def.name()); - if (function_def == nullptr) { - // The node def is not calling a function. Individual ops can be - // run directly using on-demand mode, no need to create XlaLaunch - // kernel for them. - return false; - } - +bool CanCreateXlaKernel(const NodeDef& node_def) { // If kXlaMustCompileAttr is set on the node_def, use its value. const auto& it = node_def.attr().find(kXlaMustCompileAttr); - if (it != node_def.attr().end()) { - return it->second.b(); - } - - // kXlaMustCompileAttr is not set on node_def, check if it is set on - // FunctionDef. - bool xla_compile = false; - Status status = flr.GetFunctionLibraryDefinition()->GetAttr( - node_def, kXlaMustCompileAttr, &xla_compile); - if (!status.ok() || !xla_compile) { - if (VLOG_IS_ON(3)) { - if (!status.ok()) { - VLOG(3) << "No " << kXlaMustCompileAttr << " attr defined for " - << node_def.op() << ". status=" << status.ToString(); - } else { - VLOG(3) << node_def.op() << " is explicitly marked not to be compiled"; - } - } - return false; - } - return true; + return it != node_def.attr().end() && it->second.b(); } // Given a FunctionLibraryRuntime and a NodeDef calling a function in the @@ -118,8 +90,11 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle; // If node_def is not instantiable, e.g., the function does not exist, // simply bail out. + NameAttrList function; + TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); + TF_RETURN_IF_ERROR( - flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle)); + flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle)); *fbody = flr->GetFunctionBody(handle); CHECK(*fbody); // Can't be nullptr since we just instantiated it. const DataTypeVector& arg_types = (*fbody)->arg_types; @@ -149,7 +124,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, std::unique_ptr* kernel) { - if (!CanCreateXlaKernel(*flr, node_def)) { + if (!CanCreateXlaKernel(node_def)) { return errors::Internal("Invalid node: ", node_def.ShortDebugString()); } @@ -241,9 +216,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, // Create the kernel. NameAttrList function; - function.set_name(node_def.op()); - *(function.mutable_attr()) = node_def.attr(); - + TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); Device* dev = flr->device(); Status s; OpKernelConstruction construction( diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.h b/tensorflow/compiler/jit/xla_kernel_creator_util.h index 71398c334fc..5ec8df01f77 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_util.h +++ b/tensorflow/compiler/jit/xla_kernel_creator_util.h @@ -24,11 +24,9 @@ namespace tensorflow { class FunctionLibraryRuntime; class OpKernel; - // Given a NodeDef 'node_def' and the function library runtime 'flr', returns - // true if 'node_def' is a call to a compilable function defined in 'flr', - // with the kXlaCompileAttr set. -bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr, - const NodeDef& node_def); +// Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr +// set. +bool CanCreateXlaKernel(const NodeDef& node_def); // Given a supported NodeDef, returns a XlaLaunchOp that computes the node. Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 26e9df7b970..a00a551edf4 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2533,6 +2533,7 @@ tf_cuda_library( "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/util:port", "//tensorflow/core/util:stats_calculator_portable", + "//tensorflow/compiler/jit:common", ] + if_static( extra_deps = ["@com_google_protobuf//:protobuf"], otherwise = ["@com_google_protobuf//:protobuf_headers"], diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 0d9f897c5d7..9731d74b069 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1313,10 +1313,12 @@ Status DirectSession::CreateExecutors( options_.config.experimental().has_session_metadata() ? &options_.config.experimental().session_metadata() : nullptr; + const CustomKernelCreator* custom_kernel_creator = + GetDefaultCustomKernelCreator(); func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime( device_mgr_.get(), options_.env, &options_.config, graph_def_version, func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first, - nullptr, nullptr, session_metadata)); + nullptr, custom_kernel_creator, session_metadata)); GraphOptimizer optimizer(optimizer_opts); for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) { diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index 246f50acd26..5393b162e80 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" @@ -97,6 +99,11 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, inp_mtypes->clear(); out_mtypes->clear(); + bool has_xla_compile = [&] { + const auto& it = ndef.attr().find(kXlaMustCompileAttr); + return it != ndef.attr().end() && it->second.b(); + }(); + // For functions (which have no KernelDef) and their gradients, we can only // best-effort derive the memory type from the data type. For now, we assume // int32 is always on host memory and other types are always on device memory. @@ -104,7 +111,7 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, // to derive the correct input/output memory types. We should also split // host-memory and non host-memory arguments into separate type lists. if (!status.ok() || IsFunctionCallOp(ndef.op())) { - if (device_type.type_string() == "TPU") { + if (device_type.type_string() == "TPU" || has_xla_compile) { // Here we assume that if tf.function() is called within // "with tf.device('/device:TPU:0')", the whole function will be compiled // and executed on TPU. This is true today, but when we implement auto diff --git a/tensorflow/python/compiler/xla/BUILD b/tensorflow/python/compiler/xla/BUILD index 2061f0cca2f..9f7afae4052 100644 --- a/tensorflow/python/compiler/xla/BUILD +++ b/tensorflow/python/compiler/xla/BUILD @@ -91,3 +91,20 @@ cuda_py_test( "@absl_py//absl/testing:parameterized", ], ) + +cuda_py_test( + name = "experimental_compile_test", + srcs = ["experimental_compile_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:resource_variable_ops", + ], + python_version = "PY3", + tags = [ + "no_mac", + "no_windows", + ], + xla_enabled = True, +) diff --git a/tensorflow/python/compiler/xla/experimental_compile_test.py b/tensorflow/python/compiler/xla/experimental_compile_test.py new file mode 100644 index 00000000000..c0a1c4bf307 --- /dev/null +++ b/tensorflow/python/compiler/xla/experimental_compile_test.py @@ -0,0 +1,113 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.client import session +from tensorflow.python.eager import backprop +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ExperimentalCompileTest(test.TestCase): + + def testBasic(self): + with ops.Graph().as_default() as g: + + def fn(x, a): + return x + a + + xla_func = def_function.function(fn, experimental_compile=True) + inputs = array_ops.placeholder(dtypes.float32, [5]) + # XLA support is not yet enabled for TF ROCm + if not test.is_built_with_rocm(): + x = xla_func(inputs, 1) + with session.Session(graph=g) as sess: + y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]}) + self.assertTrue(x.graph.as_graph_def().library.function[0] + .attr["_XlaMustCompile"].b) + self.assertAllClose([2, 3, 3, 4, 4], y) + + def testDerivative(self): + # XLA support is not yet enabled for TF ROCm + if test.is_built_with_rocm(): + return + + def fn(x, a): + return 2 * x + a + + with ops.Graph().as_default() as g: + xla_func = def_function.function(fn, experimental_compile=True) + with backprop.GradientTape() as tape: + inputs = array_ops.placeholder(dtypes.float32, [5]) + tape.watch(inputs) + outputs = xla_func(inputs, 1) + grads = tape.gradient(outputs, inputs) + + with session.Session(graph=g) as sess: + grads_tensor = sess.run(grads, feed_dict={inputs: [1, 2, 2, 3, 3]}) + self.assertAllClose([2, 2, 2, 2, 2], grads_tensor) + (forward, backward) = xla_func.get_concrete_function( + inputs, 1)._delayed_rewrite_functions.forward_backward() + + # Check that the must-compile attribute gets correctly propagated to the + # created derivatives. + self.assertTrue(forward.definition.attr["_XlaMustCompile"]) + self.assertTrue(backward.function_def.attr["_XlaMustCompile"]) + + def testBasicInt32(self): + with ops.Graph().as_default() as g: + + def fn(x, a): + return x + a + + xla_func = def_function.function(fn, experimental_compile=True) + inputs = array_ops.placeholder(dtypes.int32, [5]) + # XLA support is not yet enabled for TF ROCm + if not test.is_built_with_rocm(): + x = xla_func(inputs, 1) + with session.Session(graph=g) as sess: + y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]}) + self.assertTrue(x.graph.as_graph_def().library.function[0] + .attr["_XlaMustCompile"].b) + self.assertAllClose([2, 3, 3, 4, 4], y) + + # Checking that we crash on an unsupported operation lets us test that the XLA + # compiler was actually invoked. + def testUnsupportedOps(self): + with ops.Graph().as_default() as g: + + def fn(x): + return array_ops.unique(x).y # Unique is not supported by XLA + + xla_func = def_function.function(fn, experimental_compile=True) + inputs = array_ops.placeholder(dtypes.float32, [5]) + x = xla_func(inputs) + # XLA support is not yet enabled for TF ROCm + if not test.is_built_with_rocm(): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "not compilable"): + with session.Session(graph=g) as sess: + sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]}) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 20acd42ac68..86d6f31848c 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -45,6 +46,18 @@ class DefFunctionTest(test.TestCase): # XLA support is not yet enabled for TF ROCm self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1)) + def testBasicInt32(self): + + def fn(x, a): + return x + a + + xla_func = def_function.function(fn, experimental_compile=True) + + inputs = constant_op.constant([1, 2, 2, 3, 3], dtype=dtypes.int32) + if not test.is_built_with_rocm(): + # XLA support is not yet enabled for TF ROCm + self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1)) + def testDerivative(self): if test.is_built_with_rocm(): return diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index c66970e8876..4698f870785 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -1163,17 +1163,19 @@ def partitioned_call(args, graph = ops.get_default_graph() f.add_to_graph(graph) op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall" - op = graph.create_op( - op_name, - args, - tout, - name=op_name, - attrs={ - "Tin": tin_attr, - "Tout": tout_attr, - "f": func_attr, - "config_proto": config_proto, - "executor_type": executor_type_attr, - }) + + # Propagate the attribute indicating the need to compile from function to the + # call itself. + xla_compile_attr = "_XlaMustCompile" + op_attrs = { + "Tin": tin_attr, + "Tout": tout_attr, + "f": func_attr, + "config_proto": config_proto, + "executor_type": executor_type_attr, + } + if xla_compile_attr in f.definition.attr: + op_attrs[xla_compile_attr] = f.definition.attr[xla_compile_attr] + op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs) outputs = op.outputs return outputs if outputs else op