From 5609ede2288e9f8bd5c7c9c7050199d60ba1830b Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Thu, 31 Oct 2019 11:36:57 -0700 Subject: [PATCH] Support empty input shapes to tf.roll on GPU PiperOrigin-RevId: 277766909 Change-Id: I315bd52d0816f9cb6717451b7b7c3e2b3882b004 --- tensorflow/core/kernels/roll_op_gpu.cu.cc | 1 + .../python/kernel_tests/manip_ops_test.py | 21 ++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/kernels/roll_op_gpu.cu.cc b/tensorflow/core/kernels/roll_op_gpu.cu.cc index d4171edaca8..4df0305569e 100644 --- a/tensorflow/core/kernels/roll_op_gpu.cu.cc +++ b/tensorflow/core/kernels/roll_op_gpu.cu.cc @@ -58,6 +58,7 @@ struct Roll { const T* input, T* output, const gtl::ArraySlice threshold, const gtl::ArraySlice dim_range, const int64 isd) { + if (!num_elements) return; const GPUDevice& d = context->eigen_device(); auto dim_bytes = sizeof(int32) * dim_size.size(); diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py index 5700db4b950..e6cb06ca477 100644 --- a/tensorflow/python/kernel_tests/manip_ops_test.py +++ b/tensorflow/python/kernel_tests/manip_ops_test.py @@ -42,12 +42,12 @@ class RollTest(test_util.TensorFlowTestCase): def _testRoll(self, np_input, shift, axis): expected_roll = np.roll(np_input, shift, axis) - with self.cached_session(): + with self.cached_session(use_gpu=True): roll = manip_ops.roll(np_input, shift, axis) self.assertAllEqual(roll.eval(), expected_roll) def _testGradient(self, np_input, shift, axis): - with self.cached_session(): + with self.cached_session(use_gpu=True): inx = constant_op.constant(np_input.tolist()) xs = list(np_input.shape) y = manip_ops.roll(inx, shift, axis) @@ -98,12 +98,17 @@ class RollTest(test_util.TensorFlowTestCase): self._testAll(np.random.randint(-100, 100, (5)).astype(np.int32), 3, -1) self._testAll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -2) # Make sure negative axis should be 0 <= axis + dims < dims - with self.cached_session(): + with self.cached_session(use_gpu=True): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "is out of range"): manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -10).eval() + @test_util.run_deprecated_v1 + def testEmptyInput(self): + self._testAll(np.zeros([0, 1]), 1, 1) + self._testAll(np.zeros([1, 0]), 1, 1) + @test_util.run_deprecated_v1 def testInvalidInputShape(self): # The input should be 1-D or higher, checked in shape function. @@ -117,7 +122,7 @@ class RollTest(test_util.TensorFlowTestCase): tensor = array_ops.placeholder(dtype=dtypes.int32) shift = 1 axis = 0 - with self.cached_session(): + with self.cached_session(use_gpu=True): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "input must be 1-D or higher"): manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7}) @@ -135,7 +140,7 @@ class RollTest(test_util.TensorFlowTestCase): tensor = [[1, 2], [3, 4]] shift = 1 axis = array_ops.placeholder(dtype=dtypes.int32) - with self.cached_session(): + with self.cached_session(use_gpu=True): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "axis must be a scalar or a 1-D vector"): manip_ops.roll(tensor, shift, axis).eval(feed_dict={axis: [[0, 1]]}) @@ -153,7 +158,7 @@ class RollTest(test_util.TensorFlowTestCase): tensor = [[1, 2], [3, 4]] shift = array_ops.placeholder(dtype=dtypes.int32) axis = 1 - with self.cached_session(): + with self.cached_session(use_gpu=True): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "shift must be a scalar or a 1-D vector"): manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [[0, 1]]}) @@ -170,7 +175,7 @@ class RollTest(test_util.TensorFlowTestCase): tensor = [[1, 2], [3, 4]] shift = array_ops.placeholder(dtype=dtypes.int32) axis = [0, 1] - with self.cached_session(): + with self.cached_session(use_gpu=True): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "shift and axis must have the same size"): manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [1]}) @@ -179,7 +184,7 @@ class RollTest(test_util.TensorFlowTestCase): tensor = [1, 2] shift = 1 axis = 1 - with self.cached_session(): + with self.cached_session(use_gpu=True): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "is out of range"): manip_ops.roll(tensor, shift, axis).eval()