[tf.data] Rolling forward a previously rolled back change with a fix.

PiperOrigin-RevId: 284036647
Change-Id: I9d50ad7aa8123f6928c055a25bc3dc4d69d2b95d
This commit is contained in:
Jiri Simsa 2019-12-05 13:10:45 -08:00 committed by TensorFlower Gardener
parent 7e2b4b8c96
commit 769892b353
29 changed files with 539 additions and 190 deletions

View File

@ -25,11 +25,11 @@ from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -73,14 +73,12 @@ def _get_record_shape(sparse):
return tensor_shape.TensorShape([None]) return tensor_shape.TensorShape([None])
@test_util.run_all_in_graph_and_eager_modes
class BucketBySequenceLengthTest(test_base.DatasetTestBase, class BucketBySequenceLengthTest(test_base.DatasetTestBase,
parameterized.TestCase): parameterized.TestCase):
@parameterized.named_parameters( @combinations.generate(
("WithoutPadding", True), combinations.times(test_base.default_test_combinations(),
("WithPadding", False), combinations.combine(param_no_padding=[True, False])))
)
def testBucketDropReminder(self, param_no_padding): def testBucketDropReminder(self, param_no_padding):
boundaries = [10, 20, 30] boundaries = [10, 20, 30]
@ -201,10 +199,9 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
_test_bucket_by_padding(param_no_padding) _test_bucket_by_padding(param_no_padding)
@parameterized.named_parameters( @combinations.generate(
("WithoutPadding", True), combinations.times(test_base.default_test_combinations(),
("WithPadding", False), combinations.combine(param_no_padding=[True, False])))
)
def testBucket(self, param_no_padding): def testBucket(self, param_no_padding):
boundaries = [10, 20, 30] boundaries = [10, 20, 30]
@ -347,10 +344,9 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
@parameterized.named_parameters( @combinations.generate(
("WithoutPadding", True), combinations.times(test_base.default_test_combinations(),
("WithPadding", False), combinations.combine(param_no_padding=[True, False])))
)
def testTupleElements(self, param_no_padding): def testTupleElements(self, param_no_padding):
def build_dataset(sparse): def build_dataset(sparse):
@ -381,10 +377,10 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
_test_tuple_elements_by_padding(param_no_padding) _test_tuple_elements_by_padding(param_no_padding)
@parameterized.named_parameters( @combinations.generate(
("DoDropRemainder", True), combinations.times(
("DoNotDropRemainder", False), test_base.default_test_combinations(),
) combinations.combine(param_drop_remainder=[True, False])))
def testBucketSparse(self, param_drop_remainder): # pylint: disable=g-doc-args def testBucketSparse(self, param_drop_remainder): # pylint: disable=g-doc-args
"""Tests bucketing of sparse tensors (case where `no_padding` == True). """Tests bucketing of sparse tensors (case where `no_padding` == True).

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import prefetching_ops from tensorflow.python.data.experimental.ops import prefetching_ops
@ -24,6 +26,7 @@ from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import structure from tensorflow.python.data.util import structure
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -35,9 +38,9 @@ from tensorflow.python.util import compat as util_compat
# TODO(b/117581999): add eager coverage when supported. # TODO(b/117581999): add eager coverage when supported.
class CopyToDeviceTest(test_base.DatasetTestBase): class CopyToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDevice(self): def testCopyToDevice(self):
host_dataset = dataset_ops.Dataset.range(10) host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply( device_dataset = host_dataset.apply(
@ -62,7 +65,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceInt32(self): def testCopyToDeviceInt32(self):
host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
device_dataset = host_dataset.apply( device_dataset = host_dataset.apply(
@ -86,7 +89,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToSameDevice(self): def testCopyToSameDevice(self):
host_dataset = dataset_ops.Dataset.range(10) host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply( device_dataset = host_dataset.apply(
@ -111,7 +114,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceWithPrefetch(self): def testCopyToDeviceWithPrefetch(self):
host_dataset = dataset_ops.Dataset.range(10) host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply( device_dataset = host_dataset.apply(
@ -136,7 +139,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyDictToDevice(self): def testCopyDictToDevice(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
device_dataset = host_dataset.apply( device_dataset = host_dataset.apply(
@ -161,7 +164,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyDictToDeviceWithPrefetch(self): def testCopyDictToDeviceWithPrefetch(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
device_dataset = host_dataset.apply( device_dataset = host_dataset.apply(
@ -186,7 +189,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopySparseTensorsToDevice(self): def testCopySparseTensorsToDevice(self):
def make_tensor(i): def make_tensor(i):
@ -219,7 +222,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopySparseTensorsToDeviceWithPrefetch(self): def testCopySparseTensorsToDeviceWithPrefetch(self):
def make_tensor(i): def make_tensor(i):
@ -252,7 +255,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpu(self): def testCopyToDeviceGpu(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")
@ -273,7 +276,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuWithPrefetch(self): def testCopyToDeviceGpuWithPrefetch(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")
@ -294,7 +297,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuWithMap(self): def testCopyToDeviceGpuWithMap(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")
@ -332,7 +335,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuInt32(self): def testCopyToDeviceGpuInt32(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")
@ -352,7 +355,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuInt32AndPrefetch(self): def testCopyToDeviceGpuInt32AndPrefetch(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")
@ -372,7 +375,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuStrings(self): def testCopyToDeviceGpuStrings(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")
@ -392,7 +395,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuStringsAndPrefetch(self): def testCopyToDeviceGpuStringsAndPrefetch(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")
@ -412,7 +415,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDevicePingPongCPUGPU(self): def testCopyToDevicePingPongCPUGPU(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")
@ -436,7 +439,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceWithReInit(self): def testCopyToDeviceWithReInit(self):
host_dataset = dataset_ops.Dataset.range(10) host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply( device_dataset = host_dataset.apply(
@ -465,7 +468,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceWithReInitAndPrefetch(self): def testCopyToDeviceWithReInitAndPrefetch(self):
host_dataset = dataset_ops.Dataset.range(10) host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply( device_dataset = host_dataset.apply(
@ -494,7 +497,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuWithReInit(self): def testCopyToDeviceGpuWithReInit(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")
@ -518,7 +521,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuWithReInitAndPrefetch(self): def testCopyToDeviceGpuWithReInitAndPrefetch(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")
@ -542,7 +545,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testIteratorGetNextAsOptionalOnGPU(self): def testIteratorGetNextAsOptionalOnGPU(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")

View File

@ -17,35 +17,33 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import counter from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes class CounterTest(test_base.DatasetTestBase, parameterized.TestCase):
class CounterTest(test_base.DatasetTestBase):
def testCounter(self): @combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(start=3, step=4, expected_output=[[3, 7, 11]]) +
combinations.combine(start=0, step=-1, expected_output=[[0, -1, -2]]))
)
def testCounter(self, start, step, expected_output):
"""Test dataset construction using `count`.""" """Test dataset construction using `count`."""
dataset = counter.Counter(start=3, step=4) dataset = counter.Counter(start, step)
self.assertEqual( self.assertEqual(
[], dataset_ops.get_legacy_output_shapes(dataset).as_list()) [], dataset_ops.get_legacy_output_shapes(dataset).as_list())
self.assertEqual(dtypes.int64, dataset_ops.get_legacy_output_types(dataset)) self.assertEqual(dtypes.int64, dataset_ops.get_legacy_output_types(dataset))
get_next = self.getNext(dataset) get_next = self.getNext(dataset)
for expected in expected_output:
negative_dataset = counter.Counter(start=0, step=-1) self.assertEqual(expected, self.evaluate(get_next()))
negative_get_next = self.getNext(negative_dataset)
self.assertEqual(3, self.evaluate(get_next()))
self.assertEqual(3 + 4, self.evaluate(get_next()))
self.assertEqual(3 + 2 * 4, self.evaluate(get_next()))
self.assertEqual(0, self.evaluate(negative_get_next()))
self.assertEqual(-1, self.evaluate(negative_get_next()))
self.assertEqual(-2, self.evaluate(negative_get_next()))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -22,21 +22,22 @@ import gzip
import os import os
import zlib import zlib
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import error_ops from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.experimental.ops import readers from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
class CsvDatasetTest(test_base.DatasetTestBase):
def _setup_files(self, inputs, linebreak='\n', compression_type=None): def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = [] filenames = []
@ -117,26 +118,31 @@ class CsvDatasetTest(test_base.DatasetTestBase):
dataset = readers.CsvDataset(filenames, **kwargs) dataset = readers.CsvDataset(filenames, **kwargs)
self._verify_output_or_err(dataset, expected_output, expected_err_re) self._verify_output_or_err(dataset, expected_output, expected_err_re)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_requiredFields(self): def testCsvDataset_requiredFields(self):
record_defaults = [[]] * 4 record_defaults = [[]] * 4
inputs = [['1,2,3,4']] inputs = [['1,2,3,4']]
self._test_by_comparison(inputs, record_defaults=record_defaults) self._test_by_comparison(inputs, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_int(self): def testCsvDataset_int(self):
record_defaults = [[0]] * 4 record_defaults = [[0]] * 4
inputs = [['1,2,3,4', '5,6,7,8']] inputs = [['1,2,3,4', '5,6,7,8']]
self._test_by_comparison(inputs, record_defaults=record_defaults) self._test_by_comparison(inputs, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_float(self): def testCsvDataset_float(self):
record_defaults = [[0.0]] * 4 record_defaults = [[0.0]] * 4
inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']] inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']]
self._test_by_comparison(inputs, record_defaults=record_defaults) self._test_by_comparison(inputs, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_string(self): def testCsvDataset_string(self):
record_defaults = [['']] * 4 record_defaults = [['']] * 4
inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']]
self._test_by_comparison(inputs, record_defaults=record_defaults) self._test_by_comparison(inputs, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withEmptyFields(self): def testCsvDataset_withEmptyFields(self):
record_defaults = [[0]] * 4 record_defaults = [[0]] * 4
inputs = [[',,,', '1,1,1,', ',2,2,2']] inputs = [[',,,', '1,1,1,', ',2,2,2']]
@ -144,6 +150,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
record_defaults=record_defaults) record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errWithUnquotedQuotes(self): def testCsvDataset_errWithUnquotedQuotes(self):
record_defaults = [['']] * 3 record_defaults = [['']] * 3
inputs = [['1,2"3,4']] inputs = [['1,2"3,4']]
@ -152,6 +159,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
expected_err_re='Unquoted fields cannot have quotes inside', expected_err_re='Unquoted fields cannot have quotes inside',
record_defaults=record_defaults) record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errWithUnescapedQuotes(self): def testCsvDataset_errWithUnescapedQuotes(self):
record_defaults = [['']] * 3 record_defaults = [['']] * 3
inputs = [['"a"b","c","d"']] inputs = [['"a"b","c","d"']]
@ -161,6 +169,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
'Quote inside a string has to be escaped by another quote', 'Quote inside a string has to be escaped by another quote',
record_defaults=record_defaults) record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_ignoreErrWithUnescapedQuotes(self): def testCsvDataset_ignoreErrWithUnescapedQuotes(self):
record_defaults = [['']] * 3 record_defaults = [['']] * 3
inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
@ -169,6 +178,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
dataset = dataset.apply(error_ops.ignore_errors()) dataset = dataset.apply(error_ops.ignore_errors())
self._verify_output_or_err(dataset, [['e', 'f', 'g']]) self._verify_output_or_err(dataset, [['e', 'f', 'g']])
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_ignoreErrWithUnquotedQuotes(self): def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
record_defaults = [['']] * 3 record_defaults = [['']] * 3
inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
@ -177,12 +187,14 @@ class CsvDatasetTest(test_base.DatasetTestBase):
dataset = dataset.apply(error_ops.ignore_errors()) dataset = dataset.apply(error_ops.ignore_errors())
self._verify_output_or_err(dataset, [['e', 'f', 'g']]) self._verify_output_or_err(dataset, [['e', 'f', 'g']])
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self):
record_defaults = [['']] * 3 record_defaults = [['']] * 3
inputs = [['1,2"3,4']] inputs = [['1,2"3,4']]
self._test_by_comparison( self._test_by_comparison(
inputs, record_defaults=record_defaults, use_quote_delim=False) inputs, record_defaults=record_defaults, use_quote_delim=False)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_mixedTypes(self): def testCsvDataset_mixedTypes(self):
record_defaults = [ record_defaults = [
constant_op.constant([], dtype=dtypes.int32), constant_op.constant([], dtype=dtypes.int32),
@ -193,30 +205,35 @@ class CsvDatasetTest(test_base.DatasetTestBase):
inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']] inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']]
self._test_by_comparison(inputs, record_defaults=record_defaults) self._test_by_comparison(inputs, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withUseQuoteDelimFalse(self): def testCsvDataset_withUseQuoteDelimFalse(self):
record_defaults = [['']] * 4 record_defaults = [['']] * 4
inputs = [['1,2,"3,4"', '"5,6",7,8']] inputs = [['1,2,"3,4"', '"5,6",7,8']]
self._test_by_comparison( self._test_by_comparison(
inputs, record_defaults=record_defaults, use_quote_delim=False) inputs, record_defaults=record_defaults, use_quote_delim=False)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withFieldDelim(self): def testCsvDataset_withFieldDelim(self):
record_defaults = [[0]] * 4 record_defaults = [[0]] * 4
inputs = [['1:2:3:4', '5:6:7:8']] inputs = [['1:2:3:4', '5:6:7:8']]
self._test_by_comparison( self._test_by_comparison(
inputs, record_defaults=record_defaults, field_delim=':') inputs, record_defaults=record_defaults, field_delim=':')
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withNaValue(self): def testCsvDataset_withNaValue(self):
record_defaults = [[0]] * 4 record_defaults = [[0]] * 4
inputs = [['1,NA,3,4', 'NA,6,7,8']] inputs = [['1,NA,3,4', 'NA,6,7,8']]
self._test_by_comparison( self._test_by_comparison(
inputs, record_defaults=record_defaults, na_value='NA') inputs, record_defaults=record_defaults, na_value='NA')
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withSelectCols(self): def testCsvDataset_withSelectCols(self):
record_defaults = [['']] * 2 record_defaults = [['']] * 2
inputs = [['1,2,3,4', '"5","6","7","8"']] inputs = [['1,2,3,4', '"5","6","7","8"']]
self._test_by_comparison( self._test_by_comparison(
inputs, record_defaults=record_defaults, select_cols=[1, 2]) inputs, record_defaults=record_defaults, select_cols=[1, 2])
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withSelectColsTooHigh(self): def testCsvDataset_withSelectColsTooHigh(self):
record_defaults = [[0]] * 2 record_defaults = [[0]] * 2
inputs = [['1,2,3,4', '5,6,7,8']] inputs = [['1,2,3,4', '5,6,7,8']]
@ -226,23 +243,27 @@ class CsvDatasetTest(test_base.DatasetTestBase):
record_defaults=record_defaults, record_defaults=record_defaults,
select_cols=[3, 4]) select_cols=[3, 4])
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withOneCol(self): def testCsvDataset_withOneCol(self):
record_defaults = [['NA']] record_defaults = [['NA']]
inputs = [['0', '', '2']] inputs = [['0', '', '2']]
self._test_dataset( self._test_dataset(
inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults) inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withMultipleFiles(self): def testCsvDataset_withMultipleFiles(self):
record_defaults = [[0]] * 4 record_defaults = [[0]] * 4
inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']]
self._test_by_comparison(inputs, record_defaults=record_defaults) self._test_by_comparison(inputs, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withLeadingAndTrailingSpaces(self): def testCsvDataset_withLeadingAndTrailingSpaces(self):
record_defaults = [[0.0]] * 4 record_defaults = [[0.0]] * 4
inputs = [['0, 1, 2, 3']] inputs = [['0, 1, 2, 3']]
expected = [[0.0, 1.0, 2.0, 3.0]] expected = [[0.0, 1.0, 2.0, 3.0]]
self._test_dataset(inputs, expected, record_defaults=record_defaults) self._test_dataset(inputs, expected, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errorWithMissingDefault(self): def testCsvDataset_errorWithMissingDefault(self):
record_defaults = [[]] * 2 record_defaults = [[]] * 2
inputs = [['0,']] inputs = [['0,']]
@ -251,6 +272,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
expected_err_re='Field 1 is required but missing in record!', expected_err_re='Field 1 is required but missing in record!',
record_defaults=record_defaults) record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errorWithFewerDefaultsThanFields(self): def testCsvDataset_errorWithFewerDefaultsThanFields(self):
record_defaults = [[0.0]] * 2 record_defaults = [[0.0]] * 2
inputs = [['0,1,2,3']] inputs = [['0,1,2,3']]
@ -259,6 +281,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
expected_err_re='Expect 2 fields but have more in record', expected_err_re='Expect 2 fields but have more in record',
record_defaults=record_defaults) record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errorWithMoreDefaultsThanFields(self): def testCsvDataset_errorWithMoreDefaultsThanFields(self):
record_defaults = [[0.0]] * 5 record_defaults = [[0.0]] * 5
inputs = [['0,1,2,3']] inputs = [['0,1,2,3']]
@ -267,6 +290,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
expected_err_re='Expect 5 fields but have 4 in record', expected_err_re='Expect 5 fields but have 4 in record',
record_defaults=record_defaults) record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withHeader(self): def testCsvDataset_withHeader(self):
record_defaults = [[0]] * 2 record_defaults = [[0]] * 2
inputs = [['col1,col2', '1,2']] inputs = [['col1,col2', '1,2']]
@ -278,6 +302,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
header=True, header=True,
) )
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withHeaderAndNoRecords(self): def testCsvDataset_withHeaderAndNoRecords(self):
record_defaults = [[0]] * 2 record_defaults = [[0]] * 2
inputs = [['col1,col2']] inputs = [['col1,col2']]
@ -289,6 +314,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
header=True, header=True,
) )
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errorWithHeaderEmptyFile(self): def testCsvDataset_errorWithHeaderEmptyFile(self):
record_defaults = [[0]] * 2 record_defaults = [[0]] * 2
inputs = [[]] inputs = [[]]
@ -300,12 +326,14 @@ class CsvDatasetTest(test_base.DatasetTestBase):
header=True, header=True,
) )
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withEmptyFile(self): def testCsvDataset_withEmptyFile(self):
record_defaults = [['']] * 2 record_defaults = [['']] * 2
inputs = [['']] # Empty file inputs = [['']] # Empty file
self._test_dataset( self._test_dataset(
inputs, expected_output=[], record_defaults=record_defaults) inputs, expected_output=[], record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errorWithEmptyRecord(self): def testCsvDataset_errorWithEmptyRecord(self):
record_defaults = [['']] * 2 record_defaults = [['']] * 2
inputs = [['', '1,2']] # First record is empty inputs = [['', '1,2']] # First record is empty
@ -314,6 +342,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
expected_err_re='Expect 2 fields but have 1 in record', expected_err_re='Expect 2 fields but have 1 in record',
record_defaults=record_defaults) record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withChainedOps(self): def testCsvDataset_withChainedOps(self):
# Testing that one dataset can create multiple iterators fine. # Testing that one dataset can create multiple iterators fine.
# `repeat` creates multiple iterators from the same C++ Dataset. # `repeat` creates multiple iterators from the same C++ Dataset.
@ -325,6 +354,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
ds_actual.repeat(5).prefetch(1), ds_actual.repeat(5).prefetch(1),
ds_expected.repeat(5).prefetch(1)) ds_expected.repeat(5).prefetch(1))
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withTypeDefaults(self): def testCsvDataset_withTypeDefaults(self):
# Testing using dtypes as record_defaults for required fields # Testing using dtypes as record_defaults for required fields
record_defaults = [dtypes.float32, [0.0]] record_defaults = [dtypes.float32, [0.0]]
@ -335,6 +365,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
record_defaults=record_defaults, record_defaults=record_defaults,
) )
@combinations.generate(test_base.default_test_combinations())
def testMakeCsvDataset_fieldOrder(self): def testMakeCsvDataset_fieldOrder(self):
data = [[ data = [[
'1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19', '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19',
@ -352,6 +383,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
## The following tests exercise parsing logic for quoted fields ## The following tests exercise parsing logic for quoted fields
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withQuoted(self): def testCsvDataset_withQuoted(self):
record_defaults = [['']] * 4 record_defaults = [['']] * 4
inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']] inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']]
@ -363,6 +395,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
self._test_dataset( self._test_dataset(
inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults) inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withNewLine(self): def testCsvDataset_withNewLine(self):
# In this case, we expect it to behave differently from # In this case, we expect it to behave differently from
# TextLineDataset->map(decode_csv) since that flow has bugs # TextLineDataset->map(decode_csv) since that flow has bugs
@ -371,6 +404,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']]
self._test_dataset(inputs, expected, record_defaults=record_defaults) self._test_dataset(inputs, expected, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withNewLineInUnselectedCol(self): def testCsvDataset_withNewLineInUnselectedCol(self):
record_defaults = [['']] record_defaults = [['']]
inputs = [['1,"2\n3",4', '5,6,7']] inputs = [['1,"2\n3",4', '5,6,7']]
@ -380,6 +414,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
record_defaults=record_defaults, record_defaults=record_defaults,
select_cols=[0]) select_cols=[0])
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withMultipleNewLines(self): def testCsvDataset_withMultipleNewLines(self):
# In this case, we expect it to behave differently from # In this case, we expect it to behave differently from
# TextLineDataset->map(decode_csv) since that flow has bugs # TextLineDataset->map(decode_csv) since that flow has bugs
@ -388,6 +423,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']]
self._test_dataset(inputs, expected, record_defaults=record_defaults) self._test_dataset(inputs, expected, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errorWithTerminateMidRecord(self): def testCsvDataset_errorWithTerminateMidRecord(self):
record_defaults = [['']] * 4 record_defaults = [['']] * 4
inputs = [['a,b,c,"a']] inputs = [['a,b,c,"a']]
@ -397,6 +433,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
'Reached end of file without closing quoted field in record', 'Reached end of file without closing quoted field in record',
record_defaults=record_defaults) record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withEscapedQuotes(self): def testCsvDataset_withEscapedQuotes(self):
record_defaults = [['']] * 4 record_defaults = [['']] * 4
inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']] inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']]
@ -406,6 +443,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
## Testing that parsing works with all buffer sizes, quoted/unquoted fields, ## Testing that parsing works with all buffer sizes, quoted/unquoted fields,
## and different types of line breaks ## and different types of line breaks
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withInvalidBufferSize(self): def testCsvDataset_withInvalidBufferSize(self):
record_defaults = [['']] * 4 record_defaults = [['']] * 4
inputs = [['a,b,c,d']] inputs = [['a,b,c,d']]
@ -432,6 +470,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
record_defaults=record_defaults, record_defaults=record_defaults,
buffer_size=i) buffer_size=i)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withLF(self): def testCsvDataset_withLF(self):
record_defaults = [['NA']] * 3 record_defaults = [['NA']] * 3
inputs = [['abc,def,ghi', '0,1,2', ',,']] inputs = [['abc,def,ghi', '0,1,2', ',,']]
@ -439,6 +478,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
self._test_dataset_on_buffer_sizes( self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\n', record_defaults=record_defaults) inputs, expected, linebreak='\n', record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withCR(self): def testCsvDataset_withCR(self):
# Test that when the line separator is '\r', parsing works with all buffer # Test that when the line separator is '\r', parsing works with all buffer
# sizes # sizes
@ -448,6 +488,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
self._test_dataset_on_buffer_sizes( self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\r', record_defaults=record_defaults) inputs, expected, linebreak='\r', record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withCRLF(self): def testCsvDataset_withCRLF(self):
# Test that when the line separator is '\r\n', parsing works with all buffer # Test that when the line separator is '\r\n', parsing works with all buffer
# sizes # sizes
@ -457,6 +498,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
self._test_dataset_on_buffer_sizes( self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\r\n', record_defaults=record_defaults) inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withBufferSizeAndQuoted(self): def testCsvDataset_withBufferSizeAndQuoted(self):
record_defaults = [['NA']] * 3 record_defaults = [['NA']] * 3
inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
@ -465,6 +507,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
self._test_dataset_on_buffer_sizes( self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\n', record_defaults=record_defaults) inputs, expected, linebreak='\n', record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withCRAndQuoted(self): def testCsvDataset_withCRAndQuoted(self):
# Test that when the line separator is '\r', parsing works with all buffer # Test that when the line separator is '\r', parsing works with all buffer
# sizes # sizes
@ -475,6 +518,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
self._test_dataset_on_buffer_sizes( self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\r', record_defaults=record_defaults) inputs, expected, linebreak='\r', record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withCRLFAndQuoted(self): def testCsvDataset_withCRLFAndQuoted(self):
# Test that when the line separator is '\r\n', parsing works with all buffer # Test that when the line separator is '\r\n', parsing works with all buffer
# sizes # sizes
@ -485,6 +529,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
self._test_dataset_on_buffer_sizes( self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\r\n', record_defaults=record_defaults) inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withGzipCompressionType(self): def testCsvDataset_withGzipCompressionType(self):
record_defaults = [['NA']] * 3 record_defaults = [['NA']] * 3
inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
@ -497,6 +542,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
compression_type='GZIP', compression_type='GZIP',
record_defaults=record_defaults) record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withZlibCompressionType(self): def testCsvDataset_withZlibCompressionType(self):
record_defaults = [['NA']] * 3 record_defaults = [['NA']] * 3
inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
@ -509,6 +555,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
compression_type='ZLIB', compression_type='ZLIB',
record_defaults=record_defaults) record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withScalarDefaults(self): def testCsvDataset_withScalarDefaults(self):
record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4 record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4
inputs = [[',,,', '1,1,1,', ',2,2,2']] inputs = [[',,,', '1,1,1,', ',2,2,2']]
@ -516,6 +563,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
record_defaults=record_defaults) record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_with2DDefaults(self): def testCsvDataset_with2DDefaults(self):
record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4 record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4
inputs = [[',,,', '1,1,1,', ',2,2,2']] inputs = [[',,,', '1,1,1,', ',2,2,2']]

View File

@ -17,20 +17,21 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes class DenseToSparseBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
class DenseToSparseBatchTest(test_base.DatasetTestBase):
@combinations.generate(test_base.default_test_combinations())
def testDenseToSparseBatchDataset(self): def testDenseToSparseBatchDataset(self):
components = np.random.randint(12, size=(100,)).astype(np.int32) components = np.random.randint(12, size=(100,)).astype(np.int32)
dataset = dataset_ops.Dataset.from_tensor_slices( dataset = dataset_ops.Dataset.from_tensor_slices(
@ -53,6 +54,7 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next()) self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testDenseToSparseBatchDatasetWithUnknownShape(self): def testDenseToSparseBatchDatasetWithUnknownShape(self):
components = np.random.randint(5, size=(40,)).astype(np.int32) components = np.random.randint(5, size=(40,)).astype(np.int32)
dataset = dataset_ops.Dataset.from_tensor_slices( dataset = dataset_ops.Dataset.from_tensor_slices(
@ -80,12 +82,14 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next()) self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testDenseToSparseBatchDatasetWithInvalidShape(self): def testDenseToSparseBatchDatasetWithInvalidShape(self):
input_tensor = array_ops.constant([[1]]) input_tensor = array_ops.constant([[1]])
with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"): with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
dataset_ops.Dataset.from_tensors(input_tensor).apply( dataset_ops.Dataset.from_tensors(input_tensor).apply(
batching.dense_to_sparse_batch(4, [-2])) batching.dense_to_sparse_batch(4, [-2]))
@combinations.generate(test_base.default_test_combinations())
def testDenseToSparseBatchDatasetShapeErrors(self): def testDenseToSparseBatchDatasetShapeErrors(self):
def dataset_fn(input_tensor): def dataset_fn(input_tensor):

View File

@ -17,22 +17,24 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes class DirectedInterleaveDatasetTest(test_base.DatasetTestBase,
class DirectedInterleaveDatasetTest(test_base.DatasetTestBase): parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testBasic(self): def testBasic(self):
selector_dataset = dataset_ops.Dataset.range(10).repeat(100) selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
input_datasets = [ input_datasets = [
@ -76,6 +78,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
return freqs return freqs
@combinations.generate(test_base.default_test_combinations())
def testSampleFromDatasets(self): def testSampleFromDatasets(self):
random_seed.set_random_seed(1619) random_seed.set_random_seed(1619)
num_samples = 5000 num_samples = 5000
@ -95,6 +98,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples)
self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
@combinations.generate(test_base.default_test_combinations())
def testSelectFromDatasets(self): def testSelectFromDatasets(self):
words = [b"foo", b"bar", b"baz"] words = [b"foo", b"bar", b"baz"]
datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words] datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words]
@ -107,6 +111,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testErrors(self): def testErrors(self):
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegexp(ValueError,
r"vector of length `len\(datasets\)`"): r"vector of length `len\(datasets\)`"):

View File

@ -23,25 +23,30 @@ from tensorflow.python.data.experimental.ops import get_single_element
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase): class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters( @combinations.generate(
("Zero", 0, 1), combinations.times(
("Five", 5, 1), test_base.default_test_combinations(),
("Ten", 10, 1), combinations.combine(
("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."), skip=[0, 5, 10], take=[1], error=[None], error_msg=[None]) +
("MoreThanOne", 0, 2, errors.InvalidArgumentError, combinations.combine(
"Dataset had more than one element."), skip=[100],
) take=[1],
error=[errors.InvalidArgumentError],
error_msg=["Dataset was empty."]) + combinations.combine(
skip=[0],
take=[2],
error=[errors.InvalidArgumentError],
error_msg=["Dataset had more than one element."])))
def testGetSingleElement(self, skip, take, error=None, error_msg=None): def testGetSingleElement(self, skip, take, error=None, error_msg=None):
def make_sparse(x): def make_sparse(x):
@ -62,6 +67,7 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaisesRegexp(error, error_msg): with self.assertRaisesRegexp(error, error_msg):
self.evaluate(get_single_element.get_single_element(dataset)) self.evaluate(get_single_element.get_single_element(dataset))
@combinations.generate(test_base.default_test_combinations())
def testWindow(self): def testWindow(self):
"""Test that `get_single_element()` can consume a nested dataset.""" """Test that `get_single_element()` can consume a nested dataset."""
def flat_map_func(ds): def flat_map_func(ds):
@ -73,6 +79,7 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces( self.assertDatasetProduces(
dataset, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) dataset, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
@combinations.generate(test_base.default_test_combinations())
def testSideEffect(self): def testSideEffect(self):
counter_var = variables.Variable(0) counter_var = variables.Variable(0)
@ -92,6 +99,7 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(self.evaluate(fn()), b"hello") self.assertEqual(self.evaluate(fn()), b"hello")
self.assertEqual(self.evaluate(counter_var), 1) self.assertEqual(self.evaluate(counter_var), 1)
@combinations.generate(test_base.default_test_combinations())
def testAutomaticControlDependencies(self): def testAutomaticControlDependencies(self):
counter_var = variables.Variable(1) counter_var = variables.Variable(1)

View File

@ -17,25 +17,26 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.experimental.ops import grouping from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes class GroupByReducerTest(test_base.DatasetTestBase, parameterized.TestCase):
class GroupByReducerTest(test_base.DatasetTestBase):
@combinations.generate(test_base.default_test_combinations())
def testSum(self): def testSum(self):
reducer = grouping.Reducer( reducer = grouping.Reducer(
init_func=lambda _: np.int64(0), init_func=lambda _: np.int64(0),
@ -49,6 +50,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
expected_shapes=tensor_shape.TensorShape([]), expected_shapes=tensor_shape.TensorShape([]),
expected_output=[(i - 1) * i, i * i]) expected_output=[(i - 1) * i, i * i])
@combinations.generate(test_base.default_test_combinations())
def testAverage(self): def testAverage(self):
def reduce_fn(x, y): def reduce_fn(x, y):
@ -68,6 +70,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
expected_shapes=tensor_shape.TensorShape([]), expected_shapes=tensor_shape.TensorShape([]),
expected_output=[i - 1, i]) expected_output=[i - 1, i])
@combinations.generate(test_base.default_test_combinations())
def testConcat(self): def testConcat(self):
components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray) components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
reducer = grouping.Reducer( reducer = grouping.Reducer(
@ -84,6 +87,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
expected_shapes=tensor_shape.TensorShape([]), expected_shapes=tensor_shape.TensorShape([]),
expected_output=[b"acegikmoqs"[:i], b"bdfhjlnprt"[:i]]) expected_output=[b"acegikmoqs"[:i], b"bdfhjlnprt"[:i]])
@combinations.generate(test_base.default_test_combinations())
def testSparseSum(self): def testSparseSum(self):
def _sparse(i): def _sparse(i):
return sparse_tensor.SparseTensorValue( return sparse_tensor.SparseTensorValue(
@ -103,6 +107,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
expected_shapes=tensor_shape.TensorShape([]), expected_shapes=tensor_shape.TensorShape([]),
expected_output=[(i - 1) * i, i * i]) expected_output=[(i - 1) * i, i * i])
@combinations.generate(test_base.default_test_combinations())
def testChangingStateShape(self): def testChangingStateShape(self):
def reduce_fn(x, _): def reduce_fn(x, _):
@ -130,6 +135,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next()) self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testTypeMismatch(self): def testTypeMismatch(self):
reducer = grouping.Reducer( reducer = grouping.Reducer(
init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32), init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
@ -144,6 +150,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
grouping.group_by_reducer(lambda _: np.int64(0), reducer)) grouping.group_by_reducer(lambda _: np.int64(0), reducer))
# TODO(b/78665031): Remove once non-scalar keys are supported. # TODO(b/78665031): Remove once non-scalar keys are supported.
@combinations.generate(test_base.default_test_combinations())
def testInvalidKeyShape(self): def testInvalidKeyShape(self):
reducer = grouping.Reducer( reducer = grouping.Reducer(
init_func=lambda x: np.int64(0), init_func=lambda x: np.int64(0),
@ -157,6 +164,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer)) grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
# TODO(b/78665031): Remove once non-int64 keys are supported. # TODO(b/78665031): Remove once non-int64 keys are supported.
@combinations.generate(test_base.default_test_combinations())
def testInvalidKeyType(self): def testInvalidKeyType(self):
reducer = grouping.Reducer( reducer = grouping.Reducer(
init_func=lambda x: np.int64(0), init_func=lambda x: np.int64(0),
@ -169,6 +177,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
dataset.apply( dataset.apply(
grouping.group_by_reducer(lambda _: "wrong", reducer)) grouping.group_by_reducer(lambda _: "wrong", reducer))
@combinations.generate(test_base.default_test_combinations())
def testTuple(self): def testTuple(self):
def init_fn(_): def init_fn(_):
return np.array([], dtype=np.int64), np.int64(0) return np.array([], dtype=np.int64), np.int64(0)

View File

@ -17,17 +17,18 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.experimental.ops import grouping from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops from tensorflow.python.ops import string_ops
@ -37,8 +38,7 @@ from tensorflow.python.platform import test
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. # NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
# Currently, they use a constant batch size, though should be made to use a # Currently, they use a constant batch size, though should be made to use a
# different batch size per key. # different batch size per key.
@test_util.run_all_in_graph_and_eager_modes class GroupByWindowTest(test_base.DatasetTestBase, parameterized.TestCase):
class GroupByWindowTest(test_base.DatasetTestBase):
def _dynamicPad(self, bucket, window, window_size): def _dynamicPad(self, bucket, window, window_size):
# TODO(mrry): To match `tf.contrib.training.bucket()`, implement a # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
@ -51,6 +51,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape( 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
[None]), tensor_shape.TensorShape([3]))))) [None]), tensor_shape.TensorShape([3])))))
@combinations.generate(test_base.default_test_combinations())
def testSingleBucket(self): def testSingleBucket(self):
def _map_fn(v): def _map_fn(v):
@ -80,6 +81,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
self.assertAllEqual(expected_unk_int64, bucketed_values[1]) self.assertAllEqual(expected_unk_int64, bucketed_values[1])
self.assertAllEqual(expected_vec3_str, bucketed_values[2]) self.assertAllEqual(expected_vec3_str, bucketed_values[2])
@combinations.generate(test_base.default_test_combinations())
def testEvenOddBuckets(self): def testEvenOddBuckets(self):
def _map_fn(v): def _map_fn(v):
@ -132,6 +134,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1]) self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2]) self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
@combinations.generate(test_base.default_test_combinations())
def testEvenOddBucketsFilterOutAllOdd(self): def testEvenOddBucketsFilterOutAllOdd(self):
def _map_fn(v): def _map_fn(v):
@ -173,6 +176,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
self.assertAllEqual( self.assertAllEqual(
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"]) np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
@combinations.generate(test_base.default_test_combinations())
def testDynamicWindowSize(self): def testDynamicWindowSize(self):
components = np.arange(100).astype(np.int64) components = np.arange(100).astype(np.int64)
@ -202,6 +206,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
self.assertEqual(batches, 15) self.assertEqual(batches, 15)
@combinations.generate(test_base.default_test_combinations())
def testSimple(self): def testSimple(self):
components = np.random.randint(100, size=(200,)).astype(np.int64) components = np.random.randint(100, size=(200,)).astype(np.int64)
dataset = dataset_ops.Dataset.from_tensor_slices( dataset = dataset_ops.Dataset.from_tensor_slices(
@ -222,6 +227,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
self.assertGreaterEqual(num_full_batches, 24) self.assertGreaterEqual(num_full_batches, 24)
self.assertTrue(all(c == 4 for c in counts[:num_full_batches])) self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
@combinations.generate(test_base.default_test_combinations())
def testImmediateOutput(self): def testImmediateOutput(self):
components = np.array( components = np.array(
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
@ -240,6 +246,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
self.assertAllEqual([2, 2, 2, 2], self.evaluate(get_next())) self.assertAllEqual([2, 2, 2, 2], self.evaluate(get_next()))
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next())) self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next()))
@combinations.generate(test_base.default_test_combinations())
def testSmallGroups(self): def testSmallGroups(self):
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64) components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
dataset = dataset_ops.Dataset.from_tensor_slices(components).apply( dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
@ -252,6 +259,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
self.assertAllEqual([0, 0, 0], self.evaluate(get_next())) self.assertAllEqual([0, 0, 0], self.evaluate(get_next()))
self.assertAllEqual([1], self.evaluate(get_next())) self.assertAllEqual([1], self.evaluate(get_next()))
@combinations.generate(test_base.default_test_combinations())
def testEmpty(self): def testEmpty(self):
dataset = dataset_ops.Dataset.range(4).apply( dataset = dataset_ops.Dataset.range(4).apply(
grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0)) grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
@ -262,6 +270,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
"Window size must be greater than zero, but got 0."): "Window size must be greater than zero, but got 0."):
print(self.evaluate(get_next())) print(self.evaluate(get_next()))
@combinations.generate(test_base.default_test_combinations())
def testReduceFuncError(self): def testReduceFuncError(self):
components = np.random.randint(100, size=(200,)).astype(np.int64) components = np.random.randint(100, size=(200,)).astype(np.int64)
@ -280,6 +289,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(get_next()) self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testConsumeWindowDatasetMoreThanOnce(self): def testConsumeWindowDatasetMoreThanOnce(self):
components = np.random.randint(50, size=(200,)).astype(np.int64) components = np.random.randint(50, size=(200,)).astype(np.int64)
@ -311,6 +321,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
counts.append(tight_result.shape[0]) counts.append(tight_result.shape[0])
self.assertEqual(len(components), sum(counts)) self.assertEqual(len(components), sum(counts))
@combinations.generate(test_base.default_test_combinations())
def testShortCircuit(self): def testShortCircuit(self):
dataset = dataset_ops.Dataset.range(10) dataset = dataset_ops.Dataset.range(10)

View File

@ -19,14 +19,15 @@ from __future__ import print_function
import os import os
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.experimental.ops import error_ops from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers from tensorflow.python.data.ops import readers
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import python_io from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import io_ops from tensorflow.python.ops import io_ops
@ -36,9 +37,9 @@ from tensorflow.python.util import compat
_NUMPY_RANDOM_SEED = 42 _NUMPY_RANDOM_SEED = 42
@test_util.run_all_in_graph_and_eager_modes class IgnoreErrorsTest(test_base.DatasetTestBase, parameterized.TestCase):
class IgnoreErrorsTest(test_base.DatasetTestBase):
@combinations.generate(test_base.default_test_combinations())
def testMapIgnoreError(self): def testMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
@ -53,6 +54,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next()) self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testParallelMapIgnoreError(self): def testParallelMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
@ -67,6 +69,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next()) self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testReadFileIgnoreError(self): def testReadFileIgnoreError(self):
def write_string_to_file(value, filename): def write_string_to_file(value, filename):
@ -102,6 +105,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next()) self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testTFRecordDatasetIgnoreError(self): def testTFRecordDatasetIgnoreError(self):
filenames = [] filenames = []
for i in range(5): for i in range(5):
@ -126,6 +130,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next()) self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testZipIgnoreError(self): def testZipIgnoreError(self):
a = dataset_ops.Dataset.from_tensor_slices([1., 2., 0., 4.]) a = dataset_ops.Dataset.from_tensor_slices([1., 2., 0., 4.])
b = a.map(lambda x: array_ops.check_numerics(1. / x, "error")) b = a.map(lambda x: array_ops.check_numerics(1. / x, "error"))

View File

@ -17,26 +17,29 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.experimental.ops import readers from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import nest from tensorflow.python.data.util import nest
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import io_ops from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class MakeBatchedFeaturesDatasetTest( class MakeBatchedFeaturesDatasetTest(
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase): reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase,
parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testRead(self): def testRead(self):
for batch_size in [1, 2]: for batch_size in [1, 2]:
for num_epochs in [1, 10]: for num_epochs in [1, 10]:
@ -85,6 +88,7 @@ class MakeBatchedFeaturesDatasetTest(
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch() self._next_actual_batch()
@combinations.generate(test_base.default_test_combinations())
def testReadWithEquivalentDataset(self): def testReadWithEquivalentDataset(self):
features = { features = {
"file": parsing_ops.FixedLenFeature([], dtypes.int64), "file": parsing_ops.FixedLenFeature([], dtypes.int64),
@ -103,6 +107,7 @@ class MakeBatchedFeaturesDatasetTest(
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testReadWithFusedShuffleRepeatDataset(self): def testReadWithFusedShuffleRepeatDataset(self):
num_epochs = 5 num_epochs = 5
total_records = num_epochs * self._num_records total_records = num_epochs * self._num_records
@ -151,6 +156,7 @@ class MakeBatchedFeaturesDatasetTest(
all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
self.assertFalse(all_equal) self.assertFalse(all_equal)
@combinations.generate(test_base.default_test_combinations())
def testParallelReadersAndParsers(self): def testParallelReadersAndParsers(self):
num_epochs = 5 num_epochs = 5
for batch_size in [1, 2]: for batch_size in [1, 2]:
@ -186,6 +192,7 @@ class MakeBatchedFeaturesDatasetTest(
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch() self._next_actual_batch()
@combinations.generate(test_base.default_test_combinations())
def testDropFinalBatch(self): def testDropFinalBatch(self):
for batch_size in [1, 2]: for batch_size in [1, 2]:
for num_epochs in [1, 10]: for num_epochs in [1, 10]:
@ -201,6 +208,7 @@ class MakeBatchedFeaturesDatasetTest(
if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
self.assertEqual(tensor.shape[0], batch_size) self.assertEqual(tensor.shape[0], batch_size)
@combinations.generate(test_base.default_test_combinations())
def testIndefiniteRepeatShapeInference(self): def testIndefiniteRepeatShapeInference(self):
dataset = self.make_batch_feature( dataset = self.make_batch_feature(
filenames=self.test_filenames[0], filenames=self.test_filenames[0],
@ -213,6 +221,7 @@ class MakeBatchedFeaturesDatasetTest(
if issubclass(clazz, ops.Tensor): if issubclass(clazz, ops.Tensor):
self.assertEqual(32, shape[0]) self.assertEqual(32, shape[0])
@combinations.generate(test_base.default_test_combinations())
def testOldStyleReader(self): def testOldStyleReader(self):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
TypeError, r"The `reader` argument must return a `Dataset` object. " TypeError, r"The `reader` argument must return a `Dataset` object. "

View File

@ -21,21 +21,21 @@ import gzip
import os import os
import zlib import zlib
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.experimental.ops import readers from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest from tensorflow.python.data.util import nest
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes class MakeCsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
class MakeCsvDatasetTest(test_base.DatasetTestBase):
def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs): def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
return readers.make_csv_dataset( return readers.make_csv_dataset(
@ -126,6 +126,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
self._verify_output(dataset, batch_size, num_epochs, label_name, self._verify_output(dataset, batch_size, num_epochs, label_name,
expected_output, expected_keys) expected_output, expected_keys)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset(self): def testMakeCSVDataset(self):
"""Tests making a CSV dataset with keys and defaults provided.""" """Tests making a CSV dataset with keys and defaults provided."""
record_defaults = [ record_defaults = [
@ -157,6 +158,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
column_defaults=record_defaults, column_defaults=record_defaults,
) )
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withBatchSizeAndEpochs(self): def testMakeCSVDataset_withBatchSizeAndEpochs(self):
"""Tests making a CSV dataset with keys and defaults provided.""" """Tests making a CSV dataset with keys and defaults provided."""
record_defaults = [ record_defaults = [
@ -188,6 +190,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
column_defaults=record_defaults, column_defaults=record_defaults,
) )
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withCompressionType(self): def testMakeCSVDataset_withCompressionType(self):
"""Tests `compression_type` argument.""" """Tests `compression_type` argument."""
record_defaults = [ record_defaults = [
@ -221,6 +224,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
compression_type=compression_type, compression_type=compression_type,
) )
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withCompressionTypeAndNoColumnNames(self): def testMakeCSVDataset_withCompressionTypeAndNoColumnNames(self):
"""Tests `compression_type` argument.""" """Tests `compression_type` argument."""
record_defaults = [ record_defaults = [
@ -269,6 +273,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
compression_type="ZLIB", compression_type="ZLIB",
) )
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withBadInputs(self): def testMakeCSVDataset_withBadInputs(self):
"""Tests that exception is raised when input is malformed. """Tests that exception is raised when input is malformed.
""" """
@ -304,6 +309,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
label_name="not_a_real_label", label_name="not_a_real_label",
column_names=column_names) column_names=column_names)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withNoLabel(self): def testMakeCSVDataset_withNoLabel(self):
"""Tests making a CSV dataset with no label provided.""" """Tests making a CSV dataset with no label provided."""
record_defaults = [ record_defaults = [
@ -333,6 +339,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
column_defaults=record_defaults, column_defaults=record_defaults,
) )
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withNoHeader(self): def testMakeCSVDataset_withNoHeader(self):
"""Tests that datasets can be created from CSV files with no header line. """Tests that datasets can be created from CSV files with no header line.
""" """
@ -363,6 +370,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
column_defaults=record_defaults, column_defaults=record_defaults,
) )
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withTypes(self): def testMakeCSVDataset_withTypes(self):
"""Tests that defaults can be a dtype instead of a Tensor for required vals. """Tests that defaults can be a dtype instead of a Tensor for required vals.
""" """
@ -394,6 +402,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
column_defaults=record_defaults, column_defaults=record_defaults,
) )
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withNoColNames(self): def testMakeCSVDataset_withNoColNames(self):
"""Tests that datasets can be created when column names are not specified. """Tests that datasets can be created when column names are not specified.
@ -427,6 +436,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
column_defaults=record_defaults, column_defaults=record_defaults,
) )
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withTypeInferenceMismatch(self): def testMakeCSVDataset_withTypeInferenceMismatch(self):
# Test that error is thrown when num fields doesn't match columns # Test that error is thrown when num fields doesn't match columns
column_names = ["col%d" % i for i in range(5)] column_names = ["col%d" % i for i in range(5)]
@ -442,6 +452,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
batch_size=2, batch_size=2,
num_epochs=10) num_epochs=10)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withTypeInference(self): def testMakeCSVDataset_withTypeInference(self):
"""Tests that datasets can be created when no defaults are specified. """Tests that datasets can be created when no defaults are specified.
@ -468,6 +479,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
header=True, header=True,
) )
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withTypeInferenceFallthrough(self): def testMakeCSVDataset_withTypeInferenceFallthrough(self):
"""Tests that datasets can be created when no defaults are specified. """Tests that datasets can be created when no defaults are specified.
@ -498,6 +510,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
header=True, header=True,
) )
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withNAValuesAndFieldDelim(self): def testMakeCSVDataset_withNAValuesAndFieldDelim(self):
"""Tests that datasets can be created from different delim and na_value.""" """Tests that datasets can be created from different delim and na_value."""
column_names = ["col%d" % i for i in range(5)] column_names = ["col%d" % i for i in range(5)]
@ -520,6 +533,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
field_delim=" ", field_delim=" ",
) )
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withSelectCols(self): def testMakeCSVDataset_withSelectCols(self):
record_defaults = [ record_defaults = [
constant_op.constant([], dtypes.int32), constant_op.constant([], dtypes.int32),
@ -588,6 +602,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
select_columns=[column_names[i] for i in select_cols], select_columns=[column_names[i] for i in select_cols],
) )
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withSelectColsError(self): def testMakeCSVDataset_withSelectColsError(self):
record_defaults = [ record_defaults = [
constant_op.constant([], dtypes.int32), constant_op.constant([], dtypes.int32),
@ -626,6 +641,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
label_name=None, label_name=None,
select_columns=["invalid_col_name"]) select_columns=["invalid_col_name"])
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withShuffle(self): def testMakeCSVDataset_withShuffle(self):
record_defaults = [ record_defaults = [
constant_op.constant([], dtypes.int32), constant_op.constant([], dtypes.int32),
@ -710,6 +726,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
self.assertFalse(all_equal) self.assertFalse(all_equal)
@combinations.generate(test_base.default_test_combinations())
def testIndefiniteRepeatShapeInference(self): def testIndefiniteRepeatShapeInference(self):
column_names = ["col%d" % i for i in range(5)] column_names = ["col%d" % i for i in range(5)]
inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [

View File

@ -17,19 +17,22 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.experimental.ops import readers from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest from tensorflow.python.data.util import nest
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import string_ops from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class MakeTFRecordDatasetTest( class MakeTFRecordDatasetTest(
reader_dataset_ops_test_base.TFRecordDatasetTestBase): reader_dataset_ops_test_base.TFRecordDatasetTestBase,
parameterized.TestCase):
def _read_test(self, batch_size, num_epochs, file_index=None, def _read_test(self, batch_size, num_epochs, file_index=None,
num_parallel_reads=1, drop_final_batch=False, parser_fn=False): num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
@ -63,6 +66,7 @@ class MakeTFRecordDatasetTest(
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(outputs()) self.evaluate(outputs())
@combinations.generate(test_base.default_test_combinations())
def testRead(self): def testRead(self):
for batch_size in [1, 2]: for batch_size in [1, 2]:
for num_epochs in [1, 3]: for num_epochs in [1, 3]:
@ -78,6 +82,7 @@ class MakeTFRecordDatasetTest(
# Basic test: read from both files, with parallel reads. # Basic test: read from both files, with parallel reads.
self._read_test(batch_size, num_epochs, num_parallel_reads=8) self._read_test(batch_size, num_epochs, num_parallel_reads=8)
@combinations.generate(test_base.default_test_combinations())
def testDropFinalBatch(self): def testDropFinalBatch(self):
for batch_size in [1, 2, 10]: for batch_size in [1, 2, 10]:
for num_epochs in [1, 3]: for num_epochs in [1, 3]:
@ -91,6 +96,7 @@ class MakeTFRecordDatasetTest(
self._read_test(batch_size, num_epochs, num_parallel_reads=8, self._read_test(batch_size, num_epochs, num_parallel_reads=8,
drop_final_batch=True) drop_final_batch=True)
@combinations.generate(test_base.default_test_combinations())
def testParserFn(self): def testParserFn(self):
for batch_size in [1, 2]: for batch_size in [1, 2]:
for num_epochs in [1, 3]: for num_epochs in [1, 3]:
@ -145,6 +151,7 @@ class MakeTFRecordDatasetTest(
actual.extend(b) actual.extend(b)
self.assertAllEqual(sorted(expected), sorted(actual)) self.assertAllEqual(sorted(expected), sorted(actual))
@combinations.generate(test_base.default_test_combinations())
def testShuffle(self): def testShuffle(self):
for batch_size in [1, 2]: for batch_size in [1, 2]:
for num_epochs in [1, 3]: for num_epochs in [1, 3]:
@ -156,6 +163,7 @@ class MakeTFRecordDatasetTest(
self._shuffle_test(batch_size, num_epochs, num_parallel_reads, self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
seed=21345) seed=21345)
@combinations.generate(test_base.default_test_combinations())
def testIndefiniteRepeatShapeInference(self): def testIndefiniteRepeatShapeInference(self):
dataset = readers.make_tf_record_dataset( dataset = readers.make_tf_record_dataset(
file_pattern=self.test_filenames, num_epochs=None, batch_size=32) file_pattern=self.test_filenames, num_epochs=None, batch_size=32)

View File

@ -19,17 +19,19 @@ from __future__ import print_function
import time import time
from absl.testing import parameterized
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import map_defun from tensorflow.python.data.experimental.ops import map_defun
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import data_flow_ops
@ -38,9 +40,11 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_v1_only("b/123903858: Add eager and V2 test coverage") # TODO(b/123903858): Add eager and V2 test coverage
class MapDefunTest(test_base.DatasetTestBase): class MapDefunTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testNoIntraOpLimit(self): def testNoIntraOpLimit(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
@ -55,6 +59,8 @@ class MapDefunTest(test_base.DatasetTestBase):
expected = elems * 2 + 3 expected = elems * 2 + 3
self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunSimple(self): def testMapDefunSimple(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
@ -67,6 +73,8 @@ class MapDefunTest(test_base.DatasetTestBase):
expected = elems * 2 + 3 expected = elems * 2 + 3
self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunMismatchedTypes(self): def testMapDefunMismatchedTypes(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
@ -79,6 +87,8 @@ class MapDefunTest(test_base.DatasetTestBase):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(r) self.evaluate(r)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunReduceDim(self): def testMapDefunReduceDim(self):
# Tests where the output has a different rank from the input # Tests where the output has a different rank from the input
@ -92,6 +102,8 @@ class MapDefunTest(test_base.DatasetTestBase):
expected = constant_op.constant([1, 3, 5]) expected = constant_op.constant([1, 3, 5])
self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunMultipleOutputs(self): def testMapDefunMultipleOutputs(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
@ -105,6 +117,8 @@ class MapDefunTest(test_base.DatasetTestBase):
expected = [elems, elems * 2 + 3] expected = [elems, elems * 2 + 3]
self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunShapeInference(self): def testMapDefunShapeInference(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
@ -116,6 +130,8 @@ class MapDefunTest(test_base.DatasetTestBase):
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0] result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
self.assertEqual(result.get_shape(), (3, 2)) self.assertEqual(result.get_shape(), (3, 2))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunPartialShapeInference(self): def testMapDefunPartialShapeInference(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
@ -126,6 +142,8 @@ class MapDefunTest(test_base.DatasetTestBase):
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)]) result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
self.assertEqual(result[0].get_shape().as_list(), [None, 2]) self.assertEqual(result[0].get_shape().as_list(), [None, 2])
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self): def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
@function.defun(input_signature=[ @function.defun(input_signature=[
@ -145,6 +163,8 @@ class MapDefunTest(test_base.DatasetTestBase):
"All inputs must have the same dimension 0."): "All inputs must have the same dimension 0."):
sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]}) sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunRaisesDefunError(self): def testMapDefunRaisesDefunError(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
@ -157,6 +177,8 @@ class MapDefunTest(test_base.DatasetTestBase):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(result) self.evaluate(result)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunCancelledCorrectly(self): def testMapDefunCancelledCorrectly(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([5], dtypes.int64)]) @function.defun(input_signature=[tensor_spec.TensorSpec([5], dtypes.int64)])
@ -173,6 +195,8 @@ class MapDefunTest(test_base.DatasetTestBase):
r"indices = 10 is not in \[0, 5\)"): r"indices = 10 is not in \[0, 5\)"):
self.evaluate(map_defun_op) self.evaluate(map_defun_op)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithUnspecifiedOutputShape(self): def testMapDefunWithUnspecifiedOutputShape(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
@ -190,6 +214,8 @@ class MapDefunTest(test_base.DatasetTestBase):
self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1)) self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2)) self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithDifferentOutputShapeEachRun(self): def testMapDefunWithDifferentOutputShapeEachRun(self):
@function.defun( @function.defun(
@ -204,6 +230,8 @@ class MapDefunTest(test_base.DatasetTestBase):
self.assertAllEqual( self.assertAllEqual(
sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]]) sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithWrongOutputShape(self): def testMapDefunWithWrongOutputShape(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
@ -216,6 +244,8 @@ class MapDefunTest(test_base.DatasetTestBase):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(r) self.evaluate(r)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithInvalidInput(self): def testMapDefunWithInvalidInput(self):
@function.defun( @function.defun(
@ -233,6 +263,8 @@ class MapDefunTest(test_base.DatasetTestBase):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
sess.run(r, feed_dict={p: 0}) sess.run(r, feed_dict={p: 0})
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithParentCancellation(self): def testMapDefunWithParentCancellation(self):
# Checks that a cancellation of the parent graph is threaded through to # Checks that a cancellation of the parent graph is threaded through to
# MapDefunOp correctly. # MapDefunOp correctly.
@ -254,6 +286,8 @@ class MapDefunTest(test_base.DatasetTestBase):
sess.close() sess.close()
thread.join() thread.join()
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithCapturedInputs(self): def testMapDefunWithCapturedInputs(self):
c = constant_op.constant(2) c = constant_op.constant(2)
@ -266,6 +300,8 @@ class MapDefunTest(test_base.DatasetTestBase):
expected = x + c expected = x + c
self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op)) self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithVariantTensor(self): def testMapDefunWithVariantTensor(self):
@function.defun( @function.defun(
@ -288,6 +324,8 @@ class MapDefunTest(test_base.DatasetTestBase):
actual = self.evaluate(deserialized) actual = self.evaluate(deserialized)
self.assertValuesEqual(expected, actual) self.assertValuesEqual(expected, actual)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithVariantTensorAsCaptured(self): def testMapDefunWithVariantTensorAsCaptured(self):
st = sparse_tensor.SparseTensor( st = sparse_tensor.SparseTensor(
@ -309,6 +347,8 @@ class MapDefunTest(test_base.DatasetTestBase):
actual = self.evaluate(deserialized) actual = self.evaluate(deserialized)
self.assertValuesEqual(expected, actual) self.assertValuesEqual(expected, actual)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithStrTensor(self): def testMapDefunWithStrTensor(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])

View File

@ -28,14 +28,13 @@ from tensorflow.python.data.experimental.ops import threadpool
from tensorflow.python.data.experimental.ops import unique from tensorflow.python.data.experimental.ops import unique
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import script_ops from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class OverrideThreadpoolTest(test_base.DatasetTestBase, class OverrideThreadpoolTest(test_base.DatasetTestBase,
parameterized.TestCase): parameterized.TestCase):
@ -70,17 +69,13 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
# perform work. # perform work.
self.assertLessEqual(len(thread_ids), num_threads) self.assertLessEqual(len(thread_ids), num_threads)
@parameterized.named_parameters( @combinations.generate(
("1", 1, None), combinations.times(
("2", 2, None), test_base.default_test_combinations(),
("3", 4, None), combinations.combine(
("4", 8, None), num_threads=[1, 2, 4, 8, 16], max_intra_op_parallelism=[None]) +
("5", 16, None), combinations.combine(
("6", 4, -1), num_threads=[4], max_intra_op_parallelism=[-1, 0, 4])))
("7", 4, 0),
("8", 4, 1),
("9", 4, 4),
)
def testNumThreadsDeprecated(self, num_threads, max_intra_op_parallelism): def testNumThreadsDeprecated(self, num_threads, max_intra_op_parallelism):
def override_threadpool_fn(dataset): def override_threadpool_fn(dataset):
@ -93,20 +88,17 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
self._testNumThreadsHelper(num_threads, override_threadpool_fn) self._testNumThreadsHelper(num_threads, override_threadpool_fn)
@parameterized.named_parameters( @combinations.generate(
("1", 1, None), combinations.times(
("2", 2, None), test_base.default_test_combinations(),
("3", 4, None), combinations.combine(
("4", 8, None), num_threads=[1, 2, 4, 8, 16], max_intra_op_parallelism=[None]) +
("5", 16, None), combinations.combine(
("6", None, 0), num_threads=[None], max_intra_op_parallelism=[0, 1, 4]) +
("7", None, 1), combinations.combine(
("8", None, 4), num_threads=[4], max_intra_op_parallelism=[0, 1, 4]) +
("9", 4, 0), combinations.combine(
("10", 4, 1), num_threads=[None], max_intra_op_parallelism=[None])))
("11", 4, 4),
("12", None, None),
)
def testNumThreads(self, num_threads, max_intra_op_parallelism): def testNumThreads(self, num_threads, max_intra_op_parallelism):
def override_threadpool_fn(dataset): def override_threadpool_fn(dataset):
@ -121,6 +113,7 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
self._testNumThreadsHelper(num_threads, override_threadpool_fn) self._testNumThreadsHelper(num_threads, override_threadpool_fn)
@combinations.generate(test_base.default_test_combinations())
def testMaxIntraOpParallelismAsGraphDefInternal(self): def testMaxIntraOpParallelismAsGraphDefInternal(self):
dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset_ops.Dataset.from_tensors(0)
dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1) dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1)

View File

@ -22,24 +22,25 @@ import math
import threading import threading
import time import time
from absl.testing import parameterized
import numpy as np import numpy as np
from six.moves import zip_longest from six.moves import zip_longest
from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops from tensorflow.python.ops import script_ops
from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes # TODO(feihugis): refactor this test to be parameterized.
class ParallelInterleaveTest(test_base.DatasetTestBase): class ParallelInterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
def setUp(self): def setUp(self):
@ -116,6 +117,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
num_open -= 1 num_open -= 1
break break
@combinations.generate(test_base.default_test_combinations())
def testPythonImplementation(self): def testPythonImplementation(self):
input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]
@ -136,6 +138,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
(index, expected, produced)) (index, expected, produced))
@combinations.generate(test_base.default_test_combinations())
def testPythonImplementationBlockLength(self): def testPythonImplementationBlockLength(self):
input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2 input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2
expected_elements = [ expected_elements = [
@ -147,6 +150,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
(index, expected, produced)) (index, expected, produced))
@combinations.generate(test_base.default_test_combinations())
def testPythonImplementationEmptyLists(self): def testPythonImplementationEmptyLists(self):
input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [], input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [],
[6, 6, 6, 6, 6, 6]] [6, 6, 6, 6, 6, 6]]
@ -189,18 +193,23 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testSingleThreaded(self): def testSingleThreaded(self):
self._testSingleThreaded() self._testSingleThreaded()
@combinations.generate(test_base.default_test_combinations())
def testSingleThreadedSloppy(self): def testSingleThreadedSloppy(self):
self._testSingleThreaded(sloppy=True) self._testSingleThreaded(sloppy=True)
@combinations.generate(test_base.default_test_combinations())
def testSingleThreadedPrefetch1Itr(self): def testSingleThreadedPrefetch1Itr(self):
self._testSingleThreaded(prefetch_input_elements=1) self._testSingleThreaded(prefetch_input_elements=1)
@combinations.generate(test_base.default_test_combinations())
def testSingleThreadedPrefetch1ItrSloppy(self): def testSingleThreadedPrefetch1ItrSloppy(self):
self._testSingleThreaded(prefetch_input_elements=1, sloppy=True) self._testSingleThreaded(prefetch_input_elements=1, sloppy=True)
@combinations.generate(test_base.default_test_combinations())
def testSingleThreadedRagged(self): def testSingleThreadedRagged(self):
# Tests a sequence with wildly different elements per iterator. # Tests a sequence with wildly different elements per iterator.
self.skipTest("b/131722904") self.skipTest("b/131722904")
@ -259,9 +268,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContention(self): def testTwoThreadsNoContention(self):
self._testTwoThreadsNoContention() self._testTwoThreadsNoContention()
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContentionSloppy(self): def testTwoThreadsNoContentionSloppy(self):
self._testTwoThreadsNoContention(sloppy=True) self._testTwoThreadsNoContention(sloppy=True)
@ -306,9 +317,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContentionWithRaces(self): def testTwoThreadsNoContentionWithRaces(self):
self._testTwoThreadsNoContentionWithRaces() self._testTwoThreadsNoContentionWithRaces()
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContentionWithRacesSloppy(self): def testTwoThreadsNoContentionWithRacesSloppy(self):
self._testTwoThreadsNoContentionWithRaces(sloppy=True) self._testTwoThreadsNoContentionWithRaces(sloppy=True)
@ -343,9 +356,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContentionBlockLength(self): def testTwoThreadsNoContentionBlockLength(self):
self._testTwoThreadsNoContentionBlockLength() self._testTwoThreadsNoContentionBlockLength()
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContentionBlockLengthSloppy(self): def testTwoThreadsNoContentionBlockLengthSloppy(self):
self._testTwoThreadsNoContentionBlockLength(sloppy=True) self._testTwoThreadsNoContentionBlockLength(sloppy=True)
@ -391,9 +406,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContentionWithRacesAndBlocking(self): def testTwoThreadsNoContentionWithRacesAndBlocking(self):
self._testTwoThreadsNoContentionWithRacesAndBlocking() self._testTwoThreadsNoContentionWithRacesAndBlocking()
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self): def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self):
self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True) self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
@ -411,9 +428,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testEmptyInput(self): def testEmptyInput(self):
self._testEmptyInput() self._testEmptyInput()
@combinations.generate(test_base.default_test_combinations())
def testEmptyInputSloppy(self): def testEmptyInputSloppy(self):
self._testEmptyInput(sloppy=True) self._testEmptyInput(sloppy=True)
@ -431,9 +450,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testNonEmptyInputIntoEmptyOutputs(self): def testNonEmptyInputIntoEmptyOutputs(self):
self._testNonEmptyInputIntoEmptyOutputs() self._testNonEmptyInputIntoEmptyOutputs()
@combinations.generate(test_base.default_test_combinations())
def testNonEmptyInputIntoEmptyOutputsSloppy(self): def testNonEmptyInputIntoEmptyOutputsSloppy(self):
self._testNonEmptyInputIntoEmptyOutputs(sloppy=True) self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
@ -469,12 +490,15 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
"At index %s: %s expected, got: %s" % (i, expected_element, "At index %s: %s expected, got: %s" % (i, expected_element,
actual_element)) actual_element))
@combinations.generate(test_base.default_test_combinations())
def testPartiallyEmptyOutputs(self): def testPartiallyEmptyOutputs(self):
self._testPartiallyEmptyOutputs() self._testPartiallyEmptyOutputs()
@combinations.generate(test_base.default_test_combinations())
def testPartiallyEmptyOutputsSloppy(self): def testPartiallyEmptyOutputsSloppy(self):
self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0) self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0)
@combinations.generate(test_base.default_test_combinations())
def testDelayedOutputSloppy(self): def testDelayedOutputSloppy(self):
# Explicitly control the sequence of events to ensure we correctly avoid # Explicitly control the sequence of events to ensure we correctly avoid
# head-of-line blocking. # head-of-line blocking.
@ -500,6 +524,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testBlockLengthWithContentionSloppy(self): def testBlockLengthWithContentionSloppy(self):
self.skipTest("b/131722904") self.skipTest("b/131722904")
self._clear_coordination_events() self._clear_coordination_events()
@ -557,9 +582,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
self.read_coordination_events[i].acquire() self.read_coordination_events[i].acquire()
self.write_coordination_events[i].set() self.write_coordination_events[i].set()
@combinations.generate(test_base.default_test_combinations())
def testEarlyExit(self): def testEarlyExit(self):
self._testEarlyExit() self._testEarlyExit()
@combinations.generate(test_base.default_test_combinations())
def testEarlyExitSloppy(self): def testEarlyExitSloppy(self):
self._testEarlyExit(sloppy=True) self._testEarlyExit(sloppy=True)
@ -584,12 +611,15 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2) [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
self.assertItemsEqual(output_values, expected_values) self.assertItemsEqual(output_values, expected_values)
@combinations.generate(test_base.default_test_combinations())
def testTooManyReaders(self): def testTooManyReaders(self):
self._testTooManyReaders() self._testTooManyReaders()
@combinations.generate(test_base.default_test_combinations())
def testTooManyReadersSloppy(self): def testTooManyReadersSloppy(self):
self._testTooManyReaders(sloppy=True) self._testTooManyReaders(sloppy=True)
@combinations.generate(test_base.default_test_combinations())
def testSparse(self): def testSparse(self):
def _map_fn(i): def _map_fn(i):
return sparse_tensor.SparseTensor( return sparse_tensor.SparseTensor(
@ -610,6 +640,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next()) self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testErrorsInOutputFn(self): def testErrorsInOutputFn(self):
self.skipTest("b/131722904") self.skipTest("b/131722904")
self._clear_coordination_events() self._clear_coordination_events()
@ -642,6 +673,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testErrorsInInputFn(self): def testErrorsInInputFn(self):
def map_py_fn(x): def map_py_fn(x):
@ -687,6 +719,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testErrorsInInterleaveFn(self): def testErrorsInInterleaveFn(self):
def map_py_fn(x): def map_py_fn(x):
@ -730,6 +763,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testShutdownRace(self): def testShutdownRace(self):
dataset = dataset_ops.Dataset.range(20) dataset = dataset_ops.Dataset.range(20)
map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1)) map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1))

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import copy import copy
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.core.example import example_pb2 from tensorflow.core.example import example_pb2
@ -28,11 +29,11 @@ from tensorflow.python.data.experimental.ops import parsing_ops as contrib_parsi
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -50,8 +51,8 @@ feature_lists = lambda d: feature_pb2.FeatureLists(feature_list=d)
sequence_example = example_pb2.SequenceExample sequence_example = example_pb2.SequenceExample
@test_util.run_all_in_graph_and_eager_modes class ParseExampleDatasetTest(test_base.DatasetTestBase,
class ParseExampleDatasetTest(test_base.DatasetTestBase): parameterized.TestCase):
def _compare_output_to_expected(self, dict_tensors, expected_tensors): def _compare_output_to_expected(self, dict_tensors, expected_tensors):
self.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys())) self.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
@ -107,6 +108,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
self.assertEqual( self.assertEqual(
dataset_ops.get_legacy_output_shapes(dataset)[k].as_list()[1], None) dataset_ops.get_legacy_output_shapes(dataset)[k].as_list()[1], None)
@combinations.generate(test_base.default_test_combinations())
def testEmptySerializedWithAllDefaults(self): def testEmptySerializedWithAllDefaults(self):
sparse_name = "st_a" sparse_name = "st_a"
a_name = "a" a_name = "a"
@ -145,7 +147,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output, expected_values=expected_output,
create_iterator_twice=True) create_iterator_twice=True)
@test_util.run_deprecated_v1 @combinations.generate(test_base.graph_only_combinations())
def testEmptySerializedWithoutDefaultsShouldFail(self): def testEmptySerializedWithoutDefaultsShouldFail(self):
input_features = { input_features = {
"st_a": "st_a":
@ -179,7 +181,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_err=(errors_impl.InvalidArgumentError, expected_err=(errors_impl.InvalidArgumentError,
"Feature: c \\(data type: float\\) is required")) "Feature: c \\(data type: float\\) is required"))
@test_util.run_deprecated_v1 @combinations.generate(test_base.graph_only_combinations())
def testDenseNotMatchingShapeShouldFail(self): def testDenseNotMatchingShapeShouldFail(self):
original = [ original = [
example(features=features({ example(features=features({
@ -197,6 +199,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_err=(errors_impl.InvalidArgumentError, expected_err=(errors_impl.InvalidArgumentError,
"Key: a, Index: 1. Number of float values")) "Key: a, Index: 1. Number of float values"))
@combinations.generate(test_base.default_test_combinations())
def testDenseDefaultNoShapeShouldFail(self): def testDenseDefaultNoShapeShouldFail(self):
original = [example(features=features({"a": float_feature([1, 1, 3]),})),] original = [example(features=features({"a": float_feature([1, 1, 3]),})),]
@ -207,6 +210,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
{"a": parsing_ops.FixedLenFeature(None, dtypes.float32)}, {"a": parsing_ops.FixedLenFeature(None, dtypes.float32)},
expected_err=(ValueError, "Missing shape for feature a")) expected_err=(ValueError, "Missing shape for feature a"))
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingSparse(self): def testSerializedContainingSparse(self):
original = [ original = [
example(features=features({ example(features=features({
@ -248,6 +252,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output, expected_values=expected_output,
create_iterator_twice=True) create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingSparseFeature(self): def testSerializedContainingSparseFeature(self):
original = [ original = [
example(features=features({ example(features=features({
@ -284,6 +289,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output, expected_values=expected_output,
create_iterator_twice=True) create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingSparseFeatureReuse(self): def testSerializedContainingSparseFeatureReuse(self):
original = [ original = [
example(features=features({ example(features=features({
@ -325,6 +331,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output, expected_values=expected_output,
create_iterator_twice=True) create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContaining3DSparseFeature(self): def testSerializedContaining3DSparseFeature(self):
original = [ original = [
example(features=features({ example(features=features({
@ -370,6 +377,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output, expected_values=expected_output,
create_iterator_twice=True) create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingDense(self): def testSerializedContainingDense(self):
aname = "a" aname = "a"
bname = "b*has+a:tricky_name" bname = "b*has+a:tricky_name"
@ -407,6 +415,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
# This test is identical as the previous one except # This test is identical as the previous one except
# for the creation of 'serialized'. # for the creation of 'serialized'.
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingDenseWithConcat(self): def testSerializedContainingDenseWithConcat(self):
aname = "a" aname = "a"
bname = "b*has+a:tricky_name" bname = "b*has+a:tricky_name"
@ -452,6 +461,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output, expected_values=expected_output,
create_iterator_twice=True) create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingDenseScalar(self): def testSerializedContainingDenseScalar(self):
original = [ original = [
example(features=features({ example(features=features({
@ -476,6 +486,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output, expected_values=expected_output,
create_iterator_twice=True) create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingDenseWithDefaults(self): def testSerializedContainingDenseWithDefaults(self):
original = [ original = [
example(features=features({ example(features=features({
@ -514,6 +525,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output, expected_values=expected_output,
create_iterator_twice=True) create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedSparseAndSparseFeatureAndDenseWithNoDefault(self): def testSerializedSparseAndSparseFeatureAndDenseWithNoDefault(self):
expected_st_a = sparse_tensor.SparseTensorValue( # indices, values, shape expected_st_a = sparse_tensor.SparseTensorValue( # indices, values, shape
np.empty((0, 2), dtype=np.int64), # indices np.empty((0, 2), dtype=np.int64), # indices
@ -569,6 +581,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output, expected_values=expected_output,
create_iterator_twice=True) create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testerializedContainingSparseAndSparseFeatureWithReuse(self): def testerializedContainingSparseAndSparseFeatureWithReuse(self):
expected_idx = sparse_tensor.SparseTensorValue( # indices, values, shape expected_idx = sparse_tensor.SparseTensorValue( # indices, values, shape
np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64), np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
@ -667,11 +680,13 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output, expected_values=expected_output,
create_iterator_twice=True) create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingVarLenDenseLargerBatch(self): def testSerializedContainingVarLenDenseLargerBatch(self):
np.random.seed(3456) np.random.seed(3456)
for batch_size in (1, 10, 20, 100, 256): for batch_size in (1, 10, 20, 100, 256):
self._testSerializedContainingVarLenDenseLargerBatch(batch_size) self._testSerializedContainingVarLenDenseLargerBatch(batch_size)
@combinations.generate(test_base.default_test_combinations())
def testSerializedShapeMismatch(self): def testSerializedShapeMismatch(self):
aname = "a" aname = "a"
bname = "b" bname = "b"
@ -724,7 +739,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_err=(ValueError, expected_err=(ValueError,
"Cannot reshape a tensor with 0 elements to shape")) "Cannot reshape a tensor with 0 elements to shape"))
@test_util.run_deprecated_v1 @combinations.generate(test_base.graph_only_combinations())
def testSerializedContainingVarLenDense(self): def testSerializedContainingVarLenDense(self):
aname = "a" aname = "a"
bname = "b" bname = "b"
@ -877,6 +892,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
"Unsupported: FixedLenSequenceFeature requires " "Unsupported: FixedLenSequenceFeature requires "
"allow_missing to be True.")) "allow_missing to be True."))
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingRaggedFeatureWithNoPartitions(self): def testSerializedContainingRaggedFeatureWithNoPartitions(self):
original = [ original = [
example( example(
@ -922,6 +938,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output, expected_values=expected_output,
create_iterator_twice=True) create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingRaggedFeatureWithOnePartition(self): def testSerializedContainingRaggedFeatureWithOnePartition(self):
original = [ original = [
example( example(
@ -1040,6 +1057,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output, expected_values=expected_output,
create_iterator_twice=True) create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingRaggedFeatureWithMultiplePartitions(self): def testSerializedContainingRaggedFeatureWithMultiplePartitions(self):
original = [ original = [
# rt shape: [(batch), 2, None, None] # rt shape: [(batch), 2, None, None]

View File

@ -17,11 +17,14 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.experimental.ops import prefetching_ops from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import structure from tensorflow.python.data.util import structure
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -31,9 +34,9 @@ from tensorflow.python.platform import test
# TODO(b/117581999): add eager coverage when supported. # TODO(b/117581999): add eager coverage when supported.
class PrefetchToDeviceTest(test_base.DatasetTestBase): class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testPrefetchToDevice(self): def testPrefetchToDevice(self):
host_dataset = dataset_ops.Dataset.range(10) host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply( device_dataset = host_dataset.apply(
@ -57,7 +60,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testPrefetchToSameDevice(self): def testPrefetchToSameDevice(self):
host_dataset = dataset_ops.Dataset.range(10) host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply( device_dataset = host_dataset.apply(
@ -82,7 +85,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testPrefetchDictToDevice(self): def testPrefetchDictToDevice(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
device_dataset = host_dataset.apply( device_dataset = host_dataset.apply(
@ -106,7 +109,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testPrefetchSparseTensorsToDevice(self): def testPrefetchSparseTensorsToDevice(self):
def make_tensor(i): def make_tensor(i):
return sparse_tensor.SparseTensorValue( return sparse_tensor.SparseTensorValue(
@ -136,7 +139,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testPrefetchToDeviceGpu(self): def testPrefetchToDeviceGpu(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")
@ -156,7 +159,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testPrefetchToDeviceWithReInit(self): def testPrefetchToDeviceWithReInit(self):
host_dataset = dataset_ops.Dataset.range(10) host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply( device_dataset = host_dataset.apply(
@ -184,7 +187,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element) self.evaluate(next_element)
@test_util.deprecated_graph_mode_only @combinations.generate(test_base.graph_only_combinations())
def testPrefetchToDeviceGpuWithReInit(self): def testPrefetchToDeviceGpuWithReInit(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")

View File

@ -24,16 +24,17 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase): class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
@test_util.run_v1_only("b/121264236") # TODO(b/121264236)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testPrefetchWithSlackOption(self): def testPrefetchWithSlackOption(self):
"""Determines slack_period based on num devices attached to iterator.""" """Determines slack_period based on num devices attached to iterator."""
dataset = dataset_ops.Dataset.range(10) dataset = dataset_ops.Dataset.range(10)
@ -60,6 +61,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
self.evaluate(elem_on_1) self.evaluate(elem_on_1)
self.evaluate(elem_on_2) self.evaluate(elem_on_2)
@combinations.generate(test_base.default_test_combinations())
def testPrefetchWithSlackOptionWithoutIterator(self): def testPrefetchWithSlackOptionWithoutIterator(self):
"""Defaults to slack period of 1 without iterator.""" """Defaults to slack period of 1 without iterator."""
dataset = dataset_ops.Dataset.range(10) dataset = dataset_ops.Dataset.range(10)
@ -72,6 +74,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset.options()._graph_rewrite_configs()) dataset.options()._graph_rewrite_configs())
self.assertDatasetProduces(dataset, range(10)) self.assertDatasetProduces(dataset, range(10))
@combinations.generate(test_base.default_test_combinations())
def testWithPassthroughDataset(self): def testWithPassthroughDataset(self):
"""Should still work with a passthrough dataset after prefetch().""" """Should still work with a passthrough dataset after prefetch()."""
dataset = dataset_ops.Dataset.range(10) dataset = dataset_ops.Dataset.range(10)
@ -82,6 +85,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.with_options(options) dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, range(1, 11)) self.assertDatasetProduces(dataset, range(1, 11))
@combinations.generate(test_base.default_test_combinations())
def testErrorWithoutPrefetch(self): def testErrorWithoutPrefetch(self):
"""The rewrite fails if there is no prefetch() in the pipeline.""" """The rewrite fails if there is no prefetch() in the pipeline."""
dataset = dataset_ops.Dataset.range(10) dataset = dataset_ops.Dataset.range(10)
@ -92,6 +96,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = self.getNext(dataset) get_next = self.getNext(dataset)
self.evaluate(get_next()) self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testErrorWithInvalidDataset(self): def testErrorWithInvalidDataset(self):
"""With a nested dataset op after prefetch, the rewrite should fail.""" """With a nested dataset op after prefetch, the rewrite should fail."""
dataset = dataset_ops.Dataset.range(10) dataset = dataset_ops.Dataset.range(10)

View File

@ -32,8 +32,8 @@ from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest from tensorflow.python.data.util import nest
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import python_io from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -47,13 +47,11 @@ def _flat_shapes(dataset):
return nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)) return nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
@test_util.run_all_in_graph_and_eager_modes
class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
drop_remainder_cases = [("WithDropRemainder", True), @combinations.generate(
("WithoutDropRemainder", False)] combinations.times(test_base.default_test_combinations(),
combinations.combine(drop_remainder=[True, False])))
@parameterized.named_parameters(drop_remainder_cases)
def testBasic(self, drop_remainder): def testBasic(self, drop_remainder):
dataset = dataset_ops.Dataset.range(1024).batch( dataset = dataset_ops.Dataset.range(1024).batch(
32, drop_remainder=drop_remainder) 32, drop_remainder=drop_remainder)
@ -64,13 +62,16 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testScalarInputError(self): def testScalarInputError(self):
dataset = dataset_ops.Dataset.range(1024) dataset = dataset_ops.Dataset.range(1024)
distribute._RebatchDataset(dataset.batch(4), num_replicas=4) distribute._RebatchDataset(dataset.batch(4), num_replicas=4)
with self.assertRaisesRegexp(ValueError, "at least one dimension"): with self.assertRaisesRegexp(ValueError, "at least one dimension"):
distribute._RebatchDataset(dataset, num_replicas=4) distribute._RebatchDataset(dataset, num_replicas=4)
@parameterized.named_parameters(drop_remainder_cases) @combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(drop_remainder=[True, False])))
def testBatchNotDivisibleByNumReplicas(self, drop_remainder): def testBatchNotDivisibleByNumReplicas(self, drop_remainder):
dataset = dataset_ops.Dataset.range(1024).batch( dataset = dataset_ops.Dataset.range(1024).batch(
32, drop_remainder=drop_remainder) 32, drop_remainder=drop_remainder)
@ -89,6 +90,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
i += 4 i += 4
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testBatchSizeNotDivisibleByNumReplicas2(self): def testBatchSizeNotDivisibleByNumReplicas2(self):
dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True) dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True)
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
@ -102,6 +104,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output.extend([[]]) # Last replica gets an empty batch expected_output.extend([[]]) # Last replica gets an empty batch
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testTupleOutput(self): def testTupleOutput(self):
dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32) dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32)
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
@ -110,6 +113,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 1024, 8)] for i in range(0, 1024, 8)]
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testNestedDictionaryOutput(self): def testNestedDictionaryOutput(self):
dataset = dataset_ops.Dataset.range(1024).map( dataset = dataset_ops.Dataset.range(1024).map(
lambda x: {"a": x, "b": {"c": x}}).batch(32) lambda x: {"a": x, "b": {"c": x}}).batch(32)
@ -119,7 +123,9 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 1024, 8)] for i in range(0, 1024, 8)]
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@parameterized.named_parameters(drop_remainder_cases) @combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(drop_remainder=[True, False])))
def testFinalPartialBatch(self, drop_remainder): def testFinalPartialBatch(self, drop_remainder):
dataset = dataset_ops.Dataset.range(1032).batch( dataset = dataset_ops.Dataset.range(1032).batch(
32, drop_remainder=drop_remainder) 32, drop_remainder=drop_remainder)
@ -136,7 +142,9 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
[[k for k in range(i, i + 2)] for i in range(1024, 1032, 2)]) [[k for k in range(i, i + 2)] for i in range(1024, 1032, 2)])
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@parameterized.named_parameters(drop_remainder_cases) @combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(drop_remainder=[True, False])))
def testFinalPartialBatchAfterRebatch(self, drop_remainder): def testFinalPartialBatchAfterRebatch(self, drop_remainder):
dataset = dataset_ops.Dataset.range(34).batch( dataset = dataset_ops.Dataset.range(34).batch(
32, drop_remainder=drop_remainder) 32, drop_remainder=drop_remainder)
@ -150,6 +158,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output += [[32], [33], [], []] expected_output += [[32], [33], [], []]
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testMultipleBatches(self): def testMultipleBatches(self):
dataset = dataset_ops.Dataset.range(128).batch(4).batch(8) dataset = dataset_ops.Dataset.range(128).batch(4).batch(8)
self.assertEqual([[None, None]], self.assertEqual([[None, None]],
@ -170,6 +179,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 128, 8)] for i in range(0, 128, 8)]
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testMapAndBatch(self): def testMapAndBatch(self):
dataset = dataset_ops.Dataset.range(1024).apply( dataset = dataset_ops.Dataset.range(1024).apply(
batching.map_and_batch(math_ops.square, 32)) batching.map_and_batch(math_ops.square, 32))
@ -180,6 +190,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 1024, 8)] for i in range(0, 1024, 8)]
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testMapAndBatchWithCapturedInput(self): def testMapAndBatchWithCapturedInput(self):
captured_t = variables.Variable(42) captured_t = variables.Variable(42)
dataset = dataset_ops.Dataset.range(1024).apply( dataset = dataset_ops.Dataset.range(1024).apply(
@ -193,6 +204,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces( self.assertDatasetProduces(
rebatched_dataset, expected_output, requires_initialization=True) rebatched_dataset, expected_output, requires_initialization=True)
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatch(self): def testPaddedBatch(self):
dataset = dataset_ops.Dataset.range(128).batch( dataset = dataset_ops.Dataset.range(128).batch(
4, drop_remainder=True).padded_batch( 4, drop_remainder=True).padded_batch(
@ -213,6 +225,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 128, 8)] for i in range(0, 128, 8)]
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testConcatenate(self): def testConcatenate(self):
dataset1 = dataset_ops.Dataset.range(64).batch(8) dataset1 = dataset_ops.Dataset.range(64).batch(8)
dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset2 = dataset_ops.Dataset.range(32).batch(8)
@ -224,6 +237,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
[[i, i + 1] for i in range(0, 32, 2)]) [[i, i + 1] for i in range(0, 32, 2)])
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testConcatenateDifferentShapes(self): def testConcatenateDifferentShapes(self):
dataset1 = dataset_ops.Dataset.range(64).batch(16) dataset1 = dataset_ops.Dataset.range(64).batch(16)
dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset2 = dataset_ops.Dataset.range(32).batch(8)
@ -235,6 +249,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
[[i, i + 1] for i in range(0, 32, 2)]) [[i, i + 1] for i in range(0, 32, 2)])
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testZip(self): def testZip(self):
dataset1 = dataset_ops.Dataset.range(64).batch(8) dataset1 = dataset_ops.Dataset.range(64).batch(8)
dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset2 = dataset_ops.Dataset.range(32).batch(8)
@ -245,6 +260,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)] expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)]
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testZipDifferentShapes(self): def testZipDifferentShapes(self):
dataset1 = dataset_ops.Dataset.range(64).batch(16) dataset1 = dataset_ops.Dataset.range(64).batch(16)
dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset2 = dataset_ops.Dataset.range(32).batch(8)
@ -256,6 +272,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 32, 2)] for i in range(0, 32, 2)]
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testFlatMapBatching(self): def testFlatMapBatching(self):
dataset = dataset_ops.Dataset.range(2).flat_map( dataset = dataset_ops.Dataset.range(2).flat_map(
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
@ -274,6 +291,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 32, 8)] # generates 4 elements for i in range(0, 32, 8)] # generates 4 elements
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testInterleaveBatching(self): def testInterleaveBatching(self):
dataset = dataset_ops.Dataset.range(2).interleave( dataset = dataset_ops.Dataset.range(2).interleave(
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
@ -290,6 +308,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output += expected_output expected_output += expected_output
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testParallelInterleaveBatching(self): def testParallelInterleaveBatching(self):
dataset = dataset_ops.Dataset.range(2).interleave( dataset = dataset_ops.Dataset.range(2).interleave(
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
@ -307,6 +326,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output += expected_output expected_output += expected_output
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testGroupByWindowStaticBatch(self): def testGroupByWindowStaticBatch(self):
dataset = dataset_ops.Dataset.from_tensor_slices( dataset = dataset_ops.Dataset.from_tensor_slices(
[[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)]) [[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)])
@ -326,6 +346,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for k in range(2)] for k in range(2)]
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testGroupByWindowDynamicBatch(self): def testGroupByWindowDynamicBatch(self):
# {0, 1, 0, 1, ...} # {0, 1, 0, 1, ...}
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
@ -350,6 +371,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output = [[value] * batch_size for batch_size, value in pairs] expected_output = [[value] * batch_size for batch_size, value in pairs]
self.assertDatasetProduces(dataset, expected_output) self.assertDatasetProduces(dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testGroupByWindowDynamicBatchWithPartialBatch(self): def testGroupByWindowDynamicBatchWithPartialBatch(self):
# {0, 1, 0, 1, ...} # {0, 1, 0, 1, ...}
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
@ -371,6 +393,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output = [[value] * batch_size for batch_size, value in pairs] expected_output = [[value] * batch_size for batch_size, value in pairs]
self.assertDatasetProduces(dataset, expected_output) self.assertDatasetProduces(dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testGroupByWindowDynamicBatchWithPartialBatchWithDropRemainder(self): def testGroupByWindowDynamicBatchWithPartialBatchWithDropRemainder(self):
# This test exercises nested batch functionality, dynamic batch size # This test exercises nested batch functionality, dynamic batch size
# and drop_remainder=True together. # and drop_remainder=True together.
@ -395,6 +418,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output = [[value] * batch_size for batch_size, value in pairs] expected_output = [[value] * batch_size for batch_size, value in pairs]
self.assertDatasetProduces(dataset, expected_output) self.assertDatasetProduces(dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testScanAfterBatch(self): def testScanAfterBatch(self):
dataset = dataset_ops.Dataset.range(40).batch(10).apply( dataset = dataset_ops.Dataset.range(40).batch(10).apply(
scan_ops.scan(np.int64(2), lambda state, value: (state, value * state))) scan_ops.scan(np.int64(2), lambda state, value: (state, value * state)))
@ -405,6 +429,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)] # pylint: disable=g-complex-comprehension expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)] # pylint: disable=g-complex-comprehension
self.assertDatasetProduces(dataset, expected_output) self.assertDatasetProduces(dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testMakeBatchedFeaturesDataset(self): def testMakeBatchedFeaturesDataset(self):
# Set up # Set up
fn = os.path.join(self.get_temp_dir(), "tf_record.txt") fn = os.path.join(self.get_temp_dir(), "tf_record.txt")
@ -438,6 +463,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
} for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension } for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
self.assertDatasetProduces(rebatched_dataset, expected_output) self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testRaggedTensorDataset(self): def testRaggedTensorDataset(self):
# Set up a dataset that produces ragged tensors with a static batch size. # Set up a dataset that produces ragged tensors with a static batch size.
row_lengths = np.random.randint(8, size=128) row_lengths = np.random.randint(8, size=128)

View File

@ -24,9 +24,9 @@ import numpy as np
from tensorflow.python.data.experimental.ops import resampling from tensorflow.python.data.experimental.ops import resampling
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import string_ops from tensorflow.python.ops import string_ops
@ -34,12 +34,11 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat from tensorflow.python.util import compat
@test_util.run_all_in_graph_and_eager_modes
class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase): class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters( @combinations.generate(
("InitialDistributionKnown", True), combinations.times(test_base.default_test_combinations(),
("InitialDistributionUnknown", False)) combinations.combine(initial_known=[True, False])))
def testDistribution(self, initial_known): def testDistribution(self, initial_known):
classes = np.random.randint(5, size=(20000,)) # Uniformly sampled classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
target_dist = [0.9, 0.05, 0.05, 0.0, 0.0] target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
@ -72,9 +71,9 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
returned_dist = class_counts / total_returned returned_dist = class_counts / total_returned
self.assertAllClose(target_dist, returned_dist, atol=1e-2) self.assertAllClose(target_dist, returned_dist, atol=1e-2)
@parameterized.named_parameters( @combinations.generate(
("OnlyInitial", True), combinations.times(test_base.default_test_combinations(),
("NotInitial", False)) combinations.combine(only_initial_dist=[True, False])))
def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist): def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist):
init_dist = [0.5, 0.5] init_dist = [0.5, 0.5]
target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0] target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0]
@ -99,6 +98,7 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
while True: while True:
returned.append(self.evaluate(get_next())) returned.append(self.evaluate(get_next()))
@combinations.generate(test_base.default_test_combinations())
def testRandomClasses(self): def testRandomClasses(self):
init_dist = [0.25, 0.25, 0.25, 0.25] init_dist = [0.25, 0.25, 0.25, 0.25]
target_dist = [0.0, 0.0, 0.0, 1.0] target_dist = [0.0, 0.0, 0.0, 1.0]

View File

@ -17,18 +17,18 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.experimental.ops import shuffle_ops from tensorflow.python.data.experimental.ops import shuffle_ops
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes class ShuffleAndRepeatTest(test_base.DatasetTestBase, parameterized.TestCase):
class ShuffleAndRepeatTest(test_base.DatasetTestBase):
def _build_ds(self, seed, count=5, num_elements=20): def _build_ds(self, seed, count=5, num_elements=20):
return dataset_ops.Dataset.range(num_elements).apply( return dataset_ops.Dataset.range(num_elements).apply(
@ -44,6 +44,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
self.evaluate(get_next()) self.evaluate(get_next())
return outputs return outputs
@combinations.generate(test_base.default_test_combinations())
def testCorrectOutput(self): def testCorrectOutput(self):
output = self._gen_outputs(lambda: self._build_ds(10), 100) output = self._gen_outputs(lambda: self._build_ds(10), 100)
self.assertSequenceEqual( self.assertSequenceEqual(
@ -52,6 +53,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
for i in range(5): for i in range(5):
self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20)) self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20))
@combinations.generate(test_base.default_test_combinations())
def testReshuffling(self): def testReshuffling(self):
# Check that the output orders of different epochs are indeed different. # Check that the output orders of different epochs are indeed different.
output = self._gen_outputs(lambda: self._build_ds(10), 100) output = self._gen_outputs(lambda: self._build_ds(10), 100)
@ -60,17 +62,20 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
epoch2 = output[(i + 1) * 20:(i + 2) * 20] epoch2 = output[(i + 1) * 20:(i + 2) * 20]
self.assertNotEqual(epoch1, epoch2) self.assertNotEqual(epoch1, epoch2)
@combinations.generate(test_base.default_test_combinations())
def testSameOrderForSameSeeds(self): def testSameOrderForSameSeeds(self):
output1 = self._gen_outputs(lambda: self._build_ds(10), 100) output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
output2 = self._gen_outputs(lambda: self._build_ds(10), 100) output2 = self._gen_outputs(lambda: self._build_ds(10), 100)
self.assertEqual(output1, output2) self.assertEqual(output1, output2)
@combinations.generate(test_base.default_test_combinations())
def testDifferentOrderForDifferentSeeds(self): def testDifferentOrderForDifferentSeeds(self):
output1 = self._gen_outputs(lambda: self._build_ds(10), 100) output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
output2 = self._gen_outputs(lambda: self._build_ds(20), 100) output2 = self._gen_outputs(lambda: self._build_ds(20), 100)
self.assertNotEqual(output1, output2) self.assertNotEqual(output1, output2)
self.assertEqual(sorted(output1), sorted(output2)) self.assertEqual(sorted(output1), sorted(output2))
@combinations.generate(test_base.default_test_combinations())
def testCountNone(self): def testCountNone(self):
output1 = self._gen_outputs( output1 = self._gen_outputs(
lambda: self._build_ds(10, count=None), 100, verify_exhausted=False) lambda: self._build_ds(10, count=None), 100, verify_exhausted=False)
@ -79,6 +84,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
self.assertNotEqual(output1, output2) self.assertNotEqual(output1, output2)
self.assertEqual(sorted(output1), sorted(output2)) self.assertEqual(sorted(output1), sorted(output2))
@combinations.generate(test_base.default_test_combinations())
def testCountMinusOne(self): def testCountMinusOne(self):
output1 = self._gen_outputs( output1 = self._gen_outputs(
lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False) lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False)
@ -87,6 +93,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
self.assertNotEqual(output1, output2) self.assertNotEqual(output1, output2)
self.assertEqual(sorted(output1), sorted(output2)) self.assertEqual(sorted(output1), sorted(output2))
@combinations.generate(test_base.default_test_combinations())
def testInfiniteOutputs(self): def testInfiniteOutputs(self):
# Asserting the iterator is exhausted after producing 100 items should fail. # Asserting the iterator is exhausted after producing 100 items should fail.
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
@ -94,6 +101,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self._gen_outputs(lambda: self._build_ds(10, count=-1), 100) self._gen_outputs(lambda: self._build_ds(10, count=-1), 100)
@combinations.generate(test_base.default_test_combinations())
def testInfiniteEmpty(self): def testInfiniteEmpty(self):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0),
@ -102,12 +110,14 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0),
100) 100)
@combinations.generate(test_base.default_test_combinations())
def testLargeBufferSize(self): def testLargeBufferSize(self):
ds = dataset_ops.Dataset.range(20).apply( ds = dataset_ops.Dataset.range(20).apply(
shuffle_ops.shuffle_and_repeat(buffer_size=21)) shuffle_ops.shuffle_and_repeat(buffer_size=21))
get_next = self.getNext(ds) get_next = self.getNext(ds)
self.evaluate(get_next()) self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testVeryLargeBufferSize(self): def testVeryLargeBufferSize(self):
num_epochs = 1000 * 1000 num_epochs = 1000 * 1000
# Each element being shuffled and repeated has shape (100,). This will OOM # Each element being shuffled and repeated has shape (100,). This will OOM

View File

@ -18,18 +18,22 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase,
class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): parameterized.TestCase):
# Test that SqlDataset can read from a database table. # Test that SqlDataset can read from a database table.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSet(self): def testReadResultSet(self):
for _ in range(2): # Run twice to verify statelessness of db operations. for _ in range(2): # Run twice to verify statelessness of db operations.
dataset = self._createSqlDataset( dataset = self._createSqlDataset(
@ -44,6 +48,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
num_test_iterations=2) num_test_iterations=2)
# Test that SqlDataset works on a join query. # Test that SqlDataset works on a join query.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetJoinQuery(self): def testReadResultSetJoinQuery(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -60,6 +65,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that SqlDataset can read a database entry with a null-terminator # Test that SqlDataset can read a database entry with a null-terminator
# in the middle of the text and place the entry in a `string` tensor. # in the middle of the text and place the entry in a `string` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetNullTerminator(self): def testReadResultSetNullTerminator(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -76,6 +82,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that SqlDataset works when used on two different queries. # Test that SqlDataset works when used on two different queries.
# Because the output types of the dataset must be determined at graph-creation # Because the output types of the dataset must be determined at graph-creation
# time, the two queries must have the same number and types of columns. # time, the two queries must have the same number and types of columns.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetReuseSqlDataset(self): def testReadResultSetReuseSqlDataset(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -100,6 +107,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that an `OutOfRangeError` is raised on the first call to # Test that an `OutOfRangeError` is raised on the first call to
# `get_next_str_only` if result set is empty. # `get_next_str_only` if result set is empty.
@combinations.generate(test_base.default_test_combinations())
def testReadEmptyResultSet(self): def testReadEmptyResultSet(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -110,6 +118,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.evaluate(get_next()) self.evaluate(get_next())
# Test that an error is raised when `driver_name` is invalid. # Test that an error is raised when `driver_name` is invalid.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetWithInvalidDriverName(self): def testReadResultSetWithInvalidDriverName(self):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
dataset = self._createSqlDataset( dataset = self._createSqlDataset(
@ -120,6 +129,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.assertDatasetProduces(dataset, expected_output=[]) self.assertDatasetProduces(dataset, expected_output=[])
# Test that an error is raised when a column name in `query` is nonexistent # Test that an error is raised when a column name in `query` is nonexistent
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetWithInvalidColumnName(self): def testReadResultSetWithInvalidColumnName(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -130,6 +140,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.evaluate(get_next()) self.evaluate(get_next())
# Test that an error is raised when there is a syntax error in `query`. # Test that an error is raised when there is a syntax error in `query`.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetOfQueryWithSyntaxError(self): def testReadResultSetOfQueryWithSyntaxError(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -141,6 +152,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that an error is raised when the number of columns in `query` # Test that an error is raised when the number of columns in `query`
# does not match the length of `, output_types`. # does not match the length of `, output_types`.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self): def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -154,6 +166,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# than a select query. In particular, the error refers to the number of # than a select query. In particular, the error refers to the number of
# output types passed to the op not matching the number of columns in the # output types passed to the op not matching the number of columns in the
# result set of the query (namely, 0 for an insert statement.) # result set of the query (namely, 0 for an insert statement.)
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetOfInsertQuery(self): def testReadResultSetOfInsertQuery(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -165,6 +178,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read an integer from a SQLite database table and # Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int8` tensor. # place it in an `int8` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt8(self): def testReadResultSetInt8(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -178,6 +192,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a negative or 0-valued integer from a # Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int8` tensor. # SQLite database table and place it in an `int8` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt8NegativeAndZero(self): def testReadResultSetInt8NegativeAndZero(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -191,6 +206,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a large (positive or negative) integer from # Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int8` tensor. # a SQLite database table and place it in an `int8` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt8MaxValues(self): def testReadResultSetInt8MaxValues(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -205,6 +221,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read an integer from a SQLite database table and # Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int16` tensor. # place it in an `int16` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt16(self): def testReadResultSetInt16(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -218,6 +235,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a negative or 0-valued integer from a # Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int16` tensor. # SQLite database table and place it in an `int16` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt16NegativeAndZero(self): def testReadResultSetInt16NegativeAndZero(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -231,6 +249,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a large (positive or negative) integer from # Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int16` tensor. # a SQLite database table and place it in an `int16` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt16MaxValues(self): def testReadResultSetInt16MaxValues(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -246,6 +265,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read an integer from a SQLite database table and # Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int32` tensor. # place it in an `int32` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt32(self): def testReadResultSetInt32(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -257,6 +277,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a negative or 0-valued integer from a # Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int32` tensor. # SQLite database table and place it in an `int32` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt32NegativeAndZero(self): def testReadResultSetInt32NegativeAndZero(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -270,6 +291,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a large (positive or negative) integer from # Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int32` tensor. # a SQLite database table and place it in an `int32` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt32MaxValues(self): def testReadResultSetInt32MaxValues(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -285,6 +307,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a numeric `varchar` from a SQLite database # Test that `SqlDataset` can read a numeric `varchar` from a SQLite database
# table and place it in an `int32` tensor. # table and place it in an `int32` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt32VarCharColumnAsInt(self): def testReadResultSetInt32VarCharColumnAsInt(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -298,6 +321,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read an integer from a SQLite database table # Test that `SqlDataset` can read an integer from a SQLite database table
# and place it in an `int64` tensor. # and place it in an `int64` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt64(self): def testReadResultSetInt64(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -311,6 +335,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a negative or 0-valued integer from a # Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int64` tensor. # SQLite database table and place it in an `int64` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt64NegativeAndZero(self): def testReadResultSetInt64NegativeAndZero(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -324,6 +349,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a large (positive or negative) integer from # Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int64` tensor. # a SQLite database table and place it in an `int64` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt64MaxValues(self): def testReadResultSetInt64MaxValues(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -339,6 +365,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read an integer from a SQLite database table and # Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in a `uint8` tensor. # place it in a `uint8` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetUInt8(self): def testReadResultSetUInt8(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -352,6 +379,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read the minimum and maximum uint8 values from a # Test that `SqlDataset` can read the minimum and maximum uint8 values from a
# SQLite database table and place them in `uint8` tensors. # SQLite database table and place them in `uint8` tensors.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetUInt8MinAndMaxValues(self): def testReadResultSetUInt8MinAndMaxValues(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -367,6 +395,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read an integer from a SQLite database table # Test that `SqlDataset` can read an integer from a SQLite database table
# and place it in a `uint16` tensor. # and place it in a `uint16` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetUInt16(self): def testReadResultSetUInt16(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -380,6 +409,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read the minimum and maximum uint16 values from a # Test that `SqlDataset` can read the minimum and maximum uint16 values from a
# SQLite database table and place them in `uint16` tensors. # SQLite database table and place them in `uint16` tensors.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetUInt16MinAndMaxValues(self): def testReadResultSetUInt16MinAndMaxValues(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -396,6 +426,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a 0-valued and 1-valued integer from a # Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
# SQLite database table and place them as `True` and `False` respectively # SQLite database table and place them as `True` and `False` respectively
# in `bool` tensors. # in `bool` tensors.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetBool(self): def testReadResultSetBool(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -409,6 +440,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued # Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
# from a SQLite database table and place it as `True` in a `bool` tensor. # from a SQLite database table and place it as `True` in a `bool` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetBoolNotZeroOrOne(self): def testReadResultSetBoolNotZeroOrOne(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -422,6 +454,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a float from a SQLite database table # Test that `SqlDataset` can read a float from a SQLite database table
# and place it in a `float64` tensor. # and place it in a `float64` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetFloat64(self): def testReadResultSetFloat64(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -437,6 +470,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a float from a SQLite database table beyond # Test that `SqlDataset` can read a float from a SQLite database table beyond
# the precision of 64-bit IEEE, without throwing an error. Test that # the precision of 64-bit IEEE, without throwing an error. Test that
# `SqlDataset` identifies such a value as equal to itself. # `SqlDataset` identifies such a value as equal to itself.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetFloat64OverlyPrecise(self): def testReadResultSetFloat64OverlyPrecise(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -458,6 +492,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# representing the largest integer representable as a 64-bit IEEE float # representing the largest integer representable as a 64-bit IEEE float
# such that the previous integer is also representable as a 64-bit IEEE float. # such that the previous integer is also representable as a 64-bit IEEE float.
# Test that `SqlDataset` can distinguish these two numbers. # Test that `SqlDataset` can distinguish these two numbers.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self): def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
get_next = self.getNext( get_next = self.getNext(
self._createSqlDataset( self._createSqlDataset(
@ -472,6 +507,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.evaluate(get_next()) self.evaluate(get_next())
# Test that SqlDataset can stop correctly when combined with batch # Test that SqlDataset can stop correctly when combined with batch
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetWithBatchStop(self): def testReadResultSetWithBatchStop(self):
dataset = self._createSqlDataset( dataset = self._createSqlDataset(
query="SELECT * FROM data", output_types=(dtypes.int32)) query="SELECT * FROM data", output_types=(dtypes.int32))

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
@ -24,7 +25,9 @@ from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_
from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import stats_aggregator from tensorflow.python.data.experimental.ops import stats_aggregator
from tensorflow.python.data.experimental.ops import stats_ops from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -32,8 +35,11 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): # TODO(jsimsa): Figure out why are graph tests failing.
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase,
parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testBytesProduced(self): def testBytesProduced(self):
aggregator = stats_aggregator.StatsAggregator() aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).map( dataset = dataset_ops.Dataset.range(100).map(
@ -57,6 +63,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self.assertStatisticsHasCount(handle, "bytes_produced", 100.0, 101) self.assertStatisticsHasCount(handle, "bytes_produced", 100.0, 101)
self.assertStatisticsHasSum(handle, "bytes_produced", expected_sum, 101) self.assertStatisticsHasSum(handle, "bytes_produced", expected_sum, 101)
@combinations.generate(test_base.eager_only_combinations())
def testLatencyStats(self): def testLatencyStats(self):
aggregator = stats_aggregator.StatsAggregator() aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply( dataset = dataset_ops.Dataset.range(100).apply(
@ -76,6 +83,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
handle = self.getHandle(aggregator) handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "record_latency", 100.0, 101) self.assertStatisticsHasCount(handle, "record_latency", 100.0, 101)
@combinations.generate(test_base.eager_only_combinations())
def testPrefetchBufferUtilization(self): def testPrefetchBufferUtilization(self):
aggregator = stats_aggregator.StatsAggregator() aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).map( dataset = dataset_ops.Dataset.range(100).map(
@ -117,6 +125,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
301, 301,
offset=2) offset=2)
@combinations.generate(test_base.eager_only_combinations())
def testPrefetchBufferScalars(self): def testPrefetchBufferScalars(self):
aggregator = stats_aggregator.StatsAggregator() aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(10).map( dataset = dataset_ops.Dataset.range(10).map(
@ -140,6 +149,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.eager_only_combinations())
def testFilteredElementsStats(self): def testFilteredElementsStats(self):
aggregator = stats_aggregator.StatsAggregator() aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(101).filter( dataset = dataset_ops.Dataset.range(101).filter(
@ -167,6 +177,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
handle, self.regexForNodeName("FilterDataset", "filtered_elements"), handle, self.regexForNodeName("FilterDataset", "filtered_elements"),
34.0) 34.0)
@combinations.generate(test_base.eager_only_combinations())
def testReinitialize(self): def testReinitialize(self):
aggregator = stats_aggregator.StatsAggregator() aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply( dataset = dataset_ops.Dataset.range(100).apply(
@ -187,6 +198,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self.assertStatisticsHasCount(handle, "record_latency", (j + 1) * 100.0, self.assertStatisticsHasCount(handle, "record_latency", (j + 1) * 100.0,
(j * 100) + 101) (j * 100) + 101)
@combinations.generate(test_base.eager_only_combinations())
def testNoAggregatorRegistered(self): def testNoAggregatorRegistered(self):
dataset = dataset_ops.Dataset.range(100).apply( dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")) stats_ops.latency_stats("record_latency"))
@ -198,6 +210,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.eager_only_combinations())
def testMultipleTags(self): def testMultipleTags(self):
aggregator = stats_aggregator.StatsAggregator() aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply( dataset = dataset_ops.Dataset.range(100).apply(
@ -221,6 +234,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
handle, "record_latency", 100.0, 201, offset=1) handle, "record_latency", 100.0, 201, offset=1)
self.assertStatisticsHasCount(handle, "record_latency_2", 100.0, 201) self.assertStatisticsHasCount(handle, "record_latency_2", 100.0, 201)
@combinations.generate(test_base.eager_only_combinations())
def testRepeatedTags(self): def testRepeatedTags(self):
aggregator = stats_aggregator.StatsAggregator() aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply( dataset = dataset_ops.Dataset.range(100).apply(
@ -239,6 +253,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
handle = self.getHandle(aggregator) handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "record_latency", 200.0, 201) self.assertStatisticsHasCount(handle, "record_latency", 200.0, 201)
@combinations.generate(test_base.eager_only_combinations())
def testMultipleIteratorsSameAggregator(self): def testMultipleIteratorsSameAggregator(self):
aggregator = stats_aggregator.StatsAggregator() aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply( dataset = dataset_ops.Dataset.range(100).apply(
@ -259,6 +274,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
handle = self.getHandle(aggregator) handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "record_latency", 200.0, 201) self.assertStatisticsHasCount(handle, "record_latency", 200.0, 201)
@combinations.generate(test_base.eager_only_combinations())
def testMultipleDatasetWithPrefixes(self): def testMultipleDatasetWithPrefixes(self):
aggregator = stats_aggregator.StatsAggregator() aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply( dataset = dataset_ops.Dataset.range(100).apply(
@ -289,6 +305,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self.assertStatisticsHasCount(handle, "dataset2::record_latency", 100.0, self.assertStatisticsHasCount(handle, "dataset2::record_latency", 100.0,
201) 201)
@combinations.generate(test_base.eager_only_combinations())
def testMultiplePrefetchStats(self): def testMultiplePrefetchStats(self):
aggregator = stats_aggregator.StatsAggregator() aggregator = stats_aggregator.StatsAggregator()
@ -314,8 +331,10 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self.evaluate(next_element()) self.evaluate(next_element())
class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase): class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase,
parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testMapBufferUtilization(self): def testMapBufferUtilization(self):
def dataset_fn(): def dataset_fn():
@ -326,6 +345,7 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
self.parallelCallsStats( self.parallelCallsStats(
dataset_fn, {"ParallelMapDataset"}, 10, function_processing_time=True) dataset_fn, {"ParallelMapDataset"}, 10, function_processing_time=True)
@combinations.generate(test_base.eager_only_combinations())
def testMapAutoTuneBufferUtilization(self): def testMapAutoTuneBufferUtilization(self):
def dataset_fn(): def dataset_fn():
@ -336,6 +356,7 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
self.parallelCallsStats( self.parallelCallsStats(
dataset_fn, {"ParallelMapDataset"}, 10, function_processing_time=True) dataset_fn, {"ParallelMapDataset"}, 10, function_processing_time=True)
@combinations.generate(test_base.eager_only_combinations())
def testInterleaveAutoTuneBufferUtilization(self): def testInterleaveAutoTuneBufferUtilization(self):
def dataset_fn(): def dataset_fn():
@ -351,6 +372,7 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
self.parallelCallsStats(dataset_fn, {"ParallelInterleaveDatasetV2"}, 10) self.parallelCallsStats(dataset_fn, {"ParallelInterleaveDatasetV2"}, 10)
@combinations.generate(test_base.eager_only_combinations())
def testMapAndBatchAutoTuneBufferUtilization(self): def testMapAndBatchAutoTuneBufferUtilization(self):
def dataset_fn(): def dataset_fn():
@ -370,8 +392,10 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
class FeatureStatsDatasetTest( class FeatureStatsDatasetTest(
stats_dataset_test_base.StatsDatasetTestBase, stats_dataset_test_base.StatsDatasetTestBase,
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase): reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase,
parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testFeaturesStats(self): def testFeaturesStats(self):
num_epochs = 5 num_epochs = 5
total_records = num_epochs * self._num_records total_records = num_epochs * self._num_records

View File

@ -23,18 +23,21 @@ import numpy as np
from tensorflow.python.data.experimental.ops import take_while_ops from tensorflow.python.data.experimental.ops import take_while_ops
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase): class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.parameters((14, 2), (15, 2), (100, 3)) @combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(num_elements=[14, 15], window_size=[2]) +
combinations.combine(num_elements=[100], window_size=[3])))
def testTakeWhileDataset(self, num_elements, window_size): def testTakeWhileDataset(self, num_elements, window_size):
def _predicate_func(elem): def _predicate_func(elem):
@ -49,8 +52,19 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_num_elements = int(num_elements / window_size) * window_size expected_num_elements = int(num_elements / window_size) * window_size
self.assertDatasetProduces(dataset, np.arange(expected_num_elements)) self.assertDatasetProduces(dataset, np.arange(expected_num_elements))
@parameterized.parameters((10, 2, False), (16, 7, False), (100, 99, False), @combinations.generate(
(100, 101, True), (0, 1, True)) combinations.times(
test_base.default_test_combinations(),
combinations.combine(
num_elements=[10], upper_bound=[2], out_of_bounds=[False]) +
combinations.combine(
num_elements=[16], upper_bound=[7], out_of_bounds=[False]) +
combinations.combine(
num_elements=[100], upper_bound=[99], out_of_bounds=[False]) +
combinations.combine(
num_elements=[100], upper_bound=[101], out_of_bounds=[True]) +
combinations.combine(
num_elements=[0], upper_bound=[1], out_of_bounds=[True])))
def testTakeWhileDatasetRange(self, num_elements, upper_bound, out_of_bounds): def testTakeWhileDatasetRange(self, num_elements, upper_bound, out_of_bounds):
dataset = dataset_ops.Dataset.range(num_elements).apply( dataset = dataset_ops.Dataset.range(num_elements).apply(
take_while_ops.take_while(lambda x: x < upper_bound)) take_while_ops.take_while(lambda x: x < upper_bound))
@ -62,6 +76,7 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
else: else:
self.assertDatasetProduces(dataset, np.arange(upper_bound)) self.assertDatasetProduces(dataset, np.arange(upper_bound))
@combinations.generate(test_base.default_test_combinations())
def testTakeWhileDatasetString(self): def testTakeWhileDatasetString(self):
def not_equal(string): def not_equal(string):
@ -79,7 +94,13 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.assertEqual(b"test", self.evaluate(next_element())) self.assertEqual(b"test", self.evaluate(next_element()))
@parameterized.parameters((5, 3), (10, 0), (100, 5), (8, 7)) @combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(size=[5], index=[3]) +
combinations.combine(size=[10], index=[0]) +
combinations.combine(size=[100], index=[5]) +
combinations.combine(size=[8], index=[7])))
def testTakewhileDatasetShortCircuit(self, size, index): def testTakewhileDatasetShortCircuit(self, size, index):
def _predicate_func(data_elem): def _predicate_func(data_elem):
@ -98,6 +119,7 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element()) self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testTakeWhileDatasetWithRepeat(self): def testTakeWhileDatasetWithRepeat(self):
dataset = dataset_ops.Dataset.range(10).apply( dataset = dataset_ops.Dataset.range(10).apply(
take_while_ops.take_while(lambda x: x < 2)).repeat(5) take_while_ops.take_while(lambda x: x < 2)).repeat(5)

View File

@ -19,14 +19,16 @@ from __future__ import print_function
import os import os
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import grouping from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.experimental.ops import writers from tensorflow.python.data.experimental.ops import writers
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers from tensorflow.python.data.ops import readers
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import python_io from tensorflow.python.lib.io import python_io
from tensorflow.python.lib.io import tf_record from tensorflow.python.lib.io import tf_record
from tensorflow.python.ops import string_ops from tensorflow.python.ops import string_ops
@ -34,8 +36,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat from tensorflow.python.util import compat
@test_util.run_all_in_graph_and_eager_modes class TFRecordWriterTest(test_base.DatasetTestBase, parameterized.TestCase):
class TFRecordWriterTest(test_base.DatasetTestBase):
def setUp(self): def setUp(self):
super(TFRecordWriterTest, self).setUp() super(TFRecordWriterTest, self).setUp()
@ -63,11 +64,13 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
def _outputFilename(self): def _outputFilename(self):
return os.path.join(self.get_temp_dir(), "tf_record.out.txt") return os.path.join(self.get_temp_dir(), "tf_record.out.txt")
@combinations.generate(test_base.default_test_combinations())
def testWrite(self): def testWrite(self):
self.evaluate(self.writer_fn(self._createFile())) self.evaluate(self.writer_fn(self._createFile()))
for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())): for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())):
self.assertAllEqual(self._record(i), r) self.assertAllEqual(self._record(i), r)
@combinations.generate(test_base.default_test_combinations())
def testWriteZLIB(self): def testWriteZLIB(self):
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB) options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB)
self.evaluate( self.evaluate(
@ -76,6 +79,7 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
tf_record.tf_record_iterator(self._outputFilename(), options=options)): tf_record.tf_record_iterator(self._outputFilename(), options=options)):
self.assertAllEqual(self._record(i), r) self.assertAllEqual(self._record(i), r)
@combinations.generate(test_base.default_test_combinations())
def testWriteGZIP(self): def testWriteGZIP(self):
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP) options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
self.evaluate( self.evaluate(
@ -84,20 +88,24 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
tf_record.tf_record_iterator(self._outputFilename(), options=options)): tf_record.tf_record_iterator(self._outputFilename(), options=options)):
self.assertAllEqual(self._record(i), r) self.assertAllEqual(self._record(i), r)
@combinations.generate(test_base.default_test_combinations())
def testFailDataset(self): def testFailDataset(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
writers.TFRecordWriter(self._outputFilename(), "").write("whoops") writers.TFRecordWriter(self._outputFilename(), "").write("whoops")
@combinations.generate(test_base.default_test_combinations())
def testFailDType(self): def testFailDType(self):
input_dataset = dataset_ops.Dataset.from_tensors(10) input_dataset = dataset_ops.Dataset.from_tensors(10)
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset) writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset)
@combinations.generate(test_base.default_test_combinations())
def testFailShape(self): def testFailShape(self):
input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]]) input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]])
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset) writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset)
@combinations.generate(test_base.default_test_combinations())
def testSideEffect(self): def testSideEffect(self):
def writer_fn(): def writer_fn():
input_dataset = readers.TFRecordDataset(self._createFile()) input_dataset = readers.TFRecordDataset(self._createFile())
@ -112,6 +120,7 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())): for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())):
self.assertAllEqual(self._record(i), r) self.assertAllEqual(self._record(i), r)
@combinations.generate(test_base.default_test_combinations())
def testShard(self): def testShard(self):
filename = self._createFile() filename = self._createFile()
dataset = readers.TFRecordDataset([filename]) dataset = readers.TFRecordDataset([filename])

View File

@ -17,17 +17,18 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import unique from tensorflow.python.data.experimental.ops import unique
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.util import compat from tensorflow.python.util import compat
@test_util.run_all_in_graph_and_eager_modes class UniqueTest(test_base.DatasetTestBase, parameterized.TestCase):
class UniqueTest(test_base.DatasetTestBase):
def _testSimpleHelper(self, dtype, test_cases): def _testSimpleHelper(self, dtype, test_cases):
"""Test the `unique()` transformation on a list of test cases. """Test the `unique()` transformation on a list of test cases.
@ -52,7 +53,7 @@ class UniqueTest(test_base.DatasetTestBase):
for element in expected for element in expected
]) ])
@test_util.run_deprecated_v1 @combinations.generate(test_base.graph_only_combinations())
def testSimpleInt(self): def testSimpleInt(self):
for dtype in [dtypes.int32, dtypes.int64]: for dtype in [dtypes.int32, dtypes.int64]:
self._testSimpleHelper(dtype, [ self._testSimpleHelper(dtype, [
@ -65,7 +66,7 @@ class UniqueTest(test_base.DatasetTestBase):
([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]), ([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]),
]) ])
@test_util.run_deprecated_v1 @combinations.generate(test_base.graph_only_combinations())
def testSimpleString(self): def testSimpleString(self):
self._testSimpleHelper(dtypes.string, [ self._testSimpleHelper(dtypes.string, [
([], []), ([], []),

View File

@ -17,16 +17,18 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import cardinality from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import combinations
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes class VariantTest(test_base.DatasetTestBase, parameterized.TestCase):
class VariantTest(test_base.DatasetTestBase):
@combinations.generate(test_base.default_test_combinations())
def testRoundtripRange(self): def testRoundtripRange(self):
dataset = dataset_ops.Dataset.range(10) dataset = dataset_ops.Dataset.range(10)
variant = dataset_ops.to_variant(dataset) variant = dataset_ops.to_variant(dataset)
@ -35,6 +37,7 @@ class VariantTest(test_base.DatasetTestBase):
self.assertDatasetProduces(dataset, range(10)) self.assertDatasetProduces(dataset, range(10))
self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10) self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
@combinations.generate(test_base.default_test_combinations())
def testRoundtripMap(self): def testRoundtripMap(self):
dataset = dataset_ops.Dataset.range(10).map(lambda x: x*x) dataset = dataset_ops.Dataset.range(10).map(lambda x: x*x)
variant = dataset_ops.to_variant(dataset) variant = dataset_ops.to_variant(dataset)

View File

@ -17,18 +17,20 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes class WrapDatasetVariantTest(test_base.DatasetTestBase, parameterized.TestCase):
class WrapDatasetVariantTest(test_base.DatasetTestBase):
@combinations.generate(test_base.default_test_combinations())
def testBasic(self): def testBasic(self):
ds = dataset_ops.Dataset.range(100) ds = dataset_ops.Dataset.range(100)
ds_variant = ds._variant_tensor # pylint: disable=protected-access ds_variant = ds._variant_tensor # pylint: disable=protected-access
@ -42,7 +44,9 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase):
for i in range(100): for i in range(100):
self.assertEqual(i, self.evaluate(get_next())) self.assertEqual(i, self.evaluate(get_next()))
@test_util.run_v1_only("b/123901304") # TODO(b/123901304)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testSkipEagerGPU(self): def testSkipEagerGPU(self):
ds = dataset_ops.Dataset.range(100) ds = dataset_ops.Dataset.range(100)
ds_variant = ds._variant_tensor # pylint: disable=protected-access ds_variant = ds._variant_tensor # pylint: disable=protected-access