diff --git a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py index d9c463d744d..d829863b994 100644 --- a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py @@ -25,11 +25,11 @@ from tensorflow.python.data.experimental.ops import grouping from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -73,14 +73,12 @@ def _get_record_shape(sparse): return tensor_shape.TensorShape([None]) -@test_util.run_all_in_graph_and_eager_modes class BucketBySequenceLengthTest(test_base.DatasetTestBase, parameterized.TestCase): - @parameterized.named_parameters( - ("WithoutPadding", True), - ("WithPadding", False), - ) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(param_no_padding=[True, False]))) def testBucketDropReminder(self, param_no_padding): boundaries = [10, 20, 30] @@ -201,10 +199,9 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase, _test_bucket_by_padding(param_no_padding) - @parameterized.named_parameters( - ("WithoutPadding", True), - ("WithPadding", False), - ) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(param_no_padding=[True, False]))) def testBucket(self, param_no_padding): boundaries = [10, 20, 30] @@ -347,10 +344,9 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase, self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) - @parameterized.named_parameters( - ("WithoutPadding", True), - ("WithPadding", False), - ) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(param_no_padding=[True, False]))) def testTupleElements(self, param_no_padding): def build_dataset(sparse): @@ -381,10 +377,10 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase, _test_tuple_elements_by_padding(param_no_padding) - @parameterized.named_parameters( - ("DoDropRemainder", True), - ("DoNotDropRemainder", False), - ) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(param_drop_remainder=[True, False]))) def testBucketSparse(self, param_drop_remainder): # pylint: disable=g-doc-args """Tests bucketing of sparse tensors (case where `no_padding` == True). diff --git a/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py index 36c61636798..2fa149fcbaa 100644 --- a/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.core.protobuf import config_pb2 from tensorflow.python.compat import compat from tensorflow.python.data.experimental.ops import prefetching_ops @@ -24,6 +26,7 @@ from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import structure +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -35,9 +38,9 @@ from tensorflow.python.util import compat as util_compat # TODO(b/117581999): add eager coverage when supported. -class CopyToDeviceTest(test_base.DatasetTestBase): +class CopyToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDevice(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( @@ -62,7 +65,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDeviceInt32(self): host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) device_dataset = host_dataset.apply( @@ -86,7 +89,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToSameDevice(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( @@ -111,7 +114,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDeviceWithPrefetch(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( @@ -136,7 +139,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyDictToDevice(self): host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) device_dataset = host_dataset.apply( @@ -161,7 +164,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyDictToDeviceWithPrefetch(self): host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) device_dataset = host_dataset.apply( @@ -186,7 +189,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopySparseTensorsToDevice(self): def make_tensor(i): @@ -219,7 +222,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopySparseTensorsToDeviceWithPrefetch(self): def make_tensor(i): @@ -252,7 +255,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDeviceGpu(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -273,7 +276,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDeviceGpuWithPrefetch(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -294,7 +297,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDeviceGpuWithMap(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -332,7 +335,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDeviceGpuInt32(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -352,7 +355,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDeviceGpuInt32AndPrefetch(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -372,7 +375,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDeviceGpuStrings(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -392,7 +395,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDeviceGpuStringsAndPrefetch(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -412,7 +415,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDevicePingPongCPUGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -436,7 +439,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDeviceWithReInit(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( @@ -465,7 +468,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDeviceWithReInitAndPrefetch(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( @@ -494,7 +497,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDeviceGpuWithReInit(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -518,7 +521,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testCopyToDeviceGpuWithReInitAndPrefetch(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -542,7 +545,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testIteratorGetNextAsOptionalOnGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") diff --git a/tensorflow/python/data/experimental/kernel_tests/counter_test.py b/tensorflow/python/data/experimental/kernel_tests/counter_test.py index 79e4523ea43..455e49aafc7 100644 --- a/tensorflow/python/data/experimental/kernel_tests/counter_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/counter_test.py @@ -17,35 +17,33 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.experimental.ops import counter from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes -from tensorflow.python.framework import test_util from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class CounterTest(test_base.DatasetTestBase): +class CounterTest(test_base.DatasetTestBase, parameterized.TestCase): - def testCounter(self): + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(start=3, step=4, expected_output=[[3, 7, 11]]) + + combinations.combine(start=0, step=-1, expected_output=[[0, -1, -2]])) + ) + def testCounter(self, start, step, expected_output): """Test dataset construction using `count`.""" - dataset = counter.Counter(start=3, step=4) + dataset = counter.Counter(start, step) self.assertEqual( [], dataset_ops.get_legacy_output_shapes(dataset).as_list()) self.assertEqual(dtypes.int64, dataset_ops.get_legacy_output_types(dataset)) get_next = self.getNext(dataset) - - negative_dataset = counter.Counter(start=0, step=-1) - negative_get_next = self.getNext(negative_dataset) - - self.assertEqual(3, self.evaluate(get_next())) - self.assertEqual(3 + 4, self.evaluate(get_next())) - self.assertEqual(3 + 2 * 4, self.evaluate(get_next())) - - self.assertEqual(0, self.evaluate(negative_get_next())) - self.assertEqual(-1, self.evaluate(negative_get_next())) - self.assertEqual(-2, self.evaluate(negative_get_next())) + for expected in expected_output: + self.assertEqual(expected, self.evaluate(get_next())) if __name__ == "__main__": diff --git a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py index 4b349ebd811..941ca209848 100644 --- a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py @@ -22,21 +22,22 @@ import gzip import os import zlib +from absl.testing import parameterized + from tensorflow.python.data.experimental.ops import error_ops from tensorflow.python.data.experimental.ops import readers from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.eager import context +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class CsvDatasetTest(test_base.DatasetTestBase): +class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def _setup_files(self, inputs, linebreak='\n', compression_type=None): filenames = [] @@ -117,26 +118,31 @@ class CsvDatasetTest(test_base.DatasetTestBase): dataset = readers.CsvDataset(filenames, **kwargs) self._verify_output_or_err(dataset, expected_output, expected_err_re) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_requiredFields(self): record_defaults = [[]] * 4 inputs = [['1,2,3,4']] self._test_by_comparison(inputs, record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_int(self): record_defaults = [[0]] * 4 inputs = [['1,2,3,4', '5,6,7,8']] self._test_by_comparison(inputs, record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_float(self): record_defaults = [[0.0]] * 4 inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']] self._test_by_comparison(inputs, record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_string(self): record_defaults = [['']] * 4 inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] self._test_by_comparison(inputs, record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withEmptyFields(self): record_defaults = [[0]] * 4 inputs = [[',,,', '1,1,1,', ',2,2,2']] @@ -144,6 +150,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errWithUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4']] @@ -152,6 +159,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): expected_err_re='Unquoted fields cannot have quotes inside', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errWithUnescapedQuotes(self): record_defaults = [['']] * 3 inputs = [['"a"b","c","d"']] @@ -161,6 +169,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): 'Quote inside a string has to be escaped by another quote', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_ignoreErrWithUnescapedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] @@ -169,6 +178,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): dataset = dataset.apply(error_ops.ignore_errors()) self._verify_output_or_err(dataset, [['e', 'f', 'g']]) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_ignoreErrWithUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] @@ -177,12 +187,14 @@ class CsvDatasetTest(test_base.DatasetTestBase): dataset = dataset.apply(error_ops.ignore_errors()) self._verify_output_or_err(dataset, [['e', 'f', 'g']]) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4']] self._test_by_comparison( inputs, record_defaults=record_defaults, use_quote_delim=False) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_mixedTypes(self): record_defaults = [ constant_op.constant([], dtype=dtypes.int32), @@ -193,30 +205,35 @@ class CsvDatasetTest(test_base.DatasetTestBase): inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']] self._test_by_comparison(inputs, record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withUseQuoteDelimFalse(self): record_defaults = [['']] * 4 inputs = [['1,2,"3,4"', '"5,6",7,8']] self._test_by_comparison( inputs, record_defaults=record_defaults, use_quote_delim=False) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withFieldDelim(self): record_defaults = [[0]] * 4 inputs = [['1:2:3:4', '5:6:7:8']] self._test_by_comparison( inputs, record_defaults=record_defaults, field_delim=':') + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withNaValue(self): record_defaults = [[0]] * 4 inputs = [['1,NA,3,4', 'NA,6,7,8']] self._test_by_comparison( inputs, record_defaults=record_defaults, na_value='NA') + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withSelectCols(self): record_defaults = [['']] * 2 inputs = [['1,2,3,4', '"5","6","7","8"']] self._test_by_comparison( inputs, record_defaults=record_defaults, select_cols=[1, 2]) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withSelectColsTooHigh(self): record_defaults = [[0]] * 2 inputs = [['1,2,3,4', '5,6,7,8']] @@ -226,23 +243,27 @@ class CsvDatasetTest(test_base.DatasetTestBase): record_defaults=record_defaults, select_cols=[3, 4]) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withOneCol(self): record_defaults = [['NA']] inputs = [['0', '', '2']] self._test_dataset( inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withMultipleFiles(self): record_defaults = [[0]] * 4 inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] self._test_by_comparison(inputs, record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withLeadingAndTrailingSpaces(self): record_defaults = [[0.0]] * 4 inputs = [['0, 1, 2, 3']] expected = [[0.0, 1.0, 2.0, 3.0]] self._test_dataset(inputs, expected, record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errorWithMissingDefault(self): record_defaults = [[]] * 2 inputs = [['0,']] @@ -251,6 +272,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): expected_err_re='Field 1 is required but missing in record!', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errorWithFewerDefaultsThanFields(self): record_defaults = [[0.0]] * 2 inputs = [['0,1,2,3']] @@ -259,6 +281,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): expected_err_re='Expect 2 fields but have more in record', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errorWithMoreDefaultsThanFields(self): record_defaults = [[0.0]] * 5 inputs = [['0,1,2,3']] @@ -267,6 +290,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): expected_err_re='Expect 5 fields but have 4 in record', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withHeader(self): record_defaults = [[0]] * 2 inputs = [['col1,col2', '1,2']] @@ -278,6 +302,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): header=True, ) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withHeaderAndNoRecords(self): record_defaults = [[0]] * 2 inputs = [['col1,col2']] @@ -289,6 +314,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): header=True, ) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errorWithHeaderEmptyFile(self): record_defaults = [[0]] * 2 inputs = [[]] @@ -300,12 +326,14 @@ class CsvDatasetTest(test_base.DatasetTestBase): header=True, ) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withEmptyFile(self): record_defaults = [['']] * 2 inputs = [['']] # Empty file self._test_dataset( inputs, expected_output=[], record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errorWithEmptyRecord(self): record_defaults = [['']] * 2 inputs = [['', '1,2']] # First record is empty @@ -314,6 +342,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): expected_err_re='Expect 2 fields but have 1 in record', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withChainedOps(self): # Testing that one dataset can create multiple iterators fine. # `repeat` creates multiple iterators from the same C++ Dataset. @@ -325,6 +354,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): ds_actual.repeat(5).prefetch(1), ds_expected.repeat(5).prefetch(1)) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withTypeDefaults(self): # Testing using dtypes as record_defaults for required fields record_defaults = [dtypes.float32, [0.0]] @@ -335,6 +365,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): record_defaults=record_defaults, ) + @combinations.generate(test_base.default_test_combinations()) def testMakeCsvDataset_fieldOrder(self): data = [[ '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19', @@ -352,6 +383,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): ## The following tests exercise parsing logic for quoted fields + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withQuoted(self): record_defaults = [['']] * 4 inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']] @@ -363,6 +395,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): self._test_dataset( inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withNewLine(self): # In this case, we expect it to behave differently from # TextLineDataset->map(decode_csv) since that flow has bugs @@ -371,6 +404,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] self._test_dataset(inputs, expected, record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withNewLineInUnselectedCol(self): record_defaults = [['']] inputs = [['1,"2\n3",4', '5,6,7']] @@ -380,6 +414,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): record_defaults=record_defaults, select_cols=[0]) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withMultipleNewLines(self): # In this case, we expect it to behave differently from # TextLineDataset->map(decode_csv) since that flow has bugs @@ -388,6 +423,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] self._test_dataset(inputs, expected, record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errorWithTerminateMidRecord(self): record_defaults = [['']] * 4 inputs = [['a,b,c,"a']] @@ -397,6 +433,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): 'Reached end of file without closing quoted field in record', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withEscapedQuotes(self): record_defaults = [['']] * 4 inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']] @@ -406,6 +443,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): ## Testing that parsing works with all buffer sizes, quoted/unquoted fields, ## and different types of line breaks + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withInvalidBufferSize(self): record_defaults = [['']] * 4 inputs = [['a,b,c,d']] @@ -432,6 +470,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): record_defaults=record_defaults, buffer_size=i) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withLF(self): record_defaults = [['NA']] * 3 inputs = [['abc,def,ghi', '0,1,2', ',,']] @@ -439,6 +478,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): self._test_dataset_on_buffer_sizes( inputs, expected, linebreak='\n', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withCR(self): # Test that when the line separator is '\r', parsing works with all buffer # sizes @@ -448,6 +488,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): self._test_dataset_on_buffer_sizes( inputs, expected, linebreak='\r', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withCRLF(self): # Test that when the line separator is '\r\n', parsing works with all buffer # sizes @@ -457,6 +498,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): self._test_dataset_on_buffer_sizes( inputs, expected, linebreak='\r\n', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withBufferSizeAndQuoted(self): record_defaults = [['NA']] * 3 inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] @@ -465,6 +507,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): self._test_dataset_on_buffer_sizes( inputs, expected, linebreak='\n', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withCRAndQuoted(self): # Test that when the line separator is '\r', parsing works with all buffer # sizes @@ -475,6 +518,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): self._test_dataset_on_buffer_sizes( inputs, expected, linebreak='\r', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withCRLFAndQuoted(self): # Test that when the line separator is '\r\n', parsing works with all buffer # sizes @@ -485,6 +529,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): self._test_dataset_on_buffer_sizes( inputs, expected, linebreak='\r\n', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withGzipCompressionType(self): record_defaults = [['NA']] * 3 inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] @@ -497,6 +542,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): compression_type='GZIP', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withZlibCompressionType(self): record_defaults = [['NA']] * 3 inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] @@ -509,6 +555,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): compression_type='ZLIB', record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withScalarDefaults(self): record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4 inputs = [[',,,', '1,1,1,', ',2,2,2']] @@ -516,6 +563,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], record_defaults=record_defaults) + @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_with2DDefaults(self): record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4 inputs = [[',,,', '1,1,1,', ',2,2,2']] diff --git a/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py index cca7ae073ee..5dd1bb0532c 100644 --- a/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py @@ -17,20 +17,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class DenseToSparseBatchTest(test_base.DatasetTestBase): +class DenseToSparseBatchTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testDenseToSparseBatchDataset(self): components = np.random.randint(12, size=(100,)).astype(np.int32) dataset = dataset_ops.Dataset.from_tensor_slices( @@ -53,6 +54,7 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testDenseToSparseBatchDatasetWithUnknownShape(self): components = np.random.randint(5, size=(40,)).astype(np.int32) dataset = dataset_ops.Dataset.from_tensor_slices( @@ -80,12 +82,14 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testDenseToSparseBatchDatasetWithInvalidShape(self): input_tensor = array_ops.constant([[1]]) with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"): dataset_ops.Dataset.from_tensors(input_tensor).apply( batching.dense_to_sparse_batch(4, [-2])) + @combinations.generate(test_base.default_test_combinations()) def testDenseToSparseBatchDatasetShapeErrors(self): def dataset_fn(input_tensor): diff --git a/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py index 4a8c7d1ccc6..fc18afaa842 100644 --- a/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py @@ -17,22 +17,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import random_seed -from tensorflow.python.framework import test_util from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class DirectedInterleaveDatasetTest(test_base.DatasetTestBase): +class DirectedInterleaveDatasetTest(test_base.DatasetTestBase, + parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testBasic(self): selector_dataset = dataset_ops.Dataset.range(10).repeat(100) input_datasets = [ @@ -76,6 +78,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase): return freqs + @combinations.generate(test_base.default_test_combinations()) def testSampleFromDatasets(self): random_seed.set_random_seed(1619) num_samples = 5000 @@ -95,6 +98,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase): freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) + @combinations.generate(test_base.default_test_combinations()) def testSelectFromDatasets(self): words = [b"foo", b"bar", b"baz"] datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words] @@ -107,6 +111,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.default_test_combinations()) def testErrors(self): with self.assertRaisesRegexp(ValueError, r"vector of length `len\(datasets\)`"): diff --git a/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py b/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py index f65740c5651..59c2ef68d99 100644 --- a/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py @@ -23,25 +23,30 @@ from tensorflow.python.data.experimental.ops import get_single_element from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import function +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase): - @parameterized.named_parameters( - ("Zero", 0, 1), - ("Five", 5, 1), - ("Ten", 10, 1), - ("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."), - ("MoreThanOne", 0, 2, errors.InvalidArgumentError, - "Dataset had more than one element."), - ) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + skip=[0, 5, 10], take=[1], error=[None], error_msg=[None]) + + combinations.combine( + skip=[100], + take=[1], + error=[errors.InvalidArgumentError], + error_msg=["Dataset was empty."]) + combinations.combine( + skip=[0], + take=[2], + error=[errors.InvalidArgumentError], + error_msg=["Dataset had more than one element."]))) def testGetSingleElement(self, skip, take, error=None, error_msg=None): def make_sparse(x): @@ -62,6 +67,7 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaisesRegexp(error, error_msg): self.evaluate(get_single_element.get_single_element(dataset)) + @combinations.generate(test_base.default_test_combinations()) def testWindow(self): """Test that `get_single_element()` can consume a nested dataset.""" def flat_map_func(ds): @@ -73,6 +79,7 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( dataset, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) + @combinations.generate(test_base.default_test_combinations()) def testSideEffect(self): counter_var = variables.Variable(0) @@ -92,6 +99,7 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(self.evaluate(fn()), b"hello") self.assertEqual(self.evaluate(counter_var), 1) + @combinations.generate(test_base.default_test_combinations()) def testAutomaticControlDependencies(self): counter_var = variables.Variable(1) diff --git a/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py b/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py index 0e9042b2ef8..bf823143d57 100644 --- a/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py @@ -17,25 +17,26 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.data.experimental.ops import grouping from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class GroupByReducerTest(test_base.DatasetTestBase): +class GroupByReducerTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testSum(self): reducer = grouping.Reducer( init_func=lambda _: np.int64(0), @@ -49,6 +50,7 @@ class GroupByReducerTest(test_base.DatasetTestBase): expected_shapes=tensor_shape.TensorShape([]), expected_output=[(i - 1) * i, i * i]) + @combinations.generate(test_base.default_test_combinations()) def testAverage(self): def reduce_fn(x, y): @@ -68,6 +70,7 @@ class GroupByReducerTest(test_base.DatasetTestBase): expected_shapes=tensor_shape.TensorShape([]), expected_output=[i - 1, i]) + @combinations.generate(test_base.default_test_combinations()) def testConcat(self): components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray) reducer = grouping.Reducer( @@ -84,6 +87,7 @@ class GroupByReducerTest(test_base.DatasetTestBase): expected_shapes=tensor_shape.TensorShape([]), expected_output=[b"acegikmoqs"[:i], b"bdfhjlnprt"[:i]]) + @combinations.generate(test_base.default_test_combinations()) def testSparseSum(self): def _sparse(i): return sparse_tensor.SparseTensorValue( @@ -103,6 +107,7 @@ class GroupByReducerTest(test_base.DatasetTestBase): expected_shapes=tensor_shape.TensorShape([]), expected_output=[(i - 1) * i, i * i]) + @combinations.generate(test_base.default_test_combinations()) def testChangingStateShape(self): def reduce_fn(x, _): @@ -130,6 +135,7 @@ class GroupByReducerTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testTypeMismatch(self): reducer = grouping.Reducer( init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32), @@ -144,6 +150,7 @@ class GroupByReducerTest(test_base.DatasetTestBase): grouping.group_by_reducer(lambda _: np.int64(0), reducer)) # TODO(b/78665031): Remove once non-scalar keys are supported. + @combinations.generate(test_base.default_test_combinations()) def testInvalidKeyShape(self): reducer = grouping.Reducer( init_func=lambda x: np.int64(0), @@ -157,6 +164,7 @@ class GroupByReducerTest(test_base.DatasetTestBase): grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer)) # TODO(b/78665031): Remove once non-int64 keys are supported. + @combinations.generate(test_base.default_test_combinations()) def testInvalidKeyType(self): reducer = grouping.Reducer( init_func=lambda x: np.int64(0), @@ -169,6 +177,7 @@ class GroupByReducerTest(test_base.DatasetTestBase): dataset.apply( grouping.group_by_reducer(lambda _: "wrong", reducer)) + @combinations.generate(test_base.default_test_combinations()) def testTuple(self): def init_fn(_): return np.array([], dtype=np.int64), np.int64(0) diff --git a/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py b/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py index e529364e509..2495083cf63 100644 --- a/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py @@ -17,17 +17,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.data.experimental.ops import grouping from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops @@ -37,8 +38,7 @@ from tensorflow.python.platform import test # NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. # Currently, they use a constant batch size, though should be made to use a # different batch size per key. -@test_util.run_all_in_graph_and_eager_modes -class GroupByWindowTest(test_base.DatasetTestBase): +class GroupByWindowTest(test_base.DatasetTestBase, parameterized.TestCase): def _dynamicPad(self, bucket, window, window_size): # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a @@ -51,6 +51,7 @@ class GroupByWindowTest(test_base.DatasetTestBase): 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape( [None]), tensor_shape.TensorShape([3]))))) + @combinations.generate(test_base.default_test_combinations()) def testSingleBucket(self): def _map_fn(v): @@ -80,6 +81,7 @@ class GroupByWindowTest(test_base.DatasetTestBase): self.assertAllEqual(expected_unk_int64, bucketed_values[1]) self.assertAllEqual(expected_vec3_str, bucketed_values[2]) + @combinations.generate(test_base.default_test_combinations()) def testEvenOddBuckets(self): def _map_fn(v): @@ -132,6 +134,7 @@ class GroupByWindowTest(test_base.DatasetTestBase): self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1]) self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2]) + @combinations.generate(test_base.default_test_combinations()) def testEvenOddBucketsFilterOutAllOdd(self): def _map_fn(v): @@ -173,6 +176,7 @@ class GroupByWindowTest(test_base.DatasetTestBase): self.assertAllEqual( np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"]) + @combinations.generate(test_base.default_test_combinations()) def testDynamicWindowSize(self): components = np.arange(100).astype(np.int64) @@ -202,6 +206,7 @@ class GroupByWindowTest(test_base.DatasetTestBase): self.assertEqual(batches, 15) + @combinations.generate(test_base.default_test_combinations()) def testSimple(self): components = np.random.randint(100, size=(200,)).astype(np.int64) dataset = dataset_ops.Dataset.from_tensor_slices( @@ -222,6 +227,7 @@ class GroupByWindowTest(test_base.DatasetTestBase): self.assertGreaterEqual(num_full_batches, 24) self.assertTrue(all(c == 4 for c in counts[:num_full_batches])) + @combinations.generate(test_base.default_test_combinations()) def testImmediateOutput(self): components = np.array( [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) @@ -240,6 +246,7 @@ class GroupByWindowTest(test_base.DatasetTestBase): self.assertAllEqual([2, 2, 2, 2], self.evaluate(get_next())) self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next())) + @combinations.generate(test_base.default_test_combinations()) def testSmallGroups(self): components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64) dataset = dataset_ops.Dataset.from_tensor_slices(components).apply( @@ -252,6 +259,7 @@ class GroupByWindowTest(test_base.DatasetTestBase): self.assertAllEqual([0, 0, 0], self.evaluate(get_next())) self.assertAllEqual([1], self.evaluate(get_next())) + @combinations.generate(test_base.default_test_combinations()) def testEmpty(self): dataset = dataset_ops.Dataset.range(4).apply( grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0)) @@ -262,6 +270,7 @@ class GroupByWindowTest(test_base.DatasetTestBase): "Window size must be greater than zero, but got 0."): print(self.evaluate(get_next())) + @combinations.generate(test_base.default_test_combinations()) def testReduceFuncError(self): components = np.random.randint(100, size=(200,)).astype(np.int64) @@ -280,6 +289,7 @@ class GroupByWindowTest(test_base.DatasetTestBase): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testConsumeWindowDatasetMoreThanOnce(self): components = np.random.randint(50, size=(200,)).astype(np.int64) @@ -311,6 +321,7 @@ class GroupByWindowTest(test_base.DatasetTestBase): counts.append(tight_result.shape[0]) self.assertEqual(len(components), sum(counts)) + @combinations.generate(test_base.default_test_combinations()) def testShortCircuit(self): dataset = dataset_ops.Dataset.range(10) diff --git a/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py b/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py index c37439f328b..5ed72767425 100644 --- a/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py @@ -19,14 +19,15 @@ from __future__ import print_function import os +from absl.testing import parameterized import numpy as np from tensorflow.python.data.experimental.ops import error_ops from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import io_ops @@ -36,9 +37,9 @@ from tensorflow.python.util import compat _NUMPY_RANDOM_SEED = 42 -@test_util.run_all_in_graph_and_eager_modes -class IgnoreErrorsTest(test_base.DatasetTestBase): +class IgnoreErrorsTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testMapIgnoreError(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) @@ -53,6 +54,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testParallelMapIgnoreError(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) @@ -67,6 +69,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testReadFileIgnoreError(self): def write_string_to_file(value, filename): @@ -102,6 +105,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testTFRecordDatasetIgnoreError(self): filenames = [] for i in range(5): @@ -126,6 +130,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testZipIgnoreError(self): a = dataset_ops.Dataset.from_tensor_slices([1., 2., 0., 4.]) b = a.map(lambda x: array_ops.check_numerics(1. / x, "error")) diff --git a/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py index 2ddff457bc4..980fd03b073 100644 --- a/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py @@ -17,26 +17,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base from tensorflow.python.data.experimental.ops import readers +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.util import nest +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes class MakeBatchedFeaturesDatasetTest( - reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase): + reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase, + parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testRead(self): for batch_size in [1, 2]: for num_epochs in [1, 10]: @@ -85,6 +88,7 @@ class MakeBatchedFeaturesDatasetTest( with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch() + @combinations.generate(test_base.default_test_combinations()) def testReadWithEquivalentDataset(self): features = { "file": parsing_ops.FixedLenFeature([], dtypes.int64), @@ -103,6 +107,7 @@ class MakeBatchedFeaturesDatasetTest( with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.default_test_combinations()) def testReadWithFusedShuffleRepeatDataset(self): num_epochs = 5 total_records = num_epochs * self._num_records @@ -151,6 +156,7 @@ class MakeBatchedFeaturesDatasetTest( all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) self.assertFalse(all_equal) + @combinations.generate(test_base.default_test_combinations()) def testParallelReadersAndParsers(self): num_epochs = 5 for batch_size in [1, 2]: @@ -186,6 +192,7 @@ class MakeBatchedFeaturesDatasetTest( with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch() + @combinations.generate(test_base.default_test_combinations()) def testDropFinalBatch(self): for batch_size in [1, 2]: for num_epochs in [1, 10]: @@ -201,6 +208,7 @@ class MakeBatchedFeaturesDatasetTest( if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. self.assertEqual(tensor.shape[0], batch_size) + @combinations.generate(test_base.default_test_combinations()) def testIndefiniteRepeatShapeInference(self): dataset = self.make_batch_feature( filenames=self.test_filenames[0], @@ -213,6 +221,7 @@ class MakeBatchedFeaturesDatasetTest( if issubclass(clazz, ops.Tensor): self.assertEqual(32, shape[0]) + @combinations.generate(test_base.default_test_combinations()) def testOldStyleReader(self): with self.assertRaisesRegexp( TypeError, r"The `reader` argument must return a `Dataset` object. " diff --git a/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py index 16c323b3790..5f8382f43c4 100644 --- a/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py @@ -21,21 +21,21 @@ import gzip import os import zlib +from absl.testing import parameterized import numpy as np from tensorflow.python.data.experimental.ops import readers from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class MakeCsvDatasetTest(test_base.DatasetTestBase): +class MakeCsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs): return readers.make_csv_dataset( @@ -126,6 +126,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): self._verify_output(dataset, batch_size, num_epochs, label_name, expected_output, expected_keys) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset(self): """Tests making a CSV dataset with keys and defaults provided.""" record_defaults = [ @@ -157,6 +158,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): column_defaults=record_defaults, ) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withBatchSizeAndEpochs(self): """Tests making a CSV dataset with keys and defaults provided.""" record_defaults = [ @@ -188,6 +190,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): column_defaults=record_defaults, ) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withCompressionType(self): """Tests `compression_type` argument.""" record_defaults = [ @@ -221,6 +224,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): compression_type=compression_type, ) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withCompressionTypeAndNoColumnNames(self): """Tests `compression_type` argument.""" record_defaults = [ @@ -269,6 +273,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): compression_type="ZLIB", ) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withBadInputs(self): """Tests that exception is raised when input is malformed. """ @@ -304,6 +309,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): label_name="not_a_real_label", column_names=column_names) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withNoLabel(self): """Tests making a CSV dataset with no label provided.""" record_defaults = [ @@ -333,6 +339,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): column_defaults=record_defaults, ) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withNoHeader(self): """Tests that datasets can be created from CSV files with no header line. """ @@ -363,6 +370,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): column_defaults=record_defaults, ) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withTypes(self): """Tests that defaults can be a dtype instead of a Tensor for required vals. """ @@ -394,6 +402,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): column_defaults=record_defaults, ) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withNoColNames(self): """Tests that datasets can be created when column names are not specified. @@ -427,6 +436,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): column_defaults=record_defaults, ) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withTypeInferenceMismatch(self): # Test that error is thrown when num fields doesn't match columns column_names = ["col%d" % i for i in range(5)] @@ -442,6 +452,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): batch_size=2, num_epochs=10) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withTypeInference(self): """Tests that datasets can be created when no defaults are specified. @@ -468,6 +479,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): header=True, ) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withTypeInferenceFallthrough(self): """Tests that datasets can be created when no defaults are specified. @@ -498,6 +510,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): header=True, ) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withNAValuesAndFieldDelim(self): """Tests that datasets can be created from different delim and na_value.""" column_names = ["col%d" % i for i in range(5)] @@ -520,6 +533,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): field_delim=" ", ) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withSelectCols(self): record_defaults = [ constant_op.constant([], dtypes.int32), @@ -588,6 +602,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): select_columns=[column_names[i] for i in select_cols], ) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withSelectColsError(self): record_defaults = [ constant_op.constant([], dtypes.int32), @@ -626,6 +641,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): label_name=None, select_columns=["invalid_col_name"]) + @combinations.generate(test_base.default_test_combinations()) def testMakeCSVDataset_withShuffle(self): record_defaults = [ constant_op.constant([], dtypes.int32), @@ -710,6 +726,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase): all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) self.assertFalse(all_equal) + @combinations.generate(test_base.default_test_combinations()) def testIndefiniteRepeatShapeInference(self): column_names = ["col%d" % i for i in range(5)] inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ diff --git a/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py index ec1760398fa..a67ccd92842 100644 --- a/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py @@ -17,19 +17,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base from tensorflow.python.data.experimental.ops import readers +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.ops import string_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes class MakeTFRecordDatasetTest( - reader_dataset_ops_test_base.TFRecordDatasetTestBase): + reader_dataset_ops_test_base.TFRecordDatasetTestBase, + parameterized.TestCase): def _read_test(self, batch_size, num_epochs, file_index=None, num_parallel_reads=1, drop_final_batch=False, parser_fn=False): @@ -63,6 +66,7 @@ class MakeTFRecordDatasetTest( with self.assertRaises(errors.OutOfRangeError): self.evaluate(outputs()) + @combinations.generate(test_base.default_test_combinations()) def testRead(self): for batch_size in [1, 2]: for num_epochs in [1, 3]: @@ -78,6 +82,7 @@ class MakeTFRecordDatasetTest( # Basic test: read from both files, with parallel reads. self._read_test(batch_size, num_epochs, num_parallel_reads=8) + @combinations.generate(test_base.default_test_combinations()) def testDropFinalBatch(self): for batch_size in [1, 2, 10]: for num_epochs in [1, 3]: @@ -91,6 +96,7 @@ class MakeTFRecordDatasetTest( self._read_test(batch_size, num_epochs, num_parallel_reads=8, drop_final_batch=True) + @combinations.generate(test_base.default_test_combinations()) def testParserFn(self): for batch_size in [1, 2]: for num_epochs in [1, 3]: @@ -145,6 +151,7 @@ class MakeTFRecordDatasetTest( actual.extend(b) self.assertAllEqual(sorted(expected), sorted(actual)) + @combinations.generate(test_base.default_test_combinations()) def testShuffle(self): for batch_size in [1, 2]: for num_epochs in [1, 3]: @@ -156,6 +163,7 @@ class MakeTFRecordDatasetTest( self._shuffle_test(batch_size, num_epochs, num_parallel_reads, seed=21345) + @combinations.generate(test_base.default_test_combinations()) def testIndefiniteRepeatShapeInference(self): dataset = readers.make_tf_record_dataset( file_pattern=self.test_filenames, num_epochs=None, batch_size=32) diff --git a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py index a42ce40fb29..a2cc54d104e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py @@ -19,17 +19,19 @@ from __future__ import print_function import time +from absl.testing import parameterized + from tensorflow.python.client import session from tensorflow.python.data.experimental.ops import map_defun from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.eager import function +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_spec -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import data_flow_ops @@ -38,9 +40,11 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -@test_util.run_v1_only("b/123903858: Add eager and V2 test coverage") -class MapDefunTest(test_base.DatasetTestBase): +# TODO(b/123903858): Add eager and V2 test coverage +class MapDefunTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testNoIntraOpLimit(self): @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) @@ -55,6 +59,8 @@ class MapDefunTest(test_base.DatasetTestBase): expected = elems * 2 + 3 self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunSimple(self): @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) @@ -67,6 +73,8 @@ class MapDefunTest(test_base.DatasetTestBase): expected = elems * 2 + 3 self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunMismatchedTypes(self): @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) @@ -79,6 +87,8 @@ class MapDefunTest(test_base.DatasetTestBase): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(r) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunReduceDim(self): # Tests where the output has a different rank from the input @@ -92,6 +102,8 @@ class MapDefunTest(test_base.DatasetTestBase): expected = constant_op.constant([1, 3, 5]) self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunMultipleOutputs(self): @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) @@ -105,6 +117,8 @@ class MapDefunTest(test_base.DatasetTestBase): expected = [elems, elems * 2 + 3] self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunShapeInference(self): @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) @@ -116,6 +130,8 @@ class MapDefunTest(test_base.DatasetTestBase): result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0] self.assertEqual(result.get_shape(), (3, 2)) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunPartialShapeInference(self): @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) @@ -126,6 +142,8 @@ class MapDefunTest(test_base.DatasetTestBase): result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)]) self.assertEqual(result[0].get_shape().as_list(), [None, 2]) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self): @function.defun(input_signature=[ @@ -145,6 +163,8 @@ class MapDefunTest(test_base.DatasetTestBase): "All inputs must have the same dimension 0."): sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]}) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunRaisesDefunError(self): @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) @@ -157,6 +177,8 @@ class MapDefunTest(test_base.DatasetTestBase): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(result) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunCancelledCorrectly(self): @function.defun(input_signature=[tensor_spec.TensorSpec([5], dtypes.int64)]) @@ -173,6 +195,8 @@ class MapDefunTest(test_base.DatasetTestBase): r"indices = 10 is not in \[0, 5\)"): self.evaluate(map_defun_op) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunWithUnspecifiedOutputShape(self): @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) @@ -190,6 +214,8 @@ class MapDefunTest(test_base.DatasetTestBase): self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1)) self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2)) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunWithDifferentOutputShapeEachRun(self): @function.defun( @@ -204,6 +230,8 @@ class MapDefunTest(test_base.DatasetTestBase): self.assertAllEqual( sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]]) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunWithWrongOutputShape(self): @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) @@ -216,6 +244,8 @@ class MapDefunTest(test_base.DatasetTestBase): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(r) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunWithInvalidInput(self): @function.defun( @@ -233,6 +263,8 @@ class MapDefunTest(test_base.DatasetTestBase): with self.assertRaises(errors.InvalidArgumentError): sess.run(r, feed_dict={p: 0}) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunWithParentCancellation(self): # Checks that a cancellation of the parent graph is threaded through to # MapDefunOp correctly. @@ -254,6 +286,8 @@ class MapDefunTest(test_base.DatasetTestBase): sess.close() thread.join() + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunWithCapturedInputs(self): c = constant_op.constant(2) @@ -266,6 +300,8 @@ class MapDefunTest(test_base.DatasetTestBase): expected = x + c self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op)) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunWithVariantTensor(self): @function.defun( @@ -288,6 +324,8 @@ class MapDefunTest(test_base.DatasetTestBase): actual = self.evaluate(deserialized) self.assertValuesEqual(expected, actual) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunWithVariantTensorAsCaptured(self): st = sparse_tensor.SparseTensor( @@ -309,6 +347,8 @@ class MapDefunTest(test_base.DatasetTestBase): actual = self.evaluate(deserialized) self.assertValuesEqual(expected, actual) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testMapDefunWithStrTensor(self): @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) diff --git a/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py b/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py index 811a58262ef..d7944042c6e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py @@ -28,14 +28,13 @@ from tensorflow.python.data.experimental.ops import threadpool from tensorflow.python.data.experimental.ops import unique from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.ops import script_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes class OverrideThreadpoolTest(test_base.DatasetTestBase, parameterized.TestCase): @@ -70,17 +69,13 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase, # perform work. self.assertLessEqual(len(thread_ids), num_threads) - @parameterized.named_parameters( - ("1", 1, None), - ("2", 2, None), - ("3", 4, None), - ("4", 8, None), - ("5", 16, None), - ("6", 4, -1), - ("7", 4, 0), - ("8", 4, 1), - ("9", 4, 4), - ) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + num_threads=[1, 2, 4, 8, 16], max_intra_op_parallelism=[None]) + + combinations.combine( + num_threads=[4], max_intra_op_parallelism=[-1, 0, 4]))) def testNumThreadsDeprecated(self, num_threads, max_intra_op_parallelism): def override_threadpool_fn(dataset): @@ -93,20 +88,17 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase, self._testNumThreadsHelper(num_threads, override_threadpool_fn) - @parameterized.named_parameters( - ("1", 1, None), - ("2", 2, None), - ("3", 4, None), - ("4", 8, None), - ("5", 16, None), - ("6", None, 0), - ("7", None, 1), - ("8", None, 4), - ("9", 4, 0), - ("10", 4, 1), - ("11", 4, 4), - ("12", None, None), - ) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + num_threads=[1, 2, 4, 8, 16], max_intra_op_parallelism=[None]) + + combinations.combine( + num_threads=[None], max_intra_op_parallelism=[0, 1, 4]) + + combinations.combine( + num_threads=[4], max_intra_op_parallelism=[0, 1, 4]) + + combinations.combine( + num_threads=[None], max_intra_op_parallelism=[None]))) def testNumThreads(self, num_threads, max_intra_op_parallelism): def override_threadpool_fn(dataset): @@ -121,6 +113,7 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase, self._testNumThreadsHelper(num_threads, override_threadpool_fn) + @combinations.generate(test_base.default_test_combinations()) def testMaxIntraOpParallelismAsGraphDefInternal(self): dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1) diff --git a/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py b/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py index 1fe5655ec02..14d3c9d6d7f 100644 --- a/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py @@ -22,24 +22,25 @@ import math import threading import time +from absl.testing import parameterized import numpy as np from six.moves import zip_longest from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class ParallelInterleaveTest(test_base.DatasetTestBase): +# TODO(feihugis): refactor this test to be parameterized. +class ParallelInterleaveTest(test_base.DatasetTestBase, parameterized.TestCase): def setUp(self): @@ -116,6 +117,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): num_open -= 1 break + @combinations.generate(test_base.default_test_combinations()) def testPythonImplementation(self): input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] @@ -136,6 +138,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % (index, expected, produced)) + @combinations.generate(test_base.default_test_combinations()) def testPythonImplementationBlockLength(self): input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2 expected_elements = [ @@ -147,6 +150,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % (index, expected, produced)) + @combinations.generate(test_base.default_test_combinations()) def testPythonImplementationEmptyLists(self): input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6]] @@ -189,18 +193,23 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.default_test_combinations()) def testSingleThreaded(self): self._testSingleThreaded() + @combinations.generate(test_base.default_test_combinations()) def testSingleThreadedSloppy(self): self._testSingleThreaded(sloppy=True) + @combinations.generate(test_base.default_test_combinations()) def testSingleThreadedPrefetch1Itr(self): self._testSingleThreaded(prefetch_input_elements=1) + @combinations.generate(test_base.default_test_combinations()) def testSingleThreadedPrefetch1ItrSloppy(self): self._testSingleThreaded(prefetch_input_elements=1, sloppy=True) + @combinations.generate(test_base.default_test_combinations()) def testSingleThreadedRagged(self): # Tests a sequence with wildly different elements per iterator. self.skipTest("b/131722904") @@ -259,9 +268,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.default_test_combinations()) def testTwoThreadsNoContention(self): self._testTwoThreadsNoContention() + @combinations.generate(test_base.default_test_combinations()) def testTwoThreadsNoContentionSloppy(self): self._testTwoThreadsNoContention(sloppy=True) @@ -306,9 +317,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.default_test_combinations()) def testTwoThreadsNoContentionWithRaces(self): self._testTwoThreadsNoContentionWithRaces() + @combinations.generate(test_base.default_test_combinations()) def testTwoThreadsNoContentionWithRacesSloppy(self): self._testTwoThreadsNoContentionWithRaces(sloppy=True) @@ -343,9 +356,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.default_test_combinations()) def testTwoThreadsNoContentionBlockLength(self): self._testTwoThreadsNoContentionBlockLength() + @combinations.generate(test_base.default_test_combinations()) def testTwoThreadsNoContentionBlockLengthSloppy(self): self._testTwoThreadsNoContentionBlockLength(sloppy=True) @@ -391,9 +406,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.default_test_combinations()) def testTwoThreadsNoContentionWithRacesAndBlocking(self): self._testTwoThreadsNoContentionWithRacesAndBlocking() + @combinations.generate(test_base.default_test_combinations()) def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self): self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True) @@ -411,9 +428,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.default_test_combinations()) def testEmptyInput(self): self._testEmptyInput() + @combinations.generate(test_base.default_test_combinations()) def testEmptyInputSloppy(self): self._testEmptyInput(sloppy=True) @@ -431,9 +450,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.default_test_combinations()) def testNonEmptyInputIntoEmptyOutputs(self): self._testNonEmptyInputIntoEmptyOutputs() + @combinations.generate(test_base.default_test_combinations()) def testNonEmptyInputIntoEmptyOutputsSloppy(self): self._testNonEmptyInputIntoEmptyOutputs(sloppy=True) @@ -469,12 +490,15 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): "At index %s: %s expected, got: %s" % (i, expected_element, actual_element)) + @combinations.generate(test_base.default_test_combinations()) def testPartiallyEmptyOutputs(self): self._testPartiallyEmptyOutputs() + @combinations.generate(test_base.default_test_combinations()) def testPartiallyEmptyOutputsSloppy(self): self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0) + @combinations.generate(test_base.default_test_combinations()) def testDelayedOutputSloppy(self): # Explicitly control the sequence of events to ensure we correctly avoid # head-of-line blocking. @@ -500,6 +524,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.default_test_combinations()) def testBlockLengthWithContentionSloppy(self): self.skipTest("b/131722904") self._clear_coordination_events() @@ -557,9 +582,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): self.read_coordination_events[i].acquire() self.write_coordination_events[i].set() + @combinations.generate(test_base.default_test_combinations()) def testEarlyExit(self): self._testEarlyExit() + @combinations.generate(test_base.default_test_combinations()) def testEarlyExitSloppy(self): self._testEarlyExit(sloppy=True) @@ -584,12 +611,15 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2) self.assertItemsEqual(output_values, expected_values) + @combinations.generate(test_base.default_test_combinations()) def testTooManyReaders(self): self._testTooManyReaders() + @combinations.generate(test_base.default_test_combinations()) def testTooManyReadersSloppy(self): self._testTooManyReaders(sloppy=True) + @combinations.generate(test_base.default_test_combinations()) def testSparse(self): def _map_fn(i): return sparse_tensor.SparseTensor( @@ -610,6 +640,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testErrorsInOutputFn(self): self.skipTest("b/131722904") self._clear_coordination_events() @@ -642,6 +673,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.default_test_combinations()) def testErrorsInInputFn(self): def map_py_fn(x): @@ -687,6 +719,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.default_test_combinations()) def testErrorsInInterleaveFn(self): def map_py_fn(x): @@ -730,6 +763,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.default_test_combinations()) def testShutdownRace(self): dataset = dataset_ops.Dataset.range(20) map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1)) diff --git a/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py index 794f72365df..58cba64617d 100644 --- a/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import copy +from absl.testing import parameterized import numpy as np from tensorflow.core.example import example_pb2 @@ -28,11 +29,11 @@ from tensorflow.python.data.experimental.ops import parsing_ops as contrib_parsi from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import test_util from tensorflow.python.ops import parsing_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test @@ -50,8 +51,8 @@ feature_lists = lambda d: feature_pb2.FeatureLists(feature_list=d) sequence_example = example_pb2.SequenceExample -@test_util.run_all_in_graph_and_eager_modes -class ParseExampleDatasetTest(test_base.DatasetTestBase): +class ParseExampleDatasetTest(test_base.DatasetTestBase, + parameterized.TestCase): def _compare_output_to_expected(self, dict_tensors, expected_tensors): self.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys())) @@ -107,6 +108,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): self.assertEqual( dataset_ops.get_legacy_output_shapes(dataset)[k].as_list()[1], None) + @combinations.generate(test_base.default_test_combinations()) def testEmptySerializedWithAllDefaults(self): sparse_name = "st_a" a_name = "a" @@ -145,7 +147,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_values=expected_output, create_iterator_twice=True) - @test_util.run_deprecated_v1 + @combinations.generate(test_base.graph_only_combinations()) def testEmptySerializedWithoutDefaultsShouldFail(self): input_features = { "st_a": @@ -179,7 +181,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_err=(errors_impl.InvalidArgumentError, "Feature: c \\(data type: float\\) is required")) - @test_util.run_deprecated_v1 + @combinations.generate(test_base.graph_only_combinations()) def testDenseNotMatchingShapeShouldFail(self): original = [ example(features=features({ @@ -197,6 +199,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_err=(errors_impl.InvalidArgumentError, "Key: a, Index: 1. Number of float values")) + @combinations.generate(test_base.default_test_combinations()) def testDenseDefaultNoShapeShouldFail(self): original = [example(features=features({"a": float_feature([1, 1, 3]),})),] @@ -207,6 +210,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): {"a": parsing_ops.FixedLenFeature(None, dtypes.float32)}, expected_err=(ValueError, "Missing shape for feature a")) + @combinations.generate(test_base.default_test_combinations()) def testSerializedContainingSparse(self): original = [ example(features=features({ @@ -248,6 +252,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_values=expected_output, create_iterator_twice=True) + @combinations.generate(test_base.default_test_combinations()) def testSerializedContainingSparseFeature(self): original = [ example(features=features({ @@ -284,6 +289,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_values=expected_output, create_iterator_twice=True) + @combinations.generate(test_base.default_test_combinations()) def testSerializedContainingSparseFeatureReuse(self): original = [ example(features=features({ @@ -325,6 +331,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_values=expected_output, create_iterator_twice=True) + @combinations.generate(test_base.default_test_combinations()) def testSerializedContaining3DSparseFeature(self): original = [ example(features=features({ @@ -370,6 +377,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_values=expected_output, create_iterator_twice=True) + @combinations.generate(test_base.default_test_combinations()) def testSerializedContainingDense(self): aname = "a" bname = "b*has+a:tricky_name" @@ -407,6 +415,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): # This test is identical as the previous one except # for the creation of 'serialized'. + @combinations.generate(test_base.default_test_combinations()) def testSerializedContainingDenseWithConcat(self): aname = "a" bname = "b*has+a:tricky_name" @@ -452,6 +461,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_values=expected_output, create_iterator_twice=True) + @combinations.generate(test_base.default_test_combinations()) def testSerializedContainingDenseScalar(self): original = [ example(features=features({ @@ -476,6 +486,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_values=expected_output, create_iterator_twice=True) + @combinations.generate(test_base.default_test_combinations()) def testSerializedContainingDenseWithDefaults(self): original = [ example(features=features({ @@ -514,6 +525,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_values=expected_output, create_iterator_twice=True) + @combinations.generate(test_base.default_test_combinations()) def testSerializedSparseAndSparseFeatureAndDenseWithNoDefault(self): expected_st_a = sparse_tensor.SparseTensorValue( # indices, values, shape np.empty((0, 2), dtype=np.int64), # indices @@ -569,6 +581,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_values=expected_output, create_iterator_twice=True) + @combinations.generate(test_base.default_test_combinations()) def testerializedContainingSparseAndSparseFeatureWithReuse(self): expected_idx = sparse_tensor.SparseTensorValue( # indices, values, shape np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64), @@ -667,11 +680,13 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_values=expected_output, create_iterator_twice=True) + @combinations.generate(test_base.default_test_combinations()) def testSerializedContainingVarLenDenseLargerBatch(self): np.random.seed(3456) for batch_size in (1, 10, 20, 100, 256): self._testSerializedContainingVarLenDenseLargerBatch(batch_size) + @combinations.generate(test_base.default_test_combinations()) def testSerializedShapeMismatch(self): aname = "a" bname = "b" @@ -724,7 +739,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_err=(ValueError, "Cannot reshape a tensor with 0 elements to shape")) - @test_util.run_deprecated_v1 + @combinations.generate(test_base.graph_only_combinations()) def testSerializedContainingVarLenDense(self): aname = "a" bname = "b" @@ -877,6 +892,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): "Unsupported: FixedLenSequenceFeature requires " "allow_missing to be True.")) + @combinations.generate(test_base.default_test_combinations()) def testSerializedContainingRaggedFeatureWithNoPartitions(self): original = [ example( @@ -922,6 +938,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_values=expected_output, create_iterator_twice=True) + @combinations.generate(test_base.default_test_combinations()) def testSerializedContainingRaggedFeatureWithOnePartition(self): original = [ example( @@ -1040,6 +1057,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase): expected_values=expected_output, create_iterator_twice=True) + @combinations.generate(test_base.default_test_combinations()) def testSerializedContainingRaggedFeatureWithMultiplePartitions(self): original = [ # rt shape: [(batch), 2, None, None] diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py index f51da6e8b66..8ac4e239881 100644 --- a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py @@ -17,11 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.experimental.ops import prefetching_ops from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import structure +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -31,9 +34,9 @@ from tensorflow.python.platform import test # TODO(b/117581999): add eager coverage when supported. -class PrefetchToDeviceTest(test_base.DatasetTestBase): +class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testPrefetchToDevice(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( @@ -57,7 +60,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testPrefetchToSameDevice(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( @@ -82,7 +85,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testPrefetchDictToDevice(self): host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) device_dataset = host_dataset.apply( @@ -106,7 +109,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testPrefetchSparseTensorsToDevice(self): def make_tensor(i): return sparse_tensor.SparseTensorValue( @@ -136,7 +139,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testPrefetchToDeviceGpu(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -156,7 +159,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testPrefetchToDeviceWithReInit(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( @@ -184,7 +187,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) - @test_util.deprecated_graph_mode_only + @combinations.generate(test_base.graph_only_combinations()) def testPrefetchToDeviceGpuWithReInit(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py index abc9eb5f0ad..ff1f1680a76 100644 --- a/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py @@ -24,16 +24,17 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import multi_device_iterator_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase): - @test_util.run_v1_only("b/121264236") + # TODO(b/121264236) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testPrefetchWithSlackOption(self): """Determines slack_period based on num devices attached to iterator.""" dataset = dataset_ops.Dataset.range(10) @@ -60,6 +61,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase): self.evaluate(elem_on_1) self.evaluate(elem_on_2) + @combinations.generate(test_base.default_test_combinations()) def testPrefetchWithSlackOptionWithoutIterator(self): """Defaults to slack period of 1 without iterator.""" dataset = dataset_ops.Dataset.range(10) @@ -72,6 +74,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase): dataset.options()._graph_rewrite_configs()) self.assertDatasetProduces(dataset, range(10)) + @combinations.generate(test_base.default_test_combinations()) def testWithPassthroughDataset(self): """Should still work with a passthrough dataset after prefetch().""" dataset = dataset_ops.Dataset.range(10) @@ -82,6 +85,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, range(1, 11)) + @combinations.generate(test_base.default_test_combinations()) def testErrorWithoutPrefetch(self): """The rewrite fails if there is no prefetch() in the pipeline.""" dataset = dataset_ops.Dataset.range(10) @@ -92,6 +96,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase): get_next = self.getNext(dataset) self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testErrorWithInvalidDataset(self): """With a nested dataset op after prefetch, the rewrite should fail.""" dataset = dataset_ops.Dataset.range(10) diff --git a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py index 32bcdbe183b..30496658529 100644 --- a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py @@ -32,8 +32,8 @@ from tensorflow.python.data.experimental.ops import scan_ops from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes -from tensorflow.python.framework import test_util from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -47,13 +47,11 @@ def _flat_shapes(dataset): return nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)) -@test_util.run_all_in_graph_and_eager_modes class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): - drop_remainder_cases = [("WithDropRemainder", True), - ("WithoutDropRemainder", False)] - - @parameterized.named_parameters(drop_remainder_cases) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(drop_remainder=[True, False]))) def testBasic(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) @@ -64,13 +62,16 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testScalarInputError(self): dataset = dataset_ops.Dataset.range(1024) distribute._RebatchDataset(dataset.batch(4), num_replicas=4) with self.assertRaisesRegexp(ValueError, "at least one dimension"): distribute._RebatchDataset(dataset, num_replicas=4) - @parameterized.named_parameters(drop_remainder_cases) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(drop_remainder=[True, False]))) def testBatchNotDivisibleByNumReplicas(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) @@ -89,6 +90,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): i += 4 self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testBatchSizeNotDivisibleByNumReplicas2(self): dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) @@ -102,6 +104,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output.extend([[]]) # Last replica gets an empty batch self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testTupleOutput(self): dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) @@ -110,6 +113,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testNestedDictionaryOutput(self): dataset = dataset_ops.Dataset.range(1024).map( lambda x: {"a": x, "b": {"c": x}}).batch(32) @@ -119,7 +123,9 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) - @parameterized.named_parameters(drop_remainder_cases) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(drop_remainder=[True, False]))) def testFinalPartialBatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(1032).batch( 32, drop_remainder=drop_remainder) @@ -136,7 +142,9 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): [[k for k in range(i, i + 2)] for i in range(1024, 1032, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output) - @parameterized.named_parameters(drop_remainder_cases) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(drop_remainder=[True, False]))) def testFinalPartialBatchAfterRebatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(34).batch( 32, drop_remainder=drop_remainder) @@ -150,6 +158,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output += [[32], [33], [], []] self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testMultipleBatches(self): dataset = dataset_ops.Dataset.range(128).batch(4).batch(8) self.assertEqual([[None, None]], @@ -170,6 +179,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): for i in range(0, 128, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testMapAndBatch(self): dataset = dataset_ops.Dataset.range(1024).apply( batching.map_and_batch(math_ops.square, 32)) @@ -180,6 +190,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchWithCapturedInput(self): captured_t = variables.Variable(42) dataset = dataset_ops.Dataset.range(1024).apply( @@ -193,6 +204,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( rebatched_dataset, expected_output, requires_initialization=True) + @combinations.generate(test_base.default_test_combinations()) def testPaddedBatch(self): dataset = dataset_ops.Dataset.range(128).batch( 4, drop_remainder=True).padded_batch( @@ -213,6 +225,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): for i in range(0, 128, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testConcatenate(self): dataset1 = dataset_ops.Dataset.range(64).batch(8) dataset2 = dataset_ops.Dataset.range(32).batch(8) @@ -224,6 +237,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): [[i, i + 1] for i in range(0, 32, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testConcatenateDifferentShapes(self): dataset1 = dataset_ops.Dataset.range(64).batch(16) dataset2 = dataset_ops.Dataset.range(32).batch(8) @@ -235,6 +249,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): [[i, i + 1] for i in range(0, 32, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testZip(self): dataset1 = dataset_ops.Dataset.range(64).batch(8) dataset2 = dataset_ops.Dataset.range(32).batch(8) @@ -245,6 +260,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)] self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testZipDifferentShapes(self): dataset1 = dataset_ops.Dataset.range(64).batch(16) dataset2 = dataset_ops.Dataset.range(32).batch(8) @@ -256,6 +272,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): for i in range(0, 32, 2)] self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testFlatMapBatching(self): dataset = dataset_ops.Dataset.range(2).flat_map( lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda @@ -274,6 +291,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): for i in range(0, 32, 8)] # generates 4 elements self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testInterleaveBatching(self): dataset = dataset_ops.Dataset.range(2).interleave( lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda @@ -290,6 +308,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output += expected_output self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testParallelInterleaveBatching(self): dataset = dataset_ops.Dataset.range(2).interleave( lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda @@ -307,6 +326,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output += expected_output self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testGroupByWindowStaticBatch(self): dataset = dataset_ops.Dataset.from_tensor_slices( [[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)]) @@ -326,6 +346,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): for k in range(2)] self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testGroupByWindowDynamicBatch(self): # {0, 1, 0, 1, ...} dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) @@ -350,6 +371,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output = [[value] * batch_size for batch_size, value in pairs] self.assertDatasetProduces(dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testGroupByWindowDynamicBatchWithPartialBatch(self): # {0, 1, 0, 1, ...} dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) @@ -371,6 +393,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output = [[value] * batch_size for batch_size, value in pairs] self.assertDatasetProduces(dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testGroupByWindowDynamicBatchWithPartialBatchWithDropRemainder(self): # This test exercises nested batch functionality, dynamic batch size # and drop_remainder=True together. @@ -395,6 +418,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output = [[value] * batch_size for batch_size, value in pairs] self.assertDatasetProduces(dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testScanAfterBatch(self): dataset = dataset_ops.Dataset.range(40).batch(10).apply( scan_ops.scan(np.int64(2), lambda state, value: (state, value * state))) @@ -405,6 +429,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testMakeBatchedFeaturesDataset(self): # Set up fn = os.path.join(self.get_temp_dir(), "tf_record.txt") @@ -438,6 +463,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): } for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate(test_base.default_test_combinations()) def testRaggedTensorDataset(self): # Set up a dataset that produces ragged tensors with a static batch size. row_lengths = np.random.randint(8, size=128) diff --git a/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py b/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py index 673e77fc3bb..fb1d4ea5d3a 100644 --- a/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py @@ -24,9 +24,9 @@ import numpy as np from tensorflow.python.data.experimental.ops import resampling from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import string_ops @@ -34,12 +34,11 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -@test_util.run_all_in_graph_and_eager_modes class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase): - @parameterized.named_parameters( - ("InitialDistributionKnown", True), - ("InitialDistributionUnknown", False)) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(initial_known=[True, False]))) def testDistribution(self, initial_known): classes = np.random.randint(5, size=(20000,)) # Uniformly sampled target_dist = [0.9, 0.05, 0.05, 0.0, 0.0] @@ -72,9 +71,9 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase): returned_dist = class_counts / total_returned self.assertAllClose(target_dist, returned_dist, atol=1e-2) - @parameterized.named_parameters( - ("OnlyInitial", True), - ("NotInitial", False)) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(only_initial_dist=[True, False]))) def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist): init_dist = [0.5, 0.5] target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0] @@ -99,6 +98,7 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase): while True: returned.append(self.evaluate(get_next())) + @combinations.generate(test_base.default_test_combinations()) def testRandomClasses(self): init_dist = [0.25, 0.25, 0.25, 0.25] target_dist = [0.0, 0.0, 0.0, 1.0] diff --git a/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py b/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py index 92ae528b940..8bb109a6519 100644 --- a/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py @@ -17,18 +17,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.data.experimental.ops import shuffle_ops from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class ShuffleAndRepeatTest(test_base.DatasetTestBase): +class ShuffleAndRepeatTest(test_base.DatasetTestBase, parameterized.TestCase): def _build_ds(self, seed, count=5, num_elements=20): return dataset_ops.Dataset.range(num_elements).apply( @@ -44,6 +44,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase): self.evaluate(get_next()) return outputs + @combinations.generate(test_base.default_test_combinations()) def testCorrectOutput(self): output = self._gen_outputs(lambda: self._build_ds(10), 100) self.assertSequenceEqual( @@ -52,6 +53,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase): for i in range(5): self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20)) + @combinations.generate(test_base.default_test_combinations()) def testReshuffling(self): # Check that the output orders of different epochs are indeed different. output = self._gen_outputs(lambda: self._build_ds(10), 100) @@ -60,17 +62,20 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase): epoch2 = output[(i + 1) * 20:(i + 2) * 20] self.assertNotEqual(epoch1, epoch2) + @combinations.generate(test_base.default_test_combinations()) def testSameOrderForSameSeeds(self): output1 = self._gen_outputs(lambda: self._build_ds(10), 100) output2 = self._gen_outputs(lambda: self._build_ds(10), 100) self.assertEqual(output1, output2) + @combinations.generate(test_base.default_test_combinations()) def testDifferentOrderForDifferentSeeds(self): output1 = self._gen_outputs(lambda: self._build_ds(10), 100) output2 = self._gen_outputs(lambda: self._build_ds(20), 100) self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) + @combinations.generate(test_base.default_test_combinations()) def testCountNone(self): output1 = self._gen_outputs( lambda: self._build_ds(10, count=None), 100, verify_exhausted=False) @@ -79,6 +84,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase): self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) + @combinations.generate(test_base.default_test_combinations()) def testCountMinusOne(self): output1 = self._gen_outputs( lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False) @@ -87,6 +93,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase): self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) + @combinations.generate(test_base.default_test_combinations()) def testInfiniteOutputs(self): # Asserting the iterator is exhausted after producing 100 items should fail. with self.assertRaises(AssertionError): @@ -94,6 +101,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase): with self.assertRaises(AssertionError): self._gen_outputs(lambda: self._build_ds(10, count=-1), 100) + @combinations.generate(test_base.default_test_combinations()) def testInfiniteEmpty(self): with self.assertRaises(errors.OutOfRangeError): self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), @@ -102,12 +110,14 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase): self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), 100) + @combinations.generate(test_base.default_test_combinations()) def testLargeBufferSize(self): ds = dataset_ops.Dataset.range(20).apply( shuffle_ops.shuffle_and_repeat(buffer_size=21)) get_next = self.getNext(ds) self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testVeryLargeBufferSize(self): num_epochs = 1000 * 1000 # Each element being shuffled and repeated has shape (100,). This will OOM diff --git a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py index f55f62f5cb0..8e1dd4bd8dc 100644 --- a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py @@ -18,18 +18,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): +class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase, + parameterized.TestCase): # Test that SqlDataset can read from a database table. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSet(self): for _ in range(2): # Run twice to verify statelessness of db operations. dataset = self._createSqlDataset( @@ -44,6 +48,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): num_test_iterations=2) # Test that SqlDataset works on a join query. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetJoinQuery(self): get_next = self.getNext( self._createSqlDataset( @@ -60,6 +65,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that SqlDataset can read a database entry with a null-terminator # in the middle of the text and place the entry in a `string` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetNullTerminator(self): get_next = self.getNext( self._createSqlDataset( @@ -76,6 +82,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that SqlDataset works when used on two different queries. # Because the output types of the dataset must be determined at graph-creation # time, the two queries must have the same number and types of columns. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetReuseSqlDataset(self): get_next = self.getNext( self._createSqlDataset( @@ -100,6 +107,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that an `OutOfRangeError` is raised on the first call to # `get_next_str_only` if result set is empty. + @combinations.generate(test_base.default_test_combinations()) def testReadEmptyResultSet(self): get_next = self.getNext( self._createSqlDataset( @@ -110,6 +118,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): self.evaluate(get_next()) # Test that an error is raised when `driver_name` is invalid. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetWithInvalidDriverName(self): with self.assertRaises(errors.InvalidArgumentError): dataset = self._createSqlDataset( @@ -120,6 +129,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): self.assertDatasetProduces(dataset, expected_output=[]) # Test that an error is raised when a column name in `query` is nonexistent + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetWithInvalidColumnName(self): get_next = self.getNext( self._createSqlDataset( @@ -130,6 +140,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): self.evaluate(get_next()) # Test that an error is raised when there is a syntax error in `query`. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetOfQueryWithSyntaxError(self): get_next = self.getNext( self._createSqlDataset( @@ -141,6 +152,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that an error is raised when the number of columns in `query` # does not match the length of `, output_types`. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self): get_next = self.getNext( self._createSqlDataset( @@ -154,6 +166,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # than a select query. In particular, the error refers to the number of # output types passed to the op not matching the number of columns in the # result set of the query (namely, 0 for an insert statement.) + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetOfInsertQuery(self): get_next = self.getNext( self._createSqlDataset( @@ -165,6 +178,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read an integer from a SQLite database table and # place it in an `int8` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetInt8(self): get_next = self.getNext( self._createSqlDataset( @@ -178,6 +192,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read a negative or 0-valued integer from a # SQLite database table and place it in an `int8` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetInt8NegativeAndZero(self): get_next = self.getNext( self._createSqlDataset( @@ -191,6 +206,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read a large (positive or negative) integer from # a SQLite database table and place it in an `int8` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetInt8MaxValues(self): get_next = self.getNext( self._createSqlDataset( @@ -205,6 +221,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read an integer from a SQLite database table and # place it in an `int16` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetInt16(self): get_next = self.getNext( self._createSqlDataset( @@ -218,6 +235,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read a negative or 0-valued integer from a # SQLite database table and place it in an `int16` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetInt16NegativeAndZero(self): get_next = self.getNext( self._createSqlDataset( @@ -231,6 +249,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read a large (positive or negative) integer from # a SQLite database table and place it in an `int16` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetInt16MaxValues(self): get_next = self.getNext( self._createSqlDataset( @@ -246,6 +265,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read an integer from a SQLite database table and # place it in an `int32` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetInt32(self): get_next = self.getNext( self._createSqlDataset( @@ -257,6 +277,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read a negative or 0-valued integer from a # SQLite database table and place it in an `int32` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetInt32NegativeAndZero(self): get_next = self.getNext( self._createSqlDataset( @@ -270,6 +291,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read a large (positive or negative) integer from # a SQLite database table and place it in an `int32` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetInt32MaxValues(self): get_next = self.getNext( self._createSqlDataset( @@ -285,6 +307,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read a numeric `varchar` from a SQLite database # table and place it in an `int32` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetInt32VarCharColumnAsInt(self): get_next = self.getNext( self._createSqlDataset( @@ -298,6 +321,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read an integer from a SQLite database table # and place it in an `int64` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetInt64(self): get_next = self.getNext( self._createSqlDataset( @@ -311,6 +335,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read a negative or 0-valued integer from a # SQLite database table and place it in an `int64` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetInt64NegativeAndZero(self): get_next = self.getNext( self._createSqlDataset( @@ -324,6 +349,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read a large (positive or negative) integer from # a SQLite database table and place it in an `int64` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetInt64MaxValues(self): get_next = self.getNext( self._createSqlDataset( @@ -339,6 +365,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read an integer from a SQLite database table and # place it in a `uint8` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetUInt8(self): get_next = self.getNext( self._createSqlDataset( @@ -352,6 +379,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read the minimum and maximum uint8 values from a # SQLite database table and place them in `uint8` tensors. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetUInt8MinAndMaxValues(self): get_next = self.getNext( self._createSqlDataset( @@ -367,6 +395,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read an integer from a SQLite database table # and place it in a `uint16` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetUInt16(self): get_next = self.getNext( self._createSqlDataset( @@ -380,6 +409,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read the minimum and maximum uint16 values from a # SQLite database table and place them in `uint16` tensors. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetUInt16MinAndMaxValues(self): get_next = self.getNext( self._createSqlDataset( @@ -396,6 +426,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read a 0-valued and 1-valued integer from a # SQLite database table and place them as `True` and `False` respectively # in `bool` tensors. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetBool(self): get_next = self.getNext( self._createSqlDataset( @@ -409,6 +440,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued # from a SQLite database table and place it as `True` in a `bool` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetBoolNotZeroOrOne(self): get_next = self.getNext( self._createSqlDataset( @@ -422,6 +454,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read a float from a SQLite database table # and place it in a `float64` tensor. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetFloat64(self): get_next = self.getNext( self._createSqlDataset( @@ -437,6 +470,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that `SqlDataset` can read a float from a SQLite database table beyond # the precision of 64-bit IEEE, without throwing an error. Test that # `SqlDataset` identifies such a value as equal to itself. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetFloat64OverlyPrecise(self): get_next = self.getNext( self._createSqlDataset( @@ -458,6 +492,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # representing the largest integer representable as a 64-bit IEEE float # such that the previous integer is also representable as a 64-bit IEEE float. # Test that `SqlDataset` can distinguish these two numbers. + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self): get_next = self.getNext( self._createSqlDataset( @@ -472,6 +507,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): self.evaluate(get_next()) # Test that SqlDataset can stop correctly when combined with batch + @combinations.generate(test_base.default_test_combinations()) def testReadResultSetWithBatchStop(self): dataset = self._createSqlDataset( query="SELECT * FROM data", output_types=(dtypes.int32)) diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py index 4f04a0a3639..f77f2f21bf7 100644 --- a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base @@ -24,7 +25,9 @@ from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_ from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.experimental.ops import stats_aggregator from tensorflow.python.data.experimental.ops import stats_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -32,8 +35,11 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): +# TODO(jsimsa): Figure out why are graph tests failing. +class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase, + parameterized.TestCase): + @combinations.generate(test_base.eager_only_combinations()) def testBytesProduced(self): aggregator = stats_aggregator.StatsAggregator() dataset = dataset_ops.Dataset.range(100).map( @@ -57,6 +63,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): self.assertStatisticsHasCount(handle, "bytes_produced", 100.0, 101) self.assertStatisticsHasSum(handle, "bytes_produced", expected_sum, 101) + @combinations.generate(test_base.eager_only_combinations()) def testLatencyStats(self): aggregator = stats_aggregator.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( @@ -76,6 +83,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): handle = self.getHandle(aggregator) self.assertStatisticsHasCount(handle, "record_latency", 100.0, 101) + @combinations.generate(test_base.eager_only_combinations()) def testPrefetchBufferUtilization(self): aggregator = stats_aggregator.StatsAggregator() dataset = dataset_ops.Dataset.range(100).map( @@ -117,6 +125,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): 301, offset=2) + @combinations.generate(test_base.eager_only_combinations()) def testPrefetchBufferScalars(self): aggregator = stats_aggregator.StatsAggregator() dataset = dataset_ops.Dataset.range(10).map( @@ -140,6 +149,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.eager_only_combinations()) def testFilteredElementsStats(self): aggregator = stats_aggregator.StatsAggregator() dataset = dataset_ops.Dataset.range(101).filter( @@ -167,6 +177,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): handle, self.regexForNodeName("FilterDataset", "filtered_elements"), 34.0) + @combinations.generate(test_base.eager_only_combinations()) def testReinitialize(self): aggregator = stats_aggregator.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( @@ -187,6 +198,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): self.assertStatisticsHasCount(handle, "record_latency", (j + 1) * 100.0, (j * 100) + 101) + @combinations.generate(test_base.eager_only_combinations()) def testNoAggregatorRegistered(self): dataset = dataset_ops.Dataset.range(100).apply( stats_ops.latency_stats("record_latency")) @@ -198,6 +210,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.eager_only_combinations()) def testMultipleTags(self): aggregator = stats_aggregator.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( @@ -221,6 +234,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): handle, "record_latency", 100.0, 201, offset=1) self.assertStatisticsHasCount(handle, "record_latency_2", 100.0, 201) + @combinations.generate(test_base.eager_only_combinations()) def testRepeatedTags(self): aggregator = stats_aggregator.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( @@ -239,6 +253,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): handle = self.getHandle(aggregator) self.assertStatisticsHasCount(handle, "record_latency", 200.0, 201) + @combinations.generate(test_base.eager_only_combinations()) def testMultipleIteratorsSameAggregator(self): aggregator = stats_aggregator.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( @@ -259,6 +274,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): handle = self.getHandle(aggregator) self.assertStatisticsHasCount(handle, "record_latency", 200.0, 201) + @combinations.generate(test_base.eager_only_combinations()) def testMultipleDatasetWithPrefixes(self): aggregator = stats_aggregator.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( @@ -289,6 +305,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): self.assertStatisticsHasCount(handle, "dataset2::record_latency", 100.0, 201) + @combinations.generate(test_base.eager_only_combinations()) def testMultiplePrefetchStats(self): aggregator = stats_aggregator.StatsAggregator() @@ -314,8 +331,10 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): self.evaluate(next_element()) -class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase): +class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase, + parameterized.TestCase): + @combinations.generate(test_base.eager_only_combinations()) def testMapBufferUtilization(self): def dataset_fn(): @@ -326,6 +345,7 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase): self.parallelCallsStats( dataset_fn, {"ParallelMapDataset"}, 10, function_processing_time=True) + @combinations.generate(test_base.eager_only_combinations()) def testMapAutoTuneBufferUtilization(self): def dataset_fn(): @@ -336,6 +356,7 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase): self.parallelCallsStats( dataset_fn, {"ParallelMapDataset"}, 10, function_processing_time=True) + @combinations.generate(test_base.eager_only_combinations()) def testInterleaveAutoTuneBufferUtilization(self): def dataset_fn(): @@ -351,6 +372,7 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase): self.parallelCallsStats(dataset_fn, {"ParallelInterleaveDatasetV2"}, 10) + @combinations.generate(test_base.eager_only_combinations()) def testMapAndBatchAutoTuneBufferUtilization(self): def dataset_fn(): @@ -370,8 +392,10 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase): class FeatureStatsDatasetTest( stats_dataset_test_base.StatsDatasetTestBase, - reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase): + reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase, + parameterized.TestCase): + @combinations.generate(test_base.eager_only_combinations()) def testFeaturesStats(self): num_epochs = 5 total_records = num_epochs * self._num_records diff --git a/tensorflow/python/data/experimental/kernel_tests/take_while_test.py b/tensorflow/python/data/experimental/kernel_tests/take_while_test.py index b2b0effb0df..959837faa24 100644 --- a/tensorflow/python/data/experimental/kernel_tests/take_while_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/take_while_test.py @@ -23,18 +23,21 @@ import numpy as np from tensorflow.python.data.experimental.ops import take_while_ops from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase): - @parameterized.parameters((14, 2), (15, 2), (100, 3)) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(num_elements=[14, 15], window_size=[2]) + + combinations.combine(num_elements=[100], window_size=[3]))) def testTakeWhileDataset(self, num_elements, window_size): def _predicate_func(elem): @@ -49,8 +52,19 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase): expected_num_elements = int(num_elements / window_size) * window_size self.assertDatasetProduces(dataset, np.arange(expected_num_elements)) - @parameterized.parameters((10, 2, False), (16, 7, False), (100, 99, False), - (100, 101, True), (0, 1, True)) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + num_elements=[10], upper_bound=[2], out_of_bounds=[False]) + + combinations.combine( + num_elements=[16], upper_bound=[7], out_of_bounds=[False]) + + combinations.combine( + num_elements=[100], upper_bound=[99], out_of_bounds=[False]) + + combinations.combine( + num_elements=[100], upper_bound=[101], out_of_bounds=[True]) + + combinations.combine( + num_elements=[0], upper_bound=[1], out_of_bounds=[True]))) def testTakeWhileDatasetRange(self, num_elements, upper_bound, out_of_bounds): dataset = dataset_ops.Dataset.range(num_elements).apply( take_while_ops.take_while(lambda x: x < upper_bound)) @@ -62,6 +76,7 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase): else: self.assertDatasetProduces(dataset, np.arange(upper_bound)) + @combinations.generate(test_base.default_test_combinations()) def testTakeWhileDatasetString(self): def not_equal(string): @@ -79,7 +94,13 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.assertEqual(b"test", self.evaluate(next_element())) - @parameterized.parameters((5, 3), (10, 0), (100, 5), (8, 7)) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(size=[5], index=[3]) + + combinations.combine(size=[10], index=[0]) + + combinations.combine(size=[100], index=[5]) + + combinations.combine(size=[8], index=[7]))) def testTakewhileDatasetShortCircuit(self, size, index): def _predicate_func(data_elem): @@ -98,6 +119,7 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + @combinations.generate(test_base.default_test_combinations()) def testTakeWhileDatasetWithRepeat(self): dataset = dataset_ops.Dataset.range(10).apply( take_while_ops.take_while(lambda x: x < 2)).repeat(5) diff --git a/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py b/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py index 136a446bbd8..a327fc82466 100644 --- a/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py @@ -19,14 +19,16 @@ from __future__ import print_function import os +from absl.testing import parameterized + from tensorflow.python.data.experimental.ops import grouping from tensorflow.python.data.experimental.ops import writers from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.eager import function +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes -from tensorflow.python.framework import test_util from tensorflow.python.lib.io import python_io from tensorflow.python.lib.io import tf_record from tensorflow.python.ops import string_ops @@ -34,8 +36,7 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -@test_util.run_all_in_graph_and_eager_modes -class TFRecordWriterTest(test_base.DatasetTestBase): +class TFRecordWriterTest(test_base.DatasetTestBase, parameterized.TestCase): def setUp(self): super(TFRecordWriterTest, self).setUp() @@ -63,11 +64,13 @@ class TFRecordWriterTest(test_base.DatasetTestBase): def _outputFilename(self): return os.path.join(self.get_temp_dir(), "tf_record.out.txt") + @combinations.generate(test_base.default_test_combinations()) def testWrite(self): self.evaluate(self.writer_fn(self._createFile())) for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())): self.assertAllEqual(self._record(i), r) + @combinations.generate(test_base.default_test_combinations()) def testWriteZLIB(self): options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB) self.evaluate( @@ -76,6 +79,7 @@ class TFRecordWriterTest(test_base.DatasetTestBase): tf_record.tf_record_iterator(self._outputFilename(), options=options)): self.assertAllEqual(self._record(i), r) + @combinations.generate(test_base.default_test_combinations()) def testWriteGZIP(self): options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP) self.evaluate( @@ -84,20 +88,24 @@ class TFRecordWriterTest(test_base.DatasetTestBase): tf_record.tf_record_iterator(self._outputFilename(), options=options)): self.assertAllEqual(self._record(i), r) + @combinations.generate(test_base.default_test_combinations()) def testFailDataset(self): with self.assertRaises(TypeError): writers.TFRecordWriter(self._outputFilename(), "").write("whoops") + @combinations.generate(test_base.default_test_combinations()) def testFailDType(self): input_dataset = dataset_ops.Dataset.from_tensors(10) with self.assertRaises(TypeError): writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset) + @combinations.generate(test_base.default_test_combinations()) def testFailShape(self): input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]]) with self.assertRaises(TypeError): writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset) + @combinations.generate(test_base.default_test_combinations()) def testSideEffect(self): def writer_fn(): input_dataset = readers.TFRecordDataset(self._createFile()) @@ -112,6 +120,7 @@ class TFRecordWriterTest(test_base.DatasetTestBase): for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())): self.assertAllEqual(self._record(i), r) + @combinations.generate(test_base.default_test_combinations()) def testShard(self): filename = self._createFile() dataset = readers.TFRecordDataset([filename]) diff --git a/tensorflow/python/data/experimental/kernel_tests/unique_test.py b/tensorflow/python/data/experimental/kernel_tests/unique_test.py index 42d76a2eb30..9a51c4224ff 100644 --- a/tensorflow/python/data/experimental/kernel_tests/unique_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/unique_test.py @@ -17,17 +17,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.experimental.ops import unique from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes -from tensorflow.python.framework import test_util from tensorflow.python.platform import test from tensorflow.python.util import compat -@test_util.run_all_in_graph_and_eager_modes -class UniqueTest(test_base.DatasetTestBase): +class UniqueTest(test_base.DatasetTestBase, parameterized.TestCase): def _testSimpleHelper(self, dtype, test_cases): """Test the `unique()` transformation on a list of test cases. @@ -52,7 +53,7 @@ class UniqueTest(test_base.DatasetTestBase): for element in expected ]) - @test_util.run_deprecated_v1 + @combinations.generate(test_base.graph_only_combinations()) def testSimpleInt(self): for dtype in [dtypes.int32, dtypes.int64]: self._testSimpleHelper(dtype, [ @@ -65,7 +66,7 @@ class UniqueTest(test_base.DatasetTestBase): ([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]), ]) - @test_util.run_deprecated_v1 + @combinations.generate(test_base.graph_only_combinations()) def testSimpleString(self): self._testSimpleHelper(dtypes.string, [ ([], []), diff --git a/tensorflow/python/data/experimental/kernel_tests/variant_test.py b/tensorflow/python/data/experimental/kernel_tests/variant_test.py index 6a3a1424d12..897aa223371 100644 --- a/tensorflow/python/data/experimental/kernel_tests/variant_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/variant_test.py @@ -17,16 +17,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.experimental.ops import cardinality from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import test_util +from tensorflow.python.framework import combinations from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class VariantTest(test_base.DatasetTestBase): +class VariantTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testRoundtripRange(self): dataset = dataset_ops.Dataset.range(10) variant = dataset_ops.to_variant(dataset) @@ -35,6 +37,7 @@ class VariantTest(test_base.DatasetTestBase): self.assertDatasetProduces(dataset, range(10)) self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10) + @combinations.generate(test_base.default_test_combinations()) def testRoundtripMap(self): dataset = dataset_ops.Dataset.range(10).map(lambda x: x*x) variant = dataset_ops.to_variant(dataset) diff --git a/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py b/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py index 09627d02994..3fd252ab3ac 100644 --- a/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py @@ -17,18 +17,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class WrapDatasetVariantTest(test_base.DatasetTestBase): +class WrapDatasetVariantTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testBasic(self): ds = dataset_ops.Dataset.range(100) ds_variant = ds._variant_tensor # pylint: disable=protected-access @@ -42,7 +44,9 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase): for i in range(100): self.assertEqual(i, self.evaluate(get_next())) - @test_util.run_v1_only("b/123901304") + # TODO(b/123901304) + @combinations.generate( + combinations.combine(tf_api_version=[1], mode=["graph"])) def testSkipEagerGPU(self): ds = dataset_ops.Dataset.range(100) ds_variant = ds._variant_tensor # pylint: disable=protected-access