Clean up expired forward compatibility checks.

PiperOrigin-RevId: 308455990
Change-Id: Icb38f1296a4c0326c3146e36e0e23a8dfaa37695
This commit is contained in:
Andrew Audibert 2020-04-25 19:14:41 -07:00 committed by TensorFlower Gardener
parent 2f55b904b5
commit 3b58a7de85
5 changed files with 34 additions and 84 deletions

View File

@ -19,7 +19,6 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
@ -37,11 +36,8 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testParallelMap(self): def testParallelMap(self):
dataset = dataset_ops.Dataset.range(100) dataset = dataset_ops.Dataset.range(100)
parallel_map = "ParallelMap"
if compat.forward_compatible(2020, 3, 6):
parallel_map = "ParallelMapV2"
dataset = dataset.apply( dataset = dataset.apply(
testing.assert_next([parallel_map, "Prefetch", "FiniteTake"])) testing.assert_next(["ParallelMapV2", "Prefetch", "FiniteTake"]))
dataset = dataset.map( dataset = dataset.map(
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE) lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
dataset = dataset.take(50) dataset = dataset.take(50)
@ -64,11 +60,8 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testParallelInterleave(self): def testParallelInterleave(self):
dataset = dataset_ops.Dataset.range(100) dataset = dataset_ops.Dataset.range(100)
parallel_interleave = "ParallelInterleaveV3"
if compat.forward_compatible(2020, 3, 6):
parallel_interleave = "ParallelInterleaveV4"
dataset = dataset.apply( dataset = dataset.apply(
testing.assert_next([parallel_interleave, "Prefetch", "FiniteTake"])) testing.assert_next(["ParallelInterleaveV4", "Prefetch", "FiniteTake"]))
dataset = dataset.interleave( dataset = dataset.interleave(
lambda x: dataset_ops.Dataset.from_tensors(x + 1), lambda x: dataset_ops.Dataset.from_tensors(x + 1),
num_parallel_calls=dataset_ops.AUTOTUNE) num_parallel_calls=dataset_ops.AUTOTUNE)
@ -79,15 +72,9 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testChainedParallelDatasets(self): def testChainedParallelDatasets(self):
dataset = dataset_ops.Dataset.range(100) dataset = dataset_ops.Dataset.range(100)
parallel_interleave = "ParallelInterleaveV3"
if compat.forward_compatible(2020, 3, 6):
parallel_interleave = "ParallelInterleaveV4"
parallel_map = "ParallelMap"
if compat.forward_compatible(2020, 3, 6):
parallel_map = "ParallelMapV2"
dataset = dataset.apply( dataset = dataset.apply(
testing.assert_next([ testing.assert_next([
parallel_map, "Prefetch", parallel_interleave, "Prefetch", "ParallelMapV2", "Prefetch", "ParallelInterleaveV4", "Prefetch",
"MapAndBatch", "Prefetch", "FiniteTake" "MapAndBatch", "Prefetch", "FiniteTake"
])) ]))
dataset = dataset.map( dataset = dataset.map(

View File

@ -25,7 +25,6 @@ import numpy as np
from tensorflow.core.example import example_pb2 from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2 from tensorflow.core.example import feature_pb2
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
@ -222,9 +221,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
""" """
map_node_name = "Map" map_node_name = "Map"
if num_parallel_calls is not None: if num_parallel_calls is not None:
map_node_name = "ParallelMap" map_node_name = "ParallelMapV2"
if compat.forward_compatible(2020, 3, 6):
map_node_name = "ParallelMapV2"
def _make_dataset(node_names): def _make_dataset(node_names):
dataset = base_dataset.apply(testing.assert_next(node_names)) dataset = base_dataset.apply(testing.assert_next(node_names))

View File

@ -21,7 +21,6 @@ import functools
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
@ -51,8 +50,7 @@ def _test_combinations():
ds = ds.map(lambda x: (x, x), num_parallel_calls=2) # Not eliminated ds = ds.map(lambda x: (x, x), num_parallel_calls=2) # Not eliminated
return ds.map(lambda x, y: (x, y)) # Eliminated return ds.map(lambda x, y: (x, y)) # Eliminated
parallel_map_name = "ParallelMapV2" if compat.forward_compatible( parallel_map_name = "ParallelMapV2"
2020, 3, 6) else "ParallelMap"
cases = [ cases = [
("Skip0", lambda ds: ds.skip(0), None), ("Skip0", lambda ds: ds.skip(0), None),

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import structure from tensorflow.python.data.util import structure
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -85,35 +84,20 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
self._element_spec[key] = ragged_tensor.RaggedTensorSpec( self._element_spec[key] = ragged_tensor.RaggedTensorSpec(
input_dataset_shape.concatenate([None]), value_type, 1, splits_type) input_dataset_shape.concatenate([None]), value_type, 1, splits_type)
if deterministic is not None or compat.forward_compatible(2020, 3, 6): variant_tensor = (
variant_tensor = ( gen_experimental_dataset_ops.parse_example_dataset_v2(
gen_experimental_dataset_ops.parse_example_dataset_v2( self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._num_parallel_calls,
self._num_parallel_calls, self._dense_defaults,
self._dense_defaults, self._sparse_keys,
self._sparse_keys, self._dense_keys,
self._dense_keys, self._sparse_types,
self._sparse_types, self._dense_shapes,
self._dense_shapes, deterministic=self._deterministic,
deterministic=self._deterministic, ragged_keys=self._ragged_keys,
ragged_keys=self._ragged_keys, ragged_value_types=self._ragged_value_types,
ragged_value_types=self._ragged_value_types, ragged_split_types=self._ragged_split_types,
ragged_split_types=self._ragged_split_types, **self._flat_structure))
**self._flat_structure))
else:
variant_tensor = (
gen_experimental_dataset_ops.parse_example_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._num_parallel_calls,
self._dense_defaults,
self._sparse_keys,
self._dense_keys,
self._sparse_types,
self._dense_shapes,
ragged_keys=self._ragged_keys,
ragged_value_types=self._ragged_value_types,
ragged_split_types=self._ragged_split_types,
**self._flat_structure))
super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor) super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
@property @property

View File

@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert from tensorflow.python.data.util import convert
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -249,9 +248,6 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
cycle_length, dtype=dtypes.int64, name="cycle_length") cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor( self._block_length = ops.convert_to_tensor(
block_length, dtype=dtypes.int64, name="block_length") block_length, dtype=dtypes.int64, name="block_length")
if sloppy is not None:
self._sloppy = ops.convert_to_tensor(
sloppy, dtype=dtypes.bool, name="sloppy")
self._buffer_output_elements = convert.optional_param_to_tensor( self._buffer_output_elements = convert.optional_param_to_tensor(
"buffer_output_elements", "buffer_output_elements",
buffer_output_elements, buffer_output_elements,
@ -260,34 +256,22 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
"prefetch_input_elements", "prefetch_input_elements",
prefetch_input_elements, prefetch_input_elements,
argument_default=2 * cycle_length) argument_default=2 * cycle_length)
if sloppy is None or compat.forward_compatible(2020, 3, 6): if sloppy is None:
if sloppy is None: self._deterministic = "default"
self._deterministic = "default" elif sloppy:
elif sloppy: self._deterministic = "false"
self._deterministic = "false"
else:
self._deterministic = "true"
variant_tensor = ged_ops.legacy_parallel_interleave_dataset_v2(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
self._cycle_length,
self._block_length,
self._buffer_output_elements,
self._prefetch_input_elements,
f=self._map_func.function,
deterministic=self._deterministic,
**self._flat_structure)
else: else:
variant_tensor = ged_ops.parallel_interleave_dataset( self._deterministic = "true"
self._input_dataset._variant_tensor, # pylint: disable=protected-access variant_tensor = ged_ops.legacy_parallel_interleave_dataset_v2(
self._map_func.function.captured_inputs, self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._cycle_length, self._map_func.function.captured_inputs,
self._block_length, self._cycle_length,
self._sloppy, self._block_length,
self._buffer_output_elements, self._buffer_output_elements,
self._prefetch_input_elements, self._prefetch_input_elements,
f=self._map_func.function, f=self._map_func.function,
**self._flat_structure) deterministic=self._deterministic,
**self._flat_structure)
super(ParallelInterleaveDataset, self).__init__(input_dataset, super(ParallelInterleaveDataset, self).__init__(input_dataset,
variant_tensor) variant_tensor)