parent
9c7ffad45c
commit
b8111870ca
tensorflow/python/data
@ -20,84 +20,20 @@ from __future__ import print_function
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import random_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import convert
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||
from tensorflow.python.ops import gen_stateless_random_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
class _ParallelInterleaveDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` that maps a function over its input and flattens the result."""
|
||||
|
||||
def __init__(self, input_dataset, map_func, cycle_length, block_length,
|
||||
sloppy, buffer_output_elements, prefetch_input_elements):
|
||||
"""See `tf.data.experimental.parallel_interleave()` for details."""
|
||||
self._input_dataset = input_dataset
|
||||
self._map_func = dataset_ops.StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
if not isinstance(self._map_func.output_structure, dataset_ops.DatasetSpec):
|
||||
raise TypeError("`map_func` must return a `Dataset` object.")
|
||||
self._element_spec = self._map_func.output_structure._element_spec # pylint: disable=protected-access
|
||||
self._cycle_length = ops.convert_to_tensor(
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
self._block_length = ops.convert_to_tensor(
|
||||
block_length, dtype=dtypes.int64, name="block_length")
|
||||
self._sloppy = ops.convert_to_tensor(
|
||||
sloppy, dtype=dtypes.bool, name="sloppy")
|
||||
self._buffer_output_elements = convert.optional_param_to_tensor(
|
||||
"buffer_output_elements",
|
||||
buffer_output_elements,
|
||||
argument_default=2 * block_length)
|
||||
self._prefetch_input_elements = convert.optional_param_to_tensor(
|
||||
"prefetch_input_elements",
|
||||
prefetch_input_elements,
|
||||
argument_default=2 * cycle_length)
|
||||
# pylint: disable=protected-access
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.parallel_interleave_dataset(
|
||||
self._input_dataset._variant_tensor,
|
||||
self._map_func.function.captured_inputs,
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._sloppy,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
f=self._map_func.function,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_parallel_interleave_dataset(
|
||||
self._input_dataset._variant_tensor,
|
||||
self._map_func.function.captured_inputs,
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._sloppy,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
f=self._map_func.function,
|
||||
**self._flat_structure)
|
||||
# pylint: enable=protected-access
|
||||
super(_ParallelInterleaveDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
def _transformation_name(self):
|
||||
return "tf.data.experimental.parallel_interleave()"
|
||||
|
||||
|
||||
@deprecation.deprecated(
|
||||
None,
|
||||
"Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, "
|
||||
@ -154,7 +90,7 @@ def parallel_interleave(map_func,
|
||||
`tf.data.Dataset.apply`.
|
||||
"""
|
||||
def _apply_fn(dataset):
|
||||
return _ParallelInterleaveDataset(
|
||||
return readers.ParallelInterleaveDataset(
|
||||
dataset, map_func, cycle_length, block_length, sloppy,
|
||||
buffer_output_elements, prefetch_input_elements)
|
||||
|
||||
@ -193,13 +129,13 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
|
||||
# pylint: disable=protected-access
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
return (
|
||||
ged_ops.directed_interleave_dataset(
|
||||
gen_experimental_dataset_ops.directed_interleave_dataset(
|
||||
self._selector_input._variant_tensor,
|
||||
[data_input._variant_tensor for data_input in self._data_inputs],
|
||||
**self._flat_structure))
|
||||
else:
|
||||
return (
|
||||
ged_ops.experimental_directed_interleave_dataset(
|
||||
gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
|
||||
self._selector_input._variant_tensor,
|
||||
[data_input._variant_tensor for data_input in self._data_inputs],
|
||||
**self._flat_structure))
|
||||
@ -358,4 +294,3 @@ choose_from_datasets_v1.__doc__ = choose_from_datasets_v2.__doc__
|
||||
# these aliases in place.
|
||||
choose_from_datasets = choose_from_datasets_v1
|
||||
sample_from_datasets = sample_from_datasets_v1
|
||||
|
||||
|
@ -26,6 +26,7 @@ import numpy as np
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import error_ops
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
from tensorflow.python.data.experimental.ops import parsing_ops
|
||||
from tensorflow.python.data.experimental.ops import shuffle_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
@ -493,18 +494,9 @@ def make_csv_dataset_v2(
|
||||
return features
|
||||
|
||||
# Read files sequentially (if num_parallel_reads=1) or in parallel
|
||||
cycle_length = num_parallel_reads
|
||||
if num_parallel_reads == dataset_ops.AUTOTUNE:
|
||||
cycle_length = core_readers.DEFAULT_CYCLE_LENGTH
|
||||
dataset = dataset.interleave(
|
||||
filename_to_dataset,
|
||||
cycle_length,
|
||||
num_parallel_calls=num_parallel_reads)
|
||||
|
||||
if sloppy:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_deterministic = False
|
||||
dataset = dataset.with_options(options)
|
||||
dataset = dataset.apply(
|
||||
interleave_ops.parallel_interleave(
|
||||
filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
|
||||
|
||||
dataset = _maybe_shuffle_and_repeat(
|
||||
dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
|
||||
@ -846,18 +838,11 @@ def make_batched_features_dataset_v2(file_pattern,
|
||||
reader_args = []
|
||||
|
||||
# Read files sequentially (if reader_num_threads=1) or in parallel
|
||||
cycle_length = reader_num_threads
|
||||
if reader_num_threads == dataset_ops.AUTOTUNE:
|
||||
cycle_length = core_readers.DEFAULT_CYCLE_LENGTH
|
||||
dataset = dataset.interleave(
|
||||
lambda filename: reader(filename, *reader_args),
|
||||
cycle_length,
|
||||
num_parallel_calls=reader_num_threads)
|
||||
|
||||
if sloppy_ordering:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_deterministic = False
|
||||
dataset = dataset.with_options(options)
|
||||
dataset = dataset.apply(
|
||||
interleave_ops.parallel_interleave(
|
||||
lambda filename: reader(filename, *reader_args),
|
||||
cycle_length=reader_num_threads,
|
||||
sloppy=sloppy_ordering))
|
||||
|
||||
# Extract values if the `Example` tensors are stored as key-value tuples.
|
||||
if dataset_ops.get_legacy_output_types(dataset) == (
|
||||
|
@ -26,17 +26,13 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# TODO(b/64974358): Increase default buffer size to 256 MB.
|
||||
_DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB
|
||||
|
||||
# If the user requests the degree of interleave parallelism to be autotuned,
|
||||
# cycle length controls the maximum level of parallelism. We set it to a small
|
||||
# constant as a tradeoff between effective parallelism and memory and CPU usage.
|
||||
DEFAULT_CYCLE_LENGTH = 10
|
||||
|
||||
|
||||
def _create_or_validate_filenames_dataset(filenames):
|
||||
"""Creates (or validates) a dataset of filenames.
|
||||
@ -84,13 +80,10 @@ def _create_dataset_reader(dataset_creator, filenames, num_parallel_reads=None):
|
||||
if num_parallel_reads is None:
|
||||
return filenames.flat_map(read_one_file)
|
||||
else:
|
||||
cycle_length = num_parallel_reads
|
||||
if num_parallel_reads == dataset_ops.AUTOTUNE:
|
||||
cycle_length = DEFAULT_CYCLE_LENGTH
|
||||
return filenames.interleave(
|
||||
read_one_file,
|
||||
cycle_length,
|
||||
num_parallel_calls=num_parallel_reads)
|
||||
return ParallelInterleaveDataset(
|
||||
filenames, read_one_file, cycle_length=num_parallel_reads,
|
||||
block_length=1, sloppy=False, buffer_output_elements=None,
|
||||
prefetch_input_elements=None)
|
||||
|
||||
|
||||
class _TextLineDataset(dataset_ops.DatasetSource):
|
||||
@ -220,6 +213,68 @@ class _TFRecordDataset(dataset_ops.DatasetSource):
|
||||
return tensor_spec.TensorSpec([], dtypes.string)
|
||||
|
||||
|
||||
class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` that maps a function over its input and flattens the result."""
|
||||
|
||||
def __init__(self, input_dataset, map_func, cycle_length, block_length,
|
||||
sloppy, buffer_output_elements, prefetch_input_elements):
|
||||
"""See `tf.data.experimental.parallel_interleave()` for details."""
|
||||
self._input_dataset = input_dataset
|
||||
self._map_func = dataset_ops.StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
if not isinstance(self._map_func.output_structure, dataset_ops.DatasetSpec):
|
||||
raise TypeError("`map_func` must return a `Dataset` object.")
|
||||
self._element_spec = self._map_func.output_structure._element_spec # pylint: disable=protected-access
|
||||
self._cycle_length = ops.convert_to_tensor(
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
self._block_length = ops.convert_to_tensor(
|
||||
block_length, dtype=dtypes.int64, name="block_length")
|
||||
self._sloppy = ops.convert_to_tensor(
|
||||
sloppy, dtype=dtypes.bool, name="sloppy")
|
||||
self._buffer_output_elements = convert.optional_param_to_tensor(
|
||||
"buffer_output_elements",
|
||||
buffer_output_elements,
|
||||
argument_default=2 * block_length)
|
||||
self._prefetch_input_elements = convert.optional_param_to_tensor(
|
||||
"prefetch_input_elements",
|
||||
prefetch_input_elements,
|
||||
argument_default=2 * cycle_length)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.parallel_interleave_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._sloppy,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
f=self._map_func.function,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_parallel_interleave_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._sloppy,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
f=self._map_func.function,
|
||||
**self._flat_structure)
|
||||
super(ParallelInterleaveDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
def _transformation_name(self):
|
||||
return "tf.data.experimental.parallel_interleave()"
|
||||
|
||||
|
||||
@tf_export("data.TFRecordDataset", v1=[])
|
||||
class TFRecordDatasetV2(dataset_ops.DatasetV2):
|
||||
"""A `Dataset` comprising records from one or more TFRecord files."""
|
||||
|
Loading…
Reference in New Issue
Block a user