Adds tolerance arguments to pfor test functions.

PiperOrigin-RevId: 302563427
Change-Id: I68165eb6052edfaf477ebea28bcc8f664cf8234f
This commit is contained in:
A. Unique TensorFlower 2020-03-23 18:21:18 -07:00 committed by TensorFlower Gardener
parent f748283ee0
commit ac3f66b654
3 changed files with 16 additions and 10 deletions
tensorflow/python/ops/parallel_for

View File

@ -109,8 +109,6 @@ cuda_py_test(
name = "control_flow_ops_test",
srcs = ["control_flow_ops_test.py"],
tags = ["no_rocm"],
# TODO(b/149957923): The test is flaky
xla_enable_strict_auto_jit = False,
deps = [
":control_flow_ops",
":test_util",

View File

@ -31,6 +31,7 @@ from tensorflow.core.example import feature_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
@ -639,7 +640,7 @@ class RandomTest(PForTestCase):
# The random values generated in the two implementations are not guaranteed to
# match. So we only check the returned shapes.
def run_and_assert_equal(self, targets1, targets2):
def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
outputs = self._run_targets(targets1, targets2)
n = len(outputs) // 2
for i in range(n):
@ -737,7 +738,7 @@ class StatelessRandomTest(PForTestCase):
# stateless random numbers can generate different random numbers.
# TODO(agarwal): switch to checking for actual values matching once
# b/149402339 is resolved.
def run_and_assert_equal(self, targets1, targets2):
def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
outputs = self._run_targets(targets1, targets2)
n = len(outputs) // 2
for i in range(n):
@ -1735,8 +1736,10 @@ class SpectralTest(PForTestCase, parameterized.TestCase):
(fft_ops.irfft2d,),
(fft_ops.irfft3d,),
)
# TODO(agarwal): Reenable this once the test flaky is fixed.
def disabled_test_irfft(self, op_func):
def test_irfft(self, op_func):
if config.list_physical_devices("GPU"):
# TODO(b/149957923): The test is flaky
self.skipTest("b/149957923: irfft vectorization flaky")
for dtype in (dtypes.complex64, dtypes.complex128):
shape = [2, 3, 4, 3, 4]
x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape)

View File

@ -39,20 +39,25 @@ class PForTestCase(test.TestCase):
return self.evaluate(targets1 + targets2)
# TODO(agarwal): Allow tests to pass down tolerances.
def run_and_assert_equal(self, targets1, targets2):
def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
outputs = self._run_targets(targets1, targets2)
outputs = nest.flatten(outputs) # flatten SparseTensorValues
n = len(outputs) // 2
for i in range(n):
if outputs[i + n].dtype != np.object:
self.assertAllClose(outputs[i + n], outputs[i], rtol=1e-4, atol=1e-4)
self.assertAllClose(outputs[i + n], outputs[i], rtol=rtol, atol=atol)
else:
self.assertAllEqual(outputs[i + n], outputs[i])
def _test_loop_fn(self, loop_fn, iters, parallel_iterations=None):
def _test_loop_fn(self,
loop_fn,
iters,
parallel_iterations=None,
rtol=1e-4,
atol=1e-5):
t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters,
parallel_iterations=parallel_iterations)
loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1)
t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters,
parallel_iterations=parallel_iterations)
self.run_and_assert_equal(t1, t2)
self.run_and_assert_equal(t1, t2, rtol=rtol, atol=atol)