[TF:XLA] Enable XLA through autojit for all tensorflow/python/kernel_test/ tests.
Some test methods are disabled, but all tests now have a new "_xla" version of the test for XLA:GPU testing. This will run 2 different tests. One with XLA and one without. PiperOrigin-RevId: 229149574
This commit is contained in:
parent
6dd6ad9fd7
commit
4ac9bda9d0
@ -1103,6 +1103,14 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Including this as a dependency will result in tests using
|
||||||
|
# :framework_test_lib to use XLA.
|
||||||
|
py_library(
|
||||||
|
name = "is_xla_test_true",
|
||||||
|
srcs = ["framework/is_xla_test_true.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "distributed_framework_test_lib",
|
name = "distributed_framework_test_lib",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
|
@ -284,6 +284,7 @@ class FunctionTest(test.TestCase):
|
|||||||
out, = sess.run(dlogits, {logits: x, labels: y})
|
out, = sess.run(dlogits, {logits: x, labels: y})
|
||||||
self.assertAllClose(out, np.exp(prob - y))
|
self.assertAllClose(out, np.exp(prob - y))
|
||||||
|
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testCustomGradientError(self):
|
def testCustomGradientError(self):
|
||||||
dtype = dtypes.float32
|
dtype = dtypes.float32
|
||||||
|
|
||||||
|
29
tensorflow/python/framework/is_xla_test_true.py
Normal file
29
tensorflow/python/framework/is_xla_test_true.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Including this as a dependency will result in Tensorflow tests using XLA.
|
||||||
|
|
||||||
|
This function is defined by default in test_util.py to False. The test_util then
|
||||||
|
attempts to import this module. If this file is made available through the BUILD
|
||||||
|
rule, then this function is overridden and will instead cause Tensorflow graphs
|
||||||
|
to be compiled with XLA.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
def is_xla_enabled():
|
||||||
|
"""Returns true to state XLA should be enabled for Tensorflow tests."""
|
||||||
|
return True
|
@ -82,6 +82,19 @@ from tensorflow.python.util.protobuf import compare
|
|||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
# If the above import is made available through the BUILD rule, then this
|
||||||
|
# function is overridden and will instead return True and cause Tensorflow
|
||||||
|
# graphs to be compiled with XLA.
|
||||||
|
def is_xla_enabled():
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@tf_export("test.gpu_device_name")
|
@tf_export("test.gpu_device_name")
|
||||||
def gpu_device_name():
|
def gpu_device_name():
|
||||||
"""Returns the name of a GPU device if available or the empty string."""
|
"""Returns the name of a GPU device if available or the empty string."""
|
||||||
@ -97,6 +110,7 @@ def assert_ops_in_graph(expected_ops, graph):
|
|||||||
Args:
|
Args:
|
||||||
expected_ops: `dict<string, string>` of op name to op type.
|
expected_ops: `dict<string, string>` of op name to op type.
|
||||||
graph: Graph to check.
|
graph: Graph to check.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`dict<string, node>` of node name to node.
|
`dict<string, node>` of node name to node.
|
||||||
|
|
||||||
@ -149,7 +163,7 @@ def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False):
|
|||||||
actual: The `GraphDef` we have.
|
actual: The `GraphDef` we have.
|
||||||
expected: The `GraphDef` we expected.
|
expected: The `GraphDef` we expected.
|
||||||
checkpoint_v2: boolean determining whether to ignore randomized attribute
|
checkpoint_v2: boolean determining whether to ignore randomized attribute
|
||||||
values that appear in V2 checkpoints.
|
values that appear in V2 checkpoints.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AssertionError: If the `GraphDef`s do not match.
|
AssertionError: If the `GraphDef`s do not match.
|
||||||
@ -360,7 +374,8 @@ def skip_if(condition):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
condition: Either an expression that can be used in "if not condition"
|
condition: Either an expression that can be used in "if not condition"
|
||||||
statement, or a callable whose result should be a boolean.
|
statement, or a callable whose result should be a boolean.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The wrapped function
|
The wrapped function
|
||||||
"""
|
"""
|
||||||
@ -483,9 +498,11 @@ def disable_control_flow_v2(unused_msg):
|
|||||||
Returns:
|
Returns:
|
||||||
The wrapped function with _disable_control_flow_v2 attr set to True.
|
The wrapped function with _disable_control_flow_v2 attr set to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(func):
|
def wrapper(func):
|
||||||
func._disable_control_flow_v2 = True
|
func._disable_control_flow_v2 = True
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@ -568,6 +585,7 @@ def assert_no_new_tensors(f):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
f: The test case to run.
|
f: The test case to run.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The decorated test case.
|
The decorated test case.
|
||||||
"""
|
"""
|
||||||
@ -727,6 +745,7 @@ def assert_no_garbage_created(f):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
f: The function to decorate.
|
f: The function to decorate.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The decorated function.
|
The decorated function.
|
||||||
"""
|
"""
|
||||||
@ -799,8 +818,8 @@ def _combine_named_parameters(**kwargs):
|
|||||||
can be computed using `times()`.
|
can be computed using `times()`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: keyword arguments of form `option=[possibilities, ...]`
|
**kwargs: keyword arguments of form `option=[possibilities, ...]` or
|
||||||
or `option=the_only_possibility`.
|
`option=the_only_possibility`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a list of dictionaries for each combination. Keys in the dictionaries are
|
a list of dictionaries for each combination. Keys in the dictionaries are
|
||||||
@ -838,8 +857,8 @@ def generate_combinations_with_testcase_name(**kwargs):
|
|||||||
parameterized tests.
|
parameterized tests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: keyword arguments of form `option=[possibilities, ...]`
|
**kwargs: keyword arguments of form `option=[possibilities, ...]` or
|
||||||
or `option=the_only_possibility`.
|
`option=the_only_possibility`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a list of dictionaries for each combination. Keys in the dictionaries are
|
a list of dictionaries for each combination. Keys in the dictionaries are
|
||||||
@ -866,11 +885,12 @@ def generate_combinations_with_testcase_name(**kwargs):
|
|||||||
def run_all_in_graph_and_eager_modes(cls):
|
def run_all_in_graph_and_eager_modes(cls):
|
||||||
"""Execute all test methods in the given class with and without eager."""
|
"""Execute all test methods in the given class with and without eager."""
|
||||||
base_decorator = run_in_graph_and_eager_modes
|
base_decorator = run_in_graph_and_eager_modes
|
||||||
for name, value in cls.__dict__.copy().items():
|
for name in dir(cls):
|
||||||
if (callable(value) and
|
value = getattr(cls, name)
|
||||||
name.startswith(unittest.TestLoader.testMethodPrefix) and
|
if callable(value) and name.startswith(
|
||||||
not (name.startswith("testSkipEager")
|
unittest.TestLoader.testMethodPrefix) and not (
|
||||||
or name.startswith("test_skip_eager"))):
|
name.startswith("testSkipEager") or
|
||||||
|
name.startswith("test_skip_eager") or name == "test_session"):
|
||||||
setattr(cls, name, base_decorator(value))
|
setattr(cls, name, base_decorator(value))
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@ -917,8 +937,8 @@ def run_in_graph_and_eager_modes(func=None,
|
|||||||
func: function to be annotated. If `func` is None, this method returns a
|
func: function to be annotated. If `func` is None, this method returns a
|
||||||
decorator the can be applied to a function. If `func` is not None this
|
decorator the can be applied to a function. If `func` is not None this
|
||||||
returns the decorator applied to `func`.
|
returns the decorator applied to `func`.
|
||||||
config: An optional config_pb2.ConfigProto to use to configure the
|
config: An optional config_pb2.ConfigProto to use to configure the session
|
||||||
session when executing graphs.
|
when executing graphs.
|
||||||
use_gpu: If True, attempt to run as many operations as possible on GPU.
|
use_gpu: If True, attempt to run as many operations as possible on GPU.
|
||||||
reset_test: If True, tearDown and SetUp the test case between the two
|
reset_test: If True, tearDown and SetUp the test case between the two
|
||||||
executions of the test (once with and once without eager execution).
|
executions of the test (once with and once without eager execution).
|
||||||
@ -932,6 +952,7 @@ def run_in_graph_and_eager_modes(func=None,
|
|||||||
collected elsewhere in the unit test file will not work). Additionally,
|
collected elsewhere in the unit test file will not work). Additionally,
|
||||||
checks that nothing still has a reference to Tensors that the test
|
checks that nothing still has a reference to Tensors that the test
|
||||||
allocated.
|
allocated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns a decorator that will run the decorated test method twice:
|
Returns a decorator that will run the decorated test method twice:
|
||||||
once by constructing and executing a graph in a session and once with
|
once by constructing and executing a graph in a session and once with
|
||||||
@ -992,9 +1013,10 @@ def py_func_if_in_function(f):
|
|||||||
if not ops.get_default_graph()._building_function:
|
if not ops.get_default_graph()._building_function:
|
||||||
return f(*args, **kwds)
|
return f(*args, **kwds)
|
||||||
|
|
||||||
tensor_args, tensor_indices = zip(
|
tensor_args, tensor_indices = zip(*[(x, i)
|
||||||
*[(x, i) for i, x in enumerate(args)
|
for i, x in enumerate(args)
|
||||||
if isinstance(x, (ops.Tensor, variables.Variable))])
|
if isinstance(x, (ops.Tensor,
|
||||||
|
variables.Variable))])
|
||||||
|
|
||||||
def inner_f(*inner_tensor_args):
|
def inner_f(*inner_tensor_args):
|
||||||
my_args = list(args)
|
my_args = list(args)
|
||||||
@ -1056,6 +1078,7 @@ def deprecated_graph_mode_only(func=None):
|
|||||||
func: function to be annotated. If `func` is None, this method returns a
|
func: function to be annotated. If `func` is None, this method returns a
|
||||||
decorator the can be applied to a function. If `func` is not None this
|
decorator the can be applied to a function. If `func` is not None this
|
||||||
returns the decorator applied to `func`.
|
returns the decorator applied to `func`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns a decorator that will run the decorated test method in graph mode.
|
Returns a decorator that will run the decorated test method in graph mode.
|
||||||
"""
|
"""
|
||||||
@ -1390,8 +1413,7 @@ class FakeEagerSession(object):
|
|||||||
|
|
||||||
|
|
||||||
class ErrorLoggingSession(session.Session):
|
class ErrorLoggingSession(session.Session):
|
||||||
"""Wrapper around a Session that logs errors in run().
|
"""Wrapper around a Session that logs errors in run()."""
|
||||||
"""
|
|
||||||
|
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
@ -1405,13 +1427,56 @@ class ErrorLoggingSession(session.Session):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# The description is just for documentation purposes.
|
||||||
|
def disable_xla(description):
|
||||||
|
|
||||||
|
def disable_xla_impl(func):
|
||||||
|
"""Execute the test method only if xla is not enabled."""
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
|
||||||
|
def decorated(self, *args, **kwargs):
|
||||||
|
if is_xla_enabled():
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
if func is not None:
|
||||||
|
return decorator(func)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
return disable_xla_impl
|
||||||
|
|
||||||
|
|
||||||
|
# The description is just for documentation purposes.
|
||||||
|
def disable_all_xla(description):
|
||||||
|
|
||||||
|
def disable_all_impl(cls):
|
||||||
|
"""Execute all test methods in this class only if xla is not enabled."""
|
||||||
|
base_decorator = disable_xla
|
||||||
|
for name in dir(cls):
|
||||||
|
value = getattr(cls, name)
|
||||||
|
if callable(value) and name.startswith(
|
||||||
|
"test") and not name == "test_session":
|
||||||
|
setattr(cls, name, base_decorator(value))
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return disable_all_impl
|
||||||
|
|
||||||
|
|
||||||
@tf_export("test.TestCase")
|
@tf_export("test.TestCase")
|
||||||
class TensorFlowTestCase(googletest.TestCase):
|
class TensorFlowTestCase(googletest.TestCase):
|
||||||
"""Base class for tests that need to test TensorFlow.
|
"""Base class for tests that need to test TensorFlow."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
|
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
|
||||||
super(TensorFlowTestCase, self).__init__(methodName)
|
super(TensorFlowTestCase, self).__init__(methodName)
|
||||||
|
if is_xla_enabled():
|
||||||
|
os.putenv(
|
||||||
|
"TF_XLA_FLAGS", "--tf_xla_auto_jit=2 --tf_xla_min_cluster_size=1 "
|
||||||
|
"--tf_xla_enable_lazy_compilation=false")
|
||||||
self._threads = []
|
self._threads = []
|
||||||
self._tempdir = None
|
self._tempdir = None
|
||||||
self._cached_session = None
|
self._cached_session = None
|
||||||
@ -1489,9 +1554,9 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stream: The stream whose writes should be captured. This
|
stream: The stream whose writes should be captured. This stream must have
|
||||||
stream must have a file descriptor, support writing via using that
|
a file descriptor, support writing via using that file descriptor, and
|
||||||
file descriptor, and must have a `.flush()` method.
|
must have a `.flush()` method.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
A `CapturedWrites` object that contains all writes to the specified stream
|
A `CapturedWrites` object that contains all writes to the specified stream
|
||||||
@ -1840,7 +1905,6 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
self._threads.append(ret)
|
self._threads.append(ret)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
# pylint: enable=invalid-name
|
# pylint: enable=invalid-name
|
||||||
@py_func_if_in_function
|
@py_func_if_in_function
|
||||||
def assertNear(self, f1, f2, err, msg=None):
|
def assertNear(self, f1, f2, err, msg=None):
|
||||||
@ -1857,9 +1921,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
"""
|
"""
|
||||||
# f1 == f2 is needed here as we might have: f1, f2 = inf, inf
|
# f1 == f2 is needed here as we might have: f1, f2 = inf, inf
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
f1 == f2 or math.fabs(f1 - f2) <= err,
|
f1 == f2 or math.fabs(f1 - f2) <= err, "%f != %f +/- %f%s" %
|
||||||
"%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg
|
(f1, f2, err, " (%s)" % msg if msg is not None else ""))
|
||||||
if msg is not None else ""))
|
|
||||||
|
|
||||||
@py_func_if_in_function
|
@py_func_if_in_function
|
||||||
def assertArrayNear(self, farray1, farray2, err, msg=None):
|
def assertArrayNear(self, farray1, farray2, err, msg=None):
|
||||||
@ -2028,11 +2091,11 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
a: The expected numpy `ndarray`, or anything that can be converted into a
|
a: The expected numpy `ndarray`, or anything that can be converted into a
|
||||||
numpy `ndarray` (including Tensor), or any arbitrarily nested of
|
numpy `ndarray` (including Tensor), or any arbitrarily nested of
|
||||||
structure of these.
|
structure of these.
|
||||||
b: The actual numpy `ndarray`, or anything that can be converted into a
|
b: The actual numpy `ndarray`, or anything that can be converted into a
|
||||||
numpy `ndarray` (including Tensor), or any arbitrarily nested of
|
numpy `ndarray` (including Tensor), or any arbitrarily nested of
|
||||||
structure of these.
|
structure of these.
|
||||||
rtol: relative tolerance.
|
rtol: relative tolerance.
|
||||||
atol: absolute tolerance.
|
atol: absolute tolerance.
|
||||||
msg: Optional message to report on failure.
|
msg: Optional message to report on failure.
|
||||||
@ -2160,8 +2223,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
"""Assert element values are all greater than a target value.
|
"""Assert element values are all greater than a target value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a: The numpy `ndarray`, or anything that can be converted into a
|
a: The numpy `ndarray`, or anything that can be converted into a numpy
|
||||||
numpy `ndarray` (including Tensor).
|
`ndarray` (including Tensor).
|
||||||
comparison_target: The target value of comparison.
|
comparison_target: The target value of comparison.
|
||||||
"""
|
"""
|
||||||
a = self._GetNdArray(a)
|
a = self._GetNdArray(a)
|
||||||
@ -2172,8 +2235,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
"""Assert element values are all less than a target value.
|
"""Assert element values are all less than a target value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a: The numpy `ndarray`, or anything that can be converted into a
|
a: The numpy `ndarray`, or anything that can be converted into a numpy
|
||||||
numpy `ndarray` (including Tensor).
|
`ndarray` (including Tensor).
|
||||||
comparison_target: The target value of comparison.
|
comparison_target: The target value of comparison.
|
||||||
"""
|
"""
|
||||||
a = self._GetNdArray(a)
|
a = self._GetNdArray(a)
|
||||||
@ -2184,8 +2247,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
"""Assert element values are all greater than or equal to a target value.
|
"""Assert element values are all greater than or equal to a target value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a: The numpy `ndarray`, or anything that can be converted into a
|
a: The numpy `ndarray`, or anything that can be converted into a numpy
|
||||||
numpy `ndarray` (including Tensor).
|
`ndarray` (including Tensor).
|
||||||
comparison_target: The target value of comparison.
|
comparison_target: The target value of comparison.
|
||||||
"""
|
"""
|
||||||
a = self._GetNdArray(a)
|
a = self._GetNdArray(a)
|
||||||
@ -2196,8 +2259,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
"""Assert element values are all less than or equal to a target value.
|
"""Assert element values are all less than or equal to a target value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a: The numpy `ndarray`, or anything that can be converted into a
|
a: The numpy `ndarray`, or anything that can be converted into a numpy
|
||||||
numpy `ndarray` (including Tensor).
|
`ndarray` (including Tensor).
|
||||||
comparison_target: The target value of comparison.
|
comparison_target: The target value of comparison.
|
||||||
"""
|
"""
|
||||||
a = self._GetNdArray(a)
|
a = self._GetNdArray(a)
|
||||||
@ -2245,7 +2308,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
target: The numpy `ndarray`, or anything that can be converted into a
|
target: The numpy `ndarray`, or anything that can be converted into a
|
||||||
numpy `ndarray` (including Tensor).
|
numpy `ndarray` (including Tensor).
|
||||||
lower_bound: lower bound of the range
|
lower_bound: lower bound of the range
|
||||||
upper_bound: upper bound of the range
|
upper_bound: upper bound of the range
|
||||||
open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather
|
open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather
|
||||||
@ -2279,8 +2342,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
str(upper_bound) + (")" if open_upper_bound else "]"))
|
str(upper_bound) + (")" if open_upper_bound else "]"))
|
||||||
|
|
||||||
violations = (
|
violations = (
|
||||||
np.less_equal(target, lower_bound)
|
np.less_equal(target, lower_bound) if open_lower_bound else np.less(
|
||||||
if open_lower_bound else np.less(target, lower_bound))
|
target, lower_bound))
|
||||||
violations = np.logical_or(
|
violations = np.logical_or(
|
||||||
violations,
|
violations,
|
||||||
np.greater_equal(target, upper_bound)
|
np.greater_equal(target, upper_bound)
|
||||||
@ -2299,7 +2362,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
target: The numpy `ndarray`, or anything that can be converted into a
|
target: The numpy `ndarray`, or anything that can be converted into a
|
||||||
numpy `ndarray` (including Tensor).
|
numpy `ndarray` (including Tensor).
|
||||||
expected_set: (`list`, `tuple` or `set`) The closed set that the elements
|
expected_set: (`list`, `tuple` or `set`) The closed set that the elements
|
||||||
of the value of `target` are expected to fall into.
|
of the value of `target` are expected to fall into.
|
||||||
|
|
||||||
@ -2321,7 +2384,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
target: The numpy `ndarray`, or anything that can be converted into a
|
target: The numpy `ndarray`, or anything that can be converted into a
|
||||||
numpy `ndarray` (including Tensor).
|
numpy `ndarray` (including Tensor).
|
||||||
expected_dtype: Expected data type.
|
expected_dtype: Expected data type.
|
||||||
"""
|
"""
|
||||||
target = self._GetNdArray(target)
|
target = self._GetNdArray(target)
|
||||||
@ -2342,9 +2405,9 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
Args:
|
Args:
|
||||||
exception_type: The expected type of exception that should be raised.
|
exception_type: The expected type of exception that should be raised.
|
||||||
expected_err_re_or_predicate: If this is callable, it should be a function
|
expected_err_re_or_predicate: If this is callable, it should be a function
|
||||||
of one argument that inspects the passed-in exception and
|
of one argument that inspects the passed-in exception and returns True
|
||||||
returns True (success) or False (please fail the test). Otherwise, the
|
(success) or False (please fail the test). Otherwise, the error message
|
||||||
error message is expected to match this regular expression partially.
|
is expected to match this regular expression partially.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A context manager to surround code that is expected to raise an
|
A context manager to surround code that is expected to raise an
|
||||||
@ -2445,6 +2508,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
|
|
||||||
def _create_session(self, graph, config, force_gpu):
|
def _create_session(self, graph, config, force_gpu):
|
||||||
"""See session() for details."""
|
"""See session() for details."""
|
||||||
|
|
||||||
def prepare_config(config):
|
def prepare_config(config):
|
||||||
"""Returns a config for sessions.
|
"""Returns a config for sessions.
|
||||||
|
|
||||||
@ -2547,10 +2611,10 @@ def create_local_cluster(num_workers,
|
|||||||
Args:
|
Args:
|
||||||
num_workers: Number of worker servers to start.
|
num_workers: Number of worker servers to start.
|
||||||
num_ps: Number of PS servers to start.
|
num_ps: Number of PS servers to start.
|
||||||
protocol: Communication protocol. Allowed values are documented in
|
protocol: Communication protocol. Allowed values are documented in the
|
||||||
the documentation of `tf.train.Server`.
|
documentation of `tf.train.Server`.
|
||||||
worker_config: (optional) ConfigProto to initialize workers. Can be used
|
worker_config: (optional) ConfigProto to initialize workers. Can be used to
|
||||||
to instantiate multiple devices etc.
|
instantiate multiple devices etc.
|
||||||
ps_config: (optional) ConfigProto to initialize PS servers.
|
ps_config: (optional) ConfigProto to initialize PS servers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -131,6 +131,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
],
|
],
|
||||||
grpc_enabled = True,
|
grpc_enabled = True,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -145,6 +146,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:platform_benchmark",
|
"//tensorflow/python:platform_benchmark",
|
||||||
],
|
],
|
||||||
tags = ["no_windows"],
|
tags = ["no_windows"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -161,6 +163,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:platform_benchmark",
|
"//tensorflow/python:platform_benchmark",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
@ -226,6 +229,7 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
shard_count = 5,
|
shard_count = 5,
|
||||||
tags = ["no_windows_gpu"],
|
tags = ["no_windows_gpu"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
@ -282,6 +286,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
@ -387,6 +392,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:linalg_ops",
|
"//tensorflow/python:linalg_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
@ -645,6 +651,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
tags = ["optonly"],
|
tags = ["optonly"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -659,6 +666,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:linalg_ops",
|
"//tensorflow/python:linalg_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -671,6 +679,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:linalg_ops",
|
"//tensorflow/python:linalg_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -684,6 +693,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:linalg_ops",
|
"//tensorflow/python:linalg_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -695,6 +705,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:linalg_ops",
|
"//tensorflow/python:linalg_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -713,6 +724,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
@ -795,6 +807,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:resource_variable_ops",
|
"//tensorflow/python:resource_variable_ops",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
@ -853,6 +866,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:resource_variable_ops",
|
"//tensorflow/python:resource_variable_ops",
|
||||||
],
|
],
|
||||||
tags = ["noasan"], # http://b/32635055
|
tags = ["noasan"], # http://b/32635055
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
@ -1159,6 +1173,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1174,6 +1189,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
@ -1290,6 +1306,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1311,6 +1328,7 @@ cuda_py_test(
|
|||||||
"noguitar",
|
"noguitar",
|
||||||
"notap",
|
"notap",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1326,6 +1344,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
tags = ["notsan"],
|
tags = ["notsan"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1342,6 +1361,7 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
shard_count = 2,
|
shard_count = 2,
|
||||||
tags = ["no_windows_gpu"],
|
tags = ["no_windows_gpu"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
@ -1374,6 +1394,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1385,6 +1406,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1413,6 +1435,7 @@ cuda_py_test(
|
|||||||
"noasan", # times out
|
"noasan", # times out
|
||||||
"optonly", # times out
|
"optonly", # times out
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1425,6 +1448,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client",
|
"//tensorflow/python:client",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1441,6 +1465,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
shard_count = 10,
|
shard_count = 10,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1455,6 +1480,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1468,6 +1494,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1482,6 +1509,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1497,6 +1525,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1509,6 +1538,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1527,6 +1557,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1542,6 +1573,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1560,6 +1592,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1600,6 +1633,7 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
shard_count = 16,
|
shard_count = 16,
|
||||||
tags = ["no_gpu"], # TODO(b/117928656)
|
tags = ["no_gpu"], # TODO(b/117928656)
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
@ -1640,6 +1674,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1654,6 +1689,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1668,6 +1704,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1679,6 +1716,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1692,6 +1730,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1707,6 +1746,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:state_ops",
|
"//tensorflow/python:state_ops",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1721,6 +1761,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
tags = ["no_windows_gpu"],
|
tags = ["no_windows_gpu"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1733,6 +1774,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
tags = ["manual"],
|
tags = ["manual"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1748,6 +1790,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1762,6 +1805,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1774,6 +1818,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1786,6 +1831,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1812,6 +1858,7 @@ cuda_py_test(
|
|||||||
grpc_enabled = True,
|
grpc_enabled = True,
|
||||||
shard_count = 2,
|
shard_count = 2,
|
||||||
tags = ["no_windows"],
|
tags = ["no_windows"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1827,6 +1874,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1841,6 +1889,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1854,6 +1903,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1881,6 +1931,7 @@ cuda_py_test(
|
|||||||
"noasan",
|
"noasan",
|
||||||
"notap",
|
"notap",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1898,6 +1949,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python/ops/linalg",
|
"//tensorflow/python/ops/linalg",
|
||||||
],
|
],
|
||||||
tags = ["no_windows_gpu"],
|
tags = ["no_windows_gpu"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1913,6 +1965,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn",
|
"//tensorflow/python:nn",
|
||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1929,6 +1982,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python/ops/linalg",
|
"//tensorflow/python/ops/linalg",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1942,6 +1996,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
tags = ["no_windows_gpu"],
|
tags = ["no_windows_gpu"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1959,6 +2014,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1972,6 +2028,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -1987,6 +2044,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:numerics",
|
"//tensorflow/python:numerics",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2000,6 +2058,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
tags = ["no_windows_gpu"],
|
tags = ["no_windows_gpu"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2014,6 +2073,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2026,6 +2086,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2040,6 +2101,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2058,6 +2120,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python/eager:function",
|
"//tensorflow/python/eager:function",
|
||||||
],
|
],
|
||||||
tags = ["no_windows"],
|
tags = ["no_windows"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2071,6 +2134,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:string_ops",
|
"//tensorflow/python:string_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2086,6 +2150,7 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
shard_count = 4,
|
shard_count = 4,
|
||||||
tags = ["no_windows_gpu"],
|
tags = ["no_windows_gpu"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2106,6 +2171,7 @@ cuda_py_test(
|
|||||||
"noguitar",
|
"noguitar",
|
||||||
"notap",
|
"notap",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2124,6 +2190,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python/eager:backprop",
|
"//tensorflow/python/eager:backprop",
|
||||||
"//tensorflow/python:tf2",
|
"//tensorflow/python:tf2",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2136,6 +2203,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2148,6 +2216,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2160,6 +2229,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2178,6 +2248,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
"//tensorflow/python:sparse_ops",
|
"//tensorflow/python:sparse_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2191,6 +2262,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2203,6 +2275,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:session_ops",
|
"//tensorflow/python:session_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2219,6 +2292,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2233,6 +2307,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2246,6 +2321,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2259,6 +2335,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2274,6 +2351,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2288,6 +2366,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
tags = ["no_windows_gpu"],
|
tags = ["no_windows_gpu"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
@ -2332,6 +2411,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:sparse_grad",
|
"//tensorflow/python:sparse_grad",
|
||||||
"//tensorflow/python:sparse_ops",
|
"//tensorflow/python:sparse_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2356,6 +2436,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2370,6 +2451,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2385,6 +2467,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2397,6 +2480,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:string_ops",
|
"//tensorflow/python:string_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2409,6 +2493,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:parsing_ops",
|
"//tensorflow/python:parsing_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2422,6 +2507,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:summary",
|
"//tensorflow/python:summary",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2437,6 +2523,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
"//tensorflow/python:summary",
|
"//tensorflow/python:summary",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2467,6 +2554,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
],
|
],
|
||||||
flaky = 1, # create_local_cluster sometimes times out.
|
flaky = 1, # create_local_cluster sometimes times out.
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2479,6 +2567,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
tags = ["no_windows_gpu"],
|
tags = ["no_windows_gpu"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2497,6 +2586,7 @@ cuda_py_test(
|
|||||||
"no_oss",
|
"no_oss",
|
||||||
"optonly", # times out
|
"optonly", # times out
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2509,6 +2599,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2526,6 +2617,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:state_ops_gen",
|
"//tensorflow/python:state_ops_gen",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2542,6 +2634,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
"//tensorflow/python:nn_ops_gen",
|
"//tensorflow/python:nn_ops_gen",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2553,6 +2646,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2572,6 +2666,7 @@ cuda_py_test(
|
|||||||
tags = [
|
tags = [
|
||||||
"no_gpu", # Flaky: b/80127739
|
"no_gpu", # Flaky: b/80127739
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2587,6 +2682,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
tags = ["manual"],
|
tags = ["manual"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2600,6 +2696,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2618,6 +2715,7 @@ cuda_py_test(
|
|||||||
tags = [
|
tags = [
|
||||||
"optonly", # flaky timeouts unless optimized
|
"optonly", # flaky timeouts unless optimized
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2631,6 +2729,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2659,6 +2758,7 @@ cuda_py_test(
|
|||||||
"no_oss",
|
"no_oss",
|
||||||
"optonly", # times out
|
"optonly", # times out
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2676,6 +2776,7 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
# TODO(b/118842098): Re-enable this test in Kokoro.
|
# TODO(b/118842098): Re-enable this test in Kokoro.
|
||||||
tags = ["no_oss"],
|
tags = ["no_oss"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
@ -2704,6 +2805,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
tags = ["manual"],
|
tags = ["manual"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2717,6 +2819,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2735,6 +2838,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_ops_gen",
|
"//tensorflow/python:nn_ops_gen",
|
||||||
],
|
],
|
||||||
shard_count = 4,
|
shard_count = 4,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2764,6 +2868,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
],
|
],
|
||||||
shard_count = 10,
|
shard_count = 10,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2779,6 +2884,7 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
shard_count = 2,
|
shard_count = 2,
|
||||||
tags = ["optonly"],
|
tags = ["optonly"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2793,6 +2899,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2809,6 +2916,7 @@ cuda_py_test(
|
|||||||
tags = [
|
tags = [
|
||||||
"no_oss", # Requires 4GB+ RAM
|
"no_oss", # Requires 4GB+ RAM
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2822,6 +2930,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
tags = ["no_windows"],
|
tags = ["no_windows"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2845,6 +2954,7 @@ cuda_py_test(
|
|||||||
"noasan",
|
"noasan",
|
||||||
"optonly", # b/77589990
|
"optonly", # b/77589990
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2864,6 +2974,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:sparse_ops",
|
"//tensorflow/python:sparse_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(gpapan): Revisit the gradient of extract_image_patches_op to resolve
|
# TODO(gpapan): Revisit the gradient of extract_image_patches_op to resolve
|
||||||
@ -2880,6 +2991,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
tags = ["notap"], # http://b/31080670
|
tags = ["notap"], # http://b/31080670
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2894,6 +3006,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python:data_flow_ops",
|
"//tensorflow/python:data_flow_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2908,6 +3021,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python:data_flow_ops",
|
"//tensorflow/python:data_flow_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2925,6 +3039,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2940,6 +3055,7 @@ cuda_py_test(
|
|||||||
"nomsan",
|
"nomsan",
|
||||||
"notsan",
|
"notsan",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2953,6 +3069,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
],
|
],
|
||||||
shard_count = 30,
|
shard_count = 30,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2973,6 +3090,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
shard_count = 50,
|
shard_count = 50,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -2993,6 +3111,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
shard_count = 50,
|
shard_count = 50,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -3013,6 +3132,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
shard_count = 50,
|
shard_count = 50,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -3038,6 +3158,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -3055,6 +3176,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python/ops/linalg",
|
"//tensorflow/python/ops/linalg",
|
||||||
],
|
],
|
||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -3068,6 +3190,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
],
|
],
|
||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -3085,6 +3208,7 @@ cuda_py_test(
|
|||||||
data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"],
|
data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"],
|
||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
tags = ["no_windows"],
|
tags = ["no_windows"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -3100,6 +3224,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -3119,6 +3244,7 @@ cuda_py_test(
|
|||||||
"no_oss", # b/117185141.
|
"no_oss", # b/117185141.
|
||||||
"nomsan", # TODO(b/117236102): Re-enable in msan build.
|
"nomsan", # TODO(b/117236102): Re-enable in msan build.
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -3138,6 +3264,7 @@ cuda_py_test(
|
|||||||
"no_windows_gpu",
|
"no_windows_gpu",
|
||||||
"nomsan",
|
"nomsan",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -3153,6 +3280,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
sycl_py_test(
|
sycl_py_test(
|
||||||
@ -3399,6 +3527,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:while_v2",
|
"//tensorflow/python:while_v2",
|
||||||
],
|
],
|
||||||
grpc_enabled = True,
|
grpc_enabled = True,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -3424,4 +3553,5 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:while_v2",
|
"//tensorflow/python:while_v2",
|
||||||
],
|
],
|
||||||
grpc_enabled = True,
|
grpc_enabled = True,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
@ -139,6 +139,7 @@ class AtrousConv2DTest(test.TestCase):
|
|||||||
y1.eval(), self.evaluate(y2), rtol=1e-2, atol=1e-2)
|
y1.eval(), self.evaluate(y2), rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA") # larger error range
|
||||||
def testGradient(self):
|
def testGradient(self):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
# Input: [batch, height, width, input_depth]
|
# Input: [batch, height, width, input_depth]
|
||||||
|
@ -26,6 +26,7 @@ import numpy as np
|
|||||||
from tensorflow.core.util import test_log_pb2
|
from tensorflow.core.util import test_log_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import benchmark
|
from tensorflow.python.platform import benchmark
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
@ -125,6 +126,7 @@ class BenchmarkTest(test.TestCase):
|
|||||||
self.assertFalse(_ran_somebenchmark_2[0])
|
self.assertFalse(_ran_somebenchmark_2[0])
|
||||||
self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
|
self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
|
||||||
|
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testReportingBenchmark(self):
|
def testReportingBenchmark(self):
|
||||||
tempdir = test.get_temp_dir()
|
tempdir = test.get_temp_dir()
|
||||||
try:
|
try:
|
||||||
|
@ -889,6 +889,8 @@ class EnsureShapeTest(test.TestCase):
|
|||||||
|
|
||||||
# Dynamic shape check
|
# Dynamic shape check
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA"
|
||||||
|
) # Dynamic shapes not supported now with XLA
|
||||||
def testEnsuresDynamicShape_RaisesError(self):
|
def testEnsuresDynamicShape_RaisesError(self):
|
||||||
placeholder = array_ops.placeholder(dtypes.int32)
|
placeholder = array_ops.placeholder(dtypes.int32)
|
||||||
derived = math_ops.divide(placeholder, 3, name="MyDivide")
|
derived = math_ops.divide(placeholder, 3, name="MyDivide")
|
||||||
@ -902,6 +904,8 @@ class EnsureShapeTest(test.TestCase):
|
|||||||
sess.run(derived, feed_dict={placeholder: feed_val})
|
sess.run(derived, feed_dict={placeholder: feed_val})
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA"
|
||||||
|
) # Dynamic shapes not supported now with XLA
|
||||||
def testEnsuresDynamicShape_RaisesErrorDimUnknown(self):
|
def testEnsuresDynamicShape_RaisesErrorDimUnknown(self):
|
||||||
placeholder = array_ops.placeholder(dtypes.int32)
|
placeholder = array_ops.placeholder(dtypes.int32)
|
||||||
derived = placeholder / 3
|
derived = placeholder / 3
|
||||||
|
@ -163,6 +163,7 @@ class CholeskyOpTest(test.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
linalg_ops.cholesky(tensor3)
|
linalg_ops.cholesky(tensor3)
|
||||||
|
|
||||||
|
@test_util.disable_xla("This test never passed for XLA") # all nan on XLA
|
||||||
def testNotInvertibleCPU(self):
|
def testNotInvertibleCPU(self):
|
||||||
# The input should be invertible.
|
# The input should be invertible.
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
|
@ -33,6 +33,7 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_all_xla("This test never passed for XLA")
|
||||||
class ConcatOpTest(test.TestCase):
|
class ConcatOpTest(test.TestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@ -640,6 +641,8 @@ class ConcatOpTest(test.TestCase):
|
|||||||
output = self.evaluate(c)
|
output = self.evaluate(c)
|
||||||
self.assertAllEqual([[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], output)
|
self.assertAllEqual([[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], output)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_all_xla("This test never passed for XLA")
|
||||||
class ConcatOffsetTest(test.TestCase):
|
class ConcatOffsetTest(test.TestCase):
|
||||||
|
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
@ -683,6 +686,8 @@ class ConcatOffsetTest(test.TestCase):
|
|||||||
self.evaluate(off)
|
self.evaluate(off)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla(
|
||||||
|
"This test never passed for XLA") # Different error message on XLA
|
||||||
def testSizeMismatch(self):
|
def testSizeMismatch(self):
|
||||||
cdim = constant_op.constant(1, dtypes.int32)
|
cdim = constant_op.constant(1, dtypes.int32)
|
||||||
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
|
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
|
||||||
|
@ -2380,6 +2380,8 @@ class ControlFlowTest(test.TestCase):
|
|||||||
self.assertEqual(i_val, 3)
|
self.assertEqual(i_val, 3)
|
||||||
self.assertAllClose(x_val, 1.0)
|
self.assertAllClose(x_val, 1.0)
|
||||||
|
|
||||||
|
@test_util.disable_xla("This test never passed for XLA"
|
||||||
|
) # Resource variable issue for ControlFlowV2
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
def testGpuResourceAccess(self):
|
def testGpuResourceAccess(self):
|
||||||
with ops.device(test.gpu_device_name()):
|
with ops.device(test.gpu_device_name()):
|
||||||
|
@ -295,6 +295,7 @@ class DepthToSpaceTest(test.TestCase):
|
|||||||
actual_vals, expected_vals = self.evaluate([actual, expected])
|
actual_vals, expected_vals = self.evaluate([actual, expected])
|
||||||
self.assertTrue(np.array_equal(actual_vals, expected_vals))
|
self.assertTrue(np.array_equal(actual_vals, expected_vals))
|
||||||
|
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testAgainstTranspose(self):
|
def testAgainstTranspose(self):
|
||||||
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False)
|
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False)
|
||||||
self.compareToTranspose(3, 2, 3, 2, 2, "NHWC", False)
|
self.compareToTranspose(3, 2, 3, 2, 2, "NHWC", False)
|
||||||
|
@ -481,6 +481,7 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
data_format="NCHW")
|
data_format="NCHW")
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testDepthwiseConv2DFilterGrad(self):
|
def testDepthwiseConv2DFilterGrad(self):
|
||||||
for index, (input_size, filter_size, output_size, stride,
|
for index, (input_size, filter_size, output_size, stride,
|
||||||
padding) in enumerate(CheckGradConfigsToTest()):
|
padding) in enumerate(CheckGradConfigsToTest()):
|
||||||
@ -612,6 +613,7 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
cpu_value = _GetVal(use_gpu=False)
|
cpu_value = _GetVal(use_gpu=False)
|
||||||
self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
|
self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
|
||||||
|
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testDepthwiseConv2DFilterGradCompare(self):
|
def testDepthwiseConv2DFilterGradCompare(self):
|
||||||
for index, (input_size, filter_size, output_size, stride,
|
for index, (input_size, filter_size, output_size, stride,
|
||||||
padding) in enumerate(ConfigsToTest()):
|
padding) in enumerate(ConfigsToTest()):
|
||||||
|
@ -65,6 +65,7 @@ class MatrixDiagTest(test.TestCase):
|
|||||||
array_ops.matrix_diag(0)
|
array_ops.matrix_diag(0)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testInvalidShapeAtEval(self):
|
def testInvalidShapeAtEval(self):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
v = array_ops.placeholder(dtype=dtypes_lib.float32)
|
v = array_ops.placeholder(dtype=dtypes_lib.float32)
|
||||||
@ -269,6 +270,7 @@ class MatrixDiagPartTest(test.TestCase):
|
|||||||
array_ops.matrix_diag_part(0)
|
array_ops.matrix_diag_part(0)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testInvalidShapeAtEval(self):
|
def testInvalidShapeAtEval(self):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
v = array_ops.placeholder(dtype=dtypes_lib.float32)
|
v = array_ops.placeholder(dtype=dtypes_lib.float32)
|
||||||
|
@ -23,6 +23,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -42,6 +43,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
shard_count = 3,
|
shard_count = 3,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -54,6 +56,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -71,6 +74,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -86,6 +90,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -105,6 +110,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -119,6 +125,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -139,6 +146,7 @@ cuda_py_test(
|
|||||||
"noguitar", # b/110489471
|
"noguitar", # b/110489471
|
||||||
"notap", # b/110489471
|
"notap", # b/110489471
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -154,6 +162,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -169,6 +178,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -184,6 +194,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -200,6 +211,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
tags = ["manual"], # b/69001419
|
tags = ["manual"], # b/69001419
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -224,6 +236,7 @@ cuda_py_test(
|
|||||||
# disable to avoid false positives from scipy.
|
# disable to avoid false positives from scipy.
|
||||||
"nomsan",
|
"nomsan",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -240,6 +253,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -258,6 +272,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -274,6 +289,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -291,4 +307,5 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
@ -58,6 +58,7 @@ def simple_scoped_fn(a, x):
|
|||||||
|
|
||||||
|
|
||||||
@test_util.with_control_flow_v2
|
@test_util.with_control_flow_v2
|
||||||
|
@test_util.disable_all_xla("This test never passed for XLA")
|
||||||
class FunctionalOpsTest(test.TestCase):
|
class FunctionalOpsTest(test.TestCase):
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
@ -826,6 +827,8 @@ class FunctionalOpsTest(test.TestCase):
|
|||||||
self.assertAllEqual(Run(100., True), 5050.)
|
self.assertAllEqual(Run(100., True), 5050.)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
|
@test_util.disable_xla(
|
||||||
|
"This test never passed for XLA") # Different error message
|
||||||
def testWhileError(self):
|
def testWhileError(self):
|
||||||
for use_gpu in (True, False):
|
for use_gpu in (True, False):
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
@ -1102,6 +1105,7 @@ class FunctionalOpsTest(test.TestCase):
|
|||||||
|
|
||||||
# TODO(akshayka): Replace `function.Defun` with tf.contrib.eager.defun` in the
|
# TODO(akshayka): Replace `function.Defun` with tf.contrib.eager.defun` in the
|
||||||
# below test cases.
|
# below test cases.
|
||||||
|
@test_util.disable_all_xla("This test never passed for XLA")
|
||||||
class PartitionedCallTest(test.TestCase):
|
class PartitionedCallTest(test.TestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@ -34,6 +34,7 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_all_xla("This test never passed for XLA")
|
||||||
class GatherNdTest(test.TestCase):
|
class GatherNdTest(test.TestCase):
|
||||||
|
|
||||||
def _testSimpleDtype(self, dtype):
|
def _testSimpleDtype(self, dtype):
|
||||||
@ -56,6 +57,7 @@ class GatherNdTest(test.TestCase):
|
|||||||
self._testSimpleDtype("|S") # byte strings in python2 + 3
|
self._testSimpleDtype("|S") # byte strings in python2 + 3
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self):
|
def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
params = np.ones((3, 3), dtype=np.float32)
|
params = np.ones((3, 3), dtype=np.float32)
|
||||||
@ -358,6 +360,7 @@ class GatherNdTest(test.TestCase):
|
|||||||
self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval())
|
self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval())
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_all_xla("This test never passed for XLA")
|
||||||
class GatherNdOpBenchmark(test.Benchmark):
|
class GatherNdOpBenchmark(test.Benchmark):
|
||||||
|
|
||||||
def benchmark_gather_nd_op(self):
|
def benchmark_gather_nd_op(self):
|
||||||
|
@ -213,6 +213,8 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
array_ops.gather(params, [[7]], axis=1).eval()
|
array_ops.gather(params, [[7]], axis=1).eval()
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla(
|
||||||
|
"This test never passed for XLA") # Different error message.
|
||||||
def testBadAxis(self):
|
def testBadAxis(self):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
params = [0, 1, 2]
|
params = [0, 1, 2]
|
||||||
|
@ -22,6 +22,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -38,6 +39,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -60,6 +62,7 @@ cuda_py_test(
|
|||||||
"noasan", # times out, b/63678675
|
"noasan", # times out, b/63678675
|
||||||
"optonly", # times out
|
"optonly", # times out
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -76,6 +79,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -98,6 +102,7 @@ cuda_py_test(
|
|||||||
"noasan",
|
"noasan",
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -120,6 +125,7 @@ cuda_py_test(
|
|||||||
"noasan", # times out, b/63678675
|
"noasan", # times out, b/63678675
|
||||||
"optonly", # times out
|
"optonly", # times out
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -144,6 +150,7 @@ cuda_py_test(
|
|||||||
"noasan", # times out, b/63678675
|
"noasan", # times out, b/63678675
|
||||||
"optonly", # times out, b/79171797
|
"optonly", # times out, b/79171797
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -166,6 +173,7 @@ cuda_py_test(
|
|||||||
"noasan",
|
"noasan",
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -187,6 +195,7 @@ cuda_py_test(
|
|||||||
"noasan",
|
"noasan",
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -209,6 +218,7 @@ cuda_py_test(
|
|||||||
"noasan", # times out, b/63678675
|
"noasan", # times out, b/63678675
|
||||||
"optonly", # times out
|
"optonly", # times out
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -229,6 +239,7 @@ cuda_py_test(
|
|||||||
"noasan",
|
"noasan",
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -251,6 +262,7 @@ cuda_py_test(
|
|||||||
"noasan",
|
"noasan",
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -270,6 +282,7 @@ cuda_py_test(
|
|||||||
"noasan",
|
"noasan",
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -290,6 +303,7 @@ cuda_py_test(
|
|||||||
"noasan", # times out
|
"noasan", # times out
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -310,6 +324,7 @@ cuda_py_test(
|
|||||||
"noasan",
|
"noasan",
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -328,4 +343,5 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
shard_count = 5,
|
shard_count = 5,
|
||||||
tags = ["optonly"], # Test is flaky without optimization.
|
tags = ["optonly"], # Test is flaky without optimization.
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
@ -303,6 +303,7 @@ class PoolingTest(test.TestCase):
|
|||||||
self.assertLess(err, err_tolerance)
|
self.assertLess(err, err_tolerance)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA") # Much larger error
|
||||||
def testGradient1D(self):
|
def testGradient1D(self):
|
||||||
with self.session(use_gpu=test.is_gpu_available()):
|
with self.session(use_gpu=test.is_gpu_available()):
|
||||||
for padding in ["SAME", "VALID"]:
|
for padding in ["SAME", "VALID"]:
|
||||||
|
@ -730,6 +730,7 @@ class PoolingTest(test.TestCase):
|
|||||||
t = nn_ops.max_pool(
|
t = nn_ops.max_pool(
|
||||||
t, ksize=ksize, strides=strides, padding="SAME").eval()
|
t, ksize=ksize, strides=strides, padding="SAME").eval()
|
||||||
|
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testDepthwiseMaxPoolInvalidConfigs(self):
|
def testDepthwiseMaxPoolInvalidConfigs(self):
|
||||||
self._testDepthwiseMaxPoolInvalidConfig(
|
self._testDepthwiseMaxPoolInvalidConfig(
|
||||||
[1, 2, 2, 4], [1, 2, 2, 2], [1, 1, 1, 2],
|
[1, 2, 2, 4], [1, 2, 2, 2], [1, 1, 1, 2],
|
||||||
@ -1174,6 +1175,7 @@ class PoolingTest(test.TestCase):
|
|||||||
use_gpu=use_gpu)
|
use_gpu=use_gpu)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testMaxPoolGrad(self):
|
def testMaxPoolGrad(self):
|
||||||
for (data_format, use_gpu) in GetTestConfigs():
|
for (data_format, use_gpu) in GetTestConfigs():
|
||||||
self._testMaxPoolGradValidPadding1_1(data_format, use_gpu)
|
self._testMaxPoolGradValidPadding1_1(data_format, use_gpu)
|
||||||
@ -1210,6 +1212,7 @@ class PoolingTest(test.TestCase):
|
|||||||
[1, window_rows, window_cols, 1],
|
[1, window_rows, window_cols, 1],
|
||||||
[1, row_stride, col_stride, 1], padding)
|
[1, row_stride, col_stride, 1], padding)
|
||||||
|
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def _testMaxPoolGradDirect(self, input_data, output_backprop,
|
def _testMaxPoolGradDirect(self, input_data, output_backprop,
|
||||||
expected_input_backprop, input_sizes, output_sizes,
|
expected_input_backprop, input_sizes, output_sizes,
|
||||||
window_rows, window_cols, row_stride, col_stride,
|
window_rows, window_cols, row_stride, col_stride,
|
||||||
@ -1625,6 +1628,7 @@ class PoolingTest(test.TestCase):
|
|||||||
use_gpu=use_gpu)
|
use_gpu=use_gpu)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testMaxPoolGradGrad(self):
|
def testMaxPoolGradGrad(self):
|
||||||
for (data_format, use_gpu) in GetTestConfigs():
|
for (data_format, use_gpu) in GetTestConfigs():
|
||||||
self._testMaxPoolGradGradValidPadding1_1(data_format, use_gpu)
|
self._testMaxPoolGradGradValidPadding1_1(data_format, use_gpu)
|
||||||
@ -1659,6 +1663,7 @@ class PoolingTest(test.TestCase):
|
|||||||
[1, row_stride, col_stride, 1], padding)
|
[1, row_stride, col_stride, 1], padding)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testAvgPoolGrad(self):
|
def testAvgPoolGrad(self):
|
||||||
for (data_format, use_gpu) in GetTestConfigs():
|
for (data_format, use_gpu) in GetTestConfigs():
|
||||||
self._testAvgPoolGradValidPadding1_1(data_format, use_gpu)
|
self._testAvgPoolGradValidPadding1_1(data_format, use_gpu)
|
||||||
@ -1818,6 +1823,7 @@ class PoolingTest(test.TestCase):
|
|||||||
padding="SAME")
|
padding="SAME")
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testOpEdgeCases(self):
|
def testOpEdgeCases(self):
|
||||||
with self.session(use_gpu=test.is_gpu_available()) as sess:
|
with self.session(use_gpu=test.is_gpu_available()) as sess:
|
||||||
pool_funcs = [nn_ops.max_pool, nn_ops.avg_pool]
|
pool_funcs = [nn_ops.max_pool, nn_ops.avg_pool]
|
||||||
@ -1893,9 +1899,17 @@ if __name__ == "__main__":
|
|||||||
padding_) in GetShrunkInceptionMaxPoolShapes():
|
padding_) in GetShrunkInceptionMaxPoolShapes():
|
||||||
setattr(PoolingTest, "testMaxPoolFwd_" + name_,
|
setattr(PoolingTest, "testMaxPoolFwd_" + name_,
|
||||||
GetMaxPoolFwdTest(input_size_, filter_size_, stride_, padding_))
|
GetMaxPoolFwdTest(input_size_, filter_size_, stride_, padding_))
|
||||||
setattr(PoolingTest, "testMaxPoolGrad_" + name_,
|
if name_ == "maxpool5":
|
||||||
GetMaxPoolGradTest(input_size_, filter_size_, output_size_, stride_,
|
setattr(
|
||||||
padding_))
|
PoolingTest, "testMaxPoolGrad_" + name_,
|
||||||
|
test_util.disable_xla("maxpool5 fails while all others pass")(
|
||||||
|
GetMaxPoolGradTest(input_size_, filter_size_, output_size_,
|
||||||
|
stride_, padding_)))
|
||||||
|
else:
|
||||||
|
setattr(
|
||||||
|
PoolingTest, "testMaxPoolGrad_" + name_,
|
||||||
|
GetMaxPoolGradTest(input_size_, filter_size_, output_size_, stride_,
|
||||||
|
padding_))
|
||||||
setattr(PoolingTest, "testMaxPoolGradGrad_" + name_,
|
setattr(PoolingTest, "testMaxPoolGradGrad_" + name_,
|
||||||
GetMaxPoolGradGradTest(input_size_, filter_size_, output_size_,
|
GetMaxPoolGradGradTest(input_size_, filter_size_, output_size_,
|
||||||
stride_, padding_))
|
stride_, padding_))
|
||||||
|
@ -45,6 +45,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -64,6 +65,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
],
|
],
|
||||||
shard_count = 3,
|
shard_count = 3,
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -75,6 +77,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -88,6 +91,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -103,6 +107,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
"//tensorflow/python:stateless_random_ops",
|
"//tensorflow/python:stateless_random_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -120,6 +125,7 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
shard_count = 4,
|
shard_count = 4,
|
||||||
tags = ["nozapfhahn"],
|
tags = ["nozapfhahn"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -137,6 +143,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:random_grad",
|
"//tensorflow/python:random_grad",
|
||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -152,4 +159,5 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
@ -257,6 +257,7 @@ class TruncatedNormalTest(test.TestCase):
|
|||||||
self.assertAllEqual(rnd1, rnd2)
|
self.assertAllEqual(rnd1, rnd2)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_all_xla("This never passed on XLA")
|
||||||
class RandomUniformTest(RandomOpTestCommon):
|
class RandomUniformTest(RandomOpTestCommon):
|
||||||
|
|
||||||
def _Sampler(self, num, minv, maxv, dtype, use_gpu, seed=None):
|
def _Sampler(self, num, minv, maxv, dtype, use_gpu, seed=None):
|
||||||
|
@ -86,6 +86,7 @@ class ReluTest(test.TestCase):
|
|||||||
self.assertAllClose(np_relu, tf_relu)
|
self.assertAllClose(np_relu, tf_relu)
|
||||||
self.assertShapeEqual(np_relu, tf_relu)
|
self.assertShapeEqual(np_relu, tf_relu)
|
||||||
|
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testReluInt8x4BadShape(self):
|
def testReluInt8x4BadShape(self):
|
||||||
if not test.is_gpu_available(cuda_only=True):
|
if not test.is_gpu_available(cuda_only=True):
|
||||||
self.skipTest("No GPU available")
|
self.skipTest("No GPU available")
|
||||||
|
@ -741,6 +741,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testDestroyResource(self):
|
def testDestroyResource(self):
|
||||||
v = resource_variable_ops.ResourceVariable(3.0, name="var0")
|
v = resource_variable_ops.ResourceVariable(3.0, name="var0")
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
@ -70,6 +70,7 @@ def handle_options(func, x, axis, exclusive, reverse):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_all_xla("This test never passed for XLA")
|
||||||
class CumsumTest(test.TestCase):
|
class CumsumTest(test.TestCase):
|
||||||
|
|
||||||
valid_dtypes = [
|
valid_dtypes = [
|
||||||
@ -193,6 +194,7 @@ class CumsumTest(test.TestCase):
|
|||||||
self._compareGradient([5, 10], axis, exclusive, reverse)
|
self._compareGradient([5, 10], axis, exclusive, reverse)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_all_xla("This test never passed for XLA")
|
||||||
class CumprodTest(test.TestCase):
|
class CumprodTest(test.TestCase):
|
||||||
|
|
||||||
valid_dtypes = [
|
valid_dtypes = [
|
||||||
|
@ -295,6 +295,7 @@ class StatefulScatterNdTest(test.TestCase):
|
|||||||
updates).get_shape().as_list(), shape)
|
updates).get_shape().as_list(), shape)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testResVarInvalidOutputShape(self):
|
def testResVarInvalidOutputShape(self):
|
||||||
res = variables.Variable(
|
res = variables.Variable(
|
||||||
initial_value=lambda: array_ops.zeros(shape=[], dtype=dtypes.float32),
|
initial_value=lambda: array_ops.zeros(shape=[], dtype=dtypes.float32),
|
||||||
|
@ -29,6 +29,7 @@ cuda_py_tests(
|
|||||||
"//tensorflow/python:spectral_ops_test_util",
|
"//tensorflow/python:spectral_ops_test_util",
|
||||||
"//tensorflow/python/ops/signal",
|
"//tensorflow/python/ops/signal",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
@ -45,6 +46,7 @@ cuda_py_tests(
|
|||||||
],
|
],
|
||||||
shard_count = 4,
|
shard_count = 4,
|
||||||
tags = ["optonly"],
|
tags = ["optonly"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
@ -56,6 +58,7 @@ cuda_py_tests(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python/ops/signal",
|
"//tensorflow/python/ops/signal",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
@ -70,6 +73,7 @@ cuda_py_tests(
|
|||||||
"//tensorflow/python/ops/signal",
|
"//tensorflow/python/ops/signal",
|
||||||
"//tensorflow/python:spectral_ops_test_util",
|
"//tensorflow/python:spectral_ops_test_util",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
@ -87,6 +91,7 @@ cuda_py_tests(
|
|||||||
"//tensorflow/python/ops/signal",
|
"//tensorflow/python/ops/signal",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
@ -104,6 +109,7 @@ cuda_py_tests(
|
|||||||
"//tensorflow/python/ops/signal",
|
"//tensorflow/python/ops/signal",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
@ -125,6 +131,7 @@ cuda_py_tests(
|
|||||||
"//tensorflow/python/ops/signal",
|
"//tensorflow/python/ops/signal",
|
||||||
],
|
],
|
||||||
tags = ["nomac"],
|
tags = ["nomac"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
@ -140,4 +147,5 @@ cuda_py_tests(
|
|||||||
"//tensorflow/python/ops/signal",
|
"//tensorflow/python/ops/signal",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
@ -465,6 +465,7 @@ class RFFTOpsTest(BaseFFTOpsTest):
|
|||||||
gen_complex(complex_dims), rank, (size,) * rank)
|
gen_complex(complex_dims), rank, (size,) * rank)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testError(self):
|
def testError(self):
|
||||||
with spectral_ops_test_util.fft_kernel_label_map():
|
with spectral_ops_test_util.fft_kernel_label_map():
|
||||||
for rank in VALID_FFT_RANKS:
|
for rank in VALID_FFT_RANKS:
|
||||||
|
@ -285,6 +285,7 @@ class SpaceToDepthTest(test.TestCase):
|
|||||||
actual_vals, expected_vals = self.evaluate([actual, expected])
|
actual_vals, expected_vals = self.evaluate([actual, expected])
|
||||||
self.assertTrue(np.array_equal(actual_vals, expected_vals))
|
self.assertTrue(np.array_equal(actual_vals, expected_vals))
|
||||||
|
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testAgainstTranspose(self):
|
def testAgainstTranspose(self):
|
||||||
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False)
|
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False)
|
||||||
self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", False)
|
self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", False)
|
||||||
|
@ -373,6 +373,7 @@ class SplitOpTest(test.TestCase):
|
|||||||
assert s1.shape.as_list() == [1]
|
assert s1.shape.as_list() == [1]
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testNonexistentDimTensor(self):
|
def testNonexistentDimTensor(self):
|
||||||
x = array_ops.placeholder(dtypes.int32)
|
x = array_ops.placeholder(dtypes.int32)
|
||||||
values = np.zeros([5, 30])
|
values = np.zeros([5, 30])
|
||||||
|
@ -35,6 +35,7 @@ class BitwiseOpTest(test_util.TensorFlowTestCase):
|
|||||||
super(BitwiseOpTest, self).__init__(method_name)
|
super(BitwiseOpTest, self).__init__(method_name)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testBinaryOps(self):
|
def testBinaryOps(self):
|
||||||
dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
|
dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
|
||||||
dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64]
|
dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64]
|
||||||
@ -72,6 +73,7 @@ class BitwiseOpTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllEqual(truth, popcnt_result)
|
self.assertAllEqual(truth, popcnt_result)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testInvertOp(self):
|
def testInvertOp(self):
|
||||||
dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
|
dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
|
||||||
dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64]
|
dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64]
|
||||||
@ -97,6 +99,7 @@ class BitwiseOpTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllEqual(inverted, expected)
|
self.assertAllEqual(inverted, expected)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testShiftsWithPositiveLHS(self):
|
def testShiftsWithPositiveLHS(self):
|
||||||
dtype_list = [np.int8, np.int16, np.int32, np.int64,
|
dtype_list = [np.int8, np.int16, np.int32, np.int64,
|
||||||
np.uint8, np.uint16, np.uint32, np.uint64]
|
np.uint8, np.uint16, np.uint32, np.uint64]
|
||||||
|
@ -206,6 +206,7 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
2)
|
2)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testBatchNormGradImpl(self):
|
def testBatchNormGradImpl(self):
|
||||||
x_shape = [7, 5, 4, 6]
|
x_shape = [7, 5, 4, 6]
|
||||||
param_shape = [6]
|
param_shape = [6]
|
||||||
|
@ -523,6 +523,7 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
data_format='NHWC',
|
data_format='NHWC',
|
||||||
is_training=is_training)
|
is_training=is_training)
|
||||||
|
|
||||||
|
@test_util.disable_xla('This test never passed for XLA')
|
||||||
def testBatchNormGradShape5(self):
|
def testBatchNormGradShape5(self):
|
||||||
for is_training in [True, False]:
|
for is_training in [True, False]:
|
||||||
x_shape = [0, 7, 11, 4]
|
x_shape = [0, 7, 11, 4]
|
||||||
|
@ -41,6 +41,7 @@ from tensorflow.python.ops.nn_impl import _compute_sampled_logits
|
|||||||
from tensorflow.python.platform import test as test_lib
|
from tensorflow.python.platform import test as test_lib
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_all_xla("This test never passed for XLA")
|
||||||
class ZeroFractionTest(test_lib.TestCase):
|
class ZeroFractionTest(test_lib.TestCase):
|
||||||
|
|
||||||
def _ZeroFraction(self, x):
|
def _ZeroFraction(self, x):
|
||||||
@ -1017,6 +1018,7 @@ class LeakyReluTest(test_lib.TestCase):
|
|||||||
class SwishTest(test_lib.TestCase):
|
class SwishTest(test_lib.TestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def testValues(self):
|
def testValues(self):
|
||||||
np_values = np.array(
|
np_values = np.array(
|
||||||
[np.linspace(-10.0, 0.0, 100),
|
[np.linspace(-10.0, 0.0, 100),
|
||||||
|
@ -114,6 +114,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -171,4 +172,5 @@ cuda_py_test(
|
|||||||
"//tensorflow/python/ops/losses",
|
"//tensorflow/python/ops/losses",
|
||||||
],
|
],
|
||||||
tags = ["optonly"], # Too slow in non-opt mode
|
tags = ["optonly"], # Too slow in non-opt mode
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
@ -251,6 +251,7 @@ class NNTest(PForTestCase):
|
|||||||
|
|
||||||
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3)
|
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3)
|
||||||
|
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def test_fused_batch_norm(self):
|
def test_fused_batch_norm(self):
|
||||||
data_formats = ["NHWC"]
|
data_formats = ["NHWC"]
|
||||||
if test.is_gpu_available():
|
if test.is_gpu_available():
|
||||||
|
@ -482,6 +482,7 @@ class GradientsTest(test.TestCase):
|
|||||||
pfor_jacobian, while_jacobian = create_lstm_batch_jacobian(8, 4, 2)
|
pfor_jacobian, while_jacobian = create_lstm_batch_jacobian(8, 4, 2)
|
||||||
self.run_and_assert_equal(pfor_jacobian, while_jacobian)
|
self.run_and_assert_equal(pfor_jacobian, while_jacobian)
|
||||||
|
|
||||||
|
@test_util.disable_xla("This test never passed for XLA")
|
||||||
def test_dynamic_lstm_batch_jacobian(self):
|
def test_dynamic_lstm_batch_jacobian(self):
|
||||||
pfor_jacobian, while_gradients = create_dynamic_lstm_batch_jacobian(8, 4, 3)
|
pfor_jacobian, while_gradients = create_dynamic_lstm_batch_jacobian(8, 4, 3)
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
|
@ -119,6 +119,7 @@ class LBetaTest(test.TestCase):
|
|||||||
special_math_ops.lbeta(x).get_shape())
|
special_math_ops.lbeta(x).get_shape())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
@test_util.disable_xla('This test never passed for XLA')
|
||||||
def test_length_1_last_dimension_results_in_one(self):
|
def test_length_1_last_dimension_results_in_one(self):
|
||||||
# If there is only one coefficient, the formula still works, and we get one
|
# If there is only one coefficient, the formula still works, and we get one
|
||||||
# as the answer, always.
|
# as the answer, always.
|
||||||
|
@ -1802,8 +1802,17 @@ def tf_py_test(
|
|||||||
additional_visibility = [],
|
additional_visibility = [],
|
||||||
kernels = [],
|
kernels = [],
|
||||||
flaky = 0,
|
flaky = 0,
|
||||||
|
xla_enable_strict_auto_jit = False,
|
||||||
xla_enabled = False,
|
xla_enabled = False,
|
||||||
grpc_enabled = False):
|
grpc_enabled = False):
|
||||||
|
"""Create one or more python tests with extra tensorflow dependencies."""
|
||||||
|
xla_test_true_list = []
|
||||||
|
|
||||||
|
# xla_enable_strict_auto_jit is used to run Tensorflow unit tests with all XLA compilable
|
||||||
|
# kernels compiled with XLA.
|
||||||
|
if xla_enable_strict_auto_jit:
|
||||||
|
xla_enabled = True
|
||||||
|
xla_test_true_list += ["//tensorflow/python:is_xla_test_true"]
|
||||||
if xla_enabled:
|
if xla_enabled:
|
||||||
additional_deps = additional_deps + tf_additional_xla_deps_py()
|
additional_deps = additional_deps + tf_additional_xla_deps_py()
|
||||||
if grpc_enabled:
|
if grpc_enabled:
|
||||||
@ -1824,7 +1833,7 @@ def tf_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
clean_dep("//tensorflow/python:extra_py_tests_deps"),
|
clean_dep("//tensorflow/python:extra_py_tests_deps"),
|
||||||
clean_dep("//tensorflow/python:gradient_checker"),
|
clean_dep("//tensorflow/python:gradient_checker"),
|
||||||
] + additional_deps,
|
] + additional_deps + xla_test_true_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
register_extension_info(
|
register_extension_info(
|
||||||
@ -1844,8 +1853,12 @@ def cuda_py_test(
|
|||||||
kernels = [],
|
kernels = [],
|
||||||
tags = [],
|
tags = [],
|
||||||
flaky = 0,
|
flaky = 0,
|
||||||
|
xla_enable_strict_auto_jit = False,
|
||||||
xla_enabled = False,
|
xla_enabled = False,
|
||||||
grpc_enabled = False):
|
grpc_enabled = False):
|
||||||
|
# TODO(b/122522101): Don't ignore xla_enable_strict_auto_jit and enable additional
|
||||||
|
# XLA tests once enough compute resources are available.
|
||||||
|
_ignored = [xla_enable_strict_auto_jit]
|
||||||
if main == None:
|
if main == None:
|
||||||
main = name + ".py"
|
main = name + ".py"
|
||||||
for config in ["cpu", "gpu"]:
|
for config in ["cpu", "gpu"]:
|
||||||
@ -1868,6 +1881,7 @@ def cuda_py_test(
|
|||||||
shard_count = shard_count,
|
shard_count = shard_count,
|
||||||
tags = test_tags,
|
tags = test_tags,
|
||||||
xla_enabled = xla_enabled,
|
xla_enabled = xla_enabled,
|
||||||
|
xla_enable_strict_auto_jit = False,
|
||||||
)
|
)
|
||||||
|
|
||||||
register_extension_info(
|
register_extension_info(
|
||||||
@ -1921,6 +1935,7 @@ def py_tests(
|
|||||||
tags = [],
|
tags = [],
|
||||||
shard_count = 1,
|
shard_count = 1,
|
||||||
prefix = "",
|
prefix = "",
|
||||||
|
xla_enable_strict_auto_jit = False,
|
||||||
xla_enabled = False,
|
xla_enabled = False,
|
||||||
grpc_enabled = False):
|
grpc_enabled = False):
|
||||||
for src in srcs:
|
for src in srcs:
|
||||||
@ -1939,6 +1954,7 @@ def py_tests(
|
|||||||
shard_count = shard_count,
|
shard_count = shard_count,
|
||||||
tags = tags,
|
tags = tags,
|
||||||
xla_enabled = xla_enabled,
|
xla_enabled = xla_enabled,
|
||||||
|
xla_enable_strict_auto_jit = xla_enable_strict_auto_jit,
|
||||||
)
|
)
|
||||||
|
|
||||||
def cuda_py_tests(
|
def cuda_py_tests(
|
||||||
@ -1951,8 +1967,12 @@ def cuda_py_tests(
|
|||||||
shard_count = 1,
|
shard_count = 1,
|
||||||
tags = [],
|
tags = [],
|
||||||
prefix = "",
|
prefix = "",
|
||||||
|
xla_enable_strict_auto_jit = False,
|
||||||
xla_enabled = False,
|
xla_enabled = False,
|
||||||
grpc_enabled = False):
|
grpc_enabled = False):
|
||||||
|
# TODO(b/122522101): Don't ignore xla_enable_strict_auto_jit and enable additional
|
||||||
|
# XLA tests once enough compute resources are available.
|
||||||
|
_ignored = [xla_enable_strict_auto_jit]
|
||||||
test_tags = tags + tf_cuda_tests_tags()
|
test_tags = tags + tf_cuda_tests_tags()
|
||||||
py_tests(
|
py_tests(
|
||||||
name = name,
|
name = name,
|
||||||
@ -1966,6 +1986,7 @@ def cuda_py_tests(
|
|||||||
shard_count = shard_count,
|
shard_count = shard_count,
|
||||||
tags = test_tags,
|
tags = test_tags,
|
||||||
xla_enabled = xla_enabled,
|
xla_enabled = xla_enabled,
|
||||||
|
xla_enable_strict_auto_jit = False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Creates a genrule named <name> for running tools/proto_text's generator to
|
# Creates a genrule named <name> for running tools/proto_text's generator to
|
||||||
|
Loading…
Reference in New Issue
Block a user