[tf.data] Forward compatibility cleanup.

PiperOrigin-RevId: 261816972
This commit is contained in:
Jiri Simsa 2019-08-05 18:43:58 -07:00 committed by TensorFlower Gardener
parent 2d6cca0122
commit 9239c61f20
23 changed files with 180 additions and 458 deletions

View File

@ -362,7 +362,7 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
num_output = 100 // 16 + 1 num_output = 100 // 16 + 1
self.parallelCallsStats( self.parallelCallsStats(
dataset_fn, {"ExperimentalMapAndBatchDataset"}, dataset_fn, {"MapAndBatchDataset"},
num_output, num_output,
check_elements=False, check_elements=False,
function_processing_time=True) function_processing_time=True)
@ -391,7 +391,7 @@ class FeatureStatsDatasetTest(
num_output = total_records // batch_size + 1 num_output = total_records // batch_size + 1
self.parallelCallsStats( self.parallelCallsStats(
dataset_fn, {"ExperimentalParseExampleDataset"}, dataset_fn, {"ParseExampleDataset"},
num_output, num_output,
check_elements=False) check_elements=False)
@ -409,19 +409,19 @@ class FeatureStatsDatasetTest(
handle = self.getHandle(aggregator) handle = self.getHandle(aggregator)
self.assertStatisticsHasCount( self.assertStatisticsHasCount(
handle, handle,
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset", self.regexForNodeName("record_stats::ParseExampleDataset",
"features_count"), total_records) "features_count"), total_records)
self.assertStatisticsHasCount( self.assertStatisticsHasCount(
handle, handle,
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset", self.regexForNodeName("record_stats::ParseExampleDataset",
"feature_values_count"), total_records) "feature_values_count"), total_records)
self.assertStatisticsHasSum( self.assertStatisticsHasSum(
handle, handle,
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset", self.regexForNodeName("record_stats::ParseExampleDataset",
"features_count"), total_records * 4) "features_count"), total_records * 4)
self.assertStatisticsHasSum( self.assertStatisticsHasSum(
handle, handle,
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset", self.regexForNodeName("record_stats::ParseExampleDataset",
"feature_values_count"), "feature_values_count"),
self._sum_keywords(1) * num_epochs + 3 * total_records) self._sum_keywords(1) * num_epochs + 3 * total_records)

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest from tensorflow.python.data.util import nest
@ -247,18 +246,11 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
tensor_shape.TensorShape([None]).concatenate(self._row_shape), tensor_shape.TensorShape([None]).concatenate(self._row_shape),
dataset_ops.get_legacy_output_types(input_dataset)) dataset_ops.get_legacy_output_types(input_dataset))
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.dense_to_sparse_batch_dataset( variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._batch_size, self._batch_size,
row_shape=convert.partial_shape_to_tensor(self._row_shape), row_shape=convert.partial_shape_to_tensor(self._row_shape),
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._batch_size,
row_shape=convert.partial_shape_to_tensor(self._row_shape),
**self._flat_structure)
super(_DenseToSparseBatchDataset, self).__init__(input_dataset, super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
variant_tensor) variant_tensor)
@ -302,7 +294,6 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
lambda component_spec: component_spec._batch(None), lambda component_spec: component_spec._batch(None),
self._map_func.output_structure) self._map_func.output_structure)
# pylint: enable=protected-access # pylint: enable=protected-access
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.map_and_batch_dataset( variant_tensor = ged_ops.map_and_batch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs, self._map_func.function.captured_inputs,
@ -312,16 +303,6 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
drop_remainder=self._drop_remainder_t, drop_remainder=self._drop_remainder_t,
preserve_cardinality=True, preserve_cardinality=True,
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = ged_ops.experimental_map_and_batch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
f=self._map_func.function,
batch_size=self._batch_size_t,
num_parallel_calls=self._num_parallel_calls_t,
drop_remainder=self._drop_remainder_t,
preserve_cardinality=True,
**self._flat_structure)
super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor) super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
def _functions(self): def _functions(self):

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -49,7 +48,4 @@ def cardinality(dataset):
constant `INFINITE_CARDINALITY` and `UNKNOWN_CARDINALITY` respectively. constant `INFINITE_CARDINALITY` and `UNKNOWN_CARDINALITY` respectively.
""" """
if compat.forward_compatible(2019, 8, 3):
return ged_ops.dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access return ged_ops.dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access
else:
return ged_ops.experimental_dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access

View File

@ -49,18 +49,11 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
self._input_dataset = input_dataset self._input_dataset = input_dataset
self._element_spec = input_dataset.element_spec self._element_spec = input_dataset.element_spec
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.auto_shard_dataset( variant_tensor = ged_ops.auto_shard_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers, num_workers=num_workers,
index=index, index=index,
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = ged_ops.experimental_auto_shard_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers,
index=index,
**self._flat_structure)
super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor) super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
@property @property
@ -112,13 +105,8 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
num_workers=num_workers, num_workers=num_workers,
use_fallback=use_fallback, use_fallback=use_fallback,
**self._flat_structure) **self._flat_structure)
elif compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.rebatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers,
**self._flat_structure)
else: else:
variant_tensor = ged_ops.experimental_rebatch_dataset( variant_tensor = ged_ops.rebatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers, num_workers=num_workers,
**self._flat_structure) **self._flat_structure)

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -60,14 +59,8 @@ class _IgnoreErrorsDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset): def __init__(self, input_dataset):
"""See `Dataset.ignore_errors()` for details.""" """See `Dataset.ignore_errors()` for details."""
self._input_dataset = input_dataset self._input_dataset = input_dataset
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ( variant_tensor = (
gen_experimental_dataset_ops.ignore_errors_dataset( gen_experimental_dataset_ops.ignore_errors_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure)) **self._flat_structure))
else:
variant_tensor = (
gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure))
super(_IgnoreErrorsDataset, self).__init__(input_dataset, variant_tensor) super(_IgnoreErrorsDataset, self).__init__(input_dataset, variant_tensor)

View File

@ -19,7 +19,6 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure from tensorflow.python.data.util import structure
@ -255,7 +254,6 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
self._make_init_func(reducer.init_func) self._make_init_func(reducer.init_func)
self._make_reduce_func(reducer.reduce_func, input_dataset) self._make_reduce_func(reducer.reduce_func, input_dataset)
self._make_finalize_func(reducer.finalize_func) self._make_finalize_func(reducer.finalize_func)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.experimental_group_by_reducer_dataset( variant_tensor = ged_ops.experimental_group_by_reducer_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._key_func.function.captured_inputs, self._key_func.function.captured_inputs,
@ -267,18 +265,6 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
reduce_func=self._reduce_func.function, reduce_func=self._reduce_func.function,
finalize_func=self._finalize_func.function, finalize_func=self._finalize_func.function,
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = ged_ops.group_by_reducer_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._key_func.function.captured_inputs,
self._init_func.function.captured_inputs,
self._reduce_func.function.captured_inputs,
self._finalize_func.function.captured_inputs,
key_func=self._key_func.function,
init_func=self._init_func.function,
reduce_func=self._reduce_func.function,
finalize_func=self._finalize_func.function,
**self._flat_structure)
super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor) super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor)
def _make_key_func(self, key_func, input_dataset): def _make_key_func(self, key_func, input_dataset):
@ -390,7 +376,6 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
self._make_key_func(key_func, input_dataset) self._make_key_func(key_func, input_dataset)
self._make_reduce_func(reduce_func, input_dataset) self._make_reduce_func(reduce_func, input_dataset)
self._make_window_size_func(window_size_func) self._make_window_size_func(window_size_func)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.group_by_window_dataset( variant_tensor = ged_ops.group_by_window_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._key_func.function.captured_inputs, self._key_func.function.captured_inputs,
@ -400,16 +385,6 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
reduce_func=self._reduce_func.function, reduce_func=self._reduce_func.function,
window_size_func=self._window_size_func.function, window_size_func=self._window_size_func.function,
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = ged_ops.experimental_group_by_window_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._key_func.function.captured_inputs,
self._reduce_func.function.captured_inputs,
self._window_size_func.function.captured_inputs,
key_func=self._key_func.function,
reduce_func=self._reduce_func.function,
window_size_func=self._window_size_func.function,
**self._flat_structure)
super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor) super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor)
def _make_window_size_func(self, window_size_func): def _make_window_size_func(self, window_size_func):

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import random_ops from tensorflow.python.data.experimental.ops import random_ops
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers from tensorflow.python.data.ops import readers
@ -127,19 +126,11 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
def _as_variant_tensor(self): def _as_variant_tensor(self):
# pylint: disable=protected-access # pylint: disable=protected-access
if compat.forward_compatible(2019, 8, 3):
return ( return (
gen_experimental_dataset_ops.directed_interleave_dataset( gen_experimental_dataset_ops.directed_interleave_dataset(
self._selector_input._variant_tensor, self._selector_input._variant_tensor,
[data_input._variant_tensor for data_input in self._data_inputs], [data_input._variant_tensor for data_input in self._data_inputs],
**self._flat_structure)) **self._flat_structure))
else:
return (
gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
self._selector_input._variant_tensor,
[data_input._variant_tensor for data_input in self._data_inputs],
**self._flat_structure))
# pylint: enable=protected-access
def _inputs(self): def _inputs(self):
return [self._selector_input] + self._data_inputs return [self._selector_input] + self._data_inputs

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -32,11 +31,7 @@ class MatchingFilesDataset(dataset_ops.DatasetSource):
def __init__(self, patterns): def __init__(self, patterns):
self._patterns = ops.convert_to_tensor( self._patterns = ops.convert_to_tensor(
patterns, dtype=dtypes.string, name="patterns") patterns, dtype=dtypes.string, name="patterns")
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.matching_files_dataset(self._patterns) variant_tensor = ged_ops.matching_files_dataset(self._patterns)
else:
variant_tensor = ged_ops.experimental_matching_files_dataset(
self._patterns)
super(MatchingFilesDataset, self).__init__(variant_tensor) super(MatchingFilesDataset, self).__init__(variant_tensor)
@property @property

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -105,18 +104,11 @@ class _AssertNextDataset(dataset_ops.UnaryUnchangedStructureDataset):
raise ValueError("At least one transformation should be specified") raise ValueError("At least one transformation should be specified")
self._transformations = ops.convert_to_tensor( self._transformations = ops.convert_to_tensor(
transformations, dtype=dtypes.string, name="transformations") transformations, dtype=dtypes.string, name="transformations")
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ( variant_tensor = (
gen_experimental_dataset_ops.assert_next_dataset( gen_experimental_dataset_ops.assert_next_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._transformations, self._transformations,
**self._flat_structure)) **self._flat_structure))
else:
variant_tensor = (
gen_experimental_dataset_ops.experimental_assert_next_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._transformations,
**self._flat_structure))
super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor) super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor)
@ -126,16 +118,10 @@ class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset): def __init__(self, input_dataset):
"""See `non_serializable()` for details.""" """See `non_serializable()` for details."""
self._input_dataset = input_dataset self._input_dataset = input_dataset
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ( variant_tensor = (
gen_experimental_dataset_ops.non_serializable_dataset( gen_experimental_dataset_ops.non_serializable_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure)) **self._flat_structure))
else:
variant_tensor = (
gen_experimental_dataset_ops.experimental_non_serializable_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure))
super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor) super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor)
@ -171,18 +157,11 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
""" """
self._datasets = list(datasets) self._datasets = list(datasets)
self._element_spec = self._datasets[0].element_spec self._element_spec = self._datasets[0].element_spec
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ( variant_tensor = (
gen_experimental_dataset_ops.choose_fastest_dataset( gen_experimental_dataset_ops.choose_fastest_dataset(
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access [dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
num_experiments=num_experiments, num_experiments=num_experiments,
**self._flat_structure)) **self._flat_structure))
else:
variant_tensor = (
gen_experimental_dataset_ops.experimental_choose_fastest_dataset(
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
num_experiments=num_experiments,
**self._flat_structure))
super(_ChooseFastestDataset, self).__init__(variant_tensor) super(_ChooseFastestDataset, self).__init__(variant_tensor)
def _inputs(self): def _inputs(self):

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import structure from tensorflow.python.data.util import structure
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -81,7 +80,6 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
self._element_spec = structure.convert_legacy_structure( self._element_spec = structure.convert_legacy_structure(
output_types, output_shapes, output_classes) output_types, output_shapes, output_classes)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ( variant_tensor = (
gen_experimental_dataset_ops.parse_example_dataset( gen_experimental_dataset_ops.parse_example_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
@ -92,17 +90,6 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
self._sparse_types, self._sparse_types,
self._dense_shapes, self._dense_shapes,
**self._flat_structure)) **self._flat_structure))
else:
variant_tensor = (
gen_experimental_dataset_ops.experimental_parse_example_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._num_parallel_calls,
self._dense_defaults,
self._sparse_keys,
self._dense_keys,
self._sparse_types,
self._dense_shapes,
**self._flat_structure))
super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor) super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
@property @property

View File

@ -19,7 +19,6 @@ from __future__ import print_function
import functools import functools
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import random_seed from tensorflow.python.data.util import random_seed
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -35,12 +34,8 @@ class RandomDatasetV2(dataset_ops.DatasetSource):
def __init__(self, seed=None): def __init__(self, seed=None):
"""A `Dataset` of pseudorandom values.""" """A `Dataset` of pseudorandom values."""
self._seed, self._seed2 = random_seed.get_seed(seed) self._seed, self._seed2 = random_seed.get_seed(seed)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = gen_experimental_dataset_ops.random_dataset( variant_tensor = gen_experimental_dataset_ops.random_dataset(
seed=self._seed, seed2=self._seed2, **self._flat_structure) seed=self._seed, seed2=self._seed2, **self._flat_structure)
else:
variant_tensor = gen_experimental_dataset_ops.experimental_random_dataset(
seed=self._seed, seed2=self._seed2, **self._flat_structure)
super(RandomDatasetV2, self).__init__(variant_tensor) super(RandomDatasetV2, self).__init__(variant_tensor)
@property @property

View File

@ -19,12 +19,11 @@ from __future__ import print_function
import collections import collections
import csv import csv
import gzip
import functools import functools
import gzip
import numpy as np import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import error_ops from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.experimental.ops import interleave_ops
@ -687,7 +686,6 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
) )
self._element_spec = tuple( self._element_spec = tuple(
tensor_spec.TensorSpec([], d.dtype) for d in self._record_defaults) tensor_spec.TensorSpec([], d.dtype) for d in self._record_defaults)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = gen_experimental_dataset_ops.csv_dataset( variant_tensor = gen_experimental_dataset_ops.csv_dataset(
filenames=self._filenames, filenames=self._filenames,
record_defaults=self._record_defaults, record_defaults=self._record_defaults,
@ -699,18 +697,6 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
na_value=self._na_value, na_value=self._na_value,
select_cols=self._select_cols, select_cols=self._select_cols,
compression_type=self._compression_type) compression_type=self._compression_type)
else:
variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
filenames=self._filenames,
record_defaults=self._record_defaults,
buffer_size=self._buffer_size,
header=self._header,
output_shapes=self._flat_shapes,
field_delim=self._field_delim,
use_quote_delim=self._use_quote_delim,
na_value=self._na_value,
select_cols=self._select_cols,
compression_type=self._compression_type)
super(CsvDatasetV2, self).__init__(variant_tensor) super(CsvDatasetV2, self).__init__(variant_tensor)
@property @property
@ -993,14 +979,9 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
query, dtype=dtypes.string, name="query") query, dtype=dtypes.string, name="query")
self._element_spec = nest.map_structure( self._element_spec = nest.map_structure(
lambda dtype: tensor_spec.TensorSpec([], dtype), output_types) lambda dtype: tensor_spec.TensorSpec([], dtype), output_types)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = gen_experimental_dataset_ops.sql_dataset( variant_tensor = gen_experimental_dataset_ops.sql_dataset(
self._driver_name, self._data_source_name, self._query, self._driver_name, self._data_source_name, self._query,
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
self._driver_name, self._data_source_name, self._query,
**self._flat_structure)
super(SqlDatasetV2, self).__init__(variant_tensor) super(SqlDatasetV2, self).__init__(variant_tensor)
@property @property

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure from tensorflow.python.data.util import structure
@ -121,7 +120,6 @@ class _ScanDataset(dataset_ops.UnaryDataset):
self._scan_func = wrapped_func self._scan_func = wrapped_func
self._scan_func.function.add_to_graph(ops.get_default_graph()) self._scan_func.function.add_to_graph(ops.get_default_graph())
# pylint: disable=protected-access # pylint: disable=protected-access
if compat.forward_compatible(2019, 8, 3):
variant_tensor = gen_experimental_dataset_ops.scan_dataset( variant_tensor = gen_experimental_dataset_ops.scan_dataset(
self._input_dataset._variant_tensor, self._input_dataset._variant_tensor,
structure.to_tensor_list(self._state_structure, self._initial_state), structure.to_tensor_list(self._state_structure, self._initial_state),
@ -129,14 +127,6 @@ class _ScanDataset(dataset_ops.UnaryDataset):
f=self._scan_func.function, f=self._scan_func.function,
preserve_cardinality=True, preserve_cardinality=True,
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset(
self._input_dataset._variant_tensor,
structure.to_tensor_list(self._state_structure, self._initial_state),
self._scan_func.function.captured_inputs,
f=self._scan_func.function,
preserve_cardinality=True,
**self._flat_structure)
super(_ScanDataset, self).__init__(input_dataset, variant_tensor) super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
def _functions(self): def _functions(self):

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.ops import gen_experimental_dataset_ops
@ -28,16 +27,10 @@ class _SleepDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, sleep_microseconds): def __init__(self, input_dataset, sleep_microseconds):
self._input_dataset = input_dataset self._input_dataset = input_dataset
self._sleep_microseconds = sleep_microseconds self._sleep_microseconds = sleep_microseconds
if compat.forward_compatible(2019, 8, 3):
variant_tensor = gen_experimental_dataset_ops.sleep_dataset( variant_tensor = gen_experimental_dataset_ops.sleep_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._sleep_microseconds, self._sleep_microseconds,
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = gen_experimental_dataset_ops.experimental_sleep_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._sleep_microseconds,
**self._flat_structure)
super(_SleepDataset, self).__init__(input_dataset, variant_tensor) super(_SleepDataset, self).__init__(input_dataset, variant_tensor)

View File

@ -19,7 +19,6 @@ from __future__ import print_function
import tempfile import tempfile
from tensorflow.python.compat import compat
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -126,10 +125,7 @@ class StatsAggregatorV1(object):
def __init__(self): def __init__(self):
"""Creates a `StatsAggregator`.""" """Creates a `StatsAggregator`."""
if compat.forward_compatible(2019, 8, 3):
self._resource = ged_ops.stats_aggregator_handle() self._resource = ged_ops.stats_aggregator_handle()
else:
self._resource = ged_ops.experimental_stats_aggregator_handle()
def get_summary(self): def get_summary(self):
"""Returns a string `tf.Tensor` that summarizes the aggregated statistics. """Returns a string `tf.Tensor` that summarizes the aggregated statistics.
@ -141,10 +137,7 @@ class StatsAggregatorV1(object):
Returns: Returns:
A scalar string `tf.Tensor` that summarizes the aggregated statistics. A scalar string `tf.Tensor` that summarizes the aggregated statistics.
""" """
if compat.forward_compatible(2019, 8, 3):
return ged_ops.stats_aggregator_summary(self._resource) return ged_ops.stats_aggregator_summary(self._resource)
else:
return ged_ops.experimental_stats_aggregator_summary(self._resource)
# TODO(b/116314787): Change this to StatsAggregatorV2 when we have stable # TODO(b/116314787): Change this to StatsAggregatorV2 when we have stable

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -66,14 +65,8 @@ def bytes_produced_stats(tag):
""" """
def _apply_fn(dataset): def _apply_fn(dataset):
if compat.forward_compatible(2019, 8, 3):
return _StatsDataset( return _StatsDataset(
dataset, gen_experimental_dataset_ops.bytes_produced_stats_dataset, dataset, gen_experimental_dataset_ops.bytes_produced_stats_dataset, tag)
tag)
else:
return _StatsDataset(
dataset, gen_experimental_dataset_ops
.experimental_bytes_produced_stats_dataset, tag)
return _apply_fn return _apply_fn
@ -95,14 +88,8 @@ def latency_stats(tag):
""" """
def _apply_fn(dataset): def _apply_fn(dataset):
if compat.forward_compatible(2019, 8, 3):
return _StatsDataset( return _StatsDataset(
dataset, dataset, gen_experimental_dataset_ops.latency_stats_dataset, tag)
gen_experimental_dataset_ops.latency_stats_dataset, tag)
else:
return _StatsDataset(
dataset,
gen_experimental_dataset_ops.experimental_latency_stats_dataset, tag)
return _apply_fn return _apply_fn

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
@ -42,18 +41,11 @@ class _TakeWhileDataset(dataset_ops.UnaryUnchangedStructureDataset):
raise ValueError("`predicate` must return a scalar boolean tensor.") raise ValueError("`predicate` must return a scalar boolean tensor.")
self._predicate = wrapped_func self._predicate = wrapped_func
if compat.forward_compatible(2019, 8, 3):
var_tensor = gen_experimental_dataset_ops.take_while_dataset( var_tensor = gen_experimental_dataset_ops.take_while_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
other_arguments=self._predicate.function.captured_inputs, other_arguments=self._predicate.function.captured_inputs,
predicate=self._predicate.function, predicate=self._predicate.function,
**self._flat_structure) **self._flat_structure)
else:
var_tensor = gen_experimental_dataset_ops.experimental_take_while_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
other_arguments=self._predicate.function.captured_inputs,
predicate=self._predicate.function,
**self._flat_structure)
super(_TakeWhileDataset, self).__init__(input_dataset, var_tensor) super(_TakeWhileDataset, self).__init__(input_dataset, var_tensor)
def _functions(self): def _functions(self):

View File

@ -19,7 +19,6 @@ from __future__ import print_function
import threading import threading
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
@ -47,31 +46,18 @@ class PrivateThreadPool(object):
"""Creates a `PrivateThreadPool` with the given number of threads.""" """Creates a `PrivateThreadPool` with the given number of threads."""
if context.executing_eagerly(): if context.executing_eagerly():
shared_name = _generate_shared_name("privatethreadpool") shared_name = _generate_shared_name("privatethreadpool")
if compat.forward_compatible(2019, 8, 3):
self._resource = ged_ops.thread_pool_handle( self._resource = ged_ops.thread_pool_handle(
num_threads=num_threads, num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism, max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name, display_name=display_name,
shared_name=shared_name) shared_name=shared_name)
else:
self._resource = ged_ops.experimental_thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name,
shared_name=shared_name)
self._resource_deleter = resource_variable_ops.EagerResourceDeleter( self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device=context.context().device_name) handle=self._resource, handle_device=context.context().device_name)
else: else:
if compat.forward_compatible(2019, 8, 3):
self._resource = ged_ops.thread_pool_handle( self._resource = ged_ops.thread_pool_handle(
num_threads=num_threads, num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism, max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name) display_name=display_name)
else:
self._resource = ged_ops.experimental_thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name)
class _ThreadPoolDataset(dataset_ops.UnaryUnchangedStructureDataset): class _ThreadPoolDataset(dataset_ops.UnaryUnchangedStructureDataset):
@ -80,16 +66,10 @@ class _ThreadPoolDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, thread_pool): def __init__(self, input_dataset, thread_pool):
self._input_dataset = input_dataset self._input_dataset = input_dataset
self._thread_pool = thread_pool self._thread_pool = thread_pool
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.thread_pool_dataset( variant_tensor = ged_ops.thread_pool_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._thread_pool._resource, # pylint: disable=protected-access self._thread_pool._resource, # pylint: disable=protected-access
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = ged_ops.experimental_thread_pool_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._thread_pool._resource, # pylint: disable=protected-access
**self._flat_structure)
super(_ThreadPoolDataset, self).__init__(input_dataset, variant_tensor) super(_ThreadPoolDataset, self).__init__(input_dataset, variant_tensor)

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.ops import gen_experimental_dataset_ops
@ -60,12 +59,7 @@ class _UniqueDataset(dataset_ops.UnaryUnchangedStructureDataset):
raise TypeError( raise TypeError(
"`tf.data.experimental.unique()` only supports inputs with a single " "`tf.data.experimental.unique()` only supports inputs with a single "
"`tf.int32`, `tf.int64`, or `tf.string` component.") "`tf.int32`, `tf.int64`, or `tf.string` component.")
if compat.forward_compatible(2019, 8, 3):
variant_tensor = gen_experimental_dataset_ops.unique_dataset( variant_tensor = gen_experimental_dataset_ops.unique_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = gen_experimental_dataset_ops.experimental_unique_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure)
super(_UniqueDataset, self).__init__(input_dataset, variant_tensor) super(_UniqueDataset, self).__init__(input_dataset, variant_tensor)

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert from tensorflow.python.data.util import convert
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -84,9 +83,5 @@ class TFRecordWriter(object):
"produces shape {0} and types {1}".format( "produces shape {0} and types {1}".format(
dataset_ops.get_legacy_output_shapes(dataset), dataset_ops.get_legacy_output_shapes(dataset),
dataset_ops.get_legacy_output_types(dataset))) dataset_ops.get_legacy_output_types(dataset)))
if compat.forward_compatible(2019, 8, 3):
return gen_experimental_dataset_ops.dataset_to_tf_record( return gen_experimental_dataset_ops.dataset_to_tf_record(
dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access
else:
return gen_experimental_dataset_ops.experimental_dataset_to_tf_record(
dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access

View File

@ -31,7 +31,6 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import distribute_options from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.experimental.ops import optimization_options from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.experimental.ops import stats_options from tensorflow.python.data.experimental.ops import stats_options
@ -1743,12 +1742,8 @@ class DatasetV1(DatasetV2):
dataset = self._apply_options() dataset = self._apply_options()
if shared_name is None: if shared_name is None:
shared_name = "" shared_name = ""
if compat.forward_compatible(2018, 8, 3):
iterator_resource = gen_dataset_ops.iterator_v2( iterator_resource = gen_dataset_ops.iterator_v2(
container="", shared_name=shared_name, **self._flat_structure) container="", shared_name=shared_name, **self._flat_structure)
else:
iterator_resource = gen_dataset_ops.iterator(
container="", shared_name=shared_name, **self._flat_structure)
with ops.colocate_with(iterator_resource): with ops.colocate_with(iterator_resource):
initializer = gen_dataset_ops.make_iterator( initializer = gen_dataset_ops.make_iterator(
dataset._variant_tensor, # pylint: disable=protected-access dataset._variant_tensor, # pylint: disable=protected-access
@ -3755,20 +3750,12 @@ class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset):
self._stats_aggregator = aggregator self._stats_aggregator = aggregator
self._prefix = prefix self._prefix = prefix
self._counter_prefix = counter_prefix self._counter_prefix = counter_prefix
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.set_stats_aggregator_dataset( variant_tensor = ged_ops.set_stats_aggregator_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access input_dataset._variant_tensor, # pylint: disable=protected-access
self._stats_aggregator._resource, # pylint: disable=protected-access self._stats_aggregator._resource, # pylint: disable=protected-access
self._prefix, self._prefix,
self._counter_prefix, self._counter_prefix,
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = ged_ops.experimental_set_stats_aggregator_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._stats_aggregator._resource, # pylint: disable=protected-access
self._prefix,
self._counter_prefix,
**self._flat_structure)
super(_SetStatsAggregatorDataset, self).__init__(input_dataset, super(_SetStatsAggregatorDataset, self).__init__(input_dataset,
variant_tensor) variant_tensor)
@ -3782,16 +3769,10 @@ class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset):
max_intra_op_parallelism, max_intra_op_parallelism,
dtype=dtypes.int64, dtype=dtypes.int64,
name="max_intra_op_parallelism") name="max_intra_op_parallelism")
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.max_intra_op_parallelism_dataset( variant_tensor = ged_ops.max_intra_op_parallelism_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access input_dataset._variant_tensor, # pylint: disable=protected-access
self._max_intra_op_parallelism, self._max_intra_op_parallelism,
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = ged_ops.experimental_max_intra_op_parallelism_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._max_intra_op_parallelism,
**self._flat_structure)
super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset, super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset,
variant_tensor) variant_tensor)
@ -3803,16 +3784,10 @@ class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset):
self._input_dataset = input_dataset self._input_dataset = input_dataset
self._num_threads = ops.convert_to_tensor( self._num_threads = ops.convert_to_tensor(
num_threads, dtype=dtypes.int64, name="num_threads") num_threads, dtype=dtypes.int64, name="num_threads")
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.private_thread_pool_dataset( variant_tensor = ged_ops.private_thread_pool_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access input_dataset._variant_tensor, # pylint: disable=protected-access
self._num_threads, self._num_threads,
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = ged_ops.experimental_private_thread_pool_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._num_threads,
**self._flat_structure)
super(_PrivateThreadPoolDataset, self).__init__(input_dataset, super(_PrivateThreadPoolDataset, self).__init__(input_dataset,
variant_tensor) variant_tensor)
@ -3851,14 +3826,9 @@ class _UnbatchDataset(UnaryDataset):
self._structure = nest.map_structure( self._structure = nest.map_structure(
lambda component_spec: component_spec._unbatch(), # pylint: disable=protected-access lambda component_spec: component_spec._unbatch(), # pylint: disable=protected-access
get_structure(input_dataset)) get_structure(input_dataset))
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.unbatch_dataset( variant_tensor = ged_ops.unbatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = ged_ops.experimental_unbatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure)
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor) super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
@property @property

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import threading import threading
import warnings import warnings
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import optional_ops from tensorflow.python.data.ops import optional_ops
from tensorflow.python.data.util import nest from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure from tensorflow.python.data.util import structure
@ -201,7 +200,6 @@ class Iterator(trackable.Trackable):
output_types, output_shapes, output_classes) output_types, output_shapes, output_classes)
if shared_name is None: if shared_name is None:
shared_name = "" shared_name = ""
if compat.forward_compatible(2018, 8, 3):
if _device_stack_is_empty(): if _device_stack_is_empty():
with ops.device("/cpu:0"): with ops.device("/cpu:0"):
iterator_resource = gen_dataset_ops.iterator_v2( iterator_resource = gen_dataset_ops.iterator_v2(
@ -218,12 +216,6 @@ class Iterator(trackable.Trackable):
output_types=structure.get_flat_tensor_types(output_structure), output_types=structure.get_flat_tensor_types(output_structure),
output_shapes=structure.get_flat_tensor_shapes( output_shapes=structure.get_flat_tensor_shapes(
output_structure)) output_structure))
else:
iterator_resource = gen_dataset_ops.iterator(
container="",
shared_name=shared_name,
output_types=structure.get_flat_tensor_types(output_structure),
output_shapes=structure.get_flat_tensor_shapes(output_structure))
return Iterator(iterator_resource, None, output_types, output_shapes, return Iterator(iterator_resource, None, output_types, output_shapes,
output_classes) output_classes)
@ -291,7 +283,6 @@ class Iterator(trackable.Trackable):
output_structure = structure.convert_legacy_structure( output_structure = structure.convert_legacy_structure(
output_types, output_shapes, output_classes) output_types, output_shapes, output_classes)
string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
if compat.forward_compatible(2018, 8, 3):
if _device_stack_is_empty(): if _device_stack_is_empty():
with ops.device("/cpu:0"): with ops.device("/cpu:0"):
iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
@ -303,11 +294,6 @@ class Iterator(trackable.Trackable):
string_handle, string_handle,
output_types=structure.get_flat_tensor_types(output_structure), output_types=structure.get_flat_tensor_types(output_structure),
output_shapes=structure.get_flat_tensor_shapes(output_structure)) output_shapes=structure.get_flat_tensor_shapes(output_structure))
else:
iterator_resource = gen_dataset_ops.iterator_from_string_handle(
string_handle,
output_types=structure.get_flat_tensor_types(output_structure),
output_shapes=structure.get_flat_tensor_shapes(output_structure))
return Iterator(iterator_resource, None, output_types, output_shapes, return Iterator(iterator_resource, None, output_types, output_shapes,
output_classes) output_classes)

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert from tensorflow.python.data.util import convert
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -239,7 +238,6 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
"prefetch_input_elements", "prefetch_input_elements",
prefetch_input_elements, prefetch_input_elements,
argument_default=2 * cycle_length) argument_default=2 * cycle_length)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.parallel_interleave_dataset( variant_tensor = ged_ops.parallel_interleave_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs, self._map_func.function.captured_inputs,
@ -250,17 +248,6 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
self._prefetch_input_elements, self._prefetch_input_elements,
f=self._map_func.function, f=self._map_func.function,
**self._flat_structure) **self._flat_structure)
else:
variant_tensor = ged_ops.experimental_parallel_interleave_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
self._cycle_length,
self._block_length,
self._sloppy,
self._buffer_output_elements,
self._prefetch_input_elements,
f=self._map_func.function,
**self._flat_structure)
super(ParallelInterleaveDataset, self).__init__(input_dataset, super(ParallelInterleaveDataset, self).__init__(input_dataset,
variant_tensor) variant_tensor)
@ -407,15 +394,9 @@ class _FixedLengthRecordDataset(dataset_ops.DatasetSource):
compression_type, compression_type,
argument_default="", argument_default="",
argument_dtype=dtypes.string) argument_dtype=dtypes.string)
if (self._compression_type is not None or
compat.forward_compatible(2018, 11, 30)):
variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2( variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2(
self._filenames, self._header_bytes, self._record_bytes, self._filenames, self._header_bytes, self._record_bytes,
self._footer_bytes, self._buffer_size, self._compression_type) self._footer_bytes, self._buffer_size, self._compression_type)
else:
variant_tensor = gen_dataset_ops.fixed_length_record_dataset(
self._filenames, self._header_bytes, self._record_bytes,
self._footer_bytes, self._buffer_size)
super(_FixedLengthRecordDataset, self).__init__(variant_tensor) super(_FixedLengthRecordDataset, self).__init__(variant_tensor)
@property @property