[tf.data] Forward compatibility cleanup.
PiperOrigin-RevId: 314743800 Change-Id: If3eea28bf8b284921479f9c0f60ccfad19d13333
This commit is contained in:
parent
d4ff98bb3d
commit
18a8cfaa37
tensorflow/python/data
@ -23,8 +23,6 @@ import functools
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import function
|
||||
@ -331,17 +329,11 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
# We skip v2 eager since the v2 eager shuffle dataset is not serializable due
|
||||
# to its use of an external seed generator resource.
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.graph_only_combinations() +
|
||||
combinations.combine(mode=["eager"]),
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(reshuffle=[True, False])))
|
||||
def testRerandomizeOnReplicate(self, reshuffle):
|
||||
if tf2.enabled() and not compat.forward_compatible(2020, 5, 22):
|
||||
self.skipTest("Functionality currently not supported.")
|
||||
|
||||
random_seed.set_random_seed(None)
|
||||
# When no seeds are fixed, each instantiation of the shuffle dataset should
|
||||
# produce elements in a different order.
|
||||
|
@ -30,7 +30,6 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.experimental.ops import optimization_options
|
||||
from tensorflow.python.data.experimental.ops import stats_options
|
||||
@ -3609,54 +3608,6 @@ class RangeDataset(DatasetSource):
|
||||
return self._structure
|
||||
|
||||
|
||||
# This can be deleted after the forward compatibility window for switching
|
||||
# to using dummy resource expires on 5/20.
|
||||
class _MemoryCacheDeleter(object):
|
||||
"""An object which cleans up an anonymous memory cache resource.
|
||||
|
||||
An alternative to defining a __del__ method on an object. Even if the parent
|
||||
object is part of a reference cycle, the cycle will be collectable.
|
||||
"""
|
||||
|
||||
def __init__(self, handle, device, deleter):
|
||||
self._deleter = deleter
|
||||
self._handle = handle
|
||||
self._device = device
|
||||
self._eager_mode = context.executing_eagerly()
|
||||
|
||||
def __del__(self):
|
||||
with ops.device(self._device):
|
||||
# Make sure the resource is deleted in the same mode as it was created in.
|
||||
if self._eager_mode:
|
||||
with context.eager_mode():
|
||||
gen_dataset_ops.delete_memory_cache(
|
||||
handle=self._handle, deleter=self._deleter)
|
||||
else:
|
||||
with context.graph_mode():
|
||||
gen_dataset_ops.delete_memory_cache(
|
||||
handle=self._handle, deleter=self._deleter)
|
||||
|
||||
|
||||
# This can be deleted after the forward compatibility window for switching
|
||||
# to using dummy resource expires on 5/20.
|
||||
class _MemoryCache(object):
|
||||
"""Represents a memory cache resource."""
|
||||
|
||||
def __init__(self):
|
||||
super(_MemoryCache, self).__init__()
|
||||
if compat.forward_compatible(2020, 5, 20):
|
||||
self._handle = gen_dataset_ops.dummy_memory_cache()
|
||||
else:
|
||||
self._device = context.context().device_name
|
||||
self._handle, self._deleter = gen_dataset_ops.anonymous_memory_cache()
|
||||
self._resource_deleter = _MemoryCacheDeleter(
|
||||
handle=self._handle, device=self._device, deleter=self._deleter)
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
return self._handle
|
||||
|
||||
|
||||
class CacheDataset(UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that caches elements of its input."""
|
||||
|
||||
@ -3666,11 +3617,10 @@ class CacheDataset(UnaryUnchangedStructureDataset):
|
||||
self._filename = ops.convert_to_tensor(
|
||||
filename, dtype=dtypes.string, name="filename")
|
||||
if tf2.enabled() and (context.executing_eagerly() or ops.inside_function()):
|
||||
self._cache = _MemoryCache()
|
||||
variant_tensor = gen_dataset_ops.cache_dataset_v2(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
filename=self._filename,
|
||||
cache=self._cache.handle,
|
||||
cache=gen_dataset_ops.dummy_memory_cache(),
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_dataset_ops.cache_dataset(
|
||||
@ -3680,56 +3630,6 @@ class CacheDataset(UnaryUnchangedStructureDataset):
|
||||
super(CacheDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
# This can be deleted after the forward compatibility window for switching
|
||||
# to using dummy resource expires on 5/22.
|
||||
class _SeedGeneratorDeleter(object):
|
||||
"""An object which cleans up an anonymous seed generator resource.
|
||||
|
||||
An alternative to defining a __del__ method on an object. Even if the parent
|
||||
object is part of a reference cycle, the cycle will be collectable.
|
||||
"""
|
||||
|
||||
def __init__(self, handle, device, deleter):
|
||||
self._deleter = deleter
|
||||
self._handle = handle
|
||||
self._device = device
|
||||
self._eager_mode = context.executing_eagerly()
|
||||
|
||||
def __del__(self):
|
||||
with ops.device(self._device):
|
||||
# Make sure the resource is deleted in the same mode as it was created in.
|
||||
if self._eager_mode:
|
||||
with context.eager_mode():
|
||||
gen_dataset_ops.delete_seed_generator(
|
||||
handle=self._handle, deleter=self._deleter)
|
||||
else:
|
||||
with context.graph_mode():
|
||||
gen_dataset_ops.delete_seed_generator(
|
||||
handle=self._handle, deleter=self._deleter)
|
||||
|
||||
|
||||
# This can be deleted after the forward compatibility window for switching
|
||||
# to using dummy resource expires on 5/22.
|
||||
class _SeedGenerator(object):
|
||||
"""Represents a fixed seed generator resource."""
|
||||
|
||||
def __init__(self, seed, seed2, reshuffle):
|
||||
super(_SeedGenerator, self).__init__()
|
||||
if compat.forward_compatible(2020, 5, 22):
|
||||
self._handle = gen_dataset_ops.dummy_seed_generator()
|
||||
else:
|
||||
self._device = context.context().device_name
|
||||
self._handle, self._deleter = (
|
||||
gen_dataset_ops.anonymous_seed_generator(
|
||||
seed=seed, seed2=seed2, reshuffle=reshuffle))
|
||||
self._resource_deleter = _SeedGeneratorDeleter(
|
||||
handle=self._handle, device=self._device, deleter=self._deleter)
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
return self._handle
|
||||
|
||||
|
||||
class ShuffleDataset(UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that randomly shuffles the elements of its input."""
|
||||
|
||||
@ -3767,23 +3667,14 @@ class ShuffleDataset(UnaryUnchangedStructureDataset):
|
||||
|
||||
if (tf2.enabled() and
|
||||
(context.executing_eagerly() or ops.inside_function())):
|
||||
self._seed_generator = _SeedGenerator(self._seed, self._seed2,
|
||||
self._reshuffle_each_iteration)
|
||||
if compat.forward_compatible(2020, 5, 22):
|
||||
variant_tensor = gen_dataset_ops.shuffle_dataset_v3(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
buffer_size=self._buffer_size,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
seed_generator=self._seed_generator.handle,
|
||||
reshuffle_each_iteration=self._reshuffle_each_iteration,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_dataset_ops.shuffle_dataset_v2(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
buffer_size=self._buffer_size,
|
||||
seed_generator=self._seed_generator.handle,
|
||||
**self._flat_structure)
|
||||
variant_tensor = gen_dataset_ops.shuffle_dataset_v3(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
buffer_size=self._buffer_size,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
seed_generator=gen_dataset_ops.dummy_seed_generator(),
|
||||
reshuffle_each_iteration=self._reshuffle_each_iteration,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_dataset_ops.shuffle_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user