Explicitly fail when experimental_compile=True is used, but XLA is not enabled
PiperOrigin-RevId: 292025089 Change-Id: I6412e21f5701553691205037944f9aab1df64ce2
This commit is contained in:
parent
53889e9671
commit
96e0b87d1e
@ -57,6 +57,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":jit_compilation_passes",
|
||||
":xla_kernel_creator", # buildcleaner: keep
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
@ -70,6 +71,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_cuda_or_rocm([
|
||||
":jit_compilation_passes",
|
||||
":xla_kernel_creator", # buildcleaner: keep
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
@ -250,8 +252,6 @@ cc_library(
|
||||
}),
|
||||
)
|
||||
|
||||
# Internal targets below this point.
|
||||
|
||||
cc_library(
|
||||
name = "flags",
|
||||
srcs = ["flags.cc"],
|
||||
@ -265,6 +265,20 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
# Header-only version of "flags" library, for linking from the shared object
|
||||
# without ODR violations.
|
||||
cc_library(
|
||||
name = "flags_headers_only",
|
||||
hdrs = ["flags.h"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
srcs = [
|
||||
@ -276,6 +290,8 @@ cc_library(
|
||||
visibility = [":friends"],
|
||||
)
|
||||
|
||||
# Internal targets below this point.
|
||||
|
||||
cc_library(
|
||||
name = "xla_launch_util",
|
||||
srcs = ["xla_launch_util.cc"],
|
||||
@ -397,6 +413,7 @@ cc_library(
|
||||
"xla_kernel_creator.h",
|
||||
],
|
||||
deps = [
|
||||
":flags",
|
||||
":jit_compilation_passes",
|
||||
":xla_kernel_creator_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
|
@ -13,13 +13,15 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
|
||||
#include <mutex> // NOLINT
|
||||
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/strip.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -247,4 +249,11 @@ void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
}
|
||||
|
||||
static bool xla_is_enabled = false;
|
||||
|
||||
void SetXlaIsEnabled() { xla_is_enabled = true; }
|
||||
|
||||
bool IsXlaEnabled() { return xla_is_enabled; }
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -154,6 +154,15 @@ GetIntroduceFloatingPointJitterPassFlags();
|
||||
// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
|
||||
void AppendMarkForCompilationPassFlags(
|
||||
std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// Makes all future calls to `IsXlaEnabled()` return `true`.
|
||||
//
|
||||
// Should only be called when XLA is linked in.
|
||||
void SetXlaIsEnabled();
|
||||
|
||||
// Returns whether XLA is enabled.
|
||||
bool IsXlaEnabled();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
|
||||
@ -39,6 +40,10 @@ bool RegisterLaunchOpCreator() {
|
||||
}
|
||||
|
||||
static bool register_me = RegisterLaunchOpCreator();
|
||||
static bool register_xla = [] {
|
||||
SetXlaIsEnabled();
|
||||
return true;
|
||||
}();
|
||||
|
||||
} // end namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -7728,11 +7728,13 @@ tf_python_pybind_extension(
|
||||
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
|
||||
] + if_static(
|
||||
extra_deps = [
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/core:eager_service_proto_cc",
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
],
|
||||
otherwise = [
|
||||
"//tensorflow/compiler/jit:flags_headers_only",
|
||||
"//tensorflow/core:eager_service_proto_cc_headers_only",
|
||||
"//tensorflow/core:master_proto_cc_headers_only",
|
||||
"//tensorflow/core:worker_proto_cc_headers_only",
|
||||
|
@ -741,6 +741,27 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "def_function_test_cpu_only",
|
||||
srcs = ["def_function_test_cpu_only.py"],
|
||||
python_version = "PY3",
|
||||
# --config=cuda implicitly links in XLA.
|
||||
tags = [
|
||||
"no_cuda_on_cpu_tap",
|
||||
"no_oss", # No way to force no XLA linkage in OSS build from here.
|
||||
"no_pip",
|
||||
"nogpu",
|
||||
],
|
||||
deps = [
|
||||
":def_function",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python/autograph/core",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "def_function_xla_jit_test",
|
||||
srcs = ["def_function_xla_jit_test.py"],
|
||||
|
@ -23,6 +23,7 @@ import functools
|
||||
import threading
|
||||
import weakref
|
||||
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function as function_lib
|
||||
from tensorflow.python.eager import lift_to_graph
|
||||
@ -452,6 +453,10 @@ class Function(object):
|
||||
attributes.update(_XlaMustCompile=bool(self._experimental_compile))
|
||||
if self._experimental_compile:
|
||||
attributes.update(_noinline=True)
|
||||
if not pywrap_tfe.TF_IsXlaEnabled():
|
||||
raise ValueError("Attempting to use experimental_compile, "
|
||||
"but XLA support is not linked in. "
|
||||
"Rebuild with --define=with_xla_support=true.")
|
||||
if not attributes:
|
||||
attributes = None
|
||||
return function_lib.defun_with_attributes(
|
||||
@ -1196,6 +1201,10 @@ def function(func=None,
|
||||
function (and return zero or more `tf.Tensor` objects).
|
||||
If `func` is None, returns a decorator that, when invoked with a single
|
||||
`func` argument, returns a callable equivalent to the case above.
|
||||
|
||||
Raises:
|
||||
ValueError when attempting to use experimental_compile, but XLA support is
|
||||
not enabled.
|
||||
"""
|
||||
if input_signature is not None:
|
||||
function_lib.validate_signature(input_signature)
|
||||
|
51
tensorflow/python/eager/def_function_test_cpu_only.py
Normal file
51
tensorflow/python/eager/def_function_test_cpu_only.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class DefFunctionCpuOnlyTest(test.TestCase, parameterized.TestCase):
|
||||
"""Test that experimental_compile=True correctly throws an exception if XLA is not available.
|
||||
|
||||
This test should only be run without `--config=cuda`, as that implicitly links
|
||||
in XLA JIT.
|
||||
"""
|
||||
|
||||
def testExperimentalCompileRaisesExceptionWhenXlaIsUnsupported(self):
|
||||
if test.is_built_with_rocm() or test_util.is_xla_enabled():
|
||||
return
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'XLA support is not'):
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def fn(x):
|
||||
return array_ops.unique(x).y
|
||||
|
||||
fn([1, 1, 2, 3])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||
@ -338,6 +339,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
|
||||
m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
|
||||
m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
|
||||
m.def("TF_IsXlaEnabled", [] { return tensorflow::IsXlaEnabled(); });
|
||||
|
||||
// // TFE_Context Logic
|
||||
m.def(
|
||||
|
Loading…
Reference in New Issue
Block a user