diff --git a/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py index 8fb92ec1d95..e66f401ed8e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py @@ -45,11 +45,27 @@ def _make_vector_ds(nrows): return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([x], x)) -def _make_matrix_ds(nrows): +def _make_matrix_ds1(nrows): """Create a test dataset with matrix elements (of varying size).""" return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([x, 2], x)) +def _make_matrix_ds2(nrows): + """Create a test dataset with matrix elements (of varying size).""" + return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([2, x], x)) + + +def _make_matrix_ds_fully_defined(nrows): + """Create a test dataset with matrix elements (of varying size).""" + return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([2, 3], x)) + + +def _make_5dtensor_ds(nrows): + """Create a test dataset with matrix elements (of varying size).""" + return _make_scalar_ds(nrows).map( + lambda x: array_ops.fill([2, x, 3, 2*x, 4], x)) + + def _make_ragged_ds(nrows): """Create a test dataset with RaggedTensor elements (of varying size).""" values = [[[i] * (i % 3) for i in range(j)] * (j % 3) for j in range(nrows)] @@ -64,6 +80,8 @@ def _make_dict_ds(nrows): 'shape=[]': ops.convert_to_tensor(x), 'shape=[x]': math_ops.range(x), 'shape=[x, 2]': array_ops.fill([x, 2], x), + 'shape=[2, x]': array_ops.fill([2, x], x), + 'shape=[2, x, 3, 2x, 4]': array_ops.fill([2, x, 3, 2*x, 4], x) } return _make_scalar_ds(nrows).map(transform) @@ -88,8 +106,9 @@ class RaggedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): test_base.default_test_combinations(), combinations.combine( make_dataset=[ - _make_scalar_ds, _make_vector_ds, _make_matrix_ds, - _make_ragged_ds, _make_dict_ds, _make_tuple_ds, + _make_scalar_ds, _make_vector_ds, _make_matrix_ds1, + _make_matrix_ds2, _make_ragged_ds, _make_5dtensor_ds, + _make_dict_ds, _make_tuple_ds, _make_matrix_ds_fully_defined, ], nrows=[0, 20, 23], batch_size=[4], diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py index a2d3e0086dd..398ec98a7cb 100644 --- a/tensorflow/python/data/experimental/ops/batching.py +++ b/tensorflow/python/data/experimental/ops/batching.py @@ -50,11 +50,23 @@ def dense_to_ragged_batch(batch_size, batch from being produced. Unlike `tf.data.Dataset.batch`, the input elements to be batched may have - different shapes, and each batch will be encoded as a `tf.RaggedTensor`. + different shapes: + + * If an input element is a `tf.Tensor` whose static `tf.TensorShape` is + fully defined, then it is batched as normal. + * If an input element is a `tf.Tensor` whose static `tf.TensorShape` contains + one or more axes with unknown size (i.e., `shape[i]=None`), then the output + will contain a `tf.RaggedTensor` that is ragged up to any of such + dimensions. + * If an input element is a `tf.RaggedTensor` or any other type, then it is + batched as normal. + Example: >>> dataset = tf.data.Dataset.from_tensor_slices(np.arange(6)) >>> dataset = dataset.map(lambda x: tf.range(x)) + >>> dataset.element_spec.shape + TensorShape([None]) >>> dataset = dataset.apply( ... tf.data.experimental.dense_to_ragged_batch(batch_size=2)) >>> for batch in dataset: @@ -385,32 +397,44 @@ class _DenseToRaggedDataset(dataset_ops.UnaryDataset): any new ragged tensors. Existing `tf.RaggedTensor` elements do *not* have their row_splits dtype changed. """ - # Replace each TensorSpec in the input dataset's structure with a # corresponding RaggedTensorSpec. def to_ragged_spec(spec): - if isinstance(spec, tensor_spec.TensorSpec) and spec.shape.ndims != 0: + """Returns the new spec based on RaggedTensors.""" + if (not isinstance(spec, tensor_spec.TensorSpec) or + spec.shape.rank is None or + spec.shape.is_fully_defined()): + return spec + else: + ragged_rank = max([ + axis for (axis, size) in enumerate(spec.shape.as_list()) + if size is None + ]) return ragged_tensor.RaggedTensorSpec( shape=spec.shape, dtype=spec.dtype, - ragged_rank=0, + ragged_rank=ragged_rank, row_splits_dtype=row_splits_dtype) - else: - return spec self._structure = nest.map_structure(to_ragged_spec, input_dataset.element_spec) # Replace each tf.Tensor value in the input dataset with a variant-encoded - # RaggedTensor. Since we're updating the corresponding structure to be + # RaggedTensor. Since we're updating the corresponding structure to be # a RaggedTensorSpec, this variant-encoded tensor will be decoded with # RaggedTensorSpec._from_tensor_list. def to_ragged_variant(value): - if isinstance(value, ops.Tensor) and value.shape.ndims != 0: - spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value)) - return spec._to_tensor_list(value)[0] # pylint: disable=protected-access - else: + """Re-encode Tensors as RaggedTensors.""" + if (not isinstance(value, ops.Tensor) or + value.shape.rank is None or + value.shape.is_fully_defined()): return value + else: + spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value)) + if spec._ragged_rank > 0: # pylint: disable=protected-access + value = ragged_tensor.RaggedTensor.from_tensor( + value, ragged_rank=spec._ragged_rank) # pylint: disable=protected-access + return spec._to_tensor_list(value)[0] # pylint: disable=protected-access # Tuples are automatically unpacked by `dataset.map` so we repack them. if dataset_ops._should_unpack_args(input_dataset.element_spec): # pylint: disable=protected-access