Fix the correct test in pooling_ops_test
PiperOrigin-RevId: 237288986
This commit is contained in:
parent
2c27df89ac
commit
509b632a8a
@ -39,6 +39,15 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
|
||||
|
||||
def GetDeviceScope(self, use_gpu=False):
|
||||
if context.executing_eagerly():
|
||||
if use_gpu and test.is_gpu_available():
|
||||
return ops.device("GPU:0")
|
||||
return ops.device("CPU:0")
|
||||
else:
|
||||
return self.session(use_gpu=use_gpu)
|
||||
|
||||
|
||||
def GetTestConfigs(include_nchw_vect_c=False):
|
||||
"""Get all the valid tests configs to run.
|
||||
|
||||
@ -802,12 +811,7 @@ class PoolingTest(test.TestCase):
|
||||
# Generate numbers in a narrow range, so that there are many duplicates
|
||||
# in the input.
|
||||
tensor_input = np.random.random_integers(0, 3, input_shape).astype(dtype)
|
||||
def get_device_scope():
|
||||
if context.executing_eagerly():
|
||||
return ops.device("GPU:0")
|
||||
else:
|
||||
return self.cached_session(use_gpu=True)
|
||||
with get_device_scope():
|
||||
with self.cached_session(use_gpu=False):
|
||||
t = constant_op.constant(tensor_input, shape=input_shape)
|
||||
_, argmax_op = nn_ops.max_pool_with_argmax(t, ksize, strides, padding)
|
||||
argmax = self.evaluate(argmax_op)
|
||||
@ -840,7 +844,7 @@ class PoolingTest(test.TestCase):
|
||||
[True, False, [0, 1, 3, 5, 0, 2, 6, 8]]]
|
||||
|
||||
for use_gpu, include_batch_in_index, argmax_exp in configs:
|
||||
with self.session(use_gpu=use_gpu):
|
||||
with GetDeviceScope(self, use_gpu=use_gpu):
|
||||
t = constant_op.constant(tensor_input, shape=[2, 3, 3, 1])
|
||||
out_op, argmax_op = nn_ops.max_pool_with_argmax(
|
||||
t,
|
||||
@ -868,7 +872,7 @@ class PoolingTest(test.TestCase):
|
||||
[True, False, [0, 1, 3, 5, 0, 2, 6, 8]]]
|
||||
|
||||
for use_gpu, include_batch_in_index, argmax in configs:
|
||||
with self.session(use_gpu=use_gpu):
|
||||
with GetDeviceScope(self, use_gpu):
|
||||
orig_in = constant_op.constant(orig_input, shape=[2, 3, 3, 1])
|
||||
t = constant_op.constant(tensor_input, shape=[2, 2, 2, 1])
|
||||
argmax_t = constant_op.constant(
|
||||
|
Loading…
Reference in New Issue
Block a user