From 575ab84dabc8a4116abf20356f17dd0beaba011d Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Sun, 19 May 2019 15:18:17 -0700 Subject: [PATCH] [tf.data] Promoting `unbatch` from experimental to core API. PiperOrigin-RevId: 248960845 --- .../contrib/data/python/ops/batching.py | 2 +- .../data/experimental/kernel_tests/BUILD | 27 --- .../kernel_tests/restructured_dataset_test.py | 7 +- .../python/data/experimental/ops/batching.py | 148 +--------------- tensorflow/python/data/kernel_tests/BUILD | 25 +++ .../kernel_tests/unbatch_test.py | 30 ++-- tensorflow/python/data/ops/dataset_ops.py | 165 ++++++++++++++++++ .../golden/v1/tensorflow.data.-dataset.pbtxt | 4 + ...ow.data.-fixed-length-record-dataset.pbtxt | 4 + .../tensorflow.data.-t-f-record-dataset.pbtxt | 4 + .../tensorflow.data.-text-line-dataset.pbtxt | 4 + ...rflow.data.experimental.-csv-dataset.pbtxt | 4 + ...ow.data.experimental.-random-dataset.pbtxt | 4 + ...rflow.data.experimental.-sql-dataset.pbtxt | 4 + .../golden/v2/tensorflow.data.-dataset.pbtxt | 4 + ...ow.data.-fixed-length-record-dataset.pbtxt | 4 + .../tensorflow.data.-t-f-record-dataset.pbtxt | 4 + .../tensorflow.data.-text-line-dataset.pbtxt | 4 + ...rflow.data.experimental.-csv-dataset.pbtxt | 4 + ...ow.data.experimental.-random-dataset.pbtxt | 4 + ...rflow.data.experimental.-sql-dataset.pbtxt | 4 + 21 files changed, 267 insertions(+), 193 deletions(-) rename tensorflow/python/data/{experimental => }/kernel_tests/unbatch_test.py (90%) diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 6a88cc68162..0bff4fb7bcd 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -219,7 +219,7 @@ def assert_element_shape(expected_shapes): output_shapes = _merge_output_shapes( dataset_ops.get_legacy_output_shapes(dataset), expected_shapes) # pylint: disable=protected-access - return batching._RestructuredDataset( + return dataset_ops._RestructuredDataset( dataset.map(_check_shape), dataset_ops.get_legacy_output_types(dataset), output_shapes=output_shapes, diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index d90c7a99176..1652bd1ccff 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -726,33 +726,6 @@ py_test( ], ) -py_test( - name = "unbatch_test", - size = "medium", - srcs = ["unbatch_test.py"], - python_version = "PY2", - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:session", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:util", - "//tensorflow/python/data/experimental/ops:batching", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/ops/ragged", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - py_test( name = "unique_test", size = "small", diff --git a/tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py index 88c14f0a6ea..cf76a73eee6 100644 --- a/tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest @@ -46,7 +45,8 @@ class RestructuredDatasetTest(test_base.DatasetTestBase): for new_types, new_shape_lists in test_cases: # pylint: disable=protected-access - new = batching._RestructuredDataset(dataset, new_types, new_shape_lists) + new = dataset_ops._RestructuredDataset(dataset, new_types, + new_shape_lists) # pylint: enable=protected-access self.assertEqual(new_types, dataset_ops.get_legacy_output_types(new)) if new_shape_lists is not None: @@ -67,7 +67,8 @@ class RestructuredDatasetTest(test_base.DatasetTestBase): for new_types, new_shape_lists in fail_cases: with self.assertRaises(ValueError): # pylint: disable=protected-access - new = batching._RestructuredDataset(dataset, new_types, new_shape_lists) + new = dataset_ops._RestructuredDataset(dataset, new_types, + new_shape_lists) # pylint: enable=protected-access diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py index 202a80a7779..2be96fe4878 100644 --- a/tensorflow/python/data/experimental/ops/batching.py +++ b/tensorflow/python/data/experimental/ops/batching.py @@ -19,7 +19,6 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import convert -from tensorflow.python.data.util import nest from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -199,6 +198,7 @@ def map_and_batch(map_func, return _apply_fn +@deprecation.deprecated(None, "Use `tf.data.Dataset.unbatch()`.") @tf_export("data.experimental.unbatch") def unbatch(): """Splits elements of a dataset into multiple elements on the batch dimension. @@ -223,34 +223,7 @@ def unbatch(): """ def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - - # NOTE(mrry): We must ensure that any SparseTensors in `dataset` - # are normalized to the rank-1 dense representation, so that the - # sparse-oblivious unbatching logic will slice them - # appropriately. This leads to a somewhat inefficient re-encoding step - # for all SparseTensor components. - # TODO(mrry): Consider optimizing this in future if it turns out to be - # a bottleneck. - def normalize(arg, *rest): - # pylint: disable=protected-access - if rest: - return dataset._element_structure._to_batched_tensor_list((arg,) + rest) - else: - return dataset._element_structure._to_batched_tensor_list(arg) - - normalized_dataset = dataset.map(normalize) - - # NOTE(mrry): Our `map()` has lost information about the sparseness - # of any SparseTensor components, so re-apply the structure of the - # original dataset. - restructured_dataset = _RestructuredDataset( - normalized_dataset, - dataset_ops.get_legacy_output_types(dataset), - dataset_ops.get_legacy_output_shapes(dataset), - dataset_ops.get_legacy_output_classes(dataset), - allow_unsafe_cast=True) - return _UnbatchDataset(restructured_dataset) + return dataset.unbatch() return _apply_fn @@ -330,120 +303,3 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset): @property def _element_structure(self): return self._structure - - -class _RestructuredDataset(dataset_ops.UnaryDataset): - """An internal helper for changing the structure and shape of a dataset.""" - - def __init__(self, - dataset, - output_types, - output_shapes=None, - output_classes=None, - allow_unsafe_cast=False): - """Creates a new dataset with the given output types and shapes. - - The given `dataset` must have a structure that is convertible: - * `dataset.output_types` must be the same as `output_types` module nesting. - * Each shape in `dataset.output_shapes` must be compatible with each shape - in `output_shapes` (if given). - - Note: This helper permits "unsafe casts" for shapes, equivalent to using - `tf.Tensor.set_shape()` where domain-specific knowledge is available. - - Args: - dataset: A `Dataset` object. - output_types: A nested structure of `tf.DType` objects. - output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. - If omitted, the shapes will be inherited from `dataset`. - output_classes: (Optional.) A nested structure of class types. If omitted, - the class types will be inherited from `dataset`. - allow_unsafe_cast: (Optional.) If `True`, the caller may switch the - reported output types and shapes of the restructured dataset, e.g. to - switch a sparse tensor represented as `tf.variant` to its user-visible - type and shape. - - Raises: - ValueError: If either `output_types` or `output_shapes` is not compatible - with the structure of `dataset`. - """ - self._input_dataset = dataset - - input_types = dataset_ops.get_legacy_output_types(dataset) - if not allow_unsafe_cast: - # Validate that the types are compatible. - output_types = nest.map_structure(dtypes.as_dtype, output_types) - flat_original_types = nest.flatten(input_types) - flat_new_types = nest.flatten(output_types) - if flat_original_types != flat_new_types: - raise ValueError( - "Dataset with output types %r cannot be restructured to have " - "output types %r" % - (dataset_ops.get_legacy_output_types(dataset), output_types)) - - input_shapes = dataset_ops.get_legacy_output_shapes(dataset) - if output_shapes is None: - # Inherit shapes from the original `dataset`. - output_shapes = nest.pack_sequence_as( - output_types, nest.flatten(input_shapes)) - else: - if not allow_unsafe_cast: - # Validate that the shapes are compatible. - nest.assert_same_structure(output_types, output_shapes) - flat_original_shapes = nest.flatten(input_shapes) - flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) - - for original_shape, new_shape in zip(flat_original_shapes, - flat_new_shapes): - if not original_shape.is_compatible_with(new_shape): - raise ValueError( - "Dataset with output shapes %r cannot be restructured to have " - "incompatible output shapes %r" % (input_shapes, - output_shapes)) - output_shapes = nest.map_structure_up_to( - output_types, tensor_shape.as_shape, output_shapes) - - input_classes = dataset_ops.get_legacy_output_classes(dataset) - if output_classes is None: - # Inherit class types from the original `dataset`. - output_classes = nest.pack_sequence_as( - output_types, nest.flatten(input_classes)) - - self._structure = structure.convert_legacy_structure( - output_types, output_shapes, output_classes) - variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access - super(_RestructuredDataset, self).__init__(dataset, variant_tensor) - - @property - def _element_structure(self): - return self._structure - - -class _UnbatchDataset(dataset_ops.UnaryDataset): - """A dataset that splits the elements of its input into multiple elements.""" - - def __init__(self, input_dataset): - """See `unbatch()` for more details.""" - input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset) - flat_shapes = nest.flatten(input_shapes) - if any(s.ndims == 0 for s in flat_shapes): - raise ValueError("Cannot unbatch an input with scalar components.") - known_batch_dim = tensor_shape.Dimension(None) - for s in flat_shapes: - try: - known_batch_dim = known_batch_dim.merge_with(s[0]) - except ValueError: - raise ValueError("Cannot unbatch an input whose components have " - "different batch sizes.") - self._input_dataset = input_dataset - - self._structure = dataset_ops.get_structure(input_dataset)._unbatch() # pylint: disable=protected-access - - variant_tensor = ged_ops.experimental_unbatch_dataset( - self._input_dataset._variant_tensor, # pylint: disable=protected-access - **dataset_ops.flat_structure(self)) - super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor) - - @property - def _element_structure(self): - return self._structure diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 088e76922fb..8ff589d6976 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -715,6 +715,31 @@ py_library( ], ) +tf_py_test( + name = "unbatch_test", + size = "medium", + srcs = ["unbatch_test.py"], + additional_deps = [ + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:session", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/ops/ragged", + ], +) + tf_py_test( name = "window_test", size = "medium", diff --git a/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py b/tensorflow/python/data/kernel_tests/unbatch_test.py similarity index 90% rename from tensorflow/python/data/experimental/kernel_tests/unbatch_test.py rename to tensorflow/python/data/kernel_tests/unbatch_test.py index 22a9b9c8d60..6bc8f442cf9 100644 --- a/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py +++ b/tensorflow/python/data/kernel_tests/unbatch_test.py @@ -21,7 +21,6 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -41,8 +40,7 @@ from tensorflow.python.util import compat class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): def testUnbatchWithUnknownRankInput(self): - dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, - 3]).apply(batching.unbatch()) + dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]).unbatch() self.assertDatasetProduces(dataset, range(4)) def testUnbatchScalarDataset(self): @@ -51,7 +49,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): expected_types = (dtypes.int32,) * 3 data = data.batch(2) self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data)) - data = data.apply(batching.unbatch()) + data = data.unbatch() self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data)) self.assertDatasetProduces(data, [(i,) * 3 for i in range(10)]) @@ -63,7 +61,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): expected_types = (dtypes.int32, dtypes.string, dtypes.int32) data = data.batch(2) self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data)) - data = data.apply(batching.unbatch()) + data = data.unbatch() self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data)) self.assertDatasetProduces( @@ -75,9 +73,9 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): values=list(range(10)), dense_shape=[10, 10]) data = dataset_ops.Dataset.from_tensors(st) - data = data.apply(batching.unbatch()) + data = data.unbatch() data = data.batch(5) - data = data.apply(batching.unbatch()) + data = data.unbatch() expected_output = [ sparse_tensor.SparseTensorValue([[i]], [i], [10]) for i in range(10) ] @@ -91,9 +89,9 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]], [[5]], [[6]], [[7]], [[8]], [[9]]]) data = dataset_ops.Dataset.from_tensors((list(range(10)), st, rt)) - data = data.apply(batching.unbatch()) + data = data.unbatch() data = data.batch(5) - data = data.apply(batching.unbatch()) + data = data.unbatch() expected_output = [(i, sparse_tensor.SparseTensorValue([[i]], [i], [10]), ragged_factory_ops.constant_value([[i]])) for i in range(10)] @@ -104,10 +102,10 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]], [[5]], [[6]], [[7]], [[8]], [[9]]]) data = dataset_ops.Dataset.from_tensors(rt) - data = data.apply(batching.unbatch()) + data = data.unbatch() data = data.batch(5) data = data.batch(2) - data = data.apply(batching.unbatch()) + data = data.unbatch() expected_output = [ ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]]]), ragged_factory_ops.constant_value([[[5]], [[6]], [[7]], [[8]], [[9]]]), @@ -121,7 +119,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): expected_types = ((dtypes.int32,),) * 3 data = data.batch(2) self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data)) - data = data.apply(batching.unbatch()) + data = data.unbatch() self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data)) self.assertDatasetProduces(data, [((i,),) * 3 for i in range(10)]) @@ -134,7 +132,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): data = data.batch(2) self.assertAllEqual(expected_types, dataset_ops.get_legacy_output_types(data)) - data = data.apply(batching.unbatch()) + data = data.unbatch() self.assertAllEqual(expected_types, dataset_ops.get_legacy_output_types(data)) @@ -146,14 +144,14 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): data = dataset_ops.Dataset.from_tensors( (constant_op.constant([]), constant_op.constant([], shape=[0, 4]), constant_op.constant([], shape=[0, 4, 0]))) - data = data.apply(batching.unbatch()) + data = data.unbatch() self.assertDatasetProduces(data, []) def testUnbatchStaticShapeMismatch(self): data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8), np.arange(9))) with self.assertRaises(ValueError): - data.apply(batching.unbatch()) + data.unbatch() # Note: dynamic shape mismatch is graph specific test. @test_util.run_deprecated_v1 @@ -161,7 +159,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): ph1 = array_ops.placeholder(dtypes.int32, shape=[None]) ph2 = array_ops.placeholder(dtypes.int32, shape=None) data = dataset_ops.Dataset.from_tensors((ph1, ph2)) - data = data.apply(batching.unbatch()) + data = data.unbatch() iterator = dataset_ops.make_initializable_iterator(data) next_element = iterator.get_next() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 086c63d1976..3becbb7fabd 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -1452,6 +1452,54 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): output_shapes=state_structure._flat_shapes, output_types=state_structure._flat_types)) + def unbatch(self): + """Splits elements of a dataset into multiple elements. + + For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, + where `B` may vary for each input element, then for each element in the + dataset, the unbatched dataset will contain `B` consecutive elements + of shape `[a0, a1, ...]`. + + ```python + # NOTE: The following example uses `{ ... }` to represent the contents + # of a dataset. + ds = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } + + ds.unbatch() == {'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'} + ``` + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + # NOTE(mrry): We must ensure that any SparseTensors in `dataset` + # are normalized to the rank-1 dense representation, so that the + # sparse-oblivious unbatching logic will slice them + # appropriately. This leads to a somewhat inefficient re-encoding step + # for all SparseTensor components. + # TODO(mrry): Consider optimizing this in future if it turns out to be + # a bottleneck. + def normalize(arg, *rest): + # pylint: disable=protected-access + if rest: + return self._element_structure._to_batched_tensor_list((arg,) + rest) + else: + return self._element_structure._to_batched_tensor_list(arg) + + normalized_dataset = self.map(normalize) + + # NOTE(mrry): Our `map()` has lost information about the sparseness + # of any SparseTensor components, so re-apply the structure of the + # original dataset. + restructured_dataset = _RestructuredDataset( + normalized_dataset, + get_legacy_output_types(self), + get_legacy_output_shapes(self), + get_legacy_output_classes(self), + allow_unsafe_cast=True) + return _UnbatchDataset(restructured_dataset) + def with_options(self, options): """Returns a new `tf.data.Dataset` with the given options set. @@ -3556,3 +3604,120 @@ class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset): **flat_structure(self)) super(_PrivateThreadPoolDataset, self).__init__(input_dataset, variant_tensor) + + +class _RestructuredDataset(UnaryDataset): + """An internal helper for changing the structure and shape of a dataset.""" + + def __init__(self, + dataset, + output_types, + output_shapes=None, + output_classes=None, + allow_unsafe_cast=False): + """Creates a new dataset with the given output types and shapes. + + The given `dataset` must have a structure that is convertible: + * `dataset.output_types` must be the same as `output_types` module nesting. + * Each shape in `dataset.output_shapes` must be compatible with each shape + in `output_shapes` (if given). + + Note: This helper permits "unsafe casts" for shapes, equivalent to using + `tf.Tensor.set_shape()` where domain-specific knowledge is available. + + Args: + dataset: A `Dataset` object. + output_types: A nested structure of `tf.DType` objects. + output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. + If omitted, the shapes will be inherited from `dataset`. + output_classes: (Optional.) A nested structure of class types. If omitted, + the class types will be inherited from `dataset`. + allow_unsafe_cast: (Optional.) If `True`, the caller may switch the + reported output types and shapes of the restructured dataset, e.g. to + switch a sparse tensor represented as `tf.variant` to its user-visible + type and shape. + + Raises: + ValueError: If either `output_types` or `output_shapes` is not compatible + with the structure of `dataset`. + """ + self._input_dataset = dataset + + input_types = get_legacy_output_types(dataset) + if not allow_unsafe_cast: + # Validate that the types are compatible. + output_types = nest.map_structure(dtypes.as_dtype, output_types) + flat_original_types = nest.flatten(input_types) + flat_new_types = nest.flatten(output_types) + if flat_original_types != flat_new_types: + raise ValueError( + "Dataset with output types %r cannot be restructured to have " + "output types %r" % + (get_legacy_output_types(dataset), output_types)) + + input_shapes = get_legacy_output_shapes(dataset) + if output_shapes is None: + # Inherit shapes from the original `dataset`. + output_shapes = nest.pack_sequence_as( + output_types, nest.flatten(input_shapes)) + else: + if not allow_unsafe_cast: + # Validate that the shapes are compatible. + nest.assert_same_structure(output_types, output_shapes) + flat_original_shapes = nest.flatten(input_shapes) + flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) + + for original_shape, new_shape in zip(flat_original_shapes, + flat_new_shapes): + if not original_shape.is_compatible_with(new_shape): + raise ValueError( + "Dataset with output shapes %r cannot be restructured to have " + "incompatible output shapes %r" % (input_shapes, + output_shapes)) + output_shapes = nest.map_structure_up_to( + output_types, tensor_shape.as_shape, output_shapes) + + input_classes = get_legacy_output_classes(dataset) + if output_classes is None: + # Inherit class types from the original `dataset`. + output_classes = nest.pack_sequence_as( + output_types, nest.flatten(input_classes)) + + self._structure = structure_lib.convert_legacy_structure( + output_types, output_shapes, output_classes) + variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access + super(_RestructuredDataset, self).__init__(dataset, variant_tensor) + + @property + def _element_structure(self): + return self._structure + + +class _UnbatchDataset(UnaryDataset): + """A dataset that splits the elements of its input into multiple elements.""" + + def __init__(self, input_dataset): + """See `unbatch()` for more details.""" + input_shapes = get_legacy_output_shapes(input_dataset) + flat_shapes = nest.flatten(input_shapes) + if any(s.ndims == 0 for s in flat_shapes): + raise ValueError("Cannot unbatch an input with scalar components.") + known_batch_dim = tensor_shape.Dimension(None) + for s in flat_shapes: + try: + known_batch_dim = known_batch_dim.merge_with(s[0]) + except ValueError: + raise ValueError("Cannot unbatch an input whose components have " + "different batch sizes.") + self._input_dataset = input_dataset + + self._structure = get_structure(input_dataset)._unbatch() # pylint: disable=protected-access + + variant_tensor = ged_ops.experimental_unbatch_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access + **flat_structure(self)) + super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor) + + @property + def _element_structure(self): + return self._structure diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt index 94ffbca003f..cedf443100c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt @@ -133,6 +133,10 @@ tf_class { name: "take" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "unbatch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "window" argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt index 0ed2d44e551..de0c0d4cfe6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -135,6 +135,10 @@ tf_class { name: "take" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "unbatch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "window" argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt index 60f7e1f4c72..9a0ca3467ed 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt @@ -135,6 +135,10 @@ tf_class { name: "take" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "unbatch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "window" argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt index d335061158d..4a3ac523fe6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt @@ -135,6 +135,10 @@ tf_class { name: "take" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "unbatch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "window" argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt index 39431952268..6bdfd7c818d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -135,6 +135,10 @@ tf_class { name: "take" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "unbatch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "window" argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt index 6221aaa0b0d..e34175ca9a9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt @@ -135,6 +135,10 @@ tf_class { name: "take" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "unbatch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "window" argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt index d1903301787..bb7b40e9f6f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -135,6 +135,10 @@ tf_class { name: "take" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "unbatch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "window" argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt index bb56967c18a..f85436d45d0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt @@ -100,6 +100,10 @@ tf_class { name: "take" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "unbatch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "window" argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt index 597c5bce102..2f9c7de6f79 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -102,6 +102,10 @@ tf_class { name: "take" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "unbatch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "window" argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt index c24bac5bd95..a09cff47376 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt @@ -101,6 +101,10 @@ tf_class { name: "take" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "unbatch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "window" argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt index 8946cecfc83..1e3fb4a7010 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt @@ -102,6 +102,10 @@ tf_class { name: "take" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "unbatch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "window" argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt index 2365c62a61c..38447355e7f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -102,6 +102,10 @@ tf_class { name: "take" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "unbatch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "window" argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt index af008c6ad5b..3fa00204d59 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt @@ -102,6 +102,10 @@ tf_class { name: "take" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "unbatch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "window" argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt index 34370adc7da..549f5da506b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -102,6 +102,10 @@ tf_class { name: "take" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "unbatch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "window" argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "