Add test_util.no_xla_auto_jit decorator to mark tests that should never be compiled.
We have some tests that test specific behavior of the TF classic implementation and hence fail when run with XLA. This new decorator is used to document these cases and make tests pass even if compilation with XLA is enabled. First use is PoolingTest.testMaxPoolGradDirect in tensorflow/python/kernel_tests:pooling_ops_test. PiperOrigin-RevId: 245709342
This commit is contained in:
parent
246dd2439c
commit
2900bfa9b0
tensorflow/python
@ -1570,6 +1570,31 @@ def disable_all_xla(description):
|
||||
return disable_all_impl
|
||||
|
||||
|
||||
# The description is just for documentation purposes.
|
||||
def no_xla_auto_jit(description): # pylint: disable=unused-argument
|
||||
|
||||
def no_xla_auto_jit_impl(func):
|
||||
"""This test is not intended to be run with XLA auto jit enabled."""
|
||||
|
||||
def decorator(func):
|
||||
|
||||
def decorated(self, *args, **kwargs):
|
||||
if is_xla_enabled():
|
||||
# Skip test if using XLA is forced.
|
||||
return
|
||||
else:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
|
||||
return decorator
|
||||
|
||||
return no_xla_auto_jit_impl
|
||||
|
||||
|
||||
class EagerSessionWarner(object):
|
||||
|
||||
def __getattr__(self, attr):
|
||||
|
@ -1412,7 +1412,7 @@ class PoolingTest(test.TestCase):
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
@test_util.disable_xla("b/123923733") # NaNs handled differently
|
||||
@test_util.no_xla_auto_jit("b/123923733") # NaNs handled differently
|
||||
def _testMaxPoolGradDirectWithNans2_1(self):
|
||||
input_data = [float("nan")] * 16
|
||||
output_backprop = [11.0, 12.0, 13.0, 15.0, 16.0, 17.0, 19.0, 20.0, 21.0]
|
||||
@ -1487,7 +1487,7 @@ class PoolingTest(test.TestCase):
|
||||
else:
|
||||
del os.environ["TF_ENABLE_MAXPOOL_NANPROP"]
|
||||
|
||||
@test_util.disable_xla("b/123923733") # NaNs handled differently
|
||||
@test_util.no_xla_auto_jit("b/123923733") # NaNs handled differently
|
||||
def _testMaxPoolGradDirectWithNans2_2(self):
|
||||
input_data = [float("nan")] * 16
|
||||
output_backprop = [
|
||||
|
Loading…
Reference in New Issue
Block a user