RaggedTensor.to_tensor() to preserve inferred static shape.
PiperOrigin-RevId: 303189677 Change-Id: Iad8c17ea68d3139cd72d50bf6c3688b7b0e822c8
This commit is contained in:
parent
b97c90c87f
commit
c0da53ec25
@ -1621,13 +1621,30 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
default_value = array_ops.zeros((), self.dtype)
|
||||
|
||||
shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype)
|
||||
return gen_ragged_conversion_ops.ragged_tensor_to_tensor(
|
||||
tensor = gen_ragged_conversion_ops.ragged_tensor_to_tensor(
|
||||
shape=shape_tensor,
|
||||
values=self.flat_values,
|
||||
default_value=default_value,
|
||||
row_partition_types=row_partition_types,
|
||||
row_partition_tensors=row_partition_tensors)
|
||||
|
||||
ragged_shape = self.shape
|
||||
|
||||
if ragged_shape.rank is not None and not isinstance(shape, ops.Tensor):
|
||||
# Merged self.shape and shape, favoring the second one as it takes
|
||||
# into account potential padding added to the output.
|
||||
shape = tensor_shape.as_shape(shape)
|
||||
if shape.rank is None:
|
||||
output_shape = ragged_shape
|
||||
else:
|
||||
# At this point we can assume that hshape.rank == ragged_shape.rank
|
||||
# because otherwise it would have failed earlier.
|
||||
output_shape = [s1 if s1 is not None else s2 for (s1, s2)
|
||||
in zip(shape.as_list(), ragged_shape.as_list())]
|
||||
tensor.set_shape(output_shape)
|
||||
|
||||
return tensor
|
||||
|
||||
@classmethod
|
||||
def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64):
|
||||
"""Converts a 2D `tf.SparseTensor` to a `RaggedTensor`.
|
||||
|
||||
@ -475,6 +475,22 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase,
|
||||
actual = input_data.to_tensor(shape=[3, 4])
|
||||
self.assertAllEqual(actual, [[0, 1, 2, 0], [0, 0, 0, 0], [3, 0, 0, 0]])
|
||||
|
||||
@parameterized.parameters(
|
||||
([2, 3, 4], None, [2, 3, 4]),
|
||||
([2, 3, 4], [None, None, None], [2, 3, 4]),
|
||||
([2, 3, 4], [None, 3, None], [2, 3, 4]),
|
||||
([2, 3, 4], [None, 3, 4], [2, 3, 4]),
|
||||
([2, 3, 4], [2, 3, 4], [2, 3, 4]),
|
||||
)
|
||||
def test_preserve_shape_roundtrip(
|
||||
self, input_shape, to_tensor_shape, expected_shape):
|
||||
tensor = array_ops.zeros(input_shape)
|
||||
ragged_from_tensor = RaggedTensor.from_tensor(tensor, ragged_rank=2)
|
||||
recovered_tensor = ragged_from_tensor.to_tensor(shape=to_tensor_shape)
|
||||
self.assertAllEqual(tensor.shape.as_list(), expected_shape)
|
||||
self.assertAllEqual(ragged_from_tensor.shape.as_list(), expected_shape)
|
||||
self.assertAllEqual(recovered_tensor.shape.as_list(), expected_shape)
|
||||
|
||||
def test_empty_tensor_with_shape(self):
|
||||
input_data = RaggedTensor.from_value_rowids(
|
||||
values=constant_op.constant([], dtype=dtypes.int64),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user