[tf.data] Rolling forward a previously rolled back change with a fix.

PiperOrigin-RevId: 284036647
Change-Id: I9d50ad7aa8123f6928c055a25bc3dc4d69d2b95d
This commit is contained in:
Jiri Simsa 2019-12-05 13:10:45 -08:00 committed by TensorFlower Gardener
parent 7e2b4b8c96
commit 769892b353
29 changed files with 539 additions and 190 deletions

View File

@ -25,11 +25,11 @@ from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@ -73,14 +73,12 @@ def _get_record_shape(sparse):
return tensor_shape.TensorShape([None])
@test_util.run_all_in_graph_and_eager_modes
class BucketBySequenceLengthTest(test_base.DatasetTestBase,
parameterized.TestCase):
@parameterized.named_parameters(
("WithoutPadding", True),
("WithPadding", False),
)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(param_no_padding=[True, False])))
def testBucketDropReminder(self, param_no_padding):
boundaries = [10, 20, 30]
@ -201,10 +199,9 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
_test_bucket_by_padding(param_no_padding)
@parameterized.named_parameters(
("WithoutPadding", True),
("WithPadding", False),
)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(param_no_padding=[True, False])))
def testBucket(self, param_no_padding):
boundaries = [10, 20, 30]
@ -347,10 +344,9 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
@parameterized.named_parameters(
("WithoutPadding", True),
("WithPadding", False),
)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(param_no_padding=[True, False])))
def testTupleElements(self, param_no_padding):
def build_dataset(sparse):
@ -381,10 +377,10 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
_test_tuple_elements_by_padding(param_no_padding)
@parameterized.named_parameters(
("DoDropRemainder", True),
("DoNotDropRemainder", False),
)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(param_drop_remainder=[True, False])))
def testBucketSparse(self, param_drop_remainder): # pylint: disable=g-doc-args
"""Tests bucketing of sparse tensors (case where `no_padding` == True).

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import prefetching_ops
@ -24,6 +26,7 @@ from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import structure
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@ -35,9 +38,9 @@ from tensorflow.python.util import compat as util_compat
# TODO(b/117581999): add eager coverage when supported.
class CopyToDeviceTest(test_base.DatasetTestBase):
class CopyToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
@ -62,7 +65,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceInt32(self):
host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
device_dataset = host_dataset.apply(
@ -86,7 +89,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToSameDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
@ -111,7 +114,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceWithPrefetch(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
@ -136,7 +139,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyDictToDevice(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
device_dataset = host_dataset.apply(
@ -161,7 +164,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyDictToDeviceWithPrefetch(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
device_dataset = host_dataset.apply(
@ -186,7 +189,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopySparseTensorsToDevice(self):
def make_tensor(i):
@ -219,7 +222,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopySparseTensorsToDeviceWithPrefetch(self):
def make_tensor(i):
@ -252,7 +255,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpu(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -273,7 +276,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuWithPrefetch(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -294,7 +297,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuWithMap(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -332,7 +335,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuInt32(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -352,7 +355,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuInt32AndPrefetch(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -372,7 +375,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuStrings(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -392,7 +395,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuStringsAndPrefetch(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -412,7 +415,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDevicePingPongCPUGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -436,7 +439,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceWithReInit(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
@ -465,7 +468,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceWithReInitAndPrefetch(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
@ -494,7 +497,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuWithReInit(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -518,7 +521,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testCopyToDeviceGpuWithReInitAndPrefetch(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -542,7 +545,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testIteratorGetNextAsOptionalOnGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")

View File

@ -17,35 +17,33 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class CounterTest(test_base.DatasetTestBase):
class CounterTest(test_base.DatasetTestBase, parameterized.TestCase):
def testCounter(self):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(start=3, step=4, expected_output=[[3, 7, 11]]) +
combinations.combine(start=0, step=-1, expected_output=[[0, -1, -2]]))
)
def testCounter(self, start, step, expected_output):
"""Test dataset construction using `count`."""
dataset = counter.Counter(start=3, step=4)
dataset = counter.Counter(start, step)
self.assertEqual(
[], dataset_ops.get_legacy_output_shapes(dataset).as_list())
self.assertEqual(dtypes.int64, dataset_ops.get_legacy_output_types(dataset))
get_next = self.getNext(dataset)
negative_dataset = counter.Counter(start=0, step=-1)
negative_get_next = self.getNext(negative_dataset)
self.assertEqual(3, self.evaluate(get_next()))
self.assertEqual(3 + 4, self.evaluate(get_next()))
self.assertEqual(3 + 2 * 4, self.evaluate(get_next()))
self.assertEqual(0, self.evaluate(negative_get_next()))
self.assertEqual(-1, self.evaluate(negative_get_next()))
self.assertEqual(-2, self.evaluate(negative_get_next()))
for expected in expected_output:
self.assertEqual(expected, self.evaluate(get_next()))
if __name__ == "__main__":

View File

@ -22,21 +22,22 @@ import gzip
import os
import zlib
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class CsvDatasetTest(test_base.DatasetTestBase):
class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = []
@ -117,26 +118,31 @@ class CsvDatasetTest(test_base.DatasetTestBase):
dataset = readers.CsvDataset(filenames, **kwargs)
self._verify_output_or_err(dataset, expected_output, expected_err_re)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_requiredFields(self):
record_defaults = [[]] * 4
inputs = [['1,2,3,4']]
self._test_by_comparison(inputs, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_int(self):
record_defaults = [[0]] * 4
inputs = [['1,2,3,4', '5,6,7,8']]
self._test_by_comparison(inputs, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_float(self):
record_defaults = [[0.0]] * 4
inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']]
self._test_by_comparison(inputs, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_string(self):
record_defaults = [['']] * 4
inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']]
self._test_by_comparison(inputs, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withEmptyFields(self):
record_defaults = [[0]] * 4
inputs = [[',,,', '1,1,1,', ',2,2,2']]
@ -144,6 +150,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errWithUnquotedQuotes(self):
record_defaults = [['']] * 3
inputs = [['1,2"3,4']]
@ -152,6 +159,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
expected_err_re='Unquoted fields cannot have quotes inside',
record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errWithUnescapedQuotes(self):
record_defaults = [['']] * 3
inputs = [['"a"b","c","d"']]
@ -161,6 +169,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
'Quote inside a string has to be escaped by another quote',
record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_ignoreErrWithUnescapedQuotes(self):
record_defaults = [['']] * 3
inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
@ -169,6 +178,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
dataset = dataset.apply(error_ops.ignore_errors())
self._verify_output_or_err(dataset, [['e', 'f', 'g']])
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
record_defaults = [['']] * 3
inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
@ -177,12 +187,14 @@ class CsvDatasetTest(test_base.DatasetTestBase):
dataset = dataset.apply(error_ops.ignore_errors())
self._verify_output_or_err(dataset, [['e', 'f', 'g']])
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self):
record_defaults = [['']] * 3
inputs = [['1,2"3,4']]
self._test_by_comparison(
inputs, record_defaults=record_defaults, use_quote_delim=False)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_mixedTypes(self):
record_defaults = [
constant_op.constant([], dtype=dtypes.int32),
@ -193,30 +205,35 @@ class CsvDatasetTest(test_base.DatasetTestBase):
inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']]
self._test_by_comparison(inputs, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withUseQuoteDelimFalse(self):
record_defaults = [['']] * 4
inputs = [['1,2,"3,4"', '"5,6",7,8']]
self._test_by_comparison(
inputs, record_defaults=record_defaults, use_quote_delim=False)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withFieldDelim(self):
record_defaults = [[0]] * 4
inputs = [['1:2:3:4', '5:6:7:8']]
self._test_by_comparison(
inputs, record_defaults=record_defaults, field_delim=':')
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withNaValue(self):
record_defaults = [[0]] * 4
inputs = [['1,NA,3,4', 'NA,6,7,8']]
self._test_by_comparison(
inputs, record_defaults=record_defaults, na_value='NA')
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withSelectCols(self):
record_defaults = [['']] * 2
inputs = [['1,2,3,4', '"5","6","7","8"']]
self._test_by_comparison(
inputs, record_defaults=record_defaults, select_cols=[1, 2])
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withSelectColsTooHigh(self):
record_defaults = [[0]] * 2
inputs = [['1,2,3,4', '5,6,7,8']]
@ -226,23 +243,27 @@ class CsvDatasetTest(test_base.DatasetTestBase):
record_defaults=record_defaults,
select_cols=[3, 4])
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withOneCol(self):
record_defaults = [['NA']]
inputs = [['0', '', '2']]
self._test_dataset(
inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withMultipleFiles(self):
record_defaults = [[0]] * 4
inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']]
self._test_by_comparison(inputs, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withLeadingAndTrailingSpaces(self):
record_defaults = [[0.0]] * 4
inputs = [['0, 1, 2, 3']]
expected = [[0.0, 1.0, 2.0, 3.0]]
self._test_dataset(inputs, expected, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errorWithMissingDefault(self):
record_defaults = [[]] * 2
inputs = [['0,']]
@ -251,6 +272,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
expected_err_re='Field 1 is required but missing in record!',
record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errorWithFewerDefaultsThanFields(self):
record_defaults = [[0.0]] * 2
inputs = [['0,1,2,3']]
@ -259,6 +281,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
expected_err_re='Expect 2 fields but have more in record',
record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errorWithMoreDefaultsThanFields(self):
record_defaults = [[0.0]] * 5
inputs = [['0,1,2,3']]
@ -267,6 +290,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
expected_err_re='Expect 5 fields but have 4 in record',
record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withHeader(self):
record_defaults = [[0]] * 2
inputs = [['col1,col2', '1,2']]
@ -278,6 +302,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
header=True,
)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withHeaderAndNoRecords(self):
record_defaults = [[0]] * 2
inputs = [['col1,col2']]
@ -289,6 +314,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
header=True,
)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errorWithHeaderEmptyFile(self):
record_defaults = [[0]] * 2
inputs = [[]]
@ -300,12 +326,14 @@ class CsvDatasetTest(test_base.DatasetTestBase):
header=True,
)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withEmptyFile(self):
record_defaults = [['']] * 2
inputs = [['']] # Empty file
self._test_dataset(
inputs, expected_output=[], record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errorWithEmptyRecord(self):
record_defaults = [['']] * 2
inputs = [['', '1,2']] # First record is empty
@ -314,6 +342,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
expected_err_re='Expect 2 fields but have 1 in record',
record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withChainedOps(self):
# Testing that one dataset can create multiple iterators fine.
# `repeat` creates multiple iterators from the same C++ Dataset.
@ -325,6 +354,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
ds_actual.repeat(5).prefetch(1),
ds_expected.repeat(5).prefetch(1))
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withTypeDefaults(self):
# Testing using dtypes as record_defaults for required fields
record_defaults = [dtypes.float32, [0.0]]
@ -335,6 +365,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
record_defaults=record_defaults,
)
@combinations.generate(test_base.default_test_combinations())
def testMakeCsvDataset_fieldOrder(self):
data = [[
'1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19',
@ -352,6 +383,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
## The following tests exercise parsing logic for quoted fields
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withQuoted(self):
record_defaults = [['']] * 4
inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']]
@ -363,6 +395,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
self._test_dataset(
inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withNewLine(self):
# In this case, we expect it to behave differently from
# TextLineDataset->map(decode_csv) since that flow has bugs
@ -371,6 +404,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']]
self._test_dataset(inputs, expected, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withNewLineInUnselectedCol(self):
record_defaults = [['']]
inputs = [['1,"2\n3",4', '5,6,7']]
@ -380,6 +414,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
record_defaults=record_defaults,
select_cols=[0])
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withMultipleNewLines(self):
# In this case, we expect it to behave differently from
# TextLineDataset->map(decode_csv) since that flow has bugs
@ -388,6 +423,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']]
self._test_dataset(inputs, expected, record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_errorWithTerminateMidRecord(self):
record_defaults = [['']] * 4
inputs = [['a,b,c,"a']]
@ -397,6 +433,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
'Reached end of file without closing quoted field in record',
record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withEscapedQuotes(self):
record_defaults = [['']] * 4
inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']]
@ -406,6 +443,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
## Testing that parsing works with all buffer sizes, quoted/unquoted fields,
## and different types of line breaks
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withInvalidBufferSize(self):
record_defaults = [['']] * 4
inputs = [['a,b,c,d']]
@ -432,6 +470,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
record_defaults=record_defaults,
buffer_size=i)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withLF(self):
record_defaults = [['NA']] * 3
inputs = [['abc,def,ghi', '0,1,2', ',,']]
@ -439,6 +478,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\n', record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withCR(self):
# Test that when the line separator is '\r', parsing works with all buffer
# sizes
@ -448,6 +488,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\r', record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withCRLF(self):
# Test that when the line separator is '\r\n', parsing works with all buffer
# sizes
@ -457,6 +498,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withBufferSizeAndQuoted(self):
record_defaults = [['NA']] * 3
inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
@ -465,6 +507,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\n', record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withCRAndQuoted(self):
# Test that when the line separator is '\r', parsing works with all buffer
# sizes
@ -475,6 +518,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\r', record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withCRLFAndQuoted(self):
# Test that when the line separator is '\r\n', parsing works with all buffer
# sizes
@ -485,6 +529,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withGzipCompressionType(self):
record_defaults = [['NA']] * 3
inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
@ -497,6 +542,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
compression_type='GZIP',
record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withZlibCompressionType(self):
record_defaults = [['NA']] * 3
inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
@ -509,6 +555,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
compression_type='ZLIB',
record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_withScalarDefaults(self):
record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4
inputs = [[',,,', '1,1,1,', ',2,2,2']]
@ -516,6 +563,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
record_defaults=record_defaults)
@combinations.generate(test_base.default_test_combinations())
def testCsvDataset_with2DDefaults(self):
record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4
inputs = [[',,,', '1,1,1,', ',2,2,2']]

View File

@ -17,20 +17,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class DenseToSparseBatchTest(test_base.DatasetTestBase):
class DenseToSparseBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testDenseToSparseBatchDataset(self):
components = np.random.randint(12, size=(100,)).astype(np.int32)
dataset = dataset_ops.Dataset.from_tensor_slices(
@ -53,6 +54,7 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testDenseToSparseBatchDatasetWithUnknownShape(self):
components = np.random.randint(5, size=(40,)).astype(np.int32)
dataset = dataset_ops.Dataset.from_tensor_slices(
@ -80,12 +82,14 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testDenseToSparseBatchDatasetWithInvalidShape(self):
input_tensor = array_ops.constant([[1]])
with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
dataset_ops.Dataset.from_tensors(input_tensor).apply(
batching.dense_to_sparse_batch(4, [-2]))
@combinations.generate(test_base.default_test_combinations())
def testDenseToSparseBatchDatasetShapeErrors(self):
def dataset_fn(input_tensor):

View File

@ -17,22 +17,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
class DirectedInterleaveDatasetTest(test_base.DatasetTestBase,
parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testBasic(self):
selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
input_datasets = [
@ -76,6 +78,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
return freqs
@combinations.generate(test_base.default_test_combinations())
def testSampleFromDatasets(self):
random_seed.set_random_seed(1619)
num_samples = 5000
@ -95,6 +98,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples)
self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
@combinations.generate(test_base.default_test_combinations())
def testSelectFromDatasets(self):
words = [b"foo", b"bar", b"baz"]
datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words]
@ -107,6 +111,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testErrors(self):
with self.assertRaisesRegexp(ValueError,
r"vector of length `len\(datasets\)`"):

View File

@ -23,25 +23,30 @@ from tensorflow.python.data.experimental.ops import get_single_element
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("Zero", 0, 1),
("Five", 5, 1),
("Ten", 10, 1),
("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."),
("MoreThanOne", 0, 2, errors.InvalidArgumentError,
"Dataset had more than one element."),
)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
skip=[0, 5, 10], take=[1], error=[None], error_msg=[None]) +
combinations.combine(
skip=[100],
take=[1],
error=[errors.InvalidArgumentError],
error_msg=["Dataset was empty."]) + combinations.combine(
skip=[0],
take=[2],
error=[errors.InvalidArgumentError],
error_msg=["Dataset had more than one element."])))
def testGetSingleElement(self, skip, take, error=None, error_msg=None):
def make_sparse(x):
@ -62,6 +67,7 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaisesRegexp(error, error_msg):
self.evaluate(get_single_element.get_single_element(dataset))
@combinations.generate(test_base.default_test_combinations())
def testWindow(self):
"""Test that `get_single_element()` can consume a nested dataset."""
def flat_map_func(ds):
@ -73,6 +79,7 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
dataset, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
@combinations.generate(test_base.default_test_combinations())
def testSideEffect(self):
counter_var = variables.Variable(0)
@ -92,6 +99,7 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(self.evaluate(fn()), b"hello")
self.assertEqual(self.evaluate(counter_var), 1)
@combinations.generate(test_base.default_test_combinations())
def testAutomaticControlDependencies(self):
counter_var = variables.Variable(1)

View File

@ -17,25 +17,26 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class GroupByReducerTest(test_base.DatasetTestBase):
class GroupByReducerTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testSum(self):
reducer = grouping.Reducer(
init_func=lambda _: np.int64(0),
@ -49,6 +50,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
expected_shapes=tensor_shape.TensorShape([]),
expected_output=[(i - 1) * i, i * i])
@combinations.generate(test_base.default_test_combinations())
def testAverage(self):
def reduce_fn(x, y):
@ -68,6 +70,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
expected_shapes=tensor_shape.TensorShape([]),
expected_output=[i - 1, i])
@combinations.generate(test_base.default_test_combinations())
def testConcat(self):
components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
reducer = grouping.Reducer(
@ -84,6 +87,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
expected_shapes=tensor_shape.TensorShape([]),
expected_output=[b"acegikmoqs"[:i], b"bdfhjlnprt"[:i]])
@combinations.generate(test_base.default_test_combinations())
def testSparseSum(self):
def _sparse(i):
return sparse_tensor.SparseTensorValue(
@ -103,6 +107,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
expected_shapes=tensor_shape.TensorShape([]),
expected_output=[(i - 1) * i, i * i])
@combinations.generate(test_base.default_test_combinations())
def testChangingStateShape(self):
def reduce_fn(x, _):
@ -130,6 +135,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testTypeMismatch(self):
reducer = grouping.Reducer(
init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
@ -144,6 +150,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
grouping.group_by_reducer(lambda _: np.int64(0), reducer))
# TODO(b/78665031): Remove once non-scalar keys are supported.
@combinations.generate(test_base.default_test_combinations())
def testInvalidKeyShape(self):
reducer = grouping.Reducer(
init_func=lambda x: np.int64(0),
@ -157,6 +164,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
# TODO(b/78665031): Remove once non-int64 keys are supported.
@combinations.generate(test_base.default_test_combinations())
def testInvalidKeyType(self):
reducer = grouping.Reducer(
init_func=lambda x: np.int64(0),
@ -169,6 +177,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
dataset.apply(
grouping.group_by_reducer(lambda _: "wrong", reducer))
@combinations.generate(test_base.default_test_combinations())
def testTuple(self):
def init_fn(_):
return np.array([], dtype=np.int64), np.int64(0)

View File

@ -17,17 +17,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
@ -37,8 +38,7 @@ from tensorflow.python.platform import test
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
# Currently, they use a constant batch size, though should be made to use a
# different batch size per key.
@test_util.run_all_in_graph_and_eager_modes
class GroupByWindowTest(test_base.DatasetTestBase):
class GroupByWindowTest(test_base.DatasetTestBase, parameterized.TestCase):
def _dynamicPad(self, bucket, window, window_size):
# TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
@ -51,6 +51,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
[None]), tensor_shape.TensorShape([3])))))
@combinations.generate(test_base.default_test_combinations())
def testSingleBucket(self):
def _map_fn(v):
@ -80,6 +81,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
self.assertAllEqual(expected_unk_int64, bucketed_values[1])
self.assertAllEqual(expected_vec3_str, bucketed_values[2])
@combinations.generate(test_base.default_test_combinations())
def testEvenOddBuckets(self):
def _map_fn(v):
@ -132,6 +134,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
@combinations.generate(test_base.default_test_combinations())
def testEvenOddBucketsFilterOutAllOdd(self):
def _map_fn(v):
@ -173,6 +176,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
self.assertAllEqual(
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
@combinations.generate(test_base.default_test_combinations())
def testDynamicWindowSize(self):
components = np.arange(100).astype(np.int64)
@ -202,6 +206,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
self.assertEqual(batches, 15)
@combinations.generate(test_base.default_test_combinations())
def testSimple(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
dataset = dataset_ops.Dataset.from_tensor_slices(
@ -222,6 +227,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
self.assertGreaterEqual(num_full_batches, 24)
self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
@combinations.generate(test_base.default_test_combinations())
def testImmediateOutput(self):
components = np.array(
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
@ -240,6 +246,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
self.assertAllEqual([2, 2, 2, 2], self.evaluate(get_next()))
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next()))
@combinations.generate(test_base.default_test_combinations())
def testSmallGroups(self):
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
@ -252,6 +259,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
self.assertAllEqual([0, 0, 0], self.evaluate(get_next()))
self.assertAllEqual([1], self.evaluate(get_next()))
@combinations.generate(test_base.default_test_combinations())
def testEmpty(self):
dataset = dataset_ops.Dataset.range(4).apply(
grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
@ -262,6 +270,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
"Window size must be greater than zero, but got 0."):
print(self.evaluate(get_next()))
@combinations.generate(test_base.default_test_combinations())
def testReduceFuncError(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
@ -280,6 +289,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testConsumeWindowDatasetMoreThanOnce(self):
components = np.random.randint(50, size=(200,)).astype(np.int64)
@ -311,6 +321,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
counts.append(tight_result.shape[0])
self.assertEqual(len(components), sum(counts))
@combinations.generate(test_base.default_test_combinations())
def testShortCircuit(self):
dataset = dataset_ops.Dataset.range(10)

View File

@ -19,14 +19,15 @@ from __future__ import print_function
import os
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import io_ops
@ -36,9 +37,9 @@ from tensorflow.python.util import compat
_NUMPY_RANDOM_SEED = 42
@test_util.run_all_in_graph_and_eager_modes
class IgnoreErrorsTest(test_base.DatasetTestBase):
class IgnoreErrorsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
@ -53,6 +54,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testParallelMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
@ -67,6 +69,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testReadFileIgnoreError(self):
def write_string_to_file(value, filename):
@ -102,6 +105,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testTFRecordDatasetIgnoreError(self):
filenames = []
for i in range(5):
@ -126,6 +130,7 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testZipIgnoreError(self):
a = dataset_ops.Dataset.from_tensor_slices([1., 2., 0., 4.])
b = a.map(lambda x: array_ops.check_numerics(1. / x, "error"))

View File

@ -17,26 +17,29 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import nest
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class MakeBatchedFeaturesDatasetTest(
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase,
parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testRead(self):
for batch_size in [1, 2]:
for num_epochs in [1, 10]:
@ -85,6 +88,7 @@ class MakeBatchedFeaturesDatasetTest(
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch()
@combinations.generate(test_base.default_test_combinations())
def testReadWithEquivalentDataset(self):
features = {
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
@ -103,6 +107,7 @@ class MakeBatchedFeaturesDatasetTest(
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testReadWithFusedShuffleRepeatDataset(self):
num_epochs = 5
total_records = num_epochs * self._num_records
@ -151,6 +156,7 @@ class MakeBatchedFeaturesDatasetTest(
all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
self.assertFalse(all_equal)
@combinations.generate(test_base.default_test_combinations())
def testParallelReadersAndParsers(self):
num_epochs = 5
for batch_size in [1, 2]:
@ -186,6 +192,7 @@ class MakeBatchedFeaturesDatasetTest(
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch()
@combinations.generate(test_base.default_test_combinations())
def testDropFinalBatch(self):
for batch_size in [1, 2]:
for num_epochs in [1, 10]:
@ -201,6 +208,7 @@ class MakeBatchedFeaturesDatasetTest(
if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
self.assertEqual(tensor.shape[0], batch_size)
@combinations.generate(test_base.default_test_combinations())
def testIndefiniteRepeatShapeInference(self):
dataset = self.make_batch_feature(
filenames=self.test_filenames[0],
@ -213,6 +221,7 @@ class MakeBatchedFeaturesDatasetTest(
if issubclass(clazz, ops.Tensor):
self.assertEqual(32, shape[0])
@combinations.generate(test_base.default_test_combinations())
def testOldStyleReader(self):
with self.assertRaisesRegexp(
TypeError, r"The `reader` argument must return a `Dataset` object. "

View File

@ -21,21 +21,21 @@ import gzip
import os
import zlib
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class MakeCsvDatasetTest(test_base.DatasetTestBase):
class MakeCsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
return readers.make_csv_dataset(
@ -126,6 +126,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
self._verify_output(dataset, batch_size, num_epochs, label_name,
expected_output, expected_keys)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset(self):
"""Tests making a CSV dataset with keys and defaults provided."""
record_defaults = [
@ -157,6 +158,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
column_defaults=record_defaults,
)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withBatchSizeAndEpochs(self):
"""Tests making a CSV dataset with keys and defaults provided."""
record_defaults = [
@ -188,6 +190,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
column_defaults=record_defaults,
)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withCompressionType(self):
"""Tests `compression_type` argument."""
record_defaults = [
@ -221,6 +224,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
compression_type=compression_type,
)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withCompressionTypeAndNoColumnNames(self):
"""Tests `compression_type` argument."""
record_defaults = [
@ -269,6 +273,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
compression_type="ZLIB",
)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withBadInputs(self):
"""Tests that exception is raised when input is malformed.
"""
@ -304,6 +309,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
label_name="not_a_real_label",
column_names=column_names)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withNoLabel(self):
"""Tests making a CSV dataset with no label provided."""
record_defaults = [
@ -333,6 +339,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
column_defaults=record_defaults,
)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withNoHeader(self):
"""Tests that datasets can be created from CSV files with no header line.
"""
@ -363,6 +370,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
column_defaults=record_defaults,
)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withTypes(self):
"""Tests that defaults can be a dtype instead of a Tensor for required vals.
"""
@ -394,6 +402,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
column_defaults=record_defaults,
)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withNoColNames(self):
"""Tests that datasets can be created when column names are not specified.
@ -427,6 +436,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
column_defaults=record_defaults,
)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withTypeInferenceMismatch(self):
# Test that error is thrown when num fields doesn't match columns
column_names = ["col%d" % i for i in range(5)]
@ -442,6 +452,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
batch_size=2,
num_epochs=10)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withTypeInference(self):
"""Tests that datasets can be created when no defaults are specified.
@ -468,6 +479,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
header=True,
)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withTypeInferenceFallthrough(self):
"""Tests that datasets can be created when no defaults are specified.
@ -498,6 +510,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
header=True,
)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withNAValuesAndFieldDelim(self):
"""Tests that datasets can be created from different delim and na_value."""
column_names = ["col%d" % i for i in range(5)]
@ -520,6 +533,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
field_delim=" ",
)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withSelectCols(self):
record_defaults = [
constant_op.constant([], dtypes.int32),
@ -588,6 +602,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
select_columns=[column_names[i] for i in select_cols],
)
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withSelectColsError(self):
record_defaults = [
constant_op.constant([], dtypes.int32),
@ -626,6 +641,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
label_name=None,
select_columns=["invalid_col_name"])
@combinations.generate(test_base.default_test_combinations())
def testMakeCSVDataset_withShuffle(self):
record_defaults = [
constant_op.constant([], dtypes.int32),
@ -710,6 +726,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
self.assertFalse(all_equal)
@combinations.generate(test_base.default_test_combinations())
def testIndefiniteRepeatShapeInference(self):
column_names = ["col%d" % i for i in range(5)]
inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [

View File

@ -17,19 +17,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class MakeTFRecordDatasetTest(
reader_dataset_ops_test_base.TFRecordDatasetTestBase):
reader_dataset_ops_test_base.TFRecordDatasetTestBase,
parameterized.TestCase):
def _read_test(self, batch_size, num_epochs, file_index=None,
num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
@ -63,6 +66,7 @@ class MakeTFRecordDatasetTest(
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(outputs())
@combinations.generate(test_base.default_test_combinations())
def testRead(self):
for batch_size in [1, 2]:
for num_epochs in [1, 3]:
@ -78,6 +82,7 @@ class MakeTFRecordDatasetTest(
# Basic test: read from both files, with parallel reads.
self._read_test(batch_size, num_epochs, num_parallel_reads=8)
@combinations.generate(test_base.default_test_combinations())
def testDropFinalBatch(self):
for batch_size in [1, 2, 10]:
for num_epochs in [1, 3]:
@ -91,6 +96,7 @@ class MakeTFRecordDatasetTest(
self._read_test(batch_size, num_epochs, num_parallel_reads=8,
drop_final_batch=True)
@combinations.generate(test_base.default_test_combinations())
def testParserFn(self):
for batch_size in [1, 2]:
for num_epochs in [1, 3]:
@ -145,6 +151,7 @@ class MakeTFRecordDatasetTest(
actual.extend(b)
self.assertAllEqual(sorted(expected), sorted(actual))
@combinations.generate(test_base.default_test_combinations())
def testShuffle(self):
for batch_size in [1, 2]:
for num_epochs in [1, 3]:
@ -156,6 +163,7 @@ class MakeTFRecordDatasetTest(
self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
seed=21345)
@combinations.generate(test_base.default_test_combinations())
def testIndefiniteRepeatShapeInference(self):
dataset = readers.make_tf_record_dataset(
file_pattern=self.test_filenames, num_epochs=None, batch_size=32)

View File

@ -19,17 +19,19 @@ from __future__ import print_function
import time
from absl.testing import parameterized
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import map_defun
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.eager import function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import data_flow_ops
@ -38,9 +40,11 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
@test_util.run_v1_only("b/123903858: Add eager and V2 test coverage")
class MapDefunTest(test_base.DatasetTestBase):
# TODO(b/123903858): Add eager and V2 test coverage
class MapDefunTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testNoIntraOpLimit(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
@ -55,6 +59,8 @@ class MapDefunTest(test_base.DatasetTestBase):
expected = elems * 2 + 3
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunSimple(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
@ -67,6 +73,8 @@ class MapDefunTest(test_base.DatasetTestBase):
expected = elems * 2 + 3
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunMismatchedTypes(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
@ -79,6 +87,8 @@ class MapDefunTest(test_base.DatasetTestBase):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(r)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunReduceDim(self):
# Tests where the output has a different rank from the input
@ -92,6 +102,8 @@ class MapDefunTest(test_base.DatasetTestBase):
expected = constant_op.constant([1, 3, 5])
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunMultipleOutputs(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
@ -105,6 +117,8 @@ class MapDefunTest(test_base.DatasetTestBase):
expected = [elems, elems * 2 + 3]
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunShapeInference(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
@ -116,6 +130,8 @@ class MapDefunTest(test_base.DatasetTestBase):
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
self.assertEqual(result.get_shape(), (3, 2))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunPartialShapeInference(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
@ -126,6 +142,8 @@ class MapDefunTest(test_base.DatasetTestBase):
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
self.assertEqual(result[0].get_shape().as_list(), [None, 2])
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
@function.defun(input_signature=[
@ -145,6 +163,8 @@ class MapDefunTest(test_base.DatasetTestBase):
"All inputs must have the same dimension 0."):
sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunRaisesDefunError(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
@ -157,6 +177,8 @@ class MapDefunTest(test_base.DatasetTestBase):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(result)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunCancelledCorrectly(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([5], dtypes.int64)])
@ -173,6 +195,8 @@ class MapDefunTest(test_base.DatasetTestBase):
r"indices = 10 is not in \[0, 5\)"):
self.evaluate(map_defun_op)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithUnspecifiedOutputShape(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
@ -190,6 +214,8 @@ class MapDefunTest(test_base.DatasetTestBase):
self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithDifferentOutputShapeEachRun(self):
@function.defun(
@ -204,6 +230,8 @@ class MapDefunTest(test_base.DatasetTestBase):
self.assertAllEqual(
sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithWrongOutputShape(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
@ -216,6 +244,8 @@ class MapDefunTest(test_base.DatasetTestBase):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(r)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithInvalidInput(self):
@function.defun(
@ -233,6 +263,8 @@ class MapDefunTest(test_base.DatasetTestBase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(r, feed_dict={p: 0})
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithParentCancellation(self):
# Checks that a cancellation of the parent graph is threaded through to
# MapDefunOp correctly.
@ -254,6 +286,8 @@ class MapDefunTest(test_base.DatasetTestBase):
sess.close()
thread.join()
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithCapturedInputs(self):
c = constant_op.constant(2)
@ -266,6 +300,8 @@ class MapDefunTest(test_base.DatasetTestBase):
expected = x + c
self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithVariantTensor(self):
@function.defun(
@ -288,6 +324,8 @@ class MapDefunTest(test_base.DatasetTestBase):
actual = self.evaluate(deserialized)
self.assertValuesEqual(expected, actual)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithVariantTensorAsCaptured(self):
st = sparse_tensor.SparseTensor(
@ -309,6 +347,8 @@ class MapDefunTest(test_base.DatasetTestBase):
actual = self.evaluate(deserialized)
self.assertValuesEqual(expected, actual)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithStrTensor(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])

View File

@ -28,14 +28,13 @@ from tensorflow.python.data.experimental.ops import threadpool
from tensorflow.python.data.experimental.ops import unique
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class OverrideThreadpoolTest(test_base.DatasetTestBase,
parameterized.TestCase):
@ -70,17 +69,13 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
# perform work.
self.assertLessEqual(len(thread_ids), num_threads)
@parameterized.named_parameters(
("1", 1, None),
("2", 2, None),
("3", 4, None),
("4", 8, None),
("5", 16, None),
("6", 4, -1),
("7", 4, 0),
("8", 4, 1),
("9", 4, 4),
)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
num_threads=[1, 2, 4, 8, 16], max_intra_op_parallelism=[None]) +
combinations.combine(
num_threads=[4], max_intra_op_parallelism=[-1, 0, 4])))
def testNumThreadsDeprecated(self, num_threads, max_intra_op_parallelism):
def override_threadpool_fn(dataset):
@ -93,20 +88,17 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
self._testNumThreadsHelper(num_threads, override_threadpool_fn)
@parameterized.named_parameters(
("1", 1, None),
("2", 2, None),
("3", 4, None),
("4", 8, None),
("5", 16, None),
("6", None, 0),
("7", None, 1),
("8", None, 4),
("9", 4, 0),
("10", 4, 1),
("11", 4, 4),
("12", None, None),
)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
num_threads=[1, 2, 4, 8, 16], max_intra_op_parallelism=[None]) +
combinations.combine(
num_threads=[None], max_intra_op_parallelism=[0, 1, 4]) +
combinations.combine(
num_threads=[4], max_intra_op_parallelism=[0, 1, 4]) +
combinations.combine(
num_threads=[None], max_intra_op_parallelism=[None])))
def testNumThreads(self, num_threads, max_intra_op_parallelism):
def override_threadpool_fn(dataset):
@ -121,6 +113,7 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
self._testNumThreadsHelper(num_threads, override_threadpool_fn)
@combinations.generate(test_base.default_test_combinations())
def testMaxIntraOpParallelismAsGraphDefInternal(self):
dataset = dataset_ops.Dataset.from_tensors(0)
dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1)

View File

@ -22,24 +22,25 @@ import math
import threading
import time
from absl.testing import parameterized
import numpy as np
from six.moves import zip_longest
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ParallelInterleaveTest(test_base.DatasetTestBase):
# TODO(feihugis): refactor this test to be parameterized.
class ParallelInterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
def setUp(self):
@ -116,6 +117,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
num_open -= 1
break
@combinations.generate(test_base.default_test_combinations())
def testPythonImplementation(self):
input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]
@ -136,6 +138,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
(index, expected, produced))
@combinations.generate(test_base.default_test_combinations())
def testPythonImplementationBlockLength(self):
input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2
expected_elements = [
@ -147,6 +150,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
(index, expected, produced))
@combinations.generate(test_base.default_test_combinations())
def testPythonImplementationEmptyLists(self):
input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [],
[6, 6, 6, 6, 6, 6]]
@ -189,18 +193,23 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testSingleThreaded(self):
self._testSingleThreaded()
@combinations.generate(test_base.default_test_combinations())
def testSingleThreadedSloppy(self):
self._testSingleThreaded(sloppy=True)
@combinations.generate(test_base.default_test_combinations())
def testSingleThreadedPrefetch1Itr(self):
self._testSingleThreaded(prefetch_input_elements=1)
@combinations.generate(test_base.default_test_combinations())
def testSingleThreadedPrefetch1ItrSloppy(self):
self._testSingleThreaded(prefetch_input_elements=1, sloppy=True)
@combinations.generate(test_base.default_test_combinations())
def testSingleThreadedRagged(self):
# Tests a sequence with wildly different elements per iterator.
self.skipTest("b/131722904")
@ -259,9 +268,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContention(self):
self._testTwoThreadsNoContention()
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContentionSloppy(self):
self._testTwoThreadsNoContention(sloppy=True)
@ -306,9 +317,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContentionWithRaces(self):
self._testTwoThreadsNoContentionWithRaces()
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContentionWithRacesSloppy(self):
self._testTwoThreadsNoContentionWithRaces(sloppy=True)
@ -343,9 +356,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContentionBlockLength(self):
self._testTwoThreadsNoContentionBlockLength()
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContentionBlockLengthSloppy(self):
self._testTwoThreadsNoContentionBlockLength(sloppy=True)
@ -391,9 +406,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContentionWithRacesAndBlocking(self):
self._testTwoThreadsNoContentionWithRacesAndBlocking()
@combinations.generate(test_base.default_test_combinations())
def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self):
self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
@ -411,9 +428,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testEmptyInput(self):
self._testEmptyInput()
@combinations.generate(test_base.default_test_combinations())
def testEmptyInputSloppy(self):
self._testEmptyInput(sloppy=True)
@ -431,9 +450,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testNonEmptyInputIntoEmptyOutputs(self):
self._testNonEmptyInputIntoEmptyOutputs()
@combinations.generate(test_base.default_test_combinations())
def testNonEmptyInputIntoEmptyOutputsSloppy(self):
self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
@ -469,12 +490,15 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
"At index %s: %s expected, got: %s" % (i, expected_element,
actual_element))
@combinations.generate(test_base.default_test_combinations())
def testPartiallyEmptyOutputs(self):
self._testPartiallyEmptyOutputs()
@combinations.generate(test_base.default_test_combinations())
def testPartiallyEmptyOutputsSloppy(self):
self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0)
@combinations.generate(test_base.default_test_combinations())
def testDelayedOutputSloppy(self):
# Explicitly control the sequence of events to ensure we correctly avoid
# head-of-line blocking.
@ -500,6 +524,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testBlockLengthWithContentionSloppy(self):
self.skipTest("b/131722904")
self._clear_coordination_events()
@ -557,9 +582,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
self.read_coordination_events[i].acquire()
self.write_coordination_events[i].set()
@combinations.generate(test_base.default_test_combinations())
def testEarlyExit(self):
self._testEarlyExit()
@combinations.generate(test_base.default_test_combinations())
def testEarlyExitSloppy(self):
self._testEarlyExit(sloppy=True)
@ -584,12 +611,15 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
self.assertItemsEqual(output_values, expected_values)
@combinations.generate(test_base.default_test_combinations())
def testTooManyReaders(self):
self._testTooManyReaders()
@combinations.generate(test_base.default_test_combinations())
def testTooManyReadersSloppy(self):
self._testTooManyReaders(sloppy=True)
@combinations.generate(test_base.default_test_combinations())
def testSparse(self):
def _map_fn(i):
return sparse_tensor.SparseTensor(
@ -610,6 +640,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testErrorsInOutputFn(self):
self.skipTest("b/131722904")
self._clear_coordination_events()
@ -642,6 +673,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testErrorsInInputFn(self):
def map_py_fn(x):
@ -687,6 +719,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testErrorsInInterleaveFn(self):
def map_py_fn(x):
@ -730,6 +763,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testShutdownRace(self):
dataset = dataset_ops.Dataset.range(20)
map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1))

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import copy
from absl.testing import parameterized
import numpy as np
from tensorflow.core.example import example_pb2
@ -28,11 +29,11 @@ from tensorflow.python.data.experimental.ops import parsing_ops as contrib_parsi
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
@ -50,8 +51,8 @@ feature_lists = lambda d: feature_pb2.FeatureLists(feature_list=d)
sequence_example = example_pb2.SequenceExample
@test_util.run_all_in_graph_and_eager_modes
class ParseExampleDatasetTest(test_base.DatasetTestBase):
class ParseExampleDatasetTest(test_base.DatasetTestBase,
parameterized.TestCase):
def _compare_output_to_expected(self, dict_tensors, expected_tensors):
self.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
@ -107,6 +108,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
self.assertEqual(
dataset_ops.get_legacy_output_shapes(dataset)[k].as_list()[1], None)
@combinations.generate(test_base.default_test_combinations())
def testEmptySerializedWithAllDefaults(self):
sparse_name = "st_a"
a_name = "a"
@ -145,7 +147,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output,
create_iterator_twice=True)
@test_util.run_deprecated_v1
@combinations.generate(test_base.graph_only_combinations())
def testEmptySerializedWithoutDefaultsShouldFail(self):
input_features = {
"st_a":
@ -179,7 +181,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_err=(errors_impl.InvalidArgumentError,
"Feature: c \\(data type: float\\) is required"))
@test_util.run_deprecated_v1
@combinations.generate(test_base.graph_only_combinations())
def testDenseNotMatchingShapeShouldFail(self):
original = [
example(features=features({
@ -197,6 +199,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_err=(errors_impl.InvalidArgumentError,
"Key: a, Index: 1. Number of float values"))
@combinations.generate(test_base.default_test_combinations())
def testDenseDefaultNoShapeShouldFail(self):
original = [example(features=features({"a": float_feature([1, 1, 3]),})),]
@ -207,6 +210,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
{"a": parsing_ops.FixedLenFeature(None, dtypes.float32)},
expected_err=(ValueError, "Missing shape for feature a"))
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingSparse(self):
original = [
example(features=features({
@ -248,6 +252,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output,
create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingSparseFeature(self):
original = [
example(features=features({
@ -284,6 +289,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output,
create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingSparseFeatureReuse(self):
original = [
example(features=features({
@ -325,6 +331,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output,
create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContaining3DSparseFeature(self):
original = [
example(features=features({
@ -370,6 +377,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output,
create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingDense(self):
aname = "a"
bname = "b*has+a:tricky_name"
@ -407,6 +415,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
# This test is identical as the previous one except
# for the creation of 'serialized'.
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingDenseWithConcat(self):
aname = "a"
bname = "b*has+a:tricky_name"
@ -452,6 +461,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output,
create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingDenseScalar(self):
original = [
example(features=features({
@ -476,6 +486,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output,
create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingDenseWithDefaults(self):
original = [
example(features=features({
@ -514,6 +525,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output,
create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedSparseAndSparseFeatureAndDenseWithNoDefault(self):
expected_st_a = sparse_tensor.SparseTensorValue( # indices, values, shape
np.empty((0, 2), dtype=np.int64), # indices
@ -569,6 +581,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output,
create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testerializedContainingSparseAndSparseFeatureWithReuse(self):
expected_idx = sparse_tensor.SparseTensorValue( # indices, values, shape
np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
@ -667,11 +680,13 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output,
create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingVarLenDenseLargerBatch(self):
np.random.seed(3456)
for batch_size in (1, 10, 20, 100, 256):
self._testSerializedContainingVarLenDenseLargerBatch(batch_size)
@combinations.generate(test_base.default_test_combinations())
def testSerializedShapeMismatch(self):
aname = "a"
bname = "b"
@ -724,7 +739,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_err=(ValueError,
"Cannot reshape a tensor with 0 elements to shape"))
@test_util.run_deprecated_v1
@combinations.generate(test_base.graph_only_combinations())
def testSerializedContainingVarLenDense(self):
aname = "a"
bname = "b"
@ -877,6 +892,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
"Unsupported: FixedLenSequenceFeature requires "
"allow_missing to be True."))
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingRaggedFeatureWithNoPartitions(self):
original = [
example(
@ -922,6 +938,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output,
create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingRaggedFeatureWithOnePartition(self):
original = [
example(
@ -1040,6 +1057,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
expected_values=expected_output,
create_iterator_twice=True)
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingRaggedFeatureWithMultiplePartitions(self):
original = [
# rt shape: [(batch), 2, None, None]

View File

@ -17,11 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import structure
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@ -31,9 +34,9 @@ from tensorflow.python.platform import test
# TODO(b/117581999): add eager coverage when supported.
class PrefetchToDeviceTest(test_base.DatasetTestBase):
class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testPrefetchToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
@ -57,7 +60,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testPrefetchToSameDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
@ -82,7 +85,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testPrefetchDictToDevice(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
device_dataset = host_dataset.apply(
@ -106,7 +109,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testPrefetchSparseTensorsToDevice(self):
def make_tensor(i):
return sparse_tensor.SparseTensorValue(
@ -136,7 +139,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testPrefetchToDeviceGpu(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -156,7 +159,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testPrefetchToDeviceWithReInit(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
@ -184,7 +187,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
@test_util.deprecated_graph_mode_only
@combinations.generate(test_base.graph_only_combinations())
def testPrefetchToDeviceGpuWithReInit(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")

View File

@ -24,16 +24,17 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
@test_util.run_v1_only("b/121264236")
# TODO(b/121264236)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testPrefetchWithSlackOption(self):
"""Determines slack_period based on num devices attached to iterator."""
dataset = dataset_ops.Dataset.range(10)
@ -60,6 +61,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
@combinations.generate(test_base.default_test_combinations())
def testPrefetchWithSlackOptionWithoutIterator(self):
"""Defaults to slack period of 1 without iterator."""
dataset = dataset_ops.Dataset.range(10)
@ -72,6 +74,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset.options()._graph_rewrite_configs())
self.assertDatasetProduces(dataset, range(10))
@combinations.generate(test_base.default_test_combinations())
def testWithPassthroughDataset(self):
"""Should still work with a passthrough dataset after prefetch()."""
dataset = dataset_ops.Dataset.range(10)
@ -82,6 +85,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, range(1, 11))
@combinations.generate(test_base.default_test_combinations())
def testErrorWithoutPrefetch(self):
"""The rewrite fails if there is no prefetch() in the pipeline."""
dataset = dataset_ops.Dataset.range(10)
@ -92,6 +96,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = self.getNext(dataset)
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testErrorWithInvalidDataset(self):
"""With a nested dataset op after prefetch, the rewrite should fail."""
dataset = dataset_ops.Dataset.range(10)

View File

@ -32,8 +32,8 @@ from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ -47,13 +47,11 @@ def _flat_shapes(dataset):
return nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
@test_util.run_all_in_graph_and_eager_modes
class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
drop_remainder_cases = [("WithDropRemainder", True),
("WithoutDropRemainder", False)]
@parameterized.named_parameters(drop_remainder_cases)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(drop_remainder=[True, False])))
def testBasic(self, drop_remainder):
dataset = dataset_ops.Dataset.range(1024).batch(
32, drop_remainder=drop_remainder)
@ -64,13 +62,16 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testScalarInputError(self):
dataset = dataset_ops.Dataset.range(1024)
distribute._RebatchDataset(dataset.batch(4), num_replicas=4)
with self.assertRaisesRegexp(ValueError, "at least one dimension"):
distribute._RebatchDataset(dataset, num_replicas=4)
@parameterized.named_parameters(drop_remainder_cases)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(drop_remainder=[True, False])))
def testBatchNotDivisibleByNumReplicas(self, drop_remainder):
dataset = dataset_ops.Dataset.range(1024).batch(
32, drop_remainder=drop_remainder)
@ -89,6 +90,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
i += 4
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testBatchSizeNotDivisibleByNumReplicas2(self):
dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True)
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
@ -102,6 +104,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output.extend([[]]) # Last replica gets an empty batch
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testTupleOutput(self):
dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32)
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
@ -110,6 +113,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 1024, 8)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testNestedDictionaryOutput(self):
dataset = dataset_ops.Dataset.range(1024).map(
lambda x: {"a": x, "b": {"c": x}}).batch(32)
@ -119,7 +123,9 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 1024, 8)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@parameterized.named_parameters(drop_remainder_cases)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(drop_remainder=[True, False])))
def testFinalPartialBatch(self, drop_remainder):
dataset = dataset_ops.Dataset.range(1032).batch(
32, drop_remainder=drop_remainder)
@ -136,7 +142,9 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
[[k for k in range(i, i + 2)] for i in range(1024, 1032, 2)])
self.assertDatasetProduces(rebatched_dataset, expected_output)
@parameterized.named_parameters(drop_remainder_cases)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(drop_remainder=[True, False])))
def testFinalPartialBatchAfterRebatch(self, drop_remainder):
dataset = dataset_ops.Dataset.range(34).batch(
32, drop_remainder=drop_remainder)
@ -150,6 +158,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output += [[32], [33], [], []]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testMultipleBatches(self):
dataset = dataset_ops.Dataset.range(128).batch(4).batch(8)
self.assertEqual([[None, None]],
@ -170,6 +179,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 128, 8)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testMapAndBatch(self):
dataset = dataset_ops.Dataset.range(1024).apply(
batching.map_and_batch(math_ops.square, 32))
@ -180,6 +190,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 1024, 8)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testMapAndBatchWithCapturedInput(self):
captured_t = variables.Variable(42)
dataset = dataset_ops.Dataset.range(1024).apply(
@ -193,6 +204,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
rebatched_dataset, expected_output, requires_initialization=True)
@combinations.generate(test_base.default_test_combinations())
def testPaddedBatch(self):
dataset = dataset_ops.Dataset.range(128).batch(
4, drop_remainder=True).padded_batch(
@ -213,6 +225,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 128, 8)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testConcatenate(self):
dataset1 = dataset_ops.Dataset.range(64).batch(8)
dataset2 = dataset_ops.Dataset.range(32).batch(8)
@ -224,6 +237,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
[[i, i + 1] for i in range(0, 32, 2)])
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testConcatenateDifferentShapes(self):
dataset1 = dataset_ops.Dataset.range(64).batch(16)
dataset2 = dataset_ops.Dataset.range(32).batch(8)
@ -235,6 +249,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
[[i, i + 1] for i in range(0, 32, 2)])
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testZip(self):
dataset1 = dataset_ops.Dataset.range(64).batch(8)
dataset2 = dataset_ops.Dataset.range(32).batch(8)
@ -245,6 +260,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testZipDifferentShapes(self):
dataset1 = dataset_ops.Dataset.range(64).batch(16)
dataset2 = dataset_ops.Dataset.range(32).batch(8)
@ -256,6 +272,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 32, 2)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testFlatMapBatching(self):
dataset = dataset_ops.Dataset.range(2).flat_map(
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
@ -274,6 +291,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(0, 32, 8)] # generates 4 elements
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testInterleaveBatching(self):
dataset = dataset_ops.Dataset.range(2).interleave(
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
@ -290,6 +308,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output += expected_output
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testParallelInterleaveBatching(self):
dataset = dataset_ops.Dataset.range(2).interleave(
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
@ -307,6 +326,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output += expected_output
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testGroupByWindowStaticBatch(self):
dataset = dataset_ops.Dataset.from_tensor_slices(
[[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)])
@ -326,6 +346,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for k in range(2)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testGroupByWindowDynamicBatch(self):
# {0, 1, 0, 1, ...}
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
@ -350,6 +371,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output = [[value] * batch_size for batch_size, value in pairs]
self.assertDatasetProduces(dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testGroupByWindowDynamicBatchWithPartialBatch(self):
# {0, 1, 0, 1, ...}
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
@ -371,6 +393,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output = [[value] * batch_size for batch_size, value in pairs]
self.assertDatasetProduces(dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testGroupByWindowDynamicBatchWithPartialBatchWithDropRemainder(self):
# This test exercises nested batch functionality, dynamic batch size
# and drop_remainder=True together.
@ -395,6 +418,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output = [[value] * batch_size for batch_size, value in pairs]
self.assertDatasetProduces(dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testScanAfterBatch(self):
dataset = dataset_ops.Dataset.range(40).batch(10).apply(
scan_ops.scan(np.int64(2), lambda state, value: (state, value * state)))
@ -405,6 +429,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)] # pylint: disable=g-complex-comprehension
self.assertDatasetProduces(dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testMakeBatchedFeaturesDataset(self):
# Set up
fn = os.path.join(self.get_temp_dir(), "tf_record.txt")
@ -438,6 +463,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
} for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testRaggedTensorDataset(self):
# Set up a dataset that produces ragged tensors with a static batch size.
row_lengths = np.random.randint(8, size=128)

View File

@ -24,9 +24,9 @@ import numpy as np
from tensorflow.python.data.experimental.ops import resampling
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import string_ops
@ -34,12 +34,11 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
@test_util.run_all_in_graph_and_eager_modes
class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("InitialDistributionKnown", True),
("InitialDistributionUnknown", False))
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(initial_known=[True, False])))
def testDistribution(self, initial_known):
classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
@ -72,9 +71,9 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
returned_dist = class_counts / total_returned
self.assertAllClose(target_dist, returned_dist, atol=1e-2)
@parameterized.named_parameters(
("OnlyInitial", True),
("NotInitial", False))
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(only_initial_dist=[True, False])))
def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist):
init_dist = [0.5, 0.5]
target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0]
@ -99,6 +98,7 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
while True:
returned.append(self.evaluate(get_next()))
@combinations.generate(test_base.default_test_combinations())
def testRandomClasses(self):
init_dist = [0.25, 0.25, 0.25, 0.25]
target_dist = [0.0, 0.0, 0.0, 1.0]

View File

@ -17,18 +17,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.ops import shuffle_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ShuffleAndRepeatTest(test_base.DatasetTestBase):
class ShuffleAndRepeatTest(test_base.DatasetTestBase, parameterized.TestCase):
def _build_ds(self, seed, count=5, num_elements=20):
return dataset_ops.Dataset.range(num_elements).apply(
@ -44,6 +44,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
self.evaluate(get_next())
return outputs
@combinations.generate(test_base.default_test_combinations())
def testCorrectOutput(self):
output = self._gen_outputs(lambda: self._build_ds(10), 100)
self.assertSequenceEqual(
@ -52,6 +53,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
for i in range(5):
self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20))
@combinations.generate(test_base.default_test_combinations())
def testReshuffling(self):
# Check that the output orders of different epochs are indeed different.
output = self._gen_outputs(lambda: self._build_ds(10), 100)
@ -60,17 +62,20 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
epoch2 = output[(i + 1) * 20:(i + 2) * 20]
self.assertNotEqual(epoch1, epoch2)
@combinations.generate(test_base.default_test_combinations())
def testSameOrderForSameSeeds(self):
output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
output2 = self._gen_outputs(lambda: self._build_ds(10), 100)
self.assertEqual(output1, output2)
@combinations.generate(test_base.default_test_combinations())
def testDifferentOrderForDifferentSeeds(self):
output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
output2 = self._gen_outputs(lambda: self._build_ds(20), 100)
self.assertNotEqual(output1, output2)
self.assertEqual(sorted(output1), sorted(output2))
@combinations.generate(test_base.default_test_combinations())
def testCountNone(self):
output1 = self._gen_outputs(
lambda: self._build_ds(10, count=None), 100, verify_exhausted=False)
@ -79,6 +84,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
self.assertNotEqual(output1, output2)
self.assertEqual(sorted(output1), sorted(output2))
@combinations.generate(test_base.default_test_combinations())
def testCountMinusOne(self):
output1 = self._gen_outputs(
lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False)
@ -87,6 +93,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
self.assertNotEqual(output1, output2)
self.assertEqual(sorted(output1), sorted(output2))
@combinations.generate(test_base.default_test_combinations())
def testInfiniteOutputs(self):
# Asserting the iterator is exhausted after producing 100 items should fail.
with self.assertRaises(AssertionError):
@ -94,6 +101,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
with self.assertRaises(AssertionError):
self._gen_outputs(lambda: self._build_ds(10, count=-1), 100)
@combinations.generate(test_base.default_test_combinations())
def testInfiniteEmpty(self):
with self.assertRaises(errors.OutOfRangeError):
self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0),
@ -102,12 +110,14 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0),
100)
@combinations.generate(test_base.default_test_combinations())
def testLargeBufferSize(self):
ds = dataset_ops.Dataset.range(20).apply(
shuffle_ops.shuffle_and_repeat(buffer_size=21))
get_next = self.getNext(ds)
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testVeryLargeBufferSize(self):
num_epochs = 1000 * 1000
# Each element being shuffled and repeated has shape (100,). This will OOM

View File

@ -18,18 +18,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase,
parameterized.TestCase):
# Test that SqlDataset can read from a database table.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSet(self):
for _ in range(2): # Run twice to verify statelessness of db operations.
dataset = self._createSqlDataset(
@ -44,6 +48,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
num_test_iterations=2)
# Test that SqlDataset works on a join query.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetJoinQuery(self):
get_next = self.getNext(
self._createSqlDataset(
@ -60,6 +65,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that SqlDataset can read a database entry with a null-terminator
# in the middle of the text and place the entry in a `string` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetNullTerminator(self):
get_next = self.getNext(
self._createSqlDataset(
@ -76,6 +82,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that SqlDataset works when used on two different queries.
# Because the output types of the dataset must be determined at graph-creation
# time, the two queries must have the same number and types of columns.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetReuseSqlDataset(self):
get_next = self.getNext(
self._createSqlDataset(
@ -100,6 +107,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that an `OutOfRangeError` is raised on the first call to
# `get_next_str_only` if result set is empty.
@combinations.generate(test_base.default_test_combinations())
def testReadEmptyResultSet(self):
get_next = self.getNext(
self._createSqlDataset(
@ -110,6 +118,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.evaluate(get_next())
# Test that an error is raised when `driver_name` is invalid.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetWithInvalidDriverName(self):
with self.assertRaises(errors.InvalidArgumentError):
dataset = self._createSqlDataset(
@ -120,6 +129,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.assertDatasetProduces(dataset, expected_output=[])
# Test that an error is raised when a column name in `query` is nonexistent
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetWithInvalidColumnName(self):
get_next = self.getNext(
self._createSqlDataset(
@ -130,6 +140,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.evaluate(get_next())
# Test that an error is raised when there is a syntax error in `query`.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetOfQueryWithSyntaxError(self):
get_next = self.getNext(
self._createSqlDataset(
@ -141,6 +152,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that an error is raised when the number of columns in `query`
# does not match the length of `, output_types`.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
get_next = self.getNext(
self._createSqlDataset(
@ -154,6 +166,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# than a select query. In particular, the error refers to the number of
# output types passed to the op not matching the number of columns in the
# result set of the query (namely, 0 for an insert statement.)
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetOfInsertQuery(self):
get_next = self.getNext(
self._createSqlDataset(
@ -165,6 +178,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int8` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt8(self):
get_next = self.getNext(
self._createSqlDataset(
@ -178,6 +192,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int8` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt8NegativeAndZero(self):
get_next = self.getNext(
self._createSqlDataset(
@ -191,6 +206,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int8` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt8MaxValues(self):
get_next = self.getNext(
self._createSqlDataset(
@ -205,6 +221,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int16` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt16(self):
get_next = self.getNext(
self._createSqlDataset(
@ -218,6 +235,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int16` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt16NegativeAndZero(self):
get_next = self.getNext(
self._createSqlDataset(
@ -231,6 +249,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int16` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt16MaxValues(self):
get_next = self.getNext(
self._createSqlDataset(
@ -246,6 +265,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int32` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt32(self):
get_next = self.getNext(
self._createSqlDataset(
@ -257,6 +277,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int32` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt32NegativeAndZero(self):
get_next = self.getNext(
self._createSqlDataset(
@ -270,6 +291,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int32` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt32MaxValues(self):
get_next = self.getNext(
self._createSqlDataset(
@ -285,6 +307,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a numeric `varchar` from a SQLite database
# table and place it in an `int32` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt32VarCharColumnAsInt(self):
get_next = self.getNext(
self._createSqlDataset(
@ -298,6 +321,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read an integer from a SQLite database table
# and place it in an `int64` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt64(self):
get_next = self.getNext(
self._createSqlDataset(
@ -311,6 +335,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int64` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt64NegativeAndZero(self):
get_next = self.getNext(
self._createSqlDataset(
@ -324,6 +349,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int64` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetInt64MaxValues(self):
get_next = self.getNext(
self._createSqlDataset(
@ -339,6 +365,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in a `uint8` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetUInt8(self):
get_next = self.getNext(
self._createSqlDataset(
@ -352,6 +379,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read the minimum and maximum uint8 values from a
# SQLite database table and place them in `uint8` tensors.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetUInt8MinAndMaxValues(self):
get_next = self.getNext(
self._createSqlDataset(
@ -367,6 +395,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read an integer from a SQLite database table
# and place it in a `uint16` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetUInt16(self):
get_next = self.getNext(
self._createSqlDataset(
@ -380,6 +409,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read the minimum and maximum uint16 values from a
# SQLite database table and place them in `uint16` tensors.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetUInt16MinAndMaxValues(self):
get_next = self.getNext(
self._createSqlDataset(
@ -396,6 +426,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
# SQLite database table and place them as `True` and `False` respectively
# in `bool` tensors.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetBool(self):
get_next = self.getNext(
self._createSqlDataset(
@ -409,6 +440,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
# from a SQLite database table and place it as `True` in a `bool` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetBoolNotZeroOrOne(self):
get_next = self.getNext(
self._createSqlDataset(
@ -422,6 +454,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a float from a SQLite database table
# and place it in a `float64` tensor.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetFloat64(self):
get_next = self.getNext(
self._createSqlDataset(
@ -437,6 +470,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that `SqlDataset` can read a float from a SQLite database table beyond
# the precision of 64-bit IEEE, without throwing an error. Test that
# `SqlDataset` identifies such a value as equal to itself.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetFloat64OverlyPrecise(self):
get_next = self.getNext(
self._createSqlDataset(
@ -458,6 +492,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# representing the largest integer representable as a 64-bit IEEE float
# such that the previous integer is also representable as a 64-bit IEEE float.
# Test that `SqlDataset` can distinguish these two numbers.
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
get_next = self.getNext(
self._createSqlDataset(
@ -472,6 +507,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.evaluate(get_next())
# Test that SqlDataset can stop correctly when combined with batch
@combinations.generate(test_base.default_test_combinations())
def testReadResultSetWithBatchStop(self):
dataset = self._createSqlDataset(
query="SELECT * FROM data", output_types=(dtypes.int32))

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
@ -24,7 +25,9 @@ from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import stats_aggregator
from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -32,8 +35,11 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
# TODO(jsimsa): Figure out why are graph tests failing.
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase,
parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testBytesProduced(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).map(
@ -57,6 +63,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self.assertStatisticsHasCount(handle, "bytes_produced", 100.0, 101)
self.assertStatisticsHasSum(handle, "bytes_produced", expected_sum, 101)
@combinations.generate(test_base.eager_only_combinations())
def testLatencyStats(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
@ -76,6 +83,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "record_latency", 100.0, 101)
@combinations.generate(test_base.eager_only_combinations())
def testPrefetchBufferUtilization(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).map(
@ -117,6 +125,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
301,
offset=2)
@combinations.generate(test_base.eager_only_combinations())
def testPrefetchBufferScalars(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(10).map(
@ -140,6 +149,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.eager_only_combinations())
def testFilteredElementsStats(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(101).filter(
@ -167,6 +177,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
handle, self.regexForNodeName("FilterDataset", "filtered_elements"),
34.0)
@combinations.generate(test_base.eager_only_combinations())
def testReinitialize(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
@ -187,6 +198,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self.assertStatisticsHasCount(handle, "record_latency", (j + 1) * 100.0,
(j * 100) + 101)
@combinations.generate(test_base.eager_only_combinations())
def testNoAggregatorRegistered(self):
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
@ -198,6 +210,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.eager_only_combinations())
def testMultipleTags(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
@ -221,6 +234,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
handle, "record_latency", 100.0, 201, offset=1)
self.assertStatisticsHasCount(handle, "record_latency_2", 100.0, 201)
@combinations.generate(test_base.eager_only_combinations())
def testRepeatedTags(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
@ -239,6 +253,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "record_latency", 200.0, 201)
@combinations.generate(test_base.eager_only_combinations())
def testMultipleIteratorsSameAggregator(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
@ -259,6 +274,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "record_latency", 200.0, 201)
@combinations.generate(test_base.eager_only_combinations())
def testMultipleDatasetWithPrefixes(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
@ -289,6 +305,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self.assertStatisticsHasCount(handle, "dataset2::record_latency", 100.0,
201)
@combinations.generate(test_base.eager_only_combinations())
def testMultiplePrefetchStats(self):
aggregator = stats_aggregator.StatsAggregator()
@ -314,8 +331,10 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self.evaluate(next_element())
class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase,
parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testMapBufferUtilization(self):
def dataset_fn():
@ -326,6 +345,7 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
self.parallelCallsStats(
dataset_fn, {"ParallelMapDataset"}, 10, function_processing_time=True)
@combinations.generate(test_base.eager_only_combinations())
def testMapAutoTuneBufferUtilization(self):
def dataset_fn():
@ -336,6 +356,7 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
self.parallelCallsStats(
dataset_fn, {"ParallelMapDataset"}, 10, function_processing_time=True)
@combinations.generate(test_base.eager_only_combinations())
def testInterleaveAutoTuneBufferUtilization(self):
def dataset_fn():
@ -351,6 +372,7 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
self.parallelCallsStats(dataset_fn, {"ParallelInterleaveDatasetV2"}, 10)
@combinations.generate(test_base.eager_only_combinations())
def testMapAndBatchAutoTuneBufferUtilization(self):
def dataset_fn():
@ -370,8 +392,10 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
class FeatureStatsDatasetTest(
stats_dataset_test_base.StatsDatasetTestBase,
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase,
parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testFeaturesStats(self):
num_epochs = 5
total_records = num_epochs * self._num_records

View File

@ -23,18 +23,21 @@ import numpy as np
from tensorflow.python.data.experimental.ops import take_while_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.parameters((14, 2), (15, 2), (100, 3))
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(num_elements=[14, 15], window_size=[2]) +
combinations.combine(num_elements=[100], window_size=[3])))
def testTakeWhileDataset(self, num_elements, window_size):
def _predicate_func(elem):
@ -49,8 +52,19 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_num_elements = int(num_elements / window_size) * window_size
self.assertDatasetProduces(dataset, np.arange(expected_num_elements))
@parameterized.parameters((10, 2, False), (16, 7, False), (100, 99, False),
(100, 101, True), (0, 1, True))
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
num_elements=[10], upper_bound=[2], out_of_bounds=[False]) +
combinations.combine(
num_elements=[16], upper_bound=[7], out_of_bounds=[False]) +
combinations.combine(
num_elements=[100], upper_bound=[99], out_of_bounds=[False]) +
combinations.combine(
num_elements=[100], upper_bound=[101], out_of_bounds=[True]) +
combinations.combine(
num_elements=[0], upper_bound=[1], out_of_bounds=[True])))
def testTakeWhileDatasetRange(self, num_elements, upper_bound, out_of_bounds):
dataset = dataset_ops.Dataset.range(num_elements).apply(
take_while_ops.take_while(lambda x: x < upper_bound))
@ -62,6 +76,7 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
else:
self.assertDatasetProduces(dataset, np.arange(upper_bound))
@combinations.generate(test_base.default_test_combinations())
def testTakeWhileDatasetString(self):
def not_equal(string):
@ -79,7 +94,13 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
self.assertEqual(b"test", self.evaluate(next_element()))
@parameterized.parameters((5, 3), (10, 0), (100, 5), (8, 7))
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(size=[5], index=[3]) +
combinations.combine(size=[10], index=[0]) +
combinations.combine(size=[100], index=[5]) +
combinations.combine(size=[8], index=[7])))
def testTakewhileDatasetShortCircuit(self, size, index):
def _predicate_func(data_elem):
@ -98,6 +119,7 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
def testTakeWhileDatasetWithRepeat(self):
dataset = dataset_ops.Dataset.range(10).apply(
take_while_ops.take_while(lambda x: x < 2)).repeat(5)

View File

@ -19,14 +19,16 @@ from __future__ import print_function
import os
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.experimental.ops import writers
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.eager import function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import python_io
from tensorflow.python.lib.io import tf_record
from tensorflow.python.ops import string_ops
@ -34,8 +36,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
@test_util.run_all_in_graph_and_eager_modes
class TFRecordWriterTest(test_base.DatasetTestBase):
class TFRecordWriterTest(test_base.DatasetTestBase, parameterized.TestCase):
def setUp(self):
super(TFRecordWriterTest, self).setUp()
@ -63,11 +64,13 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
def _outputFilename(self):
return os.path.join(self.get_temp_dir(), "tf_record.out.txt")
@combinations.generate(test_base.default_test_combinations())
def testWrite(self):
self.evaluate(self.writer_fn(self._createFile()))
for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())):
self.assertAllEqual(self._record(i), r)
@combinations.generate(test_base.default_test_combinations())
def testWriteZLIB(self):
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB)
self.evaluate(
@ -76,6 +79,7 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
tf_record.tf_record_iterator(self._outputFilename(), options=options)):
self.assertAllEqual(self._record(i), r)
@combinations.generate(test_base.default_test_combinations())
def testWriteGZIP(self):
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
self.evaluate(
@ -84,20 +88,24 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
tf_record.tf_record_iterator(self._outputFilename(), options=options)):
self.assertAllEqual(self._record(i), r)
@combinations.generate(test_base.default_test_combinations())
def testFailDataset(self):
with self.assertRaises(TypeError):
writers.TFRecordWriter(self._outputFilename(), "").write("whoops")
@combinations.generate(test_base.default_test_combinations())
def testFailDType(self):
input_dataset = dataset_ops.Dataset.from_tensors(10)
with self.assertRaises(TypeError):
writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset)
@combinations.generate(test_base.default_test_combinations())
def testFailShape(self):
input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]])
with self.assertRaises(TypeError):
writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset)
@combinations.generate(test_base.default_test_combinations())
def testSideEffect(self):
def writer_fn():
input_dataset = readers.TFRecordDataset(self._createFile())
@ -112,6 +120,7 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())):
self.assertAllEqual(self._record(i), r)
@combinations.generate(test_base.default_test_combinations())
def testShard(self):
filename = self._createFile()
dataset = readers.TFRecordDataset([filename])

View File

@ -17,17 +17,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import unique
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@test_util.run_all_in_graph_and_eager_modes
class UniqueTest(test_base.DatasetTestBase):
class UniqueTest(test_base.DatasetTestBase, parameterized.TestCase):
def _testSimpleHelper(self, dtype, test_cases):
"""Test the `unique()` transformation on a list of test cases.
@ -52,7 +53,7 @@ class UniqueTest(test_base.DatasetTestBase):
for element in expected
])
@test_util.run_deprecated_v1
@combinations.generate(test_base.graph_only_combinations())
def testSimpleInt(self):
for dtype in [dtypes.int32, dtypes.int64]:
self._testSimpleHelper(dtype, [
@ -65,7 +66,7 @@ class UniqueTest(test_base.DatasetTestBase):
([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]),
])
@test_util.run_deprecated_v1
@combinations.generate(test_base.graph_only_combinations())
def testSimpleString(self):
self._testSimpleHelper(dtypes.string, [
([], []),

View File

@ -17,16 +17,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class VariantTest(test_base.DatasetTestBase):
class VariantTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testRoundtripRange(self):
dataset = dataset_ops.Dataset.range(10)
variant = dataset_ops.to_variant(dataset)
@ -35,6 +37,7 @@ class VariantTest(test_base.DatasetTestBase):
self.assertDatasetProduces(dataset, range(10))
self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
@combinations.generate(test_base.default_test_combinations())
def testRoundtripMap(self):
dataset = dataset_ops.Dataset.range(10).map(lambda x: x*x)
variant = dataset_ops.to_variant(dataset)

View File

@ -17,18 +17,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class WrapDatasetVariantTest(test_base.DatasetTestBase):
class WrapDatasetVariantTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testBasic(self):
ds = dataset_ops.Dataset.range(100)
ds_variant = ds._variant_tensor # pylint: disable=protected-access
@ -42,7 +44,9 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase):
for i in range(100):
self.assertEqual(i, self.evaluate(get_next()))
@test_util.run_v1_only("b/123901304")
# TODO(b/123901304)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testSkipEagerGPU(self):
ds = dataset_ops.Dataset.range(100)
ds_variant = ds._variant_tensor # pylint: disable=protected-access