Preserve shape information when passing SparseTensors to dataset functions

When we flatten SparseTensors into Tensors, the dense_shape of the SparseTensor is stored as a Tensor of dimensions instead of as a shape. Function tracing uses placeholder Tensors with no content, making it look as though all input SparseTensors have undefined shape.

This CL improves tracing by restoring SparseTensors' dense_shapes from their original SparseTensorSpecs.

PiperOrigin-RevId: 264927072
This commit is contained in:
Andrew Audibert 2019-08-22 15:01:05 -07:00 committed by TensorFlower Gardener
parent 9f00c8dbdf
commit 0f65838cb9
2 changed files with 46 additions and 4 deletions
tensorflow/python
data/kernel_tests
framework

View File

@ -733,6 +733,30 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset,
expected_output=[self.evaluate(_check(_sparse(i))) for i in range(10)])
def testSparseMapShapeInference(self):
if not context.executing_eagerly():
self.skipTest("SparseTensor shape inference requires eager mode")
row_lengths = np.random.randint(0, 4, size=128)
values = np.ones(np.sum(row_lengths))
sparse = ragged_tensor.RaggedTensor.from_row_lengths(
values, row_lengths).to_sparse()
dataset = dataset_ops.Dataset.from_tensor_slices(sparse)
dataset = dataset.batch(32, drop_remainder=True)
dataset = dataset.map(lambda x: x)
self.assertEqual((32, 3), dataset.element_spec.shape)
def testSparseMapShapeInferencePartial(self):
if not context.executing_eagerly():
self.skipTest("SparseTensor shape inference requires eager mode")
row_lengths = np.random.randint(0, 4, size=128)
values = np.ones(np.sum(row_lengths))
sparse = ragged_tensor.RaggedTensor.from_row_lengths(
values, row_lengths).to_sparse()
dataset = dataset_ops.Dataset.from_tensor_slices(sparse)
dataset = dataset.batch(32, drop_remainder=False)
dataset = dataset.map(lambda x: x)
self.assertEqual([None, 3], dataset.element_spec.shape.as_list())
def testTensorArray(self):
def _tensor_array(i):

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import tf2
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_like
@ -338,11 +339,28 @@ class SparseTensorSpec(type_spec.BatchableTypeSpec):
def _from_compatible_tensor_list(self, tensor_list):
tensor_list = gen_sparse_ops.deserialize_sparse(tensor_list[0], self._dtype)
result = SparseTensor(*tensor_list)
indices, values, dense_shape = tensor_list
rank = self._shape.ndims
result.indices.set_shape([None, rank])
result.dense_shape.set_shape([rank])
return result
indices.set_shape([None, rank])
# We restore the dense_shape from the SparseTypeSpec. This is necessary
# for shape inference when using placeholder SparseTensors in function
# tracing.
if self._shape.is_fully_defined():
dense_shape = ops.convert_to_tensor(
self._shape, dtype=dtypes.int64, name="shape")
elif (self._shape.rank is not None and
any(dim.value is not None for dim in self._shape.dims)):
# array_ops imports sparse_tensor.py. Local import to avoid import cycle.
from tensorflow.python.ops import array_ops # pylint: disable=g-import-not-at-top
pieces = array_ops.unstack(dense_shape, num=self._shape.rank)
for i, dim in enumerate(self._shape.dims):
if dim.value is not None:
pieces[i] = constant_op.constant(dim.value, dense_shape.dtype)
dense_shape = array_ops.stack(pieces)
else:
dense_shape.set_shape([rank])
return SparseTensor(indices, values, dense_shape)
def _batch(self, batch_size):
return SparseTensorSpec(