Fix testGeomSpace on GPU.
PiperOrigin-RevId: 314926636 Change-Id: If5f518524ac7b5fef284949df7e21c0724228b0b
This commit is contained in:
parent
967782ff2a
commit
6a4711835a
@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.numpy_ops import np_array_ops
|
||||
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||
@ -404,12 +405,16 @@ class ArrayCreationTest(test.TestCase):
|
||||
def run_test(start, stop, **kwargs):
|
||||
arg1 = start
|
||||
arg2 = stop
|
||||
if test_util.is_gpu_available():
|
||||
decimal = 3
|
||||
else:
|
||||
decimal = 4
|
||||
self.match(
|
||||
np_array_ops.geomspace(arg1, arg2, **kwargs),
|
||||
np.geomspace(arg1, arg2, **kwargs),
|
||||
msg='geomspace({}, {})'.format(arg1, arg2),
|
||||
almost=True,
|
||||
decimal=4)
|
||||
decimal=decimal)
|
||||
|
||||
run_test(1, 1000, num=5)
|
||||
run_test(1, 1000, num=5, endpoint=False)
|
||||
|
Loading…
Reference in New Issue
Block a user