RaggedTensor.to_tensor() to preserve inferred static shape.

PiperOrigin-RevId: 303189677
Change-Id: Iad8c17ea68d3139cd72d50bf6c3688b7b0e822c8
This commit is contained in:
A. Unique TensorFlower 2020-03-26 14:04:59 -07:00 committed by TensorFlower Gardener
parent b97c90c87f
commit c0da53ec25
2 changed files with 34 additions and 1 deletions

View File

@ -1621,13 +1621,30 @@ class RaggedTensor(composite_tensor.CompositeTensor):
default_value = array_ops.zeros((), self.dtype)
shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype)
return gen_ragged_conversion_ops.ragged_tensor_to_tensor(
tensor = gen_ragged_conversion_ops.ragged_tensor_to_tensor(
shape=shape_tensor,
values=self.flat_values,
default_value=default_value,
row_partition_types=row_partition_types,
row_partition_tensors=row_partition_tensors)
ragged_shape = self.shape
if ragged_shape.rank is not None and not isinstance(shape, ops.Tensor):
# Merged self.shape and shape, favoring the second one as it takes
# into account potential padding added to the output.
shape = tensor_shape.as_shape(shape)
if shape.rank is None:
output_shape = ragged_shape
else:
# At this point we can assume that hshape.rank == ragged_shape.rank
# because otherwise it would have failed earlier.
output_shape = [s1 if s1 is not None else s2 for (s1, s2)
in zip(shape.as_list(), ragged_shape.as_list())]
tensor.set_shape(output_shape)
return tensor
@classmethod
def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64):
"""Converts a 2D `tf.SparseTensor` to a `RaggedTensor`.

View File

@ -475,6 +475,22 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase,
actual = input_data.to_tensor(shape=[3, 4])
self.assertAllEqual(actual, [[0, 1, 2, 0], [0, 0, 0, 0], [3, 0, 0, 0]])
@parameterized.parameters(
([2, 3, 4], None, [2, 3, 4]),
([2, 3, 4], [None, None, None], [2, 3, 4]),
([2, 3, 4], [None, 3, None], [2, 3, 4]),
([2, 3, 4], [None, 3, 4], [2, 3, 4]),
([2, 3, 4], [2, 3, 4], [2, 3, 4]),
)
def test_preserve_shape_roundtrip(
self, input_shape, to_tensor_shape, expected_shape):
tensor = array_ops.zeros(input_shape)
ragged_from_tensor = RaggedTensor.from_tensor(tensor, ragged_rank=2)
recovered_tensor = ragged_from_tensor.to_tensor(shape=to_tensor_shape)
self.assertAllEqual(tensor.shape.as_list(), expected_shape)
self.assertAllEqual(ragged_from_tensor.shape.as_list(), expected_shape)
self.assertAllEqual(recovered_tensor.shape.as_list(), expected_shape)
def test_empty_tensor_with_shape(self):
input_data = RaggedTensor.from_value_rowids(
values=constant_op.constant([], dtype=dtypes.int64),