Convert linspace tests to parameterized tests
Helps with TSAN timeouts. PiperOrigin-RevId: 333195144 Change-Id: Iceb62706c9664ef14cb78f59919631af81e30389
This commit is contained in:
parent
070efb9ebf
commit
4c71606397
@ -5282,7 +5282,6 @@ cuda_py_test(
|
|||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_windows_gpu",
|
"no_windows_gpu",
|
||||||
"notsan", # b/168815578
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":framework_for_generated_wrappers",
|
":framework_for_generated_wrappers",
|
||||||
|
@ -21,8 +21,8 @@ from __future__ import print_function
|
|||||||
# Using distutils.version.LooseVersion was resulting in an error, so importing
|
# Using distutils.version.LooseVersion was resulting in an error, so importing
|
||||||
# directly.
|
# directly.
|
||||||
from distutils.version import LooseVersion # pylint: disable=g-importing-member
|
from distutils.version import LooseVersion # pylint: disable=g-importing-member
|
||||||
import itertools
|
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -31,29 +31,36 @@ from tensorflow.python.platform import googletest
|
|||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class LinspaceTest(test_util.TensorFlowTestCase):
|
class LinspaceTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def testLinspaceBroadcasts(self):
|
# pylint: disable=g-complex-comprehension
|
||||||
|
@parameterized.parameters([
|
||||||
|
{
|
||||||
|
"start_shape": start_shape,
|
||||||
|
"stop_shape": stop_shape,
|
||||||
|
"dtype": dtype,
|
||||||
|
"num": num
|
||||||
|
}
|
||||||
|
for start_shape in [(), (2,), (2, 2)]
|
||||||
|
for stop_shape in [(), (2,), (2, 2)]
|
||||||
|
for dtype in [np.float64, np.int64]
|
||||||
|
for num in [0, 1, 2, 20]
|
||||||
|
])
|
||||||
|
# pylint: enable=g-complex-comprehension
|
||||||
|
def testLinspaceBroadcasts(self, start_shape, stop_shape, dtype, num):
|
||||||
if LooseVersion(np.version.version) < LooseVersion("1.16.0"):
|
if LooseVersion(np.version.version) < LooseVersion("1.16.0"):
|
||||||
self.skipTest("numpy doesn't support axes before version 1.16.0")
|
self.skipTest("numpy doesn't support axes before version 1.16.0")
|
||||||
|
|
||||||
shapes = [(), (2,), (2, 2)]
|
ndims = max(len(start_shape), len(stop_shape))
|
||||||
|
for axis in range(-ndims, ndims):
|
||||||
|
start = np.ones(start_shape, dtype)
|
||||||
|
stop = 10 * np.ones(stop_shape, dtype)
|
||||||
|
|
||||||
types = [np.float64, np.int64]
|
np_ans = np.linspace(start, stop, num, axis=axis)
|
||||||
|
tf_ans = self.evaluate(
|
||||||
|
math_ops.linspace_nd(start, stop, num, axis=axis))
|
||||||
|
|
||||||
for start_shape, stop_shape in itertools.product(shapes, repeat=2):
|
self.assertAllClose(np_ans, tf_ans)
|
||||||
for num in [0, 1, 2, 20]:
|
|
||||||
ndims = max(len(start_shape), len(stop_shape))
|
|
||||||
for axis in range(-ndims, ndims):
|
|
||||||
for dtype in types:
|
|
||||||
start = np.ones(start_shape, dtype)
|
|
||||||
stop = 10 * np.ones(stop_shape, dtype)
|
|
||||||
|
|
||||||
np_ans = np.linspace(start, stop, num, axis=axis)
|
|
||||||
tf_ans = self.evaluate(
|
|
||||||
math_ops.linspace_nd(start, stop, num, axis=axis))
|
|
||||||
|
|
||||||
self.assertAllClose(np_ans, tf_ans)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user