Added gradient for ragged_gather.
PiperOrigin-RevId: 255181798
This commit is contained in:
parent
851b3f5a46
commit
f7415d1efb
@ -17,20 +17,24 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_gather_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedGatherOpTest(test_util.TensorFlowTestCase):
|
||||
class RaggedGatherOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def testDocStringExamples(self):
|
||||
params = constant_op.constant(['a', 'b', 'c', 'd', 'e'])
|
||||
@ -137,6 +141,121 @@ class RaggedGatherOpTest(test_util.TensorFlowTestCase):
|
||||
r'indices\.shape\.ndims must be known statically',
|
||||
ragged_gather_ops.gather, params, indices)
|
||||
|
||||
# pylint: disable=bad-whitespace
|
||||
@parameterized.parameters([
|
||||
# params.shape=[2, None]; indices.shape=[3]
|
||||
dict(
|
||||
params = [[1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
indices = [0, 0, 1],
|
||||
expected_out = [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
out_grad = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6, 0.7]],
|
||||
expected_grad = [[0.4, 0.6], [0.5, 0.6, 0.7]]),
|
||||
# params.shape=[2, None]; indices.shape=[0]
|
||||
dict(
|
||||
params = [[1, 2], [3, 4, 5]],
|
||||
indices = [],
|
||||
expected_out = [],
|
||||
out_grad = [],
|
||||
expected_grad = [[0, 0], [0, 0, 0]]),
|
||||
# params.shape=[2, None]; indices.shape=[2, 2]
|
||||
dict(
|
||||
params = [[1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
indices = [[0, 0], [1, 0]],
|
||||
expected_out = [[[1.0, 2.0], [1.0, 2.0]],
|
||||
[[3.0, 4.0, 5.0], [1.0, 2.0]]],
|
||||
out_grad = [[[0.1, 0.2], [0.3, 0.4]],
|
||||
[[0.5, 0.6, 0.7], [0.8, 0.9]]],
|
||||
expected_grad = [[1.2, 1.5], [0.5, 0.6, 0.7]]),
|
||||
# params.shape=[3, None, None]; indices.shape=[3]
|
||||
dict(
|
||||
params = [[[1, 2], [3, 4, 5]], [[6.0]], [[7.0, 8.0]]],
|
||||
indices = [2, 1, 2],
|
||||
expected_out = [[[7.0, 8.0]], [[6.0]], [[7.0, 8.0]]],
|
||||
out_grad = [[[0.1, 0.2]], [[0.3]], [[0.4, 0.5]]],
|
||||
expected_grad = [[[0, 0], [0, 0, 0]], [[0.3]], [[0.5, 0.7]]]),
|
||||
# params.shape=[3, None, None]; indices.shape=[0]
|
||||
dict(
|
||||
params = [[[1, 2], [3, 4, 5]], [[6.0]], [[7.0, 8.0]]],
|
||||
indices = [2, 1, 2],
|
||||
expected_out = [[[7.0, 8.0]], [[6.0]], [[7.0, 8.0]]],
|
||||
out_grad = [[[0.1, 0.2]], [[0.3]], [[0.4, 0.5]]],
|
||||
expected_grad = [[[0, 0], [0, 0, 0]], [[0.3]], [[0.5, 0.7]]]),
|
||||
# params.shape=[0, None]; indices.shape=[0]
|
||||
dict(
|
||||
params = [],
|
||||
indices = [],
|
||||
expected_out = [],
|
||||
out_grad = [],
|
||||
expected_grad = [],
|
||||
params_ragged_rank = 1),
|
||||
# params.shape=[2, None, 2]; indices.shape=[3]
|
||||
dict(
|
||||
params = [[[1, 2], [3, 4]], [], [[5, 6]]],
|
||||
indices = [1, 1, 2, 0, 2],
|
||||
expected_out = [[], [], [[5, 6]], [[1, 2], [3, 4]], [[5, 6]]],
|
||||
out_grad = [[], [], [[1, 2]], [[3, 4], [5, 6]], [[7, 7]]],
|
||||
expected_grad = [[[3, 4], [5, 6]], [], [[8, 9]]],
|
||||
params_ragged_rank = 1),
|
||||
]) # pyformat: disable
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradient(self,
|
||||
params,
|
||||
indices,
|
||||
expected_out,
|
||||
out_grad,
|
||||
expected_grad,
|
||||
params_ragged_rank=None):
|
||||
"""Tests that ragged_gather generates the right gradient.
|
||||
|
||||
Args:
|
||||
params: The `params` that should be passed to `gather`.
|
||||
indices: The `indices` that should be passed to `gather`.
|
||||
expected_out: The expected value of `gather(params, indices)`.
|
||||
`expected_out.shape = indices.shape + params.shape[1:]`.
|
||||
out_grad: The value that should be fed in as the gradient for `out`
|
||||
when testing the gradient of `ragged_gather`. Must have the same
|
||||
shape as `expected_out`.
|
||||
expected_grad: The expected gradient for that should be returned for
|
||||
`params`. Must have hte same shape as `params`.
|
||||
params_ragged_rank: The ragged_rank of `params`.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
|
||||
params = ragged_factory_ops.constant(
|
||||
params, dtype=dtypes.float32, ragged_rank=params_ragged_rank)
|
||||
indices = constant_op.constant(indices, dtype=dtypes.int32)
|
||||
out_ragged_rank = params.ragged_rank + indices.shape.ndims - 1
|
||||
out_grad = ragged_factory_ops.constant(
|
||||
out_grad, dtype=dtypes.float32, ragged_rank=out_ragged_rank)
|
||||
expected_out = ragged_factory_ops.constant(
|
||||
expected_out, dtype=dtypes.float32, ragged_rank=out_ragged_rank)
|
||||
expected_grad = ragged_factory_ops.constant(
|
||||
expected_grad,
|
||||
dtype=dtypes.float32,
|
||||
ragged_rank=params.ragged_rank)
|
||||
|
||||
out = ragged_gather_ops.gather(params, indices)
|
||||
self.assertAllClose(out, expected_out)
|
||||
|
||||
grads = gradients_impl.gradients(
|
||||
out.flat_values,
|
||||
(params.nested_row_splits + (params.flat_values, indices,)),
|
||||
out_grad.flat_values)
|
||||
param_nested_splits_grads = grads[:-2]
|
||||
params_flat_values_grad = grads[-2]
|
||||
indices_grad = grads[-1]
|
||||
self.assertEqual(indices_grad, None)
|
||||
for splits_grad in param_nested_splits_grads:
|
||||
self.assertEqual(splits_grad, None)
|
||||
|
||||
# The gradient generates an IndexedSlices; convert back to a normal Tensor.
|
||||
self.assertIsInstance(params_flat_values_grad, indexed_slices.IndexedSlices)
|
||||
params_flat_values_grad = ops.convert_to_tensor(params_flat_values_grad)
|
||||
|
||||
params_grad = params.with_flat_values(params_flat_values_grad)
|
||||
self.assertAllClose(params_grad, expected_grad, atol=2e-6, rtol=2e-6)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
@ -19,12 +19,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_ragged_array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
|
||||
|
||||
@ -261,3 +263,37 @@ def gather_nd(params, indices, batch_dims=0, name=None):
|
||||
|
||||
# Gather using the flattened index tuples and params.
|
||||
return gather(flattened_params, flattened_index_tuples)
|
||||
|
||||
|
||||
#===============================================================================
|
||||
# Gradient for the RaggedGather kernel
|
||||
#===============================================================================
|
||||
@ops.RegisterGradient('RaggedGather')
|
||||
def _ragged_gather_grad(op, *grads):
|
||||
"""Gradient for RaggedGather op."""
|
||||
param_nested_splits = op.inputs[:-2]
|
||||
param_inner_values = op.inputs[-2]
|
||||
indices = op.inputs[-1]
|
||||
grad_inner_values = grads[-1]
|
||||
|
||||
# For each row in `params`, find the range of values in `params.inner_values`
|
||||
# that is covered by that row. In particular, the values in row `i` are
|
||||
# `param_inner_values[combined_splits[i]:combined_splits[i+1]`.
|
||||
combined_splits = param_nested_splits[0]
|
||||
for row_splits in param_nested_splits[1:]:
|
||||
combined_splits = array_ops.gather(row_splits, combined_splits)
|
||||
|
||||
# The outer dimensions of `indices` correspond 1:1 with the outer dimensions
|
||||
# of `ragged_grad` that are encoded by `grad_nested_splits`. Thus, the
|
||||
# flattened `indices` correspond 1:1 with `grad_inner_values`.
|
||||
flat_indices = array_ops.reshape(indices, [-1])
|
||||
|
||||
# Build an IndexedSlices where the values are taken from `flat_grad`.
|
||||
grad_indices = ragged_math_ops.range(
|
||||
array_ops.gather(combined_splits, flat_indices),
|
||||
array_ops.gather(combined_splits[1:], flat_indices)).values
|
||||
|
||||
param_inner_values_grad = indexed_slices.IndexedSlices(
|
||||
values=grad_inner_values, indices=grad_indices,
|
||||
dense_shape=array_ops.shape(param_inner_values))
|
||||
return [None for _ in param_nested_splits] + [param_inner_values_grad, None]
|
||||
|
Loading…
x
Reference in New Issue
Block a user