diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index efef01de764..bb70a706b4b 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -521,6 +521,7 @@ py_test( name = "ragged_gather_op_test", srcs = ["ragged_gather_op_test.py"], python_version = "PY3", + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":ragged_factory_ops", diff --git a/tensorflow/python/ops/ragged/ragged_batch_gather_op_test.py b/tensorflow/python/ops/ragged/ragged_batch_gather_op_test.py index 5f6a7018b85..549a660ee12 100644 --- a/tensorflow/python/ops/ragged/ragged_batch_gather_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_batch_gather_op_test.py @@ -476,12 +476,12 @@ class RaggedBatchGatherOpTest(test_util.TensorFlowTestCase, ragged_indices = ragged_tensor.RaggedTensor.from_row_splits( indices, [0, 2, 4]) - with self.assertRaisesRegexp( - ValueError, 'batch_gather does not allow indices with unknown shape.'): + with self.assertRaisesRegexp(ValueError, r'batch_dims may only be negative ' + r'if rank\(indices\) is statically known.'): ragged_batch_gather_ops.batch_gather(params, indices) - with self.assertRaisesRegexp( - ValueError, 'batch_gather does not allow indices with unknown shape.'): + with self.assertRaisesRegexp(ValueError, r'batch_dims may only be negative ' + r'if rank\(indices\) is statically known.'): ragged_batch_gather_ops.batch_gather(params, ragged_indices) @parameterized.parameters( @@ -489,7 +489,7 @@ class RaggedBatchGatherOpTest(test_util.TensorFlowTestCase, dict( params=ragged_factory_ops.constant_value([['a'], ['b'], ['c']]), indices=ragged_factory_ops.constant_value([[0], [0]]), - message='Dimensions 3 and 2 are not compatible'), + message=(r'batch shape from indices .* does not match params')), dict( params=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], indices=ragged_factory_ops.constant_value([[[0, 0], [0, 0, 0]], @@ -506,20 +506,21 @@ class RaggedBatchGatherOpTest(test_util.TensorFlowTestCase, [[0]], [[0]]]), indices=ragged_factory_ops.constant_value([[[0, 0]], [[0, 0, 0]], [[0]]]), - error=errors.InvalidArgumentError, - message='.*Condition x == y did not hold.*'), + error=(ValueError, errors.InvalidArgumentError), + message=(r'batch shape from indices .* does not match ' + r'params shape|dimension size mismatch')), dict( params=ragged_factory_ops.constant_value(['a', 'b', 'c']), indices=ragged_factory_ops.constant_value([[0], [0]]), - message='batch shape from indices does not match params shape'), + message=r'batch_dims must be less than rank\(params\)'), dict( params=ragged_factory_ops.constant_value([['a']]), indices=0, - message='indices.rank must be at least 1.'), + message='batch_dims=-1 out of bounds: expected 0<=batch_dims<0'), dict( params=ragged_factory_ops.constant_value([['a']]), indices=[[[0]]], - message='batch shape from indices does not match params shape'), + message=r'batch_dims must be less than rank\(params\)'), ]) def testRaggedBatchGatherStaticError(self, params, diff --git a/tensorflow/python/ops/ragged/ragged_batch_gather_ops.py b/tensorflow/python/ops/ragged/ragged_batch_gather_ops.py index 8f4271fc821..7f9d483663c 100644 --- a/tensorflow/python/ops/ragged/ragged_batch_gather_ops.py +++ b/tensorflow/python/ops/ragged/ragged_batch_gather_ops.py @@ -18,13 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_gather_ops -from tensorflow.python.ops.ragged import ragged_tensor -from tensorflow.python.ops.ragged import ragged_util #=============================================================================== @@ -61,64 +55,4 @@ def batch_gather(params, indices, name=None): >>> tf.compat.v1.batch_gather(params, indices) """ - if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)): - return array_ops.batch_gather(params, indices, name) - - with ops.name_scope(name, 'RaggedBatchGather', [params, indices]): - params = ragged_tensor.convert_to_tensor_or_ragged_tensor( - params, name='params') - indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( - indices, name='indices') - params, indices = ragged_tensor.match_row_splits_dtypes(params, indices) - indices_ndims = indices.shape.ndims - if indices_ndims is None: - raise ValueError( - 'batch_gather does not allow indices with unknown shape.') - if indices_ndims == 0: - raise ValueError('indices.rank must be at least 1.') - - if ragged_tensor.is_ragged(indices): - # If the outermost ragged dimension is a batch dimension, recurse. - if indices_ndims > 2: - if not ragged_tensor.is_ragged(params): - raise ValueError('batch shape from indices does ' - 'not match params shape') - checks = [check_ops.assert_equal(params.row_splits, indices.row_splits)] - with ops.control_dependencies(checks): - return ragged_tensor.RaggedTensor.from_row_splits( - batch_gather(params.values, indices.values), indices.row_splits, - validate=False) - - # Otherwise, indices is a 2D ragged tensor with 1 ragged dimension. - else: - # Ensure that `params` is ragged and has at least 2 dimensions. - if not ragged_tensor.is_ragged(params): - if params.shape.ndims is not None and params.shape.ndims < 2: - raise ValueError('batch shape from indices does ' - 'not match params shape') - params = ragged_tensor.RaggedTensor.from_tensor( - params, ragged_rank=1, - row_splits_dtype=indices.row_splits.dtype) - - # Adjust indices from within-batch to global (in params.values), and - # then use ragged.gather to gather them. - num_indices = indices.row_lengths() - params_starts = params.row_starts() - adjustments = ragged_util.repeat(params_starts, num_indices, axis=0) - adjusted_index_values = ( - math_ops.cast(indices.values, adjustments.dtype) + adjustments) - return ragged_tensor.RaggedTensor.from_row_splits( - ragged_gather_ops.gather(params.values, adjusted_index_values), - indices.row_splits, validate=False) - - else: # params is a RaggedTensor and indices is a Tensor. - if indices_ndims == 1: - return ragged_gather_ops.gather(params, indices) - elif indices_ndims == 2: - # Adjust indices from batch-local to global (in params.values) - adjustments = array_ops.expand_dims(params.row_starts(), 1) - adjusted_indices = ( - math_ops.cast(indices, adjustments.dtype) + adjustments) - return ragged_gather_ops.gather(params.values, adjusted_indices) - else: - raise ValueError('batch shape from indices does not match params shape') + return ragged_gather_ops.gather(params, indices, batch_dims=-1, name=name) diff --git a/tensorflow/python/ops/ragged/ragged_gather_op_test.py b/tensorflow/python/ops/ragged/ragged_gather_op_test.py index 99f6316c26c..928e634989c 100644 --- a/tensorflow/python/ops/ragged/ragged_gather_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_gather_op_test.py @@ -20,6 +20,8 @@ from __future__ import print_function from absl.testing import parameterized +import numpy as np + from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -31,106 +33,141 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_gather_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import googletest +@test_util.run_all_in_graph_and_eager_modes class RaggedGatherOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): - def testDocStringExamples(self): - params = constant_op.constant(['a', 'b', 'c', 'd', 'e']) - indices = constant_op.constant([3, 1, 2, 1, 0]) - ragged_params = ragged_factory_ops.constant([['a', 'b', 'c'], ['d'], [], - ['e']]) - ragged_indices = ragged_factory_ops.constant([[3, 1, 2], [1], [], [0]]) - self.assertAllEqual( - ragged_gather_ops.gather(params, ragged_indices), - [[b'd', b'b', b'c'], [b'b'], [], [b'a']]) - self.assertAllEqual( - ragged_gather_ops.gather(ragged_params, indices), - [[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']]) - self.assertAllEqual( - ragged_gather_ops.gather(ragged_params, ragged_indices), - [[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]]) + @parameterized.named_parameters([ + # Basic gather (axis=0 and batch_dims=0) + dict(testcase_name='Params1DTensor_Indices1DTensor', + params=['a', 'b', 'c', 'd', 'e'], + indices=[2, 0, 2, 1], + expected=['c', 'a', 'c', 'b']), + dict(testcase_name='Params1DTensor_Indices2DRagged', + params=['a', 'b', 'c', 'd', 'e'], + indices=[[3, 1, 2], [1], [], [0]], + expected=[['d', 'b', 'c'], ['b'], [], ['a']]), + dict(testcase_name='Params2DRagged_Indices0DTensor', + params=[['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']], + indices=1, + expected=['c', 'd', 'e']), + dict(testcase_name='Params2DRagged_Indices1DTensor', + params=[['a', 'b', 'c'], ['d'], [], ['e']], + indices=[3, 1, 2, 1, 0], + expected=[ + ['e'], ['d'], [], ['d'], ['a', 'b', 'c']]), + dict(testcase_name='Params2DRagged_Indices2DRagged', + params=[['a', 'b', 'c'], ['d'], [], ['e']], + indices=[[3, 1, 2], [1], [], [0]], + expected=[ + [['e'], ['d'], []], [['d']], [], [['a', 'b', 'c']]]), + dict(testcase_name='Params3DRagged_Indices2DTensor', + params=[ + [['a', 'b'], []], [['c', 'd'], ['e'], ['f']], [['g']]], + indices=[[1, 2], [0, 1], [2, 2]], + indices_ragged_rank=0, + expected=[ + [[['c', 'd'], ['e'], ['f']], [['g']]], + [[['a', 'b'], []], [['c', 'd'], ['e'], ['f']]], + [[['g']], [['g']]]]), + dict(testcase_name='Params3DRagged_Indices3DTensor', + params=[[['a', 'b'], []], + [['c', 'd'], ['e'], ['f']], + [['g']]], + indices=[[[1, 2], [0, 1], [2, 2]], [[0, 0], [1, 2], [0, 1]]], + indices_ragged_rank=0, + expected=[ + [[[['c', 'd'], ['e'], ['f']], [['g']]], + [[['a', 'b'], []], [['c', 'd'], ['e'], ['f']]], + [[['g']], [['g']]]], + [[[['a', 'b'], []], [['a', 'b'], []]], + [[['c', 'd'], ['e'], ['f']], [['g']]], + [[['a', 'b'], []], [['c', 'd'], ['e'], ['f']]]]]), + dict(testcase_name='Params1DTensor_Indices4DRaggedRank2', + params=['a', 'b', 'c', 'd', 'e', 'f', 'g'], + indices=[[[[3, 4], [0, 6]], []], + [[[2, 1], [1, 0]], [[2, 5]], [[2, 3]]], + [[[1, 0]]]], + indices_ragged_rank=2, + expected=[ + [[['d', 'e'], ['a', 'g']], []], + [[['c', 'b'], ['b', 'a']], [['c', 'f']], [['c', 'd']]], + [[['b', 'a']]]]), + # Batch gather (batch_dims=1) + dict(testcase_name='Batch1D_Params2DRagged_Indices1DTensor', + params=[['a', 'b'], ['c'], ['d', 'e', 'f', 'g'], ['h']], + indices=[1, 0, 3, 0], + batch_dims=1, + expected=['b', 'c', 'g', 'h']), + dict(testcase_name='Batch1D_Params2DRagged_Indices2DTensor', + params=[['a', 'b'], ['c'], ['d', 'e', 'f', 'g'], ['h']], + indices=[[1, 0], [0, 0], [3, 1], [0, 0]], + indices_ragged_rank=0, + batch_dims=1, + expected=[['b', 'a'], ['c', 'c'], ['g', 'e'], ['h', 'h']]), + dict(testcase_name='Batch1D_Params2DRagged_Indices2DRagged', + params=[['a', 'b'], ['c'], ['d', 'e', 'f', 'g'], ['h']], + indices=[[1, 0], [], [3, 2, 1], [0]], + batch_dims=1, + expected=[['b', 'a'], [], ['g', 'f', 'e'], ['h']]), + dict(testcase_name='Batch1D_Params3DRagged_Indices3DRagged', + params=[[['a'], ['b', 'c']], + [], + [['d', 'e', 'f'], ['g'], ['h', 'i'], ['j']], + [['k']]], + indices=[[[1, 0], []], [], [[3, 2, 1], [0]], [[0]]], + batch_dims=1, + expected=[[[['b', 'c'], ['a']], []], + [], + [[['j'], ['h', 'i'], ['g']], [['d', 'e', 'f']]], + [[['k']]]]), + # Batch gather (batch_dims=2) + dict(testcase_name='Batch2D_Params3DRagged_Indices2DRagged', + params=[[['a', 'b', 'c'], ['d', 'e'], ['f']], + [['g'], ['h', 'i']]], + indices=[[0, 1, 0], [0, 1]], + batch_dims=2, + expected=[['a', 'e', 'f'], ['g', 'i']]), + dict(testcase_name='Batch2D_Params3DRagged_Indices3DRagged', + params=[[['a', 'b', 'c'], ['d', 'e'], ['f']], + [['g'], ['h', 'i']]], + indices=[[[2, 1, 0], [1, 1], [0]], [[0], []]], + batch_dims=2, + expected=[[['c', 'b', 'a'], ['e', 'e'], ['f']], [['g'], []]]), + # Batch gather (batch_dims=3) + dict(testcase_name='Batch3D_Params4DRagged_Indices3DRagged', + params=[[[['a', 'b', 'c'], ['d', 'e'], ['f']], + [['g'], ['h', 'i']]], [[['j']]]], + indices=[[[0, 1, 0], [0, 1]], [[0]]], + batch_dims=3, + expected=[[['a', 'e', 'f'], ['g', 'i']], [['j']]]), - def testTensorParamsAndTensorIndices(self): - params = ['a', 'b', 'c', 'd', 'e'] - indices = [2, 0, 2, 1] - self.assertAllEqual( - ragged_gather_ops.gather(params, indices), [b'c', b'a', b'c', b'b']) - self.assertIsInstance(ragged_gather_ops.gather(params, indices), ops.Tensor) - - def testRaggedParamsAndTensorIndices(self): - params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'], - [], ['g']]) - indices = [2, 0, 2, 1] - self.assertAllEqual( - ragged_gather_ops.gather(params, indices), - [[b'f'], [b'a', b'b'], [b'f'], [b'c', b'd', b'e']]) - - def testTensorParamsAndRaggedIndices(self): - params = ['a', 'b', 'c', 'd', 'e'] - indices = ragged_factory_ops.constant([[2, 1], [1, 2, 0], [3]]) - self.assertAllEqual( - ragged_gather_ops.gather(params, indices), - [[b'c', b'b'], [b'b', b'c', b'a'], [b'd']]) - - def testRaggedParamsAndRaggedIndices(self): - params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'], - [], ['g']]) - indices = ragged_factory_ops.constant([[2, 1], [1, 2, 0], [3]]) - self.assertAllEqual( - ragged_gather_ops.gather(params, indices), - [[[b'f'], [b'c', b'd', b'e']], # [[p[2], p[1] ], - [[b'c', b'd', b'e'], [b'f'], [b'a', b'b']], # [p[1], p[2], p[0]], - [[]]] # [p[3] ]] - ) # pyformat: disable - - def testRaggedParamsAndScalarIndices(self): - params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'], - [], ['g']]) - indices = 1 - self.assertAllEqual( - ragged_gather_ops.gather(params, indices), [b'c', b'd', b'e']) - - def test3DRaggedParamsAnd2DTensorIndices(self): - params = ragged_factory_ops.constant([[['a', 'b'], []], - [['c', 'd'], ['e'], ['f']], [['g']]]) - indices = [[1, 2], [0, 1], [2, 2]] - self.assertAllEqual( - ragged_gather_ops.gather(params, indices), - [[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [[p1, p2], - [[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]], # [p0, p1], - [[[b'g']], [[b'g']]]] # [p2, p2]] - ) # pyformat: disable - - def test3DRaggedParamsAnd3DTensorIndices(self): - params = ragged_factory_ops.constant([[['a', 'b'], []], # p0 - [['c', 'd'], ['e'], ['f']], # p1 - [['g']] # p2 - ]) # pyformat: disable - indices = [[[1, 2], [0, 1], [2, 2]], [[0, 0], [1, 2], [0, 1]]] - self.assertAllEqual( - ragged_gather_ops.gather(params, indices), - [[[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [[p1, p2], - [[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]], # [p0, p1], - [[[b'g']], [[b'g']]]], # [p2, p2]] - [[[[b'a', b'b'], []], [[b'a', b'b'], []]], # [[p0, p0], - [[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [p1, p2], - [[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]]]] # [p0, p1]] - ) # pyformat: disable - - def testTensorParamsAnd4DRaggedIndices(self): + ]) # pyformat: disable + def testRaggedGather(self, + params, + indices, + expected, + axis=None, + batch_dims=0, + params_ragged_rank=None, + indices_ragged_rank=None): + params = ragged_factory_ops.constant(params, ragged_rank=params_ragged_rank) indices = ragged_factory_ops.constant( - [[[[3, 4], [0, 6]], []], [[[2, 1], [1, 0]], [[2, 5]], [[2, 3]]], - [[[1, 0]]]], # pyformat: disable - ragged_rank=2, - inner_shape=(2,)) - params = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] - self.assertAllEqual( - ragged_gather_ops.gather(params, indices), - [[[[b'd', b'e'], [b'a', b'g']], []], - [[[b'c', b'b'], [b'b', b'a']], [[b'c', b'f']], [[b'c', b'd']]], - [[[b'b', b'a']]]]) # pyformat: disable + indices, ragged_rank=indices_ragged_rank) + actual = ragged_gather_ops.gather( + params, indices, axis=axis, batch_dims=batch_dims) + self.assertAllEqual(actual, self._str_to_bytes(expected)) + + def _str_to_bytes(self, x): + if isinstance(x, list): + return [self._str_to_bytes(v) for v in x] + elif isinstance(x, str) and bytes is not str: + return bytes(x, 'utf-8') + else: + return x def testOutOfBoundsError(self): tensor_params = ['a', 'b', 'c'] @@ -154,7 +191,7 @@ class RaggedGatherOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): indices = constant_op.constant([0], dtype=dtypes.int64) indices = array_ops.placeholder_with_default(indices, None) self.assertRaisesRegexp(ValueError, - r'indices\.shape\.ndims must be known statically', + r'rank\(indices\) must be known statically', ragged_gather_ops.gather, params, indices) # pylint: disable=bad-whitespace @@ -272,6 +309,87 @@ class RaggedGatherOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): params_grad = params.with_flat_values(params_flat_values_grad) self.assertAllClose(params_grad, expected_grad, atol=2e-6, rtol=2e-6) + @parameterized.parameters([ + # Basic gather (batch_dims == 0, axis == 0) + dict(params_shape=[3, 4], indices_shape=[], axis=0), + dict(params_shape=[3, 4], indices_shape=[5], axis=0), + dict(params_shape=[3, 4], indices_shape=[2, 5], axis=0), + # Gather over axis (axis > 0) + dict(params_shape=[3, 4], indices_shape=[], axis=1), + dict(params_shape=[3, 4], indices_shape=[2], axis=1), + dict(params_shape=[3, 4], indices_shape=[2, 5], axis=1), + dict(params_shape=[7, 3, 1], indices_shape=[2, 4], axis=1), + dict(params_shape=[3, 4, 5, 6], indices_shape=[2, 1, 7], axis=1), + dict(params_shape=[7, 3, 5], indices_shape=[], axis=2), + dict(params_shape=[7, 3, 5], indices_shape=[2], axis=2), + dict(params_shape=[7, 3, 5], indices_shape=[4, 2], axis=2), + dict(params_shape=[7, 3, 5, 6], indices_shape=[4, 2], axis=2), + dict(params_shape=[7, 3, 5, 6], indices_shape=[], axis=3), + dict(params_shape=[7, 3, 5, 6], indices_shape=[4], axis=3), + dict(params_shape=[7, 3, 5, 6], indices_shape=[8, 4], axis=3), + dict(params_shape=[7, 3, 5, 6], indices_shape=[2, 3, 2, 3], axis=3), + # Batched gather (batch_dims > 0) + dict(params_shape=[7, 3], indices_shape=[7], batch_dims=1), + dict(params_shape=[7, 3], indices_shape=[7, 5], batch_dims=1), + dict(params_shape=[5, 3], indices_shape=[5, 7, 4, 2], batch_dims=1), + dict(params_shape=[2, 3, 6], indices_shape=[2], batch_dims=1), + dict(params_shape=[7, 3, 6], indices_shape=[7, 5, 4, 2], batch_dims=1), + dict(params_shape=[7, 3, 5], indices_shape=[7, 3], batch_dims=2), + dict(params_shape=[7, 3, 5], indices_shape=[7, 3, 2], batch_dims=2), + dict(params_shape=[7, 3, 5, 6], indices_shape=[7, 3, 5], batch_dims=3), + dict(params_shape=[2, 3, 5, 6], indices_shape=[2, 3, 5, 7], batch_dims=3), + # Batched gather with axis (axis > batch_dims > 0) + dict(params_shape=[2, 3, 6], indices_shape=[2], axis=2, batch_dims=1), + dict(params_shape=[2, 3, 6], indices_shape=[2, 4], axis=2, batch_dims=1), + dict( + params_shape=[3, 1, 6, 7], indices_shape=[3, 4], axis=3, + batch_dims=1), + dict( + params_shape=[3, 2, 6, 7], indices_shape=[3, 4], axis=3, + batch_dims=1), + dict( + params_shape=[2, 3, 6, 7], indices_shape=[2, 3], axis=3, + batch_dims=2), + ]) + def testMatchesDenseGather(self, + params_shape, + indices_shape, + axis=None, + batch_dims=0): + # Build random params & indices matrics w/ the expected shapes. + if axis is None: + axis = batch_dims + params = np.random.randint(100, size=params_shape, dtype=np.int32) + indices = np.random.randint( + params_shape[axis], size=indices_shape, dtype=np.int32) + + # Use array_ops.gather to get the expected value. + expected = array_ops.gather( + params, indices, axis=axis, batch_dims=batch_dims) + + # Build ragged tensors with varying ragged_ranks from params & axis. + params_tensors = [params] + [ + ragged_tensor.RaggedTensor.from_tensor(params, ragged_rank=i) + for i in range(1, len(params_shape)) + ] + indices_tensors = [indices] + [ + ragged_tensor.RaggedTensor.from_tensor(indices, ragged_rank=i) + for i in range(1, len(indices_shape)) + ] + + # For each combination of params & axis tensors, check that + # ragged_gather_ops.gather matches array_ops.gather. + for params_tensor in params_tensors: + for indices_tensor in indices_tensors: + actual = ragged_gather_ops.gather( + params_tensor, indices_tensor, axis=axis, batch_dims=batch_dims) + if isinstance(actual, ragged_tensor.RaggedTensor): + actual = actual.to_tensor() + self.assertAllEqual( + expected, actual, 'params.ragged_rank=%s, indices.ragged_rank=%s' % + (getattr(params_tensor, 'ragged_rank', + 0), getattr(indices_tensor, 'ragged_rank', 0))) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/ops/ragged/ragged_gather_ops.py b/tensorflow/python/ops/ragged/ragged_gather_ops.py index 9b8fdb21c11..11f429a3695 100644 --- a/tensorflow/python/ops/ragged/ragged_gather_ops.py +++ b/tensorflow/python/ops/ragged/ragged_gather_ops.py @@ -33,22 +33,16 @@ from tensorflow.python.ops.ragged import ragged_tensor #=============================================================================== # ragged_gather #=============================================================================== -# TODO(edloper): Add an `axis` argument -def gather(params, indices, validate_indices=None, axis=0, batch_dims=0, +def gather(params, + indices, + validate_indices=None, + axis=None, + batch_dims=0, name=None): """Gathers ragged slices from `params` axis `0` according to `indices`. - Returns `RaggedTensor` output, such that: - - ```python - output.shape = indices.shape + params.shape[1:] - output.ragged_rank = indices.shape.ndims + params.ragged_rank - output[i...j, d0...dn] = params[indices[i...j], d0...dn] - ``` - - `params` may be ragged. `indices` may be ragged. - `indices` must have dtype `int32` or `int64`. If any index is out of bounds, - then an error is returned. + See `tf.gather` for full documentation. (This version has the same API + as `tf.gather`, but supports ragged `params` and `indices`.) Examples: @@ -73,8 +67,8 @@ def gather(params, indices, validate_indices=None, axis=0, batch_dims=0, Must have dtype `int32` or `int64`. Values must be in the range `[0, params.shape[0]]`. validate_indices: Ignored. - axis: Must be zero. - batch_dims: Must be zero. + axis: The axis in `params` to gather `indices` from. + batch_dims: The number of batch dimensions. name: A name for the operation (optional). Returns: @@ -86,10 +80,7 @@ def gather(params, indices, validate_indices=None, axis=0, batch_dims=0, ValueError: If indices.shape.ndims is not known statically. """ del validate_indices - if not isinstance(axis, int) or axis != 0: - raise ValueError('axis != 0 is not supported for ragged gather yet.') - if not isinstance(batch_dims, int) or batch_dims != 0: - raise ValueError('batch_dims != 0 is not supported for ragged gather yet.') + with ops.name_scope(name, 'RaggedGather', [params, indices]): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params') @@ -97,26 +88,240 @@ def gather(params, indices, validate_indices=None, axis=0, batch_dims=0, indices, name='indices') params, indices = ragged_tensor.match_row_splits_dtypes(params, indices) - if ragged_tensor.is_ragged(indices): - return indices.with_values(gather(params, indices.values)) + if batch_dims != indices.shape.rank: + batch_dims = array_ops.get_positive_axis( + batch_dims, + indices.shape.rank, + axis_name='batch_dims', + ndims_name='rank(indices)') + if params.shape.rank is not None and batch_dims >= params.shape.rank: + raise ValueError('batch_dims must be less than rank(params)') + if axis is None: + axis = batch_dims + axis = array_ops.get_positive_axis( + axis, params.shape.rank, ndims_name='rank(params)') + if axis < batch_dims: + raise ValueError('axis must be greater than or equal to batch_dims') + if indices.shape.rank is not None: + if not 0 <= batch_dims <= indices.shape.rank: + raise ValueError( + 'batch_dims=%s must be between 0 and rank(indices)=%s' % + (batch_dims, indices.shape.rank)) - if not ragged_tensor.is_ragged(params): - return array_ops.gather(params, indices) + return _gather(params, indices, axis, batch_dims) - indices = ops.convert_to_tensor(indices) - if indices.shape.ndims is None: - raise ValueError('indices.shape.ndims must be known statically') - result = gen_ragged_array_ops.ragged_gather( - indices=indices, - params_dense_values=params.flat_values, - params_nested_splits=params.nested_row_splits, - OUTPUT_RAGGED_RANK=indices.shape.ndims + len(params.nested_row_splits) - - 1) +def _gather(params, indices, axis, batch_dims): + """Helper that implements the body for ragged gather(). - # Compose the RaggedTensor from splits & values. - return ragged_tensor.RaggedTensor.from_nested_row_splits( - result.output_dense_values, result.output_nested_splits, validate=False) + Assumes that `params` and `indices` have been converted to tensors or + ragged tensors, and that `axis` and `batch_dims` have been normalized to + be positive. (So these conversions & normalizations can be skipped in + recursive calls to _gather). + + Args: + params: The tensor from which to gather values. + indices: The indices of values to gather. + axis: The axis in `params` to gather `indices` from. + batch_dims: The number of batch dimensions. + + Returns: + A potentially ragged tensor. + """ + params_is_ragged = ragged_tensor.is_ragged(params) + indices_is_ragged = ragged_tensor.is_ragged(indices) + + if not (params_is_ragged or indices_is_ragged): + return array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims) + + if batch_dims > 0: + return _batch_gather(params, indices, axis, batch_dims) + + if axis > 0: + return _axis_gather(params, indices, axis) + + if indices_is_ragged: + return indices.with_values(_gather(params, indices.values, 0, 0)) + + if indices.shape.ndims is None: + raise ValueError('rank(indices) must be known statically') + + out_ragged_rank = indices.shape.ndims + len(params.nested_row_splits) - 1 + result = gen_ragged_array_ops.ragged_gather( + indices=indices, + params_dense_values=params.flat_values, + params_nested_splits=params.nested_row_splits, + OUTPUT_RAGGED_RANK=out_ragged_rank) + + result = ragged_tensor.RaggedTensor.from_nested_row_splits( + result.output_dense_values, result.output_nested_splits, validate=False) + + # Inject uniform_row_lengths into the result RaggedTensors for dimensions + # corresponding to dense outer dimensions of `indices`. + # TODO(edloper): Change this to construct the result using RowPartition + # objects instead, so we don't need to modify private variables. + if indices.shape.ndims > 1: + target = result + indices_shape = array_ops.shape(indices, out_type=params.row_splits.dtype) + shape_cumprod = math_ops.cumprod(indices_shape) + for dim in range(indices.shape.ndims - 1): + # pylint: disable=protected-access + target._cached_nrows = shape_cumprod[dim] + target._uniform_row_length = indices_shape[dim + 1] + target = target.values + + return result + + +def _batch_gather(params, indices, axis, batch_dims): + """Helper that implements the body for ragged gather() when batch_dims>0. + + Args: + params: The tensor from which to gather values. + indices: The indices of values to gather. + axis: The axis in `params` to gather `indices` from. + batch_dims: The number of batch dimensions. + + Returns: + A potentially ragged tensor. + """ + # Perform static checks that `params` and `indices` have compatible batch + # dimensions. Note: we do not perform *runtime* checks that `params` and + # `indices` actually have the same row-splits (because we wish to avoid the + # runtime cost of those checks). If `params` and `indices` are + # incompatible, the resulting `RaggedTensor` may be nonsensical. + if not params.shape[:batch_dims].is_compatible_with( + indices.shape[:batch_dims]): + raise ValueError('batch shape from indices %s does not match params ' + 'shape %s' % (indices.shape[:batch_dims], params.shape)) + + if batch_dims > 1: + # Convert params & indices to ragged tensors. + if not isinstance(params, ragged_tensor.RaggedTensor): + if indices.uniform_row_length is None: + raise ValueError( + 'batch shape from indices does not match params shape: ragged ' + 'indices dimension corresponds to uniform params dimension') + params = ragged_tensor.RaggedTensor.from_tensor( + params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype) + if not isinstance(indices, ragged_tensor.RaggedTensor): + if params.uniform_row_length is None: + raise ValueError( + 'batch shape from indices does not match params shape: ragged ' + 'params dimension corresponds to uniform indices dimension') + indices = ragged_tensor.RaggedTensor.from_tensor( + indices, ragged_rank=1, row_splits_dtype=params.row_splits.dtype) + # Flatten the two outer batch dimensions into a single batch dimension, + # and recurse. + return params.with_values( + _gather(params.values, indices.values, axis - 1, batch_dims - 1)) + + if axis > 1: + # Convert an axis dimension into a batch dimension, by adding a dimension + # to `indices`, and tiling it to match `params`. E.g., if `params` + # had shape `[B, P1, P2]`, and `indices` had shape `[B, I1, I2]`, then we + # tile `indices` to have shape `[B, P1, I1, I2]`. That way, we can treat + # the `P1` dimension as a batch dimension. + if not isinstance(indices, ragged_tensor.RaggedTensor): + adjusted_indices = params.with_values( + array_ops.repeat(indices, params.row_lengths(), 0)) + else: + if not isinstance(params, ragged_tensor.RaggedTensor): + params = ragged_tensor.RaggedTensor.from_tensor( + params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype) + adjusted_indices = _gather( + indices, + params.with_values( + array_ops.repeat( + math_ops.range(params.nrows()), params.row_lengths())), 0, 0) + return _batch_gather(params, adjusted_indices, axis, batch_dims + 1) + + if indices.shape.rank is None: + raise ValueError('rank(indices) must be known statically') + + assert batch_dims == 1 + # If params.shape=[B, P1...PN] and indices.shape=[B, I1...IM], then: + # + # output[b, i1...im, p2...pn] = + # params[b, indices[b, i1...im], p2...pn] + # + # We construct `output` by flattening `params`, adjusting the `indices` to + # point into that flattened list, and recursively calling `gather`. + flat_params = _flatten_dims_0_and_1(params) + adjustments = _row_starts(params, indices.dtype) # offset for each batch + # increase adjustments's rank so it broadcasts w/ the outer dim of indices + adjustments = _increase_rank_to(adjustments, indices.shape.ndims) + adjusted_indices = indices + adjustments + return _gather(flat_params, adjusted_indices, axis - 1, 0) + + +def _axis_gather(params, indices, axis): + """Helper that implements ragged gather when axis>0 and batch_dims==0. + + Args: + params: The tensor from which to gather values. + indices: The indices of values to gather. + axis: The axis in `params` to gather `indices` from. + + Returns: + A potentially ragged tensor. + """ + if axis > 1: + if not isinstance(params, ragged_tensor.RaggedTensor): + params = ragged_tensor.RaggedTensor.from_tensor( + params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype) + # Recurse, using the flattened params (but do not flatten indices). + return params.with_values(_gather(params.values, indices, axis - 1, 0)) + + if indices.shape.rank is None: + raise ValueError('rank(indices) must be known statically') + if (isinstance(params, ragged_tensor.RaggedTensor) and + params.uniform_row_length is None): + raise ValueError('axis may not be a ragged dimension') + + assert axis == 1 + # If params.shape=[P1...PN] and indices.shape=[I1...IM], then: + # + # output[p1, i1...im, p3...pn] = + # params[p1, indices[i1...im], p3...pn] + # + # We construct `output` by flattening `params`, adjusting the `indices` to + # have one additional dimension, and to point into that flattened list, and + # recursively calling `gather`. + flat_params = _flatten_dims_0_and_1(params) + adjustments = _row_starts(params, indices.dtype) # offset for each batch + adjustments = _increase_rank_to(adjustments, indices.shape.ndims + 1) + adjusted_indices = indices + adjustments + return _gather(flat_params, adjusted_indices, axis - 1, 0) + + +def _flatten_dims_0_and_1(t): + """Returns a copy of `t` with the outer two dimensions merged.""" + if isinstance(t, ragged_tensor.RaggedTensor): + return t.values + else: + t_shape = array_ops.shape(t) + return array_ops.reshape(t, array_ops.concat([[-1], t_shape[2:]], axis=0)) + + +def _row_starts(t, dtype): + """Returns the start indices for the rows in `t`.""" + if isinstance(t, ragged_tensor.RaggedTensor): + return math_ops.cast(t.row_starts(), dtype) + else: + t_shape = array_ops.shape(t, out_type=dtype) + return math_ops.range(t_shape[0]) * t_shape[1] + + +def _increase_rank_to(t, rank): + """Adds *trailing* size-1 dimensions to `t` until it has the given rank.""" + if isinstance(t, ragged_tensor.RaggedTensor): + return t.with_values(_increase_rank_to(t, rank - 1)) + else: + old_dims = array_ops.shape(t) + new_dims = array_ops.ones([rank - array_ops.rank(t)], old_dims.dtype) + new_shape = array_ops.concat([old_dims, new_dims], axis=0) + return array_ops.reshape(t, new_shape) #===============================================================================