Update RaggedTensorSpec to restore uniform_row_lengths when decoding from the variant encoding (which is used when RaggedTensors are stored in Datasets).
PiperOrigin-RevId: 303877025 Change-Id: I7f1798531aaebfbd16d8c30f9c84b8d51dbde059
This commit is contained in:
parent
b4bdbbb065
commit
21539c733c
tensorflow/python/ops/ragged
@ -33,6 +33,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_like
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
@ -1374,6 +1375,59 @@ class RaggedTensor(composite_tensor.CompositeTensor, tensor_like.TensorLike):
|
||||
"inner_axis (%d)" % (outer_axis, inner_axis))
|
||||
return _merge_dims(self, outer_axis, inner_axis)
|
||||
|
||||
def _set_shape(self, shape):
|
||||
"""Updates the static shape of `self` to be `shape`.
|
||||
|
||||
* If a dimension of `shape` has known rank, and is encoded via
|
||||
partitioning, then this will update the corresponding partition to
|
||||
define `_uniform_row_length` and `nrows`.
|
||||
* If a dimension of `shape` has a known rank, and is encoded as one
|
||||
of the `flat_values` dimensions, then `flat_values.set_shape()` will
|
||||
be used to update its shape.
|
||||
|
||||
Warning: Using this method to assert an incorrect shape for a RaggedTensor
|
||||
(i.e., one that's not consistent with its actual shape) can cause
|
||||
segmentation faults and very difficult-to-diagnose behavior. Only use this
|
||||
method if you are certain that the shape is correct.
|
||||
|
||||
Args:
|
||||
shape: `tf.TensorShape` specifying the shape for this `RaggedTensor`.
|
||||
"""
|
||||
# TODO(edloper): Refactor this to not directly access private members
|
||||
# of RowPartition.
|
||||
# pylint: disable=protected-access
|
||||
|
||||
shape = tensor_shape.as_shape(shape)
|
||||
if shape.rank is None:
|
||||
return # Nothing to do.
|
||||
|
||||
shape = shape.as_list()
|
||||
|
||||
# Outermost dimension
|
||||
if shape[0] is not None:
|
||||
self._row_partition._row_splits.set_shape(shape[0] + 1)
|
||||
|
||||
# Partitioned dimensions
|
||||
dtype = self._row_partition.dtype
|
||||
for i, partition in enumerate(self._nested_row_partitions):
|
||||
size = shape[i + 1]
|
||||
if size is not None:
|
||||
if partition._uniform_row_length is not None:
|
||||
old_row_length = tensor_util.constant_value(
|
||||
partition._uniform_row_length)
|
||||
if old_row_length is not None:
|
||||
if size == old_row_length:
|
||||
continue # already have shape info for this axis.
|
||||
else:
|
||||
raise ValueError("Inconsistent size for axis %s: %s vs %s" %
|
||||
((i + 1), old_row_length, size))
|
||||
partition._uniform_row_length = ops.convert_to_tensor(size, dtype)
|
||||
if partition._nrows is None:
|
||||
partition._nrows = array_ops.size(partition._row_splits) - 1
|
||||
|
||||
# Inner dimensions
|
||||
flat_shape = tensor_shape.as_shape([None] + shape[self.ragged_rank + 1:])
|
||||
self.flat_values.set_shape(flat_shape)
|
||||
|
||||
#=============================================================================
|
||||
# Tensor Type Conversions
|
||||
@ -2164,6 +2218,8 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
return [tensor_spec.TensorSpec(None, dtypes.variant)]
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
# TODO(edloper): Update gen_ragged_conversion_ops that convert to and
|
||||
# from variant to include all of the row-partitioning tensors.
|
||||
ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0
|
||||
if ragged_rank != self._ragged_rank:
|
||||
raise ValueError("Ragged rank of value (%d) does not match ragged "
|
||||
@ -2202,9 +2258,7 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
outer_dim = tensor_shape.dimension_value(self._shape[0])
|
||||
if outer_dim is not None:
|
||||
result.row_splits.set_shape([outer_dim + 1])
|
||||
result.flat_values.set_shape(
|
||||
tensor_shape.TensorShape([None]).concatenate(
|
||||
self._shape[1 + self._ragged_rank:]))
|
||||
result._set_shape(self._shape) # pylint: disable=protected-access
|
||||
else:
|
||||
result.set_shape(self._shape)
|
||||
return result
|
||||
|
@ -41,6 +41,7 @@ from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec
|
||||
from tensorflow.python.ops.ragged.row_partition import RowPartition
|
||||
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def int32array(values):
|
||||
@ -1483,6 +1484,73 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, 'only supported in eager mode'):
|
||||
rt.numpy()
|
||||
|
||||
@parameterized.parameters([
|
||||
([[[1, 2], [3, 4, 5]], [[6]]], 2, None),
|
||||
([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]),
|
||||
([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]),
|
||||
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None),
|
||||
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]),
|
||||
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]),
|
||||
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]),
|
||||
([[[1, 2, 3]]], 1, [1, 1, None]),
|
||||
([[[1, 2, 3]]], 1, [1, 1, 3]),
|
||||
])
|
||||
def testRaggedTensorSetShape(self, rt, rt_ragged_rank, shape):
|
||||
rt1 = ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank)
|
||||
rt1._set_shape(shape)
|
||||
rt1.shape.assert_is_compatible_with(shape)
|
||||
if shape is not None:
|
||||
self.assertIsNot(rt1.shape.rank, None)
|
||||
for a, b in zip(rt1.shape, shape):
|
||||
if b is not None:
|
||||
self.assertEqual(a, b)
|
||||
|
||||
@parameterized.parameters([
|
||||
([[[1, 2], [3, 4, 5]], [[6]]], 2, None),
|
||||
([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]),
|
||||
([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]),
|
||||
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None),
|
||||
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]),
|
||||
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]),
|
||||
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]),
|
||||
([[[1, 2, 3]]], 1, [1, 1, None]),
|
||||
([[[1, 2, 3]]], 1, [1, 1, 3]),
|
||||
])
|
||||
def testRaggedTensorSetShapeWithPlaceholders(self, rt, rt_ragged_rank, shape):
|
||||
rt2 = nest.map_structure(
|
||||
lambda x: array_ops.placeholder_with_default(x, None),
|
||||
ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank),
|
||||
expand_composites=True)
|
||||
rt2._set_shape(shape)
|
||||
rt2.shape.assert_is_compatible_with(shape)
|
||||
if shape is not None:
|
||||
self.assertIsNot(rt2.shape.rank, None)
|
||||
for a, b in zip(rt2.shape, shape):
|
||||
if b is not None:
|
||||
self.assertEqual(a, b)
|
||||
|
||||
def testRaggedTensorSetShapeUniformRowLength(self):
|
||||
rt = [[[1], [2], [3]], [[4], [5], [6]]]
|
||||
|
||||
rt1 = RaggedTensor.from_tensor(rt, ragged_rank=1)
|
||||
rt1._set_shape([2, 3, 1])
|
||||
|
||||
rt2 = nest.map_structure(
|
||||
lambda x: array_ops.placeholder_with_default(x, None),
|
||||
rt1, expand_composites=True)
|
||||
rt2._set_shape([2, 3, 1])
|
||||
|
||||
def testRaggedTensorSetShapeInconsistentShapeError(self):
|
||||
rt = RaggedTensor.from_tensor([[[1], [2], [3]], [[4], [5], [6]]],
|
||||
ragged_rank=1)
|
||||
self.assertEqual(rt.shape.as_list(), [2, 3, 1])
|
||||
with self.assertRaises(ValueError):
|
||||
rt._set_shape([None, None, 5])
|
||||
with self.assertRaisesRegex(ValueError, 'Inconsistent size'):
|
||||
rt._set_shape([None, 5, None])
|
||||
with self.assertRaises(ValueError):
|
||||
rt._set_shape([5, None, None])
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
@ -1665,6 +1733,17 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
[t[0] for t in tensor_list])
|
||||
self.assertAllEqual(rt[0], first_row)
|
||||
|
||||
def testToFromBatchedTensorListPreservesUniformRowLengths(self):
|
||||
rt = RaggedTensor.from_tensor(array_ops.zeros([3, 4, 5]),
|
||||
ragged_rank=2)
|
||||
rt_spec = rt._type_spec
|
||||
tensor_list = rt_spec._to_batched_tensor_list(rt)
|
||||
rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
|
||||
self.assertAllEqual(rt, rt_reconstructed)
|
||||
self.assertTrue(rt.shape.is_fully_defined())
|
||||
self.assertTrue(rt_reconstructed.shape.is_fully_defined())
|
||||
self.assertEqual(rt.shape.as_list(), rt_reconstructed.shape.as_list())
|
||||
|
||||
@parameterized.parameters([
|
||||
(RaggedTensorSpec([2, None], dtypes.float32, 1), 32,
|
||||
RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),
|
||||
|
Loading…
Reference in New Issue
Block a user