STT-tensorflow/tensorflow/python/data/kernel_tests/cache_test.py
Jiri Simsa 2a2c812ab2 [tf.data] Migrating remaining core API tests to use TF combinations and performing various minor test cleanup.
PiperOrigin-RevId: 283378763
Change-Id: Ice08340d289406eb691fb261c20329ada7c23c8a
2019-12-02 11:24:07 -08:00

370 lines
13 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.Dataset.cache()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from os import path
import shutil
import tempfile
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
class FileCacheTest(test_base.DatasetTestBase, parameterized.TestCase):
def setUp(self):
super(FileCacheTest, self).setUp()
self.tmp_dir = tempfile.mkdtemp()
self.cache_prefix = path.join(self.tmp_dir, "cache")
def tearDown(self):
if self.tmp_dir:
shutil.rmtree(self.tmp_dir, ignore_errors=True)
super(FileCacheTest, self).tearDown()
@combinations.generate(test_base.default_test_combinations())
def testCacheDatasetPassthrough(self):
components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
np.array([9.0, 10.0, 11.0, 12.0]))
def dataset_fn(count=5, filename=None):
repeat_dataset = (
dataset_ops.Dataset.from_tensor_slices(components).repeat(count))
if filename:
return repeat_dataset.cache(filename)
else:
return repeat_dataset
self.assertEqual(
tuple([c.shape[1:] for c in components]),
dataset_ops.get_legacy_output_shapes(dataset_fn()))
get_next = self.getNext(dataset_fn())
# First run without caching to collect the "ground truth".
elements = []
for _ in range(20):
elements.append(self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
# Assert that the cached dataset has the same elements as the
# "ground truth".
get_next = self.getNext(dataset_fn(filename=self.cache_prefix))
cached_elements = []
for _ in range(20):
cached_elements.append(self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
self.assertAllEqual(elements, cached_elements)
# Re-initialize with an empty upstream (to throw errors.OutOfRangeError
# if we didn't use the cache).
get_next = self.getNext(dataset_fn(count=0, filename=self.cache_prefix))
replayed_elements = []
for _ in range(20):
replayed_elements.append(self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
self.assertEqual(cached_elements, replayed_elements)
# Re-initialize with an empty upstream and a missing cache file (should
# throw errors.OutOfRangeError immediately).
get_next = self.getNext(
dataset_fn(count=0, filename=self.cache_prefix + "nonsense"))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testConcurrentWriters(self):
components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
np.array([9.0, 10.0, 11.0, 12.0]))
cache_dataset1 = (
dataset_ops.Dataset.from_tensor_slices(components).cache(
self.cache_prefix))
cache_dataset2 = (
dataset_ops.Dataset.from_tensor_slices(components).cache(
self.cache_prefix))
get_next1 = self.getNext(cache_dataset1)
get_next2 = self.getNext(cache_dataset2)
self.evaluate(get_next1()) # this should succeed
with self.assertRaises(errors.AlreadyExistsError):
self.evaluate(get_next2())
self.evaluate(get_next1()) # this should continue to succeed
@combinations.generate(test_base.default_test_combinations())
def testConcurrentReaders(self):
components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
np.array([9.0, 10.0, 11.0, 12.0]))
cache_dataset1 = (
dataset_ops.Dataset.from_tensor_slices(components).cache(
self.cache_prefix))
cache_dataset2 = (
dataset_ops.Dataset.from_tensor_slices(components).cache(
self.cache_prefix))
get_next1 = self.getNext(cache_dataset1)
get_next2 = self.getNext(cache_dataset2)
elements = []
for _ in range(4):
elements.append(self.evaluate(get_next1()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next1())
# Re-initialize
get_next1 = self.getNext(cache_dataset1, requires_initialization=True)
get_next2 = self.getNext(cache_dataset2, requires_initialization=True)
# Reading concurrently should succeed.
elements_itr1 = []
elements_itr2 = []
elements_itr2.append(self.evaluate(get_next2()))
elements_itr1.append(self.evaluate(get_next1()))
elements_itr2.append(self.evaluate(get_next2()))
elements_itr1.append(self.evaluate(get_next1()))
# Intentionally reversing the order
elements_itr1.append(self.evaluate(get_next1()))
elements_itr2.append(self.evaluate(get_next2()))
elements_itr1.append(self.evaluate(get_next1()))
elements_itr2.append(self.evaluate(get_next2()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next2())
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next1())
self.assertAllEqual(elements, elements_itr1)
self.assertAllEqual(elements, elements_itr2)
@combinations.generate(test_base.default_test_combinations())
def testReadingPastEndOfSequence(self):
dataset = dataset_ops.Dataset.range(10).cache(self.cache_prefix)
dataset = dataset.map(lambda a: a).batch(4).repeat(2)
expected_output = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]] * 2
self.assertDatasetProduces(dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testCleaningUpCacheFiles(self):
def do_test(i):
dataset = dataset_ops.Dataset.range(10).cache(self.cache_prefix)
get_next = self.getNext(dataset)
for _ in range(i):
try:
self.evaluate(get_next())
except errors.OutOfRangeError:
break
if not context.executing_eagerly():
self.skipTest(
"Test requires eager mode for iterators to be deconstructed")
for i in [0, 3, 10, 12, 15]:
do_test(i)
class MemoryCacheTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testCacheDatasetPassthrough(self):
with ops.device("cpu:0"):
repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64))
dataset = dataset_ops.Dataset.range(3).flat_map(
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count))
cached_dataset = dataset.cache().repeat(2)
uncached_dataset = dataset.repeat(2)
self.evaluate(repeat_count.initializer)
# Needs to be initializable to capture the variable.
cached_next = self.getNext(cached_dataset, requires_initialization=True)
uncached_next = self.getNext(
uncached_dataset, requires_initialization=True)
for i in range(3):
for _ in range(10):
self.assertEqual(self.evaluate(cached_next()), i)
self.assertEqual(self.evaluate(uncached_next()), i)
self.evaluate(repeat_count.assign(0))
# The uncached iterator should now be empty.
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(uncached_next())
# The cached iterator replays from cache.
for i in range(3):
for _ in range(10):
self.assertEqual(self.evaluate(cached_next()), i)
# The cached iterator should now be empty.
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(cached_next())
@combinations.generate(test_base.default_test_combinations())
def testEmptyCacheReading(self):
components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
np.array([9.0, 10.0, 11.0, 12.0]))
repeat_dataset = (
dataset_ops.Dataset.from_tensor_slices(components).repeat(0))
cache_dataset = repeat_dataset.cache()
# Create initialization ops for iterators without and with
# caching, respectively.
self.assertDatasetProduces(cache_dataset, expected_output=[])
@combinations.generate(test_base.default_test_combinations())
def testConcurrentReaders(self):
dataset_fn = lambda: dataset_ops.Dataset.range(5).cache()
d1 = dataset_fn().map(lambda x: x + 1)
d2 = dataset_fn().map(lambda x: x + 6)
get_next1 = self.getNext(d1)
self.assertEqual(1, self.evaluate(get_next1()))
self.assertEqual(2, self.evaluate(get_next1()))
self.assertEqual(3, self.evaluate(get_next1()))
get_next2 = self.getNext(d2)
self.assertEqual(6, self.evaluate(get_next2()))
self.assertEqual(7, self.evaluate(get_next2()))
self.assertEqual(4, self.evaluate(get_next1())) # interleave execution
self.assertEqual([8, 5],
[self.evaluate(get_next2()),
self.evaluate(get_next1())])
self.assertEqual(9, self.evaluate(get_next2()))
self.assertEqual(10, self.evaluate(get_next2()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next2())
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next1())
@combinations.generate(test_base.default_test_combinations())
def testCacheTakeRepeat(self):
dataset = dataset_ops.Dataset.range(10).cache().take(5).repeat(2)
expected_output = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testCacheRepeatEpochs(self):
counter = variables.Variable(0)
self.evaluate(counter.initializer)
def increment_fn(x):
counter.assign_add(1)
return x
dataset = dataset_ops.Dataset.range(10).map(increment_fn).cache().repeat(2)
get_next = self.getNext(dataset, requires_initialization=True)
# first epoch
for i in range(10):
self.assertEqual(i, self.evaluate(counter))
self.assertEqual(i, self.evaluate(get_next()))
# second epoch
for i in range(10):
self.assertEqual(10, self.evaluate(counter))
self.assertEqual(i, self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
def testCacheIterationEpochs(self):
counter = variables.Variable(0)
self.evaluate(counter.initializer)
def increment_fn(x):
counter.assign_add(1)
return x
dataset = dataset_ops.Dataset.range(10).map(increment_fn).cache()
# first epoch
i = 0
for elem in dataset:
self.assertEqual(i, self.evaluate(elem))
i += 1
self.assertEqual(i, self.evaluate(counter))
# second epoch
i = 0
for elem in dataset:
self.assertEqual(10, self.evaluate(counter))
self.assertEqual(i, self.evaluate(elem))
i += 1
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
def testCacheV2ResourceCapture(self):
def make_dataset():
ids = dataset_ops.Dataset.range(10)
ids = ids.cache()
def interleave_fn(dataset, _):
return dataset
dataset = dataset_ops.Dataset.range(1)
dataset = dataset.interleave(functools.partial(interleave_fn, ids))
return dataset
results = []
for elem in make_dataset():
results.append(elem.numpy())
self.assertAllEqual(results, range(10))
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
def testCacheV2ConcurrentIterators(self):
dataset = dataset_ops.Dataset.range(10).cache()
it1 = iter(dataset)
it2 = iter(dataset)
for i in range(10):
self.assertEqual(next(it1), i)
self.assertEqual(next(it2), i)
if __name__ == "__main__":
test.main()