diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index 224405ace67..e87cccac1e6 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -39,6 +39,15 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging +def GetDeviceScope(self, use_gpu=False): + if context.executing_eagerly(): + if use_gpu and test.is_gpu_available(): + return ops.device("GPU:0") + return ops.device("CPU:0") + else: + return self.session(use_gpu=use_gpu) + + def GetTestConfigs(include_nchw_vect_c=False): """Get all the valid tests configs to run. @@ -802,12 +811,7 @@ class PoolingTest(test.TestCase): # Generate numbers in a narrow range, so that there are many duplicates # in the input. tensor_input = np.random.random_integers(0, 3, input_shape).astype(dtype) - def get_device_scope(): - if context.executing_eagerly(): - return ops.device("GPU:0") - else: - return self.cached_session(use_gpu=True) - with get_device_scope(): + with self.cached_session(use_gpu=False): t = constant_op.constant(tensor_input, shape=input_shape) _, argmax_op = nn_ops.max_pool_with_argmax(t, ksize, strides, padding) argmax = self.evaluate(argmax_op) @@ -840,7 +844,7 @@ class PoolingTest(test.TestCase): [True, False, [0, 1, 3, 5, 0, 2, 6, 8]]] for use_gpu, include_batch_in_index, argmax_exp in configs: - with self.session(use_gpu=use_gpu): + with GetDeviceScope(self, use_gpu=use_gpu): t = constant_op.constant(tensor_input, shape=[2, 3, 3, 1]) out_op, argmax_op = nn_ops.max_pool_with_argmax( t, @@ -868,7 +872,7 @@ class PoolingTest(test.TestCase): [True, False, [0, 1, 3, 5, 0, 2, 6, 8]]] for use_gpu, include_batch_in_index, argmax in configs: - with self.session(use_gpu=use_gpu): + with GetDeviceScope(self, use_gpu): orig_in = constant_op.constant(orig_input, shape=[2, 3, 3, 1]) t = constant_op.constant(tensor_input, shape=[2, 2, 2, 1]) argmax_t = constant_op.constant(