From 6a4711835ade7ba20547d8605ead805e90835e3d Mon Sep 17 00:00:00 2001 From: Yujing Zhang Date: Fri, 5 Jun 2020 07:46:30 -0700 Subject: [PATCH] Fix testGeomSpace on GPU. PiperOrigin-RevId: 314926636 Change-Id: If5f518524ac7b5fef284949df7e21c0724228b0b --- tensorflow/python/ops/numpy_ops/np_array_ops_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops_test.py b/tensorflow/python/ops/numpy_ops/np_array_ops_test.py index d69deda2d73..8194c1b2897 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops_test.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops.numpy_ops import np_array_ops from tensorflow.python.ops.numpy_ops import np_arrays @@ -404,12 +405,16 @@ class ArrayCreationTest(test.TestCase): def run_test(start, stop, **kwargs): arg1 = start arg2 = stop + if test_util.is_gpu_available(): + decimal = 3 + else: + decimal = 4 self.match( np_array_ops.geomspace(arg1, arg2, **kwargs), np.geomspace(arg1, arg2, **kwargs), msg='geomspace({}, {})'.format(arg1, arg2), almost=True, - decimal=4) + decimal=decimal) run_test(1, 1000, num=5) run_test(1, 1000, num=5, endpoint=False)