Fix testGeomSpace on GPU.
PiperOrigin-RevId: 314926636 Change-Id: If5f518524ac7b5fef284949df7e21c0724228b0b
This commit is contained in:
parent
967782ff2a
commit
6a4711835a
@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import indexed_slices
|
from tensorflow.python.framework import indexed_slices
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops.numpy_ops import np_array_ops
|
from tensorflow.python.ops.numpy_ops import np_array_ops
|
||||||
from tensorflow.python.ops.numpy_ops import np_arrays
|
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||||
@ -404,12 +405,16 @@ class ArrayCreationTest(test.TestCase):
|
|||||||
def run_test(start, stop, **kwargs):
|
def run_test(start, stop, **kwargs):
|
||||||
arg1 = start
|
arg1 = start
|
||||||
arg2 = stop
|
arg2 = stop
|
||||||
|
if test_util.is_gpu_available():
|
||||||
|
decimal = 3
|
||||||
|
else:
|
||||||
|
decimal = 4
|
||||||
self.match(
|
self.match(
|
||||||
np_array_ops.geomspace(arg1, arg2, **kwargs),
|
np_array_ops.geomspace(arg1, arg2, **kwargs),
|
||||||
np.geomspace(arg1, arg2, **kwargs),
|
np.geomspace(arg1, arg2, **kwargs),
|
||||||
msg='geomspace({}, {})'.format(arg1, arg2),
|
msg='geomspace({}, {})'.format(arg1, arg2),
|
||||||
almost=True,
|
almost=True,
|
||||||
decimal=4)
|
decimal=decimal)
|
||||||
|
|
||||||
run_test(1, 1000, num=5)
|
run_test(1, 1000, num=5)
|
||||||
run_test(1, 1000, num=5, endpoint=False)
|
run_test(1, 1000, num=5, endpoint=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user