tf.data.experimental.dense_to_ragged_batch to output variable ragged rank.

This means:
* For input tensors with shapes known statically, the output will still be a `tf.Tensor`.
* For input tensors with a `None` shape in the i-th axis, the output will be a `tf.RaggedTensor` with ragged rank of `i`, where it is the higher axis with `None` shape.

PiperOrigin-RevId: 303704654
Change-Id: Id4d232688d7a5e4ee0dbca6093743d27432e5de8
This commit is contained in:
A. Unique TensorFlower 2020-03-30 03:18:06 -07:00 committed by TensorFlower Gardener
parent e4b8f9dec3
commit 78c624bcc9
2 changed files with 57 additions and 14 deletions

View File

@ -45,11 +45,27 @@ def _make_vector_ds(nrows):
return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([x], x))
def _make_matrix_ds(nrows):
def _make_matrix_ds1(nrows):
"""Create a test dataset with matrix elements (of varying size)."""
return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([x, 2], x))
def _make_matrix_ds2(nrows):
"""Create a test dataset with matrix elements (of varying size)."""
return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([2, x], x))
def _make_matrix_ds_fully_defined(nrows):
"""Create a test dataset with matrix elements (of varying size)."""
return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([2, 3], x))
def _make_5dtensor_ds(nrows):
"""Create a test dataset with matrix elements (of varying size)."""
return _make_scalar_ds(nrows).map(
lambda x: array_ops.fill([2, x, 3, 2*x, 4], x))
def _make_ragged_ds(nrows):
"""Create a test dataset with RaggedTensor elements (of varying size)."""
values = [[[i] * (i % 3) for i in range(j)] * (j % 3) for j in range(nrows)]
@ -64,6 +80,8 @@ def _make_dict_ds(nrows):
'shape=[]': ops.convert_to_tensor(x),
'shape=[x]': math_ops.range(x),
'shape=[x, 2]': array_ops.fill([x, 2], x),
'shape=[2, x]': array_ops.fill([2, x], x),
'shape=[2, x, 3, 2x, 4]': array_ops.fill([2, x, 3, 2*x, 4], x)
}
return _make_scalar_ds(nrows).map(transform)
@ -88,8 +106,9 @@ class RaggedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
test_base.default_test_combinations(),
combinations.combine(
make_dataset=[
_make_scalar_ds, _make_vector_ds, _make_matrix_ds,
_make_ragged_ds, _make_dict_ds, _make_tuple_ds,
_make_scalar_ds, _make_vector_ds, _make_matrix_ds1,
_make_matrix_ds2, _make_ragged_ds, _make_5dtensor_ds,
_make_dict_ds, _make_tuple_ds, _make_matrix_ds_fully_defined,
],
nrows=[0, 20, 23],
batch_size=[4],

View File

@ -50,11 +50,23 @@ def dense_to_ragged_batch(batch_size,
batch from being produced.
Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
different shapes, and each batch will be encoded as a `tf.RaggedTensor`.
different shapes:
* If an input element is a `tf.Tensor` whose static `tf.TensorShape` is
fully defined, then it is batched as normal.
* If an input element is a `tf.Tensor` whose static `tf.TensorShape` contains
one or more axes with unknown size (i.e., `shape[i]=None`), then the output
will contain a `tf.RaggedTensor` that is ragged up to any of such
dimensions.
* If an input element is a `tf.RaggedTensor` or any other type, then it is
batched as normal.
Example:
>>> dataset = tf.data.Dataset.from_tensor_slices(np.arange(6))
>>> dataset = dataset.map(lambda x: tf.range(x))
>>> dataset.element_spec.shape
TensorShape([None])
>>> dataset = dataset.apply(
... tf.data.experimental.dense_to_ragged_batch(batch_size=2))
>>> for batch in dataset:
@ -385,32 +397,44 @@ class _DenseToRaggedDataset(dataset_ops.UnaryDataset):
any new ragged tensors. Existing `tf.RaggedTensor` elements do *not*
have their row_splits dtype changed.
"""
# Replace each TensorSpec in the input dataset's structure with a
# corresponding RaggedTensorSpec.
def to_ragged_spec(spec):
if isinstance(spec, tensor_spec.TensorSpec) and spec.shape.ndims != 0:
"""Returns the new spec based on RaggedTensors."""
if (not isinstance(spec, tensor_spec.TensorSpec) or
spec.shape.rank is None or
spec.shape.is_fully_defined()):
return spec
else:
ragged_rank = max([
axis for (axis, size) in enumerate(spec.shape.as_list())
if size is None
])
return ragged_tensor.RaggedTensorSpec(
shape=spec.shape,
dtype=spec.dtype,
ragged_rank=0,
ragged_rank=ragged_rank,
row_splits_dtype=row_splits_dtype)
else:
return spec
self._structure = nest.map_structure(to_ragged_spec,
input_dataset.element_spec)
# Replace each tf.Tensor value in the input dataset with a variant-encoded
# RaggedTensor. Since we're updating the corresponding structure to be
# RaggedTensor. Since we're updating the corresponding structure to be
# a RaggedTensorSpec, this variant-encoded tensor will be decoded with
# RaggedTensorSpec._from_tensor_list.
def to_ragged_variant(value):
if isinstance(value, ops.Tensor) and value.shape.ndims != 0:
spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value))
return spec._to_tensor_list(value)[0] # pylint: disable=protected-access
else:
"""Re-encode Tensors as RaggedTensors."""
if (not isinstance(value, ops.Tensor) or
value.shape.rank is None or
value.shape.is_fully_defined()):
return value
else:
spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value))
if spec._ragged_rank > 0: # pylint: disable=protected-access
value = ragged_tensor.RaggedTensor.from_tensor(
value, ragged_rank=spec._ragged_rank) # pylint: disable=protected-access
return spec._to_tensor_list(value)[0] # pylint: disable=protected-access
# Tuples are automatically unpacked by `dataset.map` so we repack them.
if dataset_ops._should_unpack_args(input_dataset.element_spec): # pylint: disable=protected-access