Update RaggedTensor.__getitem__ to (1) allow indexing into all uniform dimensions, and (2) preserve uniform dimensions. In particular:
(1) When slicing a ragged dimension where uniform_row_length is defined, preserve uniform_row_length. (2) Allow indexing into a ragged dimension where uniform_row_length is defined. PiperOrigin-RevId: 295789259 Change-Id: I4bfacf02b8941aa9e96ca944bcc997b7669810c6
This commit is contained in:
parent
f079f59af2
commit
6a202bc94b
@ -19,9 +19,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_gather_ops
|
||||
@ -41,9 +44,6 @@ def ragged_tensor_getitem(self, key):
|
||||
principles of Python ("In the face of ambiguity, refuse the temptation to
|
||||
guess"), we simply disallow this operation.
|
||||
|
||||
Any dimensions added by `array_ops.newaxis` will be ragged if the following
|
||||
dimension is ragged.
|
||||
|
||||
Args:
|
||||
self: The RaggedTensor to slice.
|
||||
key: Indicates which piece of the RaggedTensor to return, using standard
|
||||
@ -134,15 +134,26 @@ def _ragged_getitem(rt_input, key_list):
|
||||
# that puts all values in a single row.
|
||||
if row_key is array_ops.newaxis:
|
||||
inner_rt = _ragged_getitem(rt_input, inner_keys)
|
||||
nsplits = array_ops.shape(inner_rt.row_splits,
|
||||
out_type=inner_rt.row_splits.dtype)[0]
|
||||
return ragged_tensor.RaggedTensor.from_row_splits(
|
||||
inner_rt, array_ops.stack([0, nsplits - 1]), validate=False)
|
||||
nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0)
|
||||
if nsplits.value is not None:
|
||||
nsplits = nsplits.value
|
||||
else:
|
||||
nsplits = array_ops.shape(inner_rt.row_splits,
|
||||
out_type=inner_rt.row_splits.dtype)[0]
|
||||
return ragged_tensor.RaggedTensor.from_uniform_row_length(
|
||||
inner_rt, nsplits - 1, nrows=1, validate=False)
|
||||
|
||||
# Slicing a range of rows: first slice the outer dimension, and then
|
||||
# call `_ragged_getitem_inner_dimensions` to handle the inner keys.
|
||||
if isinstance(row_key, slice):
|
||||
sliced_rt_input = _slice_ragged_row_dimension(rt_input, row_key)
|
||||
if rt_input.uniform_row_length is not None:
|
||||
# If the inner dimension has uniform_row_length, then preserve it (by
|
||||
# re-wrapping the values in a new RaggedTensor). Note that the row
|
||||
# length won't have changed, since we're slicing a range of rows (and not
|
||||
# slicing the rows themselves).
|
||||
sliced_rt_input = ragged_tensor.RaggedTensor.from_uniform_row_length(
|
||||
sliced_rt_input.values, rt_input.uniform_row_length)
|
||||
return _ragged_getitem_inner_dimensions(sliced_rt_input, inner_keys)
|
||||
|
||||
# Indexing a single row: slice values to get the indicated row, and then
|
||||
@ -245,11 +256,14 @@ def _ragged_getitem_inner_dimensions(rt_input, key_list):
|
||||
# RaggedTensor that puts each value in its own row.
|
||||
if column_key is array_ops.newaxis:
|
||||
inner_rt = _ragged_getitem_inner_dimensions(rt_input, key_list[1:])
|
||||
nsplits = array_ops.shape(inner_rt.row_splits,
|
||||
out_type=inner_rt.row_splits.dtype)[0]
|
||||
return ragged_tensor.RaggedTensor.from_row_splits(inner_rt,
|
||||
math_ops.range(nsplits),
|
||||
validate=False)
|
||||
nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0)
|
||||
if nsplits.value is not None:
|
||||
nsplits = nsplits.value
|
||||
else:
|
||||
nsplits = array_ops.shape(inner_rt.row_splits,
|
||||
out_type=inner_rt.row_splits.dtype)[0]
|
||||
return ragged_tensor.RaggedTensor.from_uniform_row_length(
|
||||
inner_rt, 1, nrows=nsplits - 1, validate=False)
|
||||
|
||||
# Slicing a range of columns in a ragged inner dimension. We use a
|
||||
# recursive call to process the values, and then assemble a RaggedTensor
|
||||
@ -292,15 +306,59 @@ def _ragged_getitem_inner_dimensions(rt_input, key_list):
|
||||
lambda: math_ops.maximum(limits + stop_offset, lower_bound))
|
||||
inner_rt = _build_ragged_tensor_from_value_ranges(
|
||||
inner_rt_starts, inner_rt_limits, column_key.step, rt_input.values)
|
||||
# If the row dimension is uniform, then calculate the new
|
||||
# uniform_row_length, and rebuild inner_rt using that uniform_row_lengths.
|
||||
if rt_input.uniform_row_length is not None:
|
||||
new_row_length = _slice_length(rt_input.uniform_row_length, column_key)
|
||||
inner_rt = ragged_tensor.RaggedTensor.from_uniform_row_length(
|
||||
inner_rt.values, new_row_length, rt_input.nrows())
|
||||
return inner_rt.with_values(
|
||||
_ragged_getitem_inner_dimensions(inner_rt.values, key_list[1:]))
|
||||
|
||||
# Indexing a single column in a ragged inner dimension: raise an Exception.
|
||||
# See RaggedTensor.__getitem__.__doc__ for an explanation of why indexing
|
||||
# into a ragged inner dimension is problematic.
|
||||
else:
|
||||
if rt_input.uniform_row_length is None:
|
||||
raise ValueError("Cannot index into an inner ragged dimension.")
|
||||
|
||||
# Indexing a single column in a uniform inner dimension: check that the
|
||||
# given index is in-bounds, and then use a strided slice over rt_input.values
|
||||
# to take the indicated element from each row.
|
||||
row_length = rt_input.uniform_row_length
|
||||
column_key = math_ops.cast(column_key, row_length.dtype)
|
||||
oob_err_msg = "Index out of bounds when indexing into a ragged tensor"
|
||||
oob_checks = [
|
||||
check_ops.assert_greater_equal(
|
||||
column_key, -row_length, message=oob_err_msg),
|
||||
check_ops.assert_less(column_key, row_length, message=oob_err_msg),
|
||||
]
|
||||
with ops.control_dependencies(oob_checks):
|
||||
offset = _if_ge_zero(column_key, lambda: column_key,
|
||||
lambda: row_length + column_key)
|
||||
sliced_rt = rt_input.values[offset::row_length]
|
||||
return _ragged_getitem_inner_dimensions(sliced_rt, key_list[1:])
|
||||
|
||||
|
||||
def _slice_length(value_length, slice_key):
|
||||
"""Computes the number of elements in a slice of a value with a given length.
|
||||
|
||||
Returns the equivalent of: `len(range(value_length)[slice_key])`
|
||||
|
||||
Args:
|
||||
value_length: Scalar int `Tensor`: the length of the value being sliced.
|
||||
slice_key: A `slice` object used to slice elements from the the value.
|
||||
|
||||
Returns:
|
||||
The number of elements in the sliced value.
|
||||
"""
|
||||
# Note: we could compute the slice length without creating a zeros tensor
|
||||
# with some variant of (stop-start)//step, but doing so would require more
|
||||
# ops (for checking bounds, handling negative indices, negative step sizes,
|
||||
# etc); and we expect this to be an uncommon operation, so we use this
|
||||
# simpler implementation.
|
||||
zeros = array_ops.zeros(value_length, dtype=dtypes.bool)
|
||||
return array_ops.size(zeros[slice_key], out_type=value_length.dtype)
|
||||
|
||||
|
||||
def _expand_ellipsis(key_list, num_remaining_dims):
|
||||
"""Expands the ellipsis at the start of `key_list`.
|
||||
|
@ -116,6 +116,12 @@ EXAMPLE_RAGGED_TENSOR_4D_VALUES = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
|
||||
[11, 12], [13, 14], [15, 16], [17, 18],
|
||||
[19, 20]]
|
||||
|
||||
# Example 3D ragged tensor with uniform_row_lengths.
|
||||
EXAMPLE_RAGGED_TENSOR_3D = [[[1, 2, 3], [4], [5, 6]], [[], [7, 8, 9], []]]
|
||||
EXAMPLE_RAGGED_TENSOR_3D_ROWLEN = 3
|
||||
EXAMPLE_RAGGED_TENSOR_3D_SPLITS = [0, 3, 4, 6, 6, 9, 9]
|
||||
EXAMPLE_RAGGED_TENSOR_3D_VALUES = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
|
||||
|
||||
def int32array(values):
|
||||
return np.array(values, dtype=np.int32)
|
||||
@ -837,7 +843,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
# RaggedTensor.__getitem__
|
||||
#=============================================================================
|
||||
|
||||
def _TestGetItem(self, rt, slice_spec, expected):
|
||||
def _TestGetItem(self, rt, slice_spec, expected, expected_shape=None):
|
||||
"""Helper function for testing RaggedTensor.__getitem__.
|
||||
|
||||
Checks that calling `rt.__getitem__(slice_spec) returns the expected value.
|
||||
@ -855,6 +861,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
slice_spec: The slice spec.
|
||||
expected: The expected value of rt.__getitem__(slice_spec), as a python
|
||||
list; or an exception class.
|
||||
expected_shape: The expected shape for `rt.__getitem__(slice_spec)`.
|
||||
"""
|
||||
tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True)
|
||||
tensor_slice_spec2 = _make_tensor_slice_spec(slice_spec, False)
|
||||
@ -864,13 +871,18 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertAllEqual(value1, expected, 'slice_spec=%s' % (slice_spec,))
|
||||
self.assertAllEqual(value2, expected, 'slice_spec=%s' % (slice_spec,))
|
||||
self.assertAllEqual(value3, expected, 'slice_spec=%s' % (slice_spec,))
|
||||
if expected_shape is not None:
|
||||
value1.shape.assert_is_compatible_with(expected_shape)
|
||||
value2.shape.assert_is_compatible_with(expected_shape)
|
||||
value3.shape.assert_is_compatible_with(expected_shape)
|
||||
|
||||
def _TestGetItemException(self, rt, slice_spec, expected, message):
|
||||
"""Helper function for testing RaggedTensor.__getitem__ exceptions."""
|
||||
tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True)
|
||||
self.assertRaisesRegexp(expected, message, rt.__getitem__, slice_spec)
|
||||
self.assertRaisesRegexp(expected, message, rt.__getitem__,
|
||||
tensor_slice_spec1)
|
||||
tensor_slice_spec = _make_tensor_slice_spec(slice_spec, True)
|
||||
with self.assertRaisesRegexp(expected, message):
|
||||
self.evaluate(rt.__getitem__(slice_spec))
|
||||
with self.assertRaisesRegexp(expected, message):
|
||||
self.evaluate(rt.__getitem__(tensor_slice_spec))
|
||||
|
||||
@parameterized.parameters(
|
||||
# Tests for rt[i]
|
||||
@ -1225,12 +1237,84 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertEqual(rt_newaxis3.ragged_rank, 2)
|
||||
self.assertEqual(rt_newaxis4.ragged_rank, 2)
|
||||
|
||||
self.assertEqual(rt_newaxis0.shape.as_list(), [1, None, None, None, 2])
|
||||
self.assertEqual(rt_newaxis1.shape.as_list(), [2, None, None, None, 2])
|
||||
self.assertEqual(rt_newaxis2.shape.as_list(), [2, None, None, None, 2])
|
||||
self.assertEqual(rt_newaxis0.shape.as_list(), [1, 2, None, None, 2])
|
||||
self.assertEqual(rt_newaxis1.shape.as_list(), [2, 1, None, None, 2])
|
||||
self.assertEqual(rt_newaxis2.shape.as_list(), [2, None, 1, None, 2])
|
||||
self.assertEqual(rt_newaxis3.shape.as_list(), [2, None, None, 1, 2])
|
||||
self.assertEqual(rt_newaxis4.shape.as_list(), [2, None, None, 2, 1])
|
||||
|
||||
@parameterized.parameters(
|
||||
# EXAMPLE_RAGGED_TENSOR_3D.shape = [2, 3, None]
|
||||
|
||||
# Indexing into uniform_row_splits dimension:
|
||||
(SLICE_BUILDER[:, 1], [r[1] for r in EXAMPLE_RAGGED_TENSOR_3D],
|
||||
[2, None]),
|
||||
(SLICE_BUILDER[:, 2], [r[2] for r in EXAMPLE_RAGGED_TENSOR_3D],
|
||||
[2, None]),
|
||||
(SLICE_BUILDER[:, -2], [r[-2] for r in EXAMPLE_RAGGED_TENSOR_3D],
|
||||
[2, None]),
|
||||
(SLICE_BUILDER[:, -3], [r[-3] for r in EXAMPLE_RAGGED_TENSOR_3D],
|
||||
[2, None]),
|
||||
(SLICE_BUILDER[1:, 2], [r[2] for r in EXAMPLE_RAGGED_TENSOR_3D[1:]],
|
||||
[1, None]),
|
||||
(SLICE_BUILDER[:, 1, 1:], [r[1][1:] for r in EXAMPLE_RAGGED_TENSOR_3D],
|
||||
[2, None]),
|
||||
(SLICE_BUILDER[1:, 1, 1:],
|
||||
[r[1][1:] for r in EXAMPLE_RAGGED_TENSOR_3D[1:]],
|
||||
[1, None]),
|
||||
|
||||
# Slicing uniform_row_splits dimension:
|
||||
(SLICE_BUILDER[:, 2:], [r[2:] for r in EXAMPLE_RAGGED_TENSOR_3D],
|
||||
[2, 1, None]),
|
||||
(SLICE_BUILDER[:, -2:], [r[-2:] for r in EXAMPLE_RAGGED_TENSOR_3D],
|
||||
[2, 2, None]),
|
||||
(SLICE_BUILDER[:, :, 1:],
|
||||
[[c[1:] for c in r] for r in EXAMPLE_RAGGED_TENSOR_3D],
|
||||
[2, 3, None]),
|
||||
(SLICE_BUILDER[:, 5:], [r[5:] for r in EXAMPLE_RAGGED_TENSOR_3D],
|
||||
[2, 0, None]),
|
||||
|
||||
# Slicing uniform_row_splits dimension with a non-default step size:
|
||||
(SLICE_BUILDER[:, ::2], [r[::2] for r in EXAMPLE_RAGGED_TENSOR_3D],
|
||||
[2, 2, None]),
|
||||
(SLICE_BUILDER[:, ::-1], [r[::-1] for r in EXAMPLE_RAGGED_TENSOR_3D],
|
||||
[2, 3, None]),
|
||||
)
|
||||
def testRaggedTensorGetItemWithUniformRowLength(self, slice_spec, expected,
|
||||
expected_shape):
|
||||
"""Test that rt.__getitem__(slice_spec) == expected."""
|
||||
rt = RaggedTensor.from_uniform_row_length(
|
||||
RaggedTensor.from_row_splits(
|
||||
EXAMPLE_RAGGED_TENSOR_3D_VALUES,
|
||||
EXAMPLE_RAGGED_TENSOR_3D_SPLITS),
|
||||
EXAMPLE_RAGGED_TENSOR_3D_ROWLEN)
|
||||
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_3D)
|
||||
self.assertIsNot(rt.uniform_row_length, None)
|
||||
self._TestGetItem(rt, slice_spec, expected, expected_shape)
|
||||
|
||||
# If the result is 3D, then check that it still has a uniform row length:
|
||||
actual = rt.__getitem__(slice_spec)
|
||||
if actual.shape.rank == 3:
|
||||
self.assertIsNot(actual.uniform_row_length, None)
|
||||
self.assertAllEqual(actual.uniform_row_length, expected_shape[1])
|
||||
|
||||
@parameterized.parameters(
|
||||
(SLICE_BUILDER[:, 3], errors.InvalidArgumentError, 'out of bounds'),
|
||||
(SLICE_BUILDER[:, -4], errors.InvalidArgumentError, 'out of bounds'),
|
||||
(SLICE_BUILDER[:, 10], errors.InvalidArgumentError, 'out of bounds'),
|
||||
(SLICE_BUILDER[:, -10], errors.InvalidArgumentError, 'out of bounds'),
|
||||
)
|
||||
def testRaggedTensorGetItemErrorsWithUniformRowLength(self, slice_spec,
|
||||
expected, message):
|
||||
"""Test that rt.__getitem__(slice_spec) == expected."""
|
||||
rt = RaggedTensor.from_uniform_row_length(
|
||||
RaggedTensor.from_row_splits(
|
||||
EXAMPLE_RAGGED_TENSOR_3D_VALUES,
|
||||
EXAMPLE_RAGGED_TENSOR_3D_SPLITS),
|
||||
EXAMPLE_RAGGED_TENSOR_3D_ROWLEN)
|
||||
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_3D)
|
||||
self._TestGetItemException(rt, slice_spec, expected, message)
|
||||
|
||||
#=============================================================================
|
||||
# RaggedTensor.__str__
|
||||
#=============================================================================
|
||||
|
Loading…
Reference in New Issue
Block a user