[tf.data] Forward compatibility cleanup.

PiperOrigin-RevId: 314743800
Change-Id: If3eea28bf8b284921479f9c0f60ccfad19d13333
This commit is contained in:
Jiri Simsa 2020-06-04 09:29:35 -07:00 committed by TensorFlower Gardener
parent d4ff98bb3d
commit 18a8cfaa37
2 changed files with 10 additions and 127 deletions
tensorflow/python/data

View File

@ -23,8 +23,6 @@ import functools
from absl.testing import parameterized
import numpy as np
from tensorflow.python import tf2
from tensorflow.python.compat import compat
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import function
@ -331,17 +329,11 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
# We skip v2 eager since the v2 eager shuffle dataset is not serializable due
# to its use of an external seed generator resource.
@combinations.generate(
combinations.times(
test_base.graph_only_combinations() +
combinations.combine(mode=["eager"]),
test_base.default_test_combinations(),
combinations.combine(reshuffle=[True, False])))
def testRerandomizeOnReplicate(self, reshuffle):
if tf2.enabled() and not compat.forward_compatible(2020, 5, 22):
self.skipTest("Functionality currently not supported.")
random_seed.set_random_seed(None)
# When no seeds are fixed, each instantiation of the shuffle dataset should
# produce elements in a different order.

View File

@ -30,7 +30,6 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin
from tensorflow.core.framework import graph_pb2
from tensorflow.python import tf2
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.experimental.ops import stats_options
@ -3609,54 +3608,6 @@ class RangeDataset(DatasetSource):
return self._structure
# This can be deleted after the forward compatibility window for switching
# to using dummy resource expires on 5/20.
class _MemoryCacheDeleter(object):
"""An object which cleans up an anonymous memory cache resource.
An alternative to defining a __del__ method on an object. Even if the parent
object is part of a reference cycle, the cycle will be collectable.
"""
def __init__(self, handle, device, deleter):
self._deleter = deleter
self._handle = handle
self._device = device
self._eager_mode = context.executing_eagerly()
def __del__(self):
with ops.device(self._device):
# Make sure the resource is deleted in the same mode as it was created in.
if self._eager_mode:
with context.eager_mode():
gen_dataset_ops.delete_memory_cache(
handle=self._handle, deleter=self._deleter)
else:
with context.graph_mode():
gen_dataset_ops.delete_memory_cache(
handle=self._handle, deleter=self._deleter)
# This can be deleted after the forward compatibility window for switching
# to using dummy resource expires on 5/20.
class _MemoryCache(object):
"""Represents a memory cache resource."""
def __init__(self):
super(_MemoryCache, self).__init__()
if compat.forward_compatible(2020, 5, 20):
self._handle = gen_dataset_ops.dummy_memory_cache()
else:
self._device = context.context().device_name
self._handle, self._deleter = gen_dataset_ops.anonymous_memory_cache()
self._resource_deleter = _MemoryCacheDeleter(
handle=self._handle, device=self._device, deleter=self._deleter)
@property
def handle(self):
return self._handle
class CacheDataset(UnaryUnchangedStructureDataset):
"""A `Dataset` that caches elements of its input."""
@ -3666,11 +3617,10 @@ class CacheDataset(UnaryUnchangedStructureDataset):
self._filename = ops.convert_to_tensor(
filename, dtype=dtypes.string, name="filename")
if tf2.enabled() and (context.executing_eagerly() or ops.inside_function()):
self._cache = _MemoryCache()
variant_tensor = gen_dataset_ops.cache_dataset_v2(
input_dataset._variant_tensor, # pylint: disable=protected-access
filename=self._filename,
cache=self._cache.handle,
cache=gen_dataset_ops.dummy_memory_cache(),
**self._flat_structure)
else:
variant_tensor = gen_dataset_ops.cache_dataset(
@ -3680,56 +3630,6 @@ class CacheDataset(UnaryUnchangedStructureDataset):
super(CacheDataset, self).__init__(input_dataset, variant_tensor)
# This can be deleted after the forward compatibility window for switching
# to using dummy resource expires on 5/22.
class _SeedGeneratorDeleter(object):
"""An object which cleans up an anonymous seed generator resource.
An alternative to defining a __del__ method on an object. Even if the parent
object is part of a reference cycle, the cycle will be collectable.
"""
def __init__(self, handle, device, deleter):
self._deleter = deleter
self._handle = handle
self._device = device
self._eager_mode = context.executing_eagerly()
def __del__(self):
with ops.device(self._device):
# Make sure the resource is deleted in the same mode as it was created in.
if self._eager_mode:
with context.eager_mode():
gen_dataset_ops.delete_seed_generator(
handle=self._handle, deleter=self._deleter)
else:
with context.graph_mode():
gen_dataset_ops.delete_seed_generator(
handle=self._handle, deleter=self._deleter)
# This can be deleted after the forward compatibility window for switching
# to using dummy resource expires on 5/22.
class _SeedGenerator(object):
"""Represents a fixed seed generator resource."""
def __init__(self, seed, seed2, reshuffle):
super(_SeedGenerator, self).__init__()
if compat.forward_compatible(2020, 5, 22):
self._handle = gen_dataset_ops.dummy_seed_generator()
else:
self._device = context.context().device_name
self._handle, self._deleter = (
gen_dataset_ops.anonymous_seed_generator(
seed=seed, seed2=seed2, reshuffle=reshuffle))
self._resource_deleter = _SeedGeneratorDeleter(
handle=self._handle, device=self._device, deleter=self._deleter)
@property
def handle(self):
return self._handle
class ShuffleDataset(UnaryUnchangedStructureDataset):
"""A `Dataset` that randomly shuffles the elements of its input."""
@ -3767,23 +3667,14 @@ class ShuffleDataset(UnaryUnchangedStructureDataset):
if (tf2.enabled() and
(context.executing_eagerly() or ops.inside_function())):
self._seed_generator = _SeedGenerator(self._seed, self._seed2,
self._reshuffle_each_iteration)
if compat.forward_compatible(2020, 5, 22):
variant_tensor = gen_dataset_ops.shuffle_dataset_v3(
input_dataset._variant_tensor, # pylint: disable=protected-access
buffer_size=self._buffer_size,
seed=self._seed,
seed2=self._seed2,
seed_generator=self._seed_generator.handle,
reshuffle_each_iteration=self._reshuffle_each_iteration,
**self._flat_structure)
else:
variant_tensor = gen_dataset_ops.shuffle_dataset_v2(
input_dataset._variant_tensor, # pylint: disable=protected-access
buffer_size=self._buffer_size,
seed_generator=self._seed_generator.handle,
**self._flat_structure)
variant_tensor = gen_dataset_ops.shuffle_dataset_v3(
input_dataset._variant_tensor, # pylint: disable=protected-access
buffer_size=self._buffer_size,
seed=self._seed,
seed2=self._seed2,
seed_generator=gen_dataset_ops.dummy_seed_generator(),
reshuffle_each_iteration=self._reshuffle_each_iteration,
**self._flat_structure)
else:
variant_tensor = gen_dataset_ops.shuffle_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access