Add test_util.no_xla_auto_jit decorator to more tests.
Also add the decorator to tensorflow/python:layout_optimizer_test_gpu. For convenience, add a for_all_test_methods decorator, which applies a given decorator to all test methods of a class. PiperOrigin-RevId: 245904023
This commit is contained in:
parent
f3f791ee9c
commit
795f0849f5
@ -1554,20 +1554,31 @@ def disable_xla(description):
|
|||||||
return disable_xla_impl
|
return disable_xla_impl
|
||||||
|
|
||||||
|
|
||||||
# The description is just for documentation purposes.
|
def for_all_test_methods(decorator, *args, **kwargs):
|
||||||
def disable_all_xla(description):
|
"""Generate class-level decorator from given method-level decorator.
|
||||||
|
|
||||||
def disable_all_impl(cls):
|
It is expected for the given decorator to take some arguments and return
|
||||||
"""Execute all test methods in this class only if xla is not enabled."""
|
a method that is then called on the test method to produce a decorated
|
||||||
base_decorator = disable_xla
|
method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decorator: The decorator to apply.
|
||||||
|
*args: Positional arguments
|
||||||
|
**kwargs: Keyword arguments
|
||||||
|
Returns: Function that will decorate a given classes test methods with the
|
||||||
|
decorator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def all_test_methods_impl(cls):
|
||||||
|
"""Apply decorator to all test methods in class."""
|
||||||
for name in dir(cls):
|
for name in dir(cls):
|
||||||
value = getattr(cls, name)
|
value = getattr(cls, name)
|
||||||
if callable(value) and name.startswith(
|
if callable(value) and name.startswith(
|
||||||
"test") and not name == "test_session":
|
"test") and (name != "test_session"):
|
||||||
setattr(cls, name, base_decorator(description)(value))
|
setattr(cls, name, decorator(*args, **kwargs)(value))
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return disable_all_impl
|
return all_test_methods_impl
|
||||||
|
|
||||||
|
|
||||||
# The description is just for documentation purposes.
|
# The description is just for documentation purposes.
|
||||||
|
@ -201,6 +201,8 @@ def _is_permute(node):
|
|||||||
'VecPermuteNCHWToNHWC-LayoutOptimizer')
|
'VecPermuteNCHWToNHWC-LayoutOptimizer')
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.for_all_test_methods(test_util.no_xla_auto_jit,
|
||||||
|
'Test does not apply in XLA setting')
|
||||||
class LayoutOptimizerTest(test.TestCase):
|
class LayoutOptimizerTest(test.TestCase):
|
||||||
"""Tests the Grappler layout optimizer."""
|
"""Tests the Grappler layout optimizer."""
|
||||||
|
|
||||||
|
@ -421,7 +421,8 @@ class ZeroPaddingTest(keras_parameterized.TestCase):
|
|||||||
keras.layers.ZeroPadding3D(padding=None)
|
keras.layers.ZeroPadding3D(padding=None)
|
||||||
|
|
||||||
|
|
||||||
@test_util.disable_all_xla('align_corners=False not supported by XLA')
|
@test_util.for_all_test_methods(test_util.disable_xla,
|
||||||
|
'align_corners=False not supported by XLA')
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
class UpSamplingTest(keras_parameterized.TestCase):
|
class UpSamplingTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
@ -257,7 +257,8 @@ class TruncatedNormalTest(test.TestCase):
|
|||||||
self.assertAllEqual(rnd1, rnd2)
|
self.assertAllEqual(rnd1, rnd2)
|
||||||
|
|
||||||
|
|
||||||
@test_util.disable_all_xla("This never passed on XLA")
|
@test_util.for_all_test_methods(test_util.disable_xla,
|
||||||
|
"This never passed on XLA")
|
||||||
class RandomUniformTest(RandomOpTestCommon):
|
class RandomUniformTest(RandomOpTestCommon):
|
||||||
|
|
||||||
def _Sampler(self, num, minv, maxv, dtype, use_gpu, seed=None):
|
def _Sampler(self, num, minv, maxv, dtype, use_gpu, seed=None):
|
||||||
|
@ -29,7 +29,8 @@ from tensorflow.python.ops import image_ops
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@test_util.disable_all_xla('align_corners=False not supported by XLA')
|
@test_util.for_all_test_methods(test_util.disable_xla,
|
||||||
|
'align_corners=False not supported by XLA')
|
||||||
class ResizeNearestNeighborOpTest(test.TestCase):
|
class ResizeNearestNeighborOpTest(test.TestCase):
|
||||||
|
|
||||||
TYPES = [np.float32, np.float64]
|
TYPES = [np.float32, np.float64]
|
||||||
|
@ -3349,7 +3349,7 @@ class ResizeImageWithPadV1Test(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
|
|
||||||
# half_pixel_centers not supported by XLA
|
# half_pixel_centers not supported by XLA
|
||||||
@test_util.disable_all_xla("b/127616992")
|
@test_util.for_all_test_methods(test_util.disable_xla, "b/127616992")
|
||||||
class ResizeImageWithPadV2Test(test_util.TensorFlowTestCase):
|
class ResizeImageWithPadV2Test(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def _ResizeImageWithPad(self, x, target_height, target_width,
|
def _ResizeImageWithPad(self, x, target_height, target_width,
|
||||||
|
Loading…
Reference in New Issue
Block a user