Fix gather gradient for empty slices

We can't use -1's in the reshape if the size is 0, since the -1 would be
ambiguous.  Also simplify the code now that array_ops.shape does its own
checking for fully defined shapes.

A previous version used too much colocate_with and broke a distributed
model.  This version uses colocate_with only for params following the
previous version of the code.
Change: 130720593
This commit is contained in:
Geoffrey Irving 2016-08-18 21:45:55 -08:00 committed by TensorFlower Gardener
parent 37000ef3b5
commit d35ba035a0
2 changed files with 33 additions and 30 deletions

View File

@ -57,26 +57,29 @@ class GatherTest(tf.test.TestCase):
def testHigherRank(self):
np.random.seed(1)
shape = (4, 3, 2)
params = np.random.randn(*shape)
indices = np.random.randint(shape[0], size=15).reshape(3, 5)
with self.test_session(use_gpu=self.use_gpu):
tf_params = tf.constant(params)
tf_indices = tf.constant(indices)
gather = tf.gather(tf_params, tf_indices)
self.assertAllEqual(params[indices], gather.eval())
self.assertEqual(indices.shape + params.shape[1:], gather.get_shape())
# Test gradients
gather_grad = np.random.randn(*gather.get_shape().as_list())
params_grad, indices_grad = tf.gradients(gather, [tf_params, tf_indices],
gather_grad)
self.assertEqual(indices_grad, None)
self.assertEqual(type(params_grad), tf.IndexedSlices)
params_grad = tf.convert_to_tensor(params_grad)
correct_params_grad = np.zeros(shape)
for i, g in zip(indices.ravel(), gather_grad.reshape((15,) + shape[1:])):
correct_params_grad[i] += g
self.assertAllClose(correct_params_grad, params_grad.eval())
# We check that scalar and empty shapes work as well
for shape in (7, 0), (4, 3, 2):
for indices_shape in (), (0,), (3, 0), (3, 5):
params = np.random.randn(*shape)
indices = np.random.randint(shape[0], size=indices_shape)
with self.test_session(use_gpu=self.use_gpu):
tf_params = tf.constant(params)
tf_indices = tf.constant(indices)
gather = tf.gather(tf_params, tf_indices)
self.assertAllEqual(params[indices], gather.eval())
self.assertEqual(indices.shape + params.shape[1:], gather.get_shape())
# Test gradients
gather_grad = np.random.randn(*gather.get_shape().as_list())
params_grad, indices_grad = tf.gradients(
gather, [tf_params, tf_indices], gather_grad)
self.assertEqual(indices_grad, None)
self.assertEqual(type(params_grad), tf.IndexedSlices)
params_grad = tf.convert_to_tensor(params_grad)
correct_params_grad = np.zeros(shape)
for i, g in zip(indices.flat,
gather_grad.reshape((indices.size,) + shape[1:])):
correct_params_grad[i] += g
self.assertAllClose(correct_params_grad, params_grad.eval())
def testUnknownIndices(self):
params = tf.constant([[0, 1, 2]])

View File

@ -276,18 +276,18 @@ ops.NoGradient("ZerosLike")
@ops.RegisterGradient("Gather")
def _GatherGrad(op, grad):
"""Gradient for Gather op."""
if op.inputs[0].get_shape().is_fully_defined():
dense_shape = constant_op.constant(op.inputs[0].get_shape().as_list())
values_shape = [-1] + op.inputs[0].get_shape()[1:].as_list()
else:
# op.inputs[0] can be large, so colocate the shape calculation with it.
with ops.colocate_with(op.inputs[0]):
dense_shape = array_ops.shape(op.inputs[0])
values_shape = array_ops.concat(0, [[-1], dense_shape[1:]])
# params can be large, so colocate the shape calculation with it.
params = op.inputs[0]
with ops.colocate_with(params):
params_shape = array_ops.shape(params)
# Build appropriately shaped IndexedSlices
indices = op.inputs[1]
size = array_ops.expand_dims(array_ops.size(indices), 0)
values_shape = array_ops.concat(0, [size, params_shape[1:]])
values = array_ops.reshape(grad, values_shape)
indices = array_ops.reshape(op.inputs[1], [-1])
return [ops.IndexedSlices(values, indices, dense_shape), None]
indices = array_ops.reshape(indices, size)
return [ops.IndexedSlices(values, indices, params_shape), None]
@ops.RegisterGradient("GatherNd")