[tf.data] Migrating remaining experimental API tests to use TF combinations and performing various minor test cleanup.

PiperOrigin-RevId: 288533288
Change-Id: Iba6a980cd08fa0aba9e9703711b1dcdfbc3cb734
This commit is contained in:
Jiri Simsa 2020-01-07 11:10:42 -08:00 committed by TensorFlower Gardener
parent 1d41edaee6
commit febe171d32
6 changed files with 71 additions and 39 deletions

View File

@ -17,17 +17,19 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes class AssertNextTest(test_base.DatasetTestBase, parameterized.TestCase):
class AssertNextTest(test_base.DatasetTestBase):
@combinations.generate(test_base.default_test_combinations())
def testAssertNext(self): def testAssertNext(self):
dataset = dataset_ops.Dataset.from_tensors(0).apply( dataset = dataset_ops.Dataset.from_tensors(0).apply(
testing.assert_next(["Map"])).map(lambda x: x) testing.assert_next(["Map"])).map(lambda x: x)
@ -36,6 +38,7 @@ class AssertNextTest(test_base.DatasetTestBase):
dataset = dataset.with_options(options) dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, expected_output=[0]) self.assertDatasetProduces(dataset, expected_output=[0])
@combinations.generate(test_base.default_test_combinations())
def testAssertNextInvalid(self): def testAssertNextInvalid(self):
dataset = dataset_ops.Dataset.from_tensors(0).apply( dataset = dataset_ops.Dataset.from_tensors(0).apply(
testing.assert_next(["Whoops"])).map(lambda x: x) testing.assert_next(["Whoops"])).map(lambda x: x)
@ -49,6 +52,7 @@ class AssertNextTest(test_base.DatasetTestBase):
"Asserted Whoops transformation at offset 0 but encountered " "Asserted Whoops transformation at offset 0 but encountered "
"Map transformation instead.")) "Map transformation instead."))
@combinations.generate(test_base.default_test_combinations())
def testAssertNextShort(self): def testAssertNextShort(self):
dataset = dataset_ops.Dataset.from_tensors(0).apply( dataset = dataset_ops.Dataset.from_tensors(0).apply(
testing.assert_next(["Map", "Whoops"])).map(lambda x: x) testing.assert_next(["Map", "Whoops"])).map(lambda x: x)

View File

@ -17,21 +17,20 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import cardinality from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import combinations
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes def _test_combinations():
class NumElementsTest(test_base.DatasetTestBase, parameterized.TestCase): # pylint: disable=g-long-lambda
"""Tests for `tf.data.experimental.cardinality()`.""" cases = [
@parameterized.named_parameters(
# pylint: disable=g-long-lambda
("Batch1", ("Batch1",
lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=True), 2), lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=True), 2),
("Batch2", ("Batch2",
@ -151,9 +150,24 @@ class NumElementsTest(test_base.DatasetTestBase, parameterized.TestCase):
("Zip5", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range( ("Zip5", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range(
5), dataset_ops.Dataset.range(3).filter(lambda _: True))), 5), dataset_ops.Dataset.range(3).filter(lambda _: True))),
cardinality.UNKNOWN), cardinality.UNKNOWN),
# pylint: enable=g-long-lambda ]
)
def testNumElements(self, dataset_fn, expected_result): def reduce_fn(x, y):
name, dataset_fn, expected_result = y
return x + combinations.combine(
dataset_fn=combinations.NamedObject(name, dataset_fn),
expected_result=expected_result)
return functools.reduce(reduce_fn, cases, [])
class CardinalityTest(test_base.DatasetTestBase, parameterized.TestCase):
"""Tests for `tf.data.experimental.cardinality()`."""
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_test_combinations()))
def testCardinality(self, dataset_fn, expected_result):
with self.cached_session() as sess: with self.cached_session() as sess:
self.assertEqual( self.assertEqual(
sess.run(cardinality.cardinality(dataset_fn())), expected_result) sess.run(cardinality.cardinality(dataset_fn())), expected_result)

View File

@ -22,14 +22,14 @@ from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ModelDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): class ModelDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testAutotuneOption(self): def testAutotuneOption(self):
dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset_ops.Dataset.from_tensors(0)
dataset = dataset.map(lambda x: x).apply( dataset = dataset.map(lambda x: x).apply(

View File

@ -17,16 +17,18 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import combinations
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes class NonSerializableTest(test_base.DatasetTestBase, parameterized.TestCase):
class NonSerializableTest(test_base.DatasetTestBase):
@combinations.generate(test_base.default_test_combinations())
def testNonSerializable(self): def testNonSerializable(self):
dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset_ops.Dataset.from_tensors(0)
dataset = dataset.apply(testing.assert_next(["FiniteSkip"])) dataset = dataset.apply(testing.assert_next(["FiniteSkip"]))
@ -41,6 +43,7 @@ class NonSerializableTest(test_base.DatasetTestBase):
dataset = dataset.with_options(options) dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, expected_output=[0]) self.assertDatasetProduces(dataset, expected_output=[0])
@combinations.generate(test_base.default_test_combinations())
def testNonSerializableAsDirectInput(self): def testNonSerializableAsDirectInput(self):
"""Tests that non-serializable dataset can be OptimizeDataset's input.""" """Tests that non-serializable dataset can be OptimizeDataset's input."""
dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset_ops.Dataset.from_tensors(0)

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import warnings import warnings
from absl.testing import parameterized from absl.testing import parameterized
@ -30,23 +31,17 @@ from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.experimental.ops import threadpool from tensorflow.python.data.experimental.ops import threadpool
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test from tensorflow.python.platform import test
def _generate_captured_refvar_test_cases(): def _captured_refvar_test_combinations():
"""Generates testcases.
Returns:
A list of tuples of (testcase_name, make_dataset_fn). make_dataset_fn takes
a tf.Variable as input and creates a test dataset that uses that variable.
"""
def make_map_dataset(var): def make_map_dataset(var):
return dataset_ops.Dataset.from_tensors(0).map(lambda x: x + var) return dataset_ops.Dataset.from_tensors(0).map(lambda x: x + var)
@ -88,7 +83,7 @@ def _generate_captured_refvar_test_cases():
scan_ops.scan( scan_ops.scan(
0, lambda old_state, elem: (old_state + 1, elem + old_state + var))) 0, lambda old_state, elem: (old_state + 1, elem + old_state + var)))
return [ cases = [
# Core datasets # Core datasets
("Map", make_map_dataset), ("Map", make_map_dataset),
("FlatMap", make_flat_map_dataset), ("FlatMap", make_flat_map_dataset),
@ -100,10 +95,17 @@ def _generate_captured_refvar_test_cases():
("Scan", make_scan_dataset) ("Scan", make_scan_dataset)
] ]
def reduce_fn(x, y):
name, dataset_fn = y
return x + combinations.combine(
dataset_fn=combinations.NamedObject(name, dataset_fn))
return functools.reduce(reduce_fn, cases, [])
@test_util.run_all_in_graph_and_eager_modes
class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testOptimizationStatefulFunction(self): def testOptimizationStatefulFunction(self):
dataset = dataset_ops.Dataset.range( dataset = dataset_ops.Dataset.range(
10).map(lambda _: random_ops.random_uniform([])).batch(10) 10).map(lambda _: random_ops.random_uniform([])).batch(10)
@ -113,8 +115,9 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = self.getNext(dataset) get_next = self.getNext(dataset)
self.evaluate(get_next()) self.evaluate(get_next())
@test_util.run_v1_only("b/123902160") # TODO(b/123902160)
def testSkipEagerOptimizationLargeInputFromTensor(self): @combinations.generate(test_base.graph_only_combinations())
def testOptimizationLargeInputFromTensor(self):
input_t = array_ops.placeholder(dtypes.int32, (None, None, None)) input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
dataset = dataset_ops.Dataset.from_tensors(input_t) dataset = dataset_ops.Dataset.from_tensors(input_t)
options = dataset_ops.Options() options = dataset_ops.Options()
@ -128,8 +131,9 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)}) sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
self.evaluate(get_next) self.evaluate(get_next)
@test_util.run_v1_only("b/123902160") # TODO(b/123902160)
def testSkipEagerOptimizationLargeInputFromTensorSlices(self): @combinations.generate(test_base.graph_only_combinations())
def testOptimizationLargeInputFromTensorSlices(self):
input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None)) input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
dataset = dataset_ops.Dataset.from_tensor_slices(input_t) dataset = dataset_ops.Dataset.from_tensor_slices(input_t)
options = dataset_ops.Options() options = dataset_ops.Options()
@ -143,6 +147,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)}) sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
self.evaluate(get_next) self.evaluate(get_next)
@combinations.generate(test_base.default_test_combinations())
def testOptimizationNestedDataset(self): def testOptimizationNestedDataset(self):
def flat_map_fn(_): def flat_map_fn(_):
@ -160,6 +165,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.with_options(options) dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, expected_output=[0]) self.assertDatasetProduces(dataset, expected_output=[0])
@combinations.generate(test_base.default_test_combinations())
def testOptimizationNestedDatasetWithModifiedRetval(self): def testOptimizationNestedDatasetWithModifiedRetval(self):
def flat_map_fn(_): def flat_map_fn(_):
@ -179,6 +185,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.with_options(options) dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, expected_output=[[0]]) self.assertDatasetProduces(dataset, expected_output=[[0]])
@combinations.generate(test_base.default_test_combinations())
def testOptimizationThreadPoolDataset(self): def testOptimizationThreadPoolDataset(self):
dataset = dataset_ops.Dataset.range(10).batch(10) dataset = dataset_ops.Dataset.range(10).batch(10)
@ -195,9 +202,11 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output=[list(range(10))], expected_output=[list(range(10))],
requires_initialization=True) requires_initialization=True)
@parameterized.named_parameters(_generate_captured_refvar_test_cases()) # Reference variables are not supported in eager mode.
@test_util.run_v1_only("RefVariables are not supported in eager mode.") @combinations.generate(
def testSkipEagerOptimizationWithCapturedRefVar(self, dataset_fn): combinations.times(test_base.graph_only_combinations(),
_captured_refvar_test_combinations()))
def testOptimizationWithCapturedRefVar(self, dataset_fn):
"""Tests that default optimizations are disabled with ref variables.""" """Tests that default optimizations are disabled with ref variables."""
variable = variable_scope.get_variable( variable = variable_scope.get_variable(
"v", initializer=0, use_resource=False) "v", initializer=0, use_resource=False)
@ -241,6 +250,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
except errors.OutOfRangeError: except errors.OutOfRangeError:
break break
@combinations.generate(test_base.default_test_combinations())
def testOptimizationEnabledByDefault(self): def testOptimizationEnabledByDefault(self):
"""Tests that some optimizations are applied to datasets by default.""" """Tests that some optimizations are applied to datasets by default."""
options = dataset_ops.Options() options = dataset_ops.Options()
@ -252,6 +262,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual( self.assertEqual(
set(options._graph_rewrites()), set(expected_optimizations)) set(options._graph_rewrites()), set(expected_optimizations))
@combinations.generate(test_base.default_test_combinations())
def testOptimizationDisableDefault(self): def testOptimizationDisableDefault(self):
"""Tests that we can disable all graph optimizations enabled by default. """Tests that we can disable all graph optimizations enabled by default.
@ -269,6 +280,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual( self.assertEqual(
set(options._graph_rewrites()), set(expected_optimizations)) set(options._graph_rewrites()), set(expected_optimizations))
@combinations.generate(test_base.default_test_combinations())
def testAutotuningDefaults(self): def testAutotuningDefaults(self):
options = dataset_ops.Options() options = dataset_ops.Options()
@ -279,6 +291,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
optimization_options._AutotuneAlgorithm.HILL_CLIMB) optimization_options._AutotuneAlgorithm.HILL_CLIMB)
self.assertEqual(cpu_budget, 0) self.assertEqual(cpu_budget, 0)
@combinations.generate(test_base.default_test_combinations())
def testAutotuningBufferSizes(self): def testAutotuningBufferSizes(self):
options = dataset_ops.Options() options = dataset_ops.Options()
options.experimental_optimization.autotune_buffers = True options.experimental_optimization.autotune_buffers = True

View File

@ -44,10 +44,8 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(100): for i in range(100):
self.assertEqual(i, self.evaluate(get_next())) self.assertEqual(i, self.evaluate(get_next()))
# TODO(b/123901304) @combinations.generate(test_base.graph_only_combinations())
@combinations.generate( def testGPU(self):
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testSkipEagerGPU(self):
ds = dataset_ops.Dataset.range(100) ds = dataset_ops.Dataset.range(100)
ds_variant = ds._variant_tensor # pylint: disable=protected-access ds_variant = ds._variant_tensor # pylint: disable=protected-access
wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant) wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)