From 0f65838cb9787c47f41ede1884e72a0144ad2fe0 Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Thu, 22 Aug 2019 15:01:05 -0700 Subject: [PATCH] Preserve shape information when passing SparseTensors to dataset functions When we flatten SparseTensors into Tensors, the dense_shape of the SparseTensor is stored as a Tensor of dimensions instead of as a shape. Function tracing uses placeholder Tensors with no content, making it look as though all input SparseTensors have undefined shape. This CL improves tracing by restoring SparseTensors' dense_shapes from their original SparseTensorSpecs. PiperOrigin-RevId: 264927072 --- .../python/data/kernel_tests/map_test.py | 24 +++++++++++++++++ tensorflow/python/framework/sparse_tensor.py | 26 ++++++++++++++++--- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/data/kernel_tests/map_test.py b/tensorflow/python/data/kernel_tests/map_test.py index eed46dad723..0847cdd7a0d 100644 --- a/tensorflow/python/data/kernel_tests/map_test.py +++ b/tensorflow/python/data/kernel_tests/map_test.py @@ -733,6 +733,30 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): dataset, expected_output=[self.evaluate(_check(_sparse(i))) for i in range(10)]) + def testSparseMapShapeInference(self): + if not context.executing_eagerly(): + self.skipTest("SparseTensor shape inference requires eager mode") + row_lengths = np.random.randint(0, 4, size=128) + values = np.ones(np.sum(row_lengths)) + sparse = ragged_tensor.RaggedTensor.from_row_lengths( + values, row_lengths).to_sparse() + dataset = dataset_ops.Dataset.from_tensor_slices(sparse) + dataset = dataset.batch(32, drop_remainder=True) + dataset = dataset.map(lambda x: x) + self.assertEqual((32, 3), dataset.element_spec.shape) + + def testSparseMapShapeInferencePartial(self): + if not context.executing_eagerly(): + self.skipTest("SparseTensor shape inference requires eager mode") + row_lengths = np.random.randint(0, 4, size=128) + values = np.ones(np.sum(row_lengths)) + sparse = ragged_tensor.RaggedTensor.from_row_lengths( + values, row_lengths).to_sparse() + dataset = dataset_ops.Dataset.from_tensor_slices(sparse) + dataset = dataset.batch(32, drop_remainder=False) + dataset = dataset.map(lambda x: x) + self.assertEqual([None, 3], dataset.element_spec.shape.as_list()) + def testTensorArray(self): def _tensor_array(i): diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index a598f43b4ab..fe0c42ffde1 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.python import pywrap_tensorflow from tensorflow.python import tf2 from tensorflow.python.framework import composite_tensor +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_like @@ -338,11 +339,28 @@ class SparseTensorSpec(type_spec.BatchableTypeSpec): def _from_compatible_tensor_list(self, tensor_list): tensor_list = gen_sparse_ops.deserialize_sparse(tensor_list[0], self._dtype) - result = SparseTensor(*tensor_list) + indices, values, dense_shape = tensor_list rank = self._shape.ndims - result.indices.set_shape([None, rank]) - result.dense_shape.set_shape([rank]) - return result + indices.set_shape([None, rank]) + # We restore the dense_shape from the SparseTypeSpec. This is necessary + # for shape inference when using placeholder SparseTensors in function + # tracing. + if self._shape.is_fully_defined(): + dense_shape = ops.convert_to_tensor( + self._shape, dtype=dtypes.int64, name="shape") + elif (self._shape.rank is not None and + any(dim.value is not None for dim in self._shape.dims)): + # array_ops imports sparse_tensor.py. Local import to avoid import cycle. + from tensorflow.python.ops import array_ops # pylint: disable=g-import-not-at-top + pieces = array_ops.unstack(dense_shape, num=self._shape.rank) + for i, dim in enumerate(self._shape.dims): + if dim.value is not None: + pieces[i] = constant_op.constant(dim.value, dense_shape.dtype) + dense_shape = array_ops.stack(pieces) + else: + dense_shape.set_shape([rank]) + + return SparseTensor(indices, values, dense_shape) def _batch(self, batch_size): return SparseTensorSpec(