Explicitly fail when experimental_compile=True is used, but XLA is not enabled
PiperOrigin-RevId: 292025089 Change-Id: I6412e21f5701553691205037944f9aab1df64ce2
This commit is contained in:
parent
53889e9671
commit
96e0b87d1e
@ -57,6 +57,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
|
":xla_kernel_creator", # buildcleaner: keep
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
@ -70,6 +71,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = if_cuda_or_rocm([
|
deps = if_cuda_or_rocm([
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
|
":xla_kernel_creator", # buildcleaner: keep
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||||
@ -250,8 +252,6 @@ cc_library(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Internal targets below this point.
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "flags",
|
name = "flags",
|
||||||
srcs = ["flags.cc"],
|
srcs = ["flags.cc"],
|
||||||
@ -265,6 +265,20 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Header-only version of "flags" library, for linking from the shared object
|
||||||
|
# without ODR violations.
|
||||||
|
cc_library(
|
||||||
|
name = "flags_headers_only",
|
||||||
|
hdrs = ["flags.h"],
|
||||||
|
visibility = [":friends"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "common",
|
name = "common",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -276,6 +290,8 @@ cc_library(
|
|||||||
visibility = [":friends"],
|
visibility = [":friends"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Internal targets below this point.
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "xla_launch_util",
|
name = "xla_launch_util",
|
||||||
srcs = ["xla_launch_util.cc"],
|
srcs = ["xla_launch_util.cc"],
|
||||||
@ -397,6 +413,7 @@ cc_library(
|
|||||||
"xla_kernel_creator.h",
|
"xla_kernel_creator.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":flags",
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
":xla_kernel_creator_util",
|
":xla_kernel_creator_util",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
@ -13,13 +13,15 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
|
|
||||||
#include <mutex> // NOLINT
|
#include <mutex> // NOLINT
|
||||||
|
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/numbers.h"
|
||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
#include "absl/strings/strip.h"
|
#include "absl/strings/strip.h"
|
||||||
#include "tensorflow/compiler/jit/flags.h"
|
|
||||||
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -247,4 +249,11 @@ void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
|||||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool xla_is_enabled = false;
|
||||||
|
|
||||||
|
void SetXlaIsEnabled() { xla_is_enabled = true; }
|
||||||
|
|
||||||
|
bool IsXlaEnabled() { return xla_is_enabled; }
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -154,6 +154,15 @@ GetIntroduceFloatingPointJitterPassFlags();
|
|||||||
// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
|
// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
|
||||||
void AppendMarkForCompilationPassFlags(
|
void AppendMarkForCompilationPassFlags(
|
||||||
std::vector<tensorflow::Flag>* flag_list);
|
std::vector<tensorflow::Flag>* flag_list);
|
||||||
|
|
||||||
|
// Makes all future calls to `IsXlaEnabled()` return `true`.
|
||||||
|
//
|
||||||
|
// Should only be called when XLA is linked in.
|
||||||
|
void SetXlaIsEnabled();
|
||||||
|
|
||||||
|
// Returns whether XLA is enabled.
|
||||||
|
bool IsXlaEnabled();
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_
|
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_
|
||||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
|
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
|
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
|
||||||
@ -39,6 +40,10 @@ bool RegisterLaunchOpCreator() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool register_me = RegisterLaunchOpCreator();
|
static bool register_me = RegisterLaunchOpCreator();
|
||||||
|
static bool register_xla = [] {
|
||||||
|
SetXlaIsEnabled();
|
||||||
|
return true;
|
||||||
|
}();
|
||||||
|
|
||||||
} // end namespace
|
} // end namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -7728,11 +7728,13 @@ tf_python_pybind_extension(
|
|||||||
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
|
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
|
||||||
] + if_static(
|
] + if_static(
|
||||||
extra_deps = [
|
extra_deps = [
|
||||||
|
"//tensorflow/compiler/jit:flags",
|
||||||
"//tensorflow/core:eager_service_proto_cc",
|
"//tensorflow/core:eager_service_proto_cc",
|
||||||
"//tensorflow/core:master_proto_cc",
|
"//tensorflow/core:master_proto_cc",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
"//tensorflow/core:worker_proto_cc",
|
||||||
],
|
],
|
||||||
otherwise = [
|
otherwise = [
|
||||||
|
"//tensorflow/compiler/jit:flags_headers_only",
|
||||||
"//tensorflow/core:eager_service_proto_cc_headers_only",
|
"//tensorflow/core:eager_service_proto_cc_headers_only",
|
||||||
"//tensorflow/core:master_proto_cc_headers_only",
|
"//tensorflow/core:master_proto_cc_headers_only",
|
||||||
"//tensorflow/core:worker_proto_cc_headers_only",
|
"//tensorflow/core:worker_proto_cc_headers_only",
|
||||||
|
@ -741,6 +741,27 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "def_function_test_cpu_only",
|
||||||
|
srcs = ["def_function_test_cpu_only.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
# --config=cuda implicitly links in XLA.
|
||||||
|
tags = [
|
||||||
|
"no_cuda_on_cpu_tap",
|
||||||
|
"no_oss", # No way to force no XLA linkage in OSS build from here.
|
||||||
|
"no_pip",
|
||||||
|
"nogpu",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":def_function",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python/autograph/core",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "def_function_xla_jit_test",
|
name = "def_function_xla_jit_test",
|
||||||
srcs = ["def_function_xla_jit_test.py"],
|
srcs = ["def_function_xla_jit_test.py"],
|
||||||
|
@ -23,6 +23,7 @@ import functools
|
|||||||
import threading
|
import threading
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import function as function_lib
|
from tensorflow.python.eager import function as function_lib
|
||||||
from tensorflow.python.eager import lift_to_graph
|
from tensorflow.python.eager import lift_to_graph
|
||||||
@ -452,6 +453,10 @@ class Function(object):
|
|||||||
attributes.update(_XlaMustCompile=bool(self._experimental_compile))
|
attributes.update(_XlaMustCompile=bool(self._experimental_compile))
|
||||||
if self._experimental_compile:
|
if self._experimental_compile:
|
||||||
attributes.update(_noinline=True)
|
attributes.update(_noinline=True)
|
||||||
|
if not pywrap_tfe.TF_IsXlaEnabled():
|
||||||
|
raise ValueError("Attempting to use experimental_compile, "
|
||||||
|
"but XLA support is not linked in. "
|
||||||
|
"Rebuild with --define=with_xla_support=true.")
|
||||||
if not attributes:
|
if not attributes:
|
||||||
attributes = None
|
attributes = None
|
||||||
return function_lib.defun_with_attributes(
|
return function_lib.defun_with_attributes(
|
||||||
@ -1196,6 +1201,10 @@ def function(func=None,
|
|||||||
function (and return zero or more `tf.Tensor` objects).
|
function (and return zero or more `tf.Tensor` objects).
|
||||||
If `func` is None, returns a decorator that, when invoked with a single
|
If `func` is None, returns a decorator that, when invoked with a single
|
||||||
`func` argument, returns a callable equivalent to the case above.
|
`func` argument, returns a callable equivalent to the case above.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError when attempting to use experimental_compile, but XLA support is
|
||||||
|
not enabled.
|
||||||
"""
|
"""
|
||||||
if input_signature is not None:
|
if input_signature is not None:
|
||||||
function_lib.validate_signature(input_signature)
|
function_lib.validate_signature(input_signature)
|
||||||
|
51
tensorflow/python/eager/def_function_test_cpu_only.py
Normal file
51
tensorflow/python/eager/def_function_test_cpu_only.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class DefFunctionCpuOnlyTest(test.TestCase, parameterized.TestCase):
|
||||||
|
"""Test that experimental_compile=True correctly throws an exception if XLA is not available.
|
||||||
|
|
||||||
|
This test should only be run without `--config=cuda`, as that implicitly links
|
||||||
|
in XLA JIT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def testExperimentalCompileRaisesExceptionWhenXlaIsUnsupported(self):
|
||||||
|
if test.is_built_with_rocm() or test_util.is_xla_enabled():
|
||||||
|
return
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'XLA support is not'):
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def fn(x):
|
||||||
|
return array_ops.unique(x).y
|
||||||
|
|
||||||
|
fn([1, 1, 2, 3])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
ops.enable_eager_execution()
|
||||||
|
test.main()
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
|
#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
|
||||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||||
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||||
@ -338,6 +339,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
|
m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
|
||||||
m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
|
m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
|
||||||
m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
|
m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
|
||||||
|
m.def("TF_IsXlaEnabled", [] { return tensorflow::IsXlaEnabled(); });
|
||||||
|
|
||||||
// // TFE_Context Logic
|
// // TFE_Context Logic
|
||||||
m.def(
|
m.def(
|
||||||
|
Loading…
Reference in New Issue
Block a user