Explicitly fail when experimental_compile=True is used, but XLA is not enabled

PiperOrigin-RevId: 292025089
Change-Id: I6412e21f5701553691205037944f9aab1df64ce2
This commit is contained in:
George Karpenkov 2020-01-28 15:24:09 -08:00 committed by TensorFlower Gardener
parent 53889e9671
commit 96e0b87d1e
9 changed files with 128 additions and 3 deletions

View File

@ -57,6 +57,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":jit_compilation_passes",
":xla_kernel_creator", # buildcleaner: keep
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@ -70,6 +71,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = if_cuda_or_rocm([
":jit_compilation_passes",
":xla_kernel_creator", # buildcleaner: keep
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
@ -250,8 +252,6 @@ cc_library(
}),
)
# Internal targets below this point.
cc_library(
name = "flags",
srcs = ["flags.cc"],
@ -265,6 +265,20 @@ cc_library(
],
)
# Header-only version of "flags" library, for linking from the shared object
# without ODR violations.
cc_library(
name = "flags_headers_only",
hdrs = ["flags.h"],
visibility = [":friends"],
deps = [
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "common",
srcs = [
@ -276,6 +290,8 @@ cc_library(
visibility = [":friends"],
)
# Internal targets below this point.
cc_library(
name = "xla_launch_util",
srcs = ["xla_launch_util.cc"],
@ -397,6 +413,7 @@ cc_library(
"xla_kernel_creator.h",
],
deps = [
":flags",
":jit_compilation_passes",
":xla_kernel_creator_util",
"//tensorflow/core:core_cpu_internal",

View File

@ -13,13 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/flags.h"
#include <mutex> // NOLINT
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "absl/strings/strip.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
@ -247,4 +249,11 @@ void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
std::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list);
}
static bool xla_is_enabled = false;
void SetXlaIsEnabled() { xla_is_enabled = true; }
bool IsXlaEnabled() { return xla_is_enabled; }
} // namespace tensorflow

View File

@ -154,6 +154,15 @@ GetIntroduceFloatingPointJitterPassFlags();
// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
void AppendMarkForCompilationPassFlags(
std::vector<tensorflow::Flag>* flag_list);
// Makes all future calls to `IsXlaEnabled()` return `true`.
//
// Should only be called when XLA is linked in.
void SetXlaIsEnabled();
// Returns whether XLA is enabled.
bool IsXlaEnabled();
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
#include "tensorflow/core/common_runtime/function.h"
@ -39,6 +40,10 @@ bool RegisterLaunchOpCreator() {
}
static bool register_me = RegisterLaunchOpCreator();
static bool register_xla = [] {
SetXlaIsEnabled();
return true;
}();
} // end namespace
} // namespace tensorflow

View File

@ -7728,11 +7728,13 @@ tf_python_pybind_extension(
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
] + if_static(
extra_deps = [
"//tensorflow/compiler/jit:flags",
"//tensorflow/core:eager_service_proto_cc",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:worker_proto_cc",
],
otherwise = [
"//tensorflow/compiler/jit:flags_headers_only",
"//tensorflow/core:eager_service_proto_cc_headers_only",
"//tensorflow/core:master_proto_cc_headers_only",
"//tensorflow/core:worker_proto_cc_headers_only",

View File

@ -741,6 +741,27 @@ cuda_py_test(
],
)
tf_py_test(
name = "def_function_test_cpu_only",
srcs = ["def_function_test_cpu_only.py"],
python_version = "PY3",
# --config=cuda implicitly links in XLA.
tags = [
"no_cuda_on_cpu_tap",
"no_oss", # No way to force no XLA linkage in OSS build from here.
"no_pip",
"nogpu",
],
deps = [
":def_function",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
"//tensorflow/python/autograph/core",
"@absl_py//absl/testing:parameterized",
],
)
cuda_py_test(
name = "def_function_xla_jit_test",
srcs = ["def_function_xla_jit_test.py"],

View File

@ -23,6 +23,7 @@ import functools
import threading
import weakref
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import context
from tensorflow.python.eager import function as function_lib
from tensorflow.python.eager import lift_to_graph
@ -452,6 +453,10 @@ class Function(object):
attributes.update(_XlaMustCompile=bool(self._experimental_compile))
if self._experimental_compile:
attributes.update(_noinline=True)
if not pywrap_tfe.TF_IsXlaEnabled():
raise ValueError("Attempting to use experimental_compile, "
"but XLA support is not linked in. "
"Rebuild with --define=with_xla_support=true.")
if not attributes:
attributes = None
return function_lib.defun_with_attributes(
@ -1196,6 +1201,10 @@ def function(func=None,
function (and return zero or more `tf.Tensor` objects).
If `func` is None, returns a decorator that, when invoked with a single
`func` argument, returns a callable equivalent to the case above.
Raises:
ValueError when attempting to use experimental_compile, but XLA support is
not enabled.
"""
if input_signature is not None:
function_lib.validate_signature(input_signature)

View File

@ -0,0 +1,51 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class DefFunctionCpuOnlyTest(test.TestCase, parameterized.TestCase):
"""Test that experimental_compile=True correctly throws an exception if XLA is not available.
This test should only be run without `--config=cuda`, as that implicitly links
in XLA JIT.
"""
def testExperimentalCompileRaisesExceptionWhenXlaIsUnsupported(self):
if test.is_built_with_rocm() or test_util.is_xla_enabled():
return
with self.assertRaisesRegexp(ValueError, 'XLA support is not'):
@def_function.function(experimental_compile=True)
def fn(x):
return array_ops.unique(x).y
fn([1, 1, 2, 3])
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/lib/core/py_exception_registry.h"
@ -338,6 +339,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
m.def("TF_IsXlaEnabled", [] { return tensorflow::IsXlaEnabled(); });
// // TFE_Context Logic
m.def(