diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index f380032ca33..111bed82c54 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -57,6 +57,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":jit_compilation_passes", + ":xla_kernel_creator", # buildcleaner: keep "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", @@ -70,6 +71,7 @@ cc_library( visibility = ["//visibility:public"], deps = if_cuda_or_rocm([ ":jit_compilation_passes", + ":xla_kernel_creator", # buildcleaner: keep "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", @@ -250,8 +252,6 @@ cc_library( }), ) -# Internal targets below this point. - cc_library( name = "flags", srcs = ["flags.cc"], @@ -265,6 +265,20 @@ cc_library( ], ) +# Header-only version of "flags" library, for linking from the shared object +# without ODR violations. +cc_library( + name = "flags_headers_only", + hdrs = ["flags.h"], + visibility = [":friends"], + deps = [ + "//tensorflow/compiler/xla:parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "common", srcs = [ @@ -276,6 +290,8 @@ cc_library( visibility = [":friends"], ) +# Internal targets below this point. + cc_library( name = "xla_launch_util", srcs = ["xla_launch_util.cc"], @@ -397,6 +413,7 @@ cc_library( "xla_kernel_creator.h", ], deps = [ + ":flags", ":jit_compilation_passes", ":xla_kernel_creator_util", "//tensorflow/core:core_cpu_internal", diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 991ad82daa1..a3698571715 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/jit/flags.h" + #include <mutex> // NOLINT #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" #include "absl/strings/strip.h" -#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/xla/parse_flags_from_env.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { @@ -247,4 +249,11 @@ void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) { std::call_once(flags_init, &AllocateAndParseFlags); AppendMarkForCompilationPassFlagsInternal(flag_list); } + +static bool xla_is_enabled = false; + +void SetXlaIsEnabled() { xla_is_enabled = true; } + +bool IsXlaEnabled() { return xla_is_enabled; } + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 618e839fa36..b77a009b49f 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -154,6 +154,15 @@ GetIntroduceFloatingPointJitterPassFlags(); // Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet. void AppendMarkForCompilationPassFlags( std::vector<tensorflow::Flag>* flag_list); + +// Makes all future calls to `IsXlaEnabled()` return `true`. +// +// Should only be called when XLA is linked in. +void SetXlaIsEnabled(); + +// Returns whether XLA is enabled. +bool IsXlaEnabled(); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_ diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index 23bd7425dbd..6ee1db2c7c5 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/xla_kernel_creator.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_kernel_creator_util.h" #include "tensorflow/core/common_runtime/function.h" @@ -39,6 +40,10 @@ bool RegisterLaunchOpCreator() { } static bool register_me = RegisterLaunchOpCreator(); +static bool register_xla = [] { + SetXlaIsEnabled(); + return true; +}(); } // end namespace } // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index be7f8da7411..17c5ce2ca8d 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -7728,11 +7728,13 @@ tf_python_pybind_extension( "//tensorflow/core/profiler/protobuf:xplane_proto_cc", ] + if_static( extra_deps = [ + "//tensorflow/compiler/jit:flags", "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:master_proto_cc", "//tensorflow/core:worker_proto_cc", ], otherwise = [ + "//tensorflow/compiler/jit:flags_headers_only", "//tensorflow/core:eager_service_proto_cc_headers_only", "//tensorflow/core:master_proto_cc_headers_only", "//tensorflow/core:worker_proto_cc_headers_only", diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index c4fd95b0fd9..b5d85879c07 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -741,6 +741,27 @@ cuda_py_test( ], ) +tf_py_test( + name = "def_function_test_cpu_only", + srcs = ["def_function_test_cpu_only.py"], + python_version = "PY3", + # --config=cuda implicitly links in XLA. + tags = [ + "no_cuda_on_cpu_tap", + "no_oss", # No way to force no XLA linkage in OSS build from here. + "no_pip", + "nogpu", + ], + deps = [ + ":def_function", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python/autograph/core", + "@absl_py//absl/testing:parameterized", + ], +) + cuda_py_test( name = "def_function_xla_jit_test", srcs = ["def_function_xla_jit_test.py"], diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index c2d5e9f0fd2..58bb8faabfe 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -23,6 +23,7 @@ import functools import threading import weakref +from tensorflow.python import pywrap_tfe from tensorflow.python.eager import context from tensorflow.python.eager import function as function_lib from tensorflow.python.eager import lift_to_graph @@ -452,6 +453,10 @@ class Function(object): attributes.update(_XlaMustCompile=bool(self._experimental_compile)) if self._experimental_compile: attributes.update(_noinline=True) + if not pywrap_tfe.TF_IsXlaEnabled(): + raise ValueError("Attempting to use experimental_compile, " + "but XLA support is not linked in. " + "Rebuild with --define=with_xla_support=true.") if not attributes: attributes = None return function_lib.defun_with_attributes( @@ -1196,6 +1201,10 @@ def function(func=None, function (and return zero or more `tf.Tensor` objects). If `func` is None, returns a decorator that, when invoked with a single `func` argument, returns a callable equivalent to the case above. + + Raises: + ValueError when attempting to use experimental_compile, but XLA support is + not enabled. """ if input_signature is not None: function_lib.validate_signature(input_signature) diff --git a/tensorflow/python/eager/def_function_test_cpu_only.py b/tensorflow/python/eager/def_function_test_cpu_only.py new file mode 100644 index 00000000000..bd3774269ea --- /dev/null +++ b/tensorflow/python/eager/def_function_test_cpu_only.py @@ -0,0 +1,51 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.eager import def_function +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class DefFunctionCpuOnlyTest(test.TestCase, parameterized.TestCase): + """Test that experimental_compile=True correctly throws an exception if XLA is not available. + + This test should only be run without `--config=cuda`, as that implicitly links + in XLA JIT. + """ + + def testExperimentalCompileRaisesExceptionWhenXlaIsUnsupported(self): + if test.is_built_with_rocm() or test_util.is_xla_enabled(): + return + + with self.assertRaisesRegexp(ValueError, 'XLA support is not'): + + @def_function.function(experimental_compile=True) + def fn(x): + return array_ops.unique(x).y + + fn([1, 1, 2, 3]) + + +if __name__ == '__main__': + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 9de5a19c115..1c135bb9dbf 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/python/eager/pywrap_tensor_conversion.h" #include "tensorflow/python/eager/pywrap_tfe.h" #include "tensorflow/python/lib/core/py_exception_registry.h" @@ -338,6 +339,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) { m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled); m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled); m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize); + m.def("TF_IsXlaEnabled", [] { return tensorflow::IsXlaEnabled(); }); // // TFE_Context Logic m.def(