Support empty input shapes to tf.roll on GPU
PiperOrigin-RevId: 277766909 Change-Id: I315bd52d0816f9cb6717451b7b7c3e2b3882b004
This commit is contained in:
parent
e54ba169a1
commit
5609ede228
@ -58,6 +58,7 @@ struct Roll<GPUDevice, T> {
|
||||
const T* input, T* output,
|
||||
const gtl::ArraySlice<int32> threshold,
|
||||
const gtl::ArraySlice<int64> dim_range, const int64 isd) {
|
||||
if (!num_elements) return;
|
||||
const GPUDevice& d = context->eigen_device<GPUDevice>();
|
||||
|
||||
auto dim_bytes = sizeof(int32) * dim_size.size();
|
||||
|
@ -42,12 +42,12 @@ class RollTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _testRoll(self, np_input, shift, axis):
|
||||
expected_roll = np.roll(np_input, shift, axis)
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
roll = manip_ops.roll(np_input, shift, axis)
|
||||
self.assertAllEqual(roll.eval(), expected_roll)
|
||||
|
||||
def _testGradient(self, np_input, shift, axis):
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
inx = constant_op.constant(np_input.tolist())
|
||||
xs = list(np_input.shape)
|
||||
y = manip_ops.roll(inx, shift, axis)
|
||||
@ -98,12 +98,17 @@ class RollTest(test_util.TensorFlowTestCase):
|
||||
self._testAll(np.random.randint(-100, 100, (5)).astype(np.int32), 3, -1)
|
||||
self._testAll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -2)
|
||||
# Make sure negative axis should be 0 <= axis + dims < dims
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"is out of range"):
|
||||
manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32),
|
||||
3, -10).eval()
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEmptyInput(self):
|
||||
self._testAll(np.zeros([0, 1]), 1, 1)
|
||||
self._testAll(np.zeros([1, 0]), 1, 1)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInvalidInputShape(self):
|
||||
# The input should be 1-D or higher, checked in shape function.
|
||||
@ -117,7 +122,7 @@ class RollTest(test_util.TensorFlowTestCase):
|
||||
tensor = array_ops.placeholder(dtype=dtypes.int32)
|
||||
shift = 1
|
||||
axis = 0
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"input must be 1-D or higher"):
|
||||
manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7})
|
||||
@ -135,7 +140,7 @@ class RollTest(test_util.TensorFlowTestCase):
|
||||
tensor = [[1, 2], [3, 4]]
|
||||
shift = 1
|
||||
axis = array_ops.placeholder(dtype=dtypes.int32)
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"axis must be a scalar or a 1-D vector"):
|
||||
manip_ops.roll(tensor, shift, axis).eval(feed_dict={axis: [[0, 1]]})
|
||||
@ -153,7 +158,7 @@ class RollTest(test_util.TensorFlowTestCase):
|
||||
tensor = [[1, 2], [3, 4]]
|
||||
shift = array_ops.placeholder(dtype=dtypes.int32)
|
||||
axis = 1
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"shift must be a scalar or a 1-D vector"):
|
||||
manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [[0, 1]]})
|
||||
@ -170,7 +175,7 @@ class RollTest(test_util.TensorFlowTestCase):
|
||||
tensor = [[1, 2], [3, 4]]
|
||||
shift = array_ops.placeholder(dtype=dtypes.int32)
|
||||
axis = [0, 1]
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"shift and axis must have the same size"):
|
||||
manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [1]})
|
||||
@ -179,7 +184,7 @@ class RollTest(test_util.TensorFlowTestCase):
|
||||
tensor = [1, 2]
|
||||
shift = 1
|
||||
axis = 1
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"is out of range"):
|
||||
manip_ops.roll(tensor, shift, axis).eval()
|
||||
|
Loading…
Reference in New Issue
Block a user