From 21539c733c5d028cd24dc620a24a4cba4fd2cbbc Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Mon, 30 Mar 2020 20:01:52 -0700 Subject: [PATCH] Update RaggedTensorSpec to restore uniform_row_lengths when decoding from the variant encoding (which is used when RaggedTensors are stored in Datasets). PiperOrigin-RevId: 303877025 Change-Id: I7f1798531aaebfbd16d8c30f9c84b8d51dbde059 --- tensorflow/python/ops/ragged/ragged_tensor.py | 60 +++++++++++++- .../python/ops/ragged/ragged_tensor_test.py | 79 +++++++++++++++++++ 2 files changed, 136 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index 28c4c2db093..afb631ed0f2 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_like from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -1374,6 +1375,59 @@ class RaggedTensor(composite_tensor.CompositeTensor, tensor_like.TensorLike): "inner_axis (%d)" % (outer_axis, inner_axis)) return _merge_dims(self, outer_axis, inner_axis) + def _set_shape(self, shape): + """Updates the static shape of `self` to be `shape`. + + * If a dimension of `shape` has known rank, and is encoded via + partitioning, then this will update the corresponding partition to + define `_uniform_row_length` and `nrows`. + * If a dimension of `shape` has a known rank, and is encoded as one + of the `flat_values` dimensions, then `flat_values.set_shape()` will + be used to update its shape. + + Warning: Using this method to assert an incorrect shape for a RaggedTensor + (i.e., one that's not consistent with its actual shape) can cause + segmentation faults and very difficult-to-diagnose behavior. Only use this + method if you are certain that the shape is correct. + + Args: + shape: `tf.TensorShape` specifying the shape for this `RaggedTensor`. + """ + # TODO(edloper): Refactor this to not directly access private members + # of RowPartition. + # pylint: disable=protected-access + + shape = tensor_shape.as_shape(shape) + if shape.rank is None: + return # Nothing to do. + + shape = shape.as_list() + + # Outermost dimension + if shape[0] is not None: + self._row_partition._row_splits.set_shape(shape[0] + 1) + + # Partitioned dimensions + dtype = self._row_partition.dtype + for i, partition in enumerate(self._nested_row_partitions): + size = shape[i + 1] + if size is not None: + if partition._uniform_row_length is not None: + old_row_length = tensor_util.constant_value( + partition._uniform_row_length) + if old_row_length is not None: + if size == old_row_length: + continue # already have shape info for this axis. + else: + raise ValueError("Inconsistent size for axis %s: %s vs %s" % + ((i + 1), old_row_length, size)) + partition._uniform_row_length = ops.convert_to_tensor(size, dtype) + if partition._nrows is None: + partition._nrows = array_ops.size(partition._row_splits) - 1 + + # Inner dimensions + flat_shape = tensor_shape.as_shape([None] + shape[self.ragged_rank + 1:]) + self.flat_values.set_shape(flat_shape) #============================================================================= # Tensor Type Conversions @@ -2164,6 +2218,8 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): return [tensor_spec.TensorSpec(None, dtypes.variant)] def _to_tensor_list(self, value): + # TODO(edloper): Update gen_ragged_conversion_ops that convert to and + # from variant to include all of the row-partitioning tensors. ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0 if ragged_rank != self._ragged_rank: raise ValueError("Ragged rank of value (%d) does not match ragged " @@ -2202,9 +2258,7 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): outer_dim = tensor_shape.dimension_value(self._shape[0]) if outer_dim is not None: result.row_splits.set_shape([outer_dim + 1]) - result.flat_values.set_shape( - tensor_shape.TensorShape([None]).concatenate( - self._shape[1 + self._ragged_rank:])) + result._set_shape(self._shape) # pylint: disable=protected-access else: result.set_shape(self._shape) return result diff --git a/tensorflow/python/ops/ragged/ragged_tensor_test.py b/tensorflow/python/ops/ragged/ragged_tensor_test.py index 20c21bd5947..5b6521b5aa5 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor_test.py +++ b/tensorflow/python/ops/ragged/ragged_tensor_test.py @@ -41,6 +41,7 @@ from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec from tensorflow.python.ops.ragged.row_partition import RowPartition from tensorflow.python.platform import googletest +from tensorflow.python.util import nest def int32array(values): @@ -1483,6 +1484,73 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase): with self.assertRaisesRegexp(ValueError, 'only supported in eager mode'): rt.numpy() + @parameterized.parameters([ + ([[[1, 2], [3, 4, 5]], [[6]]], 2, None), + ([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]), + ([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]), + ([[[1, 2, 3]]], 1, [1, 1, None]), + ([[[1, 2, 3]]], 1, [1, 1, 3]), + ]) + def testRaggedTensorSetShape(self, rt, rt_ragged_rank, shape): + rt1 = ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank) + rt1._set_shape(shape) + rt1.shape.assert_is_compatible_with(shape) + if shape is not None: + self.assertIsNot(rt1.shape.rank, None) + for a, b in zip(rt1.shape, shape): + if b is not None: + self.assertEqual(a, b) + + @parameterized.parameters([ + ([[[1, 2], [3, 4, 5]], [[6]]], 2, None), + ([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]), + ([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]), + ([[[1, 2, 3]]], 1, [1, 1, None]), + ([[[1, 2, 3]]], 1, [1, 1, 3]), + ]) + def testRaggedTensorSetShapeWithPlaceholders(self, rt, rt_ragged_rank, shape): + rt2 = nest.map_structure( + lambda x: array_ops.placeholder_with_default(x, None), + ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank), + expand_composites=True) + rt2._set_shape(shape) + rt2.shape.assert_is_compatible_with(shape) + if shape is not None: + self.assertIsNot(rt2.shape.rank, None) + for a, b in zip(rt2.shape, shape): + if b is not None: + self.assertEqual(a, b) + + def testRaggedTensorSetShapeUniformRowLength(self): + rt = [[[1], [2], [3]], [[4], [5], [6]]] + + rt1 = RaggedTensor.from_tensor(rt, ragged_rank=1) + rt1._set_shape([2, 3, 1]) + + rt2 = nest.map_structure( + lambda x: array_ops.placeholder_with_default(x, None), + rt1, expand_composites=True) + rt2._set_shape([2, 3, 1]) + + def testRaggedTensorSetShapeInconsistentShapeError(self): + rt = RaggedTensor.from_tensor([[[1], [2], [3]], [[4], [5], [6]]], + ragged_rank=1) + self.assertEqual(rt.shape.as_list(), [2, 3, 1]) + with self.assertRaises(ValueError): + rt._set_shape([None, None, 5]) + with self.assertRaisesRegex(ValueError, 'Inconsistent size'): + rt._set_shape([None, 5, None]) + with self.assertRaises(ValueError): + rt._set_shape([5, None, None]) + @test_util.run_all_in_graph_and_eager_modes class RaggedTensorSpecTest(test_util.TensorFlowTestCase, @@ -1665,6 +1733,17 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase, [t[0] for t in tensor_list]) self.assertAllEqual(rt[0], first_row) + def testToFromBatchedTensorListPreservesUniformRowLengths(self): + rt = RaggedTensor.from_tensor(array_ops.zeros([3, 4, 5]), + ragged_rank=2) + rt_spec = rt._type_spec + tensor_list = rt_spec._to_batched_tensor_list(rt) + rt_reconstructed = rt_spec._from_tensor_list(tensor_list) + self.assertAllEqual(rt, rt_reconstructed) + self.assertTrue(rt.shape.is_fully_defined()) + self.assertTrue(rt_reconstructed.shape.is_fully_defined()) + self.assertEqual(rt.shape.as_list(), rt_reconstructed.shape.as_list()) + @parameterized.parameters([ (RaggedTensorSpec([2, None], dtypes.float32, 1), 32, RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),