Adds tolerance arguments to pfor test functions.
PiperOrigin-RevId: 302563427 Change-Id: I68165eb6052edfaf477ebea28bcc8f664cf8234f
This commit is contained in:
parent
f748283ee0
commit
ac3f66b654
@ -109,8 +109,6 @@ cuda_py_test(
|
|||||||
name = "control_flow_ops_test",
|
name = "control_flow_ops_test",
|
||||||
srcs = ["control_flow_ops_test.py"],
|
srcs = ["control_flow_ops_test.py"],
|
||||||
tags = ["no_rocm"],
|
tags = ["no_rocm"],
|
||||||
# TODO(b/149957923): The test is flaky
|
|
||||||
xla_enable_strict_auto_jit = False,
|
|
||||||
deps = [
|
deps = [
|
||||||
":control_flow_ops",
|
":control_flow_ops",
|
||||||
":test_util",
|
":test_util",
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.core.example import feature_pb2
|
|||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import indexed_slices
|
from tensorflow.python.framework import indexed_slices
|
||||||
@ -639,7 +640,7 @@ class RandomTest(PForTestCase):
|
|||||||
|
|
||||||
# The random values generated in the two implementations are not guaranteed to
|
# The random values generated in the two implementations are not guaranteed to
|
||||||
# match. So we only check the returned shapes.
|
# match. So we only check the returned shapes.
|
||||||
def run_and_assert_equal(self, targets1, targets2):
|
def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
|
||||||
outputs = self._run_targets(targets1, targets2)
|
outputs = self._run_targets(targets1, targets2)
|
||||||
n = len(outputs) // 2
|
n = len(outputs) // 2
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
@ -737,7 +738,7 @@ class StatelessRandomTest(PForTestCase):
|
|||||||
# stateless random numbers can generate different random numbers.
|
# stateless random numbers can generate different random numbers.
|
||||||
# TODO(agarwal): switch to checking for actual values matching once
|
# TODO(agarwal): switch to checking for actual values matching once
|
||||||
# b/149402339 is resolved.
|
# b/149402339 is resolved.
|
||||||
def run_and_assert_equal(self, targets1, targets2):
|
def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
|
||||||
outputs = self._run_targets(targets1, targets2)
|
outputs = self._run_targets(targets1, targets2)
|
||||||
n = len(outputs) // 2
|
n = len(outputs) // 2
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
@ -1735,8 +1736,10 @@ class SpectralTest(PForTestCase, parameterized.TestCase):
|
|||||||
(fft_ops.irfft2d,),
|
(fft_ops.irfft2d,),
|
||||||
(fft_ops.irfft3d,),
|
(fft_ops.irfft3d,),
|
||||||
)
|
)
|
||||||
# TODO(agarwal): Reenable this once the test flaky is fixed.
|
def test_irfft(self, op_func):
|
||||||
def disabled_test_irfft(self, op_func):
|
if config.list_physical_devices("GPU"):
|
||||||
|
# TODO(b/149957923): The test is flaky
|
||||||
|
self.skipTest("b/149957923: irfft vectorization flaky")
|
||||||
for dtype in (dtypes.complex64, dtypes.complex128):
|
for dtype in (dtypes.complex64, dtypes.complex128):
|
||||||
shape = [2, 3, 4, 3, 4]
|
shape = [2, 3, 4, 3, 4]
|
||||||
x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape)
|
x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape)
|
||||||
|
@ -39,20 +39,25 @@ class PForTestCase(test.TestCase):
|
|||||||
return self.evaluate(targets1 + targets2)
|
return self.evaluate(targets1 + targets2)
|
||||||
|
|
||||||
# TODO(agarwal): Allow tests to pass down tolerances.
|
# TODO(agarwal): Allow tests to pass down tolerances.
|
||||||
def run_and_assert_equal(self, targets1, targets2):
|
def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
|
||||||
outputs = self._run_targets(targets1, targets2)
|
outputs = self._run_targets(targets1, targets2)
|
||||||
outputs = nest.flatten(outputs) # flatten SparseTensorValues
|
outputs = nest.flatten(outputs) # flatten SparseTensorValues
|
||||||
n = len(outputs) // 2
|
n = len(outputs) // 2
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
if outputs[i + n].dtype != np.object:
|
if outputs[i + n].dtype != np.object:
|
||||||
self.assertAllClose(outputs[i + n], outputs[i], rtol=1e-4, atol=1e-4)
|
self.assertAllClose(outputs[i + n], outputs[i], rtol=rtol, atol=atol)
|
||||||
else:
|
else:
|
||||||
self.assertAllEqual(outputs[i + n], outputs[i])
|
self.assertAllEqual(outputs[i + n], outputs[i])
|
||||||
|
|
||||||
def _test_loop_fn(self, loop_fn, iters, parallel_iterations=None):
|
def _test_loop_fn(self,
|
||||||
|
loop_fn,
|
||||||
|
iters,
|
||||||
|
parallel_iterations=None,
|
||||||
|
rtol=1e-4,
|
||||||
|
atol=1e-5):
|
||||||
t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters,
|
t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters,
|
||||||
parallel_iterations=parallel_iterations)
|
parallel_iterations=parallel_iterations)
|
||||||
loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1)
|
loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1)
|
||||||
t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters,
|
t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters,
|
||||||
parallel_iterations=parallel_iterations)
|
parallel_iterations=parallel_iterations)
|
||||||
self.run_and_assert_equal(t1, t2)
|
self.run_and_assert_equal(t1, t2, rtol=rtol, atol=atol)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user