[tf.data] Promoting unbatch from experimental to core API.

PiperOrigin-RevId: 248960845
This commit is contained in:
Jiri Simsa 2019-05-19 15:18:17 -07:00 committed by TensorFlower Gardener
parent ee4657facf
commit 575ab84dab
21 changed files with 267 additions and 193 deletions

View File

@ -219,7 +219,7 @@ def assert_element_shape(expected_shapes):
output_shapes = _merge_output_shapes( output_shapes = _merge_output_shapes(
dataset_ops.get_legacy_output_shapes(dataset), expected_shapes) dataset_ops.get_legacy_output_shapes(dataset), expected_shapes)
# pylint: disable=protected-access # pylint: disable=protected-access
return batching._RestructuredDataset( return dataset_ops._RestructuredDataset(
dataset.map(_check_shape), dataset.map(_check_shape),
dataset_ops.get_legacy_output_types(dataset), dataset_ops.get_legacy_output_types(dataset),
output_shapes=output_shapes, output_shapes=output_shapes,

View File

@ -726,33 +726,6 @@ py_test(
], ],
) )
py_test(
name = "unbatch_test",
size = "medium",
srcs = ["unbatch_test.py"],
python_version = "PY2",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/ops/ragged",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
py_test( py_test(
name = "unique_test", name = "unique_test",
size = "small", size = "small",

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest from tensorflow.python.data.util import nest
@ -46,7 +45,8 @@ class RestructuredDatasetTest(test_base.DatasetTestBase):
for new_types, new_shape_lists in test_cases: for new_types, new_shape_lists in test_cases:
# pylint: disable=protected-access # pylint: disable=protected-access
new = batching._RestructuredDataset(dataset, new_types, new_shape_lists) new = dataset_ops._RestructuredDataset(dataset, new_types,
new_shape_lists)
# pylint: enable=protected-access # pylint: enable=protected-access
self.assertEqual(new_types, dataset_ops.get_legacy_output_types(new)) self.assertEqual(new_types, dataset_ops.get_legacy_output_types(new))
if new_shape_lists is not None: if new_shape_lists is not None:
@ -67,7 +67,8 @@ class RestructuredDatasetTest(test_base.DatasetTestBase):
for new_types, new_shape_lists in fail_cases: for new_types, new_shape_lists in fail_cases:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# pylint: disable=protected-access # pylint: disable=protected-access
new = batching._RestructuredDataset(dataset, new_types, new_shape_lists) new = dataset_ops._RestructuredDataset(dataset, new_types,
new_shape_lists)
# pylint: enable=protected-access # pylint: enable=protected-access

View File

@ -19,7 +19,6 @@ from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure from tensorflow.python.data.util import structure
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -199,6 +198,7 @@ def map_and_batch(map_func,
return _apply_fn return _apply_fn
@deprecation.deprecated(None, "Use `tf.data.Dataset.unbatch()`.")
@tf_export("data.experimental.unbatch") @tf_export("data.experimental.unbatch")
def unbatch(): def unbatch():
"""Splits elements of a dataset into multiple elements on the batch dimension. """Splits elements of a dataset into multiple elements on the batch dimension.
@ -223,34 +223,7 @@ def unbatch():
""" """
def _apply_fn(dataset): def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation.""" return dataset.unbatch()
# NOTE(mrry): We must ensure that any SparseTensors in `dataset`
# are normalized to the rank-1 dense representation, so that the
# sparse-oblivious unbatching logic will slice them
# appropriately. This leads to a somewhat inefficient re-encoding step
# for all SparseTensor components.
# TODO(mrry): Consider optimizing this in future if it turns out to be
# a bottleneck.
def normalize(arg, *rest):
# pylint: disable=protected-access
if rest:
return dataset._element_structure._to_batched_tensor_list((arg,) + rest)
else:
return dataset._element_structure._to_batched_tensor_list(arg)
normalized_dataset = dataset.map(normalize)
# NOTE(mrry): Our `map()` has lost information about the sparseness
# of any SparseTensor components, so re-apply the structure of the
# original dataset.
restructured_dataset = _RestructuredDataset(
normalized_dataset,
dataset_ops.get_legacy_output_types(dataset),
dataset_ops.get_legacy_output_shapes(dataset),
dataset_ops.get_legacy_output_classes(dataset),
allow_unsafe_cast=True)
return _UnbatchDataset(restructured_dataset)
return _apply_fn return _apply_fn
@ -330,120 +303,3 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
@property @property
def _element_structure(self): def _element_structure(self):
return self._structure return self._structure
class _RestructuredDataset(dataset_ops.UnaryDataset):
"""An internal helper for changing the structure and shape of a dataset."""
def __init__(self,
dataset,
output_types,
output_shapes=None,
output_classes=None,
allow_unsafe_cast=False):
"""Creates a new dataset with the given output types and shapes.
The given `dataset` must have a structure that is convertible:
* `dataset.output_types` must be the same as `output_types` module nesting.
* Each shape in `dataset.output_shapes` must be compatible with each shape
in `output_shapes` (if given).
Note: This helper permits "unsafe casts" for shapes, equivalent to using
`tf.Tensor.set_shape()` where domain-specific knowledge is available.
Args:
dataset: A `Dataset` object.
output_types: A nested structure of `tf.DType` objects.
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
If omitted, the shapes will be inherited from `dataset`.
output_classes: (Optional.) A nested structure of class types. If omitted,
the class types will be inherited from `dataset`.
allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
reported output types and shapes of the restructured dataset, e.g. to
switch a sparse tensor represented as `tf.variant` to its user-visible
type and shape.
Raises:
ValueError: If either `output_types` or `output_shapes` is not compatible
with the structure of `dataset`.
"""
self._input_dataset = dataset
input_types = dataset_ops.get_legacy_output_types(dataset)
if not allow_unsafe_cast:
# Validate that the types are compatible.
output_types = nest.map_structure(dtypes.as_dtype, output_types)
flat_original_types = nest.flatten(input_types)
flat_new_types = nest.flatten(output_types)
if flat_original_types != flat_new_types:
raise ValueError(
"Dataset with output types %r cannot be restructured to have "
"output types %r" %
(dataset_ops.get_legacy_output_types(dataset), output_types))
input_shapes = dataset_ops.get_legacy_output_shapes(dataset)
if output_shapes is None:
# Inherit shapes from the original `dataset`.
output_shapes = nest.pack_sequence_as(
output_types, nest.flatten(input_shapes))
else:
if not allow_unsafe_cast:
# Validate that the shapes are compatible.
nest.assert_same_structure(output_types, output_shapes)
flat_original_shapes = nest.flatten(input_shapes)
flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
for original_shape, new_shape in zip(flat_original_shapes,
flat_new_shapes):
if not original_shape.is_compatible_with(new_shape):
raise ValueError(
"Dataset with output shapes %r cannot be restructured to have "
"incompatible output shapes %r" % (input_shapes,
output_shapes))
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
input_classes = dataset_ops.get_legacy_output_classes(dataset)
if output_classes is None:
# Inherit class types from the original `dataset`.
output_classes = nest.pack_sequence_as(
output_types, nest.flatten(input_classes))
self._structure = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
@property
def _element_structure(self):
return self._structure
class _UnbatchDataset(dataset_ops.UnaryDataset):
"""A dataset that splits the elements of its input into multiple elements."""
def __init__(self, input_dataset):
"""See `unbatch()` for more details."""
input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset)
flat_shapes = nest.flatten(input_shapes)
if any(s.ndims == 0 for s in flat_shapes):
raise ValueError("Cannot unbatch an input with scalar components.")
known_batch_dim = tensor_shape.Dimension(None)
for s in flat_shapes:
try:
known_batch_dim = known_batch_dim.merge_with(s[0])
except ValueError:
raise ValueError("Cannot unbatch an input whose components have "
"different batch sizes.")
self._input_dataset = input_dataset
self._structure = dataset_ops.get_structure(input_dataset)._unbatch() # pylint: disable=protected-access
variant_tensor = ged_ops.experimental_unbatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
return self._structure

View File

@ -715,6 +715,31 @@ py_library(
], ],
) )
tf_py_test(
name = "unbatch_test",
size = "medium",
srcs = ["unbatch_test.py"],
additional_deps = [
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/ops/ragged",
],
)
tf_py_test( tf_py_test(
name = "window_test", name = "window_test",
size = "medium", size = "medium",

View File

@ -21,7 +21,6 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
@ -41,8 +40,7 @@ from tensorflow.python.util import compat
class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
def testUnbatchWithUnknownRankInput(self): def testUnbatchWithUnknownRankInput(self):
dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]).unbatch()
3]).apply(batching.unbatch())
self.assertDatasetProduces(dataset, range(4)) self.assertDatasetProduces(dataset, range(4))
def testUnbatchScalarDataset(self): def testUnbatchScalarDataset(self):
@ -51,7 +49,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_types = (dtypes.int32,) * 3 expected_types = (dtypes.int32,) * 3
data = data.batch(2) data = data.batch(2)
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data)) self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
data = data.apply(batching.unbatch()) data = data.unbatch()
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data)) self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
self.assertDatasetProduces(data, [(i,) * 3 for i in range(10)]) self.assertDatasetProduces(data, [(i,) * 3 for i in range(10)])
@ -63,7 +61,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_types = (dtypes.int32, dtypes.string, dtypes.int32) expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
data = data.batch(2) data = data.batch(2)
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data)) self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
data = data.apply(batching.unbatch()) data = data.unbatch()
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data)) self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
self.assertDatasetProduces( self.assertDatasetProduces(
@ -75,9 +73,9 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
values=list(range(10)), values=list(range(10)),
dense_shape=[10, 10]) dense_shape=[10, 10])
data = dataset_ops.Dataset.from_tensors(st) data = dataset_ops.Dataset.from_tensors(st)
data = data.apply(batching.unbatch()) data = data.unbatch()
data = data.batch(5) data = data.batch(5)
data = data.apply(batching.unbatch()) data = data.unbatch()
expected_output = [ expected_output = [
sparse_tensor.SparseTensorValue([[i]], [i], [10]) for i in range(10) sparse_tensor.SparseTensorValue([[i]], [i], [10]) for i in range(10)
] ]
@ -91,9 +89,9 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]], rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
[[5]], [[6]], [[7]], [[8]], [[9]]]) [[5]], [[6]], [[7]], [[8]], [[9]]])
data = dataset_ops.Dataset.from_tensors((list(range(10)), st, rt)) data = dataset_ops.Dataset.from_tensors((list(range(10)), st, rt))
data = data.apply(batching.unbatch()) data = data.unbatch()
data = data.batch(5) data = data.batch(5)
data = data.apply(batching.unbatch()) data = data.unbatch()
expected_output = [(i, sparse_tensor.SparseTensorValue([[i]], [i], [10]), expected_output = [(i, sparse_tensor.SparseTensorValue([[i]], [i], [10]),
ragged_factory_ops.constant_value([[i]])) ragged_factory_ops.constant_value([[i]]))
for i in range(10)] for i in range(10)]
@ -104,10 +102,10 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]], rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
[[5]], [[6]], [[7]], [[8]], [[9]]]) [[5]], [[6]], [[7]], [[8]], [[9]]])
data = dataset_ops.Dataset.from_tensors(rt) data = dataset_ops.Dataset.from_tensors(rt)
data = data.apply(batching.unbatch()) data = data.unbatch()
data = data.batch(5) data = data.batch(5)
data = data.batch(2) data = data.batch(2)
data = data.apply(batching.unbatch()) data = data.unbatch()
expected_output = [ expected_output = [
ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]]]), ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]]]),
ragged_factory_ops.constant_value([[[5]], [[6]], [[7]], [[8]], [[9]]]), ragged_factory_ops.constant_value([[[5]], [[6]], [[7]], [[8]], [[9]]]),
@ -121,7 +119,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_types = ((dtypes.int32,),) * 3 expected_types = ((dtypes.int32,),) * 3
data = data.batch(2) data = data.batch(2)
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data)) self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
data = data.apply(batching.unbatch()) data = data.unbatch()
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data)) self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
self.assertDatasetProduces(data, [((i,),) * 3 for i in range(10)]) self.assertDatasetProduces(data, [((i,),) * 3 for i in range(10)])
@ -134,7 +132,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
data = data.batch(2) data = data.batch(2)
self.assertAllEqual(expected_types, self.assertAllEqual(expected_types,
dataset_ops.get_legacy_output_types(data)) dataset_ops.get_legacy_output_types(data))
data = data.apply(batching.unbatch()) data = data.unbatch()
self.assertAllEqual(expected_types, self.assertAllEqual(expected_types,
dataset_ops.get_legacy_output_types(data)) dataset_ops.get_legacy_output_types(data))
@ -146,14 +144,14 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
data = dataset_ops.Dataset.from_tensors( data = dataset_ops.Dataset.from_tensors(
(constant_op.constant([]), constant_op.constant([], shape=[0, 4]), (constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
constant_op.constant([], shape=[0, 4, 0]))) constant_op.constant([], shape=[0, 4, 0])))
data = data.apply(batching.unbatch()) data = data.unbatch()
self.assertDatasetProduces(data, []) self.assertDatasetProduces(data, [])
def testUnbatchStaticShapeMismatch(self): def testUnbatchStaticShapeMismatch(self):
data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8), data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
np.arange(9))) np.arange(9)))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
data.apply(batching.unbatch()) data.unbatch()
# Note: dynamic shape mismatch is graph specific test. # Note: dynamic shape mismatch is graph specific test.
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@ -161,7 +159,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
ph1 = array_ops.placeholder(dtypes.int32, shape=[None]) ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
ph2 = array_ops.placeholder(dtypes.int32, shape=None) ph2 = array_ops.placeholder(dtypes.int32, shape=None)
data = dataset_ops.Dataset.from_tensors((ph1, ph2)) data = dataset_ops.Dataset.from_tensors((ph1, ph2))
data = data.apply(batching.unbatch()) data = data.unbatch()
iterator = dataset_ops.make_initializable_iterator(data) iterator = dataset_ops.make_initializable_iterator(data)
next_element = iterator.get_next() next_element = iterator.get_next()

View File

@ -1452,6 +1452,54 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
output_shapes=state_structure._flat_shapes, output_shapes=state_structure._flat_shapes,
output_types=state_structure._flat_types)) output_types=state_structure._flat_types))
def unbatch(self):
"""Splits elements of a dataset into multiple elements.
For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
where `B` may vary for each input element, then for each element in the
dataset, the unbatched dataset will contain `B` consecutive elements
of shape `[a0, a1, ...]`.
```python
# NOTE: The following example uses `{ ... }` to represent the contents
# of a dataset.
ds = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
ds.unbatch() == {'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
```
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
# NOTE(mrry): We must ensure that any SparseTensors in `dataset`
# are normalized to the rank-1 dense representation, so that the
# sparse-oblivious unbatching logic will slice them
# appropriately. This leads to a somewhat inefficient re-encoding step
# for all SparseTensor components.
# TODO(mrry): Consider optimizing this in future if it turns out to be
# a bottleneck.
def normalize(arg, *rest):
# pylint: disable=protected-access
if rest:
return self._element_structure._to_batched_tensor_list((arg,) + rest)
else:
return self._element_structure._to_batched_tensor_list(arg)
normalized_dataset = self.map(normalize)
# NOTE(mrry): Our `map()` has lost information about the sparseness
# of any SparseTensor components, so re-apply the structure of the
# original dataset.
restructured_dataset = _RestructuredDataset(
normalized_dataset,
get_legacy_output_types(self),
get_legacy_output_shapes(self),
get_legacy_output_classes(self),
allow_unsafe_cast=True)
return _UnbatchDataset(restructured_dataset)
def with_options(self, options): def with_options(self, options):
"""Returns a new `tf.data.Dataset` with the given options set. """Returns a new `tf.data.Dataset` with the given options set.
@ -3556,3 +3604,120 @@ class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset):
**flat_structure(self)) **flat_structure(self))
super(_PrivateThreadPoolDataset, self).__init__(input_dataset, super(_PrivateThreadPoolDataset, self).__init__(input_dataset,
variant_tensor) variant_tensor)
class _RestructuredDataset(UnaryDataset):
"""An internal helper for changing the structure and shape of a dataset."""
def __init__(self,
dataset,
output_types,
output_shapes=None,
output_classes=None,
allow_unsafe_cast=False):
"""Creates a new dataset with the given output types and shapes.
The given `dataset` must have a structure that is convertible:
* `dataset.output_types` must be the same as `output_types` module nesting.
* Each shape in `dataset.output_shapes` must be compatible with each shape
in `output_shapes` (if given).
Note: This helper permits "unsafe casts" for shapes, equivalent to using
`tf.Tensor.set_shape()` where domain-specific knowledge is available.
Args:
dataset: A `Dataset` object.
output_types: A nested structure of `tf.DType` objects.
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
If omitted, the shapes will be inherited from `dataset`.
output_classes: (Optional.) A nested structure of class types. If omitted,
the class types will be inherited from `dataset`.
allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
reported output types and shapes of the restructured dataset, e.g. to
switch a sparse tensor represented as `tf.variant` to its user-visible
type and shape.
Raises:
ValueError: If either `output_types` or `output_shapes` is not compatible
with the structure of `dataset`.
"""
self._input_dataset = dataset
input_types = get_legacy_output_types(dataset)
if not allow_unsafe_cast:
# Validate that the types are compatible.
output_types = nest.map_structure(dtypes.as_dtype, output_types)
flat_original_types = nest.flatten(input_types)
flat_new_types = nest.flatten(output_types)
if flat_original_types != flat_new_types:
raise ValueError(
"Dataset with output types %r cannot be restructured to have "
"output types %r" %
(get_legacy_output_types(dataset), output_types))
input_shapes = get_legacy_output_shapes(dataset)
if output_shapes is None:
# Inherit shapes from the original `dataset`.
output_shapes = nest.pack_sequence_as(
output_types, nest.flatten(input_shapes))
else:
if not allow_unsafe_cast:
# Validate that the shapes are compatible.
nest.assert_same_structure(output_types, output_shapes)
flat_original_shapes = nest.flatten(input_shapes)
flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
for original_shape, new_shape in zip(flat_original_shapes,
flat_new_shapes):
if not original_shape.is_compatible_with(new_shape):
raise ValueError(
"Dataset with output shapes %r cannot be restructured to have "
"incompatible output shapes %r" % (input_shapes,
output_shapes))
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
input_classes = get_legacy_output_classes(dataset)
if output_classes is None:
# Inherit class types from the original `dataset`.
output_classes = nest.pack_sequence_as(
output_types, nest.flatten(input_classes))
self._structure = structure_lib.convert_legacy_structure(
output_types, output_shapes, output_classes)
variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
@property
def _element_structure(self):
return self._structure
class _UnbatchDataset(UnaryDataset):
"""A dataset that splits the elements of its input into multiple elements."""
def __init__(self, input_dataset):
"""See `unbatch()` for more details."""
input_shapes = get_legacy_output_shapes(input_dataset)
flat_shapes = nest.flatten(input_shapes)
if any(s.ndims == 0 for s in flat_shapes):
raise ValueError("Cannot unbatch an input with scalar components.")
known_batch_dim = tensor_shape.Dimension(None)
for s in flat_shapes:
try:
known_batch_dim = known_batch_dim.merge_with(s[0])
except ValueError:
raise ValueError("Cannot unbatch an input whose components have "
"different batch sizes.")
self._input_dataset = input_dataset
self._structure = get_structure(input_dataset)._unbatch() # pylint: disable=protected-access
variant_tensor = ged_ops.experimental_unbatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**flat_structure(self))
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
return self._structure

View File

@ -133,6 +133,10 @@ tf_class {
name: "take" name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "unbatch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "window" name: "window"
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "

View File

@ -135,6 +135,10 @@ tf_class {
name: "take" name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "unbatch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "window" name: "window"
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "

View File

@ -135,6 +135,10 @@ tf_class {
name: "take" name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "unbatch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "window" name: "window"
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "

View File

@ -135,6 +135,10 @@ tf_class {
name: "take" name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "unbatch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "window" name: "window"
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "

View File

@ -135,6 +135,10 @@ tf_class {
name: "take" name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "unbatch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "window" name: "window"
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "

View File

@ -135,6 +135,10 @@ tf_class {
name: "take" name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "unbatch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "window" name: "window"
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "

View File

@ -135,6 +135,10 @@ tf_class {
name: "take" name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "unbatch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "window" name: "window"
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "

View File

@ -100,6 +100,10 @@ tf_class {
name: "take" name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "unbatch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "window" name: "window"
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "

View File

@ -102,6 +102,10 @@ tf_class {
name: "take" name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "unbatch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "window" name: "window"
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "

View File

@ -101,6 +101,10 @@ tf_class {
name: "take" name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "unbatch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "window" name: "window"
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "

View File

@ -102,6 +102,10 @@ tf_class {
name: "take" name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "unbatch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "window" name: "window"
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "

View File

@ -102,6 +102,10 @@ tf_class {
name: "take" name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "unbatch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "window" name: "window"
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "

View File

@ -102,6 +102,10 @@ tf_class {
name: "take" name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "unbatch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "window" name: "window"
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "

View File

@ -102,6 +102,10 @@ tf_class {
name: "take" name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "unbatch"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "window" name: "window"
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], " argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "