Update RaggedTensor.__getitem__ to (1) allow indexing into all uniform dimensions, and (2) preserve uniform dimensions. In particular:

(1) When slicing a ragged dimension where uniform_row_length is defined, preserve uniform_row_length.
(2) Allow indexing into a ragged dimension where uniform_row_length is defined.

PiperOrigin-RevId: 295789259
Change-Id: I4bfacf02b8941aa9e96ca944bcc997b7669810c6
This commit is contained in:
Edward Loper 2020-02-18 11:58:36 -08:00 committed by TensorFlower Gardener
parent f079f59af2
commit 6a202bc94b
2 changed files with 163 additions and 21 deletions

View File

@ -19,9 +19,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_gather_ops from tensorflow.python.ops.ragged import ragged_gather_ops
@ -41,9 +44,6 @@ def ragged_tensor_getitem(self, key):
principles of Python ("In the face of ambiguity, refuse the temptation to principles of Python ("In the face of ambiguity, refuse the temptation to
guess"), we simply disallow this operation. guess"), we simply disallow this operation.
Any dimensions added by `array_ops.newaxis` will be ragged if the following
dimension is ragged.
Args: Args:
self: The RaggedTensor to slice. self: The RaggedTensor to slice.
key: Indicates which piece of the RaggedTensor to return, using standard key: Indicates which piece of the RaggedTensor to return, using standard
@ -134,15 +134,26 @@ def _ragged_getitem(rt_input, key_list):
# that puts all values in a single row. # that puts all values in a single row.
if row_key is array_ops.newaxis: if row_key is array_ops.newaxis:
inner_rt = _ragged_getitem(rt_input, inner_keys) inner_rt = _ragged_getitem(rt_input, inner_keys)
nsplits = array_ops.shape(inner_rt.row_splits, nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0)
out_type=inner_rt.row_splits.dtype)[0] if nsplits.value is not None:
return ragged_tensor.RaggedTensor.from_row_splits( nsplits = nsplits.value
inner_rt, array_ops.stack([0, nsplits - 1]), validate=False) else:
nsplits = array_ops.shape(inner_rt.row_splits,
out_type=inner_rt.row_splits.dtype)[0]
return ragged_tensor.RaggedTensor.from_uniform_row_length(
inner_rt, nsplits - 1, nrows=1, validate=False)
# Slicing a range of rows: first slice the outer dimension, and then # Slicing a range of rows: first slice the outer dimension, and then
# call `_ragged_getitem_inner_dimensions` to handle the inner keys. # call `_ragged_getitem_inner_dimensions` to handle the inner keys.
if isinstance(row_key, slice): if isinstance(row_key, slice):
sliced_rt_input = _slice_ragged_row_dimension(rt_input, row_key) sliced_rt_input = _slice_ragged_row_dimension(rt_input, row_key)
if rt_input.uniform_row_length is not None:
# If the inner dimension has uniform_row_length, then preserve it (by
# re-wrapping the values in a new RaggedTensor). Note that the row
# length won't have changed, since we're slicing a range of rows (and not
# slicing the rows themselves).
sliced_rt_input = ragged_tensor.RaggedTensor.from_uniform_row_length(
sliced_rt_input.values, rt_input.uniform_row_length)
return _ragged_getitem_inner_dimensions(sliced_rt_input, inner_keys) return _ragged_getitem_inner_dimensions(sliced_rt_input, inner_keys)
# Indexing a single row: slice values to get the indicated row, and then # Indexing a single row: slice values to get the indicated row, and then
@ -245,11 +256,14 @@ def _ragged_getitem_inner_dimensions(rt_input, key_list):
# RaggedTensor that puts each value in its own row. # RaggedTensor that puts each value in its own row.
if column_key is array_ops.newaxis: if column_key is array_ops.newaxis:
inner_rt = _ragged_getitem_inner_dimensions(rt_input, key_list[1:]) inner_rt = _ragged_getitem_inner_dimensions(rt_input, key_list[1:])
nsplits = array_ops.shape(inner_rt.row_splits, nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0)
out_type=inner_rt.row_splits.dtype)[0] if nsplits.value is not None:
return ragged_tensor.RaggedTensor.from_row_splits(inner_rt, nsplits = nsplits.value
math_ops.range(nsplits), else:
validate=False) nsplits = array_ops.shape(inner_rt.row_splits,
out_type=inner_rt.row_splits.dtype)[0]
return ragged_tensor.RaggedTensor.from_uniform_row_length(
inner_rt, 1, nrows=nsplits - 1, validate=False)
# Slicing a range of columns in a ragged inner dimension. We use a # Slicing a range of columns in a ragged inner dimension. We use a
# recursive call to process the values, and then assemble a RaggedTensor # recursive call to process the values, and then assemble a RaggedTensor
@ -292,15 +306,59 @@ def _ragged_getitem_inner_dimensions(rt_input, key_list):
lambda: math_ops.maximum(limits + stop_offset, lower_bound)) lambda: math_ops.maximum(limits + stop_offset, lower_bound))
inner_rt = _build_ragged_tensor_from_value_ranges( inner_rt = _build_ragged_tensor_from_value_ranges(
inner_rt_starts, inner_rt_limits, column_key.step, rt_input.values) inner_rt_starts, inner_rt_limits, column_key.step, rt_input.values)
# If the row dimension is uniform, then calculate the new
# uniform_row_length, and rebuild inner_rt using that uniform_row_lengths.
if rt_input.uniform_row_length is not None:
new_row_length = _slice_length(rt_input.uniform_row_length, column_key)
inner_rt = ragged_tensor.RaggedTensor.from_uniform_row_length(
inner_rt.values, new_row_length, rt_input.nrows())
return inner_rt.with_values( return inner_rt.with_values(
_ragged_getitem_inner_dimensions(inner_rt.values, key_list[1:])) _ragged_getitem_inner_dimensions(inner_rt.values, key_list[1:]))
# Indexing a single column in a ragged inner dimension: raise an Exception. # Indexing a single column in a ragged inner dimension: raise an Exception.
# See RaggedTensor.__getitem__.__doc__ for an explanation of why indexing # See RaggedTensor.__getitem__.__doc__ for an explanation of why indexing
# into a ragged inner dimension is problematic. # into a ragged inner dimension is problematic.
else: if rt_input.uniform_row_length is None:
raise ValueError("Cannot index into an inner ragged dimension.") raise ValueError("Cannot index into an inner ragged dimension.")
# Indexing a single column in a uniform inner dimension: check that the
# given index is in-bounds, and then use a strided slice over rt_input.values
# to take the indicated element from each row.
row_length = rt_input.uniform_row_length
column_key = math_ops.cast(column_key, row_length.dtype)
oob_err_msg = "Index out of bounds when indexing into a ragged tensor"
oob_checks = [
check_ops.assert_greater_equal(
column_key, -row_length, message=oob_err_msg),
check_ops.assert_less(column_key, row_length, message=oob_err_msg),
]
with ops.control_dependencies(oob_checks):
offset = _if_ge_zero(column_key, lambda: column_key,
lambda: row_length + column_key)
sliced_rt = rt_input.values[offset::row_length]
return _ragged_getitem_inner_dimensions(sliced_rt, key_list[1:])
def _slice_length(value_length, slice_key):
"""Computes the number of elements in a slice of a value with a given length.
Returns the equivalent of: `len(range(value_length)[slice_key])`
Args:
value_length: Scalar int `Tensor`: the length of the value being sliced.
slice_key: A `slice` object used to slice elements from the the value.
Returns:
The number of elements in the sliced value.
"""
# Note: we could compute the slice length without creating a zeros tensor
# with some variant of (stop-start)//step, but doing so would require more
# ops (for checking bounds, handling negative indices, negative step sizes,
# etc); and we expect this to be an uncommon operation, so we use this
# simpler implementation.
zeros = array_ops.zeros(value_length, dtype=dtypes.bool)
return array_ops.size(zeros[slice_key], out_type=value_length.dtype)
def _expand_ellipsis(key_list, num_remaining_dims): def _expand_ellipsis(key_list, num_remaining_dims):
"""Expands the ellipsis at the start of `key_list`. """Expands the ellipsis at the start of `key_list`.

View File

@ -116,6 +116,12 @@ EXAMPLE_RAGGED_TENSOR_4D_VALUES = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
[11, 12], [13, 14], [15, 16], [17, 18], [11, 12], [13, 14], [15, 16], [17, 18],
[19, 20]] [19, 20]]
# Example 3D ragged tensor with uniform_row_lengths.
EXAMPLE_RAGGED_TENSOR_3D = [[[1, 2, 3], [4], [5, 6]], [[], [7, 8, 9], []]]
EXAMPLE_RAGGED_TENSOR_3D_ROWLEN = 3
EXAMPLE_RAGGED_TENSOR_3D_SPLITS = [0, 3, 4, 6, 6, 9, 9]
EXAMPLE_RAGGED_TENSOR_3D_VALUES = [1, 2, 3, 4, 5, 6, 7, 8, 9]
def int32array(values): def int32array(values):
return np.array(values, dtype=np.int32) return np.array(values, dtype=np.int32)
@ -837,7 +843,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
# RaggedTensor.__getitem__ # RaggedTensor.__getitem__
#============================================================================= #=============================================================================
def _TestGetItem(self, rt, slice_spec, expected): def _TestGetItem(self, rt, slice_spec, expected, expected_shape=None):
"""Helper function for testing RaggedTensor.__getitem__. """Helper function for testing RaggedTensor.__getitem__.
Checks that calling `rt.__getitem__(slice_spec) returns the expected value. Checks that calling `rt.__getitem__(slice_spec) returns the expected value.
@ -855,6 +861,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
slice_spec: The slice spec. slice_spec: The slice spec.
expected: The expected value of rt.__getitem__(slice_spec), as a python expected: The expected value of rt.__getitem__(slice_spec), as a python
list; or an exception class. list; or an exception class.
expected_shape: The expected shape for `rt.__getitem__(slice_spec)`.
""" """
tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True) tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True)
tensor_slice_spec2 = _make_tensor_slice_spec(slice_spec, False) tensor_slice_spec2 = _make_tensor_slice_spec(slice_spec, False)
@ -864,13 +871,18 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertAllEqual(value1, expected, 'slice_spec=%s' % (slice_spec,)) self.assertAllEqual(value1, expected, 'slice_spec=%s' % (slice_spec,))
self.assertAllEqual(value2, expected, 'slice_spec=%s' % (slice_spec,)) self.assertAllEqual(value2, expected, 'slice_spec=%s' % (slice_spec,))
self.assertAllEqual(value3, expected, 'slice_spec=%s' % (slice_spec,)) self.assertAllEqual(value3, expected, 'slice_spec=%s' % (slice_spec,))
if expected_shape is not None:
value1.shape.assert_is_compatible_with(expected_shape)
value2.shape.assert_is_compatible_with(expected_shape)
value3.shape.assert_is_compatible_with(expected_shape)
def _TestGetItemException(self, rt, slice_spec, expected, message): def _TestGetItemException(self, rt, slice_spec, expected, message):
"""Helper function for testing RaggedTensor.__getitem__ exceptions.""" """Helper function for testing RaggedTensor.__getitem__ exceptions."""
tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True) tensor_slice_spec = _make_tensor_slice_spec(slice_spec, True)
self.assertRaisesRegexp(expected, message, rt.__getitem__, slice_spec) with self.assertRaisesRegexp(expected, message):
self.assertRaisesRegexp(expected, message, rt.__getitem__, self.evaluate(rt.__getitem__(slice_spec))
tensor_slice_spec1) with self.assertRaisesRegexp(expected, message):
self.evaluate(rt.__getitem__(tensor_slice_spec))
@parameterized.parameters( @parameterized.parameters(
# Tests for rt[i] # Tests for rt[i]
@ -1225,12 +1237,84 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertEqual(rt_newaxis3.ragged_rank, 2) self.assertEqual(rt_newaxis3.ragged_rank, 2)
self.assertEqual(rt_newaxis4.ragged_rank, 2) self.assertEqual(rt_newaxis4.ragged_rank, 2)
self.assertEqual(rt_newaxis0.shape.as_list(), [1, None, None, None, 2]) self.assertEqual(rt_newaxis0.shape.as_list(), [1, 2, None, None, 2])
self.assertEqual(rt_newaxis1.shape.as_list(), [2, None, None, None, 2]) self.assertEqual(rt_newaxis1.shape.as_list(), [2, 1, None, None, 2])
self.assertEqual(rt_newaxis2.shape.as_list(), [2, None, None, None, 2]) self.assertEqual(rt_newaxis2.shape.as_list(), [2, None, 1, None, 2])
self.assertEqual(rt_newaxis3.shape.as_list(), [2, None, None, 1, 2]) self.assertEqual(rt_newaxis3.shape.as_list(), [2, None, None, 1, 2])
self.assertEqual(rt_newaxis4.shape.as_list(), [2, None, None, 2, 1]) self.assertEqual(rt_newaxis4.shape.as_list(), [2, None, None, 2, 1])
@parameterized.parameters(
# EXAMPLE_RAGGED_TENSOR_3D.shape = [2, 3, None]
# Indexing into uniform_row_splits dimension:
(SLICE_BUILDER[:, 1], [r[1] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[:, 2], [r[2] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[:, -2], [r[-2] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[:, -3], [r[-3] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[1:, 2], [r[2] for r in EXAMPLE_RAGGED_TENSOR_3D[1:]],
[1, None]),
(SLICE_BUILDER[:, 1, 1:], [r[1][1:] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[1:, 1, 1:],
[r[1][1:] for r in EXAMPLE_RAGGED_TENSOR_3D[1:]],
[1, None]),
# Slicing uniform_row_splits dimension:
(SLICE_BUILDER[:, 2:], [r[2:] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 1, None]),
(SLICE_BUILDER[:, -2:], [r[-2:] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 2, None]),
(SLICE_BUILDER[:, :, 1:],
[[c[1:] for c in r] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 3, None]),
(SLICE_BUILDER[:, 5:], [r[5:] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 0, None]),
# Slicing uniform_row_splits dimension with a non-default step size:
(SLICE_BUILDER[:, ::2], [r[::2] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 2, None]),
(SLICE_BUILDER[:, ::-1], [r[::-1] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 3, None]),
)
def testRaggedTensorGetItemWithUniformRowLength(self, slice_spec, expected,
expected_shape):
"""Test that rt.__getitem__(slice_spec) == expected."""
rt = RaggedTensor.from_uniform_row_length(
RaggedTensor.from_row_splits(
EXAMPLE_RAGGED_TENSOR_3D_VALUES,
EXAMPLE_RAGGED_TENSOR_3D_SPLITS),
EXAMPLE_RAGGED_TENSOR_3D_ROWLEN)
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_3D)
self.assertIsNot(rt.uniform_row_length, None)
self._TestGetItem(rt, slice_spec, expected, expected_shape)
# If the result is 3D, then check that it still has a uniform row length:
actual = rt.__getitem__(slice_spec)
if actual.shape.rank == 3:
self.assertIsNot(actual.uniform_row_length, None)
self.assertAllEqual(actual.uniform_row_length, expected_shape[1])
@parameterized.parameters(
(SLICE_BUILDER[:, 3], errors.InvalidArgumentError, 'out of bounds'),
(SLICE_BUILDER[:, -4], errors.InvalidArgumentError, 'out of bounds'),
(SLICE_BUILDER[:, 10], errors.InvalidArgumentError, 'out of bounds'),
(SLICE_BUILDER[:, -10], errors.InvalidArgumentError, 'out of bounds'),
)
def testRaggedTensorGetItemErrorsWithUniformRowLength(self, slice_spec,
expected, message):
"""Test that rt.__getitem__(slice_spec) == expected."""
rt = RaggedTensor.from_uniform_row_length(
RaggedTensor.from_row_splits(
EXAMPLE_RAGGED_TENSOR_3D_VALUES,
EXAMPLE_RAGGED_TENSOR_3D_SPLITS),
EXAMPLE_RAGGED_TENSOR_3D_ROWLEN)
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_3D)
self._TestGetItemException(rt, slice_spec, expected, message)
#============================================================================= #=============================================================================
# RaggedTensor.__str__ # RaggedTensor.__str__
#============================================================================= #=============================================================================