[tf.data] Migrating remaining experimental API tests to use TF combinations and performing various minor test cleanup.
PiperOrigin-RevId: 288533288 Change-Id: Iba6a980cd08fa0aba9e9703711b1dcdfbc3cb734
This commit is contained in:
parent
1d41edaee6
commit
febe171d32
@ -17,17 +17,19 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class AssertNextTest(test_base.DatasetTestBase):
|
||||
class AssertNextTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAssertNext(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).apply(
|
||||
testing.assert_next(["Map"])).map(lambda x: x)
|
||||
@ -36,6 +38,7 @@ class AssertNextTest(test_base.DatasetTestBase):
|
||||
dataset = dataset.with_options(options)
|
||||
self.assertDatasetProduces(dataset, expected_output=[0])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAssertNextInvalid(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).apply(
|
||||
testing.assert_next(["Whoops"])).map(lambda x: x)
|
||||
@ -49,6 +52,7 @@ class AssertNextTest(test_base.DatasetTestBase):
|
||||
"Asserted Whoops transformation at offset 0 but encountered "
|
||||
"Map transformation instead."))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAssertNextShort(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).apply(
|
||||
testing.assert_next(["Map", "Whoops"])).map(lambda x: x)
|
||||
|
@ -17,21 +17,20 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class NumElementsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
"""Tests for `tf.data.experimental.cardinality()`."""
|
||||
|
||||
@parameterized.named_parameters(
|
||||
# pylint: disable=g-long-lambda
|
||||
def _test_combinations():
|
||||
# pylint: disable=g-long-lambda
|
||||
cases = [
|
||||
("Batch1",
|
||||
lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=True), 2),
|
||||
("Batch2",
|
||||
@ -151,9 +150,24 @@ class NumElementsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
("Zip5", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range(
|
||||
5), dataset_ops.Dataset.range(3).filter(lambda _: True))),
|
||||
cardinality.UNKNOWN),
|
||||
# pylint: enable=g-long-lambda
|
||||
)
|
||||
def testNumElements(self, dataset_fn, expected_result):
|
||||
]
|
||||
|
||||
def reduce_fn(x, y):
|
||||
name, dataset_fn, expected_result = y
|
||||
return x + combinations.combine(
|
||||
dataset_fn=combinations.NamedObject(name, dataset_fn),
|
||||
expected_result=expected_result)
|
||||
|
||||
return functools.reduce(reduce_fn, cases, [])
|
||||
|
||||
|
||||
class CardinalityTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
"""Tests for `tf.data.experimental.cardinality()`."""
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_test_combinations()))
|
||||
def testCardinality(self, dataset_fn, expected_result):
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(
|
||||
sess.run(cardinality.cardinality(dataset_fn())), expected_result)
|
||||
|
@ -22,14 +22,14 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ModelDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAutotuneOption(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0)
|
||||
dataset = dataset.map(lambda x: x).apply(
|
||||
|
@ -17,16 +17,18 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class NonSerializableTest(test_base.DatasetTestBase):
|
||||
class NonSerializableTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNonSerializable(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0)
|
||||
dataset = dataset.apply(testing.assert_next(["FiniteSkip"]))
|
||||
@ -41,6 +43,7 @@ class NonSerializableTest(test_base.DatasetTestBase):
|
||||
dataset = dataset.with_options(options)
|
||||
self.assertDatasetProduces(dataset, expected_output=[0])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNonSerializableAsDirectInput(self):
|
||||
"""Tests that non-serializable dataset can be OptimizeDataset's input."""
|
||||
dataset = dataset_ops.Dataset.from_tensors(0)
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
from absl.testing import parameterized
|
||||
@ -30,23 +31,17 @@ from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.experimental.ops import threadpool
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _generate_captured_refvar_test_cases():
|
||||
"""Generates testcases.
|
||||
|
||||
Returns:
|
||||
A list of tuples of (testcase_name, make_dataset_fn). make_dataset_fn takes
|
||||
a tf.Variable as input and creates a test dataset that uses that variable.
|
||||
"""
|
||||
def _captured_refvar_test_combinations():
|
||||
|
||||
def make_map_dataset(var):
|
||||
return dataset_ops.Dataset.from_tensors(0).map(lambda x: x + var)
|
||||
@ -88,7 +83,7 @@ def _generate_captured_refvar_test_cases():
|
||||
scan_ops.scan(
|
||||
0, lambda old_state, elem: (old_state + 1, elem + old_state + var)))
|
||||
|
||||
return [
|
||||
cases = [
|
||||
# Core datasets
|
||||
("Map", make_map_dataset),
|
||||
("FlatMap", make_flat_map_dataset),
|
||||
@ -100,10 +95,17 @@ def _generate_captured_refvar_test_cases():
|
||||
("Scan", make_scan_dataset)
|
||||
]
|
||||
|
||||
def reduce_fn(x, y):
|
||||
name, dataset_fn = y
|
||||
return x + combinations.combine(
|
||||
dataset_fn=combinations.NamedObject(name, dataset_fn))
|
||||
|
||||
return functools.reduce(reduce_fn, cases, [])
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationStatefulFunction(self):
|
||||
dataset = dataset_ops.Dataset.range(
|
||||
10).map(lambda _: random_ops.random_uniform([])).batch(10)
|
||||
@ -113,8 +115,9 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = self.getNext(dataset)
|
||||
self.evaluate(get_next())
|
||||
|
||||
@test_util.run_v1_only("b/123902160")
|
||||
def testSkipEagerOptimizationLargeInputFromTensor(self):
|
||||
# TODO(b/123902160)
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testOptimizationLargeInputFromTensor(self):
|
||||
input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
|
||||
dataset = dataset_ops.Dataset.from_tensors(input_t)
|
||||
options = dataset_ops.Options()
|
||||
@ -128,8 +131,9 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
|
||||
self.evaluate(get_next)
|
||||
|
||||
@test_util.run_v1_only("b/123902160")
|
||||
def testSkipEagerOptimizationLargeInputFromTensorSlices(self):
|
||||
# TODO(b/123902160)
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testOptimizationLargeInputFromTensorSlices(self):
|
||||
input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(input_t)
|
||||
options = dataset_ops.Options()
|
||||
@ -143,6 +147,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
|
||||
self.evaluate(get_next)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationNestedDataset(self):
|
||||
|
||||
def flat_map_fn(_):
|
||||
@ -160,6 +165,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.with_options(options)
|
||||
self.assertDatasetProduces(dataset, expected_output=[0])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationNestedDatasetWithModifiedRetval(self):
|
||||
|
||||
def flat_map_fn(_):
|
||||
@ -179,6 +185,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.with_options(options)
|
||||
self.assertDatasetProduces(dataset, expected_output=[[0]])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationThreadPoolDataset(self):
|
||||
dataset = dataset_ops.Dataset.range(10).batch(10)
|
||||
|
||||
@ -195,9 +202,11 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output=[list(range(10))],
|
||||
requires_initialization=True)
|
||||
|
||||
@parameterized.named_parameters(_generate_captured_refvar_test_cases())
|
||||
@test_util.run_v1_only("RefVariables are not supported in eager mode.")
|
||||
def testSkipEagerOptimizationWithCapturedRefVar(self, dataset_fn):
|
||||
# Reference variables are not supported in eager mode.
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.graph_only_combinations(),
|
||||
_captured_refvar_test_combinations()))
|
||||
def testOptimizationWithCapturedRefVar(self, dataset_fn):
|
||||
"""Tests that default optimizations are disabled with ref variables."""
|
||||
variable = variable_scope.get_variable(
|
||||
"v", initializer=0, use_resource=False)
|
||||
@ -241,6 +250,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
except errors.OutOfRangeError:
|
||||
break
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationEnabledByDefault(self):
|
||||
"""Tests that some optimizations are applied to datasets by default."""
|
||||
options = dataset_ops.Options()
|
||||
@ -252,6 +262,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(
|
||||
set(options._graph_rewrites()), set(expected_optimizations))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationDisableDefault(self):
|
||||
"""Tests that we can disable all graph optimizations enabled by default.
|
||||
|
||||
@ -269,6 +280,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(
|
||||
set(options._graph_rewrites()), set(expected_optimizations))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAutotuningDefaults(self):
|
||||
options = dataset_ops.Options()
|
||||
|
||||
@ -279,6 +291,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
optimization_options._AutotuneAlgorithm.HILL_CLIMB)
|
||||
self.assertEqual(cpu_budget, 0)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAutotuningBufferSizes(self):
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.autotune_buffers = True
|
||||
|
@ -44,10 +44,8 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(get_next()))
|
||||
|
||||
# TODO(b/123901304)
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testSkipEagerGPU(self):
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testGPU(self):
|
||||
ds = dataset_ops.Dataset.range(100)
|
||||
ds_variant = ds._variant_tensor # pylint: disable=protected-access
|
||||
wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)
|
||||
|
Loading…
x
Reference in New Issue
Block a user