In RaggedTensorDynamicShape.from_tensor: Check for uniform_row_lengths.
PiperOrigin-RevId: 348138366 Change-Id: Ia0bc08f8fe00ec9714041184f71f4cee7106c31f
This commit is contained in:
parent
e7f1de9661
commit
4ad325ebf9
tensorflow/python/ops/ragged
@ -50,7 +50,9 @@ class RaggedTensorDynamicShape(object):
|
||||
Furthermore, there are two ways a dimension might be encoded:
|
||||
|
||||
* "Partitioned dimensions" are dimensions that are encoded using a
|
||||
`RowPartition`. The outermostmost partitioned dimension must be uniform.
|
||||
`RaggedTensor`'s `nested_row_splits`. The outermostmost partitioned
|
||||
dimension must be uniform, and the innermost partitioned dimension must
|
||||
be ragged.
|
||||
|
||||
* "Inner dimensions" are dimensions that are encoded using a
|
||||
`RaggedTensor`'s `flat_values`. Inner dimensions are always uniform.
|
||||
@ -118,6 +120,8 @@ class RaggedTensorDynamicShape(object):
|
||||
dimension_size.shape.with_rank_at_most(1)
|
||||
if partitioned_dim_sizes[0].shape.ndims == 1:
|
||||
raise ValueError('outermost partitioned dimension must be uniform')
|
||||
if partitioned_dim_sizes[-1].shape.ndims == 0:
|
||||
raise ValueError('innermost partitioned dimension must be ragged')
|
||||
inner_dim_sizes.shape.assert_has_rank(1)
|
||||
|
||||
# Convert dimension size tensors to a single dtype.
|
||||
@ -181,17 +185,10 @@ class RaggedTensorDynamicShape(object):
|
||||
if not ragged_tensor.is_ragged(rt_input):
|
||||
return cls([], array_ops.shape(rt_input))
|
||||
else:
|
||||
partitioned_dim_sizes = [rt_input.nrows()]
|
||||
rt = rt_input
|
||||
while ragged_tensor.is_ragged(rt):
|
||||
if rt.uniform_row_length is None:
|
||||
partitioned_dim_sizes.append(rt.row_lengths())
|
||||
else:
|
||||
partitioned_dim_sizes.append(rt.uniform_row_length)
|
||||
rt = rt.values
|
||||
|
||||
partitioned_dim_sizes = (
|
||||
(rt_input.nrows(),) + rt_input.nested_row_lengths())
|
||||
return RaggedTensorDynamicShape(
|
||||
tuple(partitioned_dim_sizes),
|
||||
partitioned_dim_sizes,
|
||||
array_ops.shape(rt_input.flat_values)[1:],
|
||||
dim_size_dtype=dim_size_dtype)
|
||||
|
||||
|
@ -30,9 +30,6 @@ from tensorflow.python.ops.ragged.ragged_tensor_shape import RaggedTensorDynamic
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
# pylint: disable=g-long-lambda
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
@ -83,15 +80,8 @@ class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
|
||||
dict(
|
||||
value=ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]),
|
||||
expected_dim_sizes=[2, [2, 1], [2, 1, 2]]),
|
||||
dict(
|
||||
value=lambda: ragged_tensor.RaggedTensor.from_uniform_row_length(
|
||||
ragged_factory_ops.constant([[1, 2], [3, 4, 5], [], [6]]),
|
||||
uniform_row_length=2),
|
||||
expected_dim_sizes=[2, 2, [2, 3, 0, 1]]),
|
||||
])
|
||||
def testFromTensor(self, value, expected_dim_sizes):
|
||||
if callable(value):
|
||||
value = value()
|
||||
shape = RaggedTensorDynamicShape.from_tensor(value)
|
||||
expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes)
|
||||
self.assertShapeEq(shape, expected)
|
||||
|
Loading…
Reference in New Issue
Block a user