Fix the correct test in pooling_ops_test

PiperOrigin-RevId: 237288986
This commit is contained in:
Akshay Modi 2019-03-07 11:37:38 -08:00 committed by TensorFlower Gardener
parent 2c27df89ac
commit 509b632a8a

View File

@ -39,6 +39,15 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
def GetDeviceScope(self, use_gpu=False):
if context.executing_eagerly():
if use_gpu and test.is_gpu_available():
return ops.device("GPU:0")
return ops.device("CPU:0")
else:
return self.session(use_gpu=use_gpu)
def GetTestConfigs(include_nchw_vect_c=False):
"""Get all the valid tests configs to run.
@ -802,12 +811,7 @@ class PoolingTest(test.TestCase):
# Generate numbers in a narrow range, so that there are many duplicates
# in the input.
tensor_input = np.random.random_integers(0, 3, input_shape).astype(dtype)
def get_device_scope():
if context.executing_eagerly():
return ops.device("GPU:0")
else:
return self.cached_session(use_gpu=True)
with get_device_scope():
with self.cached_session(use_gpu=False):
t = constant_op.constant(tensor_input, shape=input_shape)
_, argmax_op = nn_ops.max_pool_with_argmax(t, ksize, strides, padding)
argmax = self.evaluate(argmax_op)
@ -840,7 +844,7 @@ class PoolingTest(test.TestCase):
[True, False, [0, 1, 3, 5, 0, 2, 6, 8]]]
for use_gpu, include_batch_in_index, argmax_exp in configs:
with self.session(use_gpu=use_gpu):
with GetDeviceScope(self, use_gpu=use_gpu):
t = constant_op.constant(tensor_input, shape=[2, 3, 3, 1])
out_op, argmax_op = nn_ops.max_pool_with_argmax(
t,
@ -868,7 +872,7 @@ class PoolingTest(test.TestCase):
[True, False, [0, 1, 3, 5, 0, 2, 6, 8]]]
for use_gpu, include_batch_in_index, argmax in configs:
with self.session(use_gpu=use_gpu):
with GetDeviceScope(self, use_gpu):
orig_in = constant_op.constant(orig_input, shape=[2, 3, 3, 1])
t = constant_op.constant(tensor_input, shape=[2, 2, 2, 1])
argmax_t = constant_op.constant(