STT-tensorflow/tensorflow/python/ops/ragged/ragged_getitem_test.py
Gaurav Jain f618ab4955 Move away from deprecated asserts
- assertEquals -> assertEqual
- assertRaisesRegexp -> assertRegexpMatches
- assertRegexpMatches -> assertRegex

PiperOrigin-RevId: 319118081
Change-Id: Ieb457128522920ab55d6b69a7f244ab798a7d689
2020-06-30 16:10:22 -07:00

591 lines
26 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for third_party.tensorflow.python.ops.ragged_tensor."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from absl.testing import parameterized
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
from tensorflow.python.platform import googletest
class _SliceBuilder(object):
"""Helper to construct arguments for __getitem__.
Usage: _SliceBuilder()[<expr>] slice_spec Python generates for <expr>.
"""
def __getitem__(self, slice_spec):
return slice_spec
SLICE_BUILDER = _SliceBuilder()
def _make_tensor_slice_spec(slice_spec, use_constant=True):
"""Wraps all integers in an extended slice spec w/ a tensor.
This function is used to help test slicing when the slice spec contains
tensors, rather than integers.
Args:
slice_spec: The extended slice spec.
use_constant: If true, then wrap each integer with a tf.constant. If false,
then wrap each integer with a tf.placeholder.
Returns:
A copy of slice_spec, but with each integer i replaced with tf.constant(i).
"""
def make_piece_scalar(piece):
if isinstance(piece, int):
scalar = constant_op.constant(piece)
if use_constant:
return scalar
else:
return array_ops.placeholder_with_default(scalar, [])
elif isinstance(piece, slice):
return slice(
make_piece_scalar(piece.start), make_piece_scalar(piece.stop),
make_piece_scalar(piece.step))
else:
return piece
if isinstance(slice_spec, tuple):
return tuple(make_piece_scalar(piece) for piece in slice_spec)
else:
return make_piece_scalar(slice_spec)
# Example 2D ragged tensor value with one ragged dimension and with scalar
# values, expressed as nested python lists and as splits+values.
EXAMPLE_RAGGED_TENSOR_2D = [[b'a', b'b'], [b'c', b'd', b'e'], [b'f'], [],
[b'g']]
EXAMPLE_RAGGED_TENSOR_2D_SPLITS = [0, 2, 5, 6, 6, 7]
EXAMPLE_RAGGED_TENSOR_2D_VALUES = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
# Example 4D ragged tensor value, with two ragged dimensions and with values
# whose shape is [2], expressed as nested python lists and as splits+values.
EXAMPLE_RAGGED_TENSOR_4D = [
[ # rt[0]
[[1, 2], [3, 4], [5, 6]], # rt[0][0]
[[7, 8], [9, 10], [11, 12]]], # rt[0][1]
[], # rt[1]
[ # rt[2]
[[13, 14], [15, 16], [17, 18]]], # rt[2][0]
[ # rt[3]
[[19, 20]]] # rt[3][0]
] # pyformat: disable
EXAMPLE_RAGGED_TENSOR_4D_SPLITS1 = [0, 2, 2, 3, 4]
EXAMPLE_RAGGED_TENSOR_4D_SPLITS2 = [0, 3, 6, 9, 10]
EXAMPLE_RAGGED_TENSOR_4D_VALUES = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
[11, 12], [13, 14], [15, 16], [17, 18],
[19, 20]]
# Example 3D ragged tensor with uniform_row_lengths.
EXAMPLE_RAGGED_TENSOR_3D = [[[1, 2, 3], [4], [5, 6]], [[], [7, 8, 9], []]]
EXAMPLE_RAGGED_TENSOR_3D_ROWLEN = 3
EXAMPLE_RAGGED_TENSOR_3D_SPLITS = [0, 3, 4, 6, 6, 9, 9]
EXAMPLE_RAGGED_TENSOR_3D_VALUES = [1, 2, 3, 4, 5, 6, 7, 8, 9]
@test_util.run_all_in_graph_and_eager_modes
class RaggedGetItemTest(test_util.TensorFlowTestCase, parameterized.TestCase):
longMessage = True # Property in unittest.Testcase. pylint: disable=invalid-name
#=============================================================================
# RaggedTensor.__getitem__
#=============================================================================
def _TestGetItem(self, rt, slice_spec, expected, expected_shape=None):
"""Helper function for testing RaggedTensor.__getitem__.
Checks that calling `rt.__getitem__(slice_spec) returns the expected value.
Checks three different configurations for each slice spec:
* Call __getitem__ with the slice spec as-is (with int values)
* Call __getitem__ with int values in the slice spec wrapped in
`tf.constant()`.
* Call __getitem__ with int values in the slice spec wrapped in
`tf.compat.v1.placeholder()` (so value is not known at graph
construction time).
Args:
rt: The RaggedTensor to test.
slice_spec: The slice spec.
expected: The expected value of rt.__getitem__(slice_spec), as a python
list; or an exception class.
expected_shape: The expected shape for `rt.__getitem__(slice_spec)`.
"""
tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True)
tensor_slice_spec2 = _make_tensor_slice_spec(slice_spec, False)
value1 = rt.__getitem__(slice_spec)
value2 = rt.__getitem__(tensor_slice_spec1)
value3 = rt.__getitem__(tensor_slice_spec2)
self.assertAllEqual(value1, expected, 'slice_spec=%s' % (slice_spec,))
self.assertAllEqual(value2, expected, 'slice_spec=%s' % (slice_spec,))
self.assertAllEqual(value3, expected, 'slice_spec=%s' % (slice_spec,))
if expected_shape is not None:
value1.shape.assert_is_compatible_with(expected_shape)
value2.shape.assert_is_compatible_with(expected_shape)
value3.shape.assert_is_compatible_with(expected_shape)
def _TestGetItemException(self, rt, slice_spec, expected, message):
"""Helper function for testing RaggedTensor.__getitem__ exceptions."""
tensor_slice_spec = _make_tensor_slice_spec(slice_spec, True)
with self.assertRaisesRegex(expected, message):
self.evaluate(rt.__getitem__(slice_spec))
with self.assertRaisesRegex(expected, message):
self.evaluate(rt.__getitem__(tensor_slice_spec))
@parameterized.parameters(
# Tests for rt[i]
(SLICE_BUILDER[-5], EXAMPLE_RAGGED_TENSOR_2D[-5]),
(SLICE_BUILDER[-4], EXAMPLE_RAGGED_TENSOR_2D[-4]),
(SLICE_BUILDER[-1], EXAMPLE_RAGGED_TENSOR_2D[-1]),
(SLICE_BUILDER[0], EXAMPLE_RAGGED_TENSOR_2D[0]),
(SLICE_BUILDER[1], EXAMPLE_RAGGED_TENSOR_2D[1]),
(SLICE_BUILDER[4], EXAMPLE_RAGGED_TENSOR_2D[4]),
# Tests for rt[i:]
(SLICE_BUILDER[-6:], EXAMPLE_RAGGED_TENSOR_2D[-6:]),
(SLICE_BUILDER[-3:], EXAMPLE_RAGGED_TENSOR_2D[-3:]),
(SLICE_BUILDER[-1:], EXAMPLE_RAGGED_TENSOR_2D[-1:]),
(SLICE_BUILDER[0:], EXAMPLE_RAGGED_TENSOR_2D[0:]),
(SLICE_BUILDER[3:], EXAMPLE_RAGGED_TENSOR_2D[3:]),
(SLICE_BUILDER[5:], EXAMPLE_RAGGED_TENSOR_2D[5:]),
# Tests for rt[:j]
(SLICE_BUILDER[:-6], EXAMPLE_RAGGED_TENSOR_2D[:-6]),
(SLICE_BUILDER[:-3], EXAMPLE_RAGGED_TENSOR_2D[:-3]),
(SLICE_BUILDER[:-1], EXAMPLE_RAGGED_TENSOR_2D[:-1]),
(SLICE_BUILDER[:0], EXAMPLE_RAGGED_TENSOR_2D[:0]),
(SLICE_BUILDER[:3], EXAMPLE_RAGGED_TENSOR_2D[:3]),
(SLICE_BUILDER[:5], EXAMPLE_RAGGED_TENSOR_2D[:5]),
# Tests for rt[i:j]
(SLICE_BUILDER[0:3], EXAMPLE_RAGGED_TENSOR_2D[0:3]),
(SLICE_BUILDER[3:5], EXAMPLE_RAGGED_TENSOR_2D[3:5]),
(SLICE_BUILDER[-5:3], EXAMPLE_RAGGED_TENSOR_2D[-5:3]),
(SLICE_BUILDER[3:1], EXAMPLE_RAGGED_TENSOR_2D[3:1]),
(SLICE_BUILDER[-1:1], EXAMPLE_RAGGED_TENSOR_2D[-1:1]),
(SLICE_BUILDER[1:-1], EXAMPLE_RAGGED_TENSOR_2D[1:-1]),
# Tests for rt[i, j]
(SLICE_BUILDER[0, 1], EXAMPLE_RAGGED_TENSOR_2D[0][1]),
(SLICE_BUILDER[1, 2], EXAMPLE_RAGGED_TENSOR_2D[1][2]),
(SLICE_BUILDER[-1, 0], EXAMPLE_RAGGED_TENSOR_2D[-1][0]),
(SLICE_BUILDER[-3, 0], EXAMPLE_RAGGED_TENSOR_2D[-3][0]),
(SLICE_BUILDER[:], EXAMPLE_RAGGED_TENSOR_2D),
(SLICE_BUILDER[:, :], EXAMPLE_RAGGED_TENSOR_2D),
# Empty slice spec.
([], EXAMPLE_RAGGED_TENSOR_2D),
# Test for ellipsis
(SLICE_BUILDER[...], EXAMPLE_RAGGED_TENSOR_2D),
(SLICE_BUILDER[2, ...], EXAMPLE_RAGGED_TENSOR_2D[2]),
(SLICE_BUILDER[..., :], EXAMPLE_RAGGED_TENSOR_2D),
(SLICE_BUILDER[..., 2, 0], EXAMPLE_RAGGED_TENSOR_2D[2][0]),
(SLICE_BUILDER[2, ..., 0], EXAMPLE_RAGGED_TENSOR_2D[2][0]),
(SLICE_BUILDER[2, 0, ...], EXAMPLE_RAGGED_TENSOR_2D[2][0]),
# Test for array_ops.newaxis
(SLICE_BUILDER[array_ops.newaxis, :], [EXAMPLE_RAGGED_TENSOR_2D]),
(SLICE_BUILDER[:, array_ops.newaxis],
[[row] for row in EXAMPLE_RAGGED_TENSOR_2D]),
# Slicing inner ragged dimensions.
(SLICE_BUILDER[-1:,
1:4], [row[1:4] for row in EXAMPLE_RAGGED_TENSOR_2D[-1:]]),
(SLICE_BUILDER[:, 1:4], [row[1:4] for row in EXAMPLE_RAGGED_TENSOR_2D]),
(SLICE_BUILDER[:, -2:], [row[-2:] for row in EXAMPLE_RAGGED_TENSOR_2D]),
# Strided slices
(SLICE_BUILDER[::2], EXAMPLE_RAGGED_TENSOR_2D[::2]),
(SLICE_BUILDER[::-1], EXAMPLE_RAGGED_TENSOR_2D[::-1]),
(SLICE_BUILDER[::-2], EXAMPLE_RAGGED_TENSOR_2D[::-2]),
(SLICE_BUILDER[::-3], EXAMPLE_RAGGED_TENSOR_2D[::-3]),
(SLICE_BUILDER[:, ::2], [row[::2] for row in EXAMPLE_RAGGED_TENSOR_2D]),
(SLICE_BUILDER[:, ::-1], [row[::-1] for row in EXAMPLE_RAGGED_TENSOR_2D]),
(SLICE_BUILDER[:, ::-2], [row[::-2] for row in EXAMPLE_RAGGED_TENSOR_2D]),
(SLICE_BUILDER[:, ::-3], [row[::-3] for row in EXAMPLE_RAGGED_TENSOR_2D]),
(SLICE_BUILDER[:, 2::-1],
[row[2::-1] for row in EXAMPLE_RAGGED_TENSOR_2D]),
(SLICE_BUILDER[:, -1::-1],
[row[-1::-1] for row in EXAMPLE_RAGGED_TENSOR_2D]),
(SLICE_BUILDER[..., -1::-1],
[row[-1::-1] for row in EXAMPLE_RAGGED_TENSOR_2D]),
(SLICE_BUILDER[:, 2::-2],
[row[2::-2] for row in EXAMPLE_RAGGED_TENSOR_2D]),
(SLICE_BUILDER[::-1, ::-1],
[row[::-1] for row in EXAMPLE_RAGGED_TENSOR_2D[::-1]]),
) # pyformat: disable
def testWithRaggedRank1(self, slice_spec, expected):
"""Test that rt.__getitem__(slice_spec) == expected."""
# Ragged tensor
rt = RaggedTensor.from_row_splits(EXAMPLE_RAGGED_TENSOR_2D_VALUES,
EXAMPLE_RAGGED_TENSOR_2D_SPLITS)
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_2D)
self._TestGetItem(rt, slice_spec, expected)
# pylint: disable=g-complex-comprehension
@parameterized.parameters([(start, stop)
for start in [-2, -1, None, 0, 1, 2]
for stop in [-2, -1, None, 0, 1, 2]])
def testWithStridedSlices(self, start, stop):
test_value = [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10], [], [9],
[1, 2, 3, 4, 5, 6, 7, 8]]
rt = ragged_factory_ops.constant(test_value)
for step in [-3, -2, -1, 1, 2, 3]:
# Slice outer dimension
self.assertAllEqual(rt[start:stop:step], test_value[start:stop:step],
'slice=%s:%s:%s' % (start, stop, step))
# Slice inner dimension
self.assertAllEqual(rt[:, start:stop:step],
[row[start:stop:step] for row in test_value],
'slice=%s:%s:%s' % (start, stop, step))
# pylint: disable=invalid-slice-index
@parameterized.parameters(
# Tests for out-of-bound errors
(SLICE_BUILDER[5], (IndexError, ValueError, errors.InvalidArgumentError),
'.*out of bounds.*'),
(SLICE_BUILDER[-6], (IndexError, ValueError, errors.InvalidArgumentError),
'.*out of bounds.*'),
(SLICE_BUILDER[0, 2], (IndexError, ValueError,
errors.InvalidArgumentError), '.*out of bounds.*'),
(SLICE_BUILDER[3, 0], (IndexError, ValueError,
errors.InvalidArgumentError), '.*out of bounds.*'),
# Indexing into an inner ragged dimension
(SLICE_BUILDER[:, 3], ValueError,
'Cannot index into an inner ragged dimension'),
(SLICE_BUILDER[:1, 3], ValueError,
'Cannot index into an inner ragged dimension'),
(SLICE_BUILDER[..., 3], ValueError,
'Cannot index into an inner ragged dimension'),
# Tests for type errors
(SLICE_BUILDER[0.5], TypeError, re.escape(array_ops._SLICE_TYPE_ERROR)),
(SLICE_BUILDER[1:3:0.5], TypeError, re.escape(
array_ops._SLICE_TYPE_ERROR)),
(SLICE_BUILDER[:, 1:3:0.5], TypeError,
'slice strides must be integers or None'),
(SLICE_BUILDER[:, 0.5:1.5], TypeError,
'slice offsets must be integers or None'),
(SLICE_BUILDER['foo'], TypeError, re.escape(array_ops._SLICE_TYPE_ERROR)),
(SLICE_BUILDER[:, 'foo':'foo'], TypeError,
'slice offsets must be integers or None'),
# Tests for other errors
(SLICE_BUILDER[..., 0, 0,
0], IndexError, 'Too many indices for RaggedTensor'),
)
def testErrorsWithRaggedRank1(self, slice_spec, expected, message):
"""Test that rt.__getitem__(slice_spec) == expected."""
# Ragged tensor
rt = RaggedTensor.from_row_splits(EXAMPLE_RAGGED_TENSOR_2D_VALUES,
EXAMPLE_RAGGED_TENSOR_2D_SPLITS)
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_2D)
self._TestGetItemException(rt, slice_spec, expected, message)
@parameterized.parameters(
# Tests for rt[index, index, ...]
(SLICE_BUILDER[2, 0], EXAMPLE_RAGGED_TENSOR_4D[2][0]),
(SLICE_BUILDER[2, 0, 1], EXAMPLE_RAGGED_TENSOR_4D[2][0][1]),
(SLICE_BUILDER[2, 0, 1, 1], EXAMPLE_RAGGED_TENSOR_4D[2][0][1][1]),
(SLICE_BUILDER[2, 0, 1:], EXAMPLE_RAGGED_TENSOR_4D[2][0][1:]),
(SLICE_BUILDER[2, 0, 1:, 1:], [[16], [18]]),
(SLICE_BUILDER[2, 0, :, 1], [14, 16, 18]),
(SLICE_BUILDER[2, 0, 1, :], EXAMPLE_RAGGED_TENSOR_4D[2][0][1]),
# Tests for rt[index, slice, ...]
(SLICE_BUILDER[0, :], EXAMPLE_RAGGED_TENSOR_4D[0]),
(SLICE_BUILDER[1, :], EXAMPLE_RAGGED_TENSOR_4D[1]),
(SLICE_BUILDER[0, :, :, 1], [[2, 4, 6], [8, 10, 12]]),
(SLICE_BUILDER[1, :, :, 1], []),
(SLICE_BUILDER[2, :, :, 1], [[14, 16, 18]]),
(SLICE_BUILDER[3, :, :, 1], [[20]]),
# Tests for rt[slice, slice, ...]
(SLICE_BUILDER[:, :], EXAMPLE_RAGGED_TENSOR_4D),
(SLICE_BUILDER[:, :, :, 1], [[[2, 4, 6], [8, 10, 12]], [], [[14, 16, 18]],
[[20]]]),
(SLICE_BUILDER[1:, :, :, 1], [[], [[14, 16, 18]], [[20]]]),
(SLICE_BUILDER[-3:, :, :, 1], [[], [[14, 16, 18]], [[20]]]),
# Test for ellipsis
(SLICE_BUILDER[...], EXAMPLE_RAGGED_TENSOR_4D),
(SLICE_BUILDER[2, ...], EXAMPLE_RAGGED_TENSOR_4D[2]),
(SLICE_BUILDER[2, 0, ...], EXAMPLE_RAGGED_TENSOR_4D[2][0]),
(SLICE_BUILDER[..., 0], [[[1, 3, 5], [7, 9, 11]], [], [[13, 15, 17]],
[[19]]]),
(SLICE_BUILDER[2, ..., 0], [[13, 15, 17]]),
(SLICE_BUILDER[2, 0, ..., 0], [13, 15, 17]),
# Test for array_ops.newaxis
(SLICE_BUILDER[array_ops.newaxis, :], [EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[:, array_ops.newaxis],
[[row] for row in EXAMPLE_RAGGED_TENSOR_4D]),
# Empty slice spec.
([], EXAMPLE_RAGGED_TENSOR_4D),
# Slicing inner ragged dimensions.
(SLICE_BUILDER[:, 1:4], [row[1:4] for row in EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[:, -2:], [row[-2:] for row in EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[:, :, :-1],
[[v[:-1] for v in row] for row in EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[:, :, 1:2],
[[v[1:2] for v in row] for row in EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[1:, 1:3, 1:2],
[[v[1:2] for v in row[1:3]] for row in EXAMPLE_RAGGED_TENSOR_4D[1:]]),
# Strided slices
(SLICE_BUILDER[::2], EXAMPLE_RAGGED_TENSOR_4D[::2]),
(SLICE_BUILDER[::-1], EXAMPLE_RAGGED_TENSOR_4D[::-1]),
(SLICE_BUILDER[::-2], EXAMPLE_RAGGED_TENSOR_4D[::-2]),
(SLICE_BUILDER[1::2], EXAMPLE_RAGGED_TENSOR_4D[1::2]),
(SLICE_BUILDER[:, ::2], [row[::2] for row in EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[:, 1::2], [row[1::2] for row in EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[:, :, ::2],
[[v[::2] for v in row] for row in EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[:, :, 1::2],
[[v[1::2] for v in row] for row in EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[:, :, ::-1],
[[v[::-1] for v in row] for row in EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[:, :, ::-2],
[[v[::-2] for v in row] for row in EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[..., ::-1, :],
[[v[::-1] for v in row] for row in EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[..., ::-1], [[[v[::-1] for v in col] for col in row]
for row in EXAMPLE_RAGGED_TENSOR_4D]),
) # pyformat: disable
def testWithRaggedRank2(self, slice_spec, expected):
"""Test that rt.__getitem__(slice_spec) == expected."""
rt = RaggedTensor.from_nested_row_splits(
EXAMPLE_RAGGED_TENSOR_4D_VALUES,
[EXAMPLE_RAGGED_TENSOR_4D_SPLITS1, EXAMPLE_RAGGED_TENSOR_4D_SPLITS2])
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_4D)
self._TestGetItem(rt, slice_spec, expected)
@parameterized.parameters(
# Test for errors in unsupported cases
(SLICE_BUILDER[:, 0], ValueError,
'Cannot index into an inner ragged dimension.'),
(SLICE_BUILDER[:, :, 0], ValueError,
'Cannot index into an inner ragged dimension.'),
# Test for out-of-bounds errors.
(SLICE_BUILDER[1, 0], (IndexError, ValueError,
errors.InvalidArgumentError), '.*out of bounds.*'),
(SLICE_BUILDER[0, 0, 3],
(IndexError, ValueError,
errors.InvalidArgumentError), '.*out of bounds.*'),
(SLICE_BUILDER[5], (IndexError, ValueError, errors.InvalidArgumentError),
'.*out of bounds.*'),
(SLICE_BUILDER[0, 5], (IndexError, ValueError,
errors.InvalidArgumentError), '.*out of bounds.*'),
)
def testErrorsWithRaggedRank2(self, slice_spec, expected, message):
"""Test that rt.__getitem__(slice_spec) == expected."""
rt = RaggedTensor.from_nested_row_splits(
EXAMPLE_RAGGED_TENSOR_4D_VALUES,
[EXAMPLE_RAGGED_TENSOR_4D_SPLITS1, EXAMPLE_RAGGED_TENSOR_4D_SPLITS2])
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_4D)
self._TestGetItemException(rt, slice_spec, expected, message)
@parameterized.parameters(
(SLICE_BUILDER[:], []),
(SLICE_BUILDER[2:], []),
(SLICE_BUILDER[:-3], []),
)
def testWithEmptyTensor(self, slice_spec, expected):
"""Test that rt.__getitem__(slice_spec) == expected."""
rt = RaggedTensor.from_row_splits([], [0])
self._TestGetItem(rt, slice_spec, expected)
@parameterized.parameters(
(SLICE_BUILDER[0], (IndexError, ValueError, errors.InvalidArgumentError),
'.*out of bounds.*'),
(SLICE_BUILDER[-1], (IndexError, ValueError, errors.InvalidArgumentError),
'.*out of bounds.*'),
)
def testErrorsWithEmptyTensor(self, slice_spec, expected, message):
"""Test that rt.__getitem__(slice_spec) == expected."""
rt = RaggedTensor.from_row_splits([], [0])
self._TestGetItemException(rt, slice_spec, expected, message)
@parameterized.parameters(
(SLICE_BUILDER[-4], EXAMPLE_RAGGED_TENSOR_2D[-4]),
(SLICE_BUILDER[0], EXAMPLE_RAGGED_TENSOR_2D[0]),
(SLICE_BUILDER[-3:], EXAMPLE_RAGGED_TENSOR_2D[-3:]),
(SLICE_BUILDER[:3], EXAMPLE_RAGGED_TENSOR_2D[:3]),
(SLICE_BUILDER[3:5], EXAMPLE_RAGGED_TENSOR_2D[3:5]),
(SLICE_BUILDER[0, 1], EXAMPLE_RAGGED_TENSOR_2D[0][1]),
(SLICE_BUILDER[-3, 0], EXAMPLE_RAGGED_TENSOR_2D[-3][0]),
)
def testWithPlaceholderShapes(self, slice_spec, expected):
"""Test that rt.__getitem__(slice_spec) == expected."""
# Intentionally use an unknown shape for `splits`, to force the code path
# that deals with having nrows unknown at graph construction time.
splits = constant_op.constant(
EXAMPLE_RAGGED_TENSOR_2D_SPLITS, dtype=dtypes.int64)
splits = array_ops.placeholder_with_default(splits, None)
rt = RaggedTensor.from_row_splits(EXAMPLE_RAGGED_TENSOR_2D_VALUES, splits)
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_2D)
self._TestGetItem(rt, slice_spec, expected)
@parameterized.parameters(
(SLICE_BUILDER[..., 2], ValueError,
'Ellipsis not supported for unknown shape RaggedTensors'),)
def testErrorsWithPlaceholderShapes(self, slice_spec, expected, message):
"""Test that rt.__getitem__(slice_spec) == expected."""
if not context.executing_eagerly():
# Intentionally use an unknown shape for `values`.
values = array_ops.placeholder_with_default([0], None)
rt = RaggedTensor.from_row_splits(values, [0, 1])
self._TestGetItemException(rt, slice_spec, expected, message)
def testNewAxis(self):
# rt: [[[['a', 'b'], ['c', 'd']], [], [['e', 'f']]], []]
splits1 = [0, 3, 3]
splits2 = [0, 2, 2, 3]
values = constant_op.constant([['a', 'b'], ['c', 'd'], ['e', 'f']])
rt = RaggedTensor.from_nested_row_splits(values, [splits1, splits2])
rt_newaxis0 = rt[array_ops.newaxis]
rt_newaxis1 = rt[:, array_ops.newaxis]
rt_newaxis2 = rt[:, :, array_ops.newaxis]
rt_newaxis3 = rt[:, :, :, array_ops.newaxis]
rt_newaxis4 = rt[:, :, :, :, array_ops.newaxis]
self.assertAllEqual(
rt, [[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []])
self.assertAllEqual(
rt_newaxis0, [[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []]])
self.assertAllEqual(
rt_newaxis1,
[[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]]], [[]]])
self.assertAllEqual(
rt_newaxis2,
[[[[[b'a', b'b'], [b'c', b'd']]], [[]], [[[b'e', b'f']]]], []])
self.assertAllEqual(
rt_newaxis3,
[[[[[b'a', b'b']], [[b'c', b'd']]], [], [[[b'e', b'f']]]], []])
self.assertAllEqual(
rt_newaxis4,
[[[[[b'a'], [b'b']], [[b'c'], [b'd']]], [], [[[b'e'], [b'f']]]], []])
self.assertEqual(rt.ragged_rank, 2)
self.assertEqual(rt_newaxis0.ragged_rank, 3)
self.assertEqual(rt_newaxis1.ragged_rank, 3)
self.assertEqual(rt_newaxis2.ragged_rank, 3)
self.assertEqual(rt_newaxis3.ragged_rank, 2)
self.assertEqual(rt_newaxis4.ragged_rank, 2)
self.assertEqual(rt_newaxis0.shape.as_list(), [1, 2, None, None, 2])
self.assertEqual(rt_newaxis1.shape.as_list(), [2, 1, None, None, 2])
self.assertEqual(rt_newaxis2.shape.as_list(), [2, None, 1, None, 2])
self.assertEqual(rt_newaxis3.shape.as_list(), [2, None, None, 1, 2])
self.assertEqual(rt_newaxis4.shape.as_list(), [2, None, None, 2, 1])
@parameterized.parameters(
# EXAMPLE_RAGGED_TENSOR_3D.shape = [2, 3, None]
# Indexing into uniform_row_splits dimension:
(SLICE_BUILDER[:, 1], [r[1] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[:, 2], [r[2] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[:, -2], [r[-2] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[:, -3], [r[-3] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[1:, 2], [r[2] for r in EXAMPLE_RAGGED_TENSOR_3D[1:]],
[1, None]),
(SLICE_BUILDER[:, 1, 1:], [r[1][1:] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, None]),
(SLICE_BUILDER[1:, 1, 1:],
[r[1][1:] for r in EXAMPLE_RAGGED_TENSOR_3D[1:]],
[1, None]),
# Slicing uniform_row_splits dimension:
(SLICE_BUILDER[:, 2:], [r[2:] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 1, None]),
(SLICE_BUILDER[:, -2:], [r[-2:] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 2, None]),
(SLICE_BUILDER[:, :, 1:],
[[c[1:] for c in r] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 3, None]),
(SLICE_BUILDER[:, 5:], [r[5:] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 0, None]),
# Slicing uniform_row_splits dimension with a non-default step size:
(SLICE_BUILDER[:, ::2], [r[::2] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 2, None]),
(SLICE_BUILDER[:, ::-1], [r[::-1] for r in EXAMPLE_RAGGED_TENSOR_3D],
[2, 3, None]),
) # pyformat: disable
def testWithUniformRowLength(self, slice_spec, expected, expected_shape):
"""Test that rt.__getitem__(slice_spec) == expected."""
rt = RaggedTensor.from_uniform_row_length(
RaggedTensor.from_row_splits(EXAMPLE_RAGGED_TENSOR_3D_VALUES,
EXAMPLE_RAGGED_TENSOR_3D_SPLITS),
EXAMPLE_RAGGED_TENSOR_3D_ROWLEN)
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_3D)
self.assertIsNot(rt.uniform_row_length, None)
self._TestGetItem(rt, slice_spec, expected, expected_shape)
# If the result is 3D, then check that it still has a uniform row length:
actual = rt.__getitem__(slice_spec)
if actual.shape.rank == 3:
self.assertIsNot(actual.uniform_row_length, None)
self.assertAllEqual(actual.uniform_row_length, expected_shape[1])
@parameterized.parameters(
(SLICE_BUILDER[:, 3], errors.InvalidArgumentError, 'out of bounds'),
(SLICE_BUILDER[:, -4], errors.InvalidArgumentError, 'out of bounds'),
(SLICE_BUILDER[:, 10], errors.InvalidArgumentError, 'out of bounds'),
(SLICE_BUILDER[:, -10], errors.InvalidArgumentError, 'out of bounds'),
)
def testErrorsWithUniformRowLength(self, slice_spec, expected, message):
"""Test that rt.__getitem__(slice_spec) == expected."""
rt = RaggedTensor.from_uniform_row_length(
RaggedTensor.from_row_splits(EXAMPLE_RAGGED_TENSOR_3D_VALUES,
EXAMPLE_RAGGED_TENSOR_3D_SPLITS),
EXAMPLE_RAGGED_TENSOR_3D_ROWLEN)
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_3D)
self._TestGetItemException(rt, slice_spec, expected, message)
if __name__ == '__main__':
googletest.main()