Added gradient for ragged_gather.

PiperOrigin-RevId: 255181798
This commit is contained in:
Edward Loper 2019-06-26 07:12:47 -07:00 committed by TensorFlower Gardener
parent 851b3f5a46
commit f7415d1efb
2 changed files with 157 additions and 2 deletions

View File

@ -17,20 +17,24 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_gather_ops
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedGatherOpTest(test_util.TensorFlowTestCase):
class RaggedGatherOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def testDocStringExamples(self):
params = constant_op.constant(['a', 'b', 'c', 'd', 'e'])
@ -137,6 +141,121 @@ class RaggedGatherOpTest(test_util.TensorFlowTestCase):
r'indices\.shape\.ndims must be known statically',
ragged_gather_ops.gather, params, indices)
# pylint: disable=bad-whitespace
@parameterized.parameters([
# params.shape=[2, None]; indices.shape=[3]
dict(
params = [[1.0, 2.0], [3.0, 4.0, 5.0]],
indices = [0, 0, 1],
expected_out = [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0, 5.0]],
out_grad = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6, 0.7]],
expected_grad = [[0.4, 0.6], [0.5, 0.6, 0.7]]),
# params.shape=[2, None]; indices.shape=[0]
dict(
params = [[1, 2], [3, 4, 5]],
indices = [],
expected_out = [],
out_grad = [],
expected_grad = [[0, 0], [0, 0, 0]]),
# params.shape=[2, None]; indices.shape=[2, 2]
dict(
params = [[1.0, 2.0], [3.0, 4.0, 5.0]],
indices = [[0, 0], [1, 0]],
expected_out = [[[1.0, 2.0], [1.0, 2.0]],
[[3.0, 4.0, 5.0], [1.0, 2.0]]],
out_grad = [[[0.1, 0.2], [0.3, 0.4]],
[[0.5, 0.6, 0.7], [0.8, 0.9]]],
expected_grad = [[1.2, 1.5], [0.5, 0.6, 0.7]]),
# params.shape=[3, None, None]; indices.shape=[3]
dict(
params = [[[1, 2], [3, 4, 5]], [[6.0]], [[7.0, 8.0]]],
indices = [2, 1, 2],
expected_out = [[[7.0, 8.0]], [[6.0]], [[7.0, 8.0]]],
out_grad = [[[0.1, 0.2]], [[0.3]], [[0.4, 0.5]]],
expected_grad = [[[0, 0], [0, 0, 0]], [[0.3]], [[0.5, 0.7]]]),
# params.shape=[3, None, None]; indices.shape=[0]
dict(
params = [[[1, 2], [3, 4, 5]], [[6.0]], [[7.0, 8.0]]],
indices = [2, 1, 2],
expected_out = [[[7.0, 8.0]], [[6.0]], [[7.0, 8.0]]],
out_grad = [[[0.1, 0.2]], [[0.3]], [[0.4, 0.5]]],
expected_grad = [[[0, 0], [0, 0, 0]], [[0.3]], [[0.5, 0.7]]]),
# params.shape=[0, None]; indices.shape=[0]
dict(
params = [],
indices = [],
expected_out = [],
out_grad = [],
expected_grad = [],
params_ragged_rank = 1),
# params.shape=[2, None, 2]; indices.shape=[3]
dict(
params = [[[1, 2], [3, 4]], [], [[5, 6]]],
indices = [1, 1, 2, 0, 2],
expected_out = [[], [], [[5, 6]], [[1, 2], [3, 4]], [[5, 6]]],
out_grad = [[], [], [[1, 2]], [[3, 4], [5, 6]], [[7, 7]]],
expected_grad = [[[3, 4], [5, 6]], [], [[8, 9]]],
params_ragged_rank = 1),
]) # pyformat: disable
@test_util.run_deprecated_v1
def testGradient(self,
params,
indices,
expected_out,
out_grad,
expected_grad,
params_ragged_rank=None):
"""Tests that ragged_gather generates the right gradient.
Args:
params: The `params` that should be passed to `gather`.
indices: The `indices` that should be passed to `gather`.
expected_out: The expected value of `gather(params, indices)`.
`expected_out.shape = indices.shape + params.shape[1:]`.
out_grad: The value that should be fed in as the gradient for `out`
when testing the gradient of `ragged_gather`. Must have the same
shape as `expected_out`.
expected_grad: The expected gradient for that should be returned for
`params`. Must have hte same shape as `params`.
params_ragged_rank: The ragged_rank of `params`.
"""
if context.executing_eagerly():
return
params = ragged_factory_ops.constant(
params, dtype=dtypes.float32, ragged_rank=params_ragged_rank)
indices = constant_op.constant(indices, dtype=dtypes.int32)
out_ragged_rank = params.ragged_rank + indices.shape.ndims - 1
out_grad = ragged_factory_ops.constant(
out_grad, dtype=dtypes.float32, ragged_rank=out_ragged_rank)
expected_out = ragged_factory_ops.constant(
expected_out, dtype=dtypes.float32, ragged_rank=out_ragged_rank)
expected_grad = ragged_factory_ops.constant(
expected_grad,
dtype=dtypes.float32,
ragged_rank=params.ragged_rank)
out = ragged_gather_ops.gather(params, indices)
self.assertAllClose(out, expected_out)
grads = gradients_impl.gradients(
out.flat_values,
(params.nested_row_splits + (params.flat_values, indices,)),
out_grad.flat_values)
param_nested_splits_grads = grads[:-2]
params_flat_values_grad = grads[-2]
indices_grad = grads[-1]
self.assertEqual(indices_grad, None)
for splits_grad in param_nested_splits_grads:
self.assertEqual(splits_grad, None)
# The gradient generates an IndexedSlices; convert back to a normal Tensor.
self.assertIsInstance(params_flat_values_grad, indexed_slices.IndexedSlices)
params_flat_values_grad = ops.convert_to_tensor(params_flat_values_grad)
params_grad = params.with_flat_values(params_flat_values_grad)
self.assertAllClose(params_grad, expected_grad, atol=2e-6, rtol=2e-6)
if __name__ == '__main__':
googletest.main()

View File

@ -19,12 +19,14 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_ragged_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor
@ -261,3 +263,37 @@ def gather_nd(params, indices, batch_dims=0, name=None):
# Gather using the flattened index tuples and params.
return gather(flattened_params, flattened_index_tuples)
#===============================================================================
# Gradient for the RaggedGather kernel
#===============================================================================
@ops.RegisterGradient('RaggedGather')
def _ragged_gather_grad(op, *grads):
"""Gradient for RaggedGather op."""
param_nested_splits = op.inputs[:-2]
param_inner_values = op.inputs[-2]
indices = op.inputs[-1]
grad_inner_values = grads[-1]
# For each row in `params`, find the range of values in `params.inner_values`
# that is covered by that row. In particular, the values in row `i` are
# `param_inner_values[combined_splits[i]:combined_splits[i+1]`.
combined_splits = param_nested_splits[0]
for row_splits in param_nested_splits[1:]:
combined_splits = array_ops.gather(row_splits, combined_splits)
# The outer dimensions of `indices` correspond 1:1 with the outer dimensions
# of `ragged_grad` that are encoded by `grad_nested_splits`. Thus, the
# flattened `indices` correspond 1:1 with `grad_inner_values`.
flat_indices = array_ops.reshape(indices, [-1])
# Build an IndexedSlices where the values are taken from `flat_grad`.
grad_indices = ragged_math_ops.range(
array_ops.gather(combined_splits, flat_indices),
array_ops.gather(combined_splits[1:], flat_indices)).values
param_inner_values_grad = indexed_slices.IndexedSlices(
values=grad_inner_values, indices=grad_indices,
dense_shape=array_ops.shape(param_inner_values))
return [None for _ in param_nested_splits] + [param_inner_values_grad, None]