diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 2b057b5db57..d93c2314954 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5079,6 +5079,21 @@ cuda_py_test( ], ) +cuda_py_test( + name = "math_ops_linspace_test", + size = "medium", + srcs = ["ops/math_ops_linspace_test.py"], + python_version = "PY3", + tags = ["no_windows_gpu"], + deps = [ + ":framework_for_generated_wrappers", + ":framework_test_lib", + ":math_ops", + ":platform_test", + "//third_party/py/numpy", + ], +) + cuda_py_test( name = "nn_batchnorm_test", size = "medium", diff --git a/tensorflow/python/ops/math_ops_linspace_test.py b/tensorflow/python/ops/math_ops_linspace_test.py new file mode 100644 index 00000000000..f56b1c9284d --- /dev/null +++ b/tensorflow/python/ops/math_ops_linspace_test.py @@ -0,0 +1,60 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.math_ops.linspace.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Using distutils.version.LooseVersion was resulting in an error, so importing +# directly. +from distutils.version import LooseVersion # pylint: disable=g-importing-member +import itertools + +import numpy as np + +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +@test_util.run_all_in_graph_and_eager_modes +class LinspaceTest(test_util.TensorFlowTestCase): + + def testLinspaceBroadcasts(self): + if LooseVersion(np.version.version) < LooseVersion("1.16.0"): + self.skipTest("numpy doesn't support axes before version 1.16.0") + + shapes = [(), (2,), (2, 2)] + + types = [np.float64, np.int64] + + for start_shape, stop_shape in itertools.product(shapes, repeat=2): + for num in [0, 1, 2, 20]: + ndims = max(len(start_shape), len(stop_shape)) + for axis in range(-ndims, ndims): + for dtype in types: + start = np.ones(start_shape, dtype) + stop = 10 * np.ones(stop_shape, dtype) + + np_ans = np.linspace(start, stop, num, axis=axis) + tf_ans = self.evaluate( + math_ops.linspace_nd(start, stop, num, axis=axis)) + + self.assertAllClose(np_ans, tf_ans) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 85a5afc6c16..afa1dbdbaf7 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -17,9 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import distutils -import itertools - import numpy as np from tensorflow.python.eager import backprop @@ -824,32 +821,5 @@ class RangeTest(test_util.TensorFlowTestCase): self.assertAllEqual(values, self.evaluate(tensor)) -@test_util.run_all_in_graph_and_eager_modes -class LinspaceTest(test_util.TensorFlowTestCase): - - def testLinspaceBroadcasts(self): - if distutils.version.LooseVersion( - np.version.version) < distutils.version.LooseVersion("1.16.0"): - self.skipTest("numpy doesn't support axes before version 1.16.0") - - shapes = [(), (2,), (2, 2)] - - types = [np.float64, np.int64] - - for start_shape, stop_shape in itertools.product(shapes, repeat=2): - for num in [0, 1, 2, 20]: - ndims = max(len(start_shape), len(stop_shape)) - for axis in range(-ndims, ndims): - for dtype in types: - start = np.ones(start_shape, dtype) - stop = 10 * np.ones(stop_shape, dtype) - - np_ans = np.linspace(start, stop, num, axis=axis) - tf_ans = self.evaluate( - math_ops.linspace_nd(start, stop, num, axis=axis)) - - self.assertAllClose(np_ans, tf_ans) - - if __name__ == "__main__": googletest.main()