In tf.data, set Tensor dimension sizes when reconstructing a RaggedTensor from its boxed encoding.

PiperOrigin-RevId: 250548634
This commit is contained in:
Edward Loper 2019-05-29 12:56:22 -07:00 committed by TensorFlower Gardener
parent 9cd9ebced6
commit 53fd64291f
2 changed files with 27 additions and 1 deletions

View File

@ -835,8 +835,16 @@ class RaggedTensorStructure(Structure):
raise ValueError(
"ragged_rank must be greater than zero. Found ragged_rank: %d" %
self._ragged_rank)
return ragged_tensor.RaggedTensor._from_variant(
result = ragged_tensor.RaggedTensor._from_variant(
flat_value[0], dtype=self._dtype, output_ragged_rank=self._ragged_rank)
if self._shape.ndims is not None:
outer_dim = tensor_shape.dimension_value(self._shape[0])
if outer_dim is not None:
result.row_splits.set_shape([outer_dim + 1])
result.flat_values.set_shape(
tensor_shape.TensorShape([None]).concatenate(
self._shape[1 + self._ragged_rank:]))
return result
@staticmethod
def from_value(value):

View File

@ -342,6 +342,24 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
# pylint: enable=g-long-lambda
def preserveStaticShape(self):
rt = ragged_factory_ops.constant([[1, 2], [], [3]])
rt_s = structure.Structure.from_value(rt)
rt_after = rt_s._from_tensor_list(rt_s._to_tensor_list(rt))
self.assertEqual(rt_after.row_splits.shape.as_list(),
rt.row_splits.shape.as_list())
self.assertEqual(rt_after.values.shape.as_list(), [None])
st = sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5])
st_s = structure.Structure.from_value(st)
st_after = st_s._from_tensor_list(st_s._to_tensor_list(st))
self.assertEqual(st_after.indices.shape.as_list(),
[None, 2])
self.assertEqual(st_after.values.shape.as_list(), [None])
self.assertEqual(st_after.dense_shape.shape.as_list(),
st.dense_shape.shape.as_list())
def testIncompatibleStructure(self):
# Define three mutually incompatible values/structures, and assert that:
# 1. Using one structure to flatten a value with an incompatible structure