[tf.data] Migrating static optimization tests to use TF combinations.

PiperOrigin-RevId: 283145321
Change-Id: Ic12159919f9e77624986ac5ed3753276f5179ce2
This commit is contained in:
Jiri Simsa 2019-11-30 07:43:54 -08:00 committed by TensorFlower Gardener
parent dbceeb4674
commit d59e6c0bfc
16 changed files with 521 additions and 399 deletions

View File

@ -23,8 +23,8 @@ from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.experimental.ops import unique
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops

View File

@ -100,6 +100,24 @@ tf_py_test(
],
)
tf_py_test(
name = "inject_prefetch_test",
size = "small",
srcs = ["inject_prefetch_test.py"],
additional_deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/experimental/ops:testing",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
tags = [
"no_oss",
"no_pip",
"no_windows",
],
)
tf_py_test(
name = "latency_all_edges_test",
size = "small",

View File

@ -23,18 +23,18 @@ from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testSimple(self):
dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
@ -49,6 +49,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
expected_output=[0, 1, 2, 3, 4],
expected_shapes=dataset_ops.get_legacy_output_shapes(dataset))
@combinations.generate(test_base.default_test_combinations())
def testCaptureSimple(self):
dataset = dataset_ops.Dataset.range(10)
@ -67,6 +68,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
self.assertDatasetProduces(
choose_fastest, expected_output=list(range(1, 11)))
@combinations.generate(test_base.default_test_combinations())
def testDifferentFunctions(self):
dataset = dataset_ops.Dataset.range(100)
@ -83,6 +85,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
choose_fastest,
expected_output=[list(range(10 * x, 10 * x + 10)) for x in range(10)])
@combinations.generate(test_base.default_test_combinations())
def testWithRepeatBeforeAndAfter(self):
dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
@ -99,6 +102,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
self.assertDatasetProduces(
choose_fastest, expected_output=[[0] * 10 for _ in range(10)])
@combinations.generate(test_base.default_test_combinations())
def testWithPrefetch(self):
"""Should maintain ordering even if the branches do prefetching."""
dataset = dataset_ops.Dataset.range(100)
@ -114,6 +118,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
self.assertDatasetProduces(choose_fastest, expected_output=list(range(100)))
@combinations.generate(test_base.default_test_combinations())
def testWithMoreOutputThanInput(self):
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
@ -128,6 +133,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
self.assertDatasetProduces(choose_fastest, expected_output=[0] * 1000)
@combinations.generate(test_base.default_test_combinations())
def testWithBadNumElements(self):
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
@ -153,6 +159,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
choose_fastest,
expected_error=(errors.InvalidArgumentError, expected_error_msg))
@combinations.generate(test_base.default_test_combinations())
def testErrorWithRepeat(self):
dataset = dataset_ops.Dataset.from_tensors(0)

View File

@ -23,15 +23,15 @@ from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ChooseFastestDatasetTest(test_base.DatasetTestBase,
parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testChooseFastestSimple(self):
dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
merge = optimization._ChooseFastestDataset([dataset, dataset])
@ -40,6 +40,7 @@ class ChooseFastestDatasetTest(test_base.DatasetTestBase,
expected_output=[0, 1, 2, 3, 4],
expected_shapes=dataset_ops.get_legacy_output_shapes(dataset))
@combinations.generate(test_base.default_test_combinations())
def testChooseFastestManyInputs(self):
dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
merge = optimization._ChooseFastestDataset([dataset for _ in range(5)])
@ -48,6 +49,7 @@ class ChooseFastestDatasetTest(test_base.DatasetTestBase,
expected_output=[0, 1, 2, 3, 4],
expected_shapes=dataset_ops.get_legacy_output_shapes(dataset))
@combinations.generate(test_base.default_test_combinations())
def testChooseFastest(self):
dataset = dataset_ops.Dataset.range(600)
f = lambda x: 2 * x
@ -61,11 +63,25 @@ class ChooseFastestDatasetTest(test_base.DatasetTestBase,
],
expected_shapes=dataset_ops.get_legacy_output_shapes(dataset_a))
@parameterized.named_parameters(
("Shapes", [0], [[1, 2, 3]], "must have compatible output shapes."),
("Types", [0], [0.0], "must have the same output types."),
("NumComponents", [0], ([0], [1]), "must have the same output types."),
("Cardinality", [1, 2, 3], [1], "must have compatible cardinalities."))
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
slices_a=[[0]],
slices_b=[[[1, 2, 3]]],
error_msg="must have compatible output shapes.") +
combinations.combine(
slices_a=[[0]],
slices_b=[[0.0]],
error_msg="must have the same output types.") +
combinations.combine(
slices_a=[[0]],
slices_b=[([0], [1])],
error_msg="must have the same output types.") +
combinations.combine(
slices_a=[[1, 2, 3]],
slices_b=[[0]],
error_msg="must have compatible cardinalities.")))
def testChooseFastestErrorWithIncompatibleInput(self, slices_a, slices_b,
error_msg):
dataset_a = dataset_ops.Dataset.from_tensor_slices(slices_a)

View File

@ -22,47 +22,16 @@ from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
def _filter_fusion_test_cases():
"""Generates test cases for the FilterFusion optimization."""
take_all = lambda x: constant_op.constant(True)
is_zero = lambda x: math_ops.equal(x, 0)
greater = lambda x: math_ops.greater(x + 5, 0)
tests = []
filters = [take_all, is_zero, greater]
identity = lambda x: x
for x, predicate_1 in enumerate(filters):
for y, predicate_2 in enumerate(filters):
tests.append(("Mixed{}{}".format(x, y), identity,
[predicate_1, predicate_2]))
for z, predicate_3 in enumerate(filters):
tests.append(("Mixed{}{}{}".format(x, y, z), identity,
[predicate_1, predicate_2, predicate_3]))
take_all_multiple = lambda x, y: constant_op.constant(True)
# Multi output
tests.append(("Multi1", lambda x: (x, x),
[take_all_multiple, take_all_multiple]))
tests.append(("Multi2", lambda x: (x, 2), [
take_all_multiple,
lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
]))
return tuple(tests)
@test_util.run_all_in_graph_and_eager_modes
class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(*_filter_fusion_test_cases())
def testFilterFusion(self, map_function, predicates):
def _testFilterFusion(self, map_function, predicates):
dataset = dataset_ops.Dataset.range(5).apply(
testing.assert_next(["Map", "Filter",
"MemoryCacheImpl"])).map(map_function)
@ -91,6 +60,26 @@ class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output.append(r)
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testFilterFusionScalar(self):
take_all = lambda x: constant_op.constant(True)
is_zero = lambda x: math_ops.equal(x, 0)
greater = lambda x: math_ops.greater(x + 5, 0)
predicates = [take_all, is_zero, greater]
for x in predicates:
for y in predicates:
self._testFilterFusion(lambda x: x, [x, y])
for z in predicates:
self._testFilterFusion(lambda x: x, [x, y, z])
@combinations.generate(test_base.default_test_combinations())
def testFilterFusionTuple(self):
take_all = lambda x, y: constant_op.constant(True)
is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
self._testFilterFusion(lambda x: (x, x), [take_all, take_all])
self._testFilterFusion(lambda x: (x, 2), [take_all, is_zero])
if __name__ == "__main__":
test.main()

View File

@ -17,17 +17,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import combinations
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class FilterWithRandomUniformFusionTest(test_base.DatasetTestBase):
class FilterWithRandomUniformFusionTest(test_base.DatasetTestBase,
parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testFilterWithRandomUniformFusion(self):
dataset = dataset_ops.Dataset.range(10000000).apply(
testing.assert_next(["Sampling"]))

View File

@ -22,44 +22,17 @@ from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
def _hoist_random_uniform_test_cases():
"""Generates test cases for the HoistRandomUniform optimization."""
plus_one = lambda x: x + 1
def random(_):
return random_ops.random_uniform([],
minval=1,
maxval=10,
dtype=dtypes.float32,
seed=42)
def random_with_assert(x):
y = random(x)
assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
with ops.control_dependencies([assert_op]):
return y
twice_random = lambda x: (random(x) + random(x)) / 2.
tests = [("PlusOne", plus_one, False), ("RandomUniform", random, True),
("RandomWithAssert", random_with_assert, True),
("TwiceRandom", twice_random, False)]
return tuple(tests)
@test_util.run_all_in_graph_and_eager_modes
class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
def _testDataset(self, dataset):
@ -78,11 +51,10 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@parameterized.named_parameters(*_hoist_random_uniform_test_cases())
def testHoisting(self, function, will_optimize):
def _testHoistFunction(self, function, should_optimize):
dataset = dataset_ops.Dataset.range(5).apply(
testing.assert_next(
["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function)
["Zip[0]", "Map"] if should_optimize else ["Map"])).map(function)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
@ -90,6 +62,32 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.with_options(options)
self._testDataset(dataset)
@combinations.generate(test_base.default_test_combinations())
def testNoRandom(self):
self._testHoistFunction(lambda x: x + 1, should_optimize=False)
@combinations.generate(test_base.default_test_combinations())
def testRandom(self):
def random(_):
return random_ops.random_uniform([],
minval=1,
maxval=10,
dtype=dtypes.float32,
seed=42)
def random_with_assert(x):
y = random(x)
assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
with ops.control_dependencies([assert_op]):
return y
self._testHoistFunction(random, should_optimize=True)
self._testHoistFunction(random_with_assert, should_optimize=True)
self._testHoistFunction(
lambda x: (random(x) + random(x)) / 2, should_optimize=False)
@combinations.generate(test_base.default_test_combinations())
def testCapturedInputs(self):
a = constant_op.constant(1, dtype=dtypes.float32)
b = constant_op.constant(0, dtype=dtypes.float32)

View File

@ -17,35 +17,38 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import optimization
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class InjectPrefetchTest(test_base.DatasetTestBase):
class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
def _enable_autotune_buffers(self, dataset):
options = dataset_ops.Options()
options.experimental_optimization.autotune_buffers = True
return dataset.with_options(options)
@combinations.generate(test_base.default_test_combinations())
def testParallelMap(self):
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.apply(
optimization.assert_next(["ParallelMap", "Prefetch", "FiniteTake"]))
testing.assert_next(["ParallelMap", "Prefetch", "FiniteTake"]))
dataset = dataset.map(
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
dataset = dataset.take(50)
dataset = self._enable_autotune_buffers(dataset)
self.assertDatasetProduces(dataset, range(1, 51))
@combinations.generate(test_base.default_test_combinations())
def testMapAndBatch(self):
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.apply(
optimization.assert_next(["MapAndBatch", "Prefetch", "FiniteTake"]))
testing.assert_next(["MapAndBatch", "Prefetch", "FiniteTake"]))
dataset = dataset.map(
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
dataset = dataset.batch(10)
@ -54,10 +57,11 @@ class InjectPrefetchTest(test_base.DatasetTestBase):
self.assertDatasetProduces(
dataset, [list(range(i + 1, i + 11)) for i in range(0, 50, 10)])
@combinations.generate(test_base.default_test_combinations())
def testParallelInterleaveV2(self):
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.apply(
optimization.assert_next(
testing.assert_next(
["ParallelInterleaveV2", "Prefetch", "FiniteTake"]))
dataset = dataset.interleave(
lambda x: dataset_ops.Dataset.from_tensors(x + 1),
@ -66,10 +70,11 @@ class InjectPrefetchTest(test_base.DatasetTestBase):
dataset = self._enable_autotune_buffers(dataset)
self.assertDatasetProduces(dataset, range(1, 51))
@combinations.generate(test_base.default_test_combinations())
def testChainedParallelDatasets(self):
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.apply(
optimization.assert_next([
testing.assert_next([
"ParallelMap", "Prefetch", "ParallelInterleaveV2", "Prefetch",
"MapAndBatch", "Prefetch", "FiniteTake"
]))
@ -85,9 +90,10 @@ class InjectPrefetchTest(test_base.DatasetTestBase):
dataset = self._enable_autotune_buffers(dataset)
self.assertDatasetProduces(dataset, [[i] for i in range(3, 53)])
@combinations.generate(test_base.default_test_combinations())
def testNoRegularMap(self):
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.apply(optimization.assert_next(["Map", "FiniteTake"]))
dataset = dataset.apply(testing.assert_next(["Map", "FiniteTake"]))
dataset = dataset.map(lambda x: x + 1).take(50)
dataset = self._enable_autotune_buffers(dataset)
self.assertDatasetProduces(dataset, range(1, 51))

View File

@ -17,15 +17,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.experimental.ops import stats_aggregator
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase):
class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase,
parameterized.TestCase):
# TODO(jsimsa): Investigate why are graph-mode tests failing.
@combinations.generate(test_base.eager_only_combinations())
def testLatencyStatsOptimization(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.from_tensors(1).apply(

View File

@ -17,16 +17,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class MapAndBatchFusionTest(test_base.DatasetTestBase):
class MapAndBatchFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testMapAndBatchFusion(self):
dataset = dataset_ops.Dataset.range(10).apply(
testing.assert_next(

View File

@ -22,50 +22,16 @@ from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
def _map_and_filter_fusion_test_cases():
"""Generates test cases for the MapAndFilterFusion optimization."""
identity = lambda x: x
increment = lambda x: x + 1
minus_five = lambda x: x - 5
def increment_and_square(x):
y = x + 1
return y * y
take_all = lambda x: constant_op.constant(True)
is_zero = lambda x: math_ops.equal(x, 0)
is_odd = lambda x: math_ops.equal(x % 2, 0)
greater = lambda x: math_ops.greater(x + 5, 0)
functions = [identity, increment, minus_five, increment_and_square]
filters = [take_all, is_zero, is_odd, greater]
tests = []
for x, fun in enumerate(functions):
for y, predicate in enumerate(filters):
tests.append(("Mixed{}{}".format(x, y), fun, predicate))
# Multi output
tests.append(("Multi1", lambda x: (x, x),
lambda x, y: constant_op.constant(True)))
tests.append(
("Multi2", lambda x: (x, 2),
lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
return tuple(tests)
@test_util.run_all_in_graph_and_eager_modes
class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
def _testMapAndFilter(self, dataset, function, predicate):
def _testDataset(self, dataset, function, predicate):
expected_output = []
for x in range(10):
r = function(x)
@ -77,8 +43,7 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output.append(r)
self.assertDatasetProduces(dataset, expected_output=expected_output)
@parameterized.named_parameters(*_map_and_filter_fusion_test_cases())
def testMapFilterFusion(self, function, predicate):
def _testMapAndFilterFusion(self, function, predicate):
dataset = dataset_ops.Dataset.range(10).apply(
testing.assert_next(["Map", "Filter",
"Map"])).map(function).filter(predicate)
@ -86,8 +51,44 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.map_and_filter_fusion = True
dataset = dataset.with_options(options)
self._testMapAndFilter(dataset, function, predicate)
self._testDataset(dataset, function, predicate)
@combinations.generate(test_base.default_test_combinations())
def testMapAndFilterFusionScalar(self):
identity = lambda x: x
increment = lambda x: x + 1
minus_five = lambda x: x - 5
def increment_and_square(x):
y = x + 1
return y * y
functions = [identity, increment, minus_five, increment_and_square]
take_all = lambda x: constant_op.constant(True)
is_zero = lambda x: math_ops.equal(x, 0)
is_odd = lambda x: math_ops.equal(x % 2, 0)
greater = lambda x: math_ops.greater(x + 5, 0)
predicates = [take_all, is_zero, is_odd, greater]
for function in functions:
for predicate in predicates:
self._testMapAndFilterFusion(function, predicate)
@combinations.generate(test_base.default_test_combinations())
def testMapAndFilterFusionTuple(self):
replicate = lambda x: (x, x)
with_two = lambda x: (x, 2)
functions = [replicate, with_two]
take_all = lambda x, y: constant_op.constant(True)
is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
predicates = [take_all, is_zero]
for function in functions:
for predicate in predicates:
self._testMapAndFilterFusion(function, predicate)
@combinations.generate(test_base.default_test_combinations())
def testCapturedInputs(self):
a = constant_op.constant(3, dtype=dtypes.int64)
b = constant_op.constant(4, dtype=dtypes.int64)
@ -104,7 +105,7 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.map_and_filter_fusion = True
dataset = dataset.with_options(options)
self._testMapAndFilter(dataset, function, predicate)
self._testDataset(dataset, function, predicate)
if __name__ == "__main__":

View File

@ -22,51 +22,13 @@ from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
def _map_fusion_test_cases():
"""Generates test cases for the MapFusion optimization."""
identity = lambda x: x
increment = lambda x: x + 1
def increment_and_square(x):
y = x + 1
return y * y
functions = [identity, increment, increment_and_square]
tests = []
for i, fun1 in enumerate(functions):
for j, fun2 in enumerate(functions):
tests.append((
"Test{}{}".format(i, j),
[fun1, fun2],
))
for k, fun3 in enumerate(functions):
tests.append((
"Test{}{}{}".format(i, j, k),
[fun1, fun2, fun3],
))
swap = lambda x, n: (n, x)
tests.append((
"Swap1",
[lambda x: (x, 42), swap],
))
tests.append((
"Swap2",
[lambda x: (x, 42), swap, swap],
))
return tuple(tests)
@test_util.run_all_in_graph_and_eager_modes
class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(*_map_fusion_test_cases())
def testMapFusion(self, functions):
def _testMapFusion(self, functions):
dataset = dataset_ops.Dataset.range(5).apply(
testing.assert_next(["Map", "MemoryCacheImpl"]))
for function in functions:
@ -88,6 +50,31 @@ class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output.append(r)
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testMapFusionScalar(self):
identity = lambda x: x
increment = lambda x: x + 1
def increment_and_square(x):
y = x + 1
return y * y
functions = [identity, increment, increment_and_square]
for x in functions:
for y in functions:
self._testMapFusion([x, y])
for z in functions:
self._testMapFusion([x, y, z])
@combinations.generate(test_base.default_test_combinations())
def testMapAndFilterFusionTuple(self):
with_42 = lambda x: (x, 42)
swap = lambda x, y: (y, x)
self._testMapFusion([with_42, swap])
self._testMapFusion([with_42, swap, swap])
if __name__ == "__main__":
test.main()

View File

@ -22,38 +22,20 @@ from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
def _map_parallelization_test_cases():
"""Generates test cases for the MapParallelization optimization."""
identity = lambda x: x
increment = lambda x: x + 1
def assert_greater(x):
assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
with ops.control_dependencies([assert_op]):
return x
return (("Identity", identity, True),
("Increment", increment, True),
("AssertGreater", assert_greater, True))
@test_util.run_all_in_graph_and_eager_modes
class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(*_map_parallelization_test_cases())
def testMapParallelization(self, function, should_be_parallel):
next_nodes = ["ParallelMap"] if should_be_parallel else ["Map"]
def _testMapParallelization(self, function, should_optimize):
next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
dataset = dataset_ops.Dataset.range(5).apply(
testing.assert_next(next_nodes)).map(function)
options = dataset_ops.Options()
@ -63,9 +45,26 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
dataset, expected_output=[function(x) for x in range(5)])
def testMapParallelizationWithCapturedConstant(self):
"""Tests that functions with captured constants are parallelized."""
@combinations.generate(test_base.default_test_combinations())
def testIdentity(self):
self._testMapParallelization(lambda x: x, should_optimize=True)
@combinations.generate(test_base.default_test_combinations())
def testIncrement(self):
self._testMapParallelization(lambda x: x + 1, should_optimize=True)
@combinations.generate(test_base.default_test_combinations())
def testAssert(self):
def assert_greater(x):
assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
with ops.control_dependencies([assert_op]):
return x
self._testMapParallelization(assert_greater, should_optimize=True)
@combinations.generate(test_base.default_test_combinations())
def testCapturedConstant(self):
captured_t = constant_op.constant(42, dtype=dtypes.int64)
def fn(x):
return x + captured_t
@ -78,9 +77,8 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
dataset, expected_output=[x + 42 for x in range(5)])
def testMapParallelizationWithCapturedVariable(self):
"""Tests that functions with captured variables are not parallelized."""
@combinations.generate(test_base.default_test_combinations())
def testCapturedVariable(self):
captured_t = variables.Variable(42, dtype=dtypes.int64)
def fn(x):
return x + captured_t

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
import numpy as np
@ -26,12 +28,12 @@ from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import check_ops
@ -43,21 +45,45 @@ from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
def _generate_unary_cwise_math_cases():
# TODO(rachelim): Consolidate tests with pfor when APIs are somewhat shared.
bitwise_cases = [("Invert", bitwise_ops.invert)]
logical_cases = [("LogicalNot", math_ops.logical_not)]
complex_cases = [
def _generate_test_combinations(cases):
def reduce_fn(x, y):
name, fn = y
return x + combinations.combine(map_fn=combinations.NamedObject(name, fn))
return functools.reduce(reduce_fn, cases, [])
def _unary_bitwise_test_combinations():
cases = [("Invert", bitwise_ops.invert)]
return _generate_test_combinations(cases)
def _unary_logical_test_combinations():
cases = [("LogicalNot", math_ops.logical_not)]
return _generate_test_combinations(cases)
def _unary_complex_test_combinations():
cases = [
("Angle", math_ops.angle),
("ComplexAbs", math_ops.abs),
("Conj", math_ops.conj),
("Imag", math_ops.imag),
("Real", math_ops.real),
]
real_cases = [
return _generate_test_combinations(cases)
def _unary_real_test_combinations():
# acosh requires values x >= 1
def safe_acosh(x):
return math_ops.acosh(1 + math_ops.square(x))
cases = [
("Abs", math_ops.abs),
("Acos", math_ops.acos),
("Acosh", lambda x: math_ops.acosh(1 + math_ops.square(x))),
("Acosh", safe_acosh),
("Asin", math_ops.asin),
("Asinh", math_ops.asinh),
("Atan", math_ops.atan),
@ -99,45 +125,26 @@ def _generate_unary_cwise_math_cases():
("Tan", math_ops.tan),
("Tanh", math_ops.tanh),
]
random_input = np.random.rand(3, 5)
complex_component = np.random.rand(3, 5)
random_int = np.random.randint(0, 10, (7, 3, 5))
def bitwise_dataset_factory():
return dataset_ops.Dataset.from_tensor_slices(random_int)
def logical_dataset_factory():
return dataset_ops.Dataset.from_tensor_slices(random_input > 0)
def random_dataset_factory():
return dataset_ops.Dataset.from_tensor_slices(random_input)
def complex_dataset_factory():
return dataset_ops.Dataset.from_tensor_slices(
math_ops.complex(random_input, complex_component))
case_factory_pairs = [
(bitwise_cases, bitwise_dataset_factory),
(logical_cases, logical_dataset_factory),
(complex_cases, complex_dataset_factory),
(real_cases, random_dataset_factory),
]
return [(case[0], case[1], factory)
for cases, factory in case_factory_pairs
for case in cases]
return _generate_test_combinations(cases)
def _generate_binary_cwise_math_cases():
bitwise_cases = [("BitwiseAnd", bitwise_ops.bitwise_and),
("BitwiseOr", bitwise_ops.bitwise_or),
("BitwiseXor", bitwise_ops.bitwise_xor),
("LeftShift", bitwise_ops.left_shift),
("RightShift", bitwise_ops.right_shift)]
def _binary_bitwise_test_combinations():
cases = [("BitwiseAnd", bitwise_ops.bitwise_and),
("BitwiseOr", bitwise_ops.bitwise_or),
("BitwiseXor", bitwise_ops.bitwise_xor),
("LeftShift", bitwise_ops.left_shift),
("RightShift", bitwise_ops.right_shift)]
return _generate_test_combinations(cases)
logical_cases = [("LogicalAnd", math_ops.logical_and),
("LogicalOr", math_ops.logical_or)]
# Wrapper functions restricting the range of inputs of zeta and polygamma.
def _binary_logical_test_combinations():
cases = [("LogicalAnd", math_ops.logical_and),
("LogicalOr", math_ops.logical_or)]
return _generate_test_combinations(cases)
def _binary_real_test_combinations():
def safe_polygamma(x, y):
return math_ops.polygamma(
math_ops.round(clip_ops.clip_by_value(y, 1, 10)), x * x + 1)
@ -145,7 +152,7 @@ def _generate_binary_cwise_math_cases():
def safe_zeta(x, y):
return math_ops.zeta(x * x + 1, y * y)
real_cases = [
cases = [
("Add", math_ops.add),
("AddV2", math_ops.add_v2),
("Atan2", math_ops.atan2),
@ -174,150 +181,10 @@ def _generate_binary_cwise_math_cases():
("TruncateMod", math_ops.truncate_mod),
("Zeta", safe_zeta),
]
# Exercises broadcasting capabilities
x = np.random.rand(7, 3, 5)
y = np.random.rand(3, 5)
x_int = np.random.randint(0, 10, (7, 3, 5))
y_int = np.random.randint(0, 10, (3, 5))
def bitwise_dataset_factory():
return dataset_ops.Dataset.from_tensors((x_int, y_int))
def logical_dataset_factory():
return dataset_ops.Dataset.from_tensors((x > 0, y > 0))
def random_dataset_factory():
return dataset_ops.Dataset.from_tensors((x, y))
case_factory_pairs = [
(bitwise_cases, bitwise_dataset_factory),
(logical_cases, logical_dataset_factory),
(real_cases, random_dataset_factory),
]
return [(case[0], case[1], factory)
for cases, factory in case_factory_pairs
for case in cases]
return _generate_test_combinations(cases)
def _generate_cwise_test_cases():
return _generate_unary_cwise_math_cases() + _generate_binary_cwise_math_cases(
)
def _generate_csv_test_case():
def csv_factory():
return dataset_ops.Dataset.from_tensor_slices(["1.0:2:a",
"2.4:5:c"]).repeat(5)
def decode_csv_fn(x):
return parsing_ops.decode_csv(
x,
record_defaults=[
constant_op.constant([], dtypes.float32),
constant_op.constant([], dtypes.int32),
constant_op.constant([], dtypes.string)
],
field_delim=":")
return decode_csv_fn, csv_factory
def _generate_parse_single_example_test_case():
# When sparse tensors are used, map_vectorization is not
# attempted because the output_shapes of the map dataset are not defined.
# TODO(rachelim): Consider being more lax with checking the output_shapes of
# the map node.
def parse_example_factory():
def _int64_feature(*values):
return feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=values))
def _bytes_feature(*values):
return feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
value=[v.encode("utf-8") for v in values]))
return dataset_ops.Dataset.from_tensor_slices(
constant_op.constant([
example_pb2.Example(
features=feature_pb2.Features(
feature={
"dense_int": _int64_feature(i),
"dense_str": _bytes_feature(str(i)),
})).SerializeToString() for i in range(10)
]))
def parse_single_example_fn(x):
features = {
"dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0),
"dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""),
}
return parsing_ops.parse_single_example(x, features)
return parse_single_example_fn, parse_example_factory
def _generate_optimization_test_cases():
def base_dataset_factory():
return dataset_ops.Dataset.from_tensors(np.random.rand(10, 3)).repeat(5)
rand_val = np.random.rand(1, 1, 1, 1, 1, 1)
csv_test_case = _generate_csv_test_case()
parse_fn, parse_base = _generate_parse_single_example_test_case()
def dense_output_only_parse_fn(x):
# Since we haven't implemented a vectorizer for SerializeSparse, any
# function with sparse outputs will only be naively vectorized.
parse_result = parse_fn(x)
return [
y for y in parse_result if not isinstance(y, sparse_tensor.SparseTensor)
]
def map_fn_with_cycle(x):
c = lambda i: math_ops.less(i, 10)
b = lambda i: math_ops.add(i, 1)
return control_flow_ops.while_loop(c, b, [x])
# Misc test cases
test_cases = [
("Basic", lambda x: (x, x + 1), base_dataset_factory),
("Broadcast", lambda x: x + rand_val, base_dataset_factory),
("Cycle", map_fn_with_cycle, lambda: dataset_ops.Dataset.from_tensors(1)),
("Const", lambda x: 2, base_dataset_factory),
("Cast", lambda x: math_ops.cast(x, dtypes.float64),
base_dataset_factory),
("Reshape", lambda x: array_ops.reshape(x, (-1, 30)),
base_dataset_factory),
("Transpose", array_ops.transpose, base_dataset_factory),
("Unpack", array_ops.unstack, base_dataset_factory),
("UnpackNegativeAxis", lambda x: array_ops.unstack(x, axis=-1),
base_dataset_factory),
# Parsing ops
("DecodeCSV", csv_test_case[0], csv_test_case[1]),
("ParseSingleExample", parse_fn, parse_base),
("ParseSingleExampleDenseOutputOnly", dense_output_only_parse_fn,
parse_base),
] + _generate_cwise_test_cases()
return [{
"testcase_name":
x[0] + "Parallel" if num_parallel_calls is not None else x[0],
"map_fn":
x[1],
"base_dataset_factory":
x[2],
"num_parallel_calls":
num_parallel_calls
} for x in test_cases for num_parallel_calls in (None, 12)]
@test_util.run_all_in_graph_and_eager_modes
# TODO(rachelim): Consolidate tests with pfor when APIs are somewhat shared.
class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
def _enable_map_vectorization(self, dataset, use_choose=True):
@ -370,13 +237,223 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
optimized = self._enable_map_vectorization(optimized)
return unoptimized, optimized
@parameterized.named_parameters(_generate_optimization_test_cases())
def testOptimization(self, map_fn, base_dataset_factory, num_parallel_calls):
base_dataset = base_dataset_factory()
unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
def _testOptimization(self, map_fn, dataset_factory, num_parallel_calls):
dataset = dataset_factory()
unoptimized, optimized = self._get_test_datasets(dataset, map_fn,
num_parallel_calls)
self.assertDatasetsEqual(unoptimized, optimized)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testBasic(self, num_parallel_calls):
data = np.random.rand(10, 3)
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
map_fn = lambda x: (x, x + 1)
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testBroadcast(self, num_parallel_calls):
data = np.random.rand(10, 3)
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
value = np.random.rand(1, 1, 1, 1, 1, 1)
map_fn = lambda x: x + value
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testCast(self, num_parallel_calls):
data = np.random.rand(10, 3)
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
map_fn = lambda x: math_ops.cast(x, dtypes.float64)
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testConst(self, num_parallel_calls):
data = np.random.rand(10, 3)
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
map_fn = lambda x: 2
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testCycle(self, num_parallel_calls):
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(1)
def map_fn(x):
c = lambda i: math_ops.less(i, 10)
b = lambda i: math_ops.add(i, 1)
return control_flow_ops.while_loop(c, b, [x])
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testReshape(self, num_parallel_calls):
data = np.random.rand(10, 3)
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
map_fn = lambda x: array_ops.reshape(x, (-1, 30))
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testTranspose(self, num_parallel_calls):
data = np.random.rand(10, 3)
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
map_fn = array_ops.transpose
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testUnstack(self, num_parallel_calls):
data = np.random.rand(10, 3)
dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5)
map_fns = [array_ops.unstack, lambda x: array_ops.unstack(x, axis=-1)]
for map_fn in map_fns:
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_unary_bitwise_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testUnaryBitwiseOperations(self, map_fn, num_parallel_calls):
x = np.random.randint(0, 10, (7, 3, 5))
dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x)
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_unary_logical_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testUnaryLogicalOperations(self, map_fn, num_parallel_calls):
x = np.random.rand(3, 5)
dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x > 0)
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_unary_complex_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testUnaryComplexOperations(self, map_fn, num_parallel_calls):
x = math_ops.complex(np.random.rand(3, 5), np.random.rand(3, 5))
dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x)
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_unary_real_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testUnaryRealOperations(self, map_fn, num_parallel_calls):
x = np.random.rand(3, 5)
dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x)
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_binary_bitwise_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testBinaryBitwiseOperations(self, map_fn, num_parallel_calls):
x = np.random.randint(0, 10, (7, 3, 5))
y = np.random.randint(0, 10, (3, 5))
dataset_factory = lambda: dataset_ops.Dataset.from_tensors((x, y))
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_binary_logical_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testBinaryLogicalOperations(self, map_fn, num_parallel_calls):
x = np.random.rand(7, 3, 5)
y = np.random.rand(3, 5)
dataset_factory = lambda: dataset_ops.Dataset.from_tensors((x > 0, y > 0))
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_binary_real_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testBinaryRealOperations(self, map_fn, num_parallel_calls):
x = np.random.rand(7, 3, 5)
y = np.random.rand(3, 5)
dataset_factory = lambda: dataset_ops.Dataset.from_tensors((x, y))
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testDecodeCsv(self, num_parallel_calls):
def dataset_factory():
return dataset_ops.Dataset.from_tensor_slices(["1.0:2:a",
"2.4:5:c"]).repeat(5)
def decode_csv_fn(x):
return parsing_ops.decode_csv(
x,
record_defaults=[
constant_op.constant([], dtypes.float32),
constant_op.constant([], dtypes.int32),
constant_op.constant([], dtypes.string)
],
field_delim=":")
self._testOptimization(decode_csv_fn, dataset_factory, num_parallel_calls)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(num_parallel_calls=[None, 12])))
def testParseSingleExample(self, num_parallel_calls):
def dataset_factory():
def _int64_feature(*values):
return feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=values))
def _bytes_feature(*values):
return feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
value=[v.encode("utf-8") for v in values]))
# pylint:disable=g-complex-comprehension
return dataset_ops.Dataset.from_tensor_slices(
constant_op.constant([
example_pb2.Example(
features=feature_pb2.Features(
feature={
"dense_int": _int64_feature(i),
"dense_str": _bytes_feature(str(i)),
})).SerializeToString() for i in range(10)
]))
def parse_fn(x):
features = {
"dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0),
"dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""),
}
return parsing_ops.parse_single_example(x, features)
def dense_only_parse_fn(x):
return [
y for y in parse_fn(x)
if not isinstance(y, sparse_tensor.SparseTensor)
]
map_fns = [parse_fn, dense_only_parse_fn]
for map_fn in map_fns:
self._testOptimization(map_fn, dataset_factory, num_parallel_calls)
@combinations.generate(test_base.default_test_combinations())
def testOptimizationBadMapFn(self):
# Test map functions that give an error
def map_fn(x):
@ -391,6 +468,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
nxt = dataset_ops.make_one_shot_iterator(optimized).get_next()
self.evaluate(nxt)
@combinations.generate(test_base.default_test_combinations())
def testOptimizationWithCapturedInputs(self):
# Tests that vectorization works with captured inputs.
y = constant_op.constant(1, shape=(2,))
@ -405,6 +483,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
base_dataset, map_fn, expect_optimized=True)
self.assertDatasetsEqual(optimized, unoptimized)
@combinations.generate(test_base.default_test_combinations())
def testOptimizationWithMapAndBatchFusion(self):
# Tests that vectorization works on fused map and batch.
def map_fn(x):
@ -425,12 +504,11 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
optimized = self._enable_map_vectorization(optimized)
self.assertDatasetsEqual(optimized, unoptimized)
@parameterized.named_parameters(
("1", True, True),
("2", True, False),
("3", False, True),
("4", False, False),
)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
fuse_first=[True, False], fuse_second=[True, False])))
def testOptimizationWithChainedMapAndBatch(self, fuse_first, fuse_second):
# Tests that vectorization works on chained map and batch functions.
def map_fn(x):
@ -474,6 +552,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
optimized = self._enable_map_vectorization(optimized)
self.assertDatasetsEqual(optimized, unoptimized)
@combinations.generate(test_base.default_test_combinations())
def testOptimizationIgnoreStateful(self):
def map_fn(x):
@ -488,6 +567,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = self.getNext(dataset)
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testOptimizationIgnoreRagged(self):
# Make sure we ignore inputs that might not be uniformly sized
def map_fn(x):
@ -499,6 +579,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
base_dataset, map_fn, expect_optimized=False)
self.assertDatasetsEqual(unoptimized, optimized)
@combinations.generate(test_base.default_test_combinations())
def testOptimizationIgnoreRaggedMap(self):
# Don't optimize when the output of the map fn shapes are unknown.
def map_fn(x):
@ -512,6 +593,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = self.getNext(dataset)
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testOptimizationWithUnknownBatchShape(self):
tensor = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
@ -526,6 +608,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
optimized = self._enable_map_vectorization(unoptimized)
self.assertDatasetsEqual(unoptimized, optimized)
@combinations.generate(test_base.default_test_combinations())
def testOptimizationWithSparseTensor(self):
base_dataset = dataset_ops.Dataset.from_tensors(0)
@ -542,6 +625,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
optimized = self._enable_map_vectorization(unoptimized)
self.assertDatasetsEqual(unoptimized, optimized)
@combinations.generate(test_base.default_test_combinations())
def testOptimizationWithPrefetch(self):
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.map(lambda x: x)
@ -550,6 +634,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = self._enable_map_vectorization(dataset)
self.assertDatasetProduces(dataset, [list(range(10))])
@combinations.generate(test_base.default_test_combinations())
def testOptimizationWithoutChooseFastest(self):
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.map(lambda x: x**2)

View File

@ -17,19 +17,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class NoopEliminationTest(test_base.DatasetTestBase):
class NoopEliminationTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testNoopElimination(self):
a = constant_op.constant(1, dtype=dtypes.int64)
b = constant_op.constant(2, dtype=dtypes.int64)

View File

@ -17,19 +17,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase):
class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase,
parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testShuffleAndRepeatFusion(self):
if tf2.enabled() and context.executing_eagerly():
expected = "Shuffle"