diff --git a/tensorflow/python/ops/ragged/ragged_gather_op_test.py b/tensorflow/python/ops/ragged/ragged_gather_op_test.py index ee0b219ddb4..8138a10b6c7 100644 --- a/tensorflow/python/ops/ragged/ragged_gather_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_gather_op_test.py @@ -17,20 +17,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + +from absl.testing import parameterized + from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_gather_ops from tensorflow.python.platform import googletest -@test_util.run_all_in_graph_and_eager_modes -class RaggedGatherOpTest(test_util.TensorFlowTestCase): +class RaggedGatherOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): def testDocStringExamples(self): params = constant_op.constant(['a', 'b', 'c', 'd', 'e']) @@ -137,6 +141,121 @@ class RaggedGatherOpTest(test_util.TensorFlowTestCase): r'indices\.shape\.ndims must be known statically', ragged_gather_ops.gather, params, indices) + # pylint: disable=bad-whitespace + @parameterized.parameters([ + # params.shape=[2, None]; indices.shape=[3] + dict( + params = [[1.0, 2.0], [3.0, 4.0, 5.0]], + indices = [0, 0, 1], + expected_out = [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0, 5.0]], + out_grad = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6, 0.7]], + expected_grad = [[0.4, 0.6], [0.5, 0.6, 0.7]]), + # params.shape=[2, None]; indices.shape=[0] + dict( + params = [[1, 2], [3, 4, 5]], + indices = [], + expected_out = [], + out_grad = [], + expected_grad = [[0, 0], [0, 0, 0]]), + # params.shape=[2, None]; indices.shape=[2, 2] + dict( + params = [[1.0, 2.0], [3.0, 4.0, 5.0]], + indices = [[0, 0], [1, 0]], + expected_out = [[[1.0, 2.0], [1.0, 2.0]], + [[3.0, 4.0, 5.0], [1.0, 2.0]]], + out_grad = [[[0.1, 0.2], [0.3, 0.4]], + [[0.5, 0.6, 0.7], [0.8, 0.9]]], + expected_grad = [[1.2, 1.5], [0.5, 0.6, 0.7]]), + # params.shape=[3, None, None]; indices.shape=[3] + dict( + params = [[[1, 2], [3, 4, 5]], [[6.0]], [[7.0, 8.0]]], + indices = [2, 1, 2], + expected_out = [[[7.0, 8.0]], [[6.0]], [[7.0, 8.0]]], + out_grad = [[[0.1, 0.2]], [[0.3]], [[0.4, 0.5]]], + expected_grad = [[[0, 0], [0, 0, 0]], [[0.3]], [[0.5, 0.7]]]), + # params.shape=[3, None, None]; indices.shape=[0] + dict( + params = [[[1, 2], [3, 4, 5]], [[6.0]], [[7.0, 8.0]]], + indices = [2, 1, 2], + expected_out = [[[7.0, 8.0]], [[6.0]], [[7.0, 8.0]]], + out_grad = [[[0.1, 0.2]], [[0.3]], [[0.4, 0.5]]], + expected_grad = [[[0, 0], [0, 0, 0]], [[0.3]], [[0.5, 0.7]]]), + # params.shape=[0, None]; indices.shape=[0] + dict( + params = [], + indices = [], + expected_out = [], + out_grad = [], + expected_grad = [], + params_ragged_rank = 1), + # params.shape=[2, None, 2]; indices.shape=[3] + dict( + params = [[[1, 2], [3, 4]], [], [[5, 6]]], + indices = [1, 1, 2, 0, 2], + expected_out = [[], [], [[5, 6]], [[1, 2], [3, 4]], [[5, 6]]], + out_grad = [[], [], [[1, 2]], [[3, 4], [5, 6]], [[7, 7]]], + expected_grad = [[[3, 4], [5, 6]], [], [[8, 9]]], + params_ragged_rank = 1), + ]) # pyformat: disable + @test_util.run_deprecated_v1 + def testGradient(self, + params, + indices, + expected_out, + out_grad, + expected_grad, + params_ragged_rank=None): + """Tests that ragged_gather generates the right gradient. + + Args: + params: The `params` that should be passed to `gather`. + indices: The `indices` that should be passed to `gather`. + expected_out: The expected value of `gather(params, indices)`. + `expected_out.shape = indices.shape + params.shape[1:]`. + out_grad: The value that should be fed in as the gradient for `out` + when testing the gradient of `ragged_gather`. Must have the same + shape as `expected_out`. + expected_grad: The expected gradient for that should be returned for + `params`. Must have hte same shape as `params`. + params_ragged_rank: The ragged_rank of `params`. + """ + if context.executing_eagerly(): + return + + params = ragged_factory_ops.constant( + params, dtype=dtypes.float32, ragged_rank=params_ragged_rank) + indices = constant_op.constant(indices, dtype=dtypes.int32) + out_ragged_rank = params.ragged_rank + indices.shape.ndims - 1 + out_grad = ragged_factory_ops.constant( + out_grad, dtype=dtypes.float32, ragged_rank=out_ragged_rank) + expected_out = ragged_factory_ops.constant( + expected_out, dtype=dtypes.float32, ragged_rank=out_ragged_rank) + expected_grad = ragged_factory_ops.constant( + expected_grad, + dtype=dtypes.float32, + ragged_rank=params.ragged_rank) + + out = ragged_gather_ops.gather(params, indices) + self.assertAllClose(out, expected_out) + + grads = gradients_impl.gradients( + out.flat_values, + (params.nested_row_splits + (params.flat_values, indices,)), + out_grad.flat_values) + param_nested_splits_grads = grads[:-2] + params_flat_values_grad = grads[-2] + indices_grad = grads[-1] + self.assertEqual(indices_grad, None) + for splits_grad in param_nested_splits_grads: + self.assertEqual(splits_grad, None) + + # The gradient generates an IndexedSlices; convert back to a normal Tensor. + self.assertIsInstance(params_flat_values_grad, indexed_slices.IndexedSlices) + params_flat_values_grad = ops.convert_to_tensor(params_flat_values_grad) + + params_grad = params.with_flat_values(params_flat_values_grad) + self.assertAllClose(params_grad, expected_grad, atol=2e-6, rtol=2e-6) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/ops/ragged/ragged_gather_ops.py b/tensorflow/python/ops/ragged/ragged_gather_ops.py index ba3beefa8c9..89a501df170 100644 --- a/tensorflow/python/ops/ragged/ragged_gather_ops.py +++ b/tensorflow/python/ops/ragged/ragged_gather_ops.py @@ -19,12 +19,14 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import dtypes +from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_ragged_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_array_ops +from tensorflow.python.ops.ragged import ragged_math_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -261,3 +263,37 @@ def gather_nd(params, indices, batch_dims=0, name=None): # Gather using the flattened index tuples and params. return gather(flattened_params, flattened_index_tuples) + + +#=============================================================================== +# Gradient for the RaggedGather kernel +#=============================================================================== +@ops.RegisterGradient('RaggedGather') +def _ragged_gather_grad(op, *grads): + """Gradient for RaggedGather op.""" + param_nested_splits = op.inputs[:-2] + param_inner_values = op.inputs[-2] + indices = op.inputs[-1] + grad_inner_values = grads[-1] + + # For each row in `params`, find the range of values in `params.inner_values` + # that is covered by that row. In particular, the values in row `i` are + # `param_inner_values[combined_splits[i]:combined_splits[i+1]`. + combined_splits = param_nested_splits[0] + for row_splits in param_nested_splits[1:]: + combined_splits = array_ops.gather(row_splits, combined_splits) + + # The outer dimensions of `indices` correspond 1:1 with the outer dimensions + # of `ragged_grad` that are encoded by `grad_nested_splits`. Thus, the + # flattened `indices` correspond 1:1 with `grad_inner_values`. + flat_indices = array_ops.reshape(indices, [-1]) + + # Build an IndexedSlices where the values are taken from `flat_grad`. + grad_indices = ragged_math_ops.range( + array_ops.gather(combined_splits, flat_indices), + array_ops.gather(combined_splits[1:], flat_indices)).values + + param_inner_values_grad = indexed_slices.IndexedSlices( + values=grad_inner_values, indices=grad_indices, + dense_shape=array_ops.shape(param_inner_values)) + return [None for _ in param_nested_splits] + [param_inner_values_grad, None]