[tf.data] Promoting unbatch
from experimental to core API.
PiperOrigin-RevId: 248960845
This commit is contained in:
parent
ee4657facf
commit
575ab84dab
@ -219,7 +219,7 @@ def assert_element_shape(expected_shapes):
|
|||||||
output_shapes = _merge_output_shapes(
|
output_shapes = _merge_output_shapes(
|
||||||
dataset_ops.get_legacy_output_shapes(dataset), expected_shapes)
|
dataset_ops.get_legacy_output_shapes(dataset), expected_shapes)
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return batching._RestructuredDataset(
|
return dataset_ops._RestructuredDataset(
|
||||||
dataset.map(_check_shape),
|
dataset.map(_check_shape),
|
||||||
dataset_ops.get_legacy_output_types(dataset),
|
dataset_ops.get_legacy_output_types(dataset),
|
||||||
output_shapes=output_shapes,
|
output_shapes=output_shapes,
|
||||||
|
@ -726,33 +726,6 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "unbatch_test",
|
|
||||||
size = "medium",
|
|
||||||
srcs = ["unbatch_test.py"],
|
|
||||||
python_version = "PY2",
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/python:array_ops",
|
|
||||||
"//tensorflow/python:client_testlib",
|
|
||||||
"//tensorflow/python:constant_op",
|
|
||||||
"//tensorflow/python:dtypes",
|
|
||||||
"//tensorflow/python:errors",
|
|
||||||
"//tensorflow/python:framework_ops",
|
|
||||||
"//tensorflow/python:math_ops",
|
|
||||||
"//tensorflow/python:session",
|
|
||||||
"//tensorflow/python:sparse_tensor",
|
|
||||||
"//tensorflow/python:string_ops",
|
|
||||||
"//tensorflow/python:tensor_shape",
|
|
||||||
"//tensorflow/python:util",
|
|
||||||
"//tensorflow/python/data/experimental/ops:batching",
|
|
||||||
"//tensorflow/python/data/kernel_tests:test_base",
|
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
|
||||||
"//tensorflow/python/ops/ragged",
|
|
||||||
"//third_party/py/numpy",
|
|
||||||
"@absl_py//absl/testing:parameterized",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "unique_test",
|
name = "unique_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops import batching
|
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.util import nest
|
from tensorflow.python.data.util import nest
|
||||||
@ -46,7 +45,8 @@ class RestructuredDatasetTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
for new_types, new_shape_lists in test_cases:
|
for new_types, new_shape_lists in test_cases:
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
|
new = dataset_ops._RestructuredDataset(dataset, new_types,
|
||||||
|
new_shape_lists)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
self.assertEqual(new_types, dataset_ops.get_legacy_output_types(new))
|
self.assertEqual(new_types, dataset_ops.get_legacy_output_types(new))
|
||||||
if new_shape_lists is not None:
|
if new_shape_lists is not None:
|
||||||
@ -67,7 +67,8 @@ class RestructuredDatasetTest(test_base.DatasetTestBase):
|
|||||||
for new_types, new_shape_lists in fail_cases:
|
for new_types, new_shape_lists in fail_cases:
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
|
new = dataset_ops._RestructuredDataset(dataset, new_types,
|
||||||
|
new_shape_lists)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.util import convert
|
from tensorflow.python.data.util import convert
|
||||||
from tensorflow.python.data.util import nest
|
|
||||||
from tensorflow.python.data.util import structure
|
from tensorflow.python.data.util import structure
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -199,6 +198,7 @@ def map_and_batch(map_func,
|
|||||||
return _apply_fn
|
return _apply_fn
|
||||||
|
|
||||||
|
|
||||||
|
@deprecation.deprecated(None, "Use `tf.data.Dataset.unbatch()`.")
|
||||||
@tf_export("data.experimental.unbatch")
|
@tf_export("data.experimental.unbatch")
|
||||||
def unbatch():
|
def unbatch():
|
||||||
"""Splits elements of a dataset into multiple elements on the batch dimension.
|
"""Splits elements of a dataset into multiple elements on the batch dimension.
|
||||||
@ -223,34 +223,7 @@ def unbatch():
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _apply_fn(dataset):
|
def _apply_fn(dataset):
|
||||||
"""Function from `Dataset` to `Dataset` that applies the transformation."""
|
return dataset.unbatch()
|
||||||
|
|
||||||
# NOTE(mrry): We must ensure that any SparseTensors in `dataset`
|
|
||||||
# are normalized to the rank-1 dense representation, so that the
|
|
||||||
# sparse-oblivious unbatching logic will slice them
|
|
||||||
# appropriately. This leads to a somewhat inefficient re-encoding step
|
|
||||||
# for all SparseTensor components.
|
|
||||||
# TODO(mrry): Consider optimizing this in future if it turns out to be
|
|
||||||
# a bottleneck.
|
|
||||||
def normalize(arg, *rest):
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
if rest:
|
|
||||||
return dataset._element_structure._to_batched_tensor_list((arg,) + rest)
|
|
||||||
else:
|
|
||||||
return dataset._element_structure._to_batched_tensor_list(arg)
|
|
||||||
|
|
||||||
normalized_dataset = dataset.map(normalize)
|
|
||||||
|
|
||||||
# NOTE(mrry): Our `map()` has lost information about the sparseness
|
|
||||||
# of any SparseTensor components, so re-apply the structure of the
|
|
||||||
# original dataset.
|
|
||||||
restructured_dataset = _RestructuredDataset(
|
|
||||||
normalized_dataset,
|
|
||||||
dataset_ops.get_legacy_output_types(dataset),
|
|
||||||
dataset_ops.get_legacy_output_shapes(dataset),
|
|
||||||
dataset_ops.get_legacy_output_classes(dataset),
|
|
||||||
allow_unsafe_cast=True)
|
|
||||||
return _UnbatchDataset(restructured_dataset)
|
|
||||||
|
|
||||||
return _apply_fn
|
return _apply_fn
|
||||||
|
|
||||||
@ -330,120 +303,3 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
|||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def _element_structure(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
class _RestructuredDataset(dataset_ops.UnaryDataset):
|
|
||||||
"""An internal helper for changing the structure and shape of a dataset."""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
dataset,
|
|
||||||
output_types,
|
|
||||||
output_shapes=None,
|
|
||||||
output_classes=None,
|
|
||||||
allow_unsafe_cast=False):
|
|
||||||
"""Creates a new dataset with the given output types and shapes.
|
|
||||||
|
|
||||||
The given `dataset` must have a structure that is convertible:
|
|
||||||
* `dataset.output_types` must be the same as `output_types` module nesting.
|
|
||||||
* Each shape in `dataset.output_shapes` must be compatible with each shape
|
|
||||||
in `output_shapes` (if given).
|
|
||||||
|
|
||||||
Note: This helper permits "unsafe casts" for shapes, equivalent to using
|
|
||||||
`tf.Tensor.set_shape()` where domain-specific knowledge is available.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: A `Dataset` object.
|
|
||||||
output_types: A nested structure of `tf.DType` objects.
|
|
||||||
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
|
|
||||||
If omitted, the shapes will be inherited from `dataset`.
|
|
||||||
output_classes: (Optional.) A nested structure of class types. If omitted,
|
|
||||||
the class types will be inherited from `dataset`.
|
|
||||||
allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
|
|
||||||
reported output types and shapes of the restructured dataset, e.g. to
|
|
||||||
switch a sparse tensor represented as `tf.variant` to its user-visible
|
|
||||||
type and shape.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If either `output_types` or `output_shapes` is not compatible
|
|
||||||
with the structure of `dataset`.
|
|
||||||
"""
|
|
||||||
self._input_dataset = dataset
|
|
||||||
|
|
||||||
input_types = dataset_ops.get_legacy_output_types(dataset)
|
|
||||||
if not allow_unsafe_cast:
|
|
||||||
# Validate that the types are compatible.
|
|
||||||
output_types = nest.map_structure(dtypes.as_dtype, output_types)
|
|
||||||
flat_original_types = nest.flatten(input_types)
|
|
||||||
flat_new_types = nest.flatten(output_types)
|
|
||||||
if flat_original_types != flat_new_types:
|
|
||||||
raise ValueError(
|
|
||||||
"Dataset with output types %r cannot be restructured to have "
|
|
||||||
"output types %r" %
|
|
||||||
(dataset_ops.get_legacy_output_types(dataset), output_types))
|
|
||||||
|
|
||||||
input_shapes = dataset_ops.get_legacy_output_shapes(dataset)
|
|
||||||
if output_shapes is None:
|
|
||||||
# Inherit shapes from the original `dataset`.
|
|
||||||
output_shapes = nest.pack_sequence_as(
|
|
||||||
output_types, nest.flatten(input_shapes))
|
|
||||||
else:
|
|
||||||
if not allow_unsafe_cast:
|
|
||||||
# Validate that the shapes are compatible.
|
|
||||||
nest.assert_same_structure(output_types, output_shapes)
|
|
||||||
flat_original_shapes = nest.flatten(input_shapes)
|
|
||||||
flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
|
|
||||||
|
|
||||||
for original_shape, new_shape in zip(flat_original_shapes,
|
|
||||||
flat_new_shapes):
|
|
||||||
if not original_shape.is_compatible_with(new_shape):
|
|
||||||
raise ValueError(
|
|
||||||
"Dataset with output shapes %r cannot be restructured to have "
|
|
||||||
"incompatible output shapes %r" % (input_shapes,
|
|
||||||
output_shapes))
|
|
||||||
output_shapes = nest.map_structure_up_to(
|
|
||||||
output_types, tensor_shape.as_shape, output_shapes)
|
|
||||||
|
|
||||||
input_classes = dataset_ops.get_legacy_output_classes(dataset)
|
|
||||||
if output_classes is None:
|
|
||||||
# Inherit class types from the original `dataset`.
|
|
||||||
output_classes = nest.pack_sequence_as(
|
|
||||||
output_types, nest.flatten(input_classes))
|
|
||||||
|
|
||||||
self._structure = structure.convert_legacy_structure(
|
|
||||||
output_types, output_shapes, output_classes)
|
|
||||||
variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
|
|
||||||
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _element_structure(self):
|
|
||||||
return self._structure
|
|
||||||
|
|
||||||
|
|
||||||
class _UnbatchDataset(dataset_ops.UnaryDataset):
|
|
||||||
"""A dataset that splits the elements of its input into multiple elements."""
|
|
||||||
|
|
||||||
def __init__(self, input_dataset):
|
|
||||||
"""See `unbatch()` for more details."""
|
|
||||||
input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset)
|
|
||||||
flat_shapes = nest.flatten(input_shapes)
|
|
||||||
if any(s.ndims == 0 for s in flat_shapes):
|
|
||||||
raise ValueError("Cannot unbatch an input with scalar components.")
|
|
||||||
known_batch_dim = tensor_shape.Dimension(None)
|
|
||||||
for s in flat_shapes:
|
|
||||||
try:
|
|
||||||
known_batch_dim = known_batch_dim.merge_with(s[0])
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError("Cannot unbatch an input whose components have "
|
|
||||||
"different batch sizes.")
|
|
||||||
self._input_dataset = input_dataset
|
|
||||||
|
|
||||||
self._structure = dataset_ops.get_structure(input_dataset)._unbatch() # pylint: disable=protected-access
|
|
||||||
|
|
||||||
variant_tensor = ged_ops.experimental_unbatch_dataset(
|
|
||||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
|
||||||
**dataset_ops.flat_structure(self))
|
|
||||||
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _element_structure(self):
|
|
||||||
return self._structure
|
|
||||||
|
@ -715,6 +715,31 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "unbatch_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["unbatch_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:errors",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:session",
|
||||||
|
"//tensorflow/python:sparse_tensor",
|
||||||
|
"//tensorflow/python:string_ops",
|
||||||
|
"//tensorflow/python:tensor_shape",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python/data/kernel_tests:test_base",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"//tensorflow/python/ops/ragged",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "window_test",
|
name = "window_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops import batching
|
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -41,8 +40,7 @@ from tensorflow.python.util import compat
|
|||||||
class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
def testUnbatchWithUnknownRankInput(self):
|
def testUnbatchWithUnknownRankInput(self):
|
||||||
dataset = dataset_ops.Dataset.from_tensors([0, 1, 2,
|
dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]).unbatch()
|
||||||
3]).apply(batching.unbatch())
|
|
||||||
self.assertDatasetProduces(dataset, range(4))
|
self.assertDatasetProduces(dataset, range(4))
|
||||||
|
|
||||||
def testUnbatchScalarDataset(self):
|
def testUnbatchScalarDataset(self):
|
||||||
@ -51,7 +49,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
expected_types = (dtypes.int32,) * 3
|
expected_types = (dtypes.int32,) * 3
|
||||||
data = data.batch(2)
|
data = data.batch(2)
|
||||||
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
||||||
data = data.apply(batching.unbatch())
|
data = data.unbatch()
|
||||||
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
||||||
|
|
||||||
self.assertDatasetProduces(data, [(i,) * 3 for i in range(10)])
|
self.assertDatasetProduces(data, [(i,) * 3 for i in range(10)])
|
||||||
@ -63,7 +61,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
|
expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
|
||||||
data = data.batch(2)
|
data = data.batch(2)
|
||||||
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
||||||
data = data.apply(batching.unbatch())
|
data = data.unbatch()
|
||||||
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
||||||
|
|
||||||
self.assertDatasetProduces(
|
self.assertDatasetProduces(
|
||||||
@ -75,9 +73,9 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
values=list(range(10)),
|
values=list(range(10)),
|
||||||
dense_shape=[10, 10])
|
dense_shape=[10, 10])
|
||||||
data = dataset_ops.Dataset.from_tensors(st)
|
data = dataset_ops.Dataset.from_tensors(st)
|
||||||
data = data.apply(batching.unbatch())
|
data = data.unbatch()
|
||||||
data = data.batch(5)
|
data = data.batch(5)
|
||||||
data = data.apply(batching.unbatch())
|
data = data.unbatch()
|
||||||
expected_output = [
|
expected_output = [
|
||||||
sparse_tensor.SparseTensorValue([[i]], [i], [10]) for i in range(10)
|
sparse_tensor.SparseTensorValue([[i]], [i], [10]) for i in range(10)
|
||||||
]
|
]
|
||||||
@ -91,9 +89,9 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
|
rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
|
||||||
[[5]], [[6]], [[7]], [[8]], [[9]]])
|
[[5]], [[6]], [[7]], [[8]], [[9]]])
|
||||||
data = dataset_ops.Dataset.from_tensors((list(range(10)), st, rt))
|
data = dataset_ops.Dataset.from_tensors((list(range(10)), st, rt))
|
||||||
data = data.apply(batching.unbatch())
|
data = data.unbatch()
|
||||||
data = data.batch(5)
|
data = data.batch(5)
|
||||||
data = data.apply(batching.unbatch())
|
data = data.unbatch()
|
||||||
expected_output = [(i, sparse_tensor.SparseTensorValue([[i]], [i], [10]),
|
expected_output = [(i, sparse_tensor.SparseTensorValue([[i]], [i], [10]),
|
||||||
ragged_factory_ops.constant_value([[i]]))
|
ragged_factory_ops.constant_value([[i]]))
|
||||||
for i in range(10)]
|
for i in range(10)]
|
||||||
@ -104,10 +102,10 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
|
rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
|
||||||
[[5]], [[6]], [[7]], [[8]], [[9]]])
|
[[5]], [[6]], [[7]], [[8]], [[9]]])
|
||||||
data = dataset_ops.Dataset.from_tensors(rt)
|
data = dataset_ops.Dataset.from_tensors(rt)
|
||||||
data = data.apply(batching.unbatch())
|
data = data.unbatch()
|
||||||
data = data.batch(5)
|
data = data.batch(5)
|
||||||
data = data.batch(2)
|
data = data.batch(2)
|
||||||
data = data.apply(batching.unbatch())
|
data = data.unbatch()
|
||||||
expected_output = [
|
expected_output = [
|
||||||
ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]]]),
|
ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]]]),
|
||||||
ragged_factory_ops.constant_value([[[5]], [[6]], [[7]], [[8]], [[9]]]),
|
ragged_factory_ops.constant_value([[[5]], [[6]], [[7]], [[8]], [[9]]]),
|
||||||
@ -121,7 +119,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
expected_types = ((dtypes.int32,),) * 3
|
expected_types = ((dtypes.int32,),) * 3
|
||||||
data = data.batch(2)
|
data = data.batch(2)
|
||||||
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
||||||
data = data.apply(batching.unbatch())
|
data = data.unbatch()
|
||||||
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
||||||
|
|
||||||
self.assertDatasetProduces(data, [((i,),) * 3 for i in range(10)])
|
self.assertDatasetProduces(data, [((i,),) * 3 for i in range(10)])
|
||||||
@ -134,7 +132,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
data = data.batch(2)
|
data = data.batch(2)
|
||||||
self.assertAllEqual(expected_types,
|
self.assertAllEqual(expected_types,
|
||||||
dataset_ops.get_legacy_output_types(data))
|
dataset_ops.get_legacy_output_types(data))
|
||||||
data = data.apply(batching.unbatch())
|
data = data.unbatch()
|
||||||
self.assertAllEqual(expected_types,
|
self.assertAllEqual(expected_types,
|
||||||
dataset_ops.get_legacy_output_types(data))
|
dataset_ops.get_legacy_output_types(data))
|
||||||
|
|
||||||
@ -146,14 +144,14 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
data = dataset_ops.Dataset.from_tensors(
|
data = dataset_ops.Dataset.from_tensors(
|
||||||
(constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
|
(constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
|
||||||
constant_op.constant([], shape=[0, 4, 0])))
|
constant_op.constant([], shape=[0, 4, 0])))
|
||||||
data = data.apply(batching.unbatch())
|
data = data.unbatch()
|
||||||
self.assertDatasetProduces(data, [])
|
self.assertDatasetProduces(data, [])
|
||||||
|
|
||||||
def testUnbatchStaticShapeMismatch(self):
|
def testUnbatchStaticShapeMismatch(self):
|
||||||
data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
|
data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
|
||||||
np.arange(9)))
|
np.arange(9)))
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
data.apply(batching.unbatch())
|
data.unbatch()
|
||||||
|
|
||||||
# Note: dynamic shape mismatch is graph specific test.
|
# Note: dynamic shape mismatch is graph specific test.
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@ -161,7 +159,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
|
ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
|
||||||
ph2 = array_ops.placeholder(dtypes.int32, shape=None)
|
ph2 = array_ops.placeholder(dtypes.int32, shape=None)
|
||||||
data = dataset_ops.Dataset.from_tensors((ph1, ph2))
|
data = dataset_ops.Dataset.from_tensors((ph1, ph2))
|
||||||
data = data.apply(batching.unbatch())
|
data = data.unbatch()
|
||||||
iterator = dataset_ops.make_initializable_iterator(data)
|
iterator = dataset_ops.make_initializable_iterator(data)
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
@ -1452,6 +1452,54 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||||||
output_shapes=state_structure._flat_shapes,
|
output_shapes=state_structure._flat_shapes,
|
||||||
output_types=state_structure._flat_types))
|
output_types=state_structure._flat_types))
|
||||||
|
|
||||||
|
def unbatch(self):
|
||||||
|
"""Splits elements of a dataset into multiple elements.
|
||||||
|
|
||||||
|
For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
|
||||||
|
where `B` may vary for each input element, then for each element in the
|
||||||
|
dataset, the unbatched dataset will contain `B` consecutive elements
|
||||||
|
of shape `[a0, a1, ...]`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# NOTE: The following example uses `{ ... }` to represent the contents
|
||||||
|
# of a dataset.
|
||||||
|
ds = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
|
||||||
|
|
||||||
|
ds.unbatch() == {'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
|
||||||
|
```
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Dataset` transformation function, which can be passed to
|
||||||
|
`tf.data.Dataset.apply`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# NOTE(mrry): We must ensure that any SparseTensors in `dataset`
|
||||||
|
# are normalized to the rank-1 dense representation, so that the
|
||||||
|
# sparse-oblivious unbatching logic will slice them
|
||||||
|
# appropriately. This leads to a somewhat inefficient re-encoding step
|
||||||
|
# for all SparseTensor components.
|
||||||
|
# TODO(mrry): Consider optimizing this in future if it turns out to be
|
||||||
|
# a bottleneck.
|
||||||
|
def normalize(arg, *rest):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if rest:
|
||||||
|
return self._element_structure._to_batched_tensor_list((arg,) + rest)
|
||||||
|
else:
|
||||||
|
return self._element_structure._to_batched_tensor_list(arg)
|
||||||
|
|
||||||
|
normalized_dataset = self.map(normalize)
|
||||||
|
|
||||||
|
# NOTE(mrry): Our `map()` has lost information about the sparseness
|
||||||
|
# of any SparseTensor components, so re-apply the structure of the
|
||||||
|
# original dataset.
|
||||||
|
restructured_dataset = _RestructuredDataset(
|
||||||
|
normalized_dataset,
|
||||||
|
get_legacy_output_types(self),
|
||||||
|
get_legacy_output_shapes(self),
|
||||||
|
get_legacy_output_classes(self),
|
||||||
|
allow_unsafe_cast=True)
|
||||||
|
return _UnbatchDataset(restructured_dataset)
|
||||||
|
|
||||||
def with_options(self, options):
|
def with_options(self, options):
|
||||||
"""Returns a new `tf.data.Dataset` with the given options set.
|
"""Returns a new `tf.data.Dataset` with the given options set.
|
||||||
|
|
||||||
@ -3556,3 +3604,120 @@ class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset):
|
|||||||
**flat_structure(self))
|
**flat_structure(self))
|
||||||
super(_PrivateThreadPoolDataset, self).__init__(input_dataset,
|
super(_PrivateThreadPoolDataset, self).__init__(input_dataset,
|
||||||
variant_tensor)
|
variant_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
class _RestructuredDataset(UnaryDataset):
|
||||||
|
"""An internal helper for changing the structure and shape of a dataset."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dataset,
|
||||||
|
output_types,
|
||||||
|
output_shapes=None,
|
||||||
|
output_classes=None,
|
||||||
|
allow_unsafe_cast=False):
|
||||||
|
"""Creates a new dataset with the given output types and shapes.
|
||||||
|
|
||||||
|
The given `dataset` must have a structure that is convertible:
|
||||||
|
* `dataset.output_types` must be the same as `output_types` module nesting.
|
||||||
|
* Each shape in `dataset.output_shapes` must be compatible with each shape
|
||||||
|
in `output_shapes` (if given).
|
||||||
|
|
||||||
|
Note: This helper permits "unsafe casts" for shapes, equivalent to using
|
||||||
|
`tf.Tensor.set_shape()` where domain-specific knowledge is available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: A `Dataset` object.
|
||||||
|
output_types: A nested structure of `tf.DType` objects.
|
||||||
|
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
|
||||||
|
If omitted, the shapes will be inherited from `dataset`.
|
||||||
|
output_classes: (Optional.) A nested structure of class types. If omitted,
|
||||||
|
the class types will be inherited from `dataset`.
|
||||||
|
allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
|
||||||
|
reported output types and shapes of the restructured dataset, e.g. to
|
||||||
|
switch a sparse tensor represented as `tf.variant` to its user-visible
|
||||||
|
type and shape.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If either `output_types` or `output_shapes` is not compatible
|
||||||
|
with the structure of `dataset`.
|
||||||
|
"""
|
||||||
|
self._input_dataset = dataset
|
||||||
|
|
||||||
|
input_types = get_legacy_output_types(dataset)
|
||||||
|
if not allow_unsafe_cast:
|
||||||
|
# Validate that the types are compatible.
|
||||||
|
output_types = nest.map_structure(dtypes.as_dtype, output_types)
|
||||||
|
flat_original_types = nest.flatten(input_types)
|
||||||
|
flat_new_types = nest.flatten(output_types)
|
||||||
|
if flat_original_types != flat_new_types:
|
||||||
|
raise ValueError(
|
||||||
|
"Dataset with output types %r cannot be restructured to have "
|
||||||
|
"output types %r" %
|
||||||
|
(get_legacy_output_types(dataset), output_types))
|
||||||
|
|
||||||
|
input_shapes = get_legacy_output_shapes(dataset)
|
||||||
|
if output_shapes is None:
|
||||||
|
# Inherit shapes from the original `dataset`.
|
||||||
|
output_shapes = nest.pack_sequence_as(
|
||||||
|
output_types, nest.flatten(input_shapes))
|
||||||
|
else:
|
||||||
|
if not allow_unsafe_cast:
|
||||||
|
# Validate that the shapes are compatible.
|
||||||
|
nest.assert_same_structure(output_types, output_shapes)
|
||||||
|
flat_original_shapes = nest.flatten(input_shapes)
|
||||||
|
flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
|
||||||
|
|
||||||
|
for original_shape, new_shape in zip(flat_original_shapes,
|
||||||
|
flat_new_shapes):
|
||||||
|
if not original_shape.is_compatible_with(new_shape):
|
||||||
|
raise ValueError(
|
||||||
|
"Dataset with output shapes %r cannot be restructured to have "
|
||||||
|
"incompatible output shapes %r" % (input_shapes,
|
||||||
|
output_shapes))
|
||||||
|
output_shapes = nest.map_structure_up_to(
|
||||||
|
output_types, tensor_shape.as_shape, output_shapes)
|
||||||
|
|
||||||
|
input_classes = get_legacy_output_classes(dataset)
|
||||||
|
if output_classes is None:
|
||||||
|
# Inherit class types from the original `dataset`.
|
||||||
|
output_classes = nest.pack_sequence_as(
|
||||||
|
output_types, nest.flatten(input_classes))
|
||||||
|
|
||||||
|
self._structure = structure_lib.convert_legacy_structure(
|
||||||
|
output_types, output_shapes, output_classes)
|
||||||
|
variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
|
||||||
|
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _element_structure(self):
|
||||||
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
|
class _UnbatchDataset(UnaryDataset):
|
||||||
|
"""A dataset that splits the elements of its input into multiple elements."""
|
||||||
|
|
||||||
|
def __init__(self, input_dataset):
|
||||||
|
"""See `unbatch()` for more details."""
|
||||||
|
input_shapes = get_legacy_output_shapes(input_dataset)
|
||||||
|
flat_shapes = nest.flatten(input_shapes)
|
||||||
|
if any(s.ndims == 0 for s in flat_shapes):
|
||||||
|
raise ValueError("Cannot unbatch an input with scalar components.")
|
||||||
|
known_batch_dim = tensor_shape.Dimension(None)
|
||||||
|
for s in flat_shapes:
|
||||||
|
try:
|
||||||
|
known_batch_dim = known_batch_dim.merge_with(s[0])
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Cannot unbatch an input whose components have "
|
||||||
|
"different batch sizes.")
|
||||||
|
self._input_dataset = input_dataset
|
||||||
|
|
||||||
|
self._structure = get_structure(input_dataset)._unbatch() # pylint: disable=protected-access
|
||||||
|
|
||||||
|
variant_tensor = ged_ops.experimental_unbatch_dataset(
|
||||||
|
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
|
**flat_structure(self))
|
||||||
|
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _element_structure(self):
|
||||||
|
return self._structure
|
||||||
|
@ -133,6 +133,10 @@ tf_class {
|
|||||||
name: "take"
|
name: "take"
|
||||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unbatch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "window"
|
name: "window"
|
||||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||||
|
@ -135,6 +135,10 @@ tf_class {
|
|||||||
name: "take"
|
name: "take"
|
||||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unbatch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "window"
|
name: "window"
|
||||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||||
|
@ -135,6 +135,10 @@ tf_class {
|
|||||||
name: "take"
|
name: "take"
|
||||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unbatch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "window"
|
name: "window"
|
||||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||||
|
@ -135,6 +135,10 @@ tf_class {
|
|||||||
name: "take"
|
name: "take"
|
||||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unbatch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "window"
|
name: "window"
|
||||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||||
|
@ -135,6 +135,10 @@ tf_class {
|
|||||||
name: "take"
|
name: "take"
|
||||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unbatch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "window"
|
name: "window"
|
||||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||||
|
@ -135,6 +135,10 @@ tf_class {
|
|||||||
name: "take"
|
name: "take"
|
||||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unbatch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "window"
|
name: "window"
|
||||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||||
|
@ -135,6 +135,10 @@ tf_class {
|
|||||||
name: "take"
|
name: "take"
|
||||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unbatch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "window"
|
name: "window"
|
||||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||||
|
@ -100,6 +100,10 @@ tf_class {
|
|||||||
name: "take"
|
name: "take"
|
||||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unbatch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "window"
|
name: "window"
|
||||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||||
|
@ -102,6 +102,10 @@ tf_class {
|
|||||||
name: "take"
|
name: "take"
|
||||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unbatch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "window"
|
name: "window"
|
||||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||||
|
@ -101,6 +101,10 @@ tf_class {
|
|||||||
name: "take"
|
name: "take"
|
||||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unbatch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "window"
|
name: "window"
|
||||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||||
|
@ -102,6 +102,10 @@ tf_class {
|
|||||||
name: "take"
|
name: "take"
|
||||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unbatch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "window"
|
name: "window"
|
||||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||||
|
@ -102,6 +102,10 @@ tf_class {
|
|||||||
name: "take"
|
name: "take"
|
||||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unbatch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "window"
|
name: "window"
|
||||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||||
|
@ -102,6 +102,10 @@ tf_class {
|
|||||||
name: "take"
|
name: "take"
|
||||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unbatch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "window"
|
name: "window"
|
||||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||||
|
@ -102,6 +102,10 @@ tf_class {
|
|||||||
name: "take"
|
name: "take"
|
||||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unbatch"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "window"
|
name: "window"
|
||||||
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user