[tf.data] Forward compatibility cleanup.
PiperOrigin-RevId: 261816972
This commit is contained in:
parent
2d6cca0122
commit
9239c61f20
@ -362,7 +362,7 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
|
||||
num_output = 100 // 16 + 1
|
||||
self.parallelCallsStats(
|
||||
dataset_fn, {"ExperimentalMapAndBatchDataset"},
|
||||
dataset_fn, {"MapAndBatchDataset"},
|
||||
num_output,
|
||||
check_elements=False,
|
||||
function_processing_time=True)
|
||||
@ -391,7 +391,7 @@ class FeatureStatsDatasetTest(
|
||||
num_output = total_records // batch_size + 1
|
||||
|
||||
self.parallelCallsStats(
|
||||
dataset_fn, {"ExperimentalParseExampleDataset"},
|
||||
dataset_fn, {"ParseExampleDataset"},
|
||||
num_output,
|
||||
check_elements=False)
|
||||
|
||||
@ -409,19 +409,19 @@ class FeatureStatsDatasetTest(
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(
|
||||
handle,
|
||||
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
|
||||
self.regexForNodeName("record_stats::ParseExampleDataset",
|
||||
"features_count"), total_records)
|
||||
self.assertStatisticsHasCount(
|
||||
handle,
|
||||
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
|
||||
self.regexForNodeName("record_stats::ParseExampleDataset",
|
||||
"feature_values_count"), total_records)
|
||||
self.assertStatisticsHasSum(
|
||||
handle,
|
||||
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
|
||||
self.regexForNodeName("record_stats::ParseExampleDataset",
|
||||
"features_count"), total_records * 4)
|
||||
self.assertStatisticsHasSum(
|
||||
handle,
|
||||
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
|
||||
self.regexForNodeName("record_stats::ParseExampleDataset",
|
||||
"feature_values_count"),
|
||||
self._sum_keywords(1) * num_epochs + 3 * total_records)
|
||||
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import convert
|
||||
from tensorflow.python.data.util import nest
|
||||
@ -247,18 +246,11 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
|
||||
tensor_shape.TensorShape([None]).concatenate(self._row_shape),
|
||||
dataset_ops.get_legacy_output_types(input_dataset))
|
||||
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._batch_size,
|
||||
row_shape=convert.partial_shape_to_tensor(self._row_shape),
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._batch_size,
|
||||
row_shape=convert.partial_shape_to_tensor(self._row_shape),
|
||||
**self._flat_structure)
|
||||
variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._batch_size,
|
||||
row_shape=convert.partial_shape_to_tensor(self._row_shape),
|
||||
**self._flat_structure)
|
||||
super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
@ -302,26 +294,15 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
||||
lambda component_spec: component_spec._batch(None),
|
||||
self._map_func.output_structure)
|
||||
# pylint: enable=protected-access
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.map_and_batch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
batch_size=self._batch_size_t,
|
||||
num_parallel_calls=self._num_parallel_calls_t,
|
||||
drop_remainder=self._drop_remainder_t,
|
||||
preserve_cardinality=True,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_map_and_batch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
batch_size=self._batch_size_t,
|
||||
num_parallel_calls=self._num_parallel_calls_t,
|
||||
drop_remainder=self._drop_remainder_t,
|
||||
preserve_cardinality=True,
|
||||
**self._flat_structure)
|
||||
variant_tensor = ged_ops.map_and_batch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
batch_size=self._batch_size_t,
|
||||
num_parallel_calls=self._num_parallel_calls_t,
|
||||
drop_remainder=self._drop_remainder_t,
|
||||
preserve_cardinality=True,
|
||||
**self._flat_structure)
|
||||
super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -49,7 +48,4 @@ def cardinality(dataset):
|
||||
constant `INFINITE_CARDINALITY` and `UNKNOWN_CARDINALITY` respectively.
|
||||
"""
|
||||
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
return ged_ops.dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access
|
||||
else:
|
||||
return ged_ops.experimental_dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access
|
||||
return ged_ops.dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access
|
||||
|
@ -49,18 +49,11 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
self._element_spec = input_dataset.element_spec
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.auto_shard_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
num_workers=num_workers,
|
||||
index=index,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_auto_shard_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
num_workers=num_workers,
|
||||
index=index,
|
||||
**self._flat_structure)
|
||||
variant_tensor = ged_ops.auto_shard_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
num_workers=num_workers,
|
||||
index=index,
|
||||
**self._flat_structure)
|
||||
super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
@ -112,13 +105,8 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
|
||||
num_workers=num_workers,
|
||||
use_fallback=use_fallback,
|
||||
**self._flat_structure)
|
||||
elif compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.rebatch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
num_workers=num_workers,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_rebatch_dataset(
|
||||
variant_tensor = ged_ops.rebatch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
num_workers=num_workers,
|
||||
**self._flat_structure)
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -60,14 +59,8 @@ class _IgnoreErrorsDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
def __init__(self, input_dataset):
|
||||
"""See `Dataset.ignore_errors()` for details."""
|
||||
self._input_dataset = input_dataset
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.ignore_errors_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure))
|
||||
else:
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure))
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.ignore_errors_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure))
|
||||
super(_IgnoreErrorsDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
@ -19,7 +19,6 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
@ -255,30 +254,17 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
|
||||
self._make_init_func(reducer.init_func)
|
||||
self._make_reduce_func(reducer.reduce_func, input_dataset)
|
||||
self._make_finalize_func(reducer.finalize_func)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.experimental_group_by_reducer_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._init_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._finalize_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
init_func=self._init_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
finalize_func=self._finalize_func.function,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.group_by_reducer_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._init_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._finalize_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
init_func=self._init_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
finalize_func=self._finalize_func.function,
|
||||
**self._flat_structure)
|
||||
variant_tensor = ged_ops.experimental_group_by_reducer_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._init_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._finalize_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
init_func=self._init_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
finalize_func=self._finalize_func.function,
|
||||
**self._flat_structure)
|
||||
super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _make_key_func(self, key_func, input_dataset):
|
||||
@ -390,26 +376,15 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
|
||||
self._make_key_func(key_func, input_dataset)
|
||||
self._make_reduce_func(reduce_func, input_dataset)
|
||||
self._make_window_size_func(window_size_func)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.group_by_window_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._window_size_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
window_size_func=self._window_size_func.function,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_group_by_window_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._window_size_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
window_size_func=self._window_size_func.function,
|
||||
**self._flat_structure)
|
||||
variant_tensor = ged_ops.group_by_window_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._window_size_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
window_size_func=self._window_size_func.function,
|
||||
**self._flat_structure)
|
||||
super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _make_window_size_func(self, window_size_func):
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import random_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
@ -127,19 +126,11 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
return (
|
||||
gen_experimental_dataset_ops.directed_interleave_dataset(
|
||||
self._selector_input._variant_tensor,
|
||||
[data_input._variant_tensor for data_input in self._data_inputs],
|
||||
**self._flat_structure))
|
||||
else:
|
||||
return (
|
||||
gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
|
||||
self._selector_input._variant_tensor,
|
||||
[data_input._variant_tensor for data_input in self._data_inputs],
|
||||
**self._flat_structure))
|
||||
# pylint: enable=protected-access
|
||||
return (
|
||||
gen_experimental_dataset_ops.directed_interleave_dataset(
|
||||
self._selector_input._variant_tensor,
|
||||
[data_input._variant_tensor for data_input in self._data_inputs],
|
||||
**self._flat_structure))
|
||||
|
||||
def _inputs(self):
|
||||
return [self._selector_input] + self._data_inputs
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -32,11 +31,7 @@ class MatchingFilesDataset(dataset_ops.DatasetSource):
|
||||
def __init__(self, patterns):
|
||||
self._patterns = ops.convert_to_tensor(
|
||||
patterns, dtype=dtypes.string, name="patterns")
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.matching_files_dataset(self._patterns)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_matching_files_dataset(
|
||||
self._patterns)
|
||||
variant_tensor = ged_ops.matching_files_dataset(self._patterns)
|
||||
super(MatchingFilesDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -105,18 +104,11 @@ class _AssertNextDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
raise ValueError("At least one transformation should be specified")
|
||||
self._transformations = ops.convert_to_tensor(
|
||||
transformations, dtype=dtypes.string, name="transformations")
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.assert_next_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._transformations,
|
||||
**self._flat_structure))
|
||||
else:
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_assert_next_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._transformations,
|
||||
**self._flat_structure))
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.assert_next_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._transformations,
|
||||
**self._flat_structure))
|
||||
super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
@ -126,16 +118,10 @@ class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
def __init__(self, input_dataset):
|
||||
"""See `non_serializable()` for details."""
|
||||
self._input_dataset = input_dataset
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.non_serializable_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure))
|
||||
else:
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_non_serializable_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure))
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.non_serializable_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure))
|
||||
super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
@ -171,18 +157,11 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
|
||||
"""
|
||||
self._datasets = list(datasets)
|
||||
self._element_spec = self._datasets[0].element_spec
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.choose_fastest_dataset(
|
||||
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
|
||||
num_experiments=num_experiments,
|
||||
**self._flat_structure))
|
||||
else:
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_choose_fastest_dataset(
|
||||
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
|
||||
num_experiments=num_experiments,
|
||||
**self._flat_structure))
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.choose_fastest_dataset(
|
||||
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
|
||||
num_experiments=num_experiments,
|
||||
**self._flat_structure))
|
||||
super(_ChooseFastestDataset, self).__init__(variant_tensor)
|
||||
|
||||
def _inputs(self):
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -81,28 +80,16 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
||||
self._element_spec = structure.convert_legacy_structure(
|
||||
output_types, output_shapes, output_classes)
|
||||
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.parse_example_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._num_parallel_calls,
|
||||
self._dense_defaults,
|
||||
self._sparse_keys,
|
||||
self._dense_keys,
|
||||
self._sparse_types,
|
||||
self._dense_shapes,
|
||||
**self._flat_structure))
|
||||
else:
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_parse_example_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._num_parallel_calls,
|
||||
self._dense_defaults,
|
||||
self._sparse_keys,
|
||||
self._dense_keys,
|
||||
self._sparse_types,
|
||||
self._dense_shapes,
|
||||
**self._flat_structure))
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.parse_example_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._num_parallel_calls,
|
||||
self._dense_defaults,
|
||||
self._sparse_keys,
|
||||
self._dense_keys,
|
||||
self._sparse_types,
|
||||
self._dense_shapes,
|
||||
**self._flat_structure))
|
||||
super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
|
@ -19,7 +19,6 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import random_seed
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -35,12 +34,8 @@ class RandomDatasetV2(dataset_ops.DatasetSource):
|
||||
def __init__(self, seed=None):
|
||||
"""A `Dataset` of pseudorandom values."""
|
||||
self._seed, self._seed2 = random_seed.get_seed(seed)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = gen_experimental_dataset_ops.random_dataset(
|
||||
seed=self._seed, seed2=self._seed2, **self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_random_dataset(
|
||||
seed=self._seed, seed2=self._seed2, **self._flat_structure)
|
||||
variant_tensor = gen_experimental_dataset_ops.random_dataset(
|
||||
seed=self._seed, seed2=self._seed2, **self._flat_structure)
|
||||
super(RandomDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
|
@ -19,12 +19,11 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import csv
|
||||
import gzip
|
||||
import functools
|
||||
import gzip
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import error_ops
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
@ -687,30 +686,17 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
|
||||
)
|
||||
self._element_spec = tuple(
|
||||
tensor_spec.TensorSpec([], d.dtype) for d in self._record_defaults)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = gen_experimental_dataset_ops.csv_dataset(
|
||||
filenames=self._filenames,
|
||||
record_defaults=self._record_defaults,
|
||||
buffer_size=self._buffer_size,
|
||||
header=self._header,
|
||||
output_shapes=self._flat_shapes,
|
||||
field_delim=self._field_delim,
|
||||
use_quote_delim=self._use_quote_delim,
|
||||
na_value=self._na_value,
|
||||
select_cols=self._select_cols,
|
||||
compression_type=self._compression_type)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
|
||||
filenames=self._filenames,
|
||||
record_defaults=self._record_defaults,
|
||||
buffer_size=self._buffer_size,
|
||||
header=self._header,
|
||||
output_shapes=self._flat_shapes,
|
||||
field_delim=self._field_delim,
|
||||
use_quote_delim=self._use_quote_delim,
|
||||
na_value=self._na_value,
|
||||
select_cols=self._select_cols,
|
||||
compression_type=self._compression_type)
|
||||
variant_tensor = gen_experimental_dataset_ops.csv_dataset(
|
||||
filenames=self._filenames,
|
||||
record_defaults=self._record_defaults,
|
||||
buffer_size=self._buffer_size,
|
||||
header=self._header,
|
||||
output_shapes=self._flat_shapes,
|
||||
field_delim=self._field_delim,
|
||||
use_quote_delim=self._use_quote_delim,
|
||||
na_value=self._na_value,
|
||||
select_cols=self._select_cols,
|
||||
compression_type=self._compression_type)
|
||||
super(CsvDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
@ -993,14 +979,9 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
|
||||
query, dtype=dtypes.string, name="query")
|
||||
self._element_spec = nest.map_structure(
|
||||
lambda dtype: tensor_spec.TensorSpec([], dtype), output_types)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = gen_experimental_dataset_ops.sql_dataset(
|
||||
self._driver_name, self._data_source_name, self._query,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
|
||||
self._driver_name, self._data_source_name, self._query,
|
||||
**self._flat_structure)
|
||||
variant_tensor = gen_experimental_dataset_ops.sql_dataset(
|
||||
self._driver_name, self._data_source_name, self._query,
|
||||
**self._flat_structure)
|
||||
super(SqlDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
@ -121,22 +120,13 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
||||
self._scan_func = wrapped_func
|
||||
self._scan_func.function.add_to_graph(ops.get_default_graph())
|
||||
# pylint: disable=protected-access
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = gen_experimental_dataset_ops.scan_dataset(
|
||||
self._input_dataset._variant_tensor,
|
||||
structure.to_tensor_list(self._state_structure, self._initial_state),
|
||||
self._scan_func.function.captured_inputs,
|
||||
f=self._scan_func.function,
|
||||
preserve_cardinality=True,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset(
|
||||
self._input_dataset._variant_tensor,
|
||||
structure.to_tensor_list(self._state_structure, self._initial_state),
|
||||
self._scan_func.function.captured_inputs,
|
||||
f=self._scan_func.function,
|
||||
preserve_cardinality=True,
|
||||
**self._flat_structure)
|
||||
variant_tensor = gen_experimental_dataset_ops.scan_dataset(
|
||||
self._input_dataset._variant_tensor,
|
||||
structure.to_tensor_list(self._state_structure, self._initial_state),
|
||||
self._scan_func.function.captured_inputs,
|
||||
f=self._scan_func.function,
|
||||
preserve_cardinality=True,
|
||||
**self._flat_structure)
|
||||
super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||
|
||||
@ -28,16 +27,10 @@ class _SleepDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
def __init__(self, input_dataset, sleep_microseconds):
|
||||
self._input_dataset = input_dataset
|
||||
self._sleep_microseconds = sleep_microseconds
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = gen_experimental_dataset_ops.sleep_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._sleep_microseconds,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_sleep_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._sleep_microseconds,
|
||||
**self._flat_structure)
|
||||
variant_tensor = gen_experimental_dataset_ops.sleep_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._sleep_microseconds,
|
||||
**self._flat_structure)
|
||||
super(_SleepDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
|
@ -19,7 +19,6 @@ from __future__ import print_function
|
||||
|
||||
import tempfile
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
from tensorflow.python.ops import summary_ops_v2
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -126,10 +125,7 @@ class StatsAggregatorV1(object):
|
||||
|
||||
def __init__(self):
|
||||
"""Creates a `StatsAggregator`."""
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
self._resource = ged_ops.stats_aggregator_handle()
|
||||
else:
|
||||
self._resource = ged_ops.experimental_stats_aggregator_handle()
|
||||
self._resource = ged_ops.stats_aggregator_handle()
|
||||
|
||||
def get_summary(self):
|
||||
"""Returns a string `tf.Tensor` that summarizes the aggregated statistics.
|
||||
@ -141,10 +137,7 @@ class StatsAggregatorV1(object):
|
||||
Returns:
|
||||
A scalar string `tf.Tensor` that summarizes the aggregated statistics.
|
||||
"""
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
return ged_ops.stats_aggregator_summary(self._resource)
|
||||
else:
|
||||
return ged_ops.experimental_stats_aggregator_summary(self._resource)
|
||||
return ged_ops.stats_aggregator_summary(self._resource)
|
||||
|
||||
|
||||
# TODO(b/116314787): Change this to StatsAggregatorV2 when we have stable
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -66,14 +65,8 @@ def bytes_produced_stats(tag):
|
||||
"""
|
||||
|
||||
def _apply_fn(dataset):
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
return _StatsDataset(
|
||||
dataset, gen_experimental_dataset_ops.bytes_produced_stats_dataset,
|
||||
tag)
|
||||
else:
|
||||
return _StatsDataset(
|
||||
dataset, gen_experimental_dataset_ops
|
||||
.experimental_bytes_produced_stats_dataset, tag)
|
||||
return _StatsDataset(
|
||||
dataset, gen_experimental_dataset_ops.bytes_produced_stats_dataset, tag)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
@ -95,14 +88,8 @@ def latency_stats(tag):
|
||||
"""
|
||||
|
||||
def _apply_fn(dataset):
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
return _StatsDataset(
|
||||
dataset,
|
||||
gen_experimental_dataset_ops.latency_stats_dataset, tag)
|
||||
else:
|
||||
return _StatsDataset(
|
||||
dataset,
|
||||
gen_experimental_dataset_ops.experimental_latency_stats_dataset, tag)
|
||||
return _StatsDataset(
|
||||
dataset, gen_experimental_dataset_ops.latency_stats_dataset, tag)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
@ -42,18 +41,11 @@ class _TakeWhileDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
raise ValueError("`predicate` must return a scalar boolean tensor.")
|
||||
|
||||
self._predicate = wrapped_func
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
var_tensor = gen_experimental_dataset_ops.take_while_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
other_arguments=self._predicate.function.captured_inputs,
|
||||
predicate=self._predicate.function,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
var_tensor = gen_experimental_dataset_ops.experimental_take_while_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
other_arguments=self._predicate.function.captured_inputs,
|
||||
predicate=self._predicate.function,
|
||||
**self._flat_structure)
|
||||
var_tensor = gen_experimental_dataset_ops.take_while_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
other_arguments=self._predicate.function.captured_inputs,
|
||||
predicate=self._predicate.function,
|
||||
**self._flat_structure)
|
||||
super(_TakeWhileDataset, self).__init__(input_dataset, var_tensor)
|
||||
|
||||
def _functions(self):
|
||||
|
@ -19,7 +19,6 @@ from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
@ -47,31 +46,18 @@ class PrivateThreadPool(object):
|
||||
"""Creates a `PrivateThreadPool` with the given number of threads."""
|
||||
if context.executing_eagerly():
|
||||
shared_name = _generate_shared_name("privatethreadpool")
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
self._resource = ged_ops.thread_pool_handle(
|
||||
num_threads=num_threads,
|
||||
max_intra_op_parallelism=max_intra_op_parallelism,
|
||||
display_name=display_name,
|
||||
shared_name=shared_name)
|
||||
else:
|
||||
self._resource = ged_ops.experimental_thread_pool_handle(
|
||||
num_threads=num_threads,
|
||||
max_intra_op_parallelism=max_intra_op_parallelism,
|
||||
display_name=display_name,
|
||||
shared_name=shared_name)
|
||||
self._resource = ged_ops.thread_pool_handle(
|
||||
num_threads=num_threads,
|
||||
max_intra_op_parallelism=max_intra_op_parallelism,
|
||||
display_name=display_name,
|
||||
shared_name=shared_name)
|
||||
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
|
||||
handle=self._resource, handle_device=context.context().device_name)
|
||||
else:
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
self._resource = ged_ops.thread_pool_handle(
|
||||
num_threads=num_threads,
|
||||
max_intra_op_parallelism=max_intra_op_parallelism,
|
||||
display_name=display_name)
|
||||
else:
|
||||
self._resource = ged_ops.experimental_thread_pool_handle(
|
||||
num_threads=num_threads,
|
||||
max_intra_op_parallelism=max_intra_op_parallelism,
|
||||
display_name=display_name)
|
||||
self._resource = ged_ops.thread_pool_handle(
|
||||
num_threads=num_threads,
|
||||
max_intra_op_parallelism=max_intra_op_parallelism,
|
||||
display_name=display_name)
|
||||
|
||||
|
||||
class _ThreadPoolDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
@ -80,16 +66,10 @@ class _ThreadPoolDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
def __init__(self, input_dataset, thread_pool):
|
||||
self._input_dataset = input_dataset
|
||||
self._thread_pool = thread_pool
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.thread_pool_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._thread_pool._resource, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_thread_pool_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._thread_pool._resource, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
variant_tensor = ged_ops.thread_pool_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._thread_pool._resource, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
super(_ThreadPoolDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||
@ -60,12 +59,7 @@ class _UniqueDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
raise TypeError(
|
||||
"`tf.data.experimental.unique()` only supports inputs with a single "
|
||||
"`tf.int32`, `tf.int64`, or `tf.string` component.")
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = gen_experimental_dataset_ops.unique_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_unique_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
variant_tensor = gen_experimental_dataset_ops.unique_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
super(_UniqueDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import convert
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -84,9 +83,5 @@ class TFRecordWriter(object):
|
||||
"produces shape {0} and types {1}".format(
|
||||
dataset_ops.get_legacy_output_shapes(dataset),
|
||||
dataset_ops.get_legacy_output_types(dataset)))
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
return gen_experimental_dataset_ops.dataset_to_tf_record(
|
||||
dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access
|
||||
else:
|
||||
return gen_experimental_dataset_ops.experimental_dataset_to_tf_record(
|
||||
dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access
|
||||
return gen_experimental_dataset_ops.dataset_to_tf_record(
|
||||
dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access
|
||||
|
@ -31,7 +31,6 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.experimental.ops import optimization_options
|
||||
from tensorflow.python.data.experimental.ops import stats_options
|
||||
@ -1743,12 +1742,8 @@ class DatasetV1(DatasetV2):
|
||||
dataset = self._apply_options()
|
||||
if shared_name is None:
|
||||
shared_name = ""
|
||||
if compat.forward_compatible(2018, 8, 3):
|
||||
iterator_resource = gen_dataset_ops.iterator_v2(
|
||||
container="", shared_name=shared_name, **self._flat_structure)
|
||||
else:
|
||||
iterator_resource = gen_dataset_ops.iterator(
|
||||
container="", shared_name=shared_name, **self._flat_structure)
|
||||
iterator_resource = gen_dataset_ops.iterator_v2(
|
||||
container="", shared_name=shared_name, **self._flat_structure)
|
||||
with ops.colocate_with(iterator_resource):
|
||||
initializer = gen_dataset_ops.make_iterator(
|
||||
dataset._variant_tensor, # pylint: disable=protected-access
|
||||
@ -3755,20 +3750,12 @@ class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset):
|
||||
self._stats_aggregator = aggregator
|
||||
self._prefix = prefix
|
||||
self._counter_prefix = counter_prefix
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.set_stats_aggregator_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._stats_aggregator._resource, # pylint: disable=protected-access
|
||||
self._prefix,
|
||||
self._counter_prefix,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_set_stats_aggregator_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._stats_aggregator._resource, # pylint: disable=protected-access
|
||||
self._prefix,
|
||||
self._counter_prefix,
|
||||
**self._flat_structure)
|
||||
variant_tensor = ged_ops.set_stats_aggregator_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._stats_aggregator._resource, # pylint: disable=protected-access
|
||||
self._prefix,
|
||||
self._counter_prefix,
|
||||
**self._flat_structure)
|
||||
super(_SetStatsAggregatorDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
@ -3782,16 +3769,10 @@ class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset):
|
||||
max_intra_op_parallelism,
|
||||
dtype=dtypes.int64,
|
||||
name="max_intra_op_parallelism")
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.max_intra_op_parallelism_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._max_intra_op_parallelism,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_max_intra_op_parallelism_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._max_intra_op_parallelism,
|
||||
**self._flat_structure)
|
||||
variant_tensor = ged_ops.max_intra_op_parallelism_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._max_intra_op_parallelism,
|
||||
**self._flat_structure)
|
||||
super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
@ -3803,16 +3784,10 @@ class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset):
|
||||
self._input_dataset = input_dataset
|
||||
self._num_threads = ops.convert_to_tensor(
|
||||
num_threads, dtype=dtypes.int64, name="num_threads")
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.private_thread_pool_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._num_threads,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_private_thread_pool_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._num_threads,
|
||||
**self._flat_structure)
|
||||
variant_tensor = ged_ops.private_thread_pool_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._num_threads,
|
||||
**self._flat_structure)
|
||||
super(_PrivateThreadPoolDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
@ -3851,14 +3826,9 @@ class _UnbatchDataset(UnaryDataset):
|
||||
self._structure = nest.map_structure(
|
||||
lambda component_spec: component_spec._unbatch(), # pylint: disable=protected-access
|
||||
get_structure(input_dataset))
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.unbatch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_unbatch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
variant_tensor = ged_ops.unbatch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import optional_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
@ -201,29 +200,22 @@ class Iterator(trackable.Trackable):
|
||||
output_types, output_shapes, output_classes)
|
||||
if shared_name is None:
|
||||
shared_name = ""
|
||||
if compat.forward_compatible(2018, 8, 3):
|
||||
if _device_stack_is_empty():
|
||||
with ops.device("/cpu:0"):
|
||||
iterator_resource = gen_dataset_ops.iterator_v2(
|
||||
container="",
|
||||
shared_name=shared_name,
|
||||
output_types=structure.get_flat_tensor_types(
|
||||
output_structure),
|
||||
output_shapes=structure.get_flat_tensor_shapes(
|
||||
output_structure))
|
||||
else:
|
||||
if _device_stack_is_empty():
|
||||
with ops.device("/cpu:0"):
|
||||
iterator_resource = gen_dataset_ops.iterator_v2(
|
||||
container="",
|
||||
shared_name=shared_name,
|
||||
output_types=structure.get_flat_tensor_types(output_structure),
|
||||
output_types=structure.get_flat_tensor_types(
|
||||
output_structure),
|
||||
output_shapes=structure.get_flat_tensor_shapes(
|
||||
output_structure))
|
||||
else:
|
||||
iterator_resource = gen_dataset_ops.iterator(
|
||||
iterator_resource = gen_dataset_ops.iterator_v2(
|
||||
container="",
|
||||
shared_name=shared_name,
|
||||
output_types=structure.get_flat_tensor_types(output_structure),
|
||||
output_shapes=structure.get_flat_tensor_shapes(output_structure))
|
||||
output_shapes=structure.get_flat_tensor_shapes(
|
||||
output_structure))
|
||||
return Iterator(iterator_resource, None, output_types, output_shapes,
|
||||
output_classes)
|
||||
|
||||
@ -291,20 +283,14 @@ class Iterator(trackable.Trackable):
|
||||
output_structure = structure.convert_legacy_structure(
|
||||
output_types, output_shapes, output_classes)
|
||||
string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
|
||||
if compat.forward_compatible(2018, 8, 3):
|
||||
if _device_stack_is_empty():
|
||||
with ops.device("/cpu:0"):
|
||||
iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
|
||||
string_handle,
|
||||
output_types=structure.get_flat_tensor_types(output_structure),
|
||||
output_shapes=structure.get_flat_tensor_shapes(output_structure))
|
||||
else:
|
||||
if _device_stack_is_empty():
|
||||
with ops.device("/cpu:0"):
|
||||
iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
|
||||
string_handle,
|
||||
output_types=structure.get_flat_tensor_types(output_structure),
|
||||
output_shapes=structure.get_flat_tensor_shapes(output_structure))
|
||||
else:
|
||||
iterator_resource = gen_dataset_ops.iterator_from_string_handle(
|
||||
iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
|
||||
string_handle,
|
||||
output_types=structure.get_flat_tensor_types(output_structure),
|
||||
output_shapes=structure.get_flat_tensor_shapes(output_structure))
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import convert
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -239,28 +238,16 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
|
||||
"prefetch_input_elements",
|
||||
prefetch_input_elements,
|
||||
argument_default=2 * cycle_length)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.parallel_interleave_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._sloppy,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
f=self._map_func.function,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_parallel_interleave_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._sloppy,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
f=self._map_func.function,
|
||||
**self._flat_structure)
|
||||
variant_tensor = ged_ops.parallel_interleave_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._sloppy,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
f=self._map_func.function,
|
||||
**self._flat_structure)
|
||||
super(ParallelInterleaveDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
@ -407,15 +394,9 @@ class _FixedLengthRecordDataset(dataset_ops.DatasetSource):
|
||||
compression_type,
|
||||
argument_default="",
|
||||
argument_dtype=dtypes.string)
|
||||
if (self._compression_type is not None or
|
||||
compat.forward_compatible(2018, 11, 30)):
|
||||
variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2(
|
||||
self._filenames, self._header_bytes, self._record_bytes,
|
||||
self._footer_bytes, self._buffer_size, self._compression_type)
|
||||
else:
|
||||
variant_tensor = gen_dataset_ops.fixed_length_record_dataset(
|
||||
self._filenames, self._header_bytes, self._record_bytes,
|
||||
self._footer_bytes, self._buffer_size)
|
||||
variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2(
|
||||
self._filenames, self._header_bytes, self._record_bytes,
|
||||
self._footer_bytes, self._buffer_size, self._compression_type)
|
||||
super(_FixedLengthRecordDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
|
Loading…
Reference in New Issue
Block a user