Adds tolerance arguments to pfor test functions.
PiperOrigin-RevId: 302563427 Change-Id: I68165eb6052edfaf477ebea28bcc8f664cf8234f
This commit is contained in:
parent
f748283ee0
commit
ac3f66b654
tensorflow/python/ops/parallel_for
@ -109,8 +109,6 @@ cuda_py_test(
|
||||
name = "control_flow_ops_test",
|
||||
srcs = ["control_flow_ops_test.py"],
|
||||
tags = ["no_rocm"],
|
||||
# TODO(b/149957923): The test is flaky
|
||||
xla_enable_strict_auto_jit = False,
|
||||
deps = [
|
||||
":control_flow_ops",
|
||||
":test_util",
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.core.example import feature_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
@ -639,7 +640,7 @@ class RandomTest(PForTestCase):
|
||||
|
||||
# The random values generated in the two implementations are not guaranteed to
|
||||
# match. So we only check the returned shapes.
|
||||
def run_and_assert_equal(self, targets1, targets2):
|
||||
def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
|
||||
outputs = self._run_targets(targets1, targets2)
|
||||
n = len(outputs) // 2
|
||||
for i in range(n):
|
||||
@ -737,7 +738,7 @@ class StatelessRandomTest(PForTestCase):
|
||||
# stateless random numbers can generate different random numbers.
|
||||
# TODO(agarwal): switch to checking for actual values matching once
|
||||
# b/149402339 is resolved.
|
||||
def run_and_assert_equal(self, targets1, targets2):
|
||||
def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
|
||||
outputs = self._run_targets(targets1, targets2)
|
||||
n = len(outputs) // 2
|
||||
for i in range(n):
|
||||
@ -1735,8 +1736,10 @@ class SpectralTest(PForTestCase, parameterized.TestCase):
|
||||
(fft_ops.irfft2d,),
|
||||
(fft_ops.irfft3d,),
|
||||
)
|
||||
# TODO(agarwal): Reenable this once the test flaky is fixed.
|
||||
def disabled_test_irfft(self, op_func):
|
||||
def test_irfft(self, op_func):
|
||||
if config.list_physical_devices("GPU"):
|
||||
# TODO(b/149957923): The test is flaky
|
||||
self.skipTest("b/149957923: irfft vectorization flaky")
|
||||
for dtype in (dtypes.complex64, dtypes.complex128):
|
||||
shape = [2, 3, 4, 3, 4]
|
||||
x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape)
|
||||
|
@ -39,20 +39,25 @@ class PForTestCase(test.TestCase):
|
||||
return self.evaluate(targets1 + targets2)
|
||||
|
||||
# TODO(agarwal): Allow tests to pass down tolerances.
|
||||
def run_and_assert_equal(self, targets1, targets2):
|
||||
def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
|
||||
outputs = self._run_targets(targets1, targets2)
|
||||
outputs = nest.flatten(outputs) # flatten SparseTensorValues
|
||||
n = len(outputs) // 2
|
||||
for i in range(n):
|
||||
if outputs[i + n].dtype != np.object:
|
||||
self.assertAllClose(outputs[i + n], outputs[i], rtol=1e-4, atol=1e-4)
|
||||
self.assertAllClose(outputs[i + n], outputs[i], rtol=rtol, atol=atol)
|
||||
else:
|
||||
self.assertAllEqual(outputs[i + n], outputs[i])
|
||||
|
||||
def _test_loop_fn(self, loop_fn, iters, parallel_iterations=None):
|
||||
def _test_loop_fn(self,
|
||||
loop_fn,
|
||||
iters,
|
||||
parallel_iterations=None,
|
||||
rtol=1e-4,
|
||||
atol=1e-5):
|
||||
t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters,
|
||||
parallel_iterations=parallel_iterations)
|
||||
loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1)
|
||||
t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters,
|
||||
parallel_iterations=parallel_iterations)
|
||||
self.run_and_assert_equal(t1, t2)
|
||||
self.run_and_assert_equal(t1, t2, rtol=rtol, atol=atol)
|
||||
|
Loading…
Reference in New Issue
Block a user