From ac3f66b6549d672ffd63d24712a1b51806cf37d6 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Mon, 23 Mar 2020 18:21:18 -0700
Subject: [PATCH] Adds tolerance arguments to pfor test functions.

PiperOrigin-RevId: 302563427
Change-Id: I68165eb6052edfaf477ebea28bcc8f664cf8234f
---
 tensorflow/python/ops/parallel_for/BUILD            |  2 --
 .../ops/parallel_for/control_flow_ops_test.py       | 11 +++++++----
 tensorflow/python/ops/parallel_for/test_util.py     | 13 +++++++++----
 3 files changed, 16 insertions(+), 10 deletions(-)

diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD
index 6bc33e10a23..88ddf7a7ec8 100644
--- a/tensorflow/python/ops/parallel_for/BUILD
+++ b/tensorflow/python/ops/parallel_for/BUILD
@@ -109,8 +109,6 @@ cuda_py_test(
     name = "control_flow_ops_test",
     srcs = ["control_flow_ops_test.py"],
     tags = ["no_rocm"],
-    # TODO(b/149957923): The test is flaky
-    xla_enable_strict_auto_jit = False,
     deps = [
         ":control_flow_ops",
         ":test_util",
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index e6a67efa301..e33b7765ab1 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -31,6 +31,7 @@ from tensorflow.core.example import feature_pb2
 from tensorflow.python.client import session
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import def_function
+from tensorflow.python.framework import config
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import indexed_slices
@@ -639,7 +640,7 @@ class RandomTest(PForTestCase):
 
   # The random values generated in the two implementations are not guaranteed to
   # match. So we only check the returned shapes.
-  def run_and_assert_equal(self, targets1, targets2):
+  def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
     outputs = self._run_targets(targets1, targets2)
     n = len(outputs) // 2
     for i in range(n):
@@ -737,7 +738,7 @@ class StatelessRandomTest(PForTestCase):
   # stateless random numbers can generate different random numbers.
   # TODO(agarwal): switch to checking for actual values matching once
   # b/149402339 is resolved.
-  def run_and_assert_equal(self, targets1, targets2):
+  def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
     outputs = self._run_targets(targets1, targets2)
     n = len(outputs) // 2
     for i in range(n):
@@ -1735,8 +1736,10 @@ class SpectralTest(PForTestCase, parameterized.TestCase):
       (fft_ops.irfft2d,),
       (fft_ops.irfft3d,),
   )
-  # TODO(agarwal): Reenable this once the test flaky is fixed.
-  def disabled_test_irfft(self, op_func):
+  def test_irfft(self, op_func):
+    if config.list_physical_devices("GPU"):
+      # TODO(b/149957923): The test is flaky
+      self.skipTest("b/149957923: irfft vectorization flaky")
     for dtype in (dtypes.complex64, dtypes.complex128):
       shape = [2, 3, 4, 3, 4]
       x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape)
diff --git a/tensorflow/python/ops/parallel_for/test_util.py b/tensorflow/python/ops/parallel_for/test_util.py
index c8eed9ca54e..7d8a3d86a77 100644
--- a/tensorflow/python/ops/parallel_for/test_util.py
+++ b/tensorflow/python/ops/parallel_for/test_util.py
@@ -39,20 +39,25 @@ class PForTestCase(test.TestCase):
     return self.evaluate(targets1 + targets2)
 
   # TODO(agarwal): Allow tests to pass down tolerances.
-  def run_and_assert_equal(self, targets1, targets2):
+  def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
     outputs = self._run_targets(targets1, targets2)
     outputs = nest.flatten(outputs)  # flatten SparseTensorValues
     n = len(outputs) // 2
     for i in range(n):
       if outputs[i + n].dtype != np.object:
-        self.assertAllClose(outputs[i + n], outputs[i], rtol=1e-4, atol=1e-4)
+        self.assertAllClose(outputs[i + n], outputs[i], rtol=rtol, atol=atol)
       else:
         self.assertAllEqual(outputs[i + n], outputs[i])
 
-  def _test_loop_fn(self, loop_fn, iters, parallel_iterations=None):
+  def _test_loop_fn(self,
+                    loop_fn,
+                    iters,
+                    parallel_iterations=None,
+                    rtol=1e-4,
+                    atol=1e-5):
     t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters,
                                     parallel_iterations=parallel_iterations)
     loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1)
     t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters,
                                         parallel_iterations=parallel_iterations)
-    self.run_and_assert_equal(t1, t2)
+    self.run_and_assert_equal(t1, t2, rtol=rtol, atol=atol)