Convert linspace tests to parameterized tests
Helps with TSAN timeouts. PiperOrigin-RevId: 333195144 Change-Id: Iceb62706c9664ef14cb78f59919631af81e30389
This commit is contained in:
parent
070efb9ebf
commit
4c71606397
tensorflow/python
@ -5282,7 +5282,6 @@ cuda_py_test(
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_windows_gpu",
|
||||
"notsan", # b/168815578
|
||||
],
|
||||
deps = [
|
||||
":framework_for_generated_wrappers",
|
||||
|
@ -21,8 +21,8 @@ from __future__ import print_function
|
||||
# Using distutils.version.LooseVersion was resulting in an error, so importing
|
||||
# directly.
|
||||
from distutils.version import LooseVersion # pylint: disable=g-importing-member
|
||||
import itertools
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -31,29 +31,36 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LinspaceTest(test_util.TensorFlowTestCase):
|
||||
class LinspaceTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def testLinspaceBroadcasts(self):
|
||||
# pylint: disable=g-complex-comprehension
|
||||
@parameterized.parameters([
|
||||
{
|
||||
"start_shape": start_shape,
|
||||
"stop_shape": stop_shape,
|
||||
"dtype": dtype,
|
||||
"num": num
|
||||
}
|
||||
for start_shape in [(), (2,), (2, 2)]
|
||||
for stop_shape in [(), (2,), (2, 2)]
|
||||
for dtype in [np.float64, np.int64]
|
||||
for num in [0, 1, 2, 20]
|
||||
])
|
||||
# pylint: enable=g-complex-comprehension
|
||||
def testLinspaceBroadcasts(self, start_shape, stop_shape, dtype, num):
|
||||
if LooseVersion(np.version.version) < LooseVersion("1.16.0"):
|
||||
self.skipTest("numpy doesn't support axes before version 1.16.0")
|
||||
|
||||
shapes = [(), (2,), (2, 2)]
|
||||
ndims = max(len(start_shape), len(stop_shape))
|
||||
for axis in range(-ndims, ndims):
|
||||
start = np.ones(start_shape, dtype)
|
||||
stop = 10 * np.ones(stop_shape, dtype)
|
||||
|
||||
types = [np.float64, np.int64]
|
||||
np_ans = np.linspace(start, stop, num, axis=axis)
|
||||
tf_ans = self.evaluate(
|
||||
math_ops.linspace_nd(start, stop, num, axis=axis))
|
||||
|
||||
for start_shape, stop_shape in itertools.product(shapes, repeat=2):
|
||||
for num in [0, 1, 2, 20]:
|
||||
ndims = max(len(start_shape), len(stop_shape))
|
||||
for axis in range(-ndims, ndims):
|
||||
for dtype in types:
|
||||
start = np.ones(start_shape, dtype)
|
||||
stop = 10 * np.ones(stop_shape, dtype)
|
||||
|
||||
np_ans = np.linspace(start, stop, num, axis=axis)
|
||||
tf_ans = self.evaluate(
|
||||
math_ops.linspace_nd(start, stop, num, axis=axis))
|
||||
|
||||
self.assertAllClose(np_ans, tf_ans)
|
||||
self.assertAllClose(np_ans, tf_ans)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user