Fix gather gradient for empty slices
We can't use -1's in the reshape if the size is 0, since the -1 would be ambiguous. Also simplify the code now that array_ops.shape does its own checking for fully defined shapes. A previous version used too much colocate_with and broke a distributed model. This version uses colocate_with only for params following the previous version of the code. Change: 130720593
This commit is contained in:
parent
37000ef3b5
commit
d35ba035a0
@ -57,26 +57,29 @@ class GatherTest(tf.test.TestCase):
|
||||
|
||||
def testHigherRank(self):
|
||||
np.random.seed(1)
|
||||
shape = (4, 3, 2)
|
||||
params = np.random.randn(*shape)
|
||||
indices = np.random.randint(shape[0], size=15).reshape(3, 5)
|
||||
with self.test_session(use_gpu=self.use_gpu):
|
||||
tf_params = tf.constant(params)
|
||||
tf_indices = tf.constant(indices)
|
||||
gather = tf.gather(tf_params, tf_indices)
|
||||
self.assertAllEqual(params[indices], gather.eval())
|
||||
self.assertEqual(indices.shape + params.shape[1:], gather.get_shape())
|
||||
# Test gradients
|
||||
gather_grad = np.random.randn(*gather.get_shape().as_list())
|
||||
params_grad, indices_grad = tf.gradients(gather, [tf_params, tf_indices],
|
||||
gather_grad)
|
||||
self.assertEqual(indices_grad, None)
|
||||
self.assertEqual(type(params_grad), tf.IndexedSlices)
|
||||
params_grad = tf.convert_to_tensor(params_grad)
|
||||
correct_params_grad = np.zeros(shape)
|
||||
for i, g in zip(indices.ravel(), gather_grad.reshape((15,) + shape[1:])):
|
||||
correct_params_grad[i] += g
|
||||
self.assertAllClose(correct_params_grad, params_grad.eval())
|
||||
# We check that scalar and empty shapes work as well
|
||||
for shape in (7, 0), (4, 3, 2):
|
||||
for indices_shape in (), (0,), (3, 0), (3, 5):
|
||||
params = np.random.randn(*shape)
|
||||
indices = np.random.randint(shape[0], size=indices_shape)
|
||||
with self.test_session(use_gpu=self.use_gpu):
|
||||
tf_params = tf.constant(params)
|
||||
tf_indices = tf.constant(indices)
|
||||
gather = tf.gather(tf_params, tf_indices)
|
||||
self.assertAllEqual(params[indices], gather.eval())
|
||||
self.assertEqual(indices.shape + params.shape[1:], gather.get_shape())
|
||||
# Test gradients
|
||||
gather_grad = np.random.randn(*gather.get_shape().as_list())
|
||||
params_grad, indices_grad = tf.gradients(
|
||||
gather, [tf_params, tf_indices], gather_grad)
|
||||
self.assertEqual(indices_grad, None)
|
||||
self.assertEqual(type(params_grad), tf.IndexedSlices)
|
||||
params_grad = tf.convert_to_tensor(params_grad)
|
||||
correct_params_grad = np.zeros(shape)
|
||||
for i, g in zip(indices.flat,
|
||||
gather_grad.reshape((indices.size,) + shape[1:])):
|
||||
correct_params_grad[i] += g
|
||||
self.assertAllClose(correct_params_grad, params_grad.eval())
|
||||
|
||||
def testUnknownIndices(self):
|
||||
params = tf.constant([[0, 1, 2]])
|
||||
|
@ -276,18 +276,18 @@ ops.NoGradient("ZerosLike")
|
||||
@ops.RegisterGradient("Gather")
|
||||
def _GatherGrad(op, grad):
|
||||
"""Gradient for Gather op."""
|
||||
if op.inputs[0].get_shape().is_fully_defined():
|
||||
dense_shape = constant_op.constant(op.inputs[0].get_shape().as_list())
|
||||
values_shape = [-1] + op.inputs[0].get_shape()[1:].as_list()
|
||||
else:
|
||||
# op.inputs[0] can be large, so colocate the shape calculation with it.
|
||||
with ops.colocate_with(op.inputs[0]):
|
||||
dense_shape = array_ops.shape(op.inputs[0])
|
||||
values_shape = array_ops.concat(0, [[-1], dense_shape[1:]])
|
||||
# params can be large, so colocate the shape calculation with it.
|
||||
params = op.inputs[0]
|
||||
with ops.colocate_with(params):
|
||||
params_shape = array_ops.shape(params)
|
||||
|
||||
# Build appropriately shaped IndexedSlices
|
||||
indices = op.inputs[1]
|
||||
size = array_ops.expand_dims(array_ops.size(indices), 0)
|
||||
values_shape = array_ops.concat(0, [size, params_shape[1:]])
|
||||
values = array_ops.reshape(grad, values_shape)
|
||||
indices = array_ops.reshape(op.inputs[1], [-1])
|
||||
return [ops.IndexedSlices(values, indices, dense_shape), None]
|
||||
indices = array_ops.reshape(indices, size)
|
||||
return [ops.IndexedSlices(values, indices, params_shape), None]
|
||||
|
||||
|
||||
@ops.RegisterGradient("GatherNd")
|
||||
|
Loading…
Reference in New Issue
Block a user