In RaggedTensorDynamicShape.from_tensor: Check for uniform_row_lengths.
PiperOrigin-RevId: 348138366 Change-Id: Ia0bc08f8fe00ec9714041184f71f4cee7106c31f
This commit is contained in:
parent
e7f1de9661
commit
4ad325ebf9
@ -50,7 +50,9 @@ class RaggedTensorDynamicShape(object):
|
|||||||
Furthermore, there are two ways a dimension might be encoded:
|
Furthermore, there are two ways a dimension might be encoded:
|
||||||
|
|
||||||
* "Partitioned dimensions" are dimensions that are encoded using a
|
* "Partitioned dimensions" are dimensions that are encoded using a
|
||||||
`RowPartition`. The outermostmost partitioned dimension must be uniform.
|
`RaggedTensor`'s `nested_row_splits`. The outermostmost partitioned
|
||||||
|
dimension must be uniform, and the innermost partitioned dimension must
|
||||||
|
be ragged.
|
||||||
|
|
||||||
* "Inner dimensions" are dimensions that are encoded using a
|
* "Inner dimensions" are dimensions that are encoded using a
|
||||||
`RaggedTensor`'s `flat_values`. Inner dimensions are always uniform.
|
`RaggedTensor`'s `flat_values`. Inner dimensions are always uniform.
|
||||||
@ -118,6 +120,8 @@ class RaggedTensorDynamicShape(object):
|
|||||||
dimension_size.shape.with_rank_at_most(1)
|
dimension_size.shape.with_rank_at_most(1)
|
||||||
if partitioned_dim_sizes[0].shape.ndims == 1:
|
if partitioned_dim_sizes[0].shape.ndims == 1:
|
||||||
raise ValueError('outermost partitioned dimension must be uniform')
|
raise ValueError('outermost partitioned dimension must be uniform')
|
||||||
|
if partitioned_dim_sizes[-1].shape.ndims == 0:
|
||||||
|
raise ValueError('innermost partitioned dimension must be ragged')
|
||||||
inner_dim_sizes.shape.assert_has_rank(1)
|
inner_dim_sizes.shape.assert_has_rank(1)
|
||||||
|
|
||||||
# Convert dimension size tensors to a single dtype.
|
# Convert dimension size tensors to a single dtype.
|
||||||
@ -181,17 +185,10 @@ class RaggedTensorDynamicShape(object):
|
|||||||
if not ragged_tensor.is_ragged(rt_input):
|
if not ragged_tensor.is_ragged(rt_input):
|
||||||
return cls([], array_ops.shape(rt_input))
|
return cls([], array_ops.shape(rt_input))
|
||||||
else:
|
else:
|
||||||
partitioned_dim_sizes = [rt_input.nrows()]
|
partitioned_dim_sizes = (
|
||||||
rt = rt_input
|
(rt_input.nrows(),) + rt_input.nested_row_lengths())
|
||||||
while ragged_tensor.is_ragged(rt):
|
|
||||||
if rt.uniform_row_length is None:
|
|
||||||
partitioned_dim_sizes.append(rt.row_lengths())
|
|
||||||
else:
|
|
||||||
partitioned_dim_sizes.append(rt.uniform_row_length)
|
|
||||||
rt = rt.values
|
|
||||||
|
|
||||||
return RaggedTensorDynamicShape(
|
return RaggedTensorDynamicShape(
|
||||||
tuple(partitioned_dim_sizes),
|
partitioned_dim_sizes,
|
||||||
array_ops.shape(rt_input.flat_values)[1:],
|
array_ops.shape(rt_input.flat_values)[1:],
|
||||||
dim_size_dtype=dim_size_dtype)
|
dim_size_dtype=dim_size_dtype)
|
||||||
|
|
||||||
|
@ -30,9 +30,6 @@ from tensorflow.python.ops.ragged.ragged_tensor_shape import RaggedTensorDynamic
|
|||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=g-long-lambda
|
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
|
class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
@ -83,15 +80,8 @@ class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
|
|||||||
dict(
|
dict(
|
||||||
value=ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]),
|
value=ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]),
|
||||||
expected_dim_sizes=[2, [2, 1], [2, 1, 2]]),
|
expected_dim_sizes=[2, [2, 1], [2, 1, 2]]),
|
||||||
dict(
|
|
||||||
value=lambda: ragged_tensor.RaggedTensor.from_uniform_row_length(
|
|
||||||
ragged_factory_ops.constant([[1, 2], [3, 4, 5], [], [6]]),
|
|
||||||
uniform_row_length=2),
|
|
||||||
expected_dim_sizes=[2, 2, [2, 3, 0, 1]]),
|
|
||||||
])
|
])
|
||||||
def testFromTensor(self, value, expected_dim_sizes):
|
def testFromTensor(self, value, expected_dim_sizes):
|
||||||
if callable(value):
|
|
||||||
value = value()
|
|
||||||
shape = RaggedTensorDynamicShape.from_tensor(value)
|
shape = RaggedTensorDynamicShape.from_tensor(value)
|
||||||
expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes)
|
expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes)
|
||||||
self.assertShapeEq(shape, expected)
|
self.assertShapeEq(shape, expected)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user