RaggedTensor.to_tensor() to preserve inferred static shape.
PiperOrigin-RevId: 303189677 Change-Id: Iad8c17ea68d3139cd72d50bf6c3688b7b0e822c8
This commit is contained in:
parent
b97c90c87f
commit
c0da53ec25
@ -1621,13 +1621,30 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
|||||||
default_value = array_ops.zeros((), self.dtype)
|
default_value = array_ops.zeros((), self.dtype)
|
||||||
|
|
||||||
shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype)
|
shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype)
|
||||||
return gen_ragged_conversion_ops.ragged_tensor_to_tensor(
|
tensor = gen_ragged_conversion_ops.ragged_tensor_to_tensor(
|
||||||
shape=shape_tensor,
|
shape=shape_tensor,
|
||||||
values=self.flat_values,
|
values=self.flat_values,
|
||||||
default_value=default_value,
|
default_value=default_value,
|
||||||
row_partition_types=row_partition_types,
|
row_partition_types=row_partition_types,
|
||||||
row_partition_tensors=row_partition_tensors)
|
row_partition_tensors=row_partition_tensors)
|
||||||
|
|
||||||
|
ragged_shape = self.shape
|
||||||
|
|
||||||
|
if ragged_shape.rank is not None and not isinstance(shape, ops.Tensor):
|
||||||
|
# Merged self.shape and shape, favoring the second one as it takes
|
||||||
|
# into account potential padding added to the output.
|
||||||
|
shape = tensor_shape.as_shape(shape)
|
||||||
|
if shape.rank is None:
|
||||||
|
output_shape = ragged_shape
|
||||||
|
else:
|
||||||
|
# At this point we can assume that hshape.rank == ragged_shape.rank
|
||||||
|
# because otherwise it would have failed earlier.
|
||||||
|
output_shape = [s1 if s1 is not None else s2 for (s1, s2)
|
||||||
|
in zip(shape.as_list(), ragged_shape.as_list())]
|
||||||
|
tensor.set_shape(output_shape)
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64):
|
def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64):
|
||||||
"""Converts a 2D `tf.SparseTensor` to a `RaggedTensor`.
|
"""Converts a 2D `tf.SparseTensor` to a `RaggedTensor`.
|
||||||
|
|||||||
@ -475,6 +475,22 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase,
|
|||||||
actual = input_data.to_tensor(shape=[3, 4])
|
actual = input_data.to_tensor(shape=[3, 4])
|
||||||
self.assertAllEqual(actual, [[0, 1, 2, 0], [0, 0, 0, 0], [3, 0, 0, 0]])
|
self.assertAllEqual(actual, [[0, 1, 2, 0], [0, 0, 0, 0], [3, 0, 0, 0]])
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
([2, 3, 4], None, [2, 3, 4]),
|
||||||
|
([2, 3, 4], [None, None, None], [2, 3, 4]),
|
||||||
|
([2, 3, 4], [None, 3, None], [2, 3, 4]),
|
||||||
|
([2, 3, 4], [None, 3, 4], [2, 3, 4]),
|
||||||
|
([2, 3, 4], [2, 3, 4], [2, 3, 4]),
|
||||||
|
)
|
||||||
|
def test_preserve_shape_roundtrip(
|
||||||
|
self, input_shape, to_tensor_shape, expected_shape):
|
||||||
|
tensor = array_ops.zeros(input_shape)
|
||||||
|
ragged_from_tensor = RaggedTensor.from_tensor(tensor, ragged_rank=2)
|
||||||
|
recovered_tensor = ragged_from_tensor.to_tensor(shape=to_tensor_shape)
|
||||||
|
self.assertAllEqual(tensor.shape.as_list(), expected_shape)
|
||||||
|
self.assertAllEqual(ragged_from_tensor.shape.as_list(), expected_shape)
|
||||||
|
self.assertAllEqual(recovered_tensor.shape.as_list(), expected_shape)
|
||||||
|
|
||||||
def test_empty_tensor_with_shape(self):
|
def test_empty_tensor_with_shape(self):
|
||||||
input_data = RaggedTensor.from_value_rowids(
|
input_data = RaggedTensor.from_value_rowids(
|
||||||
values=constant_op.constant([], dtype=dtypes.int64),
|
values=constant_op.constant([], dtype=dtypes.int64),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user