[tf.data] Migrating remaining core API tests to use TF combinations and performing various minor test cleanup.
PiperOrigin-RevId: 283378763 Change-Id: Ice08340d289406eb691fb261c20329ada7c23c8a
This commit is contained in:
parent
2490c87654
commit
2a2c812ab2
@ -34,10 +34,12 @@ class MatchingFilesDatasetTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(MatchingFilesDatasetTest, self).setUp()
|
||||
self.tmp_dir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
super(MatchingFilesDatasetTest, self).tearDown()
|
||||
|
||||
def _touchTempFiles(self, filenames):
|
||||
for filename in filenames:
|
||||
|
@ -57,6 +57,7 @@ class DatasetSerializationTestBase(test.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
self._delete_ckpt()
|
||||
super(DatasetSerializationTestBase, self).tearDown()
|
||||
|
||||
# TODO(b/72657739): Remove sparse_tensor argument, which is to test the
|
||||
# (deprecated) saveable `SparseTensorSliceDataset`, once the API
|
||||
|
@ -287,6 +287,7 @@ tf_py_test(
|
||||
size = "small",
|
||||
srcs = ["iterator_cluster_test.py"],
|
||||
additional_deps = [
|
||||
":test_base",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
@ -400,6 +401,7 @@ tf_py_test(
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/ops/ragged",
|
||||
],
|
||||
shard_count = 4,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
|
@ -61,30 +61,30 @@ class AsNumpyIteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
ds.as_numpy_iterator()
|
||||
|
||||
def checkInvalidElement(self, element):
|
||||
def _testInvalidElement(self, element):
|
||||
ds = dataset_ops.Dataset.from_tensors(element)
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'.*does not support datasets containing.*'):
|
||||
ds.as_numpy_iterator()
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testInvalidElements(self):
|
||||
self.checkInvalidElement(sparse_tensor.SparseTensorValue([[0]], [0], [1]))
|
||||
def testSparseElement(self):
|
||||
self._testInvalidElement(sparse_tensor.SparseTensorValue([[0]], [0], [1]))
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testRaggedElement(self):
|
||||
self.checkInvalidElement(
|
||||
self._testInvalidElement(
|
||||
ragged_tensor_value.RaggedTensorValue(
|
||||
np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64)))
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDatasetElement(self):
|
||||
self.checkInvalidElement(dataset_ops.Dataset.range(3))
|
||||
self._testInvalidElement(dataset_ops.Dataset.range(3))
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testNestedNonTensorElement(self):
|
||||
tuple_elem = (constant_op.constant([1, 2, 3]), dataset_ops.Dataset.range(3))
|
||||
self.checkInvalidElement(tuple_elem)
|
||||
self._testInvalidElement(tuple_elem)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -45,9 +45,9 @@ class FileCacheTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.cache_prefix = path.join(self.tmp_dir, "cache")
|
||||
|
||||
def tearDown(self):
|
||||
super(FileCacheTest, self).tearDown()
|
||||
if self.tmp_dir:
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
super(FileCacheTest, self).tearDown()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCacheDatasetPassthrough(self):
|
||||
|
@ -42,11 +42,11 @@ from tensorflow.python.training.tracking import util as trackable_utils
|
||||
class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
super(CheckpointTest, self).tearDown()
|
||||
prefix = self._iterator_checkpoint_prefix()
|
||||
pattern = prefix + "*"
|
||||
files = gfile.Glob(pattern)
|
||||
map(gfile.Remove, files)
|
||||
super(CheckpointTest, self).tearDown()
|
||||
|
||||
def _iterator_checkpoint_prefix(self):
|
||||
return os.path.join(self.get_temp_dir(), "iterator")
|
||||
@ -66,8 +66,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
iterator_state_variant)
|
||||
return restore_op
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="graph"))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testSaveRestore(self):
|
||||
|
||||
def _build_graph(start, stop):
|
||||
@ -118,8 +117,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="graph"))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testInitThenRestore(self):
|
||||
# Note: Calling init_op before restore_op is redundant. This test just makes
|
||||
# sure we do not fail if restore is called on an already initialized
|
||||
@ -157,8 +155,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="graph"))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testMultipleSaves(self):
|
||||
|
||||
def _build_graph(start, stop):
|
||||
@ -204,8 +201,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="graph"))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testSaveRestoreWithRepeat(self):
|
||||
|
||||
def _build_graph(start, stop, num_epochs):
|
||||
@ -253,8 +249,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="graph"))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testSaveRestoreExhaustedIterator(self):
|
||||
|
||||
def _build_graph(start, stop, num_epochs):
|
||||
@ -295,8 +290,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="eager"))
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testSaveRestoreOneShotIterator(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
@ -319,8 +313,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
get_next()
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="eager"))
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testSaveRestoreMultipleIterator(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
@ -353,8 +346,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertAllEqual([1, 4], get_next_2())
|
||||
self.assertAllEqual(3, get_next_3())
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="eager"))
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testRestoreExhaustedIterator(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
@ -373,8 +365,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
get_next()
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="eager"))
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testRestoreInReconstructedIteratorInitializable(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
|
@ -43,7 +43,6 @@ from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
|
||||
|
||||
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
@ -89,13 +88,13 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
variant, original_dataset.element_spec)
|
||||
self.assertDatasetProduces(revived_dataset, list(original_dataset))
|
||||
|
||||
def checkNumInputs(self, dataset, num_inputs):
|
||||
def _testNumInputs(self, dataset, num_inputs):
|
||||
self.assertLen(dataset._inputs(), num_inputs)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFixedLengthRecordInputs(self):
|
||||
dataset = readers.FixedLengthRecordDataset("", 42)
|
||||
self.checkNumInputs(dataset, 0)
|
||||
self._testNumInputs(dataset, 0)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromGeneratorInputs(self):
|
||||
@ -103,27 +102,27 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
yield 42
|
||||
|
||||
dataset = dataset_ops.Dataset.from_generator(gen, dtypes.int32)
|
||||
self.checkNumInputs(dataset, 1)
|
||||
self._testNumInputs(dataset, 1)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromTensorsInputs(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors([42])
|
||||
self.checkNumInputs(dataset, 0)
|
||||
self._testNumInputs(dataset, 0)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRangeInputs(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
self.checkNumInputs(dataset, 0)
|
||||
self._testNumInputs(dataset, 0)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTextLineInputs(self):
|
||||
dataset = readers.TextLineDataset("")
|
||||
self.checkNumInputs(dataset, 0)
|
||||
self._testNumInputs(dataset, 0)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTFRecordInputs(self):
|
||||
dataset = readers.TFRecordDataset("")
|
||||
self.checkNumInputs(dataset, 1)
|
||||
self._testNumInputs(dataset, 1)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=1, mode=["eager", "graph"]))
|
||||
@ -135,58 +134,58 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dense_shape=np.array([3, 1])))
|
||||
self.assertEmpty(dataset_fn._inputs())
|
||||
|
||||
def checkUnaryInputs(self, dataset_fn):
|
||||
def _testUnaryInputs(self, dataset_fn):
|
||||
input_dataset = dataset_ops.Dataset.range(0)
|
||||
self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testBatchInputs(self):
|
||||
self.checkUnaryInputs(lambda x: x.batch(10))
|
||||
self._testUnaryInputs(lambda x: x.batch(10))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCacheInputs(self):
|
||||
self.checkUnaryInputs(lambda x: x.cache())
|
||||
self._testUnaryInputs(lambda x: x.cache())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFilterInputs(self):
|
||||
self.checkUnaryInputs(lambda x: x.filter(lambda x: True))
|
||||
self._testUnaryInputs(lambda x: x.filter(lambda x: True))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFlatMapInputs(self):
|
||||
self.checkUnaryInputs(
|
||||
self._testUnaryInputs(
|
||||
lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapInputs(self):
|
||||
self.checkUnaryInputs(lambda x: x.map(lambda x: x))
|
||||
self._testUnaryInputs(lambda x: x.map(lambda x: x))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatchInputs(self):
|
||||
self.checkUnaryInputs(lambda x: x.padded_batch(10, []))
|
||||
self._testUnaryInputs(lambda x: x.padded_batch(10, []))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testParallelMapInputs(self):
|
||||
self.checkUnaryInputs(lambda x: x.map(lambda x: x, num_parallel_calls=2))
|
||||
self._testUnaryInputs(lambda x: x.map(lambda x: x, num_parallel_calls=2))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRepeatInputs(self):
|
||||
self.checkUnaryInputs(lambda x: x.repeat())
|
||||
self._testUnaryInputs(lambda x: x.repeat())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testShuffleInputs(self):
|
||||
self.checkUnaryInputs(lambda x: x.shuffle(10))
|
||||
self._testUnaryInputs(lambda x: x.shuffle(10))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSkipInputs(self):
|
||||
self.checkUnaryInputs(lambda x: x.skip(1))
|
||||
self._testUnaryInputs(lambda x: x.skip(1))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTakeInputs(self):
|
||||
self.checkUnaryInputs(lambda x: x.take(1))
|
||||
self._testUnaryInputs(lambda x: x.take(1))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWindowInputs(self):
|
||||
self.checkUnaryInputs(lambda x: x.window(10))
|
||||
self._testUnaryInputs(lambda x: x.window(10))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testUnaryTransformationInputsApply(self):
|
||||
@ -195,7 +194,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertEqual([input_dataset], dataset._inputs())
|
||||
|
||||
def checkInputsWithInterleaveFn(self, dataset_fn, interleave_parallelism):
|
||||
def _testInputsWithInterleaveFn(self, dataset_fn, interleave_parallelism):
|
||||
input_dataset = dataset_ops.Dataset.range(0)
|
||||
dataset = input_dataset.interleave(
|
||||
lambda x: dataset_ops.Dataset.range(0),
|
||||
@ -205,11 +204,11 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testParallelInterleaveInputs(self):
|
||||
self.checkInputsWithInterleaveFn(lambda: dataset_ops.range(0), 2)
|
||||
self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), 2)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testInterleaveInputs(self):
|
||||
self.checkInputsWithInterleaveFn(lambda: dataset_ops.range(0), None)
|
||||
self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), None)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNoWarnings(self):
|
||||
@ -218,16 +217,16 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
lambda x: dataset_ops.Dataset.range(0), cycle_length=2)
|
||||
self.assertEmpty(mock_log.call_args_list)
|
||||
|
||||
def checkBinaryInputs(self, dataset_fn):
|
||||
def _testBinaryInputs(self, dataset_fn):
|
||||
input1 = dataset_ops.Dataset.range(0)
|
||||
input2 = dataset_ops.Dataset.range(1)
|
||||
self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testConcatenateInputs(self):
|
||||
self.checkBinaryInputs(lambda x, y: x.concatenate(y))
|
||||
self._testBinaryInputs(lambda x, y: x.concatenate(y))
|
||||
|
||||
def checkVariadicInputs(self, dataset_fn, input_datasets):
|
||||
def _testVariadicInputs(self, dataset_fn, input_datasets):
|
||||
self.assertEqual(
|
||||
nest.flatten(input_datasets),
|
||||
dataset_fn(input_datasets)._inputs())
|
||||
@ -235,20 +234,20 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testZipOneInputs(self):
|
||||
input_datasets = dataset_ops.Dataset.range(0)
|
||||
self.checkVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
|
||||
self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testZipNestInputs(self):
|
||||
input_datasets = (dataset_ops.Dataset.range(0),
|
||||
(dataset_ops.Dataset.range(1),
|
||||
dataset_ops.Dataset.range(2)))
|
||||
self.checkVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
|
||||
self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testZipTupleInputs(self):
|
||||
input_datasets = (dataset_ops.Dataset.range(0),
|
||||
dataset_ops.Dataset.range(1))
|
||||
self.checkVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
|
||||
self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFunctions(self):
|
||||
@ -273,7 +272,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(2, inputs.count(ds2))
|
||||
self.assertEqual(1, inputs.count(ds3))
|
||||
|
||||
def checkDatasetSpec(self, tf_value, expected_element_structure):
|
||||
def _testDatasetSpec(self, tf_value, expected_element_structure):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value)
|
||||
dataset_structure = structure.type_spec_from_value(dataset)
|
||||
self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec)
|
||||
@ -307,12 +306,12 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTensorDatasetSpec(self):
|
||||
self.checkDatasetSpec(
|
||||
self._testDatasetSpec(
|
||||
constant_op.constant(37.0), tensor_spec.TensorSpec([], dtypes.float32))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSparseTensorDatasetSpec(self):
|
||||
self.checkDatasetSpec(
|
||||
self._testDatasetSpec(
|
||||
sparse_tensor.SparseTensor(
|
||||
indices=[[0]],
|
||||
values=constant_op.constant([0], dtype=dtypes.int32),
|
||||
@ -320,7 +319,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNestDatasetSpec(self):
|
||||
self.checkDatasetSpec(
|
||||
self._testDatasetSpec(
|
||||
{
|
||||
"a": constant_op.constant(37.0),
|
||||
"b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
|
||||
@ -335,20 +334,19 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDatasetDatasetSpec(self):
|
||||
self.checkDatasetSpec(
|
||||
self._testDatasetSpec(
|
||||
dataset_ops.Dataset.from_tensor_slices(
|
||||
constant_op.constant([1, 2, 3])),
|
||||
dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32)))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptionalDatasetSpec(self):
|
||||
self.checkDatasetSpec(
|
||||
self._testDatasetSpec(
|
||||
optional_ops.Optional.from_value(37.0),
|
||||
optional_ops.OptionalSpec(tensor_spec.TensorSpec([], dtypes.float32)))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testSkipEagerSameGraphErrorOneShot(self):
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testSameGraphError(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
|
||||
@ -356,26 +354,27 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testSkipEagerSameGraphErrorOneShotSimple(self):
|
||||
def testSameGraphErrorOneShot(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with ops.Graph().as_default():
|
||||
with test.mock.patch.object(tf_logging, "warning") as mock_log:
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Please ensure that all datasets in the pipeline are "
|
||||
"created in the same graph as the iterator."):
|
||||
_ = dataset_ops.make_one_shot_iterator(dataset)
|
||||
self.assertRegexpMatches(
|
||||
str(mock_log.call_args), "Please ensure that all datasets in the "
|
||||
"pipeline are created in the same graph as the iterator.")
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testSkipEagerSameGraphErrorInitializable(self):
|
||||
def testSameGraphErrorInitializable(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
|
||||
dataset = dataset.batch(2)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Please ensure that all datasets in the pipeline are "
|
||||
"created in the same graph as the iterator."):
|
||||
_ = dataset_ops.make_initializable_iterator(dataset)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="eager"),
|
||||
test_base.eager_only_combinations(),
|
||||
combinations.combine(execution_mode=[context.ASYNC, context.SYNC])))
|
||||
def testEagerIteration(self, execution_mode):
|
||||
with context.execution_mode(execution_mode):
|
||||
|
@ -30,28 +30,31 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def new_and_legacy_filter_fn_combinations():
|
||||
def _test_combinations():
|
||||
|
||||
def new_filter_fn(dataset, predicate):
|
||||
def filter_fn(dataset, predicate):
|
||||
return dataset.filter(predicate)
|
||||
|
||||
def legacy_filter_fn(dataset, predicate):
|
||||
return dataset.filter_with_legacy_function(predicate)
|
||||
|
||||
return (combinations.combine(
|
||||
filter_combinations = combinations.combine(
|
||||
tf_api_version=[1, 2],
|
||||
mode=["eager", "graph"],
|
||||
apply_filter=combinations.NamedObject("new_filter_fn", new_filter_fn)) +
|
||||
combinations.combine(
|
||||
tf_api_version=1,
|
||||
mode=["eager", "graph"],
|
||||
apply_filter=combinations.NamedObject("legacy_filter_fn",
|
||||
legacy_filter_fn)))
|
||||
apply_filter=combinations.NamedObject("filter_fn", filter_fn))
|
||||
|
||||
legacy_filter_combinations = combinations.combine(
|
||||
tf_api_version=1,
|
||||
mode=["eager", "graph"],
|
||||
apply_filter=combinations.NamedObject("legacy_filter_fn",
|
||||
legacy_filter_fn))
|
||||
|
||||
return filter_combinations + legacy_filter_combinations
|
||||
|
||||
|
||||
class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(new_and_legacy_filter_fn_combinations())
|
||||
@combinations.generate(_test_combinations())
|
||||
def testFilterDataset(self, apply_filter):
|
||||
components = (np.arange(7, dtype=np.int64),
|
||||
np.array([[1, 2, 3]], dtype=np.int64) *
|
||||
@ -87,14 +90,14 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
# Test an empty dataset.
|
||||
do_test(0, 1)
|
||||
|
||||
@combinations.generate(new_and_legacy_filter_fn_combinations())
|
||||
@combinations.generate(_test_combinations())
|
||||
def testFilterRange(self, apply_filter):
|
||||
dataset = dataset_ops.Dataset.range(4)
|
||||
dataset = apply_filter(dataset,
|
||||
lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2))
|
||||
self.assertDatasetProduces(dataset, expected_output=[0, 1, 3])
|
||||
|
||||
@combinations.generate(new_and_legacy_filter_fn_combinations())
|
||||
@combinations.generate(_test_combinations())
|
||||
def testFilterDict(self, apply_filter):
|
||||
dataset = dataset_ops.Dataset.range(10).map(
|
||||
lambda x: {"foo": x * 2, "bar": x**2})
|
||||
@ -104,7 +107,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset,
|
||||
expected_output=[(i * 2 + i**2) for i in range(10) if not (i**2) % 2])
|
||||
|
||||
@combinations.generate(new_and_legacy_filter_fn_combinations())
|
||||
@combinations.generate(_test_combinations())
|
||||
def testUseStepContainerInFilter(self, apply_filter):
|
||||
input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
|
||||
|
||||
@ -119,7 +122,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = apply_filter(dataset, _predicate)
|
||||
self.assertDatasetProduces(dataset, expected_output=[input_data[0]])
|
||||
|
||||
@combinations.generate(new_and_legacy_filter_fn_combinations())
|
||||
@combinations.generate(_test_combinations())
|
||||
def testSparse(self, apply_filter):
|
||||
|
||||
def _map_fn(i):
|
||||
@ -137,7 +140,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_output=[_map_fn(i * 2)[0] for i in range(5)])
|
||||
|
||||
@combinations.generate(new_and_legacy_filter_fn_combinations())
|
||||
@combinations.generate(_test_combinations())
|
||||
def testShortCircuit(self, apply_filter):
|
||||
dataset = dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.range(10),
|
||||
@ -146,7 +149,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_output=[(i, True) for i in range(10)])
|
||||
|
||||
@combinations.generate(new_and_legacy_filter_fn_combinations())
|
||||
@combinations.generate(_test_combinations())
|
||||
def testParallelFilters(self, apply_filter):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = apply_filter(dataset, lambda x: math_ops.equal(x % 2, 0))
|
||||
|
@ -66,10 +66,8 @@ class FlatMapTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output.extend([i] * i)
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
# Note: no eager mode coverage, session specific test.
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
def testSkipEagerSharedResourceNestedFlatMapDataset(self):
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testSharedResourceNestedFlatMapDataset(self):
|
||||
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
|
||||
components = np.array(repeats, dtype=np.int64)
|
||||
iterator = (
|
||||
|
@ -32,62 +32,83 @@ from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class DatasetConstructorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _testFromGenerator(self, generator, elem_sequence, num_repeats,
|
||||
output_types=None):
|
||||
if output_types is None:
|
||||
output_types = dtypes.int64
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=output_types).repeat(num_repeats).prefetch(5)
|
||||
self.assertDatasetProduces(
|
||||
dataset,
|
||||
elem_sequence * num_repeats,
|
||||
requires_initialization=True,
|
||||
num_test_iterations=2)
|
||||
|
||||
def _testFromGeneratorOneShot(self, generator, elem_sequence, num_repeats):
|
||||
requires_initialization):
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int64).repeat(num_repeats).prefetch(5)
|
||||
self.assertDatasetProduces(
|
||||
dataset, elem_sequence * num_repeats, num_test_iterations=2)
|
||||
dataset,
|
||||
elem_sequence * num_repeats,
|
||||
requires_initialization=requires_initialization,
|
||||
num_test_iterations=2)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
num_repeats=[1, 5], requires_initialization=[True, False])))
|
||||
def testFromGeneratorUsingFn(self, num_repeats, requires_initialization):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromGeneratorUsingFunction(self):
|
||||
def generator():
|
||||
for i in range(1, 100):
|
||||
yield [i] * i
|
||||
elem_sequence = list(generator())
|
||||
self._testFromGenerator(generator, elem_sequence, 1)
|
||||
self._testFromGenerator(generator, elem_sequence, 5)
|
||||
self._testFromGeneratorOneShot(generator, elem_sequence, 1)
|
||||
self._testFromGeneratorOneShot(generator, elem_sequence, 5)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromGeneratorUsingList(self):
|
||||
elem_sequence = list(generator())
|
||||
self._testFromGenerator(
|
||||
generator,
|
||||
elem_sequence,
|
||||
num_repeats=num_repeats,
|
||||
requires_initialization=requires_initialization)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
num_repeats=[1, 5], requires_initialization=[True, False])))
|
||||
def testFromGeneratorUsingList(self, num_repeats, requires_initialization):
|
||||
generator = lambda: [[i] * i for i in range(1, 100)]
|
||||
elem_sequence = list(generator())
|
||||
self._testFromGenerator(generator, elem_sequence, 1)
|
||||
self._testFromGenerator(generator, elem_sequence, 5)
|
||||
self._testFromGenerator(
|
||||
generator,
|
||||
elem_sequence,
|
||||
num_repeats=num_repeats,
|
||||
requires_initialization=requires_initialization)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromGeneratorUsingNdarray(self):
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
num_repeats=[1, 5], requires_initialization=[True, False])))
|
||||
def testFromGeneratorUsingNdarray(self, num_repeats, requires_initialization):
|
||||
generator = lambda: np.arange(100, dtype=np.int64)
|
||||
elem_sequence = list(generator())
|
||||
self._testFromGenerator(generator, elem_sequence, 1, output_types=np.int64)
|
||||
self._testFromGenerator(generator, elem_sequence, 5, output_types=np.int64)
|
||||
self._testFromGenerator(
|
||||
generator,
|
||||
elem_sequence,
|
||||
num_repeats=num_repeats,
|
||||
requires_initialization=requires_initialization)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromGeneratorUsingGeneratorExpression(self):
|
||||
# NOTE(mrry): Generator *expressions* are not repeatable (or in
|
||||
# general reusable), because they eagerly evaluate the `for`
|
||||
# expression as `iter(range(1, 100))` and discard the means of
|
||||
# reconstructing `range(1, 100)`. Wrapping the generator
|
||||
# expression in a `lambda` makes it repeatable.
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
num_repeats=[1, 5], requires_initialization=[True, False])))
|
||||
def testFromGeneratorUsingGeneratorExpression(self, num_repeats,
|
||||
requires_initialization):
|
||||
# NOTE(mrry): Generator *expressions* are not repeatable (or in general
|
||||
# reusable), because they eagerly evaluate the `for` expression as
|
||||
# `iter(range(1, 100))` and discard the means of reconstructing
|
||||
# `range(1, 100)`. Wrapping the generator expression in a `lambda` makes
|
||||
# it repeatable.
|
||||
generator = lambda: ([i] * i for i in range(1, 100))
|
||||
elem_sequence = list(generator())
|
||||
self._testFromGenerator(generator, elem_sequence, 1)
|
||||
self._testFromGenerator(generator, elem_sequence, 5)
|
||||
self._testFromGenerator(
|
||||
generator,
|
||||
elem_sequence,
|
||||
num_repeats=num_repeats,
|
||||
requires_initialization=requires_initialization)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromMultipleConcurrentGenerators(self):
|
||||
@ -392,7 +413,6 @@ class DatasetConstructorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertAllEqual(37, self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
self.assertTrue(event.is_set())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSharedName(self):
|
||||
|
@ -237,8 +237,8 @@ class FromTensorsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual([3], get_next().shape)
|
||||
|
||||
# TODO(b/121264236): needs mechanism for multiple device in eager mode.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSkipEagerSplitPipeline(self):
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testSplitPipeline(self):
|
||||
with session.Session(
|
||||
target="",
|
||||
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
|
||||
|
@ -17,12 +17,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -37,9 +40,9 @@ from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class IteratorClusterTest(test.TestCase):
|
||||
class IteratorClusterTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testRemoteIteratorWithoutRemoteCallFail(self):
|
||||
worker_config = config_pb2.ConfigProto()
|
||||
worker_config.device_count["CPU"] = 2
|
||||
@ -95,7 +98,7 @@ class IteratorClusterTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(remote_op, feed_dict={target_placeholder: device1})
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testRemoteIteratorUsingRemoteCallOp(self):
|
||||
worker_config = config_pb2.ConfigProto()
|
||||
worker_config.device_count["CPU"] = 2
|
||||
@ -106,7 +109,7 @@ class IteratorClusterTest(test.TestCase):
|
||||
"/job:worker/replica:0/task:0/cpu:1",
|
||||
worker[0].target)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testRemoteIteratorUsingRemoteCallOpCrossProcess(self):
|
||||
workers, _ = test_util.create_local_cluster(2, 1)
|
||||
|
||||
@ -114,7 +117,7 @@ class IteratorClusterTest(test.TestCase):
|
||||
"/job:worker/replica:0/task:1/cpu:0",
|
||||
workers[0].target)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCaptureHashTableInSharedIterator(self):
|
||||
worker, _ = test_util.create_local_cluster(1, 1)
|
||||
|
||||
@ -131,10 +134,10 @@ class IteratorClusterTest(test.TestCase):
|
||||
input_sentences = dataset_ops.Dataset.from_tensor_slices(
|
||||
["brain brain tank salad surgery", "surgery brain"])
|
||||
|
||||
iterator = (
|
||||
input_sentences.map(lambda x: string_ops.string_split([x]).values).map(
|
||||
table.lookup)
|
||||
.make_initializable_iterator(shared_name="shared_iterator"))
|
||||
dataset = input_sentences.map(
|
||||
lambda x: string_ops.string_split([x]).values).map(table.lookup)
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset, shared_name="shared_iterator")
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -148,7 +151,7 @@ class IteratorClusterTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testImplicitDisposeParallelMapDataset(self):
|
||||
# Tests whether a parallel map dataset will be cleaned up correctly when
|
||||
# the pipeline does not run it until exhaustion.
|
||||
|
@ -56,8 +56,7 @@ from tensorflow.python.util import compat
|
||||
|
||||
class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testNoGradients(self):
|
||||
component = constant_op.constant([1.])
|
||||
side = constant_op.constant(0.)
|
||||
@ -68,8 +67,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertIsNone(gradients_impl.gradients(value, side)[0])
|
||||
self.assertIsNone(gradients_impl.gradients(value, [component, side])[0])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCapturingStateInOneShotRaisesException(self):
|
||||
var = variables.Variable(37.0, name="myvar")
|
||||
dataset = (
|
||||
@ -80,8 +78,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
"datasets that capture stateful objects.+myvar"):
|
||||
dataset_ops.make_one_shot_iterator(dataset)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testOneShotIterator(self):
|
||||
components = (np.arange(7),
|
||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||
@ -107,8 +104,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testOneShotIteratorCaptureByValue(self):
|
||||
components = (np.arange(7),
|
||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||
@ -172,8 +168,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testOneShotIteratorNonBlocking(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x)
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
@ -207,13 +202,11 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
self.assertEqual(num_threads, len(results))
|
||||
self.assertEqual(num_threads - 1,
|
||||
len([None for r in results if r is None]))
|
||||
self.assertLen(results, num_threads)
|
||||
self.assertLen([None for r in results if r is None], num_threads - 1)
|
||||
self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testOneShotIteratorInitializerFails(self):
|
||||
# Define a dataset whose initialization will always fail.
|
||||
dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4]))
|
||||
@ -243,8 +236,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testSimpleSharedResource(self):
|
||||
components = (np.array(1, dtype=np.int64),
|
||||
np.array([1, 2, 3], dtype=np.int64),
|
||||
@ -294,8 +286,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testNotInitializedError(self):
|
||||
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
@ -307,8 +298,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
"iterator has not been initialized"):
|
||||
sess.run(get_next)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testReinitializableIterator(self):
|
||||
dataset_3 = dataset_ops.Dataset.from_tensors(
|
||||
constant_op.constant([1, 2, 3]))
|
||||
@ -353,8 +343,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testReinitializableIteratorWithFunctions(self):
|
||||
|
||||
def g():
|
||||
@ -415,8 +404,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
(constant_op.constant([1, 2, 3], dtype=dtypes.int64),
|
||||
constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64))))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testIteratorStringHandle(self):
|
||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||
dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
|
||||
@ -474,8 +462,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(
|
||||
next_element, feed_dict={handle_placeholder: iterator_4_handle})
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testIteratorStringHandleFuture(self):
|
||||
with forward_compat.forward_compatibility_horizon(2018, 8, 4):
|
||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||
@ -541,8 +528,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(
|
||||
next_element, feed_dict={handle_placeholder: iterator_4_handle})
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testIteratorStringHandleReuseTensorObject(self):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||
one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
@ -571,8 +557,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual("foo_1", handle_with_same_name.op.name)
|
||||
self.assertIsNot(handle_with_name, handle_with_same_name)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testIteratorStringHandleError(self):
|
||||
dataset_int_scalar = (
|
||||
dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat())
|
||||
@ -613,8 +598,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
feedable_int_vector.get_next(),
|
||||
feed_dict={handle_placeholder: handle_float_vector}))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
|
||||
worker_config = config_pb2.ConfigProto()
|
||||
worker_config.device_count["CPU"] = 3
|
||||
@ -672,8 +656,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
||||
})
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self):
|
||||
s1 = server_lib.Server.create_local_server()
|
||||
s2 = server_lib.Server.create_local_server()
|
||||
@ -727,8 +710,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(n)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -785,8 +767,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||
})
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testRepeatedGetNextWarning(self):
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset_ops.Dataset.range(10))
|
||||
warnings.simplefilter("always")
|
||||
@ -929,7 +910,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(val, foo.numpy())
|
||||
val += 1
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testOwnedIteratorFunction(self):
|
||||
|
||||
queue = data_flow_ops.FIFOQueue(10, dtypes.int64)
|
||||
@ -946,7 +927,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for i in range(10):
|
||||
self.assertEqual(queue.dequeue().numpy(), i)
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testOwnedIteratorFunctionError(self):
|
||||
# In this test we verify that a function that raises an error ends up
|
||||
# properly deallocating the iterator resource.
|
||||
@ -976,7 +957,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertEqual(queue.size().numpy(), 2)
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testLimitedRetracing(self):
|
||||
trace_count = [0]
|
||||
|
||||
@ -996,7 +977,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(f(iter(dataset2))), 45)
|
||||
self.assertEqual(trace_count[0], 1)
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testNestedFunctionsIteratorResource(self):
|
||||
|
||||
@def_function.function
|
||||
|
@ -35,10 +35,12 @@ from tensorflow.python.util import compat
|
||||
class ListFilesTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(ListFilesTest, self).setUp()
|
||||
self.tmp_dir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
super(ListFilesTest, self).tearDown()
|
||||
|
||||
def _touchTempFiles(self, filenames):
|
||||
for filename in filenames:
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -119,8 +119,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
]
|
||||
self.assertEmpty(tensors, "%d Tensors are still alive." % len(tensors))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="eager"))
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testFilter(self):
|
||||
|
||||
def get_dataset():
|
||||
@ -144,8 +143,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self._testIteratorMemoryLeak(get_dataset)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="eager"))
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testFlatMap(self):
|
||||
|
||||
def get_dataset():
|
||||
@ -157,8 +155,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self._testIteratorMemoryLeak(get_dataset)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="eager"))
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testFromGenerator(self):
|
||||
|
||||
def get_dataset():
|
||||
@ -171,8 +168,8 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self._testIteratorMemoryLeak(get_dataset)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
tf_api_version=[1, 2], mode="eager", num_parallel_calls=[None, 10]))
|
||||
combinations.times(test_base.eager_only_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 10])))
|
||||
def testMap(self, num_parallel_calls):
|
||||
|
||||
def get_dataset():
|
||||
@ -201,8 +198,8 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self._testIteratorMemoryLeak(get_dataset)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
tf_api_version=[1, 2], mode="eager", num_parallel_calls=[None, 10]))
|
||||
combinations.times(test_base.eager_only_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 10])))
|
||||
def testInterleave(self, num_parallel_calls):
|
||||
|
||||
def get_dataset():
|
||||
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
@ -27,6 +29,7 @@ from tensorflow.python.data.ops import optional_ops
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -40,14 +43,90 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
def _optional_spec_test_combinations():
|
||||
# pylint: disable=g-long-lambda
|
||||
cases = [
|
||||
("Dense", lambda: constant_op.constant(37.0),
|
||||
tensor_spec.TensorSpec([], dtypes.float32)),
|
||||
("Sparse", lambda: sparse_tensor.SparseTensor(
|
||||
indices=[[0, 1]],
|
||||
values=constant_op.constant([0], dtype=dtypes.int32),
|
||||
dense_shape=[10, 10]),
|
||||
sparse_tensor.SparseTensorSpec([10, 10], dtypes.int32)),
|
||||
("Nest", lambda: {
|
||||
"a": constant_op.constant(37.0),
|
||||
"b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
|
||||
}, {
|
||||
"a":
|
||||
tensor_spec.TensorSpec([], dtypes.float32),
|
||||
"b": (
|
||||
tensor_spec.TensorSpec([1], dtypes.string),
|
||||
tensor_spec.TensorSpec([], dtypes.string),
|
||||
)
|
||||
}),
|
||||
("Optional", lambda: optional_ops.Optional.from_value(37.0),
|
||||
optional_ops.OptionalSpec(tensor_spec.TensorSpec([], dtypes.float32))),
|
||||
]
|
||||
|
||||
def reduce_fn(x, y):
|
||||
name, value_fn, expected_structure = y
|
||||
return x + combinations.combine(
|
||||
tf_value_fn=combinations.NamedObject(name, value_fn),
|
||||
expected_value_structure=expected_structure)
|
||||
|
||||
return functools.reduce(reduce_fn, cases, [])
|
||||
|
||||
|
||||
def _get_next_as_optional_test_combinations():
|
||||
# pylint: disable=g-long-lambda
|
||||
cases = [
|
||||
("Dense", np.array([1, 2, 3], dtype=np.int32),
|
||||
lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
|
||||
("Sparse",
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 1]],
|
||||
values=np.array([-1., 1.], dtype=np.float32),
|
||||
dense_shape=[2, 2]),
|
||||
lambda: sparse_tensor.SparseTensor(
|
||||
indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]),
|
||||
False),
|
||||
("Nest", {
|
||||
"a":
|
||||
np.array([1, 2, 3], dtype=np.int32),
|
||||
"b":
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 1]],
|
||||
values=np.array([-1., 1.], dtype=np.float32),
|
||||
dense_shape=[2, 2])
|
||||
}, lambda: {
|
||||
"a":
|
||||
constant_op.constant([4, 5, 6], dtype=dtypes.int32),
|
||||
"b":
|
||||
sparse_tensor.SparseTensor(
|
||||
indices=[[0, 1], [1, 0]],
|
||||
values=[37.0, 42.0],
|
||||
dense_shape=[2, 2])
|
||||
}, False),
|
||||
]
|
||||
|
||||
def reduce_fn(x, y):
|
||||
name, value, value_fn, gpu_compatible = y
|
||||
return x + combinations.combine(
|
||||
np_value=value, tf_value_fn=combinations.NamedObject(name, value_fn),
|
||||
gpu_compatible=gpu_compatible)
|
||||
|
||||
return functools.reduce(reduce_fn, cases, [])
|
||||
|
||||
|
||||
class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromValue(self):
|
||||
opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
|
||||
self.assertTrue(self.evaluate(opt.has_value()))
|
||||
self.assertEqual(37.0, self.evaluate(opt.get_value()))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromStructuredValue(self):
|
||||
opt = optional_ops.Optional.from_value({
|
||||
"a": constant_op.constant(37.0),
|
||||
@ -59,6 +138,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
"b": ([b"Foo"], b"Bar")
|
||||
}, self.evaluate(opt.get_value()))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromSparseTensor(self):
|
||||
st_0 = sparse_tensor.SparseTensorValue(
|
||||
indices=np.array([[0]]),
|
||||
@ -77,6 +157,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertAllEqual(expected.dense_shape,
|
||||
self.evaluate(actual.dense_shape))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromNone(self):
|
||||
value_structure = tensor_spec.TensorSpec([], dtypes.float32)
|
||||
opt = optional_ops.Optional.none_from_structure(value_structure)
|
||||
@ -91,6 +172,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(opt.get_value())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAddN(self):
|
||||
devices = ["/cpu:0"]
|
||||
if test_util.is_gpu_available():
|
||||
@ -117,6 +199,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
opt_none1.value_structure)
|
||||
self.assertFalse(self.evaluate(add_opt.has_value()))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNestedAddN(self):
|
||||
devices = ["/cpu:0"]
|
||||
if test_util.is_gpu_available():
|
||||
@ -137,6 +220,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
opt1.value_structure)
|
||||
self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testZerosLike(self):
|
||||
devices = ["/cpu:0"]
|
||||
if test_util.is_gpu_available():
|
||||
@ -159,6 +243,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
opt_none.value_structure)
|
||||
self.assertFalse(self.evaluate(zeros_opt.has_value()))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNestedZerosLike(self):
|
||||
devices = ["/cpu:0"]
|
||||
if test_util.is_gpu_available():
|
||||
@ -175,6 +260,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
opt1.value_structure)
|
||||
self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCopyToGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -204,6 +290,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.evaluate(gpu_optional_with_value_values))
|
||||
self.assertFalse(self.evaluate(gpu_optional_none_has_value))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNestedCopyToGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -239,42 +326,10 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertFalse(self.evaluate(inner_none.has_value()))
|
||||
self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2]))
|
||||
|
||||
def _assertElementValueEqual(self, expected, actual):
|
||||
if isinstance(expected, dict):
|
||||
self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
|
||||
for k in expected.keys():
|
||||
self._assertElementValueEqual(expected[k], actual[k])
|
||||
elif isinstance(expected, sparse_tensor.SparseTensorValue):
|
||||
self.assertAllEqual(expected.indices, actual.indices)
|
||||
self.assertAllEqual(expected.values, actual.values)
|
||||
self.assertAllEqual(expected.dense_shape, actual.dense_shape)
|
||||
else:
|
||||
self.assertAllEqual(expected, actual)
|
||||
|
||||
# pylint: disable=g-long-lambda
|
||||
@parameterized.named_parameters(
|
||||
("Tensor", lambda: constant_op.constant(37.0),
|
||||
tensor_spec.TensorSpec([], dtypes.float32)),
|
||||
("SparseTensor", lambda: sparse_tensor.SparseTensor(
|
||||
indices=[[0, 1]],
|
||||
values=constant_op.constant([0], dtype=dtypes.int32),
|
||||
dense_shape=[10, 10]),
|
||||
sparse_tensor.SparseTensorSpec([10, 10], dtypes.int32)),
|
||||
("Nest", lambda: {
|
||||
"a": constant_op.constant(37.0),
|
||||
"b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
|
||||
}, {
|
||||
"a":
|
||||
tensor_spec.TensorSpec([], dtypes.float32),
|
||||
"b": (
|
||||
tensor_spec.TensorSpec([1], dtypes.string),
|
||||
tensor_spec.TensorSpec([], dtypes.string),
|
||||
)
|
||||
}),
|
||||
("Optional", lambda: optional_ops.Optional.from_value(37.0),
|
||||
optional_ops.OptionalSpec(
|
||||
tensor_spec.TensorSpec([], dtypes.float32))),
|
||||
)
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
_optional_spec_test_combinations()))
|
||||
def testOptionalSpec(self, tf_value_fn, expected_value_structure):
|
||||
tf_value = tf_value_fn()
|
||||
opt = optional_ops.Optional.from_value(tf_value)
|
||||
@ -304,36 +359,21 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
round_trip_opt = opt_structure._from_tensor_list(
|
||||
opt_structure._to_tensor_list(opt))
|
||||
if isinstance(tf_value, optional_ops.Optional):
|
||||
self._assertElementValueEqual(
|
||||
self.assertValuesEqual(
|
||||
self.evaluate(tf_value.get_value()),
|
||||
self.evaluate(round_trip_opt.get_value().get_value()))
|
||||
else:
|
||||
self._assertElementValueEqual(
|
||||
self.assertValuesEqual(
|
||||
self.evaluate(tf_value),
|
||||
self.evaluate(round_trip_opt.get_value()))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Tensor", np.array([1, 2, 3], dtype=np.int32),
|
||||
lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
|
||||
("SparseTensor", sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 1]],
|
||||
values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]),
|
||||
lambda: sparse_tensor.SparseTensor(
|
||||
indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]),
|
||||
False),
|
||||
("Nest", {"a": np.array([1, 2, 3], dtype=np.int32),
|
||||
"b": sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 1]],
|
||||
values=np.array([-1., 1.], dtype=np.float32),
|
||||
dense_shape=[2, 2])},
|
||||
lambda: {"a": constant_op.constant([4, 5, 6], dtype=dtypes.int32),
|
||||
"b": sparse_tensor.SparseTensor(
|
||||
indices=[[0, 1], [1, 0]], values=[37.0, 42.0],
|
||||
dense_shape=[2, 2])}, False),
|
||||
)
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
_get_next_as_optional_test_combinations()))
|
||||
def testIteratorGetNextAsOptional(self, np_value, tf_value_fn,
|
||||
works_on_gpu):
|
||||
if not works_on_gpu and test.is_gpu_available():
|
||||
gpu_compatible):
|
||||
if not gpu_compatible and test.is_gpu_available():
|
||||
self.skipTest("Test case not yet supported on GPU.")
|
||||
ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
|
||||
|
||||
@ -348,7 +388,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
next_elem.value_structure,
|
||||
structure.type_spec_from_value(tf_value_fn())))
|
||||
self.assertTrue(next_elem.has_value())
|
||||
self._assertElementValueEqual(np_value, next_elem.get_value())
|
||||
self.assertValuesEqual(np_value, next_elem.get_value())
|
||||
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
|
||||
# false, and attempting to get the value will fail.
|
||||
for _ in range(2):
|
||||
@ -379,7 +419,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
elem_has_value, elem_value = self.evaluate(
|
||||
[elem_has_value_t, elem_value_t])
|
||||
self.assertTrue(elem_has_value)
|
||||
self._assertElementValueEqual(np_value, elem_value)
|
||||
self.assertValuesEqual(np_value, elem_value)
|
||||
|
||||
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
|
||||
# false, and attempting to get the value will fail.
|
||||
@ -388,6 +428,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(elem_value_t)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFunctionBoundaries(self):
|
||||
@def_function.function
|
||||
def get_optional():
|
||||
@ -407,6 +448,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
val = consume_optional(opt_tensor)
|
||||
self.assertEqual(self.evaluate(val), 1.0)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testLimitedRetracing(self):
|
||||
trace_count = [0]
|
||||
|
||||
|
@ -18,25 +18,31 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import optimization_options
|
||||
from tensorflow.python.data.experimental.ops import stats_options
|
||||
from tensorflow.python.data.experimental.ops import threading_options
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class OptionsTest(test_base.DatasetTestBase):
|
||||
class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptionsDefault(self):
|
||||
ds = dataset_ops.Dataset.range(0)
|
||||
self.assertEqual(dataset_ops.Options(), ds.options())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptionsOnce(self):
|
||||
options = dataset_ops.Options()
|
||||
ds = dataset_ops.Dataset.range(0).with_options(options).cache()
|
||||
self.assertEqual(options, ds.options())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptionsTwiceSame(self):
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.autotune = True
|
||||
@ -44,6 +50,7 @@ class OptionsTest(test_base.DatasetTestBase):
|
||||
options)
|
||||
self.assertEqual(options, ds.options())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptionsTwiceDifferent(self):
|
||||
options1 = dataset_ops.Options()
|
||||
options1.experimental_optimization.autotune = True
|
||||
@ -55,6 +62,7 @@ class OptionsTest(test_base.DatasetTestBase):
|
||||
# Explicitly check that flag is False since assertFalse allows None
|
||||
self.assertIs(ds.options().experimental_deterministic, False)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptionsTwiceDifferentError(self):
|
||||
options1 = dataset_ops.Options()
|
||||
options1.experimental_optimization.autotune = True
|
||||
@ -64,6 +72,7 @@ class OptionsTest(test_base.DatasetTestBase):
|
||||
"Cannot merge incompatible values"):
|
||||
dataset_ops.Dataset.range(0).with_options(options1).with_options(options2)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptionsMergeOptionsFromMultipleInputs(self):
|
||||
options1 = dataset_ops.Options()
|
||||
options1.experimental_optimization.autotune = True
|
||||
@ -75,6 +84,7 @@ class OptionsTest(test_base.DatasetTestBase):
|
||||
self.assertTrue(ds.options().experimental_optimization.autotune)
|
||||
self.assertTrue(ds.options().experimental_deterministic)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptionsHaveDefaults(self):
|
||||
options1 = dataset_ops.Options()
|
||||
options2 = dataset_ops.Options()
|
||||
@ -84,12 +94,11 @@ class OptionsTest(test_base.DatasetTestBase):
|
||||
options2.experimental_stats)
|
||||
self.assertIsNot(options1.experimental_threading,
|
||||
options2.experimental_threading)
|
||||
self.assertEquals(options1.experimental_optimization,
|
||||
optimization_options.OptimizationOptions())
|
||||
self.assertEquals(options1.experimental_stats,
|
||||
stats_options.StatsOptions())
|
||||
self.assertEquals(options1.experimental_threading,
|
||||
threading_options.ThreadingOptions())
|
||||
self.assertEqual(options1.experimental_optimization,
|
||||
optimization_options.OptimizationOptions())
|
||||
self.assertEqual(options1.experimental_stats, stats_options.StatsOptions())
|
||||
self.assertEqual(options1.experimental_threading,
|
||||
threading_options.ThreadingOptions())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -23,43 +23,30 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
def _random_seq_lens(count):
|
||||
return np.random.randint(20, size=(count,)).astype(np.int32)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('default_padding', _random_seq_lens(32), 4, [-1], False),
|
||||
('constant_padding', _random_seq_lens(32), 4, [25], False),
|
||||
('uneven_with_remainder', _random_seq_lens(34), 4, [-1], False),
|
||||
('uneven_without_remainder', _random_seq_lens(34), 4, [-1], True),
|
||||
)
|
||||
def testPaddedBatchDataset(self, seq_lens, batch_size, padded_shapes,
|
||||
drop_remainder):
|
||||
"""Tests the padded batch dataset logic for various input configurations.
|
||||
|
||||
Args:
|
||||
seq_lens: the input sequence lengths
|
||||
batch_size: the batch size
|
||||
padded_shapes: the padded shapes to use
|
||||
drop_remainder: whether a smaller batch size should be produced if batch
|
||||
size does not divide number of inputs evenly
|
||||
"""
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
count=[32, 34],
|
||||
padded_shapes=[[None], [25]],
|
||||
drop_remainder=[True, False])))
|
||||
def testPaddedBatchDataset(self, count, padded_shapes, drop_remainder):
|
||||
seq_lens = np.random.randint(20, size=(count,)).astype(np.int32)
|
||||
batch_size = 4
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
|
||||
lambda x: array_ops.fill([x], x)).padded_batch(
|
||||
batch_size=batch_size,
|
||||
@ -81,7 +68,9 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
if not drop_remainder and len(seq_lens) % batch_size > 0:
|
||||
result = self.evaluate(get_next())
|
||||
padded_len = np.max(result) if result.size > 0 else 0
|
||||
padded_len = padded_shapes[0]
|
||||
if padded_len is None or padded_len == -1:
|
||||
padded_len = np.max(result) if result.size > 0 else 0
|
||||
self.assertEqual((len(seq_lens) % batch_size, padded_len), result.shape)
|
||||
for j in range(len(seq_lens) % batch_size):
|
||||
seq_len = seq_lens[num_full_batches * batch_size + j]
|
||||
@ -93,7 +82,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatchShortPadding(self):
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(
|
||||
@ -102,6 +91,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_error=(errors.DataLossError, ''))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatchEmptyTensors(self):
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(
|
||||
@ -109,6 +99,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
batch_size=4, padded_shapes=[-1]))
|
||||
self.assertDatasetProduces(dataset, expected_output=[[[], [], [], []]])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatchDatasetNonDefaultPadding(self):
|
||||
|
||||
def fill_tuple(x):
|
||||
@ -139,6 +130,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatchDatasetUnicode(self):
|
||||
# See GitHub issue 16149
|
||||
def generator():
|
||||
@ -156,9 +148,8 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
next_element = self.getNext(padded_dataset)
|
||||
self.evaluate(next_element())
|
||||
|
||||
# NOTE: This test is specific to graph mode and is skipped in eager mode.
|
||||
@test_util.run_deprecated_v1
|
||||
def testSkipEagerPaddedBatchDatasetShapeSpecifications(self):
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testPaddedBatchDatasetShapeSpecifications(self):
|
||||
int_placeholder = array_ops.placeholder(dtypes.int32)
|
||||
float_placeholder = array_ops.placeholder(dtypes.float32)
|
||||
string_placeholder = array_ops.placeholder(dtypes.string)
|
||||
@ -190,6 +181,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual([None, None, None], dataset_output_shapes[1].as_list())
|
||||
self.assertEqual([None, 37], dataset_output_shapes[2].as_list())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatchSparseError(self):
|
||||
|
||||
def _map_fn(i):
|
||||
@ -199,6 +191,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
_ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatchShapeError(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r'The padded shape \(1,\) is not compatible with the '
|
||||
@ -230,9 +223,8 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
_ = dataset_ops.Dataset.range(10).padded_batch(
|
||||
5, padded_shapes=shape_as_tensor)
|
||||
|
||||
# NOTE: This test is specific to graph mode and is skipped in eager mode.
|
||||
@test_util.run_deprecated_v1
|
||||
def testSkipEagerPaddedBatchShapeError(self):
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testPaddedBatchShapeErrorPlaceholder(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r'The padded shape \((\?|None), (\?|None)\) is not compatible with the '
|
||||
|
@ -23,36 +23,41 @@ from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class PrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters((-1), (0), (5))
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(buffer_size=[-1, None, 0, 42])))
|
||||
def testBufferSize(self, buffer_size):
|
||||
dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size)
|
||||
self.assertDatasetProduces(dataset, expected_output=range(10))
|
||||
|
||||
@parameterized.parameters((-2), (-42))
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(buffer_size=[-2, -42])))
|
||||
def testInvalidBufferSize(self, buffer_size):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size)
|
||||
self.evaluate(dataset._variant_tensor)
|
||||
|
||||
@parameterized.parameters(*[(buffer_size, slack_period)
|
||||
for buffer_size in (-1, None, 0, 5)
|
||||
for slack_period in (1, 8)])
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
buffer_size=[-1, None, 0, 42], slack_period=[1, 8])))
|
||||
def testPrefetchWithSlack(self, buffer_size, slack_period):
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset_ops.PrefetchDataset(
|
||||
dataset, buffer_size, slack_period=slack_period)
|
||||
self.assertDatasetProduces(dataset, expected_output=range(100))
|
||||
|
||||
@test_util.run_v1_only("graph-mode specific test")
|
||||
def testSkipEagerPrefetchCancellation(self):
|
||||
@combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
|
||||
def testPrefetchCancellation(self):
|
||||
|
||||
def map_py_fn(x):
|
||||
while x > -1:
|
||||
|
@ -17,51 +17,60 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RangeTest(test_base.DatasetTestBase):
|
||||
class RangeTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testStop(self):
|
||||
dataset = dataset_ops.Dataset.range(5)
|
||||
self.assertDatasetProduces(dataset, expected_output=range(5))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testStartStop(self):
|
||||
start, stop = 2, 5
|
||||
dataset = dataset_ops.Dataset.range(start, stop)
|
||||
self.assertDatasetProduces(dataset, expected_output=range(2, 5))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testStartStopStep(self):
|
||||
start, stop, step = 2, 10, 2
|
||||
dataset = dataset_ops.Dataset.range(start, stop, step)
|
||||
self.assertDatasetProduces(dataset, expected_output=range(2, 10, 2))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testZeroStep(self):
|
||||
start, stop, step = 2, 10, 0
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
dataset = dataset_ops.Dataset.range(start, stop, step)
|
||||
self.evaluate(dataset._variant_tensor)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNegativeStep(self):
|
||||
start, stop, step = 2, 10, -1
|
||||
dataset = dataset_ops.Dataset.range(start, stop, step)
|
||||
self.assertDatasetProduces(dataset, expected_output=range(2, 10, -1))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testStopLessThanStart(self):
|
||||
start, stop = 10, 2
|
||||
dataset = dataset_ops.Dataset.range(start, stop)
|
||||
self.assertDatasetProduces(dataset, expected_output=range(10, 2))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testStopLessThanStartWithPositiveStep(self):
|
||||
start, stop, step = 10, 2, 2
|
||||
dataset = dataset_ops.Dataset.range(start, stop, step)
|
||||
self.assertDatasetProduces(dataset, expected_output=range(10, 2, 2))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testStopLessThanStartWithNegativeStep(self):
|
||||
start, stop, step = 10, 2, -1
|
||||
dataset = dataset_ops.Dataset.range(start, stop, step)
|
||||
|
@ -17,43 +17,33 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RepeatTest(test_base.DatasetTestBase):
|
||||
class RepeatTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def testRepeatTensorDataset(self):
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(count=[0, 3, 7])))
|
||||
def testFiniteRepeat(self, count):
|
||||
"""Test a dataset that repeats its input multiple times."""
|
||||
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
|
||||
# This placeholder can be fed when dataset-definition subgraph
|
||||
# runs (i.e. `init_op` below) to configure the number of
|
||||
# repetitions used in a particular iterator.
|
||||
dataset = dataset_ops.Dataset.from_tensors(components).repeat(count)
|
||||
self.assertEqual(
|
||||
[c.shape for c in components],
|
||||
[shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
|
||||
self.assertDatasetProduces(dataset, [components] * count)
|
||||
|
||||
def do_test(count):
|
||||
dataset = dataset_ops.Dataset.from_tensors(components).repeat(count)
|
||||
self.assertEqual(
|
||||
[c.shape for c in components],
|
||||
[shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
|
||||
self.assertDatasetProduces(dataset, [components] * count)
|
||||
|
||||
# Test a finite repetition.
|
||||
do_test(3)
|
||||
|
||||
# test a different finite repetition.
|
||||
do_test(7)
|
||||
|
||||
# Test an empty repetition.
|
||||
do_test(0)
|
||||
|
||||
# Test an infinite repetition.
|
||||
# NOTE(mrry): There's not a good way to test that the sequence
|
||||
# actually is infinite.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testInfiniteRepeat(self):
|
||||
# NOTE(mrry): There's not a good way to test that the sequence is infinite.
|
||||
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
|
||||
dataset = dataset_ops.Dataset.from_tensors(components).repeat(-1)
|
||||
self.assertEqual(
|
||||
[c.shape for c in components],
|
||||
@ -64,7 +54,8 @@ class RepeatTest(test_base.DatasetTestBase):
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
|
||||
def testRepeatRepeatTensorDataset(self):
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRepeatRepeat(self):
|
||||
"""Test the composition of repeat datasets."""
|
||||
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
|
||||
inner_count, outer_count = 7, 14
|
||||
@ -77,11 +68,6 @@ class RepeatTest(test_base.DatasetTestBase):
|
||||
self.assertDatasetProduces(dataset,
|
||||
[components] * (inner_count * outer_count))
|
||||
|
||||
def testRepeatEmptyDataset(self):
|
||||
"""Test that repeating an empty dataset does not hang."""
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).skip(10).repeat(-1)
|
||||
self.assertDatasetProduces(dataset, [])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -17,66 +17,79 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
|
||||
class ShardTest(test_base.DatasetTestBase):
|
||||
class ShardTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSimpleCase(self):
|
||||
dataset = dataset_ops.Dataset.range(10).shard(5, 2)
|
||||
self.assertDatasetProduces(dataset, expected_output=[2, 7])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNestedData(self):
|
||||
dataset_a = dataset_ops.Dataset.range(10)
|
||||
dataset_b = dataset_ops.Dataset.range(10, 0, -1)
|
||||
dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2)
|
||||
self.assertDatasetProduces(dataset, expected_output=[(2, 8), (7, 3)])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOffsetZero(self):
|
||||
dataset = dataset_ops.Dataset.range(10).shard(5, 0)
|
||||
self.assertDatasetProduces(dataset, expected_output=[0, 5])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOffsetGreaterNumShards(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
dataset = dataset_ops.Dataset.range(10).shard(5, 7)
|
||||
self.evaluate(self.getNext(dataset)())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNegativeOffset(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
dataset = dataset_ops.Dataset.range(10).shard(5, -3)
|
||||
self.evaluate(self.getNext(dataset)())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNegativeNumShards(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
dataset = dataset_ops.Dataset.range(10).shard(-3, 1)
|
||||
self.evaluate(self.getNext(dataset)())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testZeroNumShards(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
dataset = dataset_ops.Dataset.range(10).shard(0, 1)
|
||||
self.evaluate(self.getNext(dataset)())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testIteratorEndsBeforeFirstElem(self):
|
||||
dataset = dataset_ops.Dataset.range(1).shard(5, 2)
|
||||
self.assertDatasetProduces(dataset, expected_output=[])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testLargerWorkerPool(self):
|
||||
dataset = dataset_ops.Dataset.range(10).shard(7, 5)
|
||||
self.assertDatasetProduces(dataset, expected_output=[5])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testIndexEqualsNumShards(self):
|
||||
dataset = dataset_ops.Dataset.range(10).shard(5, 4)
|
||||
self.assertDatasetProduces(dataset, expected_output=[4, 9])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testIndexEqualsNumShards2(self):
|
||||
dataset = dataset_ops.Dataset.range(10).shard(4, 3)
|
||||
self.assertDatasetProduces(dataset, expected_output=[3, 7])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNumShardsLargerThanDataset(self):
|
||||
dataset = dataset_ops.Dataset.range(10).shard(20, 5)
|
||||
self.assertDatasetProduces(dataset, expected_output=[5])
|
||||
|
@ -40,7 +40,7 @@ from tensorflow.python.platform import test
|
||||
class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testShuffleDataset(self):
|
||||
def testBasic(self):
|
||||
components = (
|
||||
np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
|
||||
np.array([9.0, 10.0, 11.0, 12.0])
|
||||
@ -160,7 +160,7 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="graph"),
|
||||
test_base.graph_only_combinations(),
|
||||
combinations.combine(reshuffle=[True, False]),
|
||||
combinations.combine(graph_seed=38, op_seed=None) +
|
||||
combinations.combine(graph_seed=None, op_seed=42) +
|
||||
@ -188,7 +188,7 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
# TODO(b/117581999): enable this test for eager-mode.
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="graph"),
|
||||
test_base.graph_only_combinations(),
|
||||
combinations.combine(
|
||||
reshuffle=[True, False], initializable=[True, False])))
|
||||
def testMultipleIterators(self, reshuffle, initializable):
|
||||
@ -278,7 +278,7 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="eager"),
|
||||
test_base.eager_only_combinations(),
|
||||
combinations.combine(reshuffle=[True, False], seed=[None, 42])))
|
||||
def testReshuffleSeparateTransformations(self, reshuffle, seed):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
|
@ -17,46 +17,30 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class SkipTest(test_base.DatasetTestBase):
|
||||
class SkipTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def testSkipTensorDataset(self):
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(count=[-1, 0, 4, 10, 25])))
|
||||
def testBasic(self, count):
|
||||
components = (np.arange(10),)
|
||||
|
||||
def do_test(count):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).skip(count)
|
||||
self.assertEqual(
|
||||
[c.shape[1:] for c in components],
|
||||
[shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
|
||||
start_range = min(count, 10) if count != -1 else 10
|
||||
self.assertDatasetProduces(
|
||||
dataset,
|
||||
[tuple(components[0][i:i + 1]) for i in range(start_range, 10)])
|
||||
|
||||
# Skip fewer than input size, we should skip
|
||||
# the first 4 elements and then read the rest.
|
||||
do_test(4)
|
||||
|
||||
# Skip more than input size: get nothing.
|
||||
do_test(25)
|
||||
|
||||
# Skip exactly input size.
|
||||
do_test(10)
|
||||
|
||||
# Set -1 for 'count': skip the entire dataset.
|
||||
do_test(-1)
|
||||
|
||||
# Skip nothing
|
||||
do_test(0)
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).skip(count)
|
||||
self.assertEqual(
|
||||
[c.shape[1:] for c in components],
|
||||
[shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
|
||||
start_range = min(count, 10) if count != -1 else 10
|
||||
self.assertDatasetProduces(
|
||||
dataset,
|
||||
[tuple(components[0][i:i + 1]) for i in range(start_range, 10)])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,40 +17,30 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TakeTest(test_base.DatasetTestBase):
|
||||
class TakeTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def testTakeTensorDataset(self):
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(count=[-1, 0, 4, 10, 25])))
|
||||
def testBasic(self, count):
|
||||
components = (np.arange(10),)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).take(count)
|
||||
self.assertEqual(
|
||||
[c.shape[1:] for c in components],
|
||||
[shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
|
||||
num_output = min(count, 10) if count != -1 else 10
|
||||
self.assertDatasetProduces(
|
||||
dataset, [tuple(components[0][i:i + 1]) for i in range(num_output)])
|
||||
|
||||
def do_test(count):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).take(count)
|
||||
self.assertEqual(
|
||||
[c.shape[1:] for c in components],
|
||||
[shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
|
||||
num_output = min(count, 10) if count != -1 else 10
|
||||
self.assertDatasetProduces(
|
||||
dataset, [tuple(components[0][i:i + 1]) for i in range(num_output)])
|
||||
|
||||
# Take fewer than input size
|
||||
do_test(4)
|
||||
|
||||
# Take more than input size
|
||||
do_test(25)
|
||||
|
||||
# Take all of input
|
||||
do_test(-1)
|
||||
|
||||
# Take nothing
|
||||
do_test(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -58,7 +58,11 @@ class DatasetTestBase(test.TestCase):
|
||||
|
||||
def assertValuesEqual(self, expected, actual):
|
||||
"""Asserts that two values are equal."""
|
||||
if sparse_tensor.is_sparse(expected):
|
||||
if isinstance(expected, dict):
|
||||
self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
|
||||
for k in expected.keys():
|
||||
self.assertValuesEqual(expected[k], actual[k])
|
||||
elif sparse_tensor.is_sparse(expected):
|
||||
self.assertAllEqual(expected.indices, actual.indices)
|
||||
self.assertAllEqual(expected.values, actual.values)
|
||||
self.assertAllEqual(expected.dense_shape, actual.dense_shape)
|
||||
|
@ -21,11 +21,12 @@ import gzip
|
||||
import os
|
||||
import zlib
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
@ -37,8 +38,7 @@ except ImportError:
|
||||
psutil_import_succeeded = False
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TextLineDatasetTest(test_base.DatasetTestBase):
|
||||
class TextLineDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _lineText(self, f, l):
|
||||
return compat.as_bytes("%d: %d" % (f, l))
|
||||
@ -76,7 +76,11 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
return filenames
|
||||
|
||||
def _testTextLineDataset(self, compression_type=None):
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(compression_type=[None, "GZIP", "ZLIB"])))
|
||||
def testTextLineDataset(self, compression_type):
|
||||
test_filenames = self._createFiles(
|
||||
2, 5, crlf=True, compression_type=compression_type)
|
||||
|
||||
@ -115,6 +119,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
||||
expected_output=[[self._lineText(0, i) for i in range(5)],
|
||||
[self._lineText(1, i) for i in range(5)]] * 10)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTextLineDatasetParallelRead(self):
|
||||
test_filenames = self._createFiles(10, 10)
|
||||
files = dataset_ops.Dataset.from_tensor_slices(test_filenames).repeat(10)
|
||||
@ -125,15 +130,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_output=expected_output * 10, assert_items_equal=True)
|
||||
|
||||
def testTextLineDatasetNoCompression(self):
|
||||
self._testTextLineDataset()
|
||||
|
||||
def testTextLineDatasetGzipCompression(self):
|
||||
self._testTextLineDataset(compression_type="GZIP")
|
||||
|
||||
def testTextLineDatasetZlibCompression(self):
|
||||
self._testTextLineDataset(compression_type="ZLIB")
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTextLineDatasetBuffering(self):
|
||||
test_filenames = self._createFiles(2, 5, crlf=True)
|
||||
|
||||
@ -143,33 +140,33 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
||||
expected_output.extend([self._lineText(j, i) for i in range(5)])
|
||||
self.assertDatasetProduces(repeat_dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testIteratorResourceCleanup(self):
|
||||
filename = os.path.join(self.get_temp_dir(), "text.txt")
|
||||
with open(filename, "wt") as f:
|
||||
for i in range(3):
|
||||
f.write("%d\n" % (i,))
|
||||
with context.eager_mode():
|
||||
first_iterator = iter(readers.TextLineDataset(filename))
|
||||
self.assertEqual(b"0", next(first_iterator).numpy())
|
||||
second_iterator = iter(readers.TextLineDataset(filename))
|
||||
self.assertEqual(b"0", next(second_iterator).numpy())
|
||||
# Eager kernel caching is based on op attributes, which includes the
|
||||
# Dataset's output shape. Create a different kernel to test that they
|
||||
# don't create resources with the same names.
|
||||
different_kernel_iterator = iter(
|
||||
readers.TextLineDataset(filename).repeat().batch(16))
|
||||
self.assertEqual([16], next(different_kernel_iterator).shape)
|
||||
# Remove our references to the Python Iterator objects, which (assuming no
|
||||
# reference cycles) is enough to trigger DestroyResourceOp and close the
|
||||
# partially-read files.
|
||||
del first_iterator
|
||||
del second_iterator
|
||||
del different_kernel_iterator
|
||||
if not psutil_import_succeeded:
|
||||
self.skipTest(
|
||||
"psutil is required to check that we've closed our files.")
|
||||
open_files = psutil.Process().open_files()
|
||||
self.assertNotIn(filename, [open_file.path for open_file in open_files])
|
||||
first_iterator = iter(readers.TextLineDataset(filename))
|
||||
self.assertEqual(b"0", next(first_iterator).numpy())
|
||||
second_iterator = iter(readers.TextLineDataset(filename))
|
||||
self.assertEqual(b"0", next(second_iterator).numpy())
|
||||
# Eager kernel caching is based on op attributes, which includes the
|
||||
# Dataset's output shape. Create a different kernel to test that they
|
||||
# don't create resources with the same names.
|
||||
different_kernel_iterator = iter(
|
||||
readers.TextLineDataset(filename).repeat().batch(16))
|
||||
self.assertEqual([16], next(different_kernel_iterator).shape)
|
||||
# Remove our references to the Python Iterator objects, which (assuming no
|
||||
# reference cycles) is enough to trigger DestroyResourceOp and close the
|
||||
# partially-read files.
|
||||
del first_iterator
|
||||
del second_iterator
|
||||
del different_kernel_iterator
|
||||
if not psutil_import_succeeded:
|
||||
self.skipTest(
|
||||
"psutil is required to check that we've closed our files.")
|
||||
open_files = psutil.Process().open_files()
|
||||
self.assertNotIn(filename, [open_file.path for open_file in open_files])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -21,31 +21,31 @@ import gzip
|
||||
import os
|
||||
import zlib
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import python_io
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||
class TFRecordDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TFRecordDatasetTest, self).setUp()
|
||||
self._num_files = 2
|
||||
self._num_records = 7
|
||||
|
||||
self.test_filenames = self._createFiles()
|
||||
|
||||
def dataset_fn(self,
|
||||
filenames,
|
||||
compression_type="",
|
||||
num_epochs=1,
|
||||
batch_size=None):
|
||||
def _dataset_factory(self,
|
||||
filenames,
|
||||
compression_type="",
|
||||
num_epochs=1,
|
||||
batch_size=None):
|
||||
|
||||
repeat_dataset = readers.TFRecordDataset(
|
||||
filenames, compression_type).repeat(num_epochs)
|
||||
@ -67,6 +67,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||
writer.close()
|
||||
return filenames
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTFRecordDatasetConstructorErrorsTensorInput(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"filenames.*must be.*Tensor.*string"):
|
||||
@ -78,37 +79,40 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(Exception):
|
||||
readers.TFRecordDataset(object())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadOneEpoch(self):
|
||||
# Basic test: read from file 0.
|
||||
dataset = self.dataset_fn(self.test_filenames[0])
|
||||
dataset = self._dataset_factory(self.test_filenames[0])
|
||||
self.assertDatasetProduces(
|
||||
dataset,
|
||||
expected_output=[self._record(0, i) for i in range(self._num_records)])
|
||||
|
||||
# Basic test: read from file 1.
|
||||
dataset = self.dataset_fn(self.test_filenames[1])
|
||||
dataset = self._dataset_factory(self.test_filenames[1])
|
||||
self.assertDatasetProduces(
|
||||
dataset,
|
||||
expected_output=[self._record(1, i) for i in range(self._num_records)])
|
||||
|
||||
# Basic test: read from both files.
|
||||
dataset = self.dataset_fn(self.test_filenames)
|
||||
dataset = self._dataset_factory(self.test_filenames)
|
||||
expected_output = []
|
||||
for j in range(self._num_files):
|
||||
expected_output.extend(
|
||||
[self._record(j, i) for i in range(self._num_records)])
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadTenEpochs(self):
|
||||
dataset = self.dataset_fn(self.test_filenames, num_epochs=10)
|
||||
dataset = self._dataset_factory(self.test_filenames, num_epochs=10)
|
||||
expected_output = []
|
||||
for j in range(self._num_files):
|
||||
expected_output.extend(
|
||||
[self._record(j, i) for i in range(self._num_records)])
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output * 10)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadTenEpochsOfBatches(self):
|
||||
dataset = self.dataset_fn(
|
||||
dataset = self._dataset_factory(
|
||||
self.test_filenames, num_epochs=10, batch_size=self._num_records)
|
||||
expected_output = []
|
||||
for j in range(self._num_files):
|
||||
@ -116,6 +120,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||
[self._record(j, i) for i in range(self._num_records)])
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output * 10)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadZlibFiles(self):
|
||||
zlib_files = []
|
||||
for i, fn in enumerate(self.test_filenames):
|
||||
@ -130,9 +135,10 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||
for j in range(self._num_files):
|
||||
expected_output.extend(
|
||||
[self._record(j, i) for i in range(self._num_records)])
|
||||
dataset = self.dataset_fn(zlib_files, compression_type="ZLIB")
|
||||
dataset = self._dataset_factory(zlib_files, compression_type="ZLIB")
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadGzipFiles(self):
|
||||
gzip_files = []
|
||||
for i, fn in enumerate(self.test_filenames):
|
||||
@ -145,9 +151,10 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||
for j in range(self._num_files):
|
||||
expected_output.extend(
|
||||
[self._record(j, i) for i in range(self._num_records)])
|
||||
dataset = self.dataset_fn(gzip_files, compression_type="GZIP")
|
||||
dataset = self._dataset_factory(gzip_files, compression_type="GZIP")
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadWithBuffer(self):
|
||||
one_mebibyte = 2**20
|
||||
dataset = readers.TFRecordDataset(
|
||||
@ -158,6 +165,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||
[self._record(j, i) for i in range(self._num_records)])
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadFromDatasetOfFiles(self):
|
||||
files = dataset_ops.Dataset.from_tensor_slices(self.test_filenames)
|
||||
expected_output = []
|
||||
@ -167,6 +175,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||
dataset = readers.TFRecordDataset(files)
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadTenEpochsFromDatasetOfFilesInParallel(self):
|
||||
files = dataset_ops.Dataset.from_tensor_slices(
|
||||
self.test_filenames).repeat(10)
|
||||
|
@ -23,11 +23,11 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
@ -36,13 +36,14 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testUnbatchWithUnknownRankInput(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]).unbatch()
|
||||
self.assertDatasetProduces(dataset, range(4))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testUnbatchScalarDataset(self):
|
||||
data = tuple([math_ops.range(10) for _ in range(3)])
|
||||
data = dataset_ops.Dataset.from_tensor_slices(data)
|
||||
@ -54,12 +55,14 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertDatasetProduces(data, [(i,) * 3 for i in range(10)])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testUnbatchNestedDataset(self):
|
||||
data = dataset_ops.Dataset.from_tensors(
|
||||
[dataset_ops.Dataset.range(10) for _ in range(10)])
|
||||
data = data.unbatch().flat_map(lambda x: x)
|
||||
self.assertDatasetProduces(data, list(range(10)) * 10)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testUnbatchDatasetWithStrings(self):
|
||||
data = tuple([math_ops.range(10) for _ in range(3)])
|
||||
data = dataset_ops.Dataset.from_tensor_slices(data)
|
||||
@ -73,6 +76,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
data, [(i, compat.as_bytes(str(i)), i) for i in range(10)])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testUnbatchDatasetWithSparseTensor(self):
|
||||
st = sparse_tensor.SparseTensorValue(
|
||||
indices=[[i, i] for i in range(10)],
|
||||
@ -87,6 +91,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
]
|
||||
self.assertDatasetProduces(data, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testUnbatchDatasetWithDenseSparseAndRaggedTensor(self):
|
||||
st = sparse_tensor.SparseTensorValue(
|
||||
indices=[[i, i] for i in range(10)],
|
||||
@ -104,6 +109,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
data, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testUnbatchDatasetWithRaggedTensor(self):
|
||||
rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
|
||||
[[5]], [[6]], [[7]], [[8]], [[9]]])
|
||||
@ -119,6 +125,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
data, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testUnbatchSingleElementTupleDataset(self):
|
||||
data = tuple([(math_ops.range(10),) for _ in range(3)])
|
||||
data = dataset_ops.Dataset.from_tensor_slices(data)
|
||||
@ -130,6 +137,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertDatasetProduces(data, [((i,),) * 3 for i in range(10)])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testUnbatchMultiElementTupleDataset(self):
|
||||
data = tuple([(math_ops.range(10 * i, 10 * i + 10),
|
||||
array_ops.fill([10], "hi")) for i in range(3)])
|
||||
@ -146,6 +154,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
data,
|
||||
[((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")) for i in range(10)])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testUnbatchEmpty(self):
|
||||
data = dataset_ops.Dataset.from_tensors(
|
||||
(constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
|
||||
@ -153,15 +162,15 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
data = data.unbatch()
|
||||
self.assertDatasetProduces(data, [])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testUnbatchStaticShapeMismatch(self):
|
||||
data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
|
||||
np.arange(9)))
|
||||
with self.assertRaises(ValueError):
|
||||
data.unbatch()
|
||||
|
||||
# Note: dynamic shape mismatch is graph specific test.
|
||||
@test_util.run_deprecated_v1
|
||||
def testSkipEagerUnbatchDynamicShapeMismatch(self):
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testUnbatchDynamicShapeMismatch(self):
|
||||
ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
|
||||
ph2 = array_ops.placeholder(dtypes.int32, shape=None)
|
||||
data = dataset_ops.Dataset.from_tensors((ph1, ph2))
|
||||
@ -190,6 +199,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testUnbatchDatasetWithUintDtypes(self):
|
||||
components = (
|
||||
np.tile(np.array([[0], [1], [2], [3]], dtype=np.uint8), 2),
|
||||
|
@ -24,43 +24,32 @@ from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", 20, 14, 7, 1),
|
||||
("2", 20, 17, 9, 1),
|
||||
("3", 20, 14, 14, 1),
|
||||
("4", 20, 10, 14, 1),
|
||||
("5", 20, 14, 19, 1),
|
||||
("6", 20, 4, 1, 2),
|
||||
("7", 20, 2, 1, 6),
|
||||
("8", 20, 4, 7, 2),
|
||||
("9", 20, 2, 7, 6),
|
||||
("10", 1, 10, 4, 1),
|
||||
("11", 0, 10, 4, 1),
|
||||
("12", 20, 14, 7, 1, False),
|
||||
("13", 20, 17, 9, 1, False),
|
||||
("14", 20, 14, 14, 1, False),
|
||||
("15", 20, 10, 14, 1, False),
|
||||
("16", 20, 14, 19, 1, False),
|
||||
("17", 20, 4, 1, 2, False),
|
||||
("18", 20, 2, 1, 6, False),
|
||||
("19", 20, 4, 7, 2, False),
|
||||
("20", 20, 2, 7, 6, False),
|
||||
("21", 1, 10, 4, 1, False),
|
||||
("22", 0, 10, 4, 1, False),
|
||||
)
|
||||
def testWindowDataset(self, count, size, shift, stride, drop_remainder=True):
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
count=20,
|
||||
size=[10, 14, 17],
|
||||
shift=[7, 14],
|
||||
stride=[1, 2, 6],
|
||||
drop_remainder=[True, False]) + combinations.combine(
|
||||
count=[0, 1],
|
||||
size=10,
|
||||
shift=4,
|
||||
stride=1,
|
||||
drop_remainder=[True, False])))
|
||||
def testWindowDataset(self, count, size, shift, stride, drop_remainder):
|
||||
"""Tests a dataset that slides a window its input elements."""
|
||||
components = (np.arange(7),
|
||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||
@ -111,11 +100,12 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", 14, 0, 3, 1),
|
||||
("2", 14, 3, 0, 1),
|
||||
("3", 14, 3, 3, 0),
|
||||
)
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(count=20, size=0, shift=3, stride=1) +
|
||||
combinations.combine(count=20, size=3, shift=0, stride=1) +
|
||||
combinations.combine(count=20, size=3, shift=3, stride=0)))
|
||||
def testWindowDatasetInvalid(self, count, size, shift, stride):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
ds = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count).window(
|
||||
@ -123,12 +113,14 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
stride=stride).flat_map(lambda x: x.batch(batch_size=size))
|
||||
self.evaluate(ds._variant_tensor)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWindowDifferentNestedStructures(self):
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4])).window(2)
|
||||
self.getNext(ds)
|
||||
ds = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2]}).window(2)
|
||||
self.getNext(ds)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWindowSparse(self):
|
||||
|
||||
def _sparse(i):
|
||||
@ -148,6 +140,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
]
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWindowSparseWithDifferentDenseShapes(self):
|
||||
|
||||
def _sparse(i):
|
||||
@ -177,6 +170,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dense_shape=[5, i * 3 + 5 - 1]))
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNestedWindowSparse(self):
|
||||
|
||||
def _sparse(i):
|
||||
@ -205,6 +199,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
]
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWindowShapeError(self):
|
||||
|
||||
def generator():
|
||||
@ -222,6 +217,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
r"Cannot batch tensors with different shapes in component 0. "
|
||||
r"First element had shape \[3\] and element 2 had shape \[4\]."))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWindowIgnoreErrors(self):
|
||||
input_values = np.float32([1., np.nan, 2., np.nan, 3.])
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
|
||||
@ -232,6 +228,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset, expected_output=[np.float32([1., 2.]),
|
||||
np.float32([2., 3.])])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNestedOutput(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest("self.evaluate() does not work with a dataset")
|
||||
|
@ -17,66 +17,68 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ZipTest(test_base.DatasetTestBase):
|
||||
def _dataset_factory(components):
|
||||
datasets = tuple([
|
||||
dataset_ops.Dataset.from_tensor_slices(component)
|
||||
for component in components
|
||||
])
|
||||
return dataset_ops.Dataset.zip(datasets)
|
||||
|
||||
def testZipDataset(self):
|
||||
|
||||
def dataset_fn(components):
|
||||
datasets = tuple([
|
||||
dataset_ops.Dataset.from_tensor_slices(component)
|
||||
for component in components
|
||||
])
|
||||
return dataset_ops.Dataset.zip(datasets)
|
||||
class ZipTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
equal_length_components = [
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testZipEqual(self):
|
||||
components = [
|
||||
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||
np.tile(np.array([[12], [13], [14], [15]]), 22),
|
||||
np.array([37.0, 38.0, 39.0, 40.0])
|
||||
]
|
||||
|
||||
get_next = self.getNext(dataset_fn(equal_length_components))
|
||||
get_next = self.getNext(_dataset_factory(components))
|
||||
for i in range(4):
|
||||
results = self.evaluate(get_next())
|
||||
for component, result_component in zip(equal_length_components, results):
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
variable_length_components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]]
|
||||
get_next = self.getNext(dataset_fn(variable_length_components))
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testZipUnequal(self):
|
||||
components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]]
|
||||
get_next = self.getNext(_dataset_factory(components))
|
||||
for i in range(2):
|
||||
results = self.evaluate(get_next())
|
||||
for component, result_component in zip(variable_length_components,
|
||||
results):
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
def testNestedZipDataset(self):
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNested(self):
|
||||
|
||||
equal_length_components = [
|
||||
components = [
|
||||
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||
np.tile(np.array([[12], [13], [14], [15]]), 22),
|
||||
np.array([37.0, 38.0, 39.0, 40.0])
|
||||
]
|
||||
datasets = [
|
||||
dataset_ops.Dataset.from_tensor_slices(component)
|
||||
for component in equal_length_components
|
||||
for component in components
|
||||
]
|
||||
dataset = dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
|
||||
|
||||
@ -88,9 +90,9 @@ class ZipTest(test_base.DatasetTestBase):
|
||||
get_next = self.getNext(dataset)
|
||||
for i in range(4):
|
||||
result1, (result2, result3) = self.evaluate(get_next())
|
||||
self.assertAllEqual(equal_length_components[0][i], result1)
|
||||
self.assertAllEqual(equal_length_components[1][i], result2)
|
||||
self.assertAllEqual(equal_length_components[2][i], result3)
|
||||
self.assertAllEqual(components[0][i], result1)
|
||||
self.assertAllEqual(components[1][i], result2)
|
||||
self.assertAllEqual(components[2][i], result3)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
|
@ -66,7 +66,6 @@ from tensorflow.python.ops import gen_io_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.tracking import base as tracking_base
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util import deprecation
|
||||
@ -2437,26 +2436,25 @@ class DatasetV1Adapter(DatasetV1):
|
||||
|
||||
def _ensure_same_dataset_graph(dataset):
|
||||
"""Walks the dataset graph to ensure all datasets come from the same graph."""
|
||||
# pylint: disable=protected-access
|
||||
current_graph = ops.get_default_graph()
|
||||
bfs_q = Queue.Queue()
|
||||
bfs_q.put(dataset) # pylint: disable=protected-access
|
||||
bfs_q.put(dataset)
|
||||
visited = []
|
||||
while not bfs_q.empty():
|
||||
ds = bfs_q.get()
|
||||
visited.append(ds)
|
||||
ds_graph = ds._graph # pylint: disable=protected-access
|
||||
ds_graph = ds._graph
|
||||
if current_graph != ds_graph:
|
||||
logging.warning("The graph (" + str(current_graph) + ") of the iterator "
|
||||
"is different from the graph (" + str(ds_graph) + ") "
|
||||
"the dataset: " + str(ds._variant_tensor) + " was " # pylint: disable=protected-access
|
||||
"created in. If you are using the Estimator API, "
|
||||
"make sure that no part of the dataset returned by the "
|
||||
"`input_fn` function is defined outside the `input_fn` "
|
||||
"function. Please ensure that all datasets in the "
|
||||
"pipeline are created in the same graph as the iterator. "
|
||||
"NOTE: This warning will become an error in future "
|
||||
"versions of TensorFlow.")
|
||||
for input_ds in ds._inputs(): # pylint: disable=protected-access
|
||||
raise ValueError(
|
||||
"The graph (" + str(current_graph) + ") of the iterator is different "
|
||||
"from the graph (" + str(ds_graph) + ") the dataset: " +
|
||||
str(ds._variant_tensor) + " was created in. If you are using the "
|
||||
"Estimator API, make sure that no part of the dataset returned by "
|
||||
"the `input_fn` function is defined outside the `input_fn` function. "
|
||||
"Please ensure that all datasets in the pipeline are created in the "
|
||||
"same graph as the iterator.")
|
||||
for input_ds in ds._inputs():
|
||||
if input_ds not in visited:
|
||||
bfs_q.put(input_ds)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user