Update RaggedTensorSpec to restore uniform_row_lengths when decoding from the variant encoding (which is used when RaggedTensors are stored in Datasets).

PiperOrigin-RevId: 303877025
Change-Id: I7f1798531aaebfbd16d8c30f9c84b8d51dbde059
This commit is contained in:
Edward Loper 2020-03-30 20:01:52 -07:00 committed by TensorFlower Gardener
parent b4bdbbb065
commit 21539c733c
2 changed files with 136 additions and 3 deletions
tensorflow/python/ops/ragged

View File

@ -33,6 +33,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_like
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@ -1374,6 +1375,59 @@ class RaggedTensor(composite_tensor.CompositeTensor, tensor_like.TensorLike):
"inner_axis (%d)" % (outer_axis, inner_axis))
return _merge_dims(self, outer_axis, inner_axis)
def _set_shape(self, shape):
"""Updates the static shape of `self` to be `shape`.
* If a dimension of `shape` has known rank, and is encoded via
partitioning, then this will update the corresponding partition to
define `_uniform_row_length` and `nrows`.
* If a dimension of `shape` has a known rank, and is encoded as one
of the `flat_values` dimensions, then `flat_values.set_shape()` will
be used to update its shape.
Warning: Using this method to assert an incorrect shape for a RaggedTensor
(i.e., one that's not consistent with its actual shape) can cause
segmentation faults and very difficult-to-diagnose behavior. Only use this
method if you are certain that the shape is correct.
Args:
shape: `tf.TensorShape` specifying the shape for this `RaggedTensor`.
"""
# TODO(edloper): Refactor this to not directly access private members
# of RowPartition.
# pylint: disable=protected-access
shape = tensor_shape.as_shape(shape)
if shape.rank is None:
return # Nothing to do.
shape = shape.as_list()
# Outermost dimension
if shape[0] is not None:
self._row_partition._row_splits.set_shape(shape[0] + 1)
# Partitioned dimensions
dtype = self._row_partition.dtype
for i, partition in enumerate(self._nested_row_partitions):
size = shape[i + 1]
if size is not None:
if partition._uniform_row_length is not None:
old_row_length = tensor_util.constant_value(
partition._uniform_row_length)
if old_row_length is not None:
if size == old_row_length:
continue # already have shape info for this axis.
else:
raise ValueError("Inconsistent size for axis %s: %s vs %s" %
((i + 1), old_row_length, size))
partition._uniform_row_length = ops.convert_to_tensor(size, dtype)
if partition._nrows is None:
partition._nrows = array_ops.size(partition._row_splits) - 1
# Inner dimensions
flat_shape = tensor_shape.as_shape([None] + shape[self.ragged_rank + 1:])
self.flat_values.set_shape(flat_shape)
#=============================================================================
# Tensor Type Conversions
@ -2164,6 +2218,8 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
return [tensor_spec.TensorSpec(None, dtypes.variant)]
def _to_tensor_list(self, value):
# TODO(edloper): Update gen_ragged_conversion_ops that convert to and
# from variant to include all of the row-partitioning tensors.
ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0
if ragged_rank != self._ragged_rank:
raise ValueError("Ragged rank of value (%d) does not match ragged "
@ -2202,9 +2258,7 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
outer_dim = tensor_shape.dimension_value(self._shape[0])
if outer_dim is not None:
result.row_splits.set_shape([outer_dim + 1])
result.flat_values.set_shape(
tensor_shape.TensorShape([None]).concatenate(
self._shape[1 + self._ragged_rank:]))
result._set_shape(self._shape) # pylint: disable=protected-access
else:
result.set_shape(self._shape)
return result

View File

@ -41,6 +41,7 @@ from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec
from tensorflow.python.ops.ragged.row_partition import RowPartition
from tensorflow.python.platform import googletest
from tensorflow.python.util import nest
def int32array(values):
@ -1483,6 +1484,73 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
with self.assertRaisesRegexp(ValueError, 'only supported in eager mode'):
rt.numpy()
@parameterized.parameters([
([[[1, 2], [3, 4, 5]], [[6]]], 2, None),
([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]),
([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]),
([[[1, 2, 3]]], 1, [1, 1, None]),
([[[1, 2, 3]]], 1, [1, 1, 3]),
])
def testRaggedTensorSetShape(self, rt, rt_ragged_rank, shape):
rt1 = ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank)
rt1._set_shape(shape)
rt1.shape.assert_is_compatible_with(shape)
if shape is not None:
self.assertIsNot(rt1.shape.rank, None)
for a, b in zip(rt1.shape, shape):
if b is not None:
self.assertEqual(a, b)
@parameterized.parameters([
([[[1, 2], [3, 4, 5]], [[6]]], 2, None),
([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]),
([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]),
([[[1, 2, 3]]], 1, [1, 1, None]),
([[[1, 2, 3]]], 1, [1, 1, 3]),
])
def testRaggedTensorSetShapeWithPlaceholders(self, rt, rt_ragged_rank, shape):
rt2 = nest.map_structure(
lambda x: array_ops.placeholder_with_default(x, None),
ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank),
expand_composites=True)
rt2._set_shape(shape)
rt2.shape.assert_is_compatible_with(shape)
if shape is not None:
self.assertIsNot(rt2.shape.rank, None)
for a, b in zip(rt2.shape, shape):
if b is not None:
self.assertEqual(a, b)
def testRaggedTensorSetShapeUniformRowLength(self):
rt = [[[1], [2], [3]], [[4], [5], [6]]]
rt1 = RaggedTensor.from_tensor(rt, ragged_rank=1)
rt1._set_shape([2, 3, 1])
rt2 = nest.map_structure(
lambda x: array_ops.placeholder_with_default(x, None),
rt1, expand_composites=True)
rt2._set_shape([2, 3, 1])
def testRaggedTensorSetShapeInconsistentShapeError(self):
rt = RaggedTensor.from_tensor([[[1], [2], [3]], [[4], [5], [6]]],
ragged_rank=1)
self.assertEqual(rt.shape.as_list(), [2, 3, 1])
with self.assertRaises(ValueError):
rt._set_shape([None, None, 5])
with self.assertRaisesRegex(ValueError, 'Inconsistent size'):
rt._set_shape([None, 5, None])
with self.assertRaises(ValueError):
rt._set_shape([5, None, None])
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
@ -1665,6 +1733,17 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
[t[0] for t in tensor_list])
self.assertAllEqual(rt[0], first_row)
def testToFromBatchedTensorListPreservesUniformRowLengths(self):
rt = RaggedTensor.from_tensor(array_ops.zeros([3, 4, 5]),
ragged_rank=2)
rt_spec = rt._type_spec
tensor_list = rt_spec._to_batched_tensor_list(rt)
rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
self.assertAllEqual(rt, rt_reconstructed)
self.assertTrue(rt.shape.is_fully_defined())
self.assertTrue(rt_reconstructed.shape.is_fully_defined())
self.assertEqual(rt.shape.as_list(), rt_reconstructed.shape.as_list())
@parameterized.parameters([
(RaggedTensorSpec([2, None], dtypes.float32, 1), 32,
RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),