Support empty input shapes to tf.roll on GPU

PiperOrigin-RevId: 277766909
Change-Id: I315bd52d0816f9cb6717451b7b7c3e2b3882b004
This commit is contained in:
Reed Wanderman-Milne 2019-10-31 11:36:57 -07:00 committed by TensorFlower Gardener
parent e54ba169a1
commit 5609ede228
2 changed files with 14 additions and 8 deletions

View File

@ -58,6 +58,7 @@ struct Roll<GPUDevice, T> {
const T* input, T* output,
const gtl::ArraySlice<int32> threshold,
const gtl::ArraySlice<int64> dim_range, const int64 isd) {
if (!num_elements) return;
const GPUDevice& d = context->eigen_device<GPUDevice>();
auto dim_bytes = sizeof(int32) * dim_size.size();

View File

@ -42,12 +42,12 @@ class RollTest(test_util.TensorFlowTestCase):
def _testRoll(self, np_input, shift, axis):
expected_roll = np.roll(np_input, shift, axis)
with self.cached_session():
with self.cached_session(use_gpu=True):
roll = manip_ops.roll(np_input, shift, axis)
self.assertAllEqual(roll.eval(), expected_roll)
def _testGradient(self, np_input, shift, axis):
with self.cached_session():
with self.cached_session(use_gpu=True):
inx = constant_op.constant(np_input.tolist())
xs = list(np_input.shape)
y = manip_ops.roll(inx, shift, axis)
@ -98,12 +98,17 @@ class RollTest(test_util.TensorFlowTestCase):
self._testAll(np.random.randint(-100, 100, (5)).astype(np.int32), 3, -1)
self._testAll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -2)
# Make sure negative axis should be 0 <= axis + dims < dims
with self.cached_session():
with self.cached_session(use_gpu=True):
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"is out of range"):
manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32),
3, -10).eval()
@test_util.run_deprecated_v1
def testEmptyInput(self):
self._testAll(np.zeros([0, 1]), 1, 1)
self._testAll(np.zeros([1, 0]), 1, 1)
@test_util.run_deprecated_v1
def testInvalidInputShape(self):
# The input should be 1-D or higher, checked in shape function.
@ -117,7 +122,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = array_ops.placeholder(dtype=dtypes.int32)
shift = 1
axis = 0
with self.cached_session():
with self.cached_session(use_gpu=True):
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"input must be 1-D or higher"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7})
@ -135,7 +140,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [[1, 2], [3, 4]]
shift = 1
axis = array_ops.placeholder(dtype=dtypes.int32)
with self.cached_session():
with self.cached_session(use_gpu=True):
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"axis must be a scalar or a 1-D vector"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={axis: [[0, 1]]})
@ -153,7 +158,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [[1, 2], [3, 4]]
shift = array_ops.placeholder(dtype=dtypes.int32)
axis = 1
with self.cached_session():
with self.cached_session(use_gpu=True):
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"shift must be a scalar or a 1-D vector"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [[0, 1]]})
@ -170,7 +175,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [[1, 2], [3, 4]]
shift = array_ops.placeholder(dtype=dtypes.int32)
axis = [0, 1]
with self.cached_session():
with self.cached_session(use_gpu=True):
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"shift and axis must have the same size"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [1]})
@ -179,7 +184,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [1, 2]
shift = 1
axis = 1
with self.cached_session():
with self.cached_session(use_gpu=True):
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"is out of range"):
manip_ops.roll(tensor, shift, axis).eval()