[tf.data] Migrating static optimization tests to use TF combinations.
PiperOrigin-RevId: 283145321 Change-Id: Ic12159919f9e77624986ac5ed3753276f5179ce2
This commit is contained in:
parent
dbceeb4674
commit
d59e6c0bfc
tensorflow/python/data/experimental/kernel_tests
auto_shard_dataset_test.py
optimization
BUILDchoose_fastest_branch_dataset_test.pychoose_fastest_dataset_test.pyfilter_fusion_test.pyfilter_with_random_uniform_fusion_test.pyhoist_random_uniform_test.pyinject_prefetch_test.pylatency_all_edges_test.pymap_and_batch_fusion_test.pymap_and_filter_fusion_test.pymap_fusion_test.pymap_parallelization_test.pymap_vectorization_test.pynoop_elimination_test.pyshuffle_and_repeat_fusion_test.py
@ -23,8 +23,8 @@ from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_
|
||||
from tensorflow.python.data.experimental.ops import distribute
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.experimental.ops import readers
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.experimental.ops import unique
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
|
@ -100,6 +100,24 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "inject_prefetch_test",
|
||||
size = "small",
|
||||
srcs = ["inject_prefetch_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/experimental/ops:testing",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "latency_all_edges_test",
|
||||
size = "small",
|
||||
|
@ -23,18 +23,18 @@ from tensorflow.python.data.experimental.ops import optimization
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSimple(self):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
|
||||
|
||||
@ -49,6 +49,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
|
||||
expected_output=[0, 1, 2, 3, 4],
|
||||
expected_shapes=dataset_ops.get_legacy_output_shapes(dataset))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCaptureSimple(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
|
||||
@ -67,6 +68,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
|
||||
self.assertDatasetProduces(
|
||||
choose_fastest, expected_output=list(range(1, 11)))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDifferentFunctions(self):
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
|
||||
@ -83,6 +85,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
|
||||
choose_fastest,
|
||||
expected_output=[list(range(10 * x, 10 * x + 10)) for x in range(10)])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWithRepeatBeforeAndAfter(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
|
||||
|
||||
@ -99,6 +102,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
|
||||
self.assertDatasetProduces(
|
||||
choose_fastest, expected_output=[[0] * 10 for _ in range(10)])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWithPrefetch(self):
|
||||
"""Should maintain ordering even if the branches do prefetching."""
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
@ -114,6 +118,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
|
||||
|
||||
self.assertDatasetProduces(choose_fastest, expected_output=list(range(100)))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWithMoreOutputThanInput(self):
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
|
||||
@ -128,6 +133,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
|
||||
|
||||
self.assertDatasetProduces(choose_fastest, expected_output=[0] * 1000)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWithBadNumElements(self):
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
|
||||
@ -153,6 +159,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
|
||||
choose_fastest,
|
||||
expected_error=(errors.InvalidArgumentError, expected_error_msg))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testErrorWithRepeat(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0)
|
||||
|
||||
|
@ -23,15 +23,15 @@ from tensorflow.python.data.experimental.ops import optimization
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ChooseFastestDatasetTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testChooseFastestSimple(self):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
|
||||
merge = optimization._ChooseFastestDataset([dataset, dataset])
|
||||
@ -40,6 +40,7 @@ class ChooseFastestDatasetTest(test_base.DatasetTestBase,
|
||||
expected_output=[0, 1, 2, 3, 4],
|
||||
expected_shapes=dataset_ops.get_legacy_output_shapes(dataset))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testChooseFastestManyInputs(self):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
|
||||
merge = optimization._ChooseFastestDataset([dataset for _ in range(5)])
|
||||
@ -48,6 +49,7 @@ class ChooseFastestDatasetTest(test_base.DatasetTestBase,
|
||||
expected_output=[0, 1, 2, 3, 4],
|
||||
expected_shapes=dataset_ops.get_legacy_output_shapes(dataset))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testChooseFastest(self):
|
||||
dataset = dataset_ops.Dataset.range(600)
|
||||
f = lambda x: 2 * x
|
||||
@ -61,11 +63,25 @@ class ChooseFastestDatasetTest(test_base.DatasetTestBase,
|
||||
],
|
||||
expected_shapes=dataset_ops.get_legacy_output_shapes(dataset_a))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Shapes", [0], [[1, 2, 3]], "must have compatible output shapes."),
|
||||
("Types", [0], [0.0], "must have the same output types."),
|
||||
("NumComponents", [0], ([0], [1]), "must have the same output types."),
|
||||
("Cardinality", [1, 2, 3], [1], "must have compatible cardinalities."))
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
slices_a=[[0]],
|
||||
slices_b=[[[1, 2, 3]]],
|
||||
error_msg="must have compatible output shapes.") +
|
||||
combinations.combine(
|
||||
slices_a=[[0]],
|
||||
slices_b=[[0.0]],
|
||||
error_msg="must have the same output types.") +
|
||||
combinations.combine(
|
||||
slices_a=[[0]],
|
||||
slices_b=[([0], [1])],
|
||||
error_msg="must have the same output types.") +
|
||||
combinations.combine(
|
||||
slices_a=[[1, 2, 3]],
|
||||
slices_b=[[0]],
|
||||
error_msg="must have compatible cardinalities.")))
|
||||
def testChooseFastestErrorWithIncompatibleInput(self, slices_a, slices_b,
|
||||
error_msg):
|
||||
dataset_a = dataset_ops.Dataset.from_tensor_slices(slices_a)
|
||||
|
@ -22,47 +22,16 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _filter_fusion_test_cases():
|
||||
"""Generates test cases for the FilterFusion optimization."""
|
||||
|
||||
take_all = lambda x: constant_op.constant(True)
|
||||
is_zero = lambda x: math_ops.equal(x, 0)
|
||||
greater = lambda x: math_ops.greater(x + 5, 0)
|
||||
|
||||
tests = []
|
||||
filters = [take_all, is_zero, greater]
|
||||
identity = lambda x: x
|
||||
for x, predicate_1 in enumerate(filters):
|
||||
for y, predicate_2 in enumerate(filters):
|
||||
tests.append(("Mixed{}{}".format(x, y), identity,
|
||||
[predicate_1, predicate_2]))
|
||||
for z, predicate_3 in enumerate(filters):
|
||||
tests.append(("Mixed{}{}{}".format(x, y, z), identity,
|
||||
[predicate_1, predicate_2, predicate_3]))
|
||||
|
||||
take_all_multiple = lambda x, y: constant_op.constant(True)
|
||||
# Multi output
|
||||
tests.append(("Multi1", lambda x: (x, x),
|
||||
[take_all_multiple, take_all_multiple]))
|
||||
tests.append(("Multi2", lambda x: (x, 2), [
|
||||
take_all_multiple,
|
||||
lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
|
||||
]))
|
||||
return tuple(tests)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(*_filter_fusion_test_cases())
|
||||
def testFilterFusion(self, map_function, predicates):
|
||||
def _testFilterFusion(self, map_function, predicates):
|
||||
dataset = dataset_ops.Dataset.range(5).apply(
|
||||
testing.assert_next(["Map", "Filter",
|
||||
"MemoryCacheImpl"])).map(map_function)
|
||||
@ -91,6 +60,26 @@ class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output.append(r)
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFilterFusionScalar(self):
|
||||
take_all = lambda x: constant_op.constant(True)
|
||||
is_zero = lambda x: math_ops.equal(x, 0)
|
||||
greater = lambda x: math_ops.greater(x + 5, 0)
|
||||
predicates = [take_all, is_zero, greater]
|
||||
for x in predicates:
|
||||
for y in predicates:
|
||||
self._testFilterFusion(lambda x: x, [x, y])
|
||||
for z in predicates:
|
||||
self._testFilterFusion(lambda x: x, [x, y, z])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFilterFusionTuple(self):
|
||||
take_all = lambda x, y: constant_op.constant(True)
|
||||
is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
|
||||
|
||||
self._testFilterFusion(lambda x: (x, x), [take_all, take_all])
|
||||
self._testFilterFusion(lambda x: (x, 2), [take_all, is_zero])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -17,17 +17,20 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class FilterWithRandomUniformFusionTest(test_base.DatasetTestBase):
|
||||
class FilterWithRandomUniformFusionTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFilterWithRandomUniformFusion(self):
|
||||
dataset = dataset_ops.Dataset.range(10000000).apply(
|
||||
testing.assert_next(["Sampling"]))
|
||||
|
@ -22,44 +22,17 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _hoist_random_uniform_test_cases():
|
||||
"""Generates test cases for the HoistRandomUniform optimization."""
|
||||
|
||||
plus_one = lambda x: x + 1
|
||||
|
||||
def random(_):
|
||||
return random_ops.random_uniform([],
|
||||
minval=1,
|
||||
maxval=10,
|
||||
dtype=dtypes.float32,
|
||||
seed=42)
|
||||
|
||||
def random_with_assert(x):
|
||||
y = random(x)
|
||||
assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
|
||||
with ops.control_dependencies([assert_op]):
|
||||
return y
|
||||
|
||||
twice_random = lambda x: (random(x) + random(x)) / 2.
|
||||
|
||||
tests = [("PlusOne", plus_one, False), ("RandomUniform", random, True),
|
||||
("RandomWithAssert", random_with_assert, True),
|
||||
("TwiceRandom", twice_random, False)]
|
||||
return tuple(tests)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _testDataset(self, dataset):
|
||||
@ -78,11 +51,10 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@parameterized.named_parameters(*_hoist_random_uniform_test_cases())
|
||||
def testHoisting(self, function, will_optimize):
|
||||
def _testHoistFunction(self, function, should_optimize):
|
||||
dataset = dataset_ops.Dataset.range(5).apply(
|
||||
testing.assert_next(
|
||||
["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function)
|
||||
["Zip[0]", "Map"] if should_optimize else ["Map"])).map(function)
|
||||
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
@ -90,6 +62,32 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.with_options(options)
|
||||
self._testDataset(dataset)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNoRandom(self):
|
||||
self._testHoistFunction(lambda x: x + 1, should_optimize=False)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRandom(self):
|
||||
|
||||
def random(_):
|
||||
return random_ops.random_uniform([],
|
||||
minval=1,
|
||||
maxval=10,
|
||||
dtype=dtypes.float32,
|
||||
seed=42)
|
||||
|
||||
def random_with_assert(x):
|
||||
y = random(x)
|
||||
assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
|
||||
with ops.control_dependencies([assert_op]):
|
||||
return y
|
||||
|
||||
self._testHoistFunction(random, should_optimize=True)
|
||||
self._testHoistFunction(random_with_assert, should_optimize=True)
|
||||
self._testHoistFunction(
|
||||
lambda x: (random(x) + random(x)) / 2, should_optimize=False)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCapturedInputs(self):
|
||||
a = constant_op.constant(1, dtype=dtypes.float32)
|
||||
b = constant_op.constant(0, dtype=dtypes.float32)
|
||||
|
@ -17,35 +17,38 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.ops import optimization
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class InjectPrefetchTest(test_base.DatasetTestBase):
|
||||
class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _enable_autotune_buffers(self, dataset):
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.autotune_buffers = True
|
||||
return dataset.with_options(options)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testParallelMap(self):
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.apply(
|
||||
optimization.assert_next(["ParallelMap", "Prefetch", "FiniteTake"]))
|
||||
testing.assert_next(["ParallelMap", "Prefetch", "FiniteTake"]))
|
||||
dataset = dataset.map(
|
||||
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
dataset = dataset.take(50)
|
||||
dataset = self._enable_autotune_buffers(dataset)
|
||||
self.assertDatasetProduces(dataset, range(1, 51))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapAndBatch(self):
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.apply(
|
||||
optimization.assert_next(["MapAndBatch", "Prefetch", "FiniteTake"]))
|
||||
testing.assert_next(["MapAndBatch", "Prefetch", "FiniteTake"]))
|
||||
dataset = dataset.map(
|
||||
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
dataset = dataset.batch(10)
|
||||
@ -54,10 +57,11 @@ class InjectPrefetchTest(test_base.DatasetTestBase):
|
||||
self.assertDatasetProduces(
|
||||
dataset, [list(range(i + 1, i + 11)) for i in range(0, 50, 10)])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testParallelInterleaveV2(self):
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.apply(
|
||||
optimization.assert_next(
|
||||
testing.assert_next(
|
||||
["ParallelInterleaveV2", "Prefetch", "FiniteTake"]))
|
||||
dataset = dataset.interleave(
|
||||
lambda x: dataset_ops.Dataset.from_tensors(x + 1),
|
||||
@ -66,10 +70,11 @@ class InjectPrefetchTest(test_base.DatasetTestBase):
|
||||
dataset = self._enable_autotune_buffers(dataset)
|
||||
self.assertDatasetProduces(dataset, range(1, 51))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testChainedParallelDatasets(self):
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.apply(
|
||||
optimization.assert_next([
|
||||
testing.assert_next([
|
||||
"ParallelMap", "Prefetch", "ParallelInterleaveV2", "Prefetch",
|
||||
"MapAndBatch", "Prefetch", "FiniteTake"
|
||||
]))
|
||||
@ -85,9 +90,10 @@ class InjectPrefetchTest(test_base.DatasetTestBase):
|
||||
dataset = self._enable_autotune_buffers(dataset)
|
||||
self.assertDatasetProduces(dataset, [[i] for i in range(3, 53)])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNoRegularMap(self):
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.apply(optimization.assert_next(["Map", "FiniteTake"]))
|
||||
dataset = dataset.apply(testing.assert_next(["Map", "FiniteTake"]))
|
||||
dataset = dataset.map(lambda x: x + 1).take(50)
|
||||
dataset = self._enable_autotune_buffers(dataset)
|
||||
self.assertDatasetProduces(dataset, range(1, 51))
|
||||
|
@ -17,15 +17,22 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.experimental.ops import stats_aggregator
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
# TODO(jsimsa): Investigate why are graph-mode tests failing.
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testLatencyStatsOptimization(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.from_tensors(1).apply(
|
||||
|
@ -17,16 +17,18 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class MapAndBatchFusionTest(test_base.DatasetTestBase):
|
||||
class MapAndBatchFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapAndBatchFusion(self):
|
||||
dataset = dataset_ops.Dataset.range(10).apply(
|
||||
testing.assert_next(
|
||||
|
@ -22,50 +22,16 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _map_and_filter_fusion_test_cases():
|
||||
"""Generates test cases for the MapAndFilterFusion optimization."""
|
||||
|
||||
identity = lambda x: x
|
||||
increment = lambda x: x + 1
|
||||
minus_five = lambda x: x - 5
|
||||
|
||||
def increment_and_square(x):
|
||||
y = x + 1
|
||||
return y * y
|
||||
|
||||
take_all = lambda x: constant_op.constant(True)
|
||||
is_zero = lambda x: math_ops.equal(x, 0)
|
||||
is_odd = lambda x: math_ops.equal(x % 2, 0)
|
||||
greater = lambda x: math_ops.greater(x + 5, 0)
|
||||
|
||||
functions = [identity, increment, minus_five, increment_and_square]
|
||||
filters = [take_all, is_zero, is_odd, greater]
|
||||
tests = []
|
||||
|
||||
for x, fun in enumerate(functions):
|
||||
for y, predicate in enumerate(filters):
|
||||
tests.append(("Mixed{}{}".format(x, y), fun, predicate))
|
||||
|
||||
# Multi output
|
||||
tests.append(("Multi1", lambda x: (x, x),
|
||||
lambda x, y: constant_op.constant(True)))
|
||||
tests.append(
|
||||
("Multi2", lambda x: (x, 2),
|
||||
lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
|
||||
return tuple(tests)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _testMapAndFilter(self, dataset, function, predicate):
|
||||
def _testDataset(self, dataset, function, predicate):
|
||||
expected_output = []
|
||||
for x in range(10):
|
||||
r = function(x)
|
||||
@ -77,8 +43,7 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output.append(r)
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@parameterized.named_parameters(*_map_and_filter_fusion_test_cases())
|
||||
def testMapFilterFusion(self, function, predicate):
|
||||
def _testMapAndFilterFusion(self, function, predicate):
|
||||
dataset = dataset_ops.Dataset.range(10).apply(
|
||||
testing.assert_next(["Map", "Filter",
|
||||
"Map"])).map(function).filter(predicate)
|
||||
@ -86,8 +51,44 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
options.experimental_optimization.map_and_filter_fusion = True
|
||||
dataset = dataset.with_options(options)
|
||||
self._testMapAndFilter(dataset, function, predicate)
|
||||
self._testDataset(dataset, function, predicate)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapAndFilterFusionScalar(self):
|
||||
identity = lambda x: x
|
||||
increment = lambda x: x + 1
|
||||
minus_five = lambda x: x - 5
|
||||
|
||||
def increment_and_square(x):
|
||||
y = x + 1
|
||||
return y * y
|
||||
|
||||
functions = [identity, increment, minus_five, increment_and_square]
|
||||
|
||||
take_all = lambda x: constant_op.constant(True)
|
||||
is_zero = lambda x: math_ops.equal(x, 0)
|
||||
is_odd = lambda x: math_ops.equal(x % 2, 0)
|
||||
greater = lambda x: math_ops.greater(x + 5, 0)
|
||||
predicates = [take_all, is_zero, is_odd, greater]
|
||||
|
||||
for function in functions:
|
||||
for predicate in predicates:
|
||||
self._testMapAndFilterFusion(function, predicate)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapAndFilterFusionTuple(self):
|
||||
replicate = lambda x: (x, x)
|
||||
with_two = lambda x: (x, 2)
|
||||
functions = [replicate, with_two]
|
||||
take_all = lambda x, y: constant_op.constant(True)
|
||||
is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
|
||||
predicates = [take_all, is_zero]
|
||||
|
||||
for function in functions:
|
||||
for predicate in predicates:
|
||||
self._testMapAndFilterFusion(function, predicate)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCapturedInputs(self):
|
||||
a = constant_op.constant(3, dtype=dtypes.int64)
|
||||
b = constant_op.constant(4, dtype=dtypes.int64)
|
||||
@ -104,7 +105,7 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
options.experimental_optimization.map_and_filter_fusion = True
|
||||
dataset = dataset.with_options(options)
|
||||
self._testMapAndFilter(dataset, function, predicate)
|
||||
self._testDataset(dataset, function, predicate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -22,51 +22,13 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _map_fusion_test_cases():
|
||||
"""Generates test cases for the MapFusion optimization."""
|
||||
|
||||
identity = lambda x: x
|
||||
increment = lambda x: x + 1
|
||||
|
||||
def increment_and_square(x):
|
||||
y = x + 1
|
||||
return y * y
|
||||
|
||||
functions = [identity, increment, increment_and_square]
|
||||
tests = []
|
||||
for i, fun1 in enumerate(functions):
|
||||
for j, fun2 in enumerate(functions):
|
||||
tests.append((
|
||||
"Test{}{}".format(i, j),
|
||||
[fun1, fun2],
|
||||
))
|
||||
for k, fun3 in enumerate(functions):
|
||||
tests.append((
|
||||
"Test{}{}{}".format(i, j, k),
|
||||
[fun1, fun2, fun3],
|
||||
))
|
||||
|
||||
swap = lambda x, n: (n, x)
|
||||
tests.append((
|
||||
"Swap1",
|
||||
[lambda x: (x, 42), swap],
|
||||
))
|
||||
tests.append((
|
||||
"Swap2",
|
||||
[lambda x: (x, 42), swap, swap],
|
||||
))
|
||||
return tuple(tests)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(*_map_fusion_test_cases())
|
||||
def testMapFusion(self, functions):
|
||||
def _testMapFusion(self, functions):
|
||||
dataset = dataset_ops.Dataset.range(5).apply(
|
||||
testing.assert_next(["Map", "MemoryCacheImpl"]))
|
||||
for function in functions:
|
||||
@ -88,6 +50,31 @@ class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output.append(r)
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapFusionScalar(self):
|
||||
identity = lambda x: x
|
||||
increment = lambda x: x + 1
|
||||
|
||||
def increment_and_square(x):
|
||||
y = x + 1
|
||||
return y * y
|
||||
|
||||
functions = [identity, increment, increment_and_square]
|
||||
|
||||
for x in functions:
|
||||
for y in functions:
|
||||
self._testMapFusion([x, y])
|
||||
for z in functions:
|
||||
self._testMapFusion([x, y, z])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapAndFilterFusionTuple(self):
|
||||
with_42 = lambda x: (x, 42)
|
||||
swap = lambda x, y: (y, x)
|
||||
|
||||
self._testMapFusion([with_42, swap])
|
||||
self._testMapFusion([with_42, swap, swap])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -22,38 +22,20 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _map_parallelization_test_cases():
|
||||
"""Generates test cases for the MapParallelization optimization."""
|
||||
|
||||
identity = lambda x: x
|
||||
increment = lambda x: x + 1
|
||||
|
||||
def assert_greater(x):
|
||||
assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
|
||||
with ops.control_dependencies([assert_op]):
|
||||
return x
|
||||
|
||||
return (("Identity", identity, True),
|
||||
("Increment", increment, True),
|
||||
("AssertGreater", assert_greater, True))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(*_map_parallelization_test_cases())
|
||||
def testMapParallelization(self, function, should_be_parallel):
|
||||
next_nodes = ["ParallelMap"] if should_be_parallel else ["Map"]
|
||||
def _testMapParallelization(self, function, should_optimize):
|
||||
next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
|
||||
dataset = dataset_ops.Dataset.range(5).apply(
|
||||
testing.assert_next(next_nodes)).map(function)
|
||||
options = dataset_ops.Options()
|
||||
@ -63,9 +45,26 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_output=[function(x) for x in range(5)])
|
||||
|
||||
def testMapParallelizationWithCapturedConstant(self):
|
||||
"""Tests that functions with captured constants are parallelized."""
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testIdentity(self):
|
||||
self._testMapParallelization(lambda x: x, should_optimize=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testIncrement(self):
|
||||
self._testMapParallelization(lambda x: x + 1, should_optimize=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAssert(self):
|
||||
|
||||
def assert_greater(x):
|
||||
assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
|
||||
with ops.control_dependencies([assert_op]):
|
||||
return x
|
||||
|
||||
self._testMapParallelization(assert_greater, should_optimize=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCapturedConstant(self):
|
||||
captured_t = constant_op.constant(42, dtype=dtypes.int64)
|
||||
def fn(x):
|
||||
return x + captured_t
|
||||
@ -78,9 +77,8 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_output=[x + 42 for x in range(5)])
|
||||
|
||||
def testMapParallelizationWithCapturedVariable(self):
|
||||
"""Tests that functions with captured variables are not parallelized."""
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCapturedVariable(self):
|
||||
captured_t = variables.Variable(42, dtype=dtypes.int64)
|
||||
def fn(x):
|
||||
return x + captured_t
|
||||
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
@ -26,12 +28,12 @@ from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import bitwise_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
@ -43,21 +45,45 @@ from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _generate_unary_cwise_math_cases():
|
||||
# TODO(rachelim): Consolidate tests with pfor when APIs are somewhat shared.
|
||||
bitwise_cases = [("Invert", bitwise_ops.invert)]
|
||||
logical_cases = [("LogicalNot", math_ops.logical_not)]
|
||||
complex_cases = [
|
||||
def _generate_test_combinations(cases):
|
||||
|
||||
def reduce_fn(x, y):
|
||||
name, fn = y
|
||||
return x + combinations.combine(map_fn=combinations.NamedObject(name, fn))
|
||||
|
||||
return functools.reduce(reduce_fn, cases, [])
|
||||
|
||||
|
||||
def _unary_bitwise_test_combinations():
|
||||
cases = [("Invert", bitwise_ops.invert)]
|
||||
return _generate_test_combinations(cases)
|
||||
|
||||
|
||||
def _unary_logical_test_combinations():
|
||||
cases = [("LogicalNot", math_ops.logical_not)]
|
||||
return _generate_test_combinations(cases)
|
||||
|
||||
|
||||
def _unary_complex_test_combinations():
|
||||
cases = [
|
||||
("Angle", math_ops.angle),
|
||||
("ComplexAbs", math_ops.abs),
|
||||
("Conj", math_ops.conj),
|
||||
("Imag", math_ops.imag),
|
||||
("Real", math_ops.real),
|
||||
]
|
||||
real_cases = [
|
||||
return _generate_test_combinations(cases)
|
||||
|
||||
|
||||
def _unary_real_test_combinations():
|
||||
# acosh requires values x >= 1
|
||||
def safe_acosh(x):
|
||||
return math_ops.acosh(1 + math_ops.square(x))
|
||||
|
||||
cases = [
|
||||
("Abs", math_ops.abs),
|
||||
("Acos", math_ops.acos),
|
||||
("Acosh", lambda x: math_ops.acosh(1 + math_ops.square(x))),
|
||||
("Acosh", safe_acosh),
|
||||
("Asin", math_ops.asin),
|
||||
("Asinh", math_ops.asinh),
|
||||
("Atan", math_ops.atan),
|
||||
@ -99,45 +125,26 @@ def _generate_unary_cwise_math_cases():
|
||||
("Tan", math_ops.tan),
|
||||
("Tanh", math_ops.tanh),
|
||||
]
|
||||
random_input = np.random.rand(3, 5)
|
||||
complex_component = np.random.rand(3, 5)
|
||||
random_int = np.random.randint(0, 10, (7, 3, 5))
|
||||
|
||||
def bitwise_dataset_factory():
|
||||
return dataset_ops.Dataset.from_tensor_slices(random_int)
|
||||
|
||||
def logical_dataset_factory():
|
||||
return dataset_ops.Dataset.from_tensor_slices(random_input > 0)
|
||||
|
||||
def random_dataset_factory():
|
||||
return dataset_ops.Dataset.from_tensor_slices(random_input)
|
||||
|
||||
def complex_dataset_factory():
|
||||
return dataset_ops.Dataset.from_tensor_slices(
|
||||
math_ops.complex(random_input, complex_component))
|
||||
|
||||
case_factory_pairs = [
|
||||
(bitwise_cases, bitwise_dataset_factory),
|
||||
(logical_cases, logical_dataset_factory),
|
||||
(complex_cases, complex_dataset_factory),
|
||||
(real_cases, random_dataset_factory),
|
||||
]
|
||||
return [(case[0], case[1], factory)
|
||||
for cases, factory in case_factory_pairs
|
||||
for case in cases]
|
||||
return _generate_test_combinations(cases)
|
||||
|
||||
|
||||
def _generate_binary_cwise_math_cases():
|
||||
bitwise_cases = [("BitwiseAnd", bitwise_ops.bitwise_and),
|
||||
("BitwiseOr", bitwise_ops.bitwise_or),
|
||||
("BitwiseXor", bitwise_ops.bitwise_xor),
|
||||
("LeftShift", bitwise_ops.left_shift),
|
||||
("RightShift", bitwise_ops.right_shift)]
|
||||
def _binary_bitwise_test_combinations():
|
||||
cases = [("BitwiseAnd", bitwise_ops.bitwise_and),
|
||||
("BitwiseOr", bitwise_ops.bitwise_or),
|
||||
("BitwiseXor", bitwise_ops.bitwise_xor),
|
||||
("LeftShift", bitwise_ops.left_shift),
|
||||
("RightShift", bitwise_ops.right_shift)]
|
||||
return _generate_test_combinations(cases)
|
||||
|
||||
logical_cases = [("LogicalAnd", math_ops.logical_and),
|
||||
("LogicalOr", math_ops.logical_or)]
|
||||
|
||||
# Wrapper functions restricting the range of inputs of zeta and polygamma.
|
||||
def _binary_logical_test_combinations():
|
||||
cases = [("LogicalAnd", math_ops.logical_and),
|
||||
("LogicalOr", math_ops.logical_or)]
|
||||
return _generate_test_combinations(cases)
|
||||
|
||||
|
||||
def _binary_real_test_combinations():
|
||||
|
||||
def safe_polygamma(x, y):
|
||||
return math_ops.polygamma(
|
||||
math_ops.round(clip_ops.clip_by_value(y, 1, 10)), x * x + 1)
|
||||
@ -145,7 +152,7 @@ def _generate_binary_cwise_math_cases():
|
||||
def safe_zeta(x, y):
|
||||
return math_ops.zeta(x * x + 1, y * y)
|
||||
|
||||
real_cases = [
|
||||
cases = [
|
||||
("Add", math_ops.add),
|
||||
("AddV2", math_ops.add_v2),
|
||||
("Atan2", math_ops.atan2),
|
||||
@ -174,150 +181,10 @@ def _generate_binary_cwise_math_cases():
|
||||
("TruncateMod", math_ops.truncate_mod),
|
||||
("Zeta", safe_zeta),
|
||||
]
|
||||
|
||||
# Exercises broadcasting capabilities
|
||||
x = np.random.rand(7, 3, 5)
|
||||
y = np.random.rand(3, 5)
|
||||
|
||||
x_int = np.random.randint(0, 10, (7, 3, 5))
|
||||
y_int = np.random.randint(0, 10, (3, 5))
|
||||
|
||||
def bitwise_dataset_factory():
|
||||
return dataset_ops.Dataset.from_tensors((x_int, y_int))
|
||||
|
||||
def logical_dataset_factory():
|
||||
return dataset_ops.Dataset.from_tensors((x > 0, y > 0))
|
||||
|
||||
def random_dataset_factory():
|
||||
return dataset_ops.Dataset.from_tensors((x, y))
|
||||
|
||||
case_factory_pairs = [
|
||||
(bitwise_cases, bitwise_dataset_factory),
|
||||
(logical_cases, logical_dataset_factory),
|
||||
(real_cases, random_dataset_factory),
|
||||
]
|
||||
return [(case[0], case[1], factory)
|
||||
for cases, factory in case_factory_pairs
|
||||
for case in cases]
|
||||
return _generate_test_combinations(cases)
|
||||
|
||||
|
||||
def _generate_cwise_test_cases():
|
||||
return _generate_unary_cwise_math_cases() + _generate_binary_cwise_math_cases(
|
||||
)
|
||||
|
||||
|
||||
def _generate_csv_test_case():
|
||||
|
||||
def csv_factory():
|
||||
return dataset_ops.Dataset.from_tensor_slices(["1.0:2:a",
|
||||
"2.4:5:c"]).repeat(5)
|
||||
|
||||
def decode_csv_fn(x):
|
||||
return parsing_ops.decode_csv(
|
||||
x,
|
||||
record_defaults=[
|
||||
constant_op.constant([], dtypes.float32),
|
||||
constant_op.constant([], dtypes.int32),
|
||||
constant_op.constant([], dtypes.string)
|
||||
],
|
||||
field_delim=":")
|
||||
|
||||
return decode_csv_fn, csv_factory
|
||||
|
||||
|
||||
def _generate_parse_single_example_test_case():
|
||||
# When sparse tensors are used, map_vectorization is not
|
||||
# attempted because the output_shapes of the map dataset are not defined.
|
||||
# TODO(rachelim): Consider being more lax with checking the output_shapes of
|
||||
# the map node.
|
||||
|
||||
def parse_example_factory():
|
||||
|
||||
def _int64_feature(*values):
|
||||
return feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=values))
|
||||
|
||||
def _bytes_feature(*values):
|
||||
return feature_pb2.Feature(
|
||||
bytes_list=feature_pb2.BytesList(
|
||||
value=[v.encode("utf-8") for v in values]))
|
||||
|
||||
return dataset_ops.Dataset.from_tensor_slices(
|
||||
constant_op.constant([
|
||||
example_pb2.Example(
|
||||
features=feature_pb2.Features(
|
||||
feature={
|
||||
"dense_int": _int64_feature(i),
|
||||
"dense_str": _bytes_feature(str(i)),
|
||||
})).SerializeToString() for i in range(10)
|
||||
]))
|
||||
|
||||
def parse_single_example_fn(x):
|
||||
features = {
|
||||
"dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0),
|
||||
"dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""),
|
||||
}
|
||||
return parsing_ops.parse_single_example(x, features)
|
||||
|
||||
return parse_single_example_fn, parse_example_factory
|
||||
|
||||
|
||||
def _generate_optimization_test_cases():
|
||||
|
||||
def base_dataset_factory():
|
||||
return dataset_ops.Dataset.from_tensors(np.random.rand(10, 3)).repeat(5)
|
||||
|
||||
rand_val = np.random.rand(1, 1, 1, 1, 1, 1)
|
||||
|
||||
csv_test_case = _generate_csv_test_case()
|
||||
parse_fn, parse_base = _generate_parse_single_example_test_case()
|
||||
|
||||
def dense_output_only_parse_fn(x):
|
||||
# Since we haven't implemented a vectorizer for SerializeSparse, any
|
||||
# function with sparse outputs will only be naively vectorized.
|
||||
parse_result = parse_fn(x)
|
||||
return [
|
||||
y for y in parse_result if not isinstance(y, sparse_tensor.SparseTensor)
|
||||
]
|
||||
|
||||
def map_fn_with_cycle(x):
|
||||
c = lambda i: math_ops.less(i, 10)
|
||||
b = lambda i: math_ops.add(i, 1)
|
||||
return control_flow_ops.while_loop(c, b, [x])
|
||||
|
||||
# Misc test cases
|
||||
test_cases = [
|
||||
("Basic", lambda x: (x, x + 1), base_dataset_factory),
|
||||
("Broadcast", lambda x: x + rand_val, base_dataset_factory),
|
||||
("Cycle", map_fn_with_cycle, lambda: dataset_ops.Dataset.from_tensors(1)),
|
||||
("Const", lambda x: 2, base_dataset_factory),
|
||||
("Cast", lambda x: math_ops.cast(x, dtypes.float64),
|
||||
base_dataset_factory),
|
||||
("Reshape", lambda x: array_ops.reshape(x, (-1, 30)),
|
||||
base_dataset_factory),
|
||||
("Transpose", array_ops.transpose, base_dataset_factory),
|
||||
("Unpack", array_ops.unstack, base_dataset_factory),
|
||||
("UnpackNegativeAxis", lambda x: array_ops.unstack(x, axis=-1),
|
||||
base_dataset_factory),
|
||||
# Parsing ops
|
||||
("DecodeCSV", csv_test_case[0], csv_test_case[1]),
|
||||
("ParseSingleExample", parse_fn, parse_base),
|
||||
("ParseSingleExampleDenseOutputOnly", dense_output_only_parse_fn,
|
||||
parse_base),
|
||||
] + _generate_cwise_test_cases()
|
||||
|
||||
return [{
|
||||
"testcase_name":
|
||||
x[0] + "Parallel" if num_parallel_calls is not None else x[0],
|
||||
"map_fn":
|
||||
x[1],
|
||||
"base_dataset_factory":
|
||||
x[2],
|
||||
"num_parallel_calls":
|
||||
num_parallel_calls
|
||||
} for x in test_cases for num_parallel_calls in (None, 12)]
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
# TODO(rachelim): Consolidate tests with pfor when APIs are somewhat shared.
|
||||
class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _enable_map_vectorization(self, dataset, use_choose=True):
|
||||
@ -370,13 +237,223 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
optimized = self._enable_map_vectorization(optimized)
|
||||
return unoptimized, optimized
|
||||
|
||||
@parameterized.named_parameters(_generate_optimization_test_cases())
|
||||
def testOptimization(self, map_fn, base_dataset_factory, num_parallel_calls):
|
||||
base_dataset = base_dataset_factory()
|
||||
unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
|
||||
def _testOptimization(self, map_fn, dataset_factory, num_parallel_calls):
|
||||
dataset = dataset_factory()
|
||||
unoptimized, optimized = self._get_test_datasets(dataset, map_fn,
|
||||
num_parallel_calls)
|
||||
self.assertDatasetsEqual(unoptimized, optimized)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testBasic(self, num_parallel_calls):
|
||||
data = np.random.rand(10, 3)
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
|
||||
map_fn = lambda x: (x, x + 1)
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testBroadcast(self, num_parallel_calls):
|
||||
data = np.random.rand(10, 3)
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
|
||||
value = np.random.rand(1, 1, 1, 1, 1, 1)
|
||||
map_fn = lambda x: x + value
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testCast(self, num_parallel_calls):
|
||||
data = np.random.rand(10, 3)
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
|
||||
map_fn = lambda x: math_ops.cast(x, dtypes.float64)
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testConst(self, num_parallel_calls):
|
||||
data = np.random.rand(10, 3)
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
|
||||
map_fn = lambda x: 2
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testCycle(self, num_parallel_calls):
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(1)
|
||||
|
||||
def map_fn(x):
|
||||
c = lambda i: math_ops.less(i, 10)
|
||||
b = lambda i: math_ops.add(i, 1)
|
||||
return control_flow_ops.while_loop(c, b, [x])
|
||||
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testReshape(self, num_parallel_calls):
|
||||
data = np.random.rand(10, 3)
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
|
||||
map_fn = lambda x: array_ops.reshape(x, (-1, 30))
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testTranspose(self, num_parallel_calls):
|
||||
data = np.random.rand(10, 3)
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
|
||||
map_fn = array_ops.transpose
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testUnstack(self, num_parallel_calls):
|
||||
data = np.random.rand(10, 3)
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
|
||||
map_fns = [array_ops.unstack, lambda x: array_ops.unstack(x, axis=-1)]
|
||||
for map_fn in map_fns:
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_unary_bitwise_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testUnaryBitwiseOperations(self, map_fn, num_parallel_calls):
|
||||
x = np.random.randint(0, 10, (7, 3, 5))
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x)
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_unary_logical_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testUnaryLogicalOperations(self, map_fn, num_parallel_calls):
|
||||
x = np.random.rand(3, 5)
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x > 0)
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_unary_complex_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testUnaryComplexOperations(self, map_fn, num_parallel_calls):
|
||||
x = math_ops.complex(np.random.rand(3, 5), np.random.rand(3, 5))
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x)
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_unary_real_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testUnaryRealOperations(self, map_fn, num_parallel_calls):
|
||||
x = np.random.rand(3, 5)
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x)
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_binary_bitwise_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testBinaryBitwiseOperations(self, map_fn, num_parallel_calls):
|
||||
x = np.random.randint(0, 10, (7, 3, 5))
|
||||
y = np.random.randint(0, 10, (3, 5))
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensors((x, y))
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_binary_logical_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testBinaryLogicalOperations(self, map_fn, num_parallel_calls):
|
||||
x = np.random.rand(7, 3, 5)
|
||||
y = np.random.rand(3, 5)
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensors((x > 0, y > 0))
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_binary_real_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testBinaryRealOperations(self, map_fn, num_parallel_calls):
|
||||
x = np.random.rand(7, 3, 5)
|
||||
y = np.random.rand(3, 5)
|
||||
dataset_factory = lambda: dataset_ops.Dataset.from_tensors((x, y))
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testDecodeCsv(self, num_parallel_calls):
|
||||
|
||||
def dataset_factory():
|
||||
return dataset_ops.Dataset.from_tensor_slices(["1.0:2:a",
|
||||
"2.4:5:c"]).repeat(5)
|
||||
|
||||
def decode_csv_fn(x):
|
||||
return parsing_ops.decode_csv(
|
||||
x,
|
||||
record_defaults=[
|
||||
constant_op.constant([], dtypes.float32),
|
||||
constant_op.constant([], dtypes.int32),
|
||||
constant_op.constant([], dtypes.string)
|
||||
],
|
||||
field_delim=":")
|
||||
|
||||
self._testOptimization(decode_csv_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 12])))
|
||||
def testParseSingleExample(self, num_parallel_calls):
|
||||
|
||||
def dataset_factory():
|
||||
|
||||
def _int64_feature(*values):
|
||||
return feature_pb2.Feature(
|
||||
int64_list=feature_pb2.Int64List(value=values))
|
||||
|
||||
def _bytes_feature(*values):
|
||||
return feature_pb2.Feature(
|
||||
bytes_list=feature_pb2.BytesList(
|
||||
value=[v.encode("utf-8") for v in values]))
|
||||
|
||||
# pylint:disable=g-complex-comprehension
|
||||
return dataset_ops.Dataset.from_tensor_slices(
|
||||
constant_op.constant([
|
||||
example_pb2.Example(
|
||||
features=feature_pb2.Features(
|
||||
feature={
|
||||
"dense_int": _int64_feature(i),
|
||||
"dense_str": _bytes_feature(str(i)),
|
||||
})).SerializeToString() for i in range(10)
|
||||
]))
|
||||
|
||||
def parse_fn(x):
|
||||
features = {
|
||||
"dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0),
|
||||
"dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""),
|
||||
}
|
||||
return parsing_ops.parse_single_example(x, features)
|
||||
|
||||
def dense_only_parse_fn(x):
|
||||
return [
|
||||
y for y in parse_fn(x)
|
||||
if not isinstance(y, sparse_tensor.SparseTensor)
|
||||
]
|
||||
|
||||
map_fns = [parse_fn, dense_only_parse_fn]
|
||||
|
||||
for map_fn in map_fns:
|
||||
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationBadMapFn(self):
|
||||
# Test map functions that give an error
|
||||
def map_fn(x):
|
||||
@ -391,6 +468,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
nxt = dataset_ops.make_one_shot_iterator(optimized).get_next()
|
||||
self.evaluate(nxt)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationWithCapturedInputs(self):
|
||||
# Tests that vectorization works with captured inputs.
|
||||
y = constant_op.constant(1, shape=(2,))
|
||||
@ -405,6 +483,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
base_dataset, map_fn, expect_optimized=True)
|
||||
self.assertDatasetsEqual(optimized, unoptimized)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationWithMapAndBatchFusion(self):
|
||||
# Tests that vectorization works on fused map and batch.
|
||||
def map_fn(x):
|
||||
@ -425,12 +504,11 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
optimized = self._enable_map_vectorization(optimized)
|
||||
self.assertDatasetsEqual(optimized, unoptimized)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", True, True),
|
||||
("2", True, False),
|
||||
("3", False, True),
|
||||
("4", False, False),
|
||||
)
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
fuse_first=[True, False], fuse_second=[True, False])))
|
||||
def testOptimizationWithChainedMapAndBatch(self, fuse_first, fuse_second):
|
||||
# Tests that vectorization works on chained map and batch functions.
|
||||
def map_fn(x):
|
||||
@ -474,6 +552,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
optimized = self._enable_map_vectorization(optimized)
|
||||
self.assertDatasetsEqual(optimized, unoptimized)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationIgnoreStateful(self):
|
||||
|
||||
def map_fn(x):
|
||||
@ -488,6 +567,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = self.getNext(dataset)
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationIgnoreRagged(self):
|
||||
# Make sure we ignore inputs that might not be uniformly sized
|
||||
def map_fn(x):
|
||||
@ -499,6 +579,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
base_dataset, map_fn, expect_optimized=False)
|
||||
self.assertDatasetsEqual(unoptimized, optimized)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationIgnoreRaggedMap(self):
|
||||
# Don't optimize when the output of the map fn shapes are unknown.
|
||||
def map_fn(x):
|
||||
@ -512,6 +593,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = self.getNext(dataset)
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationWithUnknownBatchShape(self):
|
||||
tensor = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
|
||||
@ -526,6 +608,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
optimized = self._enable_map_vectorization(unoptimized)
|
||||
self.assertDatasetsEqual(unoptimized, optimized)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationWithSparseTensor(self):
|
||||
base_dataset = dataset_ops.Dataset.from_tensors(0)
|
||||
|
||||
@ -542,6 +625,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
optimized = self._enable_map_vectorization(unoptimized)
|
||||
self.assertDatasetsEqual(unoptimized, optimized)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationWithPrefetch(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.map(lambda x: x)
|
||||
@ -550,6 +634,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = self._enable_map_vectorization(dataset)
|
||||
self.assertDatasetProduces(dataset, [list(range(10))])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationWithoutChooseFastest(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.map(lambda x: x**2)
|
||||
|
@ -17,19 +17,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class NoopEliminationTest(test_base.DatasetTestBase):
|
||||
class NoopEliminationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNoopElimination(self):
|
||||
a = constant_op.constant(1, dtype=dtypes.int64)
|
||||
b = constant_op.constant(2, dtype=dtypes.int64)
|
||||
|
@ -17,19 +17,22 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase):
|
||||
class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testShuffleAndRepeatFusion(self):
|
||||
if tf2.enabled() and context.executing_eagerly():
|
||||
expected = "Shuffle"
|
||||
|
Loading…
Reference in New Issue
Block a user