In tf.data, set Tensor dimension sizes when reconstructing a RaggedTensor from its boxed encoding.
PiperOrigin-RevId: 250548634
This commit is contained in:
parent
9cd9ebced6
commit
53fd64291f
@ -835,8 +835,16 @@ class RaggedTensorStructure(Structure):
|
||||
raise ValueError(
|
||||
"ragged_rank must be greater than zero. Found ragged_rank: %d" %
|
||||
self._ragged_rank)
|
||||
return ragged_tensor.RaggedTensor._from_variant(
|
||||
result = ragged_tensor.RaggedTensor._from_variant(
|
||||
flat_value[0], dtype=self._dtype, output_ragged_rank=self._ragged_rank)
|
||||
if self._shape.ndims is not None:
|
||||
outer_dim = tensor_shape.dimension_value(self._shape[0])
|
||||
if outer_dim is not None:
|
||||
result.row_splits.set_shape([outer_dim + 1])
|
||||
result.flat_values.set_shape(
|
||||
tensor_shape.TensorShape([None]).concatenate(
|
||||
self._shape[1 + self._ragged_rank:]))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
|
@ -342,6 +342,24 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
|
||||
# pylint: enable=g-long-lambda
|
||||
|
||||
def preserveStaticShape(self):
|
||||
rt = ragged_factory_ops.constant([[1, 2], [], [3]])
|
||||
rt_s = structure.Structure.from_value(rt)
|
||||
rt_after = rt_s._from_tensor_list(rt_s._to_tensor_list(rt))
|
||||
self.assertEqual(rt_after.row_splits.shape.as_list(),
|
||||
rt.row_splits.shape.as_list())
|
||||
self.assertEqual(rt_after.values.shape.as_list(), [None])
|
||||
|
||||
st = sparse_tensor.SparseTensor(
|
||||
indices=[[3, 4]], values=[-1], dense_shape=[4, 5])
|
||||
st_s = structure.Structure.from_value(st)
|
||||
st_after = st_s._from_tensor_list(st_s._to_tensor_list(st))
|
||||
self.assertEqual(st_after.indices.shape.as_list(),
|
||||
[None, 2])
|
||||
self.assertEqual(st_after.values.shape.as_list(), [None])
|
||||
self.assertEqual(st_after.dense_shape.shape.as_list(),
|
||||
st.dense_shape.shape.as_list())
|
||||
|
||||
def testIncompatibleStructure(self):
|
||||
# Define three mutually incompatible values/structures, and assert that:
|
||||
# 1. Using one structure to flatten a value with an incompatible structure
|
||||
|
Loading…
x
Reference in New Issue
Block a user