[tf.data] Rolling forward a previously rolled back change with a fix.
PiperOrigin-RevId: 284036647 Change-Id: I9d50ad7aa8123f6928c055a25bc3dc4d69d2b95d
This commit is contained in:
parent
7e2b4b8c96
commit
769892b353
@ -25,11 +25,11 @@ from tensorflow.python.data.experimental.ops import grouping
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -73,14 +73,12 @@ def _get_record_shape(sparse):
|
||||
return tensor_shape.TensorShape([None])
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class BucketBySequenceLengthTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("WithoutPadding", True),
|
||||
("WithPadding", False),
|
||||
)
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(param_no_padding=[True, False])))
|
||||
def testBucketDropReminder(self, param_no_padding):
|
||||
|
||||
boundaries = [10, 20, 30]
|
||||
@ -201,10 +199,9 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
|
||||
|
||||
_test_bucket_by_padding(param_no_padding)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("WithoutPadding", True),
|
||||
("WithPadding", False),
|
||||
)
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(param_no_padding=[True, False])))
|
||||
def testBucket(self, param_no_padding):
|
||||
|
||||
boundaries = [10, 20, 30]
|
||||
@ -347,10 +344,9 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
|
||||
self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("WithoutPadding", True),
|
||||
("WithPadding", False),
|
||||
)
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(param_no_padding=[True, False])))
|
||||
def testTupleElements(self, param_no_padding):
|
||||
|
||||
def build_dataset(sparse):
|
||||
@ -381,10 +377,10 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
|
||||
|
||||
_test_tuple_elements_by_padding(param_no_padding)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("DoDropRemainder", True),
|
||||
("DoNotDropRemainder", False),
|
||||
)
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(param_drop_remainder=[True, False])))
|
||||
def testBucketSparse(self, param_drop_remainder): # pylint: disable=g-doc-args
|
||||
"""Tests bucketing of sparse tensors (case where `no_padding` == True).
|
||||
|
||||
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import prefetching_ops
|
||||
@ -24,6 +26,7 @@ from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
@ -35,9 +38,9 @@ from tensorflow.python.util import compat as util_compat
|
||||
|
||||
|
||||
# TODO(b/117581999): add eager coverage when supported.
|
||||
class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
class CopyToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
@ -62,7 +65,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceInt32(self):
|
||||
host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
|
||||
device_dataset = host_dataset.apply(
|
||||
@ -86,7 +89,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToSameDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
@ -111,7 +114,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceWithPrefetch(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
@ -136,7 +139,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyDictToDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
|
||||
device_dataset = host_dataset.apply(
|
||||
@ -161,7 +164,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyDictToDeviceWithPrefetch(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
|
||||
device_dataset = host_dataset.apply(
|
||||
@ -186,7 +189,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopySparseTensorsToDevice(self):
|
||||
|
||||
def make_tensor(i):
|
||||
@ -219,7 +222,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopySparseTensorsToDeviceWithPrefetch(self):
|
||||
|
||||
def make_tensor(i):
|
||||
@ -252,7 +255,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -273,7 +276,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceGpuWithPrefetch(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -294,7 +297,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceGpuWithMap(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -332,7 +335,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceGpuInt32(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -352,7 +355,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceGpuInt32AndPrefetch(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -372,7 +375,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceGpuStrings(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -392,7 +395,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceGpuStringsAndPrefetch(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -412,7 +415,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDevicePingPongCPUGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -436,7 +439,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceWithReInit(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
@ -465,7 +468,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceWithReInitAndPrefetch(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
@ -494,7 +497,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceGpuWithReInit(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -518,7 +521,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testCopyToDeviceGpuWithReInitAndPrefetch(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -542,7 +545,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testIteratorGetNextAsOptionalOnGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
@ -17,35 +17,33 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import counter
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class CounterTest(test_base.DatasetTestBase):
|
||||
class CounterTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def testCounter(self):
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(start=3, step=4, expected_output=[[3, 7, 11]]) +
|
||||
combinations.combine(start=0, step=-1, expected_output=[[0, -1, -2]]))
|
||||
)
|
||||
def testCounter(self, start, step, expected_output):
|
||||
"""Test dataset construction using `count`."""
|
||||
dataset = counter.Counter(start=3, step=4)
|
||||
dataset = counter.Counter(start, step)
|
||||
self.assertEqual(
|
||||
[], dataset_ops.get_legacy_output_shapes(dataset).as_list())
|
||||
self.assertEqual(dtypes.int64, dataset_ops.get_legacy_output_types(dataset))
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
negative_dataset = counter.Counter(start=0, step=-1)
|
||||
negative_get_next = self.getNext(negative_dataset)
|
||||
|
||||
self.assertEqual(3, self.evaluate(get_next()))
|
||||
self.assertEqual(3 + 4, self.evaluate(get_next()))
|
||||
self.assertEqual(3 + 2 * 4, self.evaluate(get_next()))
|
||||
|
||||
self.assertEqual(0, self.evaluate(negative_get_next()))
|
||||
self.assertEqual(-1, self.evaluate(negative_get_next()))
|
||||
self.assertEqual(-2, self.evaluate(negative_get_next()))
|
||||
for expected in expected_output:
|
||||
self.assertEqual(expected, self.evaluate(get_next()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -22,21 +22,22 @@ import gzip
|
||||
import os
|
||||
import zlib
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import error_ops
|
||||
from tensorflow.python.data.experimental.ops import readers
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import readers as core_readers
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
|
||||
filenames = []
|
||||
@ -117,26 +118,31 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
dataset = readers.CsvDataset(filenames, **kwargs)
|
||||
self._verify_output_or_err(dataset, expected_output, expected_err_re)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_requiredFields(self):
|
||||
record_defaults = [[]] * 4
|
||||
inputs = [['1,2,3,4']]
|
||||
self._test_by_comparison(inputs, record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_int(self):
|
||||
record_defaults = [[0]] * 4
|
||||
inputs = [['1,2,3,4', '5,6,7,8']]
|
||||
self._test_by_comparison(inputs, record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_float(self):
|
||||
record_defaults = [[0.0]] * 4
|
||||
inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']]
|
||||
self._test_by_comparison(inputs, record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_string(self):
|
||||
record_defaults = [['']] * 4
|
||||
inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']]
|
||||
self._test_by_comparison(inputs, record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withEmptyFields(self):
|
||||
record_defaults = [[0]] * 4
|
||||
inputs = [[',,,', '1,1,1,', ',2,2,2']]
|
||||
@ -144,6 +150,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
|
||||
record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_errWithUnquotedQuotes(self):
|
||||
record_defaults = [['']] * 3
|
||||
inputs = [['1,2"3,4']]
|
||||
@ -152,6 +159,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
expected_err_re='Unquoted fields cannot have quotes inside',
|
||||
record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_errWithUnescapedQuotes(self):
|
||||
record_defaults = [['']] * 3
|
||||
inputs = [['"a"b","c","d"']]
|
||||
@ -161,6 +169,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
'Quote inside a string has to be escaped by another quote',
|
||||
record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_ignoreErrWithUnescapedQuotes(self):
|
||||
record_defaults = [['']] * 3
|
||||
inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
|
||||
@ -169,6 +178,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
dataset = dataset.apply(error_ops.ignore_errors())
|
||||
self._verify_output_or_err(dataset, [['e', 'f', 'g']])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
|
||||
record_defaults = [['']] * 3
|
||||
inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
|
||||
@ -177,12 +187,14 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
dataset = dataset.apply(error_ops.ignore_errors())
|
||||
self._verify_output_or_err(dataset, [['e', 'f', 'g']])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self):
|
||||
record_defaults = [['']] * 3
|
||||
inputs = [['1,2"3,4']]
|
||||
self._test_by_comparison(
|
||||
inputs, record_defaults=record_defaults, use_quote_delim=False)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_mixedTypes(self):
|
||||
record_defaults = [
|
||||
constant_op.constant([], dtype=dtypes.int32),
|
||||
@ -193,30 +205,35 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']]
|
||||
self._test_by_comparison(inputs, record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withUseQuoteDelimFalse(self):
|
||||
record_defaults = [['']] * 4
|
||||
inputs = [['1,2,"3,4"', '"5,6",7,8']]
|
||||
self._test_by_comparison(
|
||||
inputs, record_defaults=record_defaults, use_quote_delim=False)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withFieldDelim(self):
|
||||
record_defaults = [[0]] * 4
|
||||
inputs = [['1:2:3:4', '5:6:7:8']]
|
||||
self._test_by_comparison(
|
||||
inputs, record_defaults=record_defaults, field_delim=':')
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withNaValue(self):
|
||||
record_defaults = [[0]] * 4
|
||||
inputs = [['1,NA,3,4', 'NA,6,7,8']]
|
||||
self._test_by_comparison(
|
||||
inputs, record_defaults=record_defaults, na_value='NA')
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withSelectCols(self):
|
||||
record_defaults = [['']] * 2
|
||||
inputs = [['1,2,3,4', '"5","6","7","8"']]
|
||||
self._test_by_comparison(
|
||||
inputs, record_defaults=record_defaults, select_cols=[1, 2])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withSelectColsTooHigh(self):
|
||||
record_defaults = [[0]] * 2
|
||||
inputs = [['1,2,3,4', '5,6,7,8']]
|
||||
@ -226,23 +243,27 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
record_defaults=record_defaults,
|
||||
select_cols=[3, 4])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withOneCol(self):
|
||||
record_defaults = [['NA']]
|
||||
inputs = [['0', '', '2']]
|
||||
self._test_dataset(
|
||||
inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withMultipleFiles(self):
|
||||
record_defaults = [[0]] * 4
|
||||
inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']]
|
||||
self._test_by_comparison(inputs, record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withLeadingAndTrailingSpaces(self):
|
||||
record_defaults = [[0.0]] * 4
|
||||
inputs = [['0, 1, 2, 3']]
|
||||
expected = [[0.0, 1.0, 2.0, 3.0]]
|
||||
self._test_dataset(inputs, expected, record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_errorWithMissingDefault(self):
|
||||
record_defaults = [[]] * 2
|
||||
inputs = [['0,']]
|
||||
@ -251,6 +272,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
expected_err_re='Field 1 is required but missing in record!',
|
||||
record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_errorWithFewerDefaultsThanFields(self):
|
||||
record_defaults = [[0.0]] * 2
|
||||
inputs = [['0,1,2,3']]
|
||||
@ -259,6 +281,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
expected_err_re='Expect 2 fields but have more in record',
|
||||
record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_errorWithMoreDefaultsThanFields(self):
|
||||
record_defaults = [[0.0]] * 5
|
||||
inputs = [['0,1,2,3']]
|
||||
@ -267,6 +290,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
expected_err_re='Expect 5 fields but have 4 in record',
|
||||
record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withHeader(self):
|
||||
record_defaults = [[0]] * 2
|
||||
inputs = [['col1,col2', '1,2']]
|
||||
@ -278,6 +302,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
header=True,
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withHeaderAndNoRecords(self):
|
||||
record_defaults = [[0]] * 2
|
||||
inputs = [['col1,col2']]
|
||||
@ -289,6 +314,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
header=True,
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_errorWithHeaderEmptyFile(self):
|
||||
record_defaults = [[0]] * 2
|
||||
inputs = [[]]
|
||||
@ -300,12 +326,14 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
header=True,
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withEmptyFile(self):
|
||||
record_defaults = [['']] * 2
|
||||
inputs = [['']] # Empty file
|
||||
self._test_dataset(
|
||||
inputs, expected_output=[], record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_errorWithEmptyRecord(self):
|
||||
record_defaults = [['']] * 2
|
||||
inputs = [['', '1,2']] # First record is empty
|
||||
@ -314,6 +342,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
expected_err_re='Expect 2 fields but have 1 in record',
|
||||
record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withChainedOps(self):
|
||||
# Testing that one dataset can create multiple iterators fine.
|
||||
# `repeat` creates multiple iterators from the same C++ Dataset.
|
||||
@ -325,6 +354,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
ds_actual.repeat(5).prefetch(1),
|
||||
ds_expected.repeat(5).prefetch(1))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withTypeDefaults(self):
|
||||
# Testing using dtypes as record_defaults for required fields
|
||||
record_defaults = [dtypes.float32, [0.0]]
|
||||
@ -335,6 +365,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
record_defaults=record_defaults,
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCsvDataset_fieldOrder(self):
|
||||
data = [[
|
||||
'1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19',
|
||||
@ -352,6 +383,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
## The following tests exercise parsing logic for quoted fields
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withQuoted(self):
|
||||
record_defaults = [['']] * 4
|
||||
inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']]
|
||||
@ -363,6 +395,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
self._test_dataset(
|
||||
inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withNewLine(self):
|
||||
# In this case, we expect it to behave differently from
|
||||
# TextLineDataset->map(decode_csv) since that flow has bugs
|
||||
@ -371,6 +404,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']]
|
||||
self._test_dataset(inputs, expected, record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withNewLineInUnselectedCol(self):
|
||||
record_defaults = [['']]
|
||||
inputs = [['1,"2\n3",4', '5,6,7']]
|
||||
@ -380,6 +414,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
record_defaults=record_defaults,
|
||||
select_cols=[0])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withMultipleNewLines(self):
|
||||
# In this case, we expect it to behave differently from
|
||||
# TextLineDataset->map(decode_csv) since that flow has bugs
|
||||
@ -388,6 +423,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']]
|
||||
self._test_dataset(inputs, expected, record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_errorWithTerminateMidRecord(self):
|
||||
record_defaults = [['']] * 4
|
||||
inputs = [['a,b,c,"a']]
|
||||
@ -397,6 +433,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
'Reached end of file without closing quoted field in record',
|
||||
record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withEscapedQuotes(self):
|
||||
record_defaults = [['']] * 4
|
||||
inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']]
|
||||
@ -406,6 +443,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
## Testing that parsing works with all buffer sizes, quoted/unquoted fields,
|
||||
## and different types of line breaks
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withInvalidBufferSize(self):
|
||||
record_defaults = [['']] * 4
|
||||
inputs = [['a,b,c,d']]
|
||||
@ -432,6 +470,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
record_defaults=record_defaults,
|
||||
buffer_size=i)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withLF(self):
|
||||
record_defaults = [['NA']] * 3
|
||||
inputs = [['abc,def,ghi', '0,1,2', ',,']]
|
||||
@ -439,6 +478,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
self._test_dataset_on_buffer_sizes(
|
||||
inputs, expected, linebreak='\n', record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withCR(self):
|
||||
# Test that when the line separator is '\r', parsing works with all buffer
|
||||
# sizes
|
||||
@ -448,6 +488,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
self._test_dataset_on_buffer_sizes(
|
||||
inputs, expected, linebreak='\r', record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withCRLF(self):
|
||||
# Test that when the line separator is '\r\n', parsing works with all buffer
|
||||
# sizes
|
||||
@ -457,6 +498,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
self._test_dataset_on_buffer_sizes(
|
||||
inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withBufferSizeAndQuoted(self):
|
||||
record_defaults = [['NA']] * 3
|
||||
inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
|
||||
@ -465,6 +507,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
self._test_dataset_on_buffer_sizes(
|
||||
inputs, expected, linebreak='\n', record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withCRAndQuoted(self):
|
||||
# Test that when the line separator is '\r', parsing works with all buffer
|
||||
# sizes
|
||||
@ -475,6 +518,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
self._test_dataset_on_buffer_sizes(
|
||||
inputs, expected, linebreak='\r', record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withCRLFAndQuoted(self):
|
||||
# Test that when the line separator is '\r\n', parsing works with all buffer
|
||||
# sizes
|
||||
@ -485,6 +529,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
self._test_dataset_on_buffer_sizes(
|
||||
inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withGzipCompressionType(self):
|
||||
record_defaults = [['NA']] * 3
|
||||
inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
|
||||
@ -497,6 +542,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
compression_type='GZIP',
|
||||
record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withZlibCompressionType(self):
|
||||
record_defaults = [['NA']] * 3
|
||||
inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
|
||||
@ -509,6 +555,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
compression_type='ZLIB',
|
||||
record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_withScalarDefaults(self):
|
||||
record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4
|
||||
inputs = [[',,,', '1,1,1,', ',2,2,2']]
|
||||
@ -516,6 +563,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
|
||||
record_defaults=record_defaults)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCsvDataset_with2DDefaults(self):
|
||||
record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4
|
||||
inputs = [[',,,', '1,1,1,', ',2,2,2']]
|
||||
|
@ -17,20 +17,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
||||
class DenseToSparseBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDenseToSparseBatchDataset(self):
|
||||
components = np.random.randint(12, size=(100,)).astype(np.int32)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
@ -53,6 +54,7 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDenseToSparseBatchDatasetWithUnknownShape(self):
|
||||
components = np.random.randint(5, size=(40,)).astype(np.int32)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
@ -80,12 +82,14 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDenseToSparseBatchDatasetWithInvalidShape(self):
|
||||
input_tensor = array_ops.constant([[1]])
|
||||
with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
|
||||
dataset_ops.Dataset.from_tensors(input_tensor).apply(
|
||||
batching.dense_to_sparse_batch(4, [-2]))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDenseToSparseBatchDatasetShapeErrors(self):
|
||||
|
||||
def dataset_fn(input_tensor):
|
||||
|
@ -17,22 +17,24 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
||||
class DirectedInterleaveDatasetTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testBasic(self):
|
||||
selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
|
||||
input_datasets = [
|
||||
@ -76,6 +78,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
return freqs
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSampleFromDatasets(self):
|
||||
random_seed.set_random_seed(1619)
|
||||
num_samples = 5000
|
||||
@ -95,6 +98,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
||||
freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples)
|
||||
self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSelectFromDatasets(self):
|
||||
words = [b"foo", b"bar", b"baz"]
|
||||
datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words]
|
||||
@ -107,6 +111,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testErrors(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r"vector of length `len\(datasets\)`"):
|
||||
|
@ -23,25 +23,30 @@ from tensorflow.python.data.experimental.ops import get_single_element
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Zero", 0, 1),
|
||||
("Five", 5, 1),
|
||||
("Ten", 10, 1),
|
||||
("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."),
|
||||
("MoreThanOne", 0, 2, errors.InvalidArgumentError,
|
||||
"Dataset had more than one element."),
|
||||
)
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
skip=[0, 5, 10], take=[1], error=[None], error_msg=[None]) +
|
||||
combinations.combine(
|
||||
skip=[100],
|
||||
take=[1],
|
||||
error=[errors.InvalidArgumentError],
|
||||
error_msg=["Dataset was empty."]) + combinations.combine(
|
||||
skip=[0],
|
||||
take=[2],
|
||||
error=[errors.InvalidArgumentError],
|
||||
error_msg=["Dataset had more than one element."])))
|
||||
def testGetSingleElement(self, skip, take, error=None, error_msg=None):
|
||||
|
||||
def make_sparse(x):
|
||||
@ -62,6 +67,7 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaisesRegexp(error, error_msg):
|
||||
self.evaluate(get_single_element.get_single_element(dataset))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWindow(self):
|
||||
"""Test that `get_single_element()` can consume a nested dataset."""
|
||||
def flat_map_func(ds):
|
||||
@ -73,6 +79,7 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
dataset, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSideEffect(self):
|
||||
counter_var = variables.Variable(0)
|
||||
|
||||
@ -92,6 +99,7 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(fn()), b"hello")
|
||||
self.assertEqual(self.evaluate(counter_var), 1)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAutomaticControlDependencies(self):
|
||||
counter_var = variables.Variable(1)
|
||||
|
||||
|
@ -17,25 +17,26 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import grouping
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
class GroupByReducerTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSum(self):
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda _: np.int64(0),
|
||||
@ -49,6 +50,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
expected_shapes=tensor_shape.TensorShape([]),
|
||||
expected_output=[(i - 1) * i, i * i])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAverage(self):
|
||||
|
||||
def reduce_fn(x, y):
|
||||
@ -68,6 +70,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
expected_shapes=tensor_shape.TensorShape([]),
|
||||
expected_output=[i - 1, i])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testConcat(self):
|
||||
components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
|
||||
reducer = grouping.Reducer(
|
||||
@ -84,6 +87,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
expected_shapes=tensor_shape.TensorShape([]),
|
||||
expected_output=[b"acegikmoqs"[:i], b"bdfhjlnprt"[:i]])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSparseSum(self):
|
||||
def _sparse(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
@ -103,6 +107,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
expected_shapes=tensor_shape.TensorShape([]),
|
||||
expected_output=[(i - 1) * i, i * i])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testChangingStateShape(self):
|
||||
|
||||
def reduce_fn(x, _):
|
||||
@ -130,6 +135,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTypeMismatch(self):
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
|
||||
@ -144,6 +150,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
grouping.group_by_reducer(lambda _: np.int64(0), reducer))
|
||||
|
||||
# TODO(b/78665031): Remove once non-scalar keys are supported.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testInvalidKeyShape(self):
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda x: np.int64(0),
|
||||
@ -157,6 +164,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
|
||||
|
||||
# TODO(b/78665031): Remove once non-int64 keys are supported.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testInvalidKeyType(self):
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda x: np.int64(0),
|
||||
@ -169,6 +177,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
dataset.apply(
|
||||
grouping.group_by_reducer(lambda _: "wrong", reducer))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTuple(self):
|
||||
def init_fn(_):
|
||||
return np.array([], dtype=np.int64), np.int64(0)
|
||||
|
@ -17,17 +17,18 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import grouping
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
@ -37,8 +38,7 @@ from tensorflow.python.platform import test
|
||||
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
|
||||
# Currently, they use a constant batch size, though should be made to use a
|
||||
# different batch size per key.
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
class GroupByWindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _dynamicPad(self, bucket, window, window_size):
|
||||
# TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
|
||||
@ -51,6 +51,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
|
||||
[None]), tensor_shape.TensorShape([3])))))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSingleBucket(self):
|
||||
|
||||
def _map_fn(v):
|
||||
@ -80,6 +81,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
self.assertAllEqual(expected_unk_int64, bucketed_values[1])
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values[2])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testEvenOddBuckets(self):
|
||||
|
||||
def _map_fn(v):
|
||||
@ -132,6 +134,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testEvenOddBucketsFilterOutAllOdd(self):
|
||||
|
||||
def _map_fn(v):
|
||||
@ -173,6 +176,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
self.assertAllEqual(
|
||||
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDynamicWindowSize(self):
|
||||
components = np.arange(100).astype(np.int64)
|
||||
|
||||
@ -202,6 +206,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
|
||||
self.assertEqual(batches, 15)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSimple(self):
|
||||
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
@ -222,6 +227,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
self.assertGreaterEqual(num_full_batches, 24)
|
||||
self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testImmediateOutput(self):
|
||||
components = np.array(
|
||||
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
|
||||
@ -240,6 +246,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
self.assertAllEqual([2, 2, 2, 2], self.evaluate(get_next()))
|
||||
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next()))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSmallGroups(self):
|
||||
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
|
||||
@ -252,6 +259,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
self.assertAllEqual([0, 0, 0], self.evaluate(get_next()))
|
||||
self.assertAllEqual([1], self.evaluate(get_next()))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testEmpty(self):
|
||||
dataset = dataset_ops.Dataset.range(4).apply(
|
||||
grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
|
||||
@ -262,6 +270,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
"Window size must be greater than zero, but got 0."):
|
||||
print(self.evaluate(get_next()))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReduceFuncError(self):
|
||||
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
||||
|
||||
@ -280,6 +289,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testConsumeWindowDatasetMoreThanOnce(self):
|
||||
components = np.random.randint(50, size=(200,)).astype(np.int64)
|
||||
|
||||
@ -311,6 +321,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
counts.append(tight_result.shape[0])
|
||||
self.assertEqual(len(components), sum(counts))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testShortCircuit(self):
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
|
@ -19,14 +19,15 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import error_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import python_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
@ -36,9 +37,9 @@ from tensorflow.python.util import compat
|
||||
_NUMPY_RANDOM_SEED = 42
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
class IgnoreErrorsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapIgnoreError(self):
|
||||
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
|
||||
|
||||
@ -53,6 +54,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testParallelMapIgnoreError(self):
|
||||
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
|
||||
|
||||
@ -67,6 +69,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadFileIgnoreError(self):
|
||||
|
||||
def write_string_to_file(value, filename):
|
||||
@ -102,6 +105,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTFRecordDatasetIgnoreError(self):
|
||||
filenames = []
|
||||
for i in range(5):
|
||||
@ -126,6 +130,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testZipIgnoreError(self):
|
||||
a = dataset_ops.Dataset.from_tensor_slices([1., 2., 0., 4.])
|
||||
b = a.map(lambda x: array_ops.check_numerics(1. / x, "error"))
|
||||
|
@ -17,26 +17,29 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
|
||||
from tensorflow.python.data.experimental.ops import readers
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers as core_readers
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class MakeBatchedFeaturesDatasetTest(
|
||||
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
|
||||
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRead(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 10]:
|
||||
@ -85,6 +88,7 @@ class MakeBatchedFeaturesDatasetTest(
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadWithEquivalentDataset(self):
|
||||
features = {
|
||||
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
|
||||
@ -103,6 +107,7 @@ class MakeBatchedFeaturesDatasetTest(
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadWithFusedShuffleRepeatDataset(self):
|
||||
num_epochs = 5
|
||||
total_records = num_epochs * self._num_records
|
||||
@ -151,6 +156,7 @@ class MakeBatchedFeaturesDatasetTest(
|
||||
all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
|
||||
self.assertFalse(all_equal)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testParallelReadersAndParsers(self):
|
||||
num_epochs = 5
|
||||
for batch_size in [1, 2]:
|
||||
@ -186,6 +192,7 @@ class MakeBatchedFeaturesDatasetTest(
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDropFinalBatch(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 10]:
|
||||
@ -201,6 +208,7 @@ class MakeBatchedFeaturesDatasetTest(
|
||||
if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
|
||||
self.assertEqual(tensor.shape[0], batch_size)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testIndefiniteRepeatShapeInference(self):
|
||||
dataset = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
@ -213,6 +221,7 @@ class MakeBatchedFeaturesDatasetTest(
|
||||
if issubclass(clazz, ops.Tensor):
|
||||
self.assertEqual(32, shape[0])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOldStyleReader(self):
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, r"The `reader` argument must return a `Dataset` object. "
|
||||
|
@ -21,21 +21,21 @@ import gzip
|
||||
import os
|
||||
import zlib
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import readers
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
class MakeCsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
|
||||
return readers.make_csv_dataset(
|
||||
@ -126,6 +126,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
self._verify_output(dataset, batch_size, num_epochs, label_name,
|
||||
expected_output, expected_keys)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset(self):
|
||||
"""Tests making a CSV dataset with keys and defaults provided."""
|
||||
record_defaults = [
|
||||
@ -157,6 +158,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
column_defaults=record_defaults,
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withBatchSizeAndEpochs(self):
|
||||
"""Tests making a CSV dataset with keys and defaults provided."""
|
||||
record_defaults = [
|
||||
@ -188,6 +190,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
column_defaults=record_defaults,
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withCompressionType(self):
|
||||
"""Tests `compression_type` argument."""
|
||||
record_defaults = [
|
||||
@ -221,6 +224,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
compression_type=compression_type,
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withCompressionTypeAndNoColumnNames(self):
|
||||
"""Tests `compression_type` argument."""
|
||||
record_defaults = [
|
||||
@ -269,6 +273,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
compression_type="ZLIB",
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withBadInputs(self):
|
||||
"""Tests that exception is raised when input is malformed.
|
||||
"""
|
||||
@ -304,6 +309,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
label_name="not_a_real_label",
|
||||
column_names=column_names)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withNoLabel(self):
|
||||
"""Tests making a CSV dataset with no label provided."""
|
||||
record_defaults = [
|
||||
@ -333,6 +339,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
column_defaults=record_defaults,
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withNoHeader(self):
|
||||
"""Tests that datasets can be created from CSV files with no header line.
|
||||
"""
|
||||
@ -363,6 +370,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
column_defaults=record_defaults,
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withTypes(self):
|
||||
"""Tests that defaults can be a dtype instead of a Tensor for required vals.
|
||||
"""
|
||||
@ -394,6 +402,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
column_defaults=record_defaults,
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withNoColNames(self):
|
||||
"""Tests that datasets can be created when column names are not specified.
|
||||
|
||||
@ -427,6 +436,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
column_defaults=record_defaults,
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withTypeInferenceMismatch(self):
|
||||
# Test that error is thrown when num fields doesn't match columns
|
||||
column_names = ["col%d" % i for i in range(5)]
|
||||
@ -442,6 +452,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
batch_size=2,
|
||||
num_epochs=10)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withTypeInference(self):
|
||||
"""Tests that datasets can be created when no defaults are specified.
|
||||
|
||||
@ -468,6 +479,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
header=True,
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withTypeInferenceFallthrough(self):
|
||||
"""Tests that datasets can be created when no defaults are specified.
|
||||
|
||||
@ -498,6 +510,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
header=True,
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withNAValuesAndFieldDelim(self):
|
||||
"""Tests that datasets can be created from different delim and na_value."""
|
||||
column_names = ["col%d" % i for i in range(5)]
|
||||
@ -520,6 +533,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
field_delim=" ",
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withSelectCols(self):
|
||||
record_defaults = [
|
||||
constant_op.constant([], dtypes.int32),
|
||||
@ -588,6 +602,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
select_columns=[column_names[i] for i in select_cols],
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withSelectColsError(self):
|
||||
record_defaults = [
|
||||
constant_op.constant([], dtypes.int32),
|
||||
@ -626,6 +641,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
label_name=None,
|
||||
select_columns=["invalid_col_name"])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeCSVDataset_withShuffle(self):
|
||||
record_defaults = [
|
||||
constant_op.constant([], dtypes.int32),
|
||||
@ -710,6 +726,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
|
||||
self.assertFalse(all_equal)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testIndefiniteRepeatShapeInference(self):
|
||||
column_names = ["col%d" % i for i in range(5)]
|
||||
inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
|
||||
|
@ -17,19 +17,22 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
|
||||
from tensorflow.python.data.experimental.ops import readers
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class MakeTFRecordDatasetTest(
|
||||
reader_dataset_ops_test_base.TFRecordDatasetTestBase):
|
||||
reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _read_test(self, batch_size, num_epochs, file_index=None,
|
||||
num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
|
||||
@ -63,6 +66,7 @@ class MakeTFRecordDatasetTest(
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(outputs())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRead(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 3]:
|
||||
@ -78,6 +82,7 @@ class MakeTFRecordDatasetTest(
|
||||
# Basic test: read from both files, with parallel reads.
|
||||
self._read_test(batch_size, num_epochs, num_parallel_reads=8)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDropFinalBatch(self):
|
||||
for batch_size in [1, 2, 10]:
|
||||
for num_epochs in [1, 3]:
|
||||
@ -91,6 +96,7 @@ class MakeTFRecordDatasetTest(
|
||||
self._read_test(batch_size, num_epochs, num_parallel_reads=8,
|
||||
drop_final_batch=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testParserFn(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 3]:
|
||||
@ -145,6 +151,7 @@ class MakeTFRecordDatasetTest(
|
||||
actual.extend(b)
|
||||
self.assertAllEqual(sorted(expected), sorted(actual))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testShuffle(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 3]:
|
||||
@ -156,6 +163,7 @@ class MakeTFRecordDatasetTest(
|
||||
self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
|
||||
seed=21345)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testIndefiniteRepeatShapeInference(self):
|
||||
dataset = readers.make_tf_record_dataset(
|
||||
file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
|
||||
|
@ -19,17 +19,19 @@ from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.experimental.ops import map_defun
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
@ -38,9 +40,11 @@ from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/123903858: Add eager and V2 test coverage")
|
||||
class MapDefunTest(test_base.DatasetTestBase):
|
||||
# TODO(b/123903858): Add eager and V2 test coverage
|
||||
class MapDefunTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testNoIntraOpLimit(self):
|
||||
|
||||
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
|
||||
@ -55,6 +59,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
expected = elems * 2 + 3
|
||||
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunSimple(self):
|
||||
|
||||
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
|
||||
@ -67,6 +73,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
expected = elems * 2 + 3
|
||||
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunMismatchedTypes(self):
|
||||
|
||||
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
|
||||
@ -79,6 +87,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(r)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunReduceDim(self):
|
||||
# Tests where the output has a different rank from the input
|
||||
|
||||
@ -92,6 +102,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
expected = constant_op.constant([1, 3, 5])
|
||||
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunMultipleOutputs(self):
|
||||
|
||||
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
|
||||
@ -105,6 +117,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
expected = [elems, elems * 2 + 3]
|
||||
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunShapeInference(self):
|
||||
|
||||
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
|
||||
@ -116,6 +130,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
|
||||
self.assertEqual(result.get_shape(), (3, 2))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunPartialShapeInference(self):
|
||||
|
||||
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
|
||||
@ -126,6 +142,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
|
||||
self.assertEqual(result[0].get_shape().as_list(), [None, 2])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
|
||||
|
||||
@function.defun(input_signature=[
|
||||
@ -145,6 +163,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
"All inputs must have the same dimension 0."):
|
||||
sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunRaisesDefunError(self):
|
||||
|
||||
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
|
||||
@ -157,6 +177,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(result)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunCancelledCorrectly(self):
|
||||
|
||||
@function.defun(input_signature=[tensor_spec.TensorSpec([5], dtypes.int64)])
|
||||
@ -173,6 +195,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
r"indices = 10 is not in \[0, 5\)"):
|
||||
self.evaluate(map_defun_op)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunWithUnspecifiedOutputShape(self):
|
||||
|
||||
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
|
||||
@ -190,6 +214,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
|
||||
self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunWithDifferentOutputShapeEachRun(self):
|
||||
|
||||
@function.defun(
|
||||
@ -204,6 +230,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
self.assertAllEqual(
|
||||
sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunWithWrongOutputShape(self):
|
||||
|
||||
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
|
||||
@ -216,6 +244,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(r)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunWithInvalidInput(self):
|
||||
|
||||
@function.defun(
|
||||
@ -233,6 +263,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(r, feed_dict={p: 0})
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunWithParentCancellation(self):
|
||||
# Checks that a cancellation of the parent graph is threaded through to
|
||||
# MapDefunOp correctly.
|
||||
@ -254,6 +286,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
sess.close()
|
||||
thread.join()
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunWithCapturedInputs(self):
|
||||
c = constant_op.constant(2)
|
||||
|
||||
@ -266,6 +300,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
expected = x + c
|
||||
self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunWithVariantTensor(self):
|
||||
|
||||
@function.defun(
|
||||
@ -288,6 +324,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
actual = self.evaluate(deserialized)
|
||||
self.assertValuesEqual(expected, actual)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunWithVariantTensorAsCaptured(self):
|
||||
|
||||
st = sparse_tensor.SparseTensor(
|
||||
@ -309,6 +347,8 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
actual = self.evaluate(deserialized)
|
||||
self.assertValuesEqual(expected, actual)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testMapDefunWithStrTensor(self):
|
||||
|
||||
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
|
||||
|
@ -28,14 +28,13 @@ from tensorflow.python.data.experimental.ops import threadpool
|
||||
from tensorflow.python.data.experimental.ops import unique
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@ -70,17 +69,13 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
||||
# perform work.
|
||||
self.assertLessEqual(len(thread_ids), num_threads)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", 1, None),
|
||||
("2", 2, None),
|
||||
("3", 4, None),
|
||||
("4", 8, None),
|
||||
("5", 16, None),
|
||||
("6", 4, -1),
|
||||
("7", 4, 0),
|
||||
("8", 4, 1),
|
||||
("9", 4, 4),
|
||||
)
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
num_threads=[1, 2, 4, 8, 16], max_intra_op_parallelism=[None]) +
|
||||
combinations.combine(
|
||||
num_threads=[4], max_intra_op_parallelism=[-1, 0, 4])))
|
||||
def testNumThreadsDeprecated(self, num_threads, max_intra_op_parallelism):
|
||||
|
||||
def override_threadpool_fn(dataset):
|
||||
@ -93,20 +88,17 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
||||
|
||||
self._testNumThreadsHelper(num_threads, override_threadpool_fn)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", 1, None),
|
||||
("2", 2, None),
|
||||
("3", 4, None),
|
||||
("4", 8, None),
|
||||
("5", 16, None),
|
||||
("6", None, 0),
|
||||
("7", None, 1),
|
||||
("8", None, 4),
|
||||
("9", 4, 0),
|
||||
("10", 4, 1),
|
||||
("11", 4, 4),
|
||||
("12", None, None),
|
||||
)
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
num_threads=[1, 2, 4, 8, 16], max_intra_op_parallelism=[None]) +
|
||||
combinations.combine(
|
||||
num_threads=[None], max_intra_op_parallelism=[0, 1, 4]) +
|
||||
combinations.combine(
|
||||
num_threads=[4], max_intra_op_parallelism=[0, 1, 4]) +
|
||||
combinations.combine(
|
||||
num_threads=[None], max_intra_op_parallelism=[None])))
|
||||
def testNumThreads(self, num_threads, max_intra_op_parallelism):
|
||||
|
||||
def override_threadpool_fn(dataset):
|
||||
@ -121,6 +113,7 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
||||
|
||||
self._testNumThreadsHelper(num_threads, override_threadpool_fn)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMaxIntraOpParallelismAsGraphDefInternal(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0)
|
||||
dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1)
|
||||
|
@ -22,24 +22,25 @@ import math
|
||||
import threading
|
||||
import time
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from six.moves import zip_longest
|
||||
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
# TODO(feihugis): refactor this test to be parameterized.
|
||||
class ParallelInterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@ -116,6 +117,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
num_open -= 1
|
||||
break
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPythonImplementation(self):
|
||||
input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
|
||||
[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]
|
||||
@ -136,6 +138,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
|
||||
(index, expected, produced))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPythonImplementationBlockLength(self):
|
||||
input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2
|
||||
expected_elements = [
|
||||
@ -147,6 +150,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
|
||||
(index, expected, produced))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPythonImplementationEmptyLists(self):
|
||||
input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [],
|
||||
[6, 6, 6, 6, 6, 6]]
|
||||
@ -189,18 +193,23 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSingleThreaded(self):
|
||||
self._testSingleThreaded()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSingleThreadedSloppy(self):
|
||||
self._testSingleThreaded(sloppy=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSingleThreadedPrefetch1Itr(self):
|
||||
self._testSingleThreaded(prefetch_input_elements=1)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSingleThreadedPrefetch1ItrSloppy(self):
|
||||
self._testSingleThreaded(prefetch_input_elements=1, sloppy=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSingleThreadedRagged(self):
|
||||
# Tests a sequence with wildly different elements per iterator.
|
||||
self.skipTest("b/131722904")
|
||||
@ -259,9 +268,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTwoThreadsNoContention(self):
|
||||
self._testTwoThreadsNoContention()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTwoThreadsNoContentionSloppy(self):
|
||||
self._testTwoThreadsNoContention(sloppy=True)
|
||||
|
||||
@ -306,9 +317,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTwoThreadsNoContentionWithRaces(self):
|
||||
self._testTwoThreadsNoContentionWithRaces()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTwoThreadsNoContentionWithRacesSloppy(self):
|
||||
self._testTwoThreadsNoContentionWithRaces(sloppy=True)
|
||||
|
||||
@ -343,9 +356,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTwoThreadsNoContentionBlockLength(self):
|
||||
self._testTwoThreadsNoContentionBlockLength()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTwoThreadsNoContentionBlockLengthSloppy(self):
|
||||
self._testTwoThreadsNoContentionBlockLength(sloppy=True)
|
||||
|
||||
@ -391,9 +406,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTwoThreadsNoContentionWithRacesAndBlocking(self):
|
||||
self._testTwoThreadsNoContentionWithRacesAndBlocking()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self):
|
||||
self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
|
||||
|
||||
@ -411,9 +428,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testEmptyInput(self):
|
||||
self._testEmptyInput()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testEmptyInputSloppy(self):
|
||||
self._testEmptyInput(sloppy=True)
|
||||
|
||||
@ -431,9 +450,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNonEmptyInputIntoEmptyOutputs(self):
|
||||
self._testNonEmptyInputIntoEmptyOutputs()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNonEmptyInputIntoEmptyOutputsSloppy(self):
|
||||
self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
|
||||
|
||||
@ -469,12 +490,15 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
"At index %s: %s expected, got: %s" % (i, expected_element,
|
||||
actual_element))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPartiallyEmptyOutputs(self):
|
||||
self._testPartiallyEmptyOutputs()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPartiallyEmptyOutputsSloppy(self):
|
||||
self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDelayedOutputSloppy(self):
|
||||
# Explicitly control the sequence of events to ensure we correctly avoid
|
||||
# head-of-line blocking.
|
||||
@ -500,6 +524,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testBlockLengthWithContentionSloppy(self):
|
||||
self.skipTest("b/131722904")
|
||||
self._clear_coordination_events()
|
||||
@ -557,9 +582,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
self.read_coordination_events[i].acquire()
|
||||
self.write_coordination_events[i].set()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testEarlyExit(self):
|
||||
self._testEarlyExit()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testEarlyExitSloppy(self):
|
||||
self._testEarlyExit(sloppy=True)
|
||||
|
||||
@ -584,12 +611,15 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
|
||||
self.assertItemsEqual(output_values, expected_values)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTooManyReaders(self):
|
||||
self._testTooManyReaders()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTooManyReadersSloppy(self):
|
||||
self._testTooManyReaders(sloppy=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSparse(self):
|
||||
def _map_fn(i):
|
||||
return sparse_tensor.SparseTensor(
|
||||
@ -610,6 +640,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testErrorsInOutputFn(self):
|
||||
self.skipTest("b/131722904")
|
||||
self._clear_coordination_events()
|
||||
@ -642,6 +673,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testErrorsInInputFn(self):
|
||||
|
||||
def map_py_fn(x):
|
||||
@ -687,6 +719,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testErrorsInInterleaveFn(self):
|
||||
|
||||
def map_py_fn(x):
|
||||
@ -730,6 +763,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testShutdownRace(self):
|
||||
dataset = dataset_ops.Dataset.range(20)
|
||||
map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1))
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.example import example_pb2
|
||||
@ -28,11 +29,11 @@ from tensorflow.python.data.experimental.ops import parsing_ops as contrib_parsi
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -50,8 +51,8 @@ feature_lists = lambda d: feature_pb2.FeatureLists(feature_list=d)
|
||||
sequence_example = example_pb2.SequenceExample
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
class ParseExampleDatasetTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _compare_output_to_expected(self, dict_tensors, expected_tensors):
|
||||
self.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
|
||||
@ -107,6 +108,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
self.assertEqual(
|
||||
dataset_ops.get_legacy_output_shapes(dataset)[k].as_list()[1], None)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testEmptySerializedWithAllDefaults(self):
|
||||
sparse_name = "st_a"
|
||||
a_name = "a"
|
||||
@ -145,7 +147,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_values=expected_output,
|
||||
create_iterator_twice=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testEmptySerializedWithoutDefaultsShouldFail(self):
|
||||
input_features = {
|
||||
"st_a":
|
||||
@ -179,7 +181,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_err=(errors_impl.InvalidArgumentError,
|
||||
"Feature: c \\(data type: float\\) is required"))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testDenseNotMatchingShapeShouldFail(self):
|
||||
original = [
|
||||
example(features=features({
|
||||
@ -197,6 +199,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_err=(errors_impl.InvalidArgumentError,
|
||||
"Key: a, Index: 1. Number of float values"))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDenseDefaultNoShapeShouldFail(self):
|
||||
original = [example(features=features({"a": float_feature([1, 1, 3]),})),]
|
||||
|
||||
@ -207,6 +210,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
{"a": parsing_ops.FixedLenFeature(None, dtypes.float32)},
|
||||
expected_err=(ValueError, "Missing shape for feature a"))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSerializedContainingSparse(self):
|
||||
original = [
|
||||
example(features=features({
|
||||
@ -248,6 +252,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_values=expected_output,
|
||||
create_iterator_twice=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSerializedContainingSparseFeature(self):
|
||||
original = [
|
||||
example(features=features({
|
||||
@ -284,6 +289,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_values=expected_output,
|
||||
create_iterator_twice=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSerializedContainingSparseFeatureReuse(self):
|
||||
original = [
|
||||
example(features=features({
|
||||
@ -325,6 +331,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_values=expected_output,
|
||||
create_iterator_twice=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSerializedContaining3DSparseFeature(self):
|
||||
original = [
|
||||
example(features=features({
|
||||
@ -370,6 +377,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_values=expected_output,
|
||||
create_iterator_twice=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSerializedContainingDense(self):
|
||||
aname = "a"
|
||||
bname = "b*has+a:tricky_name"
|
||||
@ -407,6 +415,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
# This test is identical as the previous one except
|
||||
# for the creation of 'serialized'.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSerializedContainingDenseWithConcat(self):
|
||||
aname = "a"
|
||||
bname = "b*has+a:tricky_name"
|
||||
@ -452,6 +461,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_values=expected_output,
|
||||
create_iterator_twice=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSerializedContainingDenseScalar(self):
|
||||
original = [
|
||||
example(features=features({
|
||||
@ -476,6 +486,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_values=expected_output,
|
||||
create_iterator_twice=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSerializedContainingDenseWithDefaults(self):
|
||||
original = [
|
||||
example(features=features({
|
||||
@ -514,6 +525,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_values=expected_output,
|
||||
create_iterator_twice=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSerializedSparseAndSparseFeatureAndDenseWithNoDefault(self):
|
||||
expected_st_a = sparse_tensor.SparseTensorValue( # indices, values, shape
|
||||
np.empty((0, 2), dtype=np.int64), # indices
|
||||
@ -569,6 +581,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_values=expected_output,
|
||||
create_iterator_twice=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testerializedContainingSparseAndSparseFeatureWithReuse(self):
|
||||
expected_idx = sparse_tensor.SparseTensorValue( # indices, values, shape
|
||||
np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
|
||||
@ -667,11 +680,13 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_values=expected_output,
|
||||
create_iterator_twice=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSerializedContainingVarLenDenseLargerBatch(self):
|
||||
np.random.seed(3456)
|
||||
for batch_size in (1, 10, 20, 100, 256):
|
||||
self._testSerializedContainingVarLenDenseLargerBatch(batch_size)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSerializedShapeMismatch(self):
|
||||
aname = "a"
|
||||
bname = "b"
|
||||
@ -724,7 +739,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_err=(ValueError,
|
||||
"Cannot reshape a tensor with 0 elements to shape"))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testSerializedContainingVarLenDense(self):
|
||||
aname = "a"
|
||||
bname = "b"
|
||||
@ -877,6 +892,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
"Unsupported: FixedLenSequenceFeature requires "
|
||||
"allow_missing to be True."))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSerializedContainingRaggedFeatureWithNoPartitions(self):
|
||||
original = [
|
||||
example(
|
||||
@ -922,6 +938,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_values=expected_output,
|
||||
create_iterator_twice=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSerializedContainingRaggedFeatureWithOnePartition(self):
|
||||
original = [
|
||||
example(
|
||||
@ -1040,6 +1057,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
expected_values=expected_output,
|
||||
create_iterator_twice=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSerializedContainingRaggedFeatureWithMultiplePartitions(self):
|
||||
original = [
|
||||
# rt shape: [(batch), 2, None, None]
|
||||
|
@ -17,11 +17,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.data.experimental.ops import prefetching_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
@ -31,9 +34,9 @@ from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# TODO(b/117581999): add eager coverage when supported.
|
||||
class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testPrefetchToDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
@ -57,7 +60,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testPrefetchToSameDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
@ -82,7 +85,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testPrefetchDictToDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
|
||||
device_dataset = host_dataset.apply(
|
||||
@ -106,7 +109,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testPrefetchSparseTensorsToDevice(self):
|
||||
def make_tensor(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
@ -136,7 +139,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testPrefetchToDeviceGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -156,7 +159,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testPrefetchToDeviceWithReInit(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
@ -184,7 +187,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testPrefetchToDeviceGpuWithReInit(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
@ -24,16 +24,17 @@ from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_v1_only("b/121264236")
|
||||
# TODO(b/121264236)
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testPrefetchWithSlackOption(self):
|
||||
"""Determines slack_period based on num devices attached to iterator."""
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
@ -60,6 +61,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPrefetchWithSlackOptionWithoutIterator(self):
|
||||
"""Defaults to slack period of 1 without iterator."""
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
@ -72,6 +74,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset.options()._graph_rewrite_configs())
|
||||
self.assertDatasetProduces(dataset, range(10))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWithPassthroughDataset(self):
|
||||
"""Should still work with a passthrough dataset after prefetch()."""
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
@ -82,6 +85,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.with_options(options)
|
||||
self.assertDatasetProduces(dataset, range(1, 11))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testErrorWithoutPrefetch(self):
|
||||
"""The rewrite fails if there is no prefetch() in the pipeline."""
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
@ -92,6 +96,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = self.getNext(dataset)
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testErrorWithInvalidDataset(self):
|
||||
"""With a nested dataset op after prefetch, the rewrite should fail."""
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
|
@ -32,8 +32,8 @@ from tensorflow.python.data.experimental.ops import scan_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import python_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -47,13 +47,11 @@ def _flat_shapes(dataset):
|
||||
return nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
drop_remainder_cases = [("WithDropRemainder", True),
|
||||
("WithoutDropRemainder", False)]
|
||||
|
||||
@parameterized.named_parameters(drop_remainder_cases)
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(drop_remainder=[True, False])))
|
||||
def testBasic(self, drop_remainder):
|
||||
dataset = dataset_ops.Dataset.range(1024).batch(
|
||||
32, drop_remainder=drop_remainder)
|
||||
@ -64,13 +62,16 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testScalarInputError(self):
|
||||
dataset = dataset_ops.Dataset.range(1024)
|
||||
distribute._RebatchDataset(dataset.batch(4), num_replicas=4)
|
||||
with self.assertRaisesRegexp(ValueError, "at least one dimension"):
|
||||
distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
|
||||
@parameterized.named_parameters(drop_remainder_cases)
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(drop_remainder=[True, False])))
|
||||
def testBatchNotDivisibleByNumReplicas(self, drop_remainder):
|
||||
dataset = dataset_ops.Dataset.range(1024).batch(
|
||||
32, drop_remainder=drop_remainder)
|
||||
@ -89,6 +90,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
i += 4
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testBatchSizeNotDivisibleByNumReplicas2(self):
|
||||
dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
|
||||
@ -102,6 +104,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output.extend([[]]) # Last replica gets an empty batch
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTupleOutput(self):
|
||||
dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
@ -110,6 +113,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for i in range(0, 1024, 8)]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNestedDictionaryOutput(self):
|
||||
dataset = dataset_ops.Dataset.range(1024).map(
|
||||
lambda x: {"a": x, "b": {"c": x}}).batch(32)
|
||||
@ -119,7 +123,9 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for i in range(0, 1024, 8)]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@parameterized.named_parameters(drop_remainder_cases)
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(drop_remainder=[True, False])))
|
||||
def testFinalPartialBatch(self, drop_remainder):
|
||||
dataset = dataset_ops.Dataset.range(1032).batch(
|
||||
32, drop_remainder=drop_remainder)
|
||||
@ -136,7 +142,9 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
[[k for k in range(i, i + 2)] for i in range(1024, 1032, 2)])
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@parameterized.named_parameters(drop_remainder_cases)
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(drop_remainder=[True, False])))
|
||||
def testFinalPartialBatchAfterRebatch(self, drop_remainder):
|
||||
dataset = dataset_ops.Dataset.range(34).batch(
|
||||
32, drop_remainder=drop_remainder)
|
||||
@ -150,6 +158,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output += [[32], [33], [], []]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMultipleBatches(self):
|
||||
dataset = dataset_ops.Dataset.range(128).batch(4).batch(8)
|
||||
self.assertEqual([[None, None]],
|
||||
@ -170,6 +179,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for i in range(0, 128, 8)]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapAndBatch(self):
|
||||
dataset = dataset_ops.Dataset.range(1024).apply(
|
||||
batching.map_and_batch(math_ops.square, 32))
|
||||
@ -180,6 +190,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for i in range(0, 1024, 8)]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapAndBatchWithCapturedInput(self):
|
||||
captured_t = variables.Variable(42)
|
||||
dataset = dataset_ops.Dataset.range(1024).apply(
|
||||
@ -193,6 +204,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
rebatched_dataset, expected_output, requires_initialization=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatch(self):
|
||||
dataset = dataset_ops.Dataset.range(128).batch(
|
||||
4, drop_remainder=True).padded_batch(
|
||||
@ -213,6 +225,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for i in range(0, 128, 8)]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testConcatenate(self):
|
||||
dataset1 = dataset_ops.Dataset.range(64).batch(8)
|
||||
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||
@ -224,6 +237,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
[[i, i + 1] for i in range(0, 32, 2)])
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testConcatenateDifferentShapes(self):
|
||||
dataset1 = dataset_ops.Dataset.range(64).batch(16)
|
||||
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||
@ -235,6 +249,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
[[i, i + 1] for i in range(0, 32, 2)])
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testZip(self):
|
||||
dataset1 = dataset_ops.Dataset.range(64).batch(8)
|
||||
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||
@ -245,6 +260,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testZipDifferentShapes(self):
|
||||
dataset1 = dataset_ops.Dataset.range(64).batch(16)
|
||||
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||
@ -256,6 +272,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for i in range(0, 32, 2)]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFlatMapBatching(self):
|
||||
dataset = dataset_ops.Dataset.range(2).flat_map(
|
||||
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
||||
@ -274,6 +291,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for i in range(0, 32, 8)] # generates 4 elements
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testInterleaveBatching(self):
|
||||
dataset = dataset_ops.Dataset.range(2).interleave(
|
||||
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
||||
@ -290,6 +308,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output += expected_output
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testParallelInterleaveBatching(self):
|
||||
dataset = dataset_ops.Dataset.range(2).interleave(
|
||||
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
||||
@ -307,6 +326,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output += expected_output
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testGroupByWindowStaticBatch(self):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
[[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)])
|
||||
@ -326,6 +346,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for k in range(2)]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testGroupByWindowDynamicBatch(self):
|
||||
# {0, 1, 0, 1, ...}
|
||||
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
|
||||
@ -350,6 +371,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output = [[value] * batch_size for batch_size, value in pairs]
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testGroupByWindowDynamicBatchWithPartialBatch(self):
|
||||
# {0, 1, 0, 1, ...}
|
||||
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
|
||||
@ -371,6 +393,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output = [[value] * batch_size for batch_size, value in pairs]
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testGroupByWindowDynamicBatchWithPartialBatchWithDropRemainder(self):
|
||||
# This test exercises nested batch functionality, dynamic batch size
|
||||
# and drop_remainder=True together.
|
||||
@ -395,6 +418,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output = [[value] * batch_size for batch_size, value in pairs]
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testScanAfterBatch(self):
|
||||
dataset = dataset_ops.Dataset.range(40).batch(10).apply(
|
||||
scan_ops.scan(np.int64(2), lambda state, value: (state, value * state)))
|
||||
@ -405,6 +429,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)] # pylint: disable=g-complex-comprehension
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMakeBatchedFeaturesDataset(self):
|
||||
# Set up
|
||||
fn = os.path.join(self.get_temp_dir(), "tf_record.txt")
|
||||
@ -438,6 +463,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
} for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRaggedTensorDataset(self):
|
||||
# Set up a dataset that produces ragged tensors with a static batch size.
|
||||
row_lengths = np.random.randint(8, size=128)
|
||||
|
@ -24,9 +24,9 @@ import numpy as np
|
||||
from tensorflow.python.data.experimental.ops import resampling
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
@ -34,12 +34,11 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("InitialDistributionKnown", True),
|
||||
("InitialDistributionUnknown", False))
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(initial_known=[True, False])))
|
||||
def testDistribution(self, initial_known):
|
||||
classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
|
||||
target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
|
||||
@ -72,9 +71,9 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
returned_dist = class_counts / total_returned
|
||||
self.assertAllClose(target_dist, returned_dist, atol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("OnlyInitial", True),
|
||||
("NotInitial", False))
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(only_initial_dist=[True, False])))
|
||||
def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist):
|
||||
init_dist = [0.5, 0.5]
|
||||
target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0]
|
||||
@ -99,6 +98,7 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
while True:
|
||||
returned.append(self.evaluate(get_next()))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRandomClasses(self):
|
||||
init_dist = [0.25, 0.25, 0.25, 0.25]
|
||||
target_dist = [0.0, 0.0, 0.0, 1.0]
|
||||
|
@ -17,18 +17,18 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import shuffle_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
||||
class ShuffleAndRepeatTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _build_ds(self, seed, count=5, num_elements=20):
|
||||
return dataset_ops.Dataset.range(num_elements).apply(
|
||||
@ -44,6 +44,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
||||
self.evaluate(get_next())
|
||||
return outputs
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCorrectOutput(self):
|
||||
output = self._gen_outputs(lambda: self._build_ds(10), 100)
|
||||
self.assertSequenceEqual(
|
||||
@ -52,6 +53,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
||||
for i in range(5):
|
||||
self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReshuffling(self):
|
||||
# Check that the output orders of different epochs are indeed different.
|
||||
output = self._gen_outputs(lambda: self._build_ds(10), 100)
|
||||
@ -60,17 +62,20 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
||||
epoch2 = output[(i + 1) * 20:(i + 2) * 20]
|
||||
self.assertNotEqual(epoch1, epoch2)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSameOrderForSameSeeds(self):
|
||||
output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
|
||||
output2 = self._gen_outputs(lambda: self._build_ds(10), 100)
|
||||
self.assertEqual(output1, output2)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDifferentOrderForDifferentSeeds(self):
|
||||
output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
|
||||
output2 = self._gen_outputs(lambda: self._build_ds(20), 100)
|
||||
self.assertNotEqual(output1, output2)
|
||||
self.assertEqual(sorted(output1), sorted(output2))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCountNone(self):
|
||||
output1 = self._gen_outputs(
|
||||
lambda: self._build_ds(10, count=None), 100, verify_exhausted=False)
|
||||
@ -79,6 +84,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
||||
self.assertNotEqual(output1, output2)
|
||||
self.assertEqual(sorted(output1), sorted(output2))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCountMinusOne(self):
|
||||
output1 = self._gen_outputs(
|
||||
lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False)
|
||||
@ -87,6 +93,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
||||
self.assertNotEqual(output1, output2)
|
||||
self.assertEqual(sorted(output1), sorted(output2))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testInfiniteOutputs(self):
|
||||
# Asserting the iterator is exhausted after producing 100 items should fail.
|
||||
with self.assertRaises(AssertionError):
|
||||
@ -94,6 +101,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(AssertionError):
|
||||
self._gen_outputs(lambda: self._build_ds(10, count=-1), 100)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testInfiniteEmpty(self):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0),
|
||||
@ -102,12 +110,14 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
||||
self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0),
|
||||
100)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testLargeBufferSize(self):
|
||||
ds = dataset_ops.Dataset.range(20).apply(
|
||||
shuffle_ops.shuffle_and_repeat(buffer_size=21))
|
||||
get_next = self.getNext(ds)
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testVeryLargeBufferSize(self):
|
||||
num_epochs = 1000 * 1000
|
||||
# Each element being shuffled and repeated has shape (100,). This will OOM
|
||||
|
@ -18,18 +18,22 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
# Test that SqlDataset can read from a database table.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSet(self):
|
||||
for _ in range(2): # Run twice to verify statelessness of db operations.
|
||||
dataset = self._createSqlDataset(
|
||||
@ -44,6 +48,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
num_test_iterations=2)
|
||||
|
||||
# Test that SqlDataset works on a join query.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetJoinQuery(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -60,6 +65,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that SqlDataset can read a database entry with a null-terminator
|
||||
# in the middle of the text and place the entry in a `string` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetNullTerminator(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -76,6 +82,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
# Test that SqlDataset works when used on two different queries.
|
||||
# Because the output types of the dataset must be determined at graph-creation
|
||||
# time, the two queries must have the same number and types of columns.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetReuseSqlDataset(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -100,6 +107,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that an `OutOfRangeError` is raised on the first call to
|
||||
# `get_next_str_only` if result set is empty.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadEmptyResultSet(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -110,6 +118,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.evaluate(get_next())
|
||||
|
||||
# Test that an error is raised when `driver_name` is invalid.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetWithInvalidDriverName(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
dataset = self._createSqlDataset(
|
||||
@ -120,6 +129,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.assertDatasetProduces(dataset, expected_output=[])
|
||||
|
||||
# Test that an error is raised when a column name in `query` is nonexistent
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetWithInvalidColumnName(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -130,6 +140,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.evaluate(get_next())
|
||||
|
||||
# Test that an error is raised when there is a syntax error in `query`.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetOfQueryWithSyntaxError(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -141,6 +152,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that an error is raised when the number of columns in `query`
|
||||
# does not match the length of `, output_types`.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -154,6 +166,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
# than a select query. In particular, the error refers to the number of
|
||||
# output types passed to the op not matching the number of columns in the
|
||||
# result set of the query (namely, 0 for an insert statement.)
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetOfInsertQuery(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -165,6 +178,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||
# place it in an `int8` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetInt8(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -178,6 +192,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||
# SQLite database table and place it in an `int8` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetInt8NegativeAndZero(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -191,6 +206,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||
# a SQLite database table and place it in an `int8` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetInt8MaxValues(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -205,6 +221,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||
# place it in an `int16` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetInt16(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -218,6 +235,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||
# SQLite database table and place it in an `int16` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetInt16NegativeAndZero(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -231,6 +249,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||
# a SQLite database table and place it in an `int16` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetInt16MaxValues(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -246,6 +265,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||
# place it in an `int32` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetInt32(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -257,6 +277,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||
# SQLite database table and place it in an `int32` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetInt32NegativeAndZero(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -270,6 +291,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||
# a SQLite database table and place it in an `int32` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetInt32MaxValues(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -285,6 +307,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read a numeric `varchar` from a SQLite database
|
||||
# table and place it in an `int32` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetInt32VarCharColumnAsInt(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -298,6 +321,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table
|
||||
# and place it in an `int64` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetInt64(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -311,6 +335,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||
# SQLite database table and place it in an `int64` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetInt64NegativeAndZero(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -324,6 +349,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||
# a SQLite database table and place it in an `int64` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetInt64MaxValues(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -339,6 +365,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||
# place it in a `uint8` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetUInt8(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -352,6 +379,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read the minimum and maximum uint8 values from a
|
||||
# SQLite database table and place them in `uint8` tensors.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetUInt8MinAndMaxValues(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -367,6 +395,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table
|
||||
# and place it in a `uint16` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetUInt16(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -380,6 +409,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read the minimum and maximum uint16 values from a
|
||||
# SQLite database table and place them in `uint16` tensors.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetUInt16MinAndMaxValues(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -396,6 +426,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
# Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
|
||||
# SQLite database table and place them as `True` and `False` respectively
|
||||
# in `bool` tensors.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetBool(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -409,6 +440,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
|
||||
# from a SQLite database table and place it as `True` in a `bool` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetBoolNotZeroOrOne(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -422,6 +454,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that `SqlDataset` can read a float from a SQLite database table
|
||||
# and place it in a `float64` tensor.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetFloat64(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -437,6 +470,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
# Test that `SqlDataset` can read a float from a SQLite database table beyond
|
||||
# the precision of 64-bit IEEE, without throwing an error. Test that
|
||||
# `SqlDataset` identifies such a value as equal to itself.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetFloat64OverlyPrecise(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -458,6 +492,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
# representing the largest integer representable as a 64-bit IEEE float
|
||||
# such that the previous integer is also representable as a 64-bit IEEE float.
|
||||
# Test that `SqlDataset` can distinguish these two numbers.
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
|
||||
get_next = self.getNext(
|
||||
self._createSqlDataset(
|
||||
@ -472,6 +507,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.evaluate(get_next())
|
||||
|
||||
# Test that SqlDataset can stop correctly when combined with batch
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadResultSetWithBatchStop(self):
|
||||
dataset = self._createSqlDataset(
|
||||
query="SELECT * FROM data", output_types=(dtypes.int32))
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
|
||||
@ -24,7 +25,9 @@ from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import stats_aggregator
|
||||
from tensorflow.python.data.experimental.ops import stats_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -32,8 +35,11 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
# TODO(jsimsa): Figure out why are graph tests failing.
|
||||
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testBytesProduced(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).map(
|
||||
@ -57,6 +63,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
self.assertStatisticsHasCount(handle, "bytes_produced", 100.0, 101)
|
||||
self.assertStatisticsHasSum(handle, "bytes_produced", expected_sum, 101)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testLatencyStats(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
@ -76,6 +83,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(handle, "record_latency", 100.0, 101)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testPrefetchBufferUtilization(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).map(
|
||||
@ -117,6 +125,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
301,
|
||||
offset=2)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testPrefetchBufferScalars(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(10).map(
|
||||
@ -140,6 +149,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testFilteredElementsStats(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(101).filter(
|
||||
@ -167,6 +177,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
handle, self.regexForNodeName("FilterDataset", "filtered_elements"),
|
||||
34.0)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testReinitialize(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
@ -187,6 +198,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
self.assertStatisticsHasCount(handle, "record_latency", (j + 1) * 100.0,
|
||||
(j * 100) + 101)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testNoAggregatorRegistered(self):
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
@ -198,6 +210,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testMultipleTags(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
@ -221,6 +234,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
handle, "record_latency", 100.0, 201, offset=1)
|
||||
self.assertStatisticsHasCount(handle, "record_latency_2", 100.0, 201)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testRepeatedTags(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
@ -239,6 +253,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(handle, "record_latency", 200.0, 201)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testMultipleIteratorsSameAggregator(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
@ -259,6 +274,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(handle, "record_latency", 200.0, 201)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testMultipleDatasetWithPrefixes(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
@ -289,6 +305,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
self.assertStatisticsHasCount(handle, "dataset2::record_latency", 100.0,
|
||||
201)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testMultiplePrefetchStats(self):
|
||||
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
@ -314,8 +331,10 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
self.evaluate(next_element())
|
||||
|
||||
|
||||
class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testMapBufferUtilization(self):
|
||||
|
||||
def dataset_fn():
|
||||
@ -326,6 +345,7 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
self.parallelCallsStats(
|
||||
dataset_fn, {"ParallelMapDataset"}, 10, function_processing_time=True)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testMapAutoTuneBufferUtilization(self):
|
||||
|
||||
def dataset_fn():
|
||||
@ -336,6 +356,7 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
self.parallelCallsStats(
|
||||
dataset_fn, {"ParallelMapDataset"}, 10, function_processing_time=True)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testInterleaveAutoTuneBufferUtilization(self):
|
||||
|
||||
def dataset_fn():
|
||||
@ -351,6 +372,7 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
|
||||
self.parallelCallsStats(dataset_fn, {"ParallelInterleaveDatasetV2"}, 10)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testMapAndBatchAutoTuneBufferUtilization(self):
|
||||
|
||||
def dataset_fn():
|
||||
@ -370,8 +392,10 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
|
||||
class FeatureStatsDatasetTest(
|
||||
stats_dataset_test_base.StatsDatasetTestBase,
|
||||
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
|
||||
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testFeaturesStats(self):
|
||||
num_epochs = 5
|
||||
total_records = num_epochs * self._num_records
|
||||
|
@ -23,18 +23,21 @@ import numpy as np
|
||||
from tensorflow.python.data.experimental.ops import take_while_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters((14, 2), (15, 2), (100, 3))
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(num_elements=[14, 15], window_size=[2]) +
|
||||
combinations.combine(num_elements=[100], window_size=[3])))
|
||||
def testTakeWhileDataset(self, num_elements, window_size):
|
||||
|
||||
def _predicate_func(elem):
|
||||
@ -49,8 +52,19 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_num_elements = int(num_elements / window_size) * window_size
|
||||
self.assertDatasetProduces(dataset, np.arange(expected_num_elements))
|
||||
|
||||
@parameterized.parameters((10, 2, False), (16, 7, False), (100, 99, False),
|
||||
(100, 101, True), (0, 1, True))
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
num_elements=[10], upper_bound=[2], out_of_bounds=[False]) +
|
||||
combinations.combine(
|
||||
num_elements=[16], upper_bound=[7], out_of_bounds=[False]) +
|
||||
combinations.combine(
|
||||
num_elements=[100], upper_bound=[99], out_of_bounds=[False]) +
|
||||
combinations.combine(
|
||||
num_elements=[100], upper_bound=[101], out_of_bounds=[True]) +
|
||||
combinations.combine(
|
||||
num_elements=[0], upper_bound=[1], out_of_bounds=[True])))
|
||||
def testTakeWhileDatasetRange(self, num_elements, upper_bound, out_of_bounds):
|
||||
dataset = dataset_ops.Dataset.range(num_elements).apply(
|
||||
take_while_ops.take_while(lambda x: x < upper_bound))
|
||||
@ -62,6 +76,7 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
else:
|
||||
self.assertDatasetProduces(dataset, np.arange(upper_bound))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTakeWhileDatasetString(self):
|
||||
|
||||
def not_equal(string):
|
||||
@ -79,7 +94,13 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.assertEqual(b"test", self.evaluate(next_element()))
|
||||
|
||||
@parameterized.parameters((5, 3), (10, 0), (100, 5), (8, 7))
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(size=[5], index=[3]) +
|
||||
combinations.combine(size=[10], index=[0]) +
|
||||
combinations.combine(size=[100], index=[5]) +
|
||||
combinations.combine(size=[8], index=[7])))
|
||||
def testTakewhileDatasetShortCircuit(self, size, index):
|
||||
|
||||
def _predicate_func(data_elem):
|
||||
@ -98,6 +119,7 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTakeWhileDatasetWithRepeat(self):
|
||||
dataset = dataset_ops.Dataset.range(10).apply(
|
||||
take_while_ops.take_while(lambda x: x < 2)).repeat(5)
|
||||
|
@ -19,14 +19,16 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import grouping
|
||||
from tensorflow.python.data.experimental.ops import writers
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import python_io
|
||||
from tensorflow.python.lib.io import tf_record
|
||||
from tensorflow.python.ops import string_ops
|
||||
@ -34,8 +36,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TFRecordWriterTest(test_base.DatasetTestBase):
|
||||
class TFRecordWriterTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TFRecordWriterTest, self).setUp()
|
||||
@ -63,11 +64,13 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
|
||||
def _outputFilename(self):
|
||||
return os.path.join(self.get_temp_dir(), "tf_record.out.txt")
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWrite(self):
|
||||
self.evaluate(self.writer_fn(self._createFile()))
|
||||
for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())):
|
||||
self.assertAllEqual(self._record(i), r)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWriteZLIB(self):
|
||||
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB)
|
||||
self.evaluate(
|
||||
@ -76,6 +79,7 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
|
||||
tf_record.tf_record_iterator(self._outputFilename(), options=options)):
|
||||
self.assertAllEqual(self._record(i), r)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWriteGZIP(self):
|
||||
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
|
||||
self.evaluate(
|
||||
@ -84,20 +88,24 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
|
||||
tf_record.tf_record_iterator(self._outputFilename(), options=options)):
|
||||
self.assertAllEqual(self._record(i), r)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFailDataset(self):
|
||||
with self.assertRaises(TypeError):
|
||||
writers.TFRecordWriter(self._outputFilename(), "").write("whoops")
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFailDType(self):
|
||||
input_dataset = dataset_ops.Dataset.from_tensors(10)
|
||||
with self.assertRaises(TypeError):
|
||||
writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFailShape(self):
|
||||
input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]])
|
||||
with self.assertRaises(TypeError):
|
||||
writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSideEffect(self):
|
||||
def writer_fn():
|
||||
input_dataset = readers.TFRecordDataset(self._createFile())
|
||||
@ -112,6 +120,7 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
|
||||
for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())):
|
||||
self.assertAllEqual(self._record(i), r)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testShard(self):
|
||||
filename = self._createFile()
|
||||
dataset = readers.TFRecordDataset([filename])
|
||||
|
@ -17,17 +17,18 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import unique
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class UniqueTest(test_base.DatasetTestBase):
|
||||
class UniqueTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _testSimpleHelper(self, dtype, test_cases):
|
||||
"""Test the `unique()` transformation on a list of test cases.
|
||||
@ -52,7 +53,7 @@ class UniqueTest(test_base.DatasetTestBase):
|
||||
for element in expected
|
||||
])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testSimpleInt(self):
|
||||
for dtype in [dtypes.int32, dtypes.int64]:
|
||||
self._testSimpleHelper(dtype, [
|
||||
@ -65,7 +66,7 @@ class UniqueTest(test_base.DatasetTestBase):
|
||||
([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]),
|
||||
])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testSimpleString(self):
|
||||
self._testSimpleHelper(dtypes.string, [
|
||||
([], []),
|
||||
|
@ -17,16 +17,18 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class VariantTest(test_base.DatasetTestBase):
|
||||
class VariantTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRoundtripRange(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
variant = dataset_ops.to_variant(dataset)
|
||||
@ -35,6 +37,7 @@ class VariantTest(test_base.DatasetTestBase):
|
||||
self.assertDatasetProduces(dataset, range(10))
|
||||
self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRoundtripMap(self):
|
||||
dataset = dataset_ops.Dataset.range(10).map(lambda x: x*x)
|
||||
variant = dataset_ops.to_variant(dataset)
|
||||
|
@ -17,18 +17,20 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class WrapDatasetVariantTest(test_base.DatasetTestBase):
|
||||
class WrapDatasetVariantTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testBasic(self):
|
||||
ds = dataset_ops.Dataset.range(100)
|
||||
ds_variant = ds._variant_tensor # pylint: disable=protected-access
|
||||
@ -42,7 +44,9 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase):
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(get_next()))
|
||||
|
||||
@test_util.run_v1_only("b/123901304")
|
||||
# TODO(b/123901304)
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testSkipEagerGPU(self):
|
||||
ds = dataset_ops.Dataset.range(100)
|
||||
ds_variant = ds._variant_tensor # pylint: disable=protected-access
|
||||
|
Loading…
x
Reference in New Issue
Block a user