[tf.data] Migrating remaining core API tests to use TF combinations and performing various minor test cleanup.

PiperOrigin-RevId: 283378763
Change-Id: Ice08340d289406eb691fb261c20329ada7c23c8a
This commit is contained in:
Jiri Simsa 2019-12-02 11:20:05 -08:00 committed by TensorFlower Gardener
parent 2490c87654
commit 2a2c812ab2
33 changed files with 1018 additions and 962 deletions

View File

@ -34,10 +34,12 @@ class MatchingFilesDatasetTest(test_base.DatasetTestBase,
parameterized.TestCase):
def setUp(self):
super(MatchingFilesDatasetTest, self).setUp()
self.tmp_dir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
super(MatchingFilesDatasetTest, self).tearDown()
def _touchTempFiles(self, filenames):
for filename in filenames:

View File

@ -57,6 +57,7 @@ class DatasetSerializationTestBase(test.TestCase):
def tearDown(self):
self._delete_ckpt()
super(DatasetSerializationTestBase, self).tearDown()
# TODO(b/72657739): Remove sparse_tensor argument, which is to test the
# (deprecated) saveable `SparseTensorSliceDataset`, once the API

View File

@ -287,6 +287,7 @@ tf_py_test(
size = "small",
srcs = ["iterator_cluster_test.py"],
additional_deps = [
":test_base",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
@ -400,6 +401,7 @@ tf_py_test(
"//tensorflow/python:variable_scope",
"//tensorflow/python/ops/ragged",
],
shard_count = 4,
)
cuda_py_test(

View File

@ -61,30 +61,30 @@ class AsNumpyIteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(RuntimeError):
ds.as_numpy_iterator()
def checkInvalidElement(self, element):
def _testInvalidElement(self, element):
ds = dataset_ops.Dataset.from_tensors(element)
with self.assertRaisesRegex(TypeError,
'.*does not support datasets containing.*'):
ds.as_numpy_iterator()
@combinations.generate(test_base.eager_only_combinations())
def testInvalidElements(self):
self.checkInvalidElement(sparse_tensor.SparseTensorValue([[0]], [0], [1]))
def testSparseElement(self):
self._testInvalidElement(sparse_tensor.SparseTensorValue([[0]], [0], [1]))
@combinations.generate(test_base.eager_only_combinations())
def testRaggedElement(self):
self.checkInvalidElement(
self._testInvalidElement(
ragged_tensor_value.RaggedTensorValue(
np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64)))
@combinations.generate(test_base.eager_only_combinations())
def testDatasetElement(self):
self.checkInvalidElement(dataset_ops.Dataset.range(3))
self._testInvalidElement(dataset_ops.Dataset.range(3))
@combinations.generate(test_base.eager_only_combinations())
def testNestedNonTensorElement(self):
tuple_elem = (constant_op.constant([1, 2, 3]), dataset_ops.Dataset.range(3))
self.checkInvalidElement(tuple_elem)
self._testInvalidElement(tuple_elem)
if __name__ == '__main__':

View File

@ -45,9 +45,9 @@ class FileCacheTest(test_base.DatasetTestBase, parameterized.TestCase):
self.cache_prefix = path.join(self.tmp_dir, "cache")
def tearDown(self):
super(FileCacheTest, self).tearDown()
if self.tmp_dir:
shutil.rmtree(self.tmp_dir, ignore_errors=True)
super(FileCacheTest, self).tearDown()
@combinations.generate(test_base.default_test_combinations())
def testCacheDatasetPassthrough(self):

View File

@ -42,11 +42,11 @@ from tensorflow.python.training.tracking import util as trackable_utils
class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
def tearDown(self):
super(CheckpointTest, self).tearDown()
prefix = self._iterator_checkpoint_prefix()
pattern = prefix + "*"
files = gfile.Glob(pattern)
map(gfile.Remove, files)
super(CheckpointTest, self).tearDown()
def _iterator_checkpoint_prefix(self):
return os.path.join(self.get_temp_dir(), "iterator")
@ -66,8 +66,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
iterator_state_variant)
return restore_op
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode="graph"))
@combinations.generate(test_base.graph_only_combinations())
def testSaveRestore(self):
def _build_graph(start, stop):
@ -118,8 +117,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode="graph"))
@combinations.generate(test_base.graph_only_combinations())
def testInitThenRestore(self):
# Note: Calling init_op before restore_op is redundant. This test just makes
# sure we do not fail if restore is called on an already initialized
@ -157,8 +155,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode="graph"))
@combinations.generate(test_base.graph_only_combinations())
def testMultipleSaves(self):
def _build_graph(start, stop):
@ -204,8 +201,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode="graph"))
@combinations.generate(test_base.graph_only_combinations())
def testSaveRestoreWithRepeat(self):
def _build_graph(start, stop, num_epochs):
@ -253,8 +249,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode="graph"))
@combinations.generate(test_base.graph_only_combinations())
def testSaveRestoreExhaustedIterator(self):
def _build_graph(start, stop, num_epochs):
@ -295,8 +290,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode="eager"))
@combinations.generate(test_base.eager_only_combinations())
def testSaveRestoreOneShotIterator(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
@ -319,8 +313,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
get_next()
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode="eager"))
@combinations.generate(test_base.eager_only_combinations())
def testSaveRestoreMultipleIterator(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
@ -353,8 +346,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertAllEqual([1, 4], get_next_2())
self.assertAllEqual(3, get_next_3())
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode="eager"))
@combinations.generate(test_base.eager_only_combinations())
def testRestoreExhaustedIterator(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
@ -373,8 +365,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
get_next()
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode="eager"))
@combinations.generate(test_base.eager_only_combinations())
def testRestoreInReconstructedIteratorInitializable(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

View File

@ -43,7 +43,6 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@ -89,13 +88,13 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
variant, original_dataset.element_spec)
self.assertDatasetProduces(revived_dataset, list(original_dataset))
def checkNumInputs(self, dataset, num_inputs):
def _testNumInputs(self, dataset, num_inputs):
self.assertLen(dataset._inputs(), num_inputs)
@combinations.generate(test_base.default_test_combinations())
def testFixedLengthRecordInputs(self):
dataset = readers.FixedLengthRecordDataset("", 42)
self.checkNumInputs(dataset, 0)
self._testNumInputs(dataset, 0)
@combinations.generate(test_base.default_test_combinations())
def testFromGeneratorInputs(self):
@ -103,27 +102,27 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
yield 42
dataset = dataset_ops.Dataset.from_generator(gen, dtypes.int32)
self.checkNumInputs(dataset, 1)
self._testNumInputs(dataset, 1)
@combinations.generate(test_base.default_test_combinations())
def testFromTensorsInputs(self):
dataset = dataset_ops.Dataset.from_tensors([42])
self.checkNumInputs(dataset, 0)
self._testNumInputs(dataset, 0)
@combinations.generate(test_base.default_test_combinations())
def testRangeInputs(self):
dataset = dataset_ops.Dataset.range(10)
self.checkNumInputs(dataset, 0)
self._testNumInputs(dataset, 0)
@combinations.generate(test_base.default_test_combinations())
def testTextLineInputs(self):
dataset = readers.TextLineDataset("")
self.checkNumInputs(dataset, 0)
self._testNumInputs(dataset, 0)
@combinations.generate(test_base.default_test_combinations())
def testTFRecordInputs(self):
dataset = readers.TFRecordDataset("")
self.checkNumInputs(dataset, 1)
self._testNumInputs(dataset, 1)
@combinations.generate(
combinations.combine(tf_api_version=1, mode=["eager", "graph"]))
@ -135,58 +134,58 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
dense_shape=np.array([3, 1])))
self.assertEmpty(dataset_fn._inputs())
def checkUnaryInputs(self, dataset_fn):
def _testUnaryInputs(self, dataset_fn):
input_dataset = dataset_ops.Dataset.range(0)
self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
@combinations.generate(test_base.default_test_combinations())
def testBatchInputs(self):
self.checkUnaryInputs(lambda x: x.batch(10))
self._testUnaryInputs(lambda x: x.batch(10))
@combinations.generate(test_base.default_test_combinations())
def testCacheInputs(self):
self.checkUnaryInputs(lambda x: x.cache())
self._testUnaryInputs(lambda x: x.cache())
@combinations.generate(test_base.default_test_combinations())
def testFilterInputs(self):
self.checkUnaryInputs(lambda x: x.filter(lambda x: True))
self._testUnaryInputs(lambda x: x.filter(lambda x: True))
@combinations.generate(test_base.default_test_combinations())
def testFlatMapInputs(self):
self.checkUnaryInputs(
self._testUnaryInputs(
lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)))
@combinations.generate(test_base.default_test_combinations())
def testMapInputs(self):
self.checkUnaryInputs(lambda x: x.map(lambda x: x))
self._testUnaryInputs(lambda x: x.map(lambda x: x))
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchInputs(self):
self.checkUnaryInputs(lambda x: x.padded_batch(10, []))
self._testUnaryInputs(lambda x: x.padded_batch(10, []))
@combinations.generate(test_base.default_test_combinations())
def testParallelMapInputs(self):
self.checkUnaryInputs(lambda x: x.map(lambda x: x, num_parallel_calls=2))
self._testUnaryInputs(lambda x: x.map(lambda x: x, num_parallel_calls=2))
@combinations.generate(test_base.default_test_combinations())
def testRepeatInputs(self):
self.checkUnaryInputs(lambda x: x.repeat())
self._testUnaryInputs(lambda x: x.repeat())
@combinations.generate(test_base.default_test_combinations())
def testShuffleInputs(self):
self.checkUnaryInputs(lambda x: x.shuffle(10))
self._testUnaryInputs(lambda x: x.shuffle(10))
@combinations.generate(test_base.default_test_combinations())
def testSkipInputs(self):
self.checkUnaryInputs(lambda x: x.skip(1))
self._testUnaryInputs(lambda x: x.skip(1))
@combinations.generate(test_base.default_test_combinations())
def testTakeInputs(self):
self.checkUnaryInputs(lambda x: x.take(1))
self._testUnaryInputs(lambda x: x.take(1))
@combinations.generate(test_base.default_test_combinations())
def testWindowInputs(self):
self.checkUnaryInputs(lambda x: x.window(10))
self._testUnaryInputs(lambda x: x.window(10))
@combinations.generate(test_base.default_test_combinations())
def testUnaryTransformationInputsApply(self):
@ -195,7 +194,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual([input_dataset], dataset._inputs())
def checkInputsWithInterleaveFn(self, dataset_fn, interleave_parallelism):
def _testInputsWithInterleaveFn(self, dataset_fn, interleave_parallelism):
input_dataset = dataset_ops.Dataset.range(0)
dataset = input_dataset.interleave(
lambda x: dataset_ops.Dataset.range(0),
@ -205,11 +204,11 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testParallelInterleaveInputs(self):
self.checkInputsWithInterleaveFn(lambda: dataset_ops.range(0), 2)
self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), 2)
@combinations.generate(test_base.default_test_combinations())
def testInterleaveInputs(self):
self.checkInputsWithInterleaveFn(lambda: dataset_ops.range(0), None)
self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), None)
@combinations.generate(test_base.default_test_combinations())
def testNoWarnings(self):
@ -218,16 +217,16 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
lambda x: dataset_ops.Dataset.range(0), cycle_length=2)
self.assertEmpty(mock_log.call_args_list)
def checkBinaryInputs(self, dataset_fn):
def _testBinaryInputs(self, dataset_fn):
input1 = dataset_ops.Dataset.range(0)
input2 = dataset_ops.Dataset.range(1)
self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())
@combinations.generate(test_base.default_test_combinations())
def testConcatenateInputs(self):
self.checkBinaryInputs(lambda x, y: x.concatenate(y))
self._testBinaryInputs(lambda x, y: x.concatenate(y))
def checkVariadicInputs(self, dataset_fn, input_datasets):
def _testVariadicInputs(self, dataset_fn, input_datasets):
self.assertEqual(
nest.flatten(input_datasets),
dataset_fn(input_datasets)._inputs())
@ -235,20 +234,20 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testZipOneInputs(self):
input_datasets = dataset_ops.Dataset.range(0)
self.checkVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
@combinations.generate(test_base.default_test_combinations())
def testZipNestInputs(self):
input_datasets = (dataset_ops.Dataset.range(0),
(dataset_ops.Dataset.range(1),
dataset_ops.Dataset.range(2)))
self.checkVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
@combinations.generate(test_base.default_test_combinations())
def testZipTupleInputs(self):
input_datasets = (dataset_ops.Dataset.range(0),
dataset_ops.Dataset.range(1))
self.checkVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)
@combinations.generate(test_base.default_test_combinations())
def testFunctions(self):
@ -273,7 +272,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(2, inputs.count(ds2))
self.assertEqual(1, inputs.count(ds3))
def checkDatasetSpec(self, tf_value, expected_element_structure):
def _testDatasetSpec(self, tf_value, expected_element_structure):
dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value)
dataset_structure = structure.type_spec_from_value(dataset)
self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec)
@ -307,12 +306,12 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testTensorDatasetSpec(self):
self.checkDatasetSpec(
self._testDatasetSpec(
constant_op.constant(37.0), tensor_spec.TensorSpec([], dtypes.float32))
@combinations.generate(test_base.default_test_combinations())
def testSparseTensorDatasetSpec(self):
self.checkDatasetSpec(
self._testDatasetSpec(
sparse_tensor.SparseTensor(
indices=[[0]],
values=constant_op.constant([0], dtype=dtypes.int32),
@ -320,7 +319,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testNestDatasetSpec(self):
self.checkDatasetSpec(
self._testDatasetSpec(
{
"a": constant_op.constant(37.0),
"b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
@ -335,20 +334,19 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testDatasetDatasetSpec(self):
self.checkDatasetSpec(
self._testDatasetSpec(
dataset_ops.Dataset.from_tensor_slices(
constant_op.constant([1, 2, 3])),
dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32)))
@combinations.generate(test_base.default_test_combinations())
def testOptionalDatasetSpec(self):
self.checkDatasetSpec(
self._testDatasetSpec(
optional_ops.Optional.from_value(37.0),
optional_ops.OptionalSpec(tensor_spec.TensorSpec([], dtypes.float32)))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testSkipEagerSameGraphErrorOneShot(self):
@combinations.generate(test_base.graph_only_combinations())
def testSameGraphError(self):
dataset = dataset_ops.Dataset.range(10)
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
@ -356,26 +354,27 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testSkipEagerSameGraphErrorOneShotSimple(self):
def testSameGraphErrorOneShot(self):
dataset = dataset_ops.Dataset.range(10)
with ops.Graph().as_default():
with test.mock.patch.object(tf_logging, "warning") as mock_log:
with self.assertRaisesRegexp(
ValueError, "Please ensure that all datasets in the pipeline are "
"created in the same graph as the iterator."):
_ = dataset_ops.make_one_shot_iterator(dataset)
self.assertRegexpMatches(
str(mock_log.call_args), "Please ensure that all datasets in the "
"pipeline are created in the same graph as the iterator.")
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testSkipEagerSameGraphErrorInitializable(self):
def testSameGraphErrorInitializable(self):
dataset = dataset_ops.Dataset.range(10)
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
dataset = dataset.batch(2)
with self.assertRaisesRegexp(
ValueError, "Please ensure that all datasets in the pipeline are "
"created in the same graph as the iterator."):
_ = dataset_ops.make_initializable_iterator(dataset)
@combinations.generate(
combinations.times(
combinations.combine(tf_api_version=[1, 2], mode="eager"),
test_base.eager_only_combinations(),
combinations.combine(execution_mode=[context.ASYNC, context.SYNC])))
def testEagerIteration(self, execution_mode):
with context.execution_mode(execution_mode):

View File

@ -30,28 +30,31 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
def new_and_legacy_filter_fn_combinations():
def _test_combinations():
def new_filter_fn(dataset, predicate):
def filter_fn(dataset, predicate):
return dataset.filter(predicate)
def legacy_filter_fn(dataset, predicate):
return dataset.filter_with_legacy_function(predicate)
return (combinations.combine(
filter_combinations = combinations.combine(
tf_api_version=[1, 2],
mode=["eager", "graph"],
apply_filter=combinations.NamedObject("new_filter_fn", new_filter_fn)) +
combinations.combine(
tf_api_version=1,
mode=["eager", "graph"],
apply_filter=combinations.NamedObject("legacy_filter_fn",
legacy_filter_fn)))
apply_filter=combinations.NamedObject("filter_fn", filter_fn))
legacy_filter_combinations = combinations.combine(
tf_api_version=1,
mode=["eager", "graph"],
apply_filter=combinations.NamedObject("legacy_filter_fn",
legacy_filter_fn))
return filter_combinations + legacy_filter_combinations
class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(new_and_legacy_filter_fn_combinations())
@combinations.generate(_test_combinations())
def testFilterDataset(self, apply_filter):
components = (np.arange(7, dtype=np.int64),
np.array([[1, 2, 3]], dtype=np.int64) *
@ -87,14 +90,14 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
# Test an empty dataset.
do_test(0, 1)
@combinations.generate(new_and_legacy_filter_fn_combinations())
@combinations.generate(_test_combinations())
def testFilterRange(self, apply_filter):
dataset = dataset_ops.Dataset.range(4)
dataset = apply_filter(dataset,
lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2))
self.assertDatasetProduces(dataset, expected_output=[0, 1, 3])
@combinations.generate(new_and_legacy_filter_fn_combinations())
@combinations.generate(_test_combinations())
def testFilterDict(self, apply_filter):
dataset = dataset_ops.Dataset.range(10).map(
lambda x: {"foo": x * 2, "bar": x**2})
@ -104,7 +107,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset,
expected_output=[(i * 2 + i**2) for i in range(10) if not (i**2) % 2])
@combinations.generate(new_and_legacy_filter_fn_combinations())
@combinations.generate(_test_combinations())
def testUseStepContainerInFilter(self, apply_filter):
input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
@ -119,7 +122,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = apply_filter(dataset, _predicate)
self.assertDatasetProduces(dataset, expected_output=[input_data[0]])
@combinations.generate(new_and_legacy_filter_fn_combinations())
@combinations.generate(_test_combinations())
def testSparse(self, apply_filter):
def _map_fn(i):
@ -137,7 +140,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
dataset, expected_output=[_map_fn(i * 2)[0] for i in range(5)])
@combinations.generate(new_and_legacy_filter_fn_combinations())
@combinations.generate(_test_combinations())
def testShortCircuit(self, apply_filter):
dataset = dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(10),
@ -146,7 +149,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
dataset, expected_output=[(i, True) for i in range(10)])
@combinations.generate(new_and_legacy_filter_fn_combinations())
@combinations.generate(_test_combinations())
def testParallelFilters(self, apply_filter):
dataset = dataset_ops.Dataset.range(10)
dataset = apply_filter(dataset, lambda x: math_ops.equal(x % 2, 0))

View File

@ -66,10 +66,8 @@ class FlatMapTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output.extend([i] * i)
self.assertDatasetProduces(dataset, expected_output=expected_output)
# Note: no eager mode coverage, session specific test.
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
def testSkipEagerSharedResourceNestedFlatMapDataset(self):
@combinations.generate(test_base.graph_only_combinations())
def testSharedResourceNestedFlatMapDataset(self):
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
components = np.array(repeats, dtype=np.int64)
iterator = (

View File

@ -32,62 +32,83 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
class DatasetConstructorTest(test_base.DatasetTestBase, parameterized.TestCase):
class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
def _testFromGenerator(self, generator, elem_sequence, num_repeats,
output_types=None):
if output_types is None:
output_types = dtypes.int64
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=output_types).repeat(num_repeats).prefetch(5)
self.assertDatasetProduces(
dataset,
elem_sequence * num_repeats,
requires_initialization=True,
num_test_iterations=2)
def _testFromGeneratorOneShot(self, generator, elem_sequence, num_repeats):
requires_initialization):
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int64).repeat(num_repeats).prefetch(5)
self.assertDatasetProduces(
dataset, elem_sequence * num_repeats, num_test_iterations=2)
dataset,
elem_sequence * num_repeats,
requires_initialization=requires_initialization,
num_test_iterations=2)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
num_repeats=[1, 5], requires_initialization=[True, False])))
def testFromGeneratorUsingFn(self, num_repeats, requires_initialization):
@combinations.generate(test_base.default_test_combinations())
def testFromGeneratorUsingFunction(self):
def generator():
for i in range(1, 100):
yield [i] * i
elem_sequence = list(generator())
self._testFromGenerator(generator, elem_sequence, 1)
self._testFromGenerator(generator, elem_sequence, 5)
self._testFromGeneratorOneShot(generator, elem_sequence, 1)
self._testFromGeneratorOneShot(generator, elem_sequence, 5)
@combinations.generate(test_base.default_test_combinations())
def testFromGeneratorUsingList(self):
elem_sequence = list(generator())
self._testFromGenerator(
generator,
elem_sequence,
num_repeats=num_repeats,
requires_initialization=requires_initialization)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
num_repeats=[1, 5], requires_initialization=[True, False])))
def testFromGeneratorUsingList(self, num_repeats, requires_initialization):
generator = lambda: [[i] * i for i in range(1, 100)]
elem_sequence = list(generator())
self._testFromGenerator(generator, elem_sequence, 1)
self._testFromGenerator(generator, elem_sequence, 5)
self._testFromGenerator(
generator,
elem_sequence,
num_repeats=num_repeats,
requires_initialization=requires_initialization)
@combinations.generate(test_base.default_test_combinations())
def testFromGeneratorUsingNdarray(self):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
num_repeats=[1, 5], requires_initialization=[True, False])))
def testFromGeneratorUsingNdarray(self, num_repeats, requires_initialization):
generator = lambda: np.arange(100, dtype=np.int64)
elem_sequence = list(generator())
self._testFromGenerator(generator, elem_sequence, 1, output_types=np.int64)
self._testFromGenerator(generator, elem_sequence, 5, output_types=np.int64)
self._testFromGenerator(
generator,
elem_sequence,
num_repeats=num_repeats,
requires_initialization=requires_initialization)
@combinations.generate(test_base.default_test_combinations())
def testFromGeneratorUsingGeneratorExpression(self):
# NOTE(mrry): Generator *expressions* are not repeatable (or in
# general reusable), because they eagerly evaluate the `for`
# expression as `iter(range(1, 100))` and discard the means of
# reconstructing `range(1, 100)`. Wrapping the generator
# expression in a `lambda` makes it repeatable.
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
num_repeats=[1, 5], requires_initialization=[True, False])))
def testFromGeneratorUsingGeneratorExpression(self, num_repeats,
requires_initialization):
# NOTE(mrry): Generator *expressions* are not repeatable (or in general
# reusable), because they eagerly evaluate the `for` expression as
# `iter(range(1, 100))` and discard the means of reconstructing
# `range(1, 100)`. Wrapping the generator expression in a `lambda` makes
# it repeatable.
generator = lambda: ([i] * i for i in range(1, 100))
elem_sequence = list(generator())
self._testFromGenerator(generator, elem_sequence, 1)
self._testFromGenerator(generator, elem_sequence, 5)
self._testFromGenerator(
generator,
elem_sequence,
num_repeats=num_repeats,
requires_initialization=requires_initialization)
@combinations.generate(test_base.default_test_combinations())
def testFromMultipleConcurrentGenerators(self):
@ -392,7 +413,6 @@ class DatasetConstructorTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertAllEqual(37, self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
self.assertTrue(event.is_set())
@combinations.generate(test_base.default_test_combinations())
def testSharedName(self):

View File

@ -237,8 +237,8 @@ class FromTensorsTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual([3], get_next().shape)
# TODO(b/121264236): needs mechanism for multiple device in eager mode.
@combinations.generate(test_base.default_test_combinations())
def testSkipEagerSplitPipeline(self):
@combinations.generate(test_base.graph_only_combinations())
def testSplitPipeline(self):
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:

View File

@ -17,12 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -37,9 +40,9 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
class IteratorClusterTest(test.TestCase):
class IteratorClusterTest(test.TestCase, parameterized.TestCase):
@test_util.run_v1_only("b/120545219")
@combinations.generate(test_base.graph_only_combinations())
def testRemoteIteratorWithoutRemoteCallFail(self):
worker_config = config_pb2.ConfigProto()
worker_config.device_count["CPU"] = 2
@ -95,7 +98,7 @@ class IteratorClusterTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(remote_op, feed_dict={target_placeholder: device1})
@test_util.run_v1_only("b/120545219")
@combinations.generate(test_base.graph_only_combinations())
def testRemoteIteratorUsingRemoteCallOp(self):
worker_config = config_pb2.ConfigProto()
worker_config.device_count["CPU"] = 2
@ -106,7 +109,7 @@ class IteratorClusterTest(test.TestCase):
"/job:worker/replica:0/task:0/cpu:1",
worker[0].target)
@test_util.run_v1_only("b/120545219")
@combinations.generate(test_base.graph_only_combinations())
def testRemoteIteratorUsingRemoteCallOpCrossProcess(self):
workers, _ = test_util.create_local_cluster(2, 1)
@ -114,7 +117,7 @@ class IteratorClusterTest(test.TestCase):
"/job:worker/replica:0/task:1/cpu:0",
workers[0].target)
@test_util.run_v1_only("b/120545219")
@combinations.generate(test_base.graph_only_combinations())
def testCaptureHashTableInSharedIterator(self):
worker, _ = test_util.create_local_cluster(1, 1)
@ -131,10 +134,10 @@ class IteratorClusterTest(test.TestCase):
input_sentences = dataset_ops.Dataset.from_tensor_slices(
["brain brain tank salad surgery", "surgery brain"])
iterator = (
input_sentences.map(lambda x: string_ops.string_split([x]).values).map(
table.lookup)
.make_initializable_iterator(shared_name="shared_iterator"))
dataset = input_sentences.map(
lambda x: string_ops.string_split([x]).values).map(table.lookup)
iterator = dataset_ops.make_initializable_iterator(
dataset, shared_name="shared_iterator")
init_op = iterator.initializer
get_next = iterator.get_next()
@ -148,7 +151,7 @@ class IteratorClusterTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@test_util.run_v1_only("b/120545219")
@combinations.generate(test_base.graph_only_combinations())
def testImplicitDisposeParallelMapDataset(self):
# Tests whether a parallel map dataset will be cleaned up correctly when
# the pipeline does not run it until exhaustion.

View File

@ -56,8 +56,7 @@ from tensorflow.python.util import compat
class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testNoGradients(self):
component = constant_op.constant([1.])
side = constant_op.constant(0.)
@ -68,8 +67,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertIsNone(gradients_impl.gradients(value, side)[0])
self.assertIsNone(gradients_impl.gradients(value, [component, side])[0])
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testCapturingStateInOneShotRaisesException(self):
var = variables.Variable(37.0, name="myvar")
dataset = (
@ -80,8 +78,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
"datasets that capture stateful objects.+myvar"):
dataset_ops.make_one_shot_iterator(dataset)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testOneShotIterator(self):
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
@ -107,8 +104,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testOneShotIteratorCaptureByValue(self):
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
@ -172,8 +168,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testOneShotIteratorNonBlocking(self):
dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x)
iterator = dataset_ops.make_one_shot_iterator(dataset)
@ -207,13 +202,11 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
for t in threads:
t.join()
self.assertEqual(num_threads, len(results))
self.assertEqual(num_threads - 1,
len([None for r in results if r is None]))
self.assertLen(results, num_threads)
self.assertLen([None for r in results if r is None], num_threads - 1)
self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None])
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testOneShotIteratorInitializerFails(self):
# Define a dataset whose initialization will always fail.
dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4]))
@ -243,8 +236,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
for t in threads:
t.join()
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testSimpleSharedResource(self):
components = (np.array(1, dtype=np.int64),
np.array([1, 2, 3], dtype=np.int64),
@ -294,8 +286,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testNotInitializedError(self):
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
iterator = dataset_ops.make_initializable_iterator(
@ -307,8 +298,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
"iterator has not been initialized"):
sess.run(get_next)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testReinitializableIterator(self):
dataset_3 = dataset_ops.Dataset.from_tensors(
constant_op.constant([1, 2, 3]))
@ -353,8 +343,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testReinitializableIteratorWithFunctions(self):
def g():
@ -415,8 +404,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
(constant_op.constant([1, 2, 3], dtype=dtypes.int64),
constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64))))
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testIteratorStringHandle(self):
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
@ -474,8 +462,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(
next_element, feed_dict={handle_placeholder: iterator_4_handle})
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testIteratorStringHandleFuture(self):
with forward_compat.forward_compatibility_horizon(2018, 8, 4):
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
@ -541,8 +528,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(
next_element, feed_dict={handle_placeholder: iterator_4_handle})
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testIteratorStringHandleReuseTensorObject(self):
dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset)
@ -571,8 +557,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual("foo_1", handle_with_same_name.op.name)
self.assertIsNot(handle_with_name, handle_with_same_name)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testIteratorStringHandleError(self):
dataset_int_scalar = (
dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat())
@ -613,8 +598,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
feedable_int_vector.get_next(),
feed_dict={handle_placeholder: handle_float_vector}))
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
worker_config = config_pb2.ConfigProto()
worker_config.device_count["CPU"] = 3
@ -672,8 +656,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
})
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self):
s1 = server_lib.Server.create_local_server()
s2 = server_lib.Server.create_local_server()
@ -727,8 +710,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(n)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -785,8 +767,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
@combinations.generate(test_base.graph_only_combinations())
def testRepeatedGetNextWarning(self):
iterator = dataset_ops.make_one_shot_iterator(dataset_ops.Dataset.range(10))
warnings.simplefilter("always")
@ -929,7 +910,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(val, foo.numpy())
val += 1
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
@combinations.generate(test_base.eager_only_combinations())
def testOwnedIteratorFunction(self):
queue = data_flow_ops.FIFOQueue(10, dtypes.int64)
@ -946,7 +927,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(10):
self.assertEqual(queue.dequeue().numpy(), i)
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
@combinations.generate(test_base.eager_only_combinations())
def testOwnedIteratorFunctionError(self):
# In this test we verify that a function that raises an error ends up
# properly deallocating the iterator resource.
@ -976,7 +957,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(queue.size().numpy(), 2)
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
@combinations.generate(test_base.eager_only_combinations())
def testLimitedRetracing(self):
trace_count = [0]
@ -996,7 +977,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(self.evaluate(f(iter(dataset2))), 45)
self.assertEqual(trace_count[0], 1)
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
@combinations.generate(test_base.eager_only_combinations())
def testNestedFunctionsIteratorResource(self):
@def_function.function

View File

@ -35,10 +35,12 @@ from tensorflow.python.util import compat
class ListFilesTest(test_base.DatasetTestBase, parameterized.TestCase):
def setUp(self):
super(ListFilesTest, self).setUp()
self.tmp_dir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
super(ListFilesTest, self).tearDown()
def _touchTempFiles(self, filenames):
for filename in filenames:

File diff suppressed because it is too large Load Diff

View File

@ -119,8 +119,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
]
self.assertEmpty(tensors, "%d Tensors are still alive." % len(tensors))
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode="eager"))
@combinations.generate(test_base.eager_only_combinations())
def testFilter(self):
def get_dataset():
@ -144,8 +143,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
self._testIteratorMemoryLeak(get_dataset)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode="eager"))
@combinations.generate(test_base.eager_only_combinations())
def testFlatMap(self):
def get_dataset():
@ -157,8 +155,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
self._testIteratorMemoryLeak(get_dataset)
@combinations.generate(
combinations.combine(tf_api_version=[1, 2], mode="eager"))
@combinations.generate(test_base.eager_only_combinations())
def testFromGenerator(self):
def get_dataset():
@ -171,8 +168,8 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
self._testIteratorMemoryLeak(get_dataset)
@combinations.generate(
combinations.combine(
tf_api_version=[1, 2], mode="eager", num_parallel_calls=[None, 10]))
combinations.times(test_base.eager_only_combinations(),
combinations.combine(num_parallel_calls=[None, 10])))
def testMap(self, num_parallel_calls):
def get_dataset():
@ -201,8 +198,8 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
self._testIteratorMemoryLeak(get_dataset)
@combinations.generate(
combinations.combine(
tf_api_version=[1, 2], mode="eager", num_parallel_calls=[None, 10]))
combinations.times(test_base.eager_only_combinations(),
combinations.combine(num_parallel_calls=[None, 10])))
def testInterleave(self, num_parallel_calls):
def get_dataset():

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
import numpy as np
@ -27,6 +29,7 @@ from tensorflow.python.data.ops import optional_ops
from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -40,14 +43,90 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
def _optional_spec_test_combinations():
# pylint: disable=g-long-lambda
cases = [
("Dense", lambda: constant_op.constant(37.0),
tensor_spec.TensorSpec([], dtypes.float32)),
("Sparse", lambda: sparse_tensor.SparseTensor(
indices=[[0, 1]],
values=constant_op.constant([0], dtype=dtypes.int32),
dense_shape=[10, 10]),
sparse_tensor.SparseTensorSpec([10, 10], dtypes.int32)),
("Nest", lambda: {
"a": constant_op.constant(37.0),
"b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
}, {
"a":
tensor_spec.TensorSpec([], dtypes.float32),
"b": (
tensor_spec.TensorSpec([1], dtypes.string),
tensor_spec.TensorSpec([], dtypes.string),
)
}),
("Optional", lambda: optional_ops.Optional.from_value(37.0),
optional_ops.OptionalSpec(tensor_spec.TensorSpec([], dtypes.float32))),
]
def reduce_fn(x, y):
name, value_fn, expected_structure = y
return x + combinations.combine(
tf_value_fn=combinations.NamedObject(name, value_fn),
expected_value_structure=expected_structure)
return functools.reduce(reduce_fn, cases, [])
def _get_next_as_optional_test_combinations():
# pylint: disable=g-long-lambda
cases = [
("Dense", np.array([1, 2, 3], dtype=np.int32),
lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
("Sparse",
sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 1]],
values=np.array([-1., 1.], dtype=np.float32),
dense_shape=[2, 2]),
lambda: sparse_tensor.SparseTensor(
indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]),
False),
("Nest", {
"a":
np.array([1, 2, 3], dtype=np.int32),
"b":
sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 1]],
values=np.array([-1., 1.], dtype=np.float32),
dense_shape=[2, 2])
}, lambda: {
"a":
constant_op.constant([4, 5, 6], dtype=dtypes.int32),
"b":
sparse_tensor.SparseTensor(
indices=[[0, 1], [1, 0]],
values=[37.0, 42.0],
dense_shape=[2, 2])
}, False),
]
def reduce_fn(x, y):
name, value, value_fn, gpu_compatible = y
return x + combinations.combine(
np_value=value, tf_value_fn=combinations.NamedObject(name, value_fn),
gpu_compatible=gpu_compatible)
return functools.reduce(reduce_fn, cases, [])
class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testFromValue(self):
opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
self.assertTrue(self.evaluate(opt.has_value()))
self.assertEqual(37.0, self.evaluate(opt.get_value()))
@combinations.generate(test_base.default_test_combinations())
def testFromStructuredValue(self):
opt = optional_ops.Optional.from_value({
"a": constant_op.constant(37.0),
@ -59,6 +138,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
"b": ([b"Foo"], b"Bar")
}, self.evaluate(opt.get_value()))
@combinations.generate(test_base.default_test_combinations())
def testFromSparseTensor(self):
st_0 = sparse_tensor.SparseTensorValue(
indices=np.array([[0]]),
@ -77,6 +157,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertAllEqual(expected.dense_shape,
self.evaluate(actual.dense_shape))
@combinations.generate(test_base.default_test_combinations())
def testFromNone(self):
value_structure = tensor_spec.TensorSpec([], dtypes.float32)
opt = optional_ops.Optional.none_from_structure(value_structure)
@ -91,6 +172,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(opt.get_value())
@combinations.generate(test_base.default_test_combinations())
def testAddN(self):
devices = ["/cpu:0"]
if test_util.is_gpu_available():
@ -117,6 +199,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
opt_none1.value_structure)
self.assertFalse(self.evaluate(add_opt.has_value()))
@combinations.generate(test_base.default_test_combinations())
def testNestedAddN(self):
devices = ["/cpu:0"]
if test_util.is_gpu_available():
@ -137,6 +220,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
opt1.value_structure)
self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0])
@combinations.generate(test_base.default_test_combinations())
def testZerosLike(self):
devices = ["/cpu:0"]
if test_util.is_gpu_available():
@ -159,6 +243,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
opt_none.value_structure)
self.assertFalse(self.evaluate(zeros_opt.has_value()))
@combinations.generate(test_base.default_test_combinations())
def testNestedZerosLike(self):
devices = ["/cpu:0"]
if test_util.is_gpu_available():
@ -175,6 +260,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
opt1.value_structure)
self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0)
@combinations.generate(test_base.default_test_combinations())
def testCopyToGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -204,6 +290,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
self.evaluate(gpu_optional_with_value_values))
self.assertFalse(self.evaluate(gpu_optional_none_has_value))
@combinations.generate(test_base.default_test_combinations())
def testNestedCopyToGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -239,42 +326,10 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertFalse(self.evaluate(inner_none.has_value()))
self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2]))
def _assertElementValueEqual(self, expected, actual):
if isinstance(expected, dict):
self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
for k in expected.keys():
self._assertElementValueEqual(expected[k], actual[k])
elif isinstance(expected, sparse_tensor.SparseTensorValue):
self.assertAllEqual(expected.indices, actual.indices)
self.assertAllEqual(expected.values, actual.values)
self.assertAllEqual(expected.dense_shape, actual.dense_shape)
else:
self.assertAllEqual(expected, actual)
# pylint: disable=g-long-lambda
@parameterized.named_parameters(
("Tensor", lambda: constant_op.constant(37.0),
tensor_spec.TensorSpec([], dtypes.float32)),
("SparseTensor", lambda: sparse_tensor.SparseTensor(
indices=[[0, 1]],
values=constant_op.constant([0], dtype=dtypes.int32),
dense_shape=[10, 10]),
sparse_tensor.SparseTensorSpec([10, 10], dtypes.int32)),
("Nest", lambda: {
"a": constant_op.constant(37.0),
"b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
}, {
"a":
tensor_spec.TensorSpec([], dtypes.float32),
"b": (
tensor_spec.TensorSpec([1], dtypes.string),
tensor_spec.TensorSpec([], dtypes.string),
)
}),
("Optional", lambda: optional_ops.Optional.from_value(37.0),
optional_ops.OptionalSpec(
tensor_spec.TensorSpec([], dtypes.float32))),
)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
_optional_spec_test_combinations()))
def testOptionalSpec(self, tf_value_fn, expected_value_structure):
tf_value = tf_value_fn()
opt = optional_ops.Optional.from_value(tf_value)
@ -304,36 +359,21 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
round_trip_opt = opt_structure._from_tensor_list(
opt_structure._to_tensor_list(opt))
if isinstance(tf_value, optional_ops.Optional):
self._assertElementValueEqual(
self.assertValuesEqual(
self.evaluate(tf_value.get_value()),
self.evaluate(round_trip_opt.get_value().get_value()))
else:
self._assertElementValueEqual(
self.assertValuesEqual(
self.evaluate(tf_value),
self.evaluate(round_trip_opt.get_value()))
@parameterized.named_parameters(
("Tensor", np.array([1, 2, 3], dtype=np.int32),
lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
("SparseTensor", sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 1]],
values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]),
lambda: sparse_tensor.SparseTensor(
indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]),
False),
("Nest", {"a": np.array([1, 2, 3], dtype=np.int32),
"b": sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 1]],
values=np.array([-1., 1.], dtype=np.float32),
dense_shape=[2, 2])},
lambda: {"a": constant_op.constant([4, 5, 6], dtype=dtypes.int32),
"b": sparse_tensor.SparseTensor(
indices=[[0, 1], [1, 0]], values=[37.0, 42.0],
dense_shape=[2, 2])}, False),
)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
_get_next_as_optional_test_combinations()))
def testIteratorGetNextAsOptional(self, np_value, tf_value_fn,
works_on_gpu):
if not works_on_gpu and test.is_gpu_available():
gpu_compatible):
if not gpu_compatible and test.is_gpu_available():
self.skipTest("Test case not yet supported on GPU.")
ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
@ -348,7 +388,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
next_elem.value_structure,
structure.type_spec_from_value(tf_value_fn())))
self.assertTrue(next_elem.has_value())
self._assertElementValueEqual(np_value, next_elem.get_value())
self.assertValuesEqual(np_value, next_elem.get_value())
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
# false, and attempting to get the value will fail.
for _ in range(2):
@ -379,7 +419,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
elem_has_value, elem_value = self.evaluate(
[elem_has_value_t, elem_value_t])
self.assertTrue(elem_has_value)
self._assertElementValueEqual(np_value, elem_value)
self.assertValuesEqual(np_value, elem_value)
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
# false, and attempting to get the value will fail.
@ -388,6 +428,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(elem_value_t)
@combinations.generate(test_base.default_test_combinations())
def testFunctionBoundaries(self):
@def_function.function
def get_optional():
@ -407,6 +448,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
val = consume_optional(opt_tensor)
self.assertEqual(self.evaluate(val), 1.0)
@combinations.generate(test_base.default_test_combinations())
def testLimitedRetracing(self):
trace_count = [0]

View File

@ -18,25 +18,31 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.experimental.ops import stats_options
from tensorflow.python.data.experimental.ops import threading_options
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
class OptionsTest(test_base.DatasetTestBase):
class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testOptionsDefault(self):
ds = dataset_ops.Dataset.range(0)
self.assertEqual(dataset_ops.Options(), ds.options())
@combinations.generate(test_base.default_test_combinations())
def testOptionsOnce(self):
options = dataset_ops.Options()
ds = dataset_ops.Dataset.range(0).with_options(options).cache()
self.assertEqual(options, ds.options())
@combinations.generate(test_base.default_test_combinations())
def testOptionsTwiceSame(self):
options = dataset_ops.Options()
options.experimental_optimization.autotune = True
@ -44,6 +50,7 @@ class OptionsTest(test_base.DatasetTestBase):
options)
self.assertEqual(options, ds.options())
@combinations.generate(test_base.default_test_combinations())
def testOptionsTwiceDifferent(self):
options1 = dataset_ops.Options()
options1.experimental_optimization.autotune = True
@ -55,6 +62,7 @@ class OptionsTest(test_base.DatasetTestBase):
# Explicitly check that flag is False since assertFalse allows None
self.assertIs(ds.options().experimental_deterministic, False)
@combinations.generate(test_base.default_test_combinations())
def testOptionsTwiceDifferentError(self):
options1 = dataset_ops.Options()
options1.experimental_optimization.autotune = True
@ -64,6 +72,7 @@ class OptionsTest(test_base.DatasetTestBase):
"Cannot merge incompatible values"):
dataset_ops.Dataset.range(0).with_options(options1).with_options(options2)
@combinations.generate(test_base.default_test_combinations())
def testOptionsMergeOptionsFromMultipleInputs(self):
options1 = dataset_ops.Options()
options1.experimental_optimization.autotune = True
@ -75,6 +84,7 @@ class OptionsTest(test_base.DatasetTestBase):
self.assertTrue(ds.options().experimental_optimization.autotune)
self.assertTrue(ds.options().experimental_deterministic)
@combinations.generate(test_base.default_test_combinations())
def testOptionsHaveDefaults(self):
options1 = dataset_ops.Options()
options2 = dataset_ops.Options()
@ -84,12 +94,11 @@ class OptionsTest(test_base.DatasetTestBase):
options2.experimental_stats)
self.assertIsNot(options1.experimental_threading,
options2.experimental_threading)
self.assertEquals(options1.experimental_optimization,
optimization_options.OptimizationOptions())
self.assertEquals(options1.experimental_stats,
stats_options.StatsOptions())
self.assertEquals(options1.experimental_threading,
threading_options.ThreadingOptions())
self.assertEqual(options1.experimental_optimization,
optimization_options.OptimizationOptions())
self.assertEqual(options1.experimental_stats, stats_options.StatsOptions())
self.assertEqual(options1.experimental_threading,
threading_options.ThreadingOptions())
if __name__ == "__main__":

View File

@ -23,43 +23,30 @@ import numpy as np
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
def _random_seq_lens(count):
return np.random.randint(20, size=(count,)).astype(np.int32)
@test_util.run_all_in_graph_and_eager_modes
class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
('default_padding', _random_seq_lens(32), 4, [-1], False),
('constant_padding', _random_seq_lens(32), 4, [25], False),
('uneven_with_remainder', _random_seq_lens(34), 4, [-1], False),
('uneven_without_remainder', _random_seq_lens(34), 4, [-1], True),
)
def testPaddedBatchDataset(self, seq_lens, batch_size, padded_shapes,
drop_remainder):
"""Tests the padded batch dataset logic for various input configurations.
Args:
seq_lens: the input sequence lengths
batch_size: the batch size
padded_shapes: the padded shapes to use
drop_remainder: whether a smaller batch size should be produced if batch
size does not divide number of inputs evenly
"""
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
count=[32, 34],
padded_shapes=[[None], [25]],
drop_remainder=[True, False])))
def testPaddedBatchDataset(self, count, padded_shapes, drop_remainder):
seq_lens = np.random.randint(20, size=(count,)).astype(np.int32)
batch_size = 4
dataset = dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
lambda x: array_ops.fill([x], x)).padded_batch(
batch_size=batch_size,
@ -81,7 +68,9 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
if not drop_remainder and len(seq_lens) % batch_size > 0:
result = self.evaluate(get_next())
padded_len = np.max(result) if result.size > 0 else 0
padded_len = padded_shapes[0]
if padded_len is None or padded_len == -1:
padded_len = np.max(result) if result.size > 0 else 0
self.assertEqual((len(seq_lens) % batch_size, padded_len), result.shape)
for j in range(len(seq_lens) % batch_size):
seq_len = seq_lens[num_full_batches * batch_size + j]
@ -93,7 +82,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@test_util.run_deprecated_v1
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchShortPadding(self):
dataset = (
dataset_ops.Dataset.from_tensor_slices(
@ -102,6 +91,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
dataset, expected_error=(errors.DataLossError, ''))
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchEmptyTensors(self):
dataset = (
dataset_ops.Dataset.from_tensor_slices(
@ -109,6 +99,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
batch_size=4, padded_shapes=[-1]))
self.assertDatasetProduces(dataset, expected_output=[[[], [], [], []]])
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchDatasetNonDefaultPadding(self):
def fill_tuple(x):
@ -139,6 +130,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchDatasetUnicode(self):
# See GitHub issue 16149
def generator():
@ -156,9 +148,8 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
next_element = self.getNext(padded_dataset)
self.evaluate(next_element())
# NOTE: This test is specific to graph mode and is skipped in eager mode.
@test_util.run_deprecated_v1
def testSkipEagerPaddedBatchDatasetShapeSpecifications(self):
@combinations.generate(test_base.graph_only_combinations())
def testPaddedBatchDatasetShapeSpecifications(self):
int_placeholder = array_ops.placeholder(dtypes.int32)
float_placeholder = array_ops.placeholder(dtypes.float32)
string_placeholder = array_ops.placeholder(dtypes.string)
@ -190,6 +181,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual([None, None, None], dataset_output_shapes[1].as_list())
self.assertEqual([None, 37], dataset_output_shapes[2].as_list())
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchSparseError(self):
def _map_fn(i):
@ -199,6 +191,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(TypeError):
_ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10)
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatchShapeError(self):
with self.assertRaisesRegexp(
ValueError, r'The padded shape \(1,\) is not compatible with the '
@ -230,9 +223,8 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
_ = dataset_ops.Dataset.range(10).padded_batch(
5, padded_shapes=shape_as_tensor)
# NOTE: This test is specific to graph mode and is skipped in eager mode.
@test_util.run_deprecated_v1
def testSkipEagerPaddedBatchShapeError(self):
@combinations.generate(test_base.graph_only_combinations())
def testPaddedBatchShapeErrorPlaceholder(self):
with self.assertRaisesRegexp(
ValueError,
r'The padded shape \((\?|None), (\?|None)\) is not compatible with the '

View File

@ -23,36 +23,41 @@ from absl.testing import parameterized
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class PrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.parameters((-1), (0), (5))
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(buffer_size=[-1, None, 0, 42])))
def testBufferSize(self, buffer_size):
dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size)
self.assertDatasetProduces(dataset, expected_output=range(10))
@parameterized.parameters((-2), (-42))
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(buffer_size=[-2, -42])))
def testInvalidBufferSize(self, buffer_size):
with self.assertRaises(errors.InvalidArgumentError):
dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size)
self.evaluate(dataset._variant_tensor)
@parameterized.parameters(*[(buffer_size, slack_period)
for buffer_size in (-1, None, 0, 5)
for slack_period in (1, 8)])
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
buffer_size=[-1, None, 0, 42], slack_period=[1, 8])))
def testPrefetchWithSlack(self, buffer_size, slack_period):
dataset = dataset_ops.Dataset.range(100)
dataset = dataset_ops.PrefetchDataset(
dataset, buffer_size, slack_period=slack_period)
self.assertDatasetProduces(dataset, expected_output=range(100))
@test_util.run_v1_only("graph-mode specific test")
def testSkipEagerPrefetchCancellation(self):
@combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
def testPrefetchCancellation(self):
def map_py_fn(x):
while x > -1:

View File

@ -17,51 +17,60 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class RangeTest(test_base.DatasetTestBase):
class RangeTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testStop(self):
dataset = dataset_ops.Dataset.range(5)
self.assertDatasetProduces(dataset, expected_output=range(5))
@combinations.generate(test_base.default_test_combinations())
def testStartStop(self):
start, stop = 2, 5
dataset = dataset_ops.Dataset.range(start, stop)
self.assertDatasetProduces(dataset, expected_output=range(2, 5))
@combinations.generate(test_base.default_test_combinations())
def testStartStopStep(self):
start, stop, step = 2, 10, 2
dataset = dataset_ops.Dataset.range(start, stop, step)
self.assertDatasetProduces(dataset, expected_output=range(2, 10, 2))
@combinations.generate(test_base.default_test_combinations())
def testZeroStep(self):
start, stop, step = 2, 10, 0
with self.assertRaises(errors.InvalidArgumentError):
dataset = dataset_ops.Dataset.range(start, stop, step)
self.evaluate(dataset._variant_tensor)
@combinations.generate(test_base.default_test_combinations())
def testNegativeStep(self):
start, stop, step = 2, 10, -1
dataset = dataset_ops.Dataset.range(start, stop, step)
self.assertDatasetProduces(dataset, expected_output=range(2, 10, -1))
@combinations.generate(test_base.default_test_combinations())
def testStopLessThanStart(self):
start, stop = 10, 2
dataset = dataset_ops.Dataset.range(start, stop)
self.assertDatasetProduces(dataset, expected_output=range(10, 2))
@combinations.generate(test_base.default_test_combinations())
def testStopLessThanStartWithPositiveStep(self):
start, stop, step = 10, 2, 2
dataset = dataset_ops.Dataset.range(start, stop, step)
self.assertDatasetProduces(dataset, expected_output=range(10, 2, 2))
@combinations.generate(test_base.default_test_combinations())
def testStopLessThanStartWithNegativeStep(self):
start, stop, step = 10, 2, -1
dataset = dataset_ops.Dataset.range(start, stop, step)

View File

@ -17,43 +17,33 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class RepeatTest(test_base.DatasetTestBase):
class RepeatTest(test_base.DatasetTestBase, parameterized.TestCase):
def testRepeatTensorDataset(self):
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(count=[0, 3, 7])))
def testFiniteRepeat(self, count):
"""Test a dataset that repeats its input multiple times."""
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
# This placeholder can be fed when dataset-definition subgraph
# runs (i.e. `init_op` below) to configure the number of
# repetitions used in a particular iterator.
dataset = dataset_ops.Dataset.from_tensors(components).repeat(count)
self.assertEqual(
[c.shape for c in components],
[shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
self.assertDatasetProduces(dataset, [components] * count)
def do_test(count):
dataset = dataset_ops.Dataset.from_tensors(components).repeat(count)
self.assertEqual(
[c.shape for c in components],
[shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
self.assertDatasetProduces(dataset, [components] * count)
# Test a finite repetition.
do_test(3)
# test a different finite repetition.
do_test(7)
# Test an empty repetition.
do_test(0)
# Test an infinite repetition.
# NOTE(mrry): There's not a good way to test that the sequence
# actually is infinite.
@combinations.generate(test_base.default_test_combinations())
def testInfiniteRepeat(self):
# NOTE(mrry): There's not a good way to test that the sequence is infinite.
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
dataset = dataset_ops.Dataset.from_tensors(components).repeat(-1)
self.assertEqual(
[c.shape for c in components],
@ -64,7 +54,8 @@ class RepeatTest(test_base.DatasetTestBase):
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
def testRepeatRepeatTensorDataset(self):
@combinations.generate(test_base.default_test_combinations())
def testRepeatRepeat(self):
"""Test the composition of repeat datasets."""
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
inner_count, outer_count = 7, 14
@ -77,11 +68,6 @@ class RepeatTest(test_base.DatasetTestBase):
self.assertDatasetProduces(dataset,
[components] * (inner_count * outer_count))
def testRepeatEmptyDataset(self):
"""Test that repeating an empty dataset does not hang."""
dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).skip(10).repeat(-1)
self.assertDatasetProduces(dataset, [])
if __name__ == "__main__":
test.main()

View File

@ -17,66 +17,79 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
class ShardTest(test_base.DatasetTestBase):
class ShardTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testSimpleCase(self):
dataset = dataset_ops.Dataset.range(10).shard(5, 2)
self.assertDatasetProduces(dataset, expected_output=[2, 7])
@combinations.generate(test_base.default_test_combinations())
def testNestedData(self):
dataset_a = dataset_ops.Dataset.range(10)
dataset_b = dataset_ops.Dataset.range(10, 0, -1)
dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2)
self.assertDatasetProduces(dataset, expected_output=[(2, 8), (7, 3)])
@combinations.generate(test_base.default_test_combinations())
def testOffsetZero(self):
dataset = dataset_ops.Dataset.range(10).shard(5, 0)
self.assertDatasetProduces(dataset, expected_output=[0, 5])
@combinations.generate(test_base.default_test_combinations())
def testOffsetGreaterNumShards(self):
with self.assertRaises(errors.InvalidArgumentError):
dataset = dataset_ops.Dataset.range(10).shard(5, 7)
self.evaluate(self.getNext(dataset)())
@combinations.generate(test_base.default_test_combinations())
def testNegativeOffset(self):
with self.assertRaises(errors.InvalidArgumentError):
dataset = dataset_ops.Dataset.range(10).shard(5, -3)
self.evaluate(self.getNext(dataset)())
@combinations.generate(test_base.default_test_combinations())
def testNegativeNumShards(self):
with self.assertRaises(errors.InvalidArgumentError):
dataset = dataset_ops.Dataset.range(10).shard(-3, 1)
self.evaluate(self.getNext(dataset)())
@combinations.generate(test_base.default_test_combinations())
def testZeroNumShards(self):
with self.assertRaises(errors.InvalidArgumentError):
dataset = dataset_ops.Dataset.range(10).shard(0, 1)
self.evaluate(self.getNext(dataset)())
@combinations.generate(test_base.default_test_combinations())
def testIteratorEndsBeforeFirstElem(self):
dataset = dataset_ops.Dataset.range(1).shard(5, 2)
self.assertDatasetProduces(dataset, expected_output=[])
@combinations.generate(test_base.default_test_combinations())
def testLargerWorkerPool(self):
dataset = dataset_ops.Dataset.range(10).shard(7, 5)
self.assertDatasetProduces(dataset, expected_output=[5])
@combinations.generate(test_base.default_test_combinations())
def testIndexEqualsNumShards(self):
dataset = dataset_ops.Dataset.range(10).shard(5, 4)
self.assertDatasetProduces(dataset, expected_output=[4, 9])
@combinations.generate(test_base.default_test_combinations())
def testIndexEqualsNumShards2(self):
dataset = dataset_ops.Dataset.range(10).shard(4, 3)
self.assertDatasetProduces(dataset, expected_output=[3, 7])
@combinations.generate(test_base.default_test_combinations())
def testNumShardsLargerThanDataset(self):
dataset = dataset_ops.Dataset.range(10).shard(20, 5)
self.assertDatasetProduces(dataset, expected_output=[5])

View File

@ -40,7 +40,7 @@ from tensorflow.python.platform import test
class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testShuffleDataset(self):
def testBasic(self):
components = (
np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
np.array([9.0, 10.0, 11.0, 12.0])
@ -160,7 +160,7 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(
combinations.times(
combinations.combine(tf_api_version=[1, 2], mode="graph"),
test_base.graph_only_combinations(),
combinations.combine(reshuffle=[True, False]),
combinations.combine(graph_seed=38, op_seed=None) +
combinations.combine(graph_seed=None, op_seed=42) +
@ -188,7 +188,7 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
# TODO(b/117581999): enable this test for eager-mode.
@combinations.generate(
combinations.times(
combinations.combine(tf_api_version=[1, 2], mode="graph"),
test_base.graph_only_combinations(),
combinations.combine(
reshuffle=[True, False], initializable=[True, False])))
def testMultipleIterators(self, reshuffle, initializable):
@ -278,7 +278,7 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(
combinations.times(
combinations.combine(tf_api_version=[1, 2], mode="eager"),
test_base.eager_only_combinations(),
combinations.combine(reshuffle=[True, False], seed=[None, 42])))
def testReshuffleSeparateTransformations(self, reshuffle, seed):
dataset = dataset_ops.Dataset.range(10)

View File

@ -17,46 +17,30 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class SkipTest(test_base.DatasetTestBase):
class SkipTest(test_base.DatasetTestBase, parameterized.TestCase):
def testSkipTensorDataset(self):
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(count=[-1, 0, 4, 10, 25])))
def testBasic(self, count):
components = (np.arange(10),)
def do_test(count):
dataset = dataset_ops.Dataset.from_tensor_slices(components).skip(count)
self.assertEqual(
[c.shape[1:] for c in components],
[shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
start_range = min(count, 10) if count != -1 else 10
self.assertDatasetProduces(
dataset,
[tuple(components[0][i:i + 1]) for i in range(start_range, 10)])
# Skip fewer than input size, we should skip
# the first 4 elements and then read the rest.
do_test(4)
# Skip more than input size: get nothing.
do_test(25)
# Skip exactly input size.
do_test(10)
# Set -1 for 'count': skip the entire dataset.
do_test(-1)
# Skip nothing
do_test(0)
dataset = dataset_ops.Dataset.from_tensor_slices(components).skip(count)
self.assertEqual(
[c.shape[1:] for c in components],
[shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
start_range = min(count, 10) if count != -1 else 10
self.assertDatasetProduces(
dataset,
[tuple(components[0][i:i + 1]) for i in range(start_range, 10)])
if __name__ == "__main__":

View File

@ -17,40 +17,30 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class TakeTest(test_base.DatasetTestBase):
class TakeTest(test_base.DatasetTestBase, parameterized.TestCase):
def testTakeTensorDataset(self):
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(count=[-1, 0, 4, 10, 25])))
def testBasic(self, count):
components = (np.arange(10),)
dataset = dataset_ops.Dataset.from_tensor_slices(components).take(count)
self.assertEqual(
[c.shape[1:] for c in components],
[shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
num_output = min(count, 10) if count != -1 else 10
self.assertDatasetProduces(
dataset, [tuple(components[0][i:i + 1]) for i in range(num_output)])
def do_test(count):
dataset = dataset_ops.Dataset.from_tensor_slices(components).take(count)
self.assertEqual(
[c.shape[1:] for c in components],
[shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
num_output = min(count, 10) if count != -1 else 10
self.assertDatasetProduces(
dataset, [tuple(components[0][i:i + 1]) for i in range(num_output)])
# Take fewer than input size
do_test(4)
# Take more than input size
do_test(25)
# Take all of input
do_test(-1)
# Take nothing
do_test(0)
if __name__ == "__main__":
test.main()

View File

@ -58,7 +58,11 @@ class DatasetTestBase(test.TestCase):
def assertValuesEqual(self, expected, actual):
"""Asserts that two values are equal."""
if sparse_tensor.is_sparse(expected):
if isinstance(expected, dict):
self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
for k in expected.keys():
self.assertValuesEqual(expected[k], actual[k])
elif sparse_tensor.is_sparse(expected):
self.assertAllEqual(expected.indices, actual.indices)
self.assertAllEqual(expected.values, actual.values)
self.assertAllEqual(expected.dense_shape, actual.dense_shape)

View File

@ -21,11 +21,12 @@ import gzip
import os
import zlib
from absl.testing import parameterized
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@ -37,8 +38,7 @@ except ImportError:
psutil_import_succeeded = False
@test_util.run_all_in_graph_and_eager_modes
class TextLineDatasetTest(test_base.DatasetTestBase):
class TextLineDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _lineText(self, f, l):
return compat.as_bytes("%d: %d" % (f, l))
@ -76,7 +76,11 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
return filenames
def _testTextLineDataset(self, compression_type=None):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(compression_type=[None, "GZIP", "ZLIB"])))
def testTextLineDataset(self, compression_type):
test_filenames = self._createFiles(
2, 5, crlf=True, compression_type=compression_type)
@ -115,6 +119,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
expected_output=[[self._lineText(0, i) for i in range(5)],
[self._lineText(1, i) for i in range(5)]] * 10)
@combinations.generate(test_base.default_test_combinations())
def testTextLineDatasetParallelRead(self):
test_filenames = self._createFiles(10, 10)
files = dataset_ops.Dataset.from_tensor_slices(test_filenames).repeat(10)
@ -125,15 +130,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
self.assertDatasetProduces(
dataset, expected_output=expected_output * 10, assert_items_equal=True)
def testTextLineDatasetNoCompression(self):
self._testTextLineDataset()
def testTextLineDatasetGzipCompression(self):
self._testTextLineDataset(compression_type="GZIP")
def testTextLineDatasetZlibCompression(self):
self._testTextLineDataset(compression_type="ZLIB")
@combinations.generate(test_base.default_test_combinations())
def testTextLineDatasetBuffering(self):
test_filenames = self._createFiles(2, 5, crlf=True)
@ -143,33 +140,33 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
expected_output.extend([self._lineText(j, i) for i in range(5)])
self.assertDatasetProduces(repeat_dataset, expected_output=expected_output)
@combinations.generate(test_base.eager_only_combinations())
def testIteratorResourceCleanup(self):
filename = os.path.join(self.get_temp_dir(), "text.txt")
with open(filename, "wt") as f:
for i in range(3):
f.write("%d\n" % (i,))
with context.eager_mode():
first_iterator = iter(readers.TextLineDataset(filename))
self.assertEqual(b"0", next(first_iterator).numpy())
second_iterator = iter(readers.TextLineDataset(filename))
self.assertEqual(b"0", next(second_iterator).numpy())
# Eager kernel caching is based on op attributes, which includes the
# Dataset's output shape. Create a different kernel to test that they
# don't create resources with the same names.
different_kernel_iterator = iter(
readers.TextLineDataset(filename).repeat().batch(16))
self.assertEqual([16], next(different_kernel_iterator).shape)
# Remove our references to the Python Iterator objects, which (assuming no
# reference cycles) is enough to trigger DestroyResourceOp and close the
# partially-read files.
del first_iterator
del second_iterator
del different_kernel_iterator
if not psutil_import_succeeded:
self.skipTest(
"psutil is required to check that we've closed our files.")
open_files = psutil.Process().open_files()
self.assertNotIn(filename, [open_file.path for open_file in open_files])
first_iterator = iter(readers.TextLineDataset(filename))
self.assertEqual(b"0", next(first_iterator).numpy())
second_iterator = iter(readers.TextLineDataset(filename))
self.assertEqual(b"0", next(second_iterator).numpy())
# Eager kernel caching is based on op attributes, which includes the
# Dataset's output shape. Create a different kernel to test that they
# don't create resources with the same names.
different_kernel_iterator = iter(
readers.TextLineDataset(filename).repeat().batch(16))
self.assertEqual([16], next(different_kernel_iterator).shape)
# Remove our references to the Python Iterator objects, which (assuming no
# reference cycles) is enough to trigger DestroyResourceOp and close the
# partially-read files.
del first_iterator
del second_iterator
del different_kernel_iterator
if not psutil_import_succeeded:
self.skipTest(
"psutil is required to check that we've closed our files.")
open_files = psutil.Process().open_files()
self.assertNotIn(filename, [open_file.path for open_file in open_files])
if __name__ == "__main__":

View File

@ -21,31 +21,31 @@ import gzip
import os
import zlib
from absl.testing import parameterized
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import python_io
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@test_util.run_all_in_graph_and_eager_modes
class TFRecordDatasetTest(test_base.DatasetTestBase):
class TFRecordDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def setUp(self):
super(TFRecordDatasetTest, self).setUp()
self._num_files = 2
self._num_records = 7
self.test_filenames = self._createFiles()
def dataset_fn(self,
filenames,
compression_type="",
num_epochs=1,
batch_size=None):
def _dataset_factory(self,
filenames,
compression_type="",
num_epochs=1,
batch_size=None):
repeat_dataset = readers.TFRecordDataset(
filenames, compression_type).repeat(num_epochs)
@ -67,6 +67,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
writer.close()
return filenames
@combinations.generate(test_base.default_test_combinations())
def testTFRecordDatasetConstructorErrorsTensorInput(self):
with self.assertRaisesRegex(TypeError,
"filenames.*must be.*Tensor.*string"):
@ -78,37 +79,40 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(Exception):
readers.TFRecordDataset(object())
@combinations.generate(test_base.default_test_combinations())
def testReadOneEpoch(self):
# Basic test: read from file 0.
dataset = self.dataset_fn(self.test_filenames[0])
dataset = self._dataset_factory(self.test_filenames[0])
self.assertDatasetProduces(
dataset,
expected_output=[self._record(0, i) for i in range(self._num_records)])
# Basic test: read from file 1.
dataset = self.dataset_fn(self.test_filenames[1])
dataset = self._dataset_factory(self.test_filenames[1])
self.assertDatasetProduces(
dataset,
expected_output=[self._record(1, i) for i in range(self._num_records)])
# Basic test: read from both files.
dataset = self.dataset_fn(self.test_filenames)
dataset = self._dataset_factory(self.test_filenames)
expected_output = []
for j in range(self._num_files):
expected_output.extend(
[self._record(j, i) for i in range(self._num_records)])
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testReadTenEpochs(self):
dataset = self.dataset_fn(self.test_filenames, num_epochs=10)
dataset = self._dataset_factory(self.test_filenames, num_epochs=10)
expected_output = []
for j in range(self._num_files):
expected_output.extend(
[self._record(j, i) for i in range(self._num_records)])
self.assertDatasetProduces(dataset, expected_output=expected_output * 10)
@combinations.generate(test_base.default_test_combinations())
def testReadTenEpochsOfBatches(self):
dataset = self.dataset_fn(
dataset = self._dataset_factory(
self.test_filenames, num_epochs=10, batch_size=self._num_records)
expected_output = []
for j in range(self._num_files):
@ -116,6 +120,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
[self._record(j, i) for i in range(self._num_records)])
self.assertDatasetProduces(dataset, expected_output=expected_output * 10)
@combinations.generate(test_base.default_test_combinations())
def testReadZlibFiles(self):
zlib_files = []
for i, fn in enumerate(self.test_filenames):
@ -130,9 +135,10 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
for j in range(self._num_files):
expected_output.extend(
[self._record(j, i) for i in range(self._num_records)])
dataset = self.dataset_fn(zlib_files, compression_type="ZLIB")
dataset = self._dataset_factory(zlib_files, compression_type="ZLIB")
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testReadGzipFiles(self):
gzip_files = []
for i, fn in enumerate(self.test_filenames):
@ -145,9 +151,10 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
for j in range(self._num_files):
expected_output.extend(
[self._record(j, i) for i in range(self._num_records)])
dataset = self.dataset_fn(gzip_files, compression_type="GZIP")
dataset = self._dataset_factory(gzip_files, compression_type="GZIP")
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testReadWithBuffer(self):
one_mebibyte = 2**20
dataset = readers.TFRecordDataset(
@ -158,6 +165,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
[self._record(j, i) for i in range(self._num_records)])
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testReadFromDatasetOfFiles(self):
files = dataset_ops.Dataset.from_tensor_slices(self.test_filenames)
expected_output = []
@ -167,6 +175,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
dataset = readers.TFRecordDataset(files)
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testReadTenEpochsFromDatasetOfFilesInParallel(self):
files = dataset_ops.Dataset.from_tensor_slices(
self.test_filenames).repeat(10)

View File

@ -23,11 +23,11 @@ import numpy as np
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
@ -36,13 +36,14 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
@test_util.run_all_in_graph_and_eager_modes
class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testUnbatchWithUnknownRankInput(self):
dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]).unbatch()
self.assertDatasetProduces(dataset, range(4))
@combinations.generate(test_base.default_test_combinations())
def testUnbatchScalarDataset(self):
data = tuple([math_ops.range(10) for _ in range(3)])
data = dataset_ops.Dataset.from_tensor_slices(data)
@ -54,12 +55,14 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(data, [(i,) * 3 for i in range(10)])
@combinations.generate(test_base.default_test_combinations())
def testUnbatchNestedDataset(self):
data = dataset_ops.Dataset.from_tensors(
[dataset_ops.Dataset.range(10) for _ in range(10)])
data = data.unbatch().flat_map(lambda x: x)
self.assertDatasetProduces(data, list(range(10)) * 10)
@combinations.generate(test_base.default_test_combinations())
def testUnbatchDatasetWithStrings(self):
data = tuple([math_ops.range(10) for _ in range(3)])
data = dataset_ops.Dataset.from_tensor_slices(data)
@ -73,6 +76,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
data, [(i, compat.as_bytes(str(i)), i) for i in range(10)])
@combinations.generate(test_base.default_test_combinations())
def testUnbatchDatasetWithSparseTensor(self):
st = sparse_tensor.SparseTensorValue(
indices=[[i, i] for i in range(10)],
@ -87,6 +91,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
]
self.assertDatasetProduces(data, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testUnbatchDatasetWithDenseSparseAndRaggedTensor(self):
st = sparse_tensor.SparseTensorValue(
indices=[[i, i] for i in range(10)],
@ -104,6 +109,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
data, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testUnbatchDatasetWithRaggedTensor(self):
rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]],
[[5]], [[6]], [[7]], [[8]], [[9]]])
@ -119,6 +125,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
data, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testUnbatchSingleElementTupleDataset(self):
data = tuple([(math_ops.range(10),) for _ in range(3)])
data = dataset_ops.Dataset.from_tensor_slices(data)
@ -130,6 +137,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(data, [((i,),) * 3 for i in range(10)])
@combinations.generate(test_base.default_test_combinations())
def testUnbatchMultiElementTupleDataset(self):
data = tuple([(math_ops.range(10 * i, 10 * i + 10),
array_ops.fill([10], "hi")) for i in range(3)])
@ -146,6 +154,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
data,
[((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")) for i in range(10)])
@combinations.generate(test_base.default_test_combinations())
def testUnbatchEmpty(self):
data = dataset_ops.Dataset.from_tensors(
(constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
@ -153,15 +162,15 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
data = data.unbatch()
self.assertDatasetProduces(data, [])
@combinations.generate(test_base.default_test_combinations())
def testUnbatchStaticShapeMismatch(self):
data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
np.arange(9)))
with self.assertRaises(ValueError):
data.unbatch()
# Note: dynamic shape mismatch is graph specific test.
@test_util.run_deprecated_v1
def testSkipEagerUnbatchDynamicShapeMismatch(self):
@combinations.generate(test_base.graph_only_combinations())
def testUnbatchDynamicShapeMismatch(self):
ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
ph2 = array_ops.placeholder(dtypes.int32, shape=None)
data = dataset_ops.Dataset.from_tensors((ph1, ph2))
@ -190,6 +199,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(next_element)
@combinations.generate(test_base.default_test_combinations())
def testUnbatchDatasetWithUintDtypes(self):
components = (
np.tile(np.array([[0], [1], [2], [3]], dtype=np.uint8), 2),

View File

@ -24,43 +24,32 @@ from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("1", 20, 14, 7, 1),
("2", 20, 17, 9, 1),
("3", 20, 14, 14, 1),
("4", 20, 10, 14, 1),
("5", 20, 14, 19, 1),
("6", 20, 4, 1, 2),
("7", 20, 2, 1, 6),
("8", 20, 4, 7, 2),
("9", 20, 2, 7, 6),
("10", 1, 10, 4, 1),
("11", 0, 10, 4, 1),
("12", 20, 14, 7, 1, False),
("13", 20, 17, 9, 1, False),
("14", 20, 14, 14, 1, False),
("15", 20, 10, 14, 1, False),
("16", 20, 14, 19, 1, False),
("17", 20, 4, 1, 2, False),
("18", 20, 2, 1, 6, False),
("19", 20, 4, 7, 2, False),
("20", 20, 2, 7, 6, False),
("21", 1, 10, 4, 1, False),
("22", 0, 10, 4, 1, False),
)
def testWindowDataset(self, count, size, shift, stride, drop_remainder=True):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
count=20,
size=[10, 14, 17],
shift=[7, 14],
stride=[1, 2, 6],
drop_remainder=[True, False]) + combinations.combine(
count=[0, 1],
size=10,
shift=4,
stride=1,
drop_remainder=[True, False])))
def testWindowDataset(self, count, size, shift, stride, drop_remainder):
"""Tests a dataset that slides a window its input elements."""
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
@ -111,11 +100,12 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@parameterized.named_parameters(
("1", 14, 0, 3, 1),
("2", 14, 3, 0, 1),
("3", 14, 3, 3, 0),
)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(count=20, size=0, shift=3, stride=1) +
combinations.combine(count=20, size=3, shift=0, stride=1) +
combinations.combine(count=20, size=3, shift=3, stride=0)))
def testWindowDatasetInvalid(self, count, size, shift, stride):
with self.assertRaises(errors.InvalidArgumentError):
ds = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count).window(
@ -123,12 +113,14 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
stride=stride).flat_map(lambda x: x.batch(batch_size=size))
self.evaluate(ds._variant_tensor)
@combinations.generate(test_base.default_test_combinations())
def testWindowDifferentNestedStructures(self):
ds = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4])).window(2)
self.getNext(ds)
ds = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2]}).window(2)
self.getNext(ds)
@combinations.generate(test_base.default_test_combinations())
def testWindowSparse(self):
def _sparse(i):
@ -148,6 +140,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
]
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testWindowSparseWithDifferentDenseShapes(self):
def _sparse(i):
@ -177,6 +170,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
dense_shape=[5, i * 3 + 5 - 1]))
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testNestedWindowSparse(self):
def _sparse(i):
@ -205,6 +199,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
]
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testWindowShapeError(self):
def generator():
@ -222,6 +217,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
r"Cannot batch tensors with different shapes in component 0. "
r"First element had shape \[3\] and element 2 had shape \[4\]."))
@combinations.generate(test_base.default_test_combinations())
def testWindowIgnoreErrors(self):
input_values = np.float32([1., np.nan, 2., np.nan, 3.])
dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
@ -232,6 +228,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset, expected_output=[np.float32([1., 2.]),
np.float32([2., 3.])])
@combinations.generate(test_base.default_test_combinations())
def testNestedOutput(self):
if not context.executing_eagerly():
self.skipTest("self.evaluate() does not work with a dataset")

View File

@ -17,66 +17,68 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ZipTest(test_base.DatasetTestBase):
def _dataset_factory(components):
datasets = tuple([
dataset_ops.Dataset.from_tensor_slices(component)
for component in components
])
return dataset_ops.Dataset.zip(datasets)
def testZipDataset(self):
def dataset_fn(components):
datasets = tuple([
dataset_ops.Dataset.from_tensor_slices(component)
for component in components
])
return dataset_ops.Dataset.zip(datasets)
class ZipTest(test_base.DatasetTestBase, parameterized.TestCase):
equal_length_components = [
@combinations.generate(test_base.default_test_combinations())
def testZipEqual(self):
components = [
np.tile(np.array([[1], [2], [3], [4]]), 20),
np.tile(np.array([[12], [13], [14], [15]]), 22),
np.array([37.0, 38.0, 39.0, 40.0])
]
get_next = self.getNext(dataset_fn(equal_length_components))
get_next = self.getNext(_dataset_factory(components))
for i in range(4):
results = self.evaluate(get_next())
for component, result_component in zip(equal_length_components, results):
for component, result_component in zip(components, results):
self.assertAllEqual(component[i], result_component)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
variable_length_components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]]
get_next = self.getNext(dataset_fn(variable_length_components))
@combinations.generate(test_base.default_test_combinations())
def testZipUnequal(self):
components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]]
get_next = self.getNext(_dataset_factory(components))
for i in range(2):
results = self.evaluate(get_next())
for component, result_component in zip(variable_length_components,
results):
for component, result_component in zip(components, results):
self.assertAllEqual(component[i], result_component)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
def testNestedZipDataset(self):
@combinations.generate(test_base.default_test_combinations())
def testNested(self):
equal_length_components = [
components = [
np.tile(np.array([[1], [2], [3], [4]]), 20),
np.tile(np.array([[12], [13], [14], [15]]), 22),
np.array([37.0, 38.0, 39.0, 40.0])
]
datasets = [
dataset_ops.Dataset.from_tensor_slices(component)
for component in equal_length_components
for component in components
]
dataset = dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
@ -88,9 +90,9 @@ class ZipTest(test_base.DatasetTestBase):
get_next = self.getNext(dataset)
for i in range(4):
result1, (result2, result3) = self.evaluate(get_next())
self.assertAllEqual(equal_length_components[0][i], result1)
self.assertAllEqual(equal_length_components[1][i], result2)
self.assertAllEqual(equal_length_components[2][i], result3)
self.assertAllEqual(components[0][i], result1)
self.assertAllEqual(components[1][i], result2)
self.assertAllEqual(components[2][i], result3)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
with self.assertRaises(errors.OutOfRangeError):

View File

@ -66,7 +66,6 @@ from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import base as tracking_base
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import deprecation
@ -2437,26 +2436,25 @@ class DatasetV1Adapter(DatasetV1):
def _ensure_same_dataset_graph(dataset):
"""Walks the dataset graph to ensure all datasets come from the same graph."""
# pylint: disable=protected-access
current_graph = ops.get_default_graph()
bfs_q = Queue.Queue()
bfs_q.put(dataset) # pylint: disable=protected-access
bfs_q.put(dataset)
visited = []
while not bfs_q.empty():
ds = bfs_q.get()
visited.append(ds)
ds_graph = ds._graph # pylint: disable=protected-access
ds_graph = ds._graph
if current_graph != ds_graph:
logging.warning("The graph (" + str(current_graph) + ") of the iterator "
"is different from the graph (" + str(ds_graph) + ") "
"the dataset: " + str(ds._variant_tensor) + " was " # pylint: disable=protected-access
"created in. If you are using the Estimator API, "
"make sure that no part of the dataset returned by the "
"`input_fn` function is defined outside the `input_fn` "
"function. Please ensure that all datasets in the "
"pipeline are created in the same graph as the iterator. "
"NOTE: This warning will become an error in future "
"versions of TensorFlow.")
for input_ds in ds._inputs(): # pylint: disable=protected-access
raise ValueError(
"The graph (" + str(current_graph) + ") of the iterator is different "
"from the graph (" + str(ds_graph) + ") the dataset: " +
str(ds._variant_tensor) + " was created in. If you are using the "
"Estimator API, make sure that no part of the dataset returned by "
"the `input_fn` function is defined outside the `input_fn` function. "
"Please ensure that all datasets in the pipeline are created in the "
"same graph as the iterator.")
for input_ds in ds._inputs():
if input_ds not in visited:
bfs_q.put(input_ds)