diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 5210c2e2d32..65f4257313a 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1554,20 +1554,31 @@ def disable_xla(description): return disable_xla_impl -# The description is just for documentation purposes. -def disable_all_xla(description): +def for_all_test_methods(decorator, *args, **kwargs): + """Generate class-level decorator from given method-level decorator. - def disable_all_impl(cls): - """Execute all test methods in this class only if xla is not enabled.""" - base_decorator = disable_xla + It is expected for the given decorator to take some arguments and return + a method that is then called on the test method to produce a decorated + method. + + Args: + decorator: The decorator to apply. + *args: Positional arguments + **kwargs: Keyword arguments + Returns: Function that will decorate a given classes test methods with the + decorator. + """ + + def all_test_methods_impl(cls): + """Apply decorator to all test methods in class.""" for name in dir(cls): value = getattr(cls, name) if callable(value) and name.startswith( - "test") and not name == "test_session": - setattr(cls, name, base_decorator(description)(value)) + "test") and (name != "test_session"): + setattr(cls, name, decorator(*args, **kwargs)(value)) return cls - return disable_all_impl + return all_test_methods_impl # The description is just for documentation purposes. diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 3b6d2ce26af..7a7761b5174 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -201,6 +201,8 @@ def _is_permute(node): 'VecPermuteNCHWToNHWC-LayoutOptimizer') +@test_util.for_all_test_methods(test_util.no_xla_auto_jit, + 'Test does not apply in XLA setting') class LayoutOptimizerTest(test.TestCase): """Tests the Grappler layout optimizer.""" diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py index ea088783606..82238fad694 100644 --- a/tensorflow/python/keras/layers/convolutional_test.py +++ b/tensorflow/python/keras/layers/convolutional_test.py @@ -421,7 +421,8 @@ class ZeroPaddingTest(keras_parameterized.TestCase): keras.layers.ZeroPadding3D(padding=None) -@test_util.disable_all_xla('align_corners=False not supported by XLA') +@test_util.for_all_test_methods(test_util.disable_xla, + 'align_corners=False not supported by XLA') @keras_parameterized.run_all_keras_modes class UpSamplingTest(keras_parameterized.TestCase): diff --git a/tensorflow/python/kernel_tests/random/random_ops_test.py b/tensorflow/python/kernel_tests/random/random_ops_test.py index 68672a04bbd..1023b8f7901 100644 --- a/tensorflow/python/kernel_tests/random/random_ops_test.py +++ b/tensorflow/python/kernel_tests/random/random_ops_test.py @@ -257,7 +257,8 @@ class TruncatedNormalTest(test.TestCase): self.assertAllEqual(rnd1, rnd2) -@test_util.disable_all_xla("This never passed on XLA") +@test_util.for_all_test_methods(test_util.disable_xla, + "This never passed on XLA") class RandomUniformTest(RandomOpTestCommon): def _Sampler(self, num, minv, maxv, dtype, use_gpu, seed=None): diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py index ea41ea39f98..43d9699980e 100644 --- a/tensorflow/python/ops/image_grad_test.py +++ b/tensorflow/python/ops/image_grad_test.py @@ -29,7 +29,8 @@ from tensorflow.python.ops import image_ops from tensorflow.python.platform import test -@test_util.disable_all_xla('align_corners=False not supported by XLA') +@test_util.for_all_test_methods(test_util.disable_xla, + 'align_corners=False not supported by XLA') class ResizeNearestNeighborOpTest(test.TestCase): TYPES = [np.float32, np.float64] diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index c55590a0f98..3526e0eed2b 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -3349,7 +3349,7 @@ class ResizeImageWithPadV1Test(test_util.TensorFlowTestCase): # half_pixel_centers not supported by XLA -@test_util.disable_all_xla("b/127616992") +@test_util.for_all_test_methods(test_util.disable_xla, "b/127616992") class ResizeImageWithPadV2Test(test_util.TensorFlowTestCase): def _ResizeImageWithPad(self, x, target_height, target_width,