In RaggedTensorDynamicShape.from_tensor: Check for uniform_row_lengths.

PiperOrigin-RevId: 348138366
Change-Id: Ia0bc08f8fe00ec9714041184f71f4cee7106c31f
This commit is contained in:
A. Unique TensorFlower 2020-12-17 20:01:17 -08:00 committed by TensorFlower Gardener
parent e7f1de9661
commit 4ad325ebf9
2 changed files with 8 additions and 21 deletions

View File

@ -50,7 +50,9 @@ class RaggedTensorDynamicShape(object):
Furthermore, there are two ways a dimension might be encoded:
* "Partitioned dimensions" are dimensions that are encoded using a
`RowPartition`. The outermostmost partitioned dimension must be uniform.
`RaggedTensor`'s `nested_row_splits`. The outermostmost partitioned
dimension must be uniform, and the innermost partitioned dimension must
be ragged.
* "Inner dimensions" are dimensions that are encoded using a
`RaggedTensor`'s `flat_values`. Inner dimensions are always uniform.
@ -118,6 +120,8 @@ class RaggedTensorDynamicShape(object):
dimension_size.shape.with_rank_at_most(1)
if partitioned_dim_sizes[0].shape.ndims == 1:
raise ValueError('outermost partitioned dimension must be uniform')
if partitioned_dim_sizes[-1].shape.ndims == 0:
raise ValueError('innermost partitioned dimension must be ragged')
inner_dim_sizes.shape.assert_has_rank(1)
# Convert dimension size tensors to a single dtype.
@ -181,17 +185,10 @@ class RaggedTensorDynamicShape(object):
if not ragged_tensor.is_ragged(rt_input):
return cls([], array_ops.shape(rt_input))
else:
partitioned_dim_sizes = [rt_input.nrows()]
rt = rt_input
while ragged_tensor.is_ragged(rt):
if rt.uniform_row_length is None:
partitioned_dim_sizes.append(rt.row_lengths())
else:
partitioned_dim_sizes.append(rt.uniform_row_length)
rt = rt.values
partitioned_dim_sizes = (
(rt_input.nrows(),) + rt_input.nested_row_lengths())
return RaggedTensorDynamicShape(
tuple(partitioned_dim_sizes),
partitioned_dim_sizes,
array_ops.shape(rt_input.flat_values)[1:],
dim_size_dtype=dim_size_dtype)

View File

@ -30,9 +30,6 @@ from tensorflow.python.ops.ragged.ragged_tensor_shape import RaggedTensorDynamic
from tensorflow.python.platform import googletest
# pylint: disable=g-long-lambda
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@ -83,15 +80,8 @@ class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
dict(
value=ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]),
expected_dim_sizes=[2, [2, 1], [2, 1, 2]]),
dict(
value=lambda: ragged_tensor.RaggedTensor.from_uniform_row_length(
ragged_factory_ops.constant([[1, 2], [3, 4, 5], [], [6]]),
uniform_row_length=2),
expected_dim_sizes=[2, 2, [2, 3, 0, 1]]),
])
def testFromTensor(self, value, expected_dim_sizes):
if callable(value):
value = value()
shape = RaggedTensorDynamicShape.from_tensor(value)
expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes)
self.assertShapeEq(shape, expected)