Fix the correct test in pooling_ops_test
PiperOrigin-RevId: 237288986
This commit is contained in:
parent
2c27df89ac
commit
509b632a8a
@ -39,6 +39,15 @@ from tensorflow.python.platform import test
|
|||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
|
|
||||||
|
|
||||||
|
def GetDeviceScope(self, use_gpu=False):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
if use_gpu and test.is_gpu_available():
|
||||||
|
return ops.device("GPU:0")
|
||||||
|
return ops.device("CPU:0")
|
||||||
|
else:
|
||||||
|
return self.session(use_gpu=use_gpu)
|
||||||
|
|
||||||
|
|
||||||
def GetTestConfigs(include_nchw_vect_c=False):
|
def GetTestConfigs(include_nchw_vect_c=False):
|
||||||
"""Get all the valid tests configs to run.
|
"""Get all the valid tests configs to run.
|
||||||
|
|
||||||
@ -802,12 +811,7 @@ class PoolingTest(test.TestCase):
|
|||||||
# Generate numbers in a narrow range, so that there are many duplicates
|
# Generate numbers in a narrow range, so that there are many duplicates
|
||||||
# in the input.
|
# in the input.
|
||||||
tensor_input = np.random.random_integers(0, 3, input_shape).astype(dtype)
|
tensor_input = np.random.random_integers(0, 3, input_shape).astype(dtype)
|
||||||
def get_device_scope():
|
with self.cached_session(use_gpu=False):
|
||||||
if context.executing_eagerly():
|
|
||||||
return ops.device("GPU:0")
|
|
||||||
else:
|
|
||||||
return self.cached_session(use_gpu=True)
|
|
||||||
with get_device_scope():
|
|
||||||
t = constant_op.constant(tensor_input, shape=input_shape)
|
t = constant_op.constant(tensor_input, shape=input_shape)
|
||||||
_, argmax_op = nn_ops.max_pool_with_argmax(t, ksize, strides, padding)
|
_, argmax_op = nn_ops.max_pool_with_argmax(t, ksize, strides, padding)
|
||||||
argmax = self.evaluate(argmax_op)
|
argmax = self.evaluate(argmax_op)
|
||||||
@ -840,7 +844,7 @@ class PoolingTest(test.TestCase):
|
|||||||
[True, False, [0, 1, 3, 5, 0, 2, 6, 8]]]
|
[True, False, [0, 1, 3, 5, 0, 2, 6, 8]]]
|
||||||
|
|
||||||
for use_gpu, include_batch_in_index, argmax_exp in configs:
|
for use_gpu, include_batch_in_index, argmax_exp in configs:
|
||||||
with self.session(use_gpu=use_gpu):
|
with GetDeviceScope(self, use_gpu=use_gpu):
|
||||||
t = constant_op.constant(tensor_input, shape=[2, 3, 3, 1])
|
t = constant_op.constant(tensor_input, shape=[2, 3, 3, 1])
|
||||||
out_op, argmax_op = nn_ops.max_pool_with_argmax(
|
out_op, argmax_op = nn_ops.max_pool_with_argmax(
|
||||||
t,
|
t,
|
||||||
@ -868,7 +872,7 @@ class PoolingTest(test.TestCase):
|
|||||||
[True, False, [0, 1, 3, 5, 0, 2, 6, 8]]]
|
[True, False, [0, 1, 3, 5, 0, 2, 6, 8]]]
|
||||||
|
|
||||||
for use_gpu, include_batch_in_index, argmax in configs:
|
for use_gpu, include_batch_in_index, argmax in configs:
|
||||||
with self.session(use_gpu=use_gpu):
|
with GetDeviceScope(self, use_gpu):
|
||||||
orig_in = constant_op.constant(orig_input, shape=[2, 3, 3, 1])
|
orig_in = constant_op.constant(orig_input, shape=[2, 3, 3, 1])
|
||||||
t = constant_op.constant(tensor_input, shape=[2, 2, 2, 1])
|
t = constant_op.constant(tensor_input, shape=[2, 2, 2, 1])
|
||||||
argmax_t = constant_op.constant(
|
argmax_t = constant_op.constant(
|
||||||
|
Loading…
Reference in New Issue
Block a user