[tf.data] Promoting unbatch
from experimental to core API.
PiperOrigin-RevId: 248960845
This commit is contained in:
parent
ee4657facf
commit
575ab84dab
@ -219,7 +219,7 @@ def assert_element_shape(expected_shapes):
|
||||
output_shapes = _merge_output_shapes(
|
||||
dataset_ops.get_legacy_output_shapes(dataset), expected_shapes)
|
||||
# pylint: disable=protected-access
|
||||
return batching._RestructuredDataset(
|
||||
return dataset_ops._RestructuredDataset(
|
||||
dataset.map(_check_shape),
|
||||
dataset_ops.get_legacy_output_types(dataset),
|
||||
output_shapes=output_shapes,
|
||||
|
@ -726,33 +726,6 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "unbatch_test",
|
||||
size = "medium",
|
||||
srcs = ["unbatch_test.py"],
|
||||
python_version = "PY2",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/ops/ragged",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "unique_test",
|
||||
size = "small",
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
@ -46,7 +45,8 @@ class RestructuredDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
for new_types, new_shape_lists in test_cases:
|
||||
# pylint: disable=protected-access
|
||||
new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
|
||||
new = dataset_ops._RestructuredDataset(dataset, new_types,
|
||||
new_shape_lists)
|
||||
# pylint: enable=protected-access
|
||||
self.assertEqual(new_types, dataset_ops.get_legacy_output_types(new))
|
||||
if new_shape_lists is not None:
|
||||
@ -67,7 +67,8 @@ class RestructuredDatasetTest(test_base.DatasetTestBase):
|
||||
for new_types, new_shape_lists in fail_cases:
|
||||
with self.assertRaises(ValueError):
|
||||
# pylint: disable=protected-access
|
||||
new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
|
||||
new = dataset_ops._RestructuredDataset(dataset, new_types,
|
||||
new_shape_lists)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
|
@ -19,7 +19,6 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import convert
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -199,6 +198,7 @@ def map_and_batch(map_func,
|
||||
return _apply_fn
|
||||
|
||||
|
||||
@deprecation.deprecated(None, "Use `tf.data.Dataset.unbatch()`.")
|
||||
@tf_export("data.experimental.unbatch")
|
||||
def unbatch():
|
||||
"""Splits elements of a dataset into multiple elements on the batch dimension.
|
||||
@ -223,34 +223,7 @@ def unbatch():
|
||||
"""
|
||||
|
||||
def _apply_fn(dataset):
|
||||
"""Function from `Dataset` to `Dataset` that applies the transformation."""
|
||||
|
||||
# NOTE(mrry): We must ensure that any SparseTensors in `dataset`
|
||||
# are normalized to the rank-1 dense representation, so that the
|
||||
# sparse-oblivious unbatching logic will slice them
|
||||
# appropriately. This leads to a somewhat inefficient re-encoding step
|
||||
# for all SparseTensor components.
|
||||
# TODO(mrry): Consider optimizing this in future if it turns out to be
|
||||
# a bottleneck.
|
||||
def normalize(arg, *rest):
|
||||
# pylint: disable=protected-access
|
||||
if rest:
|
||||
return dataset._element_structure._to_batched_tensor_list((arg,) + rest)
|
||||
else:
|
||||
return dataset._element_structure._to_batched_tensor_list(arg)
|
||||
|
||||
normalized_dataset = dataset.map(normalize)
|
||||
|
||||
# NOTE(mrry): Our `map()` has lost information about the sparseness
|
||||
# of any SparseTensor components, so re-apply the structure of the
|
||||
# original dataset.
|
||||
restructured_dataset = _RestructuredDataset(
|
||||
normalized_dataset,
|
||||
dataset_ops.get_legacy_output_types(dataset),
|
||||
dataset_ops.get_legacy_output_shapes(dataset),
|
||||
dataset_ops.get_legacy_output_classes(dataset),
|
||||
allow_unsafe_cast=True)
|
||||
return _UnbatchDataset(restructured_dataset)
|
||||
return dataset.unbatch()
|
||||
|
||||
return _apply_fn
|
||||
|
||||
@ -330,120 +303,3 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
class _RestructuredDataset(dataset_ops.UnaryDataset):
|
||||
"""An internal helper for changing the structure and shape of a dataset."""
|
||||
|
||||
def __init__(self,
|
||||
dataset,
|
||||
output_types,
|
||||
output_shapes=None,
|
||||
output_classes=None,
|
||||
allow_unsafe_cast=False):
|
||||
"""Creates a new dataset with the given output types and shapes.
|
||||
|
||||
The given `dataset` must have a structure that is convertible:
|
||||
* `dataset.output_types` must be the same as `output_types` module nesting.
|
||||
* Each shape in `dataset.output_shapes` must be compatible with each shape
|
||||
in `output_shapes` (if given).
|
||||
|
||||
Note: This helper permits "unsafe casts" for shapes, equivalent to using
|
||||
`tf.Tensor.set_shape()` where domain-specific knowledge is available.
|
||||
|
||||
Args:
|
||||
dataset: A `Dataset` object.
|
||||
output_types: A nested structure of `tf.DType` objects.
|
||||
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
|
||||
If omitted, the shapes will be inherited from `dataset`.
|
||||
output_classes: (Optional.) A nested structure of class types. If omitted,
|
||||
the class types will be inherited from `dataset`.
|
||||
allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
|
||||
reported output types and shapes of the restructured dataset, e.g. to
|
||||
switch a sparse tensor represented as `tf.variant` to its user-visible
|
||||
type and shape.
|
||||
|
||||
Raises:
|
||||
ValueError: If either `output_types` or `output_shapes` is not compatible
|
||||
with the structure of `dataset`.
|
||||
"""
|
||||
self._input_dataset = dataset
|
||||
|
||||
input_types = dataset_ops.get_legacy_output_types(dataset)
|
||||
if not allow_unsafe_cast:
|
||||
# Validate that the types are compatible.
|
||||
output_types = nest.map_structure(dtypes.as_dtype, output_types)
|
||||
flat_original_types = nest.flatten(input_types)
|
||||
flat_new_types = nest.flatten(output_types)
|
||||
if flat_original_types != flat_new_types:
|
||||
raise ValueError(
|
||||
"Dataset with output types %r cannot be restructured to have "
|
||||
"output types %r" %
|
||||
(dataset_ops.get_legacy_output_types(dataset), output_types))
|
||||
|
||||
input_shapes = dataset_ops.get_legacy_output_shapes(dataset)
|
||||
if output_shapes is None:
|
||||
# Inherit shapes from the original `dataset`.
|
||||
output_shapes = nest.pack_sequence_as(
|
||||
output_types, nest.flatten(input_shapes))
|
||||
else:
|
||||
if not allow_unsafe_cast:
|
||||
# Validate that the shapes are compatible.
|
||||
nest.assert_same_structure(output_types, output_shapes)
|
||||
flat_original_shapes = nest.flatten(input_shapes)
|
||||
flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
|
||||
|
||||
for original_shape, new_shape in zip(flat_original_shapes,
|
||||
flat_new_shapes):
|
||||
if not original_shape.is_compatible_with(new_shape):
|
||||
raise ValueError(
|
||||
"Dataset with output shapes %r cannot be restructured to have "
|
||||
"incompatible output shapes %r" % (input_shapes,
|
||||
output_shapes))
|
||||
output_shapes = nest.map_structure_up_to(
|
||||
output_types, tensor_shape.as_shape, output_shapes)
|
||||
|
||||
input_classes = dataset_ops.get_legacy_output_classes(dataset)
|
||||
if output_classes is None:
|
||||
# Inherit class types from the original `dataset`.
|
||||
output_classes = nest.pack_sequence_as(
|
||||
output_types, nest.flatten(input_classes))
|
||||
|
||||
self._structure = structure.convert_legacy_structure(
|
||||
output_types, output_shapes, output_classes)
|
||||
variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
|
||||
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
class _UnbatchDataset(dataset_ops.UnaryDataset):
|
||||
"""A dataset that splits the elements of its input into multiple elements."""
|
||||
|
||||
def __init__(self, input_dataset):
|
||||
"""See `unbatch()` for more details."""
|
||||
input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset)
|
||||
flat_shapes = nest.flatten(input_shapes)
|
||||
if any(s.ndims == 0 for s in flat_shapes):
|
||||
raise ValueError("Cannot unbatch an input with scalar components.")
|
||||
known_batch_dim = tensor_shape.Dimension(None)
|
||||
for s in flat_shapes:
|
||||
try:
|
||||
known_batch_dim = known_batch_dim.merge_with(s[0])
|
||||
except ValueError:
|
||||
raise ValueError("Cannot unbatch an input whose components have "
|
||||
"different batch sizes.")
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
self._structure = dataset_ops.get_structure(input_dataset)._unbatch() # pylint: disable=protected-access
|
||||
|
||||
variant_tensor = ged_ops.experimental_unbatch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
|
@ -715,6 +715,31 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "unbatch_test",
|
||||
size = "medium",
|
||||
srcs = ["unbatch_test.py"],
|
||||
additional_deps = [
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/ops/ragged",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "window_test",
|
||||
size = "medium",
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -41,8 +40,7 @@ from tensorflow.python.util import compat
|
||||
class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def testUnbatchWithUnknownRankInput(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors([0, 1, 2,
|
||||
3]).apply(batching.unbatch())
|
||||
dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]).unbatch()
|
||||
self.assertDatasetProduces(dataset, range(4))
|
||||
|
||||
def testUnbatchScalarDataset(self):
|
||||
@ -51,7 +49,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_types = (dtypes.int32,) * 3
|
||||
data = data.batch(2)
|
||||
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.unbatch()
|
||||
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
||||
|
||||
self.assertDatasetProduces(data, [(i,) * 3 for i in range(10)])
|
||||
@ -63,7 +61,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
|
||||
data = data.batch(2)
|
||||
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.unbatch()
|
||||
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
||||
|
||||
self.assertDatasetProduces(
|
||||
@ -75,9 +73,9 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
values=list(range(10)),
|
||||
dense_shape=[10, 10])
|
||||
data = dataset_ops.Dataset.from_tensors(st)
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.unbatch()
|
||||
data = data.batch(5)
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.unbatch()
|
||||
expected_output = [
|
||||
sparse_tensor.SparseTensorValue([[i]], [i], [10]) for i in range(10)
|
||||
]
|
||||
@ -91,9 +89,9 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
|
||||
[[5]], [[6]], [[7]], [[8]], [[9]]])
|
||||
data = dataset_ops.Dataset.from_tensors((list(range(10)), st, rt))
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.unbatch()
|
||||
data = data.batch(5)
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.unbatch()
|
||||
expected_output = [(i, sparse_tensor.SparseTensorValue([[i]], [i], [10]),
|
||||
ragged_factory_ops.constant_value([[i]]))
|
||||
for i in range(10)]
|
||||
@ -104,10 +102,10 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
|
||||
[[5]], [[6]], [[7]], [[8]], [[9]]])
|
||||
data = dataset_ops.Dataset.from_tensors(rt)
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.unbatch()
|
||||
data = data.batch(5)
|
||||
data = data.batch(2)
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.unbatch()
|
||||
expected_output = [
|
||||
ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]]]),
|
||||
ragged_factory_ops.constant_value([[[5]], [[6]], [[7]], [[8]], [[9]]]),
|
||||
@ -121,7 +119,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_types = ((dtypes.int32,),) * 3
|
||||
data = data.batch(2)
|
||||
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.unbatch()
|
||||
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
||||
|
||||
self.assertDatasetProduces(data, [((i,),) * 3 for i in range(10)])
|
||||
@ -134,7 +132,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
data = data.batch(2)
|
||||
self.assertAllEqual(expected_types,
|
||||
dataset_ops.get_legacy_output_types(data))
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.unbatch()
|
||||
self.assertAllEqual(expected_types,
|
||||
dataset_ops.get_legacy_output_types(data))
|
||||
|
||||
@ -146,14 +144,14 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
data = dataset_ops.Dataset.from_tensors(
|
||||
(constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
|
||||
constant_op.constant([], shape=[0, 4, 0])))
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.unbatch()
|
||||
self.assertDatasetProduces(data, [])
|
||||
|
||||
def testUnbatchStaticShapeMismatch(self):
|
||||
data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
|
||||
np.arange(9)))
|
||||
with self.assertRaises(ValueError):
|
||||
data.apply(batching.unbatch())
|
||||
data.unbatch()
|
||||
|
||||
# Note: dynamic shape mismatch is graph specific test.
|
||||
@test_util.run_deprecated_v1
|
||||
@ -161,7 +159,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
|
||||
ph2 = array_ops.placeholder(dtypes.int32, shape=None)
|
||||
data = dataset_ops.Dataset.from_tensors((ph1, ph2))
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.unbatch()
|
||||
iterator = dataset_ops.make_initializable_iterator(data)
|
||||
next_element = iterator.get_next()
|
||||
|
@ -1452,6 +1452,54 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
output_shapes=state_structure._flat_shapes,
|
||||
output_types=state_structure._flat_types))
|
||||
|
||||
def unbatch(self):
|
||||
"""Splits elements of a dataset into multiple elements.
|
||||
|
||||
For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
|
||||
where `B` may vary for each input element, then for each element in the
|
||||
dataset, the unbatched dataset will contain `B` consecutive elements
|
||||
of shape `[a0, a1, ...]`.
|
||||
|
||||
```python
|
||||
# NOTE: The following example uses `{ ... }` to represent the contents
|
||||
# of a dataset.
|
||||
ds = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
|
||||
|
||||
ds.unbatch() == {'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
|
||||
```
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
`tf.data.Dataset.apply`.
|
||||
"""
|
||||
|
||||
# NOTE(mrry): We must ensure that any SparseTensors in `dataset`
|
||||
# are normalized to the rank-1 dense representation, so that the
|
||||
# sparse-oblivious unbatching logic will slice them
|
||||
# appropriately. This leads to a somewhat inefficient re-encoding step
|
||||
# for all SparseTensor components.
|
||||
# TODO(mrry): Consider optimizing this in future if it turns out to be
|
||||
# a bottleneck.
|
||||
def normalize(arg, *rest):
|
||||
# pylint: disable=protected-access
|
||||
if rest:
|
||||
return self._element_structure._to_batched_tensor_list((arg,) + rest)
|
||||
else:
|
||||
return self._element_structure._to_batched_tensor_list(arg)
|
||||
|
||||
normalized_dataset = self.map(normalize)
|
||||
|
||||
# NOTE(mrry): Our `map()` has lost information about the sparseness
|
||||
# of any SparseTensor components, so re-apply the structure of the
|
||||
# original dataset.
|
||||
restructured_dataset = _RestructuredDataset(
|
||||
normalized_dataset,
|
||||
get_legacy_output_types(self),
|
||||
get_legacy_output_shapes(self),
|
||||
get_legacy_output_classes(self),
|
||||
allow_unsafe_cast=True)
|
||||
return _UnbatchDataset(restructured_dataset)
|
||||
|
||||
def with_options(self, options):
|
||||
"""Returns a new `tf.data.Dataset` with the given options set.
|
||||
|
||||
@ -3556,3 +3604,120 @@ class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset):
|
||||
**flat_structure(self))
|
||||
super(_PrivateThreadPoolDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
|
||||
class _RestructuredDataset(UnaryDataset):
|
||||
"""An internal helper for changing the structure and shape of a dataset."""
|
||||
|
||||
def __init__(self,
|
||||
dataset,
|
||||
output_types,
|
||||
output_shapes=None,
|
||||
output_classes=None,
|
||||
allow_unsafe_cast=False):
|
||||
"""Creates a new dataset with the given output types and shapes.
|
||||
|
||||
The given `dataset` must have a structure that is convertible:
|
||||
* `dataset.output_types` must be the same as `output_types` module nesting.
|
||||
* Each shape in `dataset.output_shapes` must be compatible with each shape
|
||||
in `output_shapes` (if given).
|
||||
|
||||
Note: This helper permits "unsafe casts" for shapes, equivalent to using
|
||||
`tf.Tensor.set_shape()` where domain-specific knowledge is available.
|
||||
|
||||
Args:
|
||||
dataset: A `Dataset` object.
|
||||
output_types: A nested structure of `tf.DType` objects.
|
||||
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
|
||||
If omitted, the shapes will be inherited from `dataset`.
|
||||
output_classes: (Optional.) A nested structure of class types. If omitted,
|
||||
the class types will be inherited from `dataset`.
|
||||
allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
|
||||
reported output types and shapes of the restructured dataset, e.g. to
|
||||
switch a sparse tensor represented as `tf.variant` to its user-visible
|
||||
type and shape.
|
||||
|
||||
Raises:
|
||||
ValueError: If either `output_types` or `output_shapes` is not compatible
|
||||
with the structure of `dataset`.
|
||||
"""
|
||||
self._input_dataset = dataset
|
||||
|
||||
input_types = get_legacy_output_types(dataset)
|
||||
if not allow_unsafe_cast:
|
||||
# Validate that the types are compatible.
|
||||
output_types = nest.map_structure(dtypes.as_dtype, output_types)
|
||||
flat_original_types = nest.flatten(input_types)
|
||||
flat_new_types = nest.flatten(output_types)
|
||||
if flat_original_types != flat_new_types:
|
||||
raise ValueError(
|
||||
"Dataset with output types %r cannot be restructured to have "
|
||||
"output types %r" %
|
||||
(get_legacy_output_types(dataset), output_types))
|
||||
|
||||
input_shapes = get_legacy_output_shapes(dataset)
|
||||
if output_shapes is None:
|
||||
# Inherit shapes from the original `dataset`.
|
||||
output_shapes = nest.pack_sequence_as(
|
||||
output_types, nest.flatten(input_shapes))
|
||||
else:
|
||||
if not allow_unsafe_cast:
|
||||
# Validate that the shapes are compatible.
|
||||
nest.assert_same_structure(output_types, output_shapes)
|
||||
flat_original_shapes = nest.flatten(input_shapes)
|
||||
flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
|
||||
|
||||
for original_shape, new_shape in zip(flat_original_shapes,
|
||||
flat_new_shapes):
|
||||
if not original_shape.is_compatible_with(new_shape):
|
||||
raise ValueError(
|
||||
"Dataset with output shapes %r cannot be restructured to have "
|
||||
"incompatible output shapes %r" % (input_shapes,
|
||||
output_shapes))
|
||||
output_shapes = nest.map_structure_up_to(
|
||||
output_types, tensor_shape.as_shape, output_shapes)
|
||||
|
||||
input_classes = get_legacy_output_classes(dataset)
|
||||
if output_classes is None:
|
||||
# Inherit class types from the original `dataset`.
|
||||
output_classes = nest.pack_sequence_as(
|
||||
output_types, nest.flatten(input_classes))
|
||||
|
||||
self._structure = structure_lib.convert_legacy_structure(
|
||||
output_types, output_shapes, output_classes)
|
||||
variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
|
||||
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
class _UnbatchDataset(UnaryDataset):
|
||||
"""A dataset that splits the elements of its input into multiple elements."""
|
||||
|
||||
def __init__(self, input_dataset):
|
||||
"""See `unbatch()` for more details."""
|
||||
input_shapes = get_legacy_output_shapes(input_dataset)
|
||||
flat_shapes = nest.flatten(input_shapes)
|
||||
if any(s.ndims == 0 for s in flat_shapes):
|
||||
raise ValueError("Cannot unbatch an input with scalar components.")
|
||||
known_batch_dim = tensor_shape.Dimension(None)
|
||||
for s in flat_shapes:
|
||||
try:
|
||||
known_batch_dim = known_batch_dim.merge_with(s[0])
|
||||
except ValueError:
|
||||
raise ValueError("Cannot unbatch an input whose components have "
|
||||
"different batch sizes.")
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
self._structure = get_structure(input_dataset)._unbatch() # pylint: disable=protected-access
|
||||
|
||||
variant_tensor = ged_ops.experimental_unbatch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**flat_structure(self))
|
||||
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
|
@ -133,6 +133,10 @@ tf_class {
|
||||
name: "take"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "window"
|
||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||
|
@ -135,6 +135,10 @@ tf_class {
|
||||
name: "take"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "window"
|
||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||
|
@ -135,6 +135,10 @@ tf_class {
|
||||
name: "take"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "window"
|
||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||
|
@ -135,6 +135,10 @@ tf_class {
|
||||
name: "take"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "window"
|
||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||
|
@ -135,6 +135,10 @@ tf_class {
|
||||
name: "take"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "window"
|
||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||
|
@ -135,6 +135,10 @@ tf_class {
|
||||
name: "take"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "window"
|
||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||
|
@ -135,6 +135,10 @@ tf_class {
|
||||
name: "take"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "window"
|
||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||
|
@ -100,6 +100,10 @@ tf_class {
|
||||
name: "take"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "window"
|
||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||
|
@ -102,6 +102,10 @@ tf_class {
|
||||
name: "take"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "window"
|
||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||
|
@ -101,6 +101,10 @@ tf_class {
|
||||
name: "take"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "window"
|
||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||
|
@ -102,6 +102,10 @@ tf_class {
|
||||
name: "take"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "window"
|
||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||
|
@ -102,6 +102,10 @@ tf_class {
|
||||
name: "take"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "window"
|
||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||
|
@ -102,6 +102,10 @@ tf_class {
|
||||
name: "take"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "window"
|
||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||
|
@ -102,6 +102,10 @@ tf_class {
|
||||
name: "take"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "window"
|
||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user