Extend the ragged version of tf.gather to support batch_dims
and axis
args.
PiperOrigin-RevId: 299158220 Change-Id: I8cac49a4e2bac64c867c0997aa8f829dc569eec4
This commit is contained in:
parent
9d7566d0cb
commit
e2ff7f453e
@ -521,6 +521,7 @@ py_test(
|
||||
name = "ragged_gather_op_test",
|
||||
srcs = ["ragged_gather_op_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
|
@ -476,12 +476,12 @@ class RaggedBatchGatherOpTest(test_util.TensorFlowTestCase,
|
||||
ragged_indices = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
indices, [0, 2, 4])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'batch_gather does not allow indices with unknown shape.'):
|
||||
with self.assertRaisesRegexp(ValueError, r'batch_dims may only be negative '
|
||||
r'if rank\(indices\) is statically known.'):
|
||||
ragged_batch_gather_ops.batch_gather(params, indices)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'batch_gather does not allow indices with unknown shape.'):
|
||||
with self.assertRaisesRegexp(ValueError, r'batch_dims may only be negative '
|
||||
r'if rank\(indices\) is statically known.'):
|
||||
ragged_batch_gather_ops.batch_gather(params, ragged_indices)
|
||||
|
||||
@parameterized.parameters(
|
||||
@ -489,7 +489,7 @@ class RaggedBatchGatherOpTest(test_util.TensorFlowTestCase,
|
||||
dict(
|
||||
params=ragged_factory_ops.constant_value([['a'], ['b'], ['c']]),
|
||||
indices=ragged_factory_ops.constant_value([[0], [0]]),
|
||||
message='Dimensions 3 and 2 are not compatible'),
|
||||
message=(r'batch shape from indices .* does not match params')),
|
||||
dict(
|
||||
params=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
|
||||
indices=ragged_factory_ops.constant_value([[[0, 0], [0, 0, 0]],
|
||||
@ -506,20 +506,21 @@ class RaggedBatchGatherOpTest(test_util.TensorFlowTestCase,
|
||||
[[0]], [[0]]]),
|
||||
indices=ragged_factory_ops.constant_value([[[0, 0]], [[0, 0, 0]],
|
||||
[[0]]]),
|
||||
error=errors.InvalidArgumentError,
|
||||
message='.*Condition x == y did not hold.*'),
|
||||
error=(ValueError, errors.InvalidArgumentError),
|
||||
message=(r'batch shape from indices .* does not match '
|
||||
r'params shape|dimension size mismatch')),
|
||||
dict(
|
||||
params=ragged_factory_ops.constant_value(['a', 'b', 'c']),
|
||||
indices=ragged_factory_ops.constant_value([[0], [0]]),
|
||||
message='batch shape from indices does not match params shape'),
|
||||
message=r'batch_dims must be less than rank\(params\)'),
|
||||
dict(
|
||||
params=ragged_factory_ops.constant_value([['a']]),
|
||||
indices=0,
|
||||
message='indices.rank must be at least 1.'),
|
||||
message='batch_dims=-1 out of bounds: expected 0<=batch_dims<0'),
|
||||
dict(
|
||||
params=ragged_factory_ops.constant_value([['a']]),
|
||||
indices=[[[0]]],
|
||||
message='batch shape from indices does not match params shape'),
|
||||
message=r'batch_dims must be less than rank\(params\)'),
|
||||
])
|
||||
def testRaggedBatchGatherStaticError(self,
|
||||
params,
|
||||
|
@ -18,13 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_gather_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_util
|
||||
|
||||
|
||||
#===============================================================================
|
||||
@ -61,64 +55,4 @@ def batch_gather(params, indices, name=None):
|
||||
>>> tf.compat.v1.batch_gather(params, indices)
|
||||
<tf.RaggedTensor [[b'b', b'c', b'a'], [], [], [b'e', b'e']]>
|
||||
"""
|
||||
if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)):
|
||||
return array_ops.batch_gather(params, indices, name)
|
||||
|
||||
with ops.name_scope(name, 'RaggedBatchGather', [params, indices]):
|
||||
params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
params, name='params')
|
||||
indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
indices, name='indices')
|
||||
params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
|
||||
indices_ndims = indices.shape.ndims
|
||||
if indices_ndims is None:
|
||||
raise ValueError(
|
||||
'batch_gather does not allow indices with unknown shape.')
|
||||
if indices_ndims == 0:
|
||||
raise ValueError('indices.rank must be at least 1.')
|
||||
|
||||
if ragged_tensor.is_ragged(indices):
|
||||
# If the outermost ragged dimension is a batch dimension, recurse.
|
||||
if indices_ndims > 2:
|
||||
if not ragged_tensor.is_ragged(params):
|
||||
raise ValueError('batch shape from indices does '
|
||||
'not match params shape')
|
||||
checks = [check_ops.assert_equal(params.row_splits, indices.row_splits)]
|
||||
with ops.control_dependencies(checks):
|
||||
return ragged_tensor.RaggedTensor.from_row_splits(
|
||||
batch_gather(params.values, indices.values), indices.row_splits,
|
||||
validate=False)
|
||||
|
||||
# Otherwise, indices is a 2D ragged tensor with 1 ragged dimension.
|
||||
else:
|
||||
# Ensure that `params` is ragged and has at least 2 dimensions.
|
||||
if not ragged_tensor.is_ragged(params):
|
||||
if params.shape.ndims is not None and params.shape.ndims < 2:
|
||||
raise ValueError('batch shape from indices does '
|
||||
'not match params shape')
|
||||
params = ragged_tensor.RaggedTensor.from_tensor(
|
||||
params, ragged_rank=1,
|
||||
row_splits_dtype=indices.row_splits.dtype)
|
||||
|
||||
# Adjust indices from within-batch to global (in params.values), and
|
||||
# then use ragged.gather to gather them.
|
||||
num_indices = indices.row_lengths()
|
||||
params_starts = params.row_starts()
|
||||
adjustments = ragged_util.repeat(params_starts, num_indices, axis=0)
|
||||
adjusted_index_values = (
|
||||
math_ops.cast(indices.values, adjustments.dtype) + adjustments)
|
||||
return ragged_tensor.RaggedTensor.from_row_splits(
|
||||
ragged_gather_ops.gather(params.values, adjusted_index_values),
|
||||
indices.row_splits, validate=False)
|
||||
|
||||
else: # params is a RaggedTensor and indices is a Tensor.
|
||||
if indices_ndims == 1:
|
||||
return ragged_gather_ops.gather(params, indices)
|
||||
elif indices_ndims == 2:
|
||||
# Adjust indices from batch-local to global (in params.values)
|
||||
adjustments = array_ops.expand_dims(params.row_starts(), 1)
|
||||
adjusted_indices = (
|
||||
math_ops.cast(indices, adjustments.dtype) + adjustments)
|
||||
return ragged_gather_ops.gather(params.values, adjusted_indices)
|
||||
else:
|
||||
raise ValueError('batch shape from indices does not match params shape')
|
||||
return ragged_gather_ops.gather(params, indices, batch_dims=-1, name=name)
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -31,106 +33,141 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_gather_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedGatherOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def testDocStringExamples(self):
|
||||
params = constant_op.constant(['a', 'b', 'c', 'd', 'e'])
|
||||
indices = constant_op.constant([3, 1, 2, 1, 0])
|
||||
ragged_params = ragged_factory_ops.constant([['a', 'b', 'c'], ['d'], [],
|
||||
['e']])
|
||||
ragged_indices = ragged_factory_ops.constant([[3, 1, 2], [1], [], [0]])
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, ragged_indices),
|
||||
[[b'd', b'b', b'c'], [b'b'], [], [b'a']])
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(ragged_params, indices),
|
||||
[[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']])
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(ragged_params, ragged_indices),
|
||||
[[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]])
|
||||
@parameterized.named_parameters([
|
||||
# Basic gather (axis=0 and batch_dims=0)
|
||||
dict(testcase_name='Params1DTensor_Indices1DTensor',
|
||||
params=['a', 'b', 'c', 'd', 'e'],
|
||||
indices=[2, 0, 2, 1],
|
||||
expected=['c', 'a', 'c', 'b']),
|
||||
dict(testcase_name='Params1DTensor_Indices2DRagged',
|
||||
params=['a', 'b', 'c', 'd', 'e'],
|
||||
indices=[[3, 1, 2], [1], [], [0]],
|
||||
expected=[['d', 'b', 'c'], ['b'], [], ['a']]),
|
||||
dict(testcase_name='Params2DRagged_Indices0DTensor',
|
||||
params=[['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']],
|
||||
indices=1,
|
||||
expected=['c', 'd', 'e']),
|
||||
dict(testcase_name='Params2DRagged_Indices1DTensor',
|
||||
params=[['a', 'b', 'c'], ['d'], [], ['e']],
|
||||
indices=[3, 1, 2, 1, 0],
|
||||
expected=[
|
||||
['e'], ['d'], [], ['d'], ['a', 'b', 'c']]),
|
||||
dict(testcase_name='Params2DRagged_Indices2DRagged',
|
||||
params=[['a', 'b', 'c'], ['d'], [], ['e']],
|
||||
indices=[[3, 1, 2], [1], [], [0]],
|
||||
expected=[
|
||||
[['e'], ['d'], []], [['d']], [], [['a', 'b', 'c']]]),
|
||||
dict(testcase_name='Params3DRagged_Indices2DTensor',
|
||||
params=[
|
||||
[['a', 'b'], []], [['c', 'd'], ['e'], ['f']], [['g']]],
|
||||
indices=[[1, 2], [0, 1], [2, 2]],
|
||||
indices_ragged_rank=0,
|
||||
expected=[
|
||||
[[['c', 'd'], ['e'], ['f']], [['g']]],
|
||||
[[['a', 'b'], []], [['c', 'd'], ['e'], ['f']]],
|
||||
[[['g']], [['g']]]]),
|
||||
dict(testcase_name='Params3DRagged_Indices3DTensor',
|
||||
params=[[['a', 'b'], []],
|
||||
[['c', 'd'], ['e'], ['f']],
|
||||
[['g']]],
|
||||
indices=[[[1, 2], [0, 1], [2, 2]], [[0, 0], [1, 2], [0, 1]]],
|
||||
indices_ragged_rank=0,
|
||||
expected=[
|
||||
[[[['c', 'd'], ['e'], ['f']], [['g']]],
|
||||
[[['a', 'b'], []], [['c', 'd'], ['e'], ['f']]],
|
||||
[[['g']], [['g']]]],
|
||||
[[[['a', 'b'], []], [['a', 'b'], []]],
|
||||
[[['c', 'd'], ['e'], ['f']], [['g']]],
|
||||
[[['a', 'b'], []], [['c', 'd'], ['e'], ['f']]]]]),
|
||||
dict(testcase_name='Params1DTensor_Indices4DRaggedRank2',
|
||||
params=['a', 'b', 'c', 'd', 'e', 'f', 'g'],
|
||||
indices=[[[[3, 4], [0, 6]], []],
|
||||
[[[2, 1], [1, 0]], [[2, 5]], [[2, 3]]],
|
||||
[[[1, 0]]]],
|
||||
indices_ragged_rank=2,
|
||||
expected=[
|
||||
[[['d', 'e'], ['a', 'g']], []],
|
||||
[[['c', 'b'], ['b', 'a']], [['c', 'f']], [['c', 'd']]],
|
||||
[[['b', 'a']]]]),
|
||||
# Batch gather (batch_dims=1)
|
||||
dict(testcase_name='Batch1D_Params2DRagged_Indices1DTensor',
|
||||
params=[['a', 'b'], ['c'], ['d', 'e', 'f', 'g'], ['h']],
|
||||
indices=[1, 0, 3, 0],
|
||||
batch_dims=1,
|
||||
expected=['b', 'c', 'g', 'h']),
|
||||
dict(testcase_name='Batch1D_Params2DRagged_Indices2DTensor',
|
||||
params=[['a', 'b'], ['c'], ['d', 'e', 'f', 'g'], ['h']],
|
||||
indices=[[1, 0], [0, 0], [3, 1], [0, 0]],
|
||||
indices_ragged_rank=0,
|
||||
batch_dims=1,
|
||||
expected=[['b', 'a'], ['c', 'c'], ['g', 'e'], ['h', 'h']]),
|
||||
dict(testcase_name='Batch1D_Params2DRagged_Indices2DRagged',
|
||||
params=[['a', 'b'], ['c'], ['d', 'e', 'f', 'g'], ['h']],
|
||||
indices=[[1, 0], [], [3, 2, 1], [0]],
|
||||
batch_dims=1,
|
||||
expected=[['b', 'a'], [], ['g', 'f', 'e'], ['h']]),
|
||||
dict(testcase_name='Batch1D_Params3DRagged_Indices3DRagged',
|
||||
params=[[['a'], ['b', 'c']],
|
||||
[],
|
||||
[['d', 'e', 'f'], ['g'], ['h', 'i'], ['j']],
|
||||
[['k']]],
|
||||
indices=[[[1, 0], []], [], [[3, 2, 1], [0]], [[0]]],
|
||||
batch_dims=1,
|
||||
expected=[[[['b', 'c'], ['a']], []],
|
||||
[],
|
||||
[[['j'], ['h', 'i'], ['g']], [['d', 'e', 'f']]],
|
||||
[[['k']]]]),
|
||||
# Batch gather (batch_dims=2)
|
||||
dict(testcase_name='Batch2D_Params3DRagged_Indices2DRagged',
|
||||
params=[[['a', 'b', 'c'], ['d', 'e'], ['f']],
|
||||
[['g'], ['h', 'i']]],
|
||||
indices=[[0, 1, 0], [0, 1]],
|
||||
batch_dims=2,
|
||||
expected=[['a', 'e', 'f'], ['g', 'i']]),
|
||||
dict(testcase_name='Batch2D_Params3DRagged_Indices3DRagged',
|
||||
params=[[['a', 'b', 'c'], ['d', 'e'], ['f']],
|
||||
[['g'], ['h', 'i']]],
|
||||
indices=[[[2, 1, 0], [1, 1], [0]], [[0], []]],
|
||||
batch_dims=2,
|
||||
expected=[[['c', 'b', 'a'], ['e', 'e'], ['f']], [['g'], []]]),
|
||||
# Batch gather (batch_dims=3)
|
||||
dict(testcase_name='Batch3D_Params4DRagged_Indices3DRagged',
|
||||
params=[[[['a', 'b', 'c'], ['d', 'e'], ['f']],
|
||||
[['g'], ['h', 'i']]], [[['j']]]],
|
||||
indices=[[[0, 1, 0], [0, 1]], [[0]]],
|
||||
batch_dims=3,
|
||||
expected=[[['a', 'e', 'f'], ['g', 'i']], [['j']]]),
|
||||
|
||||
def testTensorParamsAndTensorIndices(self):
|
||||
params = ['a', 'b', 'c', 'd', 'e']
|
||||
indices = [2, 0, 2, 1]
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices), [b'c', b'a', b'c', b'b'])
|
||||
self.assertIsInstance(ragged_gather_ops.gather(params, indices), ops.Tensor)
|
||||
|
||||
def testRaggedParamsAndTensorIndices(self):
|
||||
params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'],
|
||||
[], ['g']])
|
||||
indices = [2, 0, 2, 1]
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices),
|
||||
[[b'f'], [b'a', b'b'], [b'f'], [b'c', b'd', b'e']])
|
||||
|
||||
def testTensorParamsAndRaggedIndices(self):
|
||||
params = ['a', 'b', 'c', 'd', 'e']
|
||||
indices = ragged_factory_ops.constant([[2, 1], [1, 2, 0], [3]])
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices),
|
||||
[[b'c', b'b'], [b'b', b'c', b'a'], [b'd']])
|
||||
|
||||
def testRaggedParamsAndRaggedIndices(self):
|
||||
params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'],
|
||||
[], ['g']])
|
||||
indices = ragged_factory_ops.constant([[2, 1], [1, 2, 0], [3]])
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices),
|
||||
[[[b'f'], [b'c', b'd', b'e']], # [[p[2], p[1] ],
|
||||
[[b'c', b'd', b'e'], [b'f'], [b'a', b'b']], # [p[1], p[2], p[0]],
|
||||
[[]]] # [p[3] ]]
|
||||
) # pyformat: disable
|
||||
|
||||
def testRaggedParamsAndScalarIndices(self):
|
||||
params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'],
|
||||
[], ['g']])
|
||||
indices = 1
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices), [b'c', b'd', b'e'])
|
||||
|
||||
def test3DRaggedParamsAnd2DTensorIndices(self):
|
||||
params = ragged_factory_ops.constant([[['a', 'b'], []],
|
||||
[['c', 'd'], ['e'], ['f']], [['g']]])
|
||||
indices = [[1, 2], [0, 1], [2, 2]]
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices),
|
||||
[[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [[p1, p2],
|
||||
[[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]], # [p0, p1],
|
||||
[[[b'g']], [[b'g']]]] # [p2, p2]]
|
||||
) # pyformat: disable
|
||||
|
||||
def test3DRaggedParamsAnd3DTensorIndices(self):
|
||||
params = ragged_factory_ops.constant([[['a', 'b'], []], # p0
|
||||
[['c', 'd'], ['e'], ['f']], # p1
|
||||
[['g']] # p2
|
||||
]) # pyformat: disable
|
||||
indices = [[[1, 2], [0, 1], [2, 2]], [[0, 0], [1, 2], [0, 1]]]
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices),
|
||||
[[[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [[p1, p2],
|
||||
[[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]], # [p0, p1],
|
||||
[[[b'g']], [[b'g']]]], # [p2, p2]]
|
||||
[[[[b'a', b'b'], []], [[b'a', b'b'], []]], # [[p0, p0],
|
||||
[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [p1, p2],
|
||||
[[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]]]] # [p0, p1]]
|
||||
) # pyformat: disable
|
||||
|
||||
def testTensorParamsAnd4DRaggedIndices(self):
|
||||
def testRaggedGather(self,
|
||||
params,
|
||||
indices,
|
||||
expected,
|
||||
axis=None,
|
||||
batch_dims=0,
|
||||
params_ragged_rank=None,
|
||||
indices_ragged_rank=None):
|
||||
params = ragged_factory_ops.constant(params, ragged_rank=params_ragged_rank)
|
||||
indices = ragged_factory_ops.constant(
|
||||
[[[[3, 4], [0, 6]], []], [[[2, 1], [1, 0]], [[2, 5]], [[2, 3]]],
|
||||
[[[1, 0]]]], # pyformat: disable
|
||||
ragged_rank=2,
|
||||
inner_shape=(2,))
|
||||
params = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices),
|
||||
[[[[b'd', b'e'], [b'a', b'g']], []],
|
||||
[[[b'c', b'b'], [b'b', b'a']], [[b'c', b'f']], [[b'c', b'd']]],
|
||||
[[[b'b', b'a']]]]) # pyformat: disable
|
||||
indices, ragged_rank=indices_ragged_rank)
|
||||
actual = ragged_gather_ops.gather(
|
||||
params, indices, axis=axis, batch_dims=batch_dims)
|
||||
self.assertAllEqual(actual, self._str_to_bytes(expected))
|
||||
|
||||
def _str_to_bytes(self, x):
|
||||
if isinstance(x, list):
|
||||
return [self._str_to_bytes(v) for v in x]
|
||||
elif isinstance(x, str) and bytes is not str:
|
||||
return bytes(x, 'utf-8')
|
||||
else:
|
||||
return x
|
||||
|
||||
def testOutOfBoundsError(self):
|
||||
tensor_params = ['a', 'b', 'c']
|
||||
@ -154,7 +191,7 @@ class RaggedGatherOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
indices = constant_op.constant([0], dtype=dtypes.int64)
|
||||
indices = array_ops.placeholder_with_default(indices, None)
|
||||
self.assertRaisesRegexp(ValueError,
|
||||
r'indices\.shape\.ndims must be known statically',
|
||||
r'rank\(indices\) must be known statically',
|
||||
ragged_gather_ops.gather, params, indices)
|
||||
|
||||
# pylint: disable=bad-whitespace
|
||||
@ -272,6 +309,87 @@ class RaggedGatherOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
params_grad = params.with_flat_values(params_flat_values_grad)
|
||||
self.assertAllClose(params_grad, expected_grad, atol=2e-6, rtol=2e-6)
|
||||
|
||||
@parameterized.parameters([
|
||||
# Basic gather (batch_dims == 0, axis == 0)
|
||||
dict(params_shape=[3, 4], indices_shape=[], axis=0),
|
||||
dict(params_shape=[3, 4], indices_shape=[5], axis=0),
|
||||
dict(params_shape=[3, 4], indices_shape=[2, 5], axis=0),
|
||||
# Gather over axis (axis > 0)
|
||||
dict(params_shape=[3, 4], indices_shape=[], axis=1),
|
||||
dict(params_shape=[3, 4], indices_shape=[2], axis=1),
|
||||
dict(params_shape=[3, 4], indices_shape=[2, 5], axis=1),
|
||||
dict(params_shape=[7, 3, 1], indices_shape=[2, 4], axis=1),
|
||||
dict(params_shape=[3, 4, 5, 6], indices_shape=[2, 1, 7], axis=1),
|
||||
dict(params_shape=[7, 3, 5], indices_shape=[], axis=2),
|
||||
dict(params_shape=[7, 3, 5], indices_shape=[2], axis=2),
|
||||
dict(params_shape=[7, 3, 5], indices_shape=[4, 2], axis=2),
|
||||
dict(params_shape=[7, 3, 5, 6], indices_shape=[4, 2], axis=2),
|
||||
dict(params_shape=[7, 3, 5, 6], indices_shape=[], axis=3),
|
||||
dict(params_shape=[7, 3, 5, 6], indices_shape=[4], axis=3),
|
||||
dict(params_shape=[7, 3, 5, 6], indices_shape=[8, 4], axis=3),
|
||||
dict(params_shape=[7, 3, 5, 6], indices_shape=[2, 3, 2, 3], axis=3),
|
||||
# Batched gather (batch_dims > 0)
|
||||
dict(params_shape=[7, 3], indices_shape=[7], batch_dims=1),
|
||||
dict(params_shape=[7, 3], indices_shape=[7, 5], batch_dims=1),
|
||||
dict(params_shape=[5, 3], indices_shape=[5, 7, 4, 2], batch_dims=1),
|
||||
dict(params_shape=[2, 3, 6], indices_shape=[2], batch_dims=1),
|
||||
dict(params_shape=[7, 3, 6], indices_shape=[7, 5, 4, 2], batch_dims=1),
|
||||
dict(params_shape=[7, 3, 5], indices_shape=[7, 3], batch_dims=2),
|
||||
dict(params_shape=[7, 3, 5], indices_shape=[7, 3, 2], batch_dims=2),
|
||||
dict(params_shape=[7, 3, 5, 6], indices_shape=[7, 3, 5], batch_dims=3),
|
||||
dict(params_shape=[2, 3, 5, 6], indices_shape=[2, 3, 5, 7], batch_dims=3),
|
||||
# Batched gather with axis (axis > batch_dims > 0)
|
||||
dict(params_shape=[2, 3, 6], indices_shape=[2], axis=2, batch_dims=1),
|
||||
dict(params_shape=[2, 3, 6], indices_shape=[2, 4], axis=2, batch_dims=1),
|
||||
dict(
|
||||
params_shape=[3, 1, 6, 7], indices_shape=[3, 4], axis=3,
|
||||
batch_dims=1),
|
||||
dict(
|
||||
params_shape=[3, 2, 6, 7], indices_shape=[3, 4], axis=3,
|
||||
batch_dims=1),
|
||||
dict(
|
||||
params_shape=[2, 3, 6, 7], indices_shape=[2, 3], axis=3,
|
||||
batch_dims=2),
|
||||
])
|
||||
def testMatchesDenseGather(self,
|
||||
params_shape,
|
||||
indices_shape,
|
||||
axis=None,
|
||||
batch_dims=0):
|
||||
# Build random params & indices matrics w/ the expected shapes.
|
||||
if axis is None:
|
||||
axis = batch_dims
|
||||
params = np.random.randint(100, size=params_shape, dtype=np.int32)
|
||||
indices = np.random.randint(
|
||||
params_shape[axis], size=indices_shape, dtype=np.int32)
|
||||
|
||||
# Use array_ops.gather to get the expected value.
|
||||
expected = array_ops.gather(
|
||||
params, indices, axis=axis, batch_dims=batch_dims)
|
||||
|
||||
# Build ragged tensors with varying ragged_ranks from params & axis.
|
||||
params_tensors = [params] + [
|
||||
ragged_tensor.RaggedTensor.from_tensor(params, ragged_rank=i)
|
||||
for i in range(1, len(params_shape))
|
||||
]
|
||||
indices_tensors = [indices] + [
|
||||
ragged_tensor.RaggedTensor.from_tensor(indices, ragged_rank=i)
|
||||
for i in range(1, len(indices_shape))
|
||||
]
|
||||
|
||||
# For each combination of params & axis tensors, check that
|
||||
# ragged_gather_ops.gather matches array_ops.gather.
|
||||
for params_tensor in params_tensors:
|
||||
for indices_tensor in indices_tensors:
|
||||
actual = ragged_gather_ops.gather(
|
||||
params_tensor, indices_tensor, axis=axis, batch_dims=batch_dims)
|
||||
if isinstance(actual, ragged_tensor.RaggedTensor):
|
||||
actual = actual.to_tensor()
|
||||
self.assertAllEqual(
|
||||
expected, actual, 'params.ragged_rank=%s, indices.ragged_rank=%s' %
|
||||
(getattr(params_tensor, 'ragged_rank',
|
||||
0), getattr(indices_tensor, 'ragged_rank', 0)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
@ -33,22 +33,16 @@ from tensorflow.python.ops.ragged import ragged_tensor
|
||||
#===============================================================================
|
||||
# ragged_gather
|
||||
#===============================================================================
|
||||
# TODO(edloper): Add an `axis` argument
|
||||
def gather(params, indices, validate_indices=None, axis=0, batch_dims=0,
|
||||
def gather(params,
|
||||
indices,
|
||||
validate_indices=None,
|
||||
axis=None,
|
||||
batch_dims=0,
|
||||
name=None):
|
||||
"""Gathers ragged slices from `params` axis `0` according to `indices`.
|
||||
|
||||
Returns `RaggedTensor` output, such that:
|
||||
|
||||
```python
|
||||
output.shape = indices.shape + params.shape[1:]
|
||||
output.ragged_rank = indices.shape.ndims + params.ragged_rank
|
||||
output[i...j, d0...dn] = params[indices[i...j], d0...dn]
|
||||
```
|
||||
|
||||
`params` may be ragged. `indices` may be ragged.
|
||||
`indices` must have dtype `int32` or `int64`. If any index is out of bounds,
|
||||
then an error is returned.
|
||||
See `tf.gather` for full documentation. (This version has the same API
|
||||
as `tf.gather`, but supports ragged `params` and `indices`.)
|
||||
|
||||
Examples:
|
||||
|
||||
@ -73,8 +67,8 @@ def gather(params, indices, validate_indices=None, axis=0, batch_dims=0,
|
||||
Must have dtype `int32` or `int64`. Values must be in the range `[0,
|
||||
params.shape[0]]`.
|
||||
validate_indices: Ignored.
|
||||
axis: Must be zero.
|
||||
batch_dims: Must be zero.
|
||||
axis: The axis in `params` to gather `indices` from.
|
||||
batch_dims: The number of batch dimensions.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
@ -86,10 +80,7 @@ def gather(params, indices, validate_indices=None, axis=0, batch_dims=0,
|
||||
ValueError: If indices.shape.ndims is not known statically.
|
||||
"""
|
||||
del validate_indices
|
||||
if not isinstance(axis, int) or axis != 0:
|
||||
raise ValueError('axis != 0 is not supported for ragged gather yet.')
|
||||
if not isinstance(batch_dims, int) or batch_dims != 0:
|
||||
raise ValueError('batch_dims != 0 is not supported for ragged gather yet.')
|
||||
|
||||
with ops.name_scope(name, 'RaggedGather', [params, indices]):
|
||||
params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
params, name='params')
|
||||
@ -97,27 +88,241 @@ def gather(params, indices, validate_indices=None, axis=0, batch_dims=0,
|
||||
indices, name='indices')
|
||||
params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
|
||||
|
||||
if ragged_tensor.is_ragged(indices):
|
||||
return indices.with_values(gather(params, indices.values))
|
||||
if batch_dims != indices.shape.rank:
|
||||
batch_dims = array_ops.get_positive_axis(
|
||||
batch_dims,
|
||||
indices.shape.rank,
|
||||
axis_name='batch_dims',
|
||||
ndims_name='rank(indices)')
|
||||
if params.shape.rank is not None and batch_dims >= params.shape.rank:
|
||||
raise ValueError('batch_dims must be less than rank(params)')
|
||||
if axis is None:
|
||||
axis = batch_dims
|
||||
axis = array_ops.get_positive_axis(
|
||||
axis, params.shape.rank, ndims_name='rank(params)')
|
||||
if axis < batch_dims:
|
||||
raise ValueError('axis must be greater than or equal to batch_dims')
|
||||
if indices.shape.rank is not None:
|
||||
if not 0 <= batch_dims <= indices.shape.rank:
|
||||
raise ValueError(
|
||||
'batch_dims=%s must be between 0 and rank(indices)=%s' %
|
||||
(batch_dims, indices.shape.rank))
|
||||
|
||||
if not ragged_tensor.is_ragged(params):
|
||||
return array_ops.gather(params, indices)
|
||||
return _gather(params, indices, axis, batch_dims)
|
||||
|
||||
|
||||
def _gather(params, indices, axis, batch_dims):
|
||||
"""Helper that implements the body for ragged gather().
|
||||
|
||||
Assumes that `params` and `indices` have been converted to tensors or
|
||||
ragged tensors, and that `axis` and `batch_dims` have been normalized to
|
||||
be positive. (So these conversions & normalizations can be skipped in
|
||||
recursive calls to _gather).
|
||||
|
||||
Args:
|
||||
params: The tensor from which to gather values.
|
||||
indices: The indices of values to gather.
|
||||
axis: The axis in `params` to gather `indices` from.
|
||||
batch_dims: The number of batch dimensions.
|
||||
|
||||
Returns:
|
||||
A potentially ragged tensor.
|
||||
"""
|
||||
params_is_ragged = ragged_tensor.is_ragged(params)
|
||||
indices_is_ragged = ragged_tensor.is_ragged(indices)
|
||||
|
||||
if not (params_is_ragged or indices_is_ragged):
|
||||
return array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
|
||||
|
||||
if batch_dims > 0:
|
||||
return _batch_gather(params, indices, axis, batch_dims)
|
||||
|
||||
if axis > 0:
|
||||
return _axis_gather(params, indices, axis)
|
||||
|
||||
if indices_is_ragged:
|
||||
return indices.with_values(_gather(params, indices.values, 0, 0))
|
||||
|
||||
indices = ops.convert_to_tensor(indices)
|
||||
if indices.shape.ndims is None:
|
||||
raise ValueError('indices.shape.ndims must be known statically')
|
||||
raise ValueError('rank(indices) must be known statically')
|
||||
|
||||
out_ragged_rank = indices.shape.ndims + len(params.nested_row_splits) - 1
|
||||
result = gen_ragged_array_ops.ragged_gather(
|
||||
indices=indices,
|
||||
params_dense_values=params.flat_values,
|
||||
params_nested_splits=params.nested_row_splits,
|
||||
OUTPUT_RAGGED_RANK=indices.shape.ndims + len(params.nested_row_splits) -
|
||||
1)
|
||||
OUTPUT_RAGGED_RANK=out_ragged_rank)
|
||||
|
||||
# Compose the RaggedTensor from splits & values.
|
||||
return ragged_tensor.RaggedTensor.from_nested_row_splits(
|
||||
result = ragged_tensor.RaggedTensor.from_nested_row_splits(
|
||||
result.output_dense_values, result.output_nested_splits, validate=False)
|
||||
|
||||
# Inject uniform_row_lengths into the result RaggedTensors for dimensions
|
||||
# corresponding to dense outer dimensions of `indices`.
|
||||
# TODO(edloper): Change this to construct the result using RowPartition
|
||||
# objects instead, so we don't need to modify private variables.
|
||||
if indices.shape.ndims > 1:
|
||||
target = result
|
||||
indices_shape = array_ops.shape(indices, out_type=params.row_splits.dtype)
|
||||
shape_cumprod = math_ops.cumprod(indices_shape)
|
||||
for dim in range(indices.shape.ndims - 1):
|
||||
# pylint: disable=protected-access
|
||||
target._cached_nrows = shape_cumprod[dim]
|
||||
target._uniform_row_length = indices_shape[dim + 1]
|
||||
target = target.values
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _batch_gather(params, indices, axis, batch_dims):
|
||||
"""Helper that implements the body for ragged gather() when batch_dims>0.
|
||||
|
||||
Args:
|
||||
params: The tensor from which to gather values.
|
||||
indices: The indices of values to gather.
|
||||
axis: The axis in `params` to gather `indices` from.
|
||||
batch_dims: The number of batch dimensions.
|
||||
|
||||
Returns:
|
||||
A potentially ragged tensor.
|
||||
"""
|
||||
# Perform static checks that `params` and `indices` have compatible batch
|
||||
# dimensions. Note: we do not perform *runtime* checks that `params` and
|
||||
# `indices` actually have the same row-splits (because we wish to avoid the
|
||||
# runtime cost of those checks). If `params` and `indices` are
|
||||
# incompatible, the resulting `RaggedTensor` may be nonsensical.
|
||||
if not params.shape[:batch_dims].is_compatible_with(
|
||||
indices.shape[:batch_dims]):
|
||||
raise ValueError('batch shape from indices %s does not match params '
|
||||
'shape %s' % (indices.shape[:batch_dims], params.shape))
|
||||
|
||||
if batch_dims > 1:
|
||||
# Convert params & indices to ragged tensors.
|
||||
if not isinstance(params, ragged_tensor.RaggedTensor):
|
||||
if indices.uniform_row_length is None:
|
||||
raise ValueError(
|
||||
'batch shape from indices does not match params shape: ragged '
|
||||
'indices dimension corresponds to uniform params dimension')
|
||||
params = ragged_tensor.RaggedTensor.from_tensor(
|
||||
params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype)
|
||||
if not isinstance(indices, ragged_tensor.RaggedTensor):
|
||||
if params.uniform_row_length is None:
|
||||
raise ValueError(
|
||||
'batch shape from indices does not match params shape: ragged '
|
||||
'params dimension corresponds to uniform indices dimension')
|
||||
indices = ragged_tensor.RaggedTensor.from_tensor(
|
||||
indices, ragged_rank=1, row_splits_dtype=params.row_splits.dtype)
|
||||
# Flatten the two outer batch dimensions into a single batch dimension,
|
||||
# and recurse.
|
||||
return params.with_values(
|
||||
_gather(params.values, indices.values, axis - 1, batch_dims - 1))
|
||||
|
||||
if axis > 1:
|
||||
# Convert an axis dimension into a batch dimension, by adding a dimension
|
||||
# to `indices`, and tiling it to match `params`. E.g., if `params`
|
||||
# had shape `[B, P1, P2]`, and `indices` had shape `[B, I1, I2]`, then we
|
||||
# tile `indices` to have shape `[B, P1, I1, I2]`. That way, we can treat
|
||||
# the `P1` dimension as a batch dimension.
|
||||
if not isinstance(indices, ragged_tensor.RaggedTensor):
|
||||
adjusted_indices = params.with_values(
|
||||
array_ops.repeat(indices, params.row_lengths(), 0))
|
||||
else:
|
||||
if not isinstance(params, ragged_tensor.RaggedTensor):
|
||||
params = ragged_tensor.RaggedTensor.from_tensor(
|
||||
params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype)
|
||||
adjusted_indices = _gather(
|
||||
indices,
|
||||
params.with_values(
|
||||
array_ops.repeat(
|
||||
math_ops.range(params.nrows()), params.row_lengths())), 0, 0)
|
||||
return _batch_gather(params, adjusted_indices, axis, batch_dims + 1)
|
||||
|
||||
if indices.shape.rank is None:
|
||||
raise ValueError('rank(indices) must be known statically')
|
||||
|
||||
assert batch_dims == 1
|
||||
# If params.shape=[B, P1...PN] and indices.shape=[B, I1...IM], then:
|
||||
#
|
||||
# output[b, i1...im, p2...pn] =
|
||||
# params[b, indices[b, i1...im], p2...pn]
|
||||
#
|
||||
# We construct `output` by flattening `params`, adjusting the `indices` to
|
||||
# point into that flattened list, and recursively calling `gather`.
|
||||
flat_params = _flatten_dims_0_and_1(params)
|
||||
adjustments = _row_starts(params, indices.dtype) # offset for each batch
|
||||
# increase adjustments's rank so it broadcasts w/ the outer dim of indices
|
||||
adjustments = _increase_rank_to(adjustments, indices.shape.ndims)
|
||||
adjusted_indices = indices + adjustments
|
||||
return _gather(flat_params, adjusted_indices, axis - 1, 0)
|
||||
|
||||
|
||||
def _axis_gather(params, indices, axis):
|
||||
"""Helper that implements ragged gather when axis>0 and batch_dims==0.
|
||||
|
||||
Args:
|
||||
params: The tensor from which to gather values.
|
||||
indices: The indices of values to gather.
|
||||
axis: The axis in `params` to gather `indices` from.
|
||||
|
||||
Returns:
|
||||
A potentially ragged tensor.
|
||||
"""
|
||||
if axis > 1:
|
||||
if not isinstance(params, ragged_tensor.RaggedTensor):
|
||||
params = ragged_tensor.RaggedTensor.from_tensor(
|
||||
params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype)
|
||||
# Recurse, using the flattened params (but do not flatten indices).
|
||||
return params.with_values(_gather(params.values, indices, axis - 1, 0))
|
||||
|
||||
if indices.shape.rank is None:
|
||||
raise ValueError('rank(indices) must be known statically')
|
||||
if (isinstance(params, ragged_tensor.RaggedTensor) and
|
||||
params.uniform_row_length is None):
|
||||
raise ValueError('axis may not be a ragged dimension')
|
||||
|
||||
assert axis == 1
|
||||
# If params.shape=[P1...PN] and indices.shape=[I1...IM], then:
|
||||
#
|
||||
# output[p1, i1...im, p3...pn] =
|
||||
# params[p1, indices[i1...im], p3...pn]
|
||||
#
|
||||
# We construct `output` by flattening `params`, adjusting the `indices` to
|
||||
# have one additional dimension, and to point into that flattened list, and
|
||||
# recursively calling `gather`.
|
||||
flat_params = _flatten_dims_0_and_1(params)
|
||||
adjustments = _row_starts(params, indices.dtype) # offset for each batch
|
||||
adjustments = _increase_rank_to(adjustments, indices.shape.ndims + 1)
|
||||
adjusted_indices = indices + adjustments
|
||||
return _gather(flat_params, adjusted_indices, axis - 1, 0)
|
||||
|
||||
|
||||
def _flatten_dims_0_and_1(t):
|
||||
"""Returns a copy of `t` with the outer two dimensions merged."""
|
||||
if isinstance(t, ragged_tensor.RaggedTensor):
|
||||
return t.values
|
||||
else:
|
||||
t_shape = array_ops.shape(t)
|
||||
return array_ops.reshape(t, array_ops.concat([[-1], t_shape[2:]], axis=0))
|
||||
|
||||
|
||||
def _row_starts(t, dtype):
|
||||
"""Returns the start indices for the rows in `t`."""
|
||||
if isinstance(t, ragged_tensor.RaggedTensor):
|
||||
return math_ops.cast(t.row_starts(), dtype)
|
||||
else:
|
||||
t_shape = array_ops.shape(t, out_type=dtype)
|
||||
return math_ops.range(t_shape[0]) * t_shape[1]
|
||||
|
||||
|
||||
def _increase_rank_to(t, rank):
|
||||
"""Adds *trailing* size-1 dimensions to `t` until it has the given rank."""
|
||||
if isinstance(t, ragged_tensor.RaggedTensor):
|
||||
return t.with_values(_increase_rank_to(t, rank - 1))
|
||||
else:
|
||||
old_dims = array_ops.shape(t)
|
||||
new_dims = array_ops.ones([rank - array_ops.rank(t)], old_dims.dtype)
|
||||
new_shape = array_ops.concat([old_dims, new_dims], axis=0)
|
||||
return array_ops.reshape(t, new_shape)
|
||||
|
||||
|
||||
#===============================================================================
|
||||
# ragged.gather_nd
|
||||
|
Loading…
Reference in New Issue
Block a user