tf.data.experimental.dense_to_ragged_batch
to output variable ragged rank.
This means: * For input tensors with shapes known statically, the output will still be a `tf.Tensor`. * For input tensors with a `None` shape in the i-th axis, the output will be a `tf.RaggedTensor` with ragged rank of `i`, where it is the higher axis with `None` shape. PiperOrigin-RevId: 303704654 Change-Id: Id4d232688d7a5e4ee0dbca6093743d27432e5de8
This commit is contained in:
parent
e4b8f9dec3
commit
78c624bcc9
@ -45,11 +45,27 @@ def _make_vector_ds(nrows):
|
||||
return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([x], x))
|
||||
|
||||
|
||||
def _make_matrix_ds(nrows):
|
||||
def _make_matrix_ds1(nrows):
|
||||
"""Create a test dataset with matrix elements (of varying size)."""
|
||||
return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([x, 2], x))
|
||||
|
||||
|
||||
def _make_matrix_ds2(nrows):
|
||||
"""Create a test dataset with matrix elements (of varying size)."""
|
||||
return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([2, x], x))
|
||||
|
||||
|
||||
def _make_matrix_ds_fully_defined(nrows):
|
||||
"""Create a test dataset with matrix elements (of varying size)."""
|
||||
return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([2, 3], x))
|
||||
|
||||
|
||||
def _make_5dtensor_ds(nrows):
|
||||
"""Create a test dataset with matrix elements (of varying size)."""
|
||||
return _make_scalar_ds(nrows).map(
|
||||
lambda x: array_ops.fill([2, x, 3, 2*x, 4], x))
|
||||
|
||||
|
||||
def _make_ragged_ds(nrows):
|
||||
"""Create a test dataset with RaggedTensor elements (of varying size)."""
|
||||
values = [[[i] * (i % 3) for i in range(j)] * (j % 3) for j in range(nrows)]
|
||||
@ -64,6 +80,8 @@ def _make_dict_ds(nrows):
|
||||
'shape=[]': ops.convert_to_tensor(x),
|
||||
'shape=[x]': math_ops.range(x),
|
||||
'shape=[x, 2]': array_ops.fill([x, 2], x),
|
||||
'shape=[2, x]': array_ops.fill([2, x], x),
|
||||
'shape=[2, x, 3, 2x, 4]': array_ops.fill([2, x, 3, 2*x, 4], x)
|
||||
}
|
||||
return _make_scalar_ds(nrows).map(transform)
|
||||
|
||||
@ -88,8 +106,9 @@ class RaggedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
make_dataset=[
|
||||
_make_scalar_ds, _make_vector_ds, _make_matrix_ds,
|
||||
_make_ragged_ds, _make_dict_ds, _make_tuple_ds,
|
||||
_make_scalar_ds, _make_vector_ds, _make_matrix_ds1,
|
||||
_make_matrix_ds2, _make_ragged_ds, _make_5dtensor_ds,
|
||||
_make_dict_ds, _make_tuple_ds, _make_matrix_ds_fully_defined,
|
||||
],
|
||||
nrows=[0, 20, 23],
|
||||
batch_size=[4],
|
||||
|
@ -50,11 +50,23 @@ def dense_to_ragged_batch(batch_size,
|
||||
batch from being produced.
|
||||
|
||||
Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
|
||||
different shapes, and each batch will be encoded as a `tf.RaggedTensor`.
|
||||
different shapes:
|
||||
|
||||
* If an input element is a `tf.Tensor` whose static `tf.TensorShape` is
|
||||
fully defined, then it is batched as normal.
|
||||
* If an input element is a `tf.Tensor` whose static `tf.TensorShape` contains
|
||||
one or more axes with unknown size (i.e., `shape[i]=None`), then the output
|
||||
will contain a `tf.RaggedTensor` that is ragged up to any of such
|
||||
dimensions.
|
||||
* If an input element is a `tf.RaggedTensor` or any other type, then it is
|
||||
batched as normal.
|
||||
|
||||
Example:
|
||||
|
||||
>>> dataset = tf.data.Dataset.from_tensor_slices(np.arange(6))
|
||||
>>> dataset = dataset.map(lambda x: tf.range(x))
|
||||
>>> dataset.element_spec.shape
|
||||
TensorShape([None])
|
||||
>>> dataset = dataset.apply(
|
||||
... tf.data.experimental.dense_to_ragged_batch(batch_size=2))
|
||||
>>> for batch in dataset:
|
||||
@ -385,32 +397,44 @@ class _DenseToRaggedDataset(dataset_ops.UnaryDataset):
|
||||
any new ragged tensors. Existing `tf.RaggedTensor` elements do *not*
|
||||
have their row_splits dtype changed.
|
||||
"""
|
||||
|
||||
# Replace each TensorSpec in the input dataset's structure with a
|
||||
# corresponding RaggedTensorSpec.
|
||||
def to_ragged_spec(spec):
|
||||
if isinstance(spec, tensor_spec.TensorSpec) and spec.shape.ndims != 0:
|
||||
"""Returns the new spec based on RaggedTensors."""
|
||||
if (not isinstance(spec, tensor_spec.TensorSpec) or
|
||||
spec.shape.rank is None or
|
||||
spec.shape.is_fully_defined()):
|
||||
return spec
|
||||
else:
|
||||
ragged_rank = max([
|
||||
axis for (axis, size) in enumerate(spec.shape.as_list())
|
||||
if size is None
|
||||
])
|
||||
return ragged_tensor.RaggedTensorSpec(
|
||||
shape=spec.shape,
|
||||
dtype=spec.dtype,
|
||||
ragged_rank=0,
|
||||
ragged_rank=ragged_rank,
|
||||
row_splits_dtype=row_splits_dtype)
|
||||
else:
|
||||
return spec
|
||||
|
||||
self._structure = nest.map_structure(to_ragged_spec,
|
||||
input_dataset.element_spec)
|
||||
|
||||
# Replace each tf.Tensor value in the input dataset with a variant-encoded
|
||||
# RaggedTensor. Since we're updating the corresponding structure to be
|
||||
# RaggedTensor. Since we're updating the corresponding structure to be
|
||||
# a RaggedTensorSpec, this variant-encoded tensor will be decoded with
|
||||
# RaggedTensorSpec._from_tensor_list.
|
||||
def to_ragged_variant(value):
|
||||
if isinstance(value, ops.Tensor) and value.shape.ndims != 0:
|
||||
spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value))
|
||||
return spec._to_tensor_list(value)[0] # pylint: disable=protected-access
|
||||
else:
|
||||
"""Re-encode Tensors as RaggedTensors."""
|
||||
if (not isinstance(value, ops.Tensor) or
|
||||
value.shape.rank is None or
|
||||
value.shape.is_fully_defined()):
|
||||
return value
|
||||
else:
|
||||
spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value))
|
||||
if spec._ragged_rank > 0: # pylint: disable=protected-access
|
||||
value = ragged_tensor.RaggedTensor.from_tensor(
|
||||
value, ragged_rank=spec._ragged_rank) # pylint: disable=protected-access
|
||||
return spec._to_tensor_list(value)[0] # pylint: disable=protected-access
|
||||
|
||||
# Tuples are automatically unpacked by `dataset.map` so we repack them.
|
||||
if dataset_ops._should_unpack_args(input_dataset.element_spec): # pylint: disable=protected-access
|
||||
|
Loading…
x
Reference in New Issue
Block a user