diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD index 6bc33e10a23..88ddf7a7ec8 100644 --- a/tensorflow/python/ops/parallel_for/BUILD +++ b/tensorflow/python/ops/parallel_for/BUILD @@ -109,8 +109,6 @@ cuda_py_test( name = "control_flow_ops_test", srcs = ["control_flow_ops_test.py"], tags = ["no_rocm"], - # TODO(b/149957923): The test is flaky - xla_enable_strict_auto_jit = False, deps = [ ":control_flow_ops", ":test_util", diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index e6a67efa301..e33b7765ab1 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -31,6 +31,7 @@ from tensorflow.core.example import feature_pb2 from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function +from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices @@ -639,7 +640,7 @@ class RandomTest(PForTestCase): # The random values generated in the two implementations are not guaranteed to # match. So we only check the returned shapes. - def run_and_assert_equal(self, targets1, targets2): + def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): outputs = self._run_targets(targets1, targets2) n = len(outputs) // 2 for i in range(n): @@ -737,7 +738,7 @@ class StatelessRandomTest(PForTestCase): # stateless random numbers can generate different random numbers. # TODO(agarwal): switch to checking for actual values matching once # b/149402339 is resolved. - def run_and_assert_equal(self, targets1, targets2): + def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): outputs = self._run_targets(targets1, targets2) n = len(outputs) // 2 for i in range(n): @@ -1735,8 +1736,10 @@ class SpectralTest(PForTestCase, parameterized.TestCase): (fft_ops.irfft2d,), (fft_ops.irfft3d,), ) - # TODO(agarwal): Reenable this once the test flaky is fixed. - def disabled_test_irfft(self, op_func): + def test_irfft(self, op_func): + if config.list_physical_devices("GPU"): + # TODO(b/149957923): The test is flaky + self.skipTest("b/149957923: irfft vectorization flaky") for dtype in (dtypes.complex64, dtypes.complex128): shape = [2, 3, 4, 3, 4] x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape) diff --git a/tensorflow/python/ops/parallel_for/test_util.py b/tensorflow/python/ops/parallel_for/test_util.py index c8eed9ca54e..7d8a3d86a77 100644 --- a/tensorflow/python/ops/parallel_for/test_util.py +++ b/tensorflow/python/ops/parallel_for/test_util.py @@ -39,20 +39,25 @@ class PForTestCase(test.TestCase): return self.evaluate(targets1 + targets2) # TODO(agarwal): Allow tests to pass down tolerances. - def run_and_assert_equal(self, targets1, targets2): + def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): outputs = self._run_targets(targets1, targets2) outputs = nest.flatten(outputs) # flatten SparseTensorValues n = len(outputs) // 2 for i in range(n): if outputs[i + n].dtype != np.object: - self.assertAllClose(outputs[i + n], outputs[i], rtol=1e-4, atol=1e-4) + self.assertAllClose(outputs[i + n], outputs[i], rtol=rtol, atol=atol) else: self.assertAllEqual(outputs[i + n], outputs[i]) - def _test_loop_fn(self, loop_fn, iters, parallel_iterations=None): + def _test_loop_fn(self, + loop_fn, + iters, + parallel_iterations=None, + rtol=1e-4, + atol=1e-5): t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters, parallel_iterations=parallel_iterations) loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1) t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters, parallel_iterations=parallel_iterations) - self.run_and_assert_equal(t1, t2) + self.run_and_assert_equal(t1, t2, rtol=rtol, atol=atol)