From 509b632a8a4e84cd95f97f99b5ecb770802cf169 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Thu, 7 Mar 2019 11:37:38 -0800 Subject: [PATCH] Fix the correct test in pooling_ops_test PiperOrigin-RevId: 237288986 --- .../python/kernel_tests/pooling_ops_test.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index 224405ace67..e87cccac1e6 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -39,6 +39,15 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging +def GetDeviceScope(self, use_gpu=False): + if context.executing_eagerly(): + if use_gpu and test.is_gpu_available(): + return ops.device("GPU:0") + return ops.device("CPU:0") + else: + return self.session(use_gpu=use_gpu) + + def GetTestConfigs(include_nchw_vect_c=False): """Get all the valid tests configs to run. @@ -802,12 +811,7 @@ class PoolingTest(test.TestCase): # Generate numbers in a narrow range, so that there are many duplicates # in the input. tensor_input = np.random.random_integers(0, 3, input_shape).astype(dtype) - def get_device_scope(): - if context.executing_eagerly(): - return ops.device("GPU:0") - else: - return self.cached_session(use_gpu=True) - with get_device_scope(): + with self.cached_session(use_gpu=False): t = constant_op.constant(tensor_input, shape=input_shape) _, argmax_op = nn_ops.max_pool_with_argmax(t, ksize, strides, padding) argmax = self.evaluate(argmax_op) @@ -840,7 +844,7 @@ class PoolingTest(test.TestCase): [True, False, [0, 1, 3, 5, 0, 2, 6, 8]]] for use_gpu, include_batch_in_index, argmax_exp in configs: - with self.session(use_gpu=use_gpu): + with GetDeviceScope(self, use_gpu=use_gpu): t = constant_op.constant(tensor_input, shape=[2, 3, 3, 1]) out_op, argmax_op = nn_ops.max_pool_with_argmax( t, @@ -868,7 +872,7 @@ class PoolingTest(test.TestCase): [True, False, [0, 1, 3, 5, 0, 2, 6, 8]]] for use_gpu, include_batch_in_index, argmax in configs: - with self.session(use_gpu=use_gpu): + with GetDeviceScope(self, use_gpu): orig_in = constant_op.constant(orig_input, shape=[2, 3, 3, 1]) t = constant_op.constant(tensor_input, shape=[2, 2, 2, 1]) argmax_t = constant_op.constant(