Update RaggedTensor.__getitem__ to (1) allow indexing into all uniform dimensions, and (2) preserve uniform dimensions. In particular:

(1) When slicing a ragged dimension where uniform_row_length is defined, preserve uniform_row_length.
(2) Allow indexing into a ragged dimension where uniform_row_length is defined.

PiperOrigin-RevId: 295789259
Change-Id: I4bfacf02b8941aa9e96ca944bcc997b7669810c6
This commit is contained in:
Edward Loper 2020-02-18 11:58:36 -08:00 committed by TensorFlower Gardener
parent f079f59af2
commit 6a202bc94b
2 changed files with 163 additions and 21 deletions

View File

@ -19,9 +19,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_gather_ops
@ -41,9 +44,6 @@ def ragged_tensor_getitem(self, key):
principles of Python ("In the face of ambiguity, refuse the temptation to
guess"), we simply disallow this operation.
Any dimensions added by `array_ops.newaxis` will be ragged if the following
dimension is ragged.
Args:
self: The RaggedTensor to slice.
key: Indicates which piece of the RaggedTensor to return, using standard
@ -134,15 +134,26 @@ def _ragged_getitem(rt_input, key_list):
# that puts all values in a single row.
if row_key is array_ops.newaxis:
inner_rt = _ragged_getitem(rt_input, inner_keys)
nsplits = array_ops.shape(inner_rt.row_splits,
out_type=inner_rt.row_splits.dtype)[0]
return ragged_tensor.RaggedTensor.from_row_splits(
inner_rt, array_ops.stack([0, nsplits - 1]), validate=False)
nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0)
if nsplits.value is not None:
nsplits = nsplits.value
else:
nsplits = array_ops.shape(inner_rt.row_splits,
out_type=inner_rt.row_splits.dtype)[0]
return ragged_tensor.RaggedTensor.from_uniform_row_length(
inner_rt, nsplits - 1, nrows=1, validate=False)
# Slicing a range of rows: first slice the outer dimension, and then
# call `_ragged_getitem_inner_dimensions` to handle the inner keys.
if isinstance(row_key, slice):
sliced_rt_input = _slice_ragged_row_dimension(rt_input, row_key)
if rt_input.uniform_row_length is not None:
# If the inner dimension has uniform_row_length, then preserve it (by
# re-wrapping the values in a new RaggedTensor). Note that the row
# length won't have changed, since we're slicing a range of rows (and not
# slicing the rows themselves).
sliced_rt_input = ragged_tensor.RaggedTensor.from_uniform_row_length(
sliced_rt_input.values, rt_input.uniform_row_length)
return _ragged_getitem_inner_dimensions(sliced_rt_input, inner_keys)
# Indexing a single row: slice values to get the indicated row, and then
@ -245,11 +256,14 @@ def _ragged_getitem_inner_dimensions(rt_input, key_list):
# RaggedTensor that puts each value in its own row.
if column_key is array_ops.newaxis:
inner_rt = _ragged_getitem_inner_dimensions(rt_input, key_list[1:])
nsplits = array_ops.shape(inner_rt.row_splits,
out_type=inner_rt.row_splits.dtype)[0]
return ragged_tensor.RaggedTensor.from_row_splits(inner_rt,
math_ops.range(nsplits),
validate=False)
nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0)
if nsplits.value is not None:
nsplits = nsplits.value
else:
nsplits = array_ops.shape(inner_rt.row_splits,
out_type=inner_rt.row_splits.dtype)[0]
return ragged_tensor.RaggedTensor.from_uniform_row_length(
inner_rt, 1, nrows=nsplits - 1, validate=False)
# Slicing a range of columns in a ragged inner dimension. We use a
# recursive call to process the values, and then assemble a RaggedTensor
@ -292,15 +306,59 @@ def _ragged_getitem_inner_dimensions(rt_input, key_list):
lambda: math_ops.maximum(limits + stop_offset, lower_bound))
inner_rt = _build_ragged_tensor_from_value_ranges(
inner_rt_starts, inner_rt_limits, column_key.step, rt_input.values)
# If the row dimension is uniform, then calculate the new
# uniform_row_length, and rebuild inner_rt using that uniform_row_lengths.
if rt_input.uniform_row_length is not None:
new_row_length = _slice_length(rt_input.uniform_row_length, column_key)
inner_rt = ragged_tensor.RaggedTensor.from_uniform_row_length(
inner_rt.values, new_row_length, rt_input.nrows())
return inner_rt.with_values(
_ragged_getitem_inner_dimensions(inner_rt.values, key_list[1:]))
# Indexing a single column in a ragged inner dimension: raise an Exception.
# See RaggedTensor.__getitem__.__doc__ for an explanation of why indexing
# into a ragged inner dimension is problematic.
else:
if rt_input.uniform_row_length is None:
raise ValueError("Cannot index into an inner ragged dimension.")
# Indexing a single column in a uniform inner dimension: check that the
# given index is in-bounds, and then use a strided slice over rt_input.values
# to take the indicated element from each row.
row_length = rt_input.uniform_row_length
column_key = math_ops.cast(column_key, row_length.dtype)
oob_err_msg = "Index out of bounds when indexing into a ragged tensor"
oob_checks = [
check_ops.assert_greater_equal(
column_key, -row_length, message=oob_err_msg),
check_ops.assert_less(column_key, row_length, message=oob_err_msg),
]
with ops.control_dependencies(oob_checks):
offset = _if_ge_zero(column_key, lambda: column_key,
lambda: row_length + column_key)
sliced_rt = rt_input.values[offset::row_length]
return _ragged_getitem_inner_dimensions(sliced_rt, key_list[1:])
def _slice_length(value_length, slice_key):
"""Computes the number of elements in a slice of a value with a given length.
Returns the equivalent of: `len(range(value_length)[slice_key])`
Args:
value_length: Scalar int `Tensor`: the length of the value being sliced.
slice_key: A `slice` object used to slice elements from the the value.
Returns:
The number of elements in the sliced value.
"""
# Note: we could compute the slice length without creating a zeros tensor
# with some variant of (stop-start)//step, but doing so would require more
# ops (for checking bounds, handling negative indices, negative step sizes,
# etc); and we expect this to be an uncommon operation, so we use this
# simpler implementation.
zeros = array_ops.zeros(value_length, dtype=dtypes.bool)
return array_ops.size(zeros[slice_key], out_type=value_length.dtype)
def _expand_ellipsis(key_list, num_remaining_dims):
"""Expands the ellipsis at the start of `key_list`.

View File

@ -116,6 +116,12 @@ EXAMPLE_RAGGED_TENSOR_4D_VALUES = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
[11, 12], [13, 14], [15, 16], [17, 18],
[19, 20]]
# Example 3D ragged tensor with uniform_row_lengths.
EXAMPLE_RAGGED_TENSOR_3D = [[[1, 2, 3], [4], [5, 6]], [[], [7, 8, 9], []]]
EXAMPLE_RAGGED_TENSOR_3D_ROWLEN = 3
EXAMPLE_RAGGED_TENSOR_3D_SPLITS = [0, 3, 4, 6, 6, 9, 9]
EXAMPLE_RAGGED_TENSOR_3D_VALUES = [1, 2, 3, 4, 5, 6, 7, 8, 9]
def int32array(values):
return np.array(values, dtype=np.int32)
@ -837,7 +843,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
# RaggedTensor.__getitem__
#=============================================================================
def _TestGetItem(self, rt, slice_spec, expected):
def _TestGetItem(self, rt, slice_spec, expected, expected_shape=None):
"""Helper function for testing RaggedTensor.__getitem__.
Checks that calling `rt.__getitem__(slice_spec) returns the expected value.
@ -855,6 +861,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
slice_spec: The slice spec.
expected: The expected value of rt.__getitem__(slice_spec), as a python
list; or an exception class.
expected_shape: The expected shape for `rt.__getitem__(slice_spec)`.
"""
tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True)
tensor_slice_spec2 = _make_tensor_slice_spec(slice_spec, False)
@ -864,13 +871,18 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertAllEqual(value1, expected, 'slice_spec=%s' % (slice_spec,))
self.assertAllEqual(value2, expected, 'slice_spec=%s' % (slice_spec,))
self.assertAllEqual(value3, expected, 'slice_spec=%s' % (slice_spec,))
if expected_shape is not None:
value1.shape.assert_is_compatible_with(expected_shape)
value2.shape.assert_is_compatible_with(expected_shape)
value3.shape.assert_is_compatible_with(expected_shape)
def _TestGetItemException(self, rt, slice_spec, expected, message):
"""Helper function for testing RaggedTensor.__getitem__ exceptions."""
tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True)
self.assertRaisesRegexp(expected, message, rt.__getitem__, slice_spec)
self.assertRaisesRegexp(expected, message, rt.__getitem__,
tensor_slice_spec1)
tensor_slice_spec = _make_tensor_slice_spec(slice_spec, True)
with self.assertRaisesRegexp(expected, message):
self.evaluate(rt.__getitem__(slice_spec))
with self.assertRaisesRegexp(expected, message):
self.evaluate(rt.__getitem__(tensor_slice_spec))
@parameterized.parameters(
# Tests for rt[i]
@ -1225,12 +1237,84 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertEqual(rt_newaxis3.ragged_rank, 2)
self.assertEqual(rt_newaxis4.ragged_rank, 2)
self.assertEqual(rt_newaxis0.shape.as_list(), [1, None, None, None, 2])
self.assertEqual(rt_newaxis1.shape.as_list(), [2, None, None, None, 2])
self.assertEqual(rt_newaxis2.shape.as_list(), [2, None, None, None, 2])
self.assertEqual(rt_newaxis0.shape.as_list(), [1, 2, None, None, 2])
self.assertEqual(rt_newaxis1.shape.as_list(), [2, 1, None, None, 2])
self.assertEqual(rt_newaxis2.shape.as_list(), [2, None, 1, None, 2])
self.assertEqual(rt_newaxis3.shape.as_list(), [2, None, None, 1, 2])
self.assertEqual(rt_newaxis4.shape.as_list(), [2, None, None, 2, 1])
@parameterized.parameters(
# EXAMPLE_RAGGED_TENSOR_3D.shape = [2, 3, None]
# Indexing into uniform_row_splits dimension:
(SLICE_BUILDER[:, 1], [r[1] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[:, 2], [r[2] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[:, -2], [r[-2] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[:, -3], [r[-3] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[1:, 2], [r[2] for r in EXAMPLE_RAGGED_TENSOR_3D[1:]],
[1, None]),
(SLICE_BUILDER[:, 1, 1:], [r[1][1:] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[1:, 1, 1:],
[r[1][1:] for r in EXAMPLE_RAGGED_TENSOR_3D[1:]],
[1, None]),
# Slicing uniform_row_splits dimension:
(SLICE_BUILDER[:, 2:], [r[2:] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 1, None]),
(SLICE_BUILDER[:, -2:], [r[-2:] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 2, None]),
(SLICE_BUILDER[:, :, 1:],
[[c[1:] for c in r] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 3, None]),
(SLICE_BUILDER[:, 5:], [r[5:] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 0, None]),
# Slicing uniform_row_splits dimension with a non-default step size:
(SLICE_BUILDER[:, ::2], [r[::2] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 2, None]),
(SLICE_BUILDER[:, ::-1], [r[::-1] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 3, None]),
)
def testRaggedTensorGetItemWithUniformRowLength(self, slice_spec, expected,
expected_shape):
"""Test that rt.__getitem__(slice_spec) == expected."""
rt = RaggedTensor.from_uniform_row_length(
RaggedTensor.from_row_splits(
EXAMPLE_RAGGED_TENSOR_3D_VALUES,
EXAMPLE_RAGGED_TENSOR_3D_SPLITS),
EXAMPLE_RAGGED_TENSOR_3D_ROWLEN)
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_3D)
self.assertIsNot(rt.uniform_row_length, None)
self._TestGetItem(rt, slice_spec, expected, expected_shape)
# If the result is 3D, then check that it still has a uniform row length:
actual = rt.__getitem__(slice_spec)
if actual.shape.rank == 3:
self.assertIsNot(actual.uniform_row_length, None)
self.assertAllEqual(actual.uniform_row_length, expected_shape[1])
@parameterized.parameters(
(SLICE_BUILDER[:, 3], errors.InvalidArgumentError, 'out of bounds'),
(SLICE_BUILDER[:, -4], errors.InvalidArgumentError, 'out of bounds'),
(SLICE_BUILDER[:, 10], errors.InvalidArgumentError, 'out of bounds'),
(SLICE_BUILDER[:, -10], errors.InvalidArgumentError, 'out of bounds'),
)
def testRaggedTensorGetItemErrorsWithUniformRowLength(self, slice_spec,
expected, message):
"""Test that rt.__getitem__(slice_spec) == expected."""
rt = RaggedTensor.from_uniform_row_length(
RaggedTensor.from_row_splits(
EXAMPLE_RAGGED_TENSOR_3D_VALUES,
EXAMPLE_RAGGED_TENSOR_3D_SPLITS),
EXAMPLE_RAGGED_TENSOR_3D_ROWLEN)
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_3D)
self._TestGetItemException(rt, slice_spec, expected, message)
#=============================================================================
# RaggedTensor.__str__
#=============================================================================