[tf.data] Refactoring optimization test methods.

This CL breaks down large tests that iterate over different test cases into smaller ones -- one per test case.

PiperOrigin-RevId: 283993741
Change-Id: I0e67958279d924d0b139164108e971bf39de96ca
This commit is contained in:
Jiri Simsa 2019-12-05 09:47:03 -08:00 committed by TensorFlower Gardener
parent f575856f03
commit 1bcbacd3df
6 changed files with 190 additions and 132 deletions

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
@ -29,12 +31,42 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
def _test_combinations():
cases = []
take_all = lambda x: constant_op.constant(True)
is_zero = lambda x: math_ops.equal(x, 0)
greater = lambda x: math_ops.greater(x + 5, 0)
predicates = [take_all, is_zero, greater]
for i, x in enumerate(predicates):
for j, y in enumerate(predicates):
cases.append((lambda x: x, "Scalar{}{}".format(i, j), [x, y]))
for k, z in enumerate(predicates):
cases.append((lambda x: x, "Scalar{}{}{}".format(i, j, k), [x, y, z]))
take_all = lambda x, y: constant_op.constant(True)
is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
cases.append((lambda x: (x, x), "Tuple1", [take_all, take_all]))
cases.append((lambda x: (x, 2), "Tuple2", [take_all, is_zero]))
def reduce_fn(x, y):
function, name, predicates = y
return x + combinations.combine(
function=function,
predicates=combinations.NamedObject(name, predicates))
return functools.reduce(reduce_fn, cases, [])
class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
def _testFilterFusion(self, map_function, predicates):
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_test_combinations()))
def testFilterFusion(self, function, predicates):
dataset = dataset_ops.Dataset.range(5).apply(
testing.assert_next(["Map", "Filter",
"MemoryCacheImpl"])).map(map_function)
testing.assert_next(["Map", "Filter", "MemoryCacheImpl"])).map(function)
for predicate in predicates:
dataset = dataset.filter(predicate)
@ -45,7 +77,7 @@ class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.with_options(options)
expected_output = []
for x in range(5):
r = map_function(x)
r = function(x)
filtered = False
for predicate in predicates:
if isinstance(r, tuple):
@ -60,26 +92,6 @@ class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output.append(r)
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testFilterFusionScalar(self):
take_all = lambda x: constant_op.constant(True)
is_zero = lambda x: math_ops.equal(x, 0)
greater = lambda x: math_ops.greater(x + 5, 0)
predicates = [take_all, is_zero, greater]
for x in predicates:
for y in predicates:
self._testFilterFusion(lambda x: x, [x, y])
for z in predicates:
self._testFilterFusion(lambda x: x, [x, y, z])
@combinations.generate(test_base.default_test_combinations())
def testFilterFusionTuple(self):
take_all = lambda x, y: constant_op.constant(True)
is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
self._testFilterFusion(lambda x: (x, x), [take_all, take_all])
self._testFilterFusion(lambda x: (x, 2), [take_all, is_zero])
if __name__ == "__main__":
test.main()

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
@ -33,6 +35,36 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
def _test_combinations():
def random(_):
return random_ops.random_uniform([],
minval=1,
maxval=10,
dtype=dtypes.float32,
seed=42)
def random_with_assert(x):
y = random(x)
assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
with ops.control_dependencies([assert_op]):
return y
cases = [
("Increment", lambda x: x + 1, False),
("Random", random, True),
("RandomWithAssert", random_with_assert, True),
("Complex", lambda x: (random(x) + random(x)) / 2, False),
]
def reduce_fn(x, y):
name, map_fn, should_optimize = y
return x + combinations.combine(
map_fn=combinations.NamedObject(name, map_fn),
should_optimize=should_optimize)
return functools.reduce(reduce_fn, cases, [])
class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
def _testDataset(self, dataset):
@ -51,10 +83,13 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
def _testHoistFunction(self, function, should_optimize):
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_test_combinations()))
def testHoistFunction(self, map_fn, should_optimize):
dataset = dataset_ops.Dataset.range(5).apply(
testing.assert_next(
["Zip[0]", "Map"] if should_optimize else ["Map"])).map(function)
["Zip[0]", "Map"] if should_optimize else ["Map"])).map(map_fn)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
@ -62,31 +97,6 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.with_options(options)
self._testDataset(dataset)
@combinations.generate(test_base.default_test_combinations())
def testNoRandom(self):
self._testHoistFunction(lambda x: x + 1, should_optimize=False)
@combinations.generate(test_base.default_test_combinations())
def testRandom(self):
def random(_):
return random_ops.random_uniform([],
minval=1,
maxval=10,
dtype=dtypes.float32,
seed=42)
def random_with_assert(x):
y = random(x)
assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
with ops.control_dependencies([assert_op]):
return y
self._testHoistFunction(random, should_optimize=True)
self._testHoistFunction(random_with_assert, should_optimize=True)
self._testHoistFunction(
lambda x: (random(x) + random(x)) / 2, should_optimize=False)
@combinations.generate(test_base.default_test_combinations())
def testCapturedInputs(self):
a = constant_op.constant(1, dtype=dtypes.float32)

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
@ -29,6 +31,49 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
def _test_combinations():
cases = []
identity = lambda x: x
increment = lambda x: x + 1
minus_five = lambda x: x - 5
def increment_and_square(x):
y = x + 1
return y * y
functions = [identity, increment, minus_five, increment_and_square]
take_all = lambda x: constant_op.constant(True)
is_zero = lambda x: math_ops.equal(x, 0)
is_odd = lambda x: math_ops.equal(x % 2, 0)
greater = lambda x: math_ops.greater(x + 5, 0)
predicates = [take_all, is_zero, is_odd, greater]
for i, function in enumerate(functions):
for j, predicate in enumerate(predicates):
cases.append((function, "Scalar{}{}".format(i, j), predicate))
replicate = lambda x: (x, x)
with_two = lambda x: (x, 2)
functions = [replicate, with_two]
take_all = lambda x, y: constant_op.constant(True)
is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
predicates = [take_all, is_zero]
for i, function in enumerate(functions):
for j, predicate in enumerate(predicates):
cases.append((function, "Tuple{}{}".format(i, j), predicate))
def reduce_fn(x, y):
function, name, predicate = y
return x + combinations.combine(
function=function,
predicate=combinations.NamedObject(name, predicate))
return functools.reduce(reduce_fn, cases, [])
class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
def _testDataset(self, dataset, function, predicate):
@ -43,7 +88,10 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output.append(r)
self.assertDatasetProduces(dataset, expected_output=expected_output)
def _testMapAndFilterFusion(self, function, predicate):
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_test_combinations()))
def testMapAndFilterFusion(self, function, predicate):
dataset = dataset_ops.Dataset.range(10).apply(
testing.assert_next(["Map", "Filter",
"Map"])).map(function).filter(predicate)
@ -53,41 +101,6 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.with_options(options)
self._testDataset(dataset, function, predicate)
@combinations.generate(test_base.default_test_combinations())
def testMapAndFilterFusionScalar(self):
identity = lambda x: x
increment = lambda x: x + 1
minus_five = lambda x: x - 5
def increment_and_square(x):
y = x + 1
return y * y
functions = [identity, increment, minus_five, increment_and_square]
take_all = lambda x: constant_op.constant(True)
is_zero = lambda x: math_ops.equal(x, 0)
is_odd = lambda x: math_ops.equal(x % 2, 0)
greater = lambda x: math_ops.greater(x + 5, 0)
predicates = [take_all, is_zero, is_odd, greater]
for function in functions:
for predicate in predicates:
self._testMapAndFilterFusion(function, predicate)
@combinations.generate(test_base.default_test_combinations())
def testMapAndFilterFusionTuple(self):
replicate = lambda x: (x, x)
with_two = lambda x: (x, 2)
functions = [replicate, with_two]
take_all = lambda x, y: constant_op.constant(True)
is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
predicates = [take_all, is_zero]
for function in functions:
for predicate in predicates:
self._testMapAndFilterFusion(function, predicate)
@combinations.generate(test_base.default_test_combinations())
def testCapturedInputs(self):
a = constant_op.constant(3, dtype=dtypes.int64)

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
@ -26,9 +28,44 @@ from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
def _test_combinations():
cases = []
identity = lambda x: x
increment = lambda x: x + 1
def increment_and_square(x):
y = x + 1
return y * y
functions = [identity, increment, increment_and_square]
for i, x in enumerate(functions):
for j, y in enumerate(functions):
cases.append(("Scalar{}{}".format(i, j), [x, y]))
for k, z in enumerate(functions):
cases.append(("Scalar{}{}{}".format(i, j, k), [x, y, z]))
with_42 = lambda x: (x, 42)
swap = lambda x, y: (y, x)
cases.append(("Tuple1", [with_42, swap]))
cases.append(("Tuple2", [with_42, swap, swap]))
def reduce_fn(x, y):
name, functions = y
return x + combinations.combine(
functions=combinations.NamedObject(name, functions))
return functools.reduce(reduce_fn, cases, [])
class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
def _testMapFusion(self, functions):
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_test_combinations()))
def testMapFusion(self, functions):
dataset = dataset_ops.Dataset.range(5).apply(
testing.assert_next(["Map", "MemoryCacheImpl"]))
for function in functions:
@ -50,31 +87,6 @@ class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output.append(r)
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testMapFusionScalar(self):
identity = lambda x: x
increment = lambda x: x + 1
def increment_and_square(x):
y = x + 1
return y * y
functions = [identity, increment, increment_and_square]
for x in functions:
for y in functions:
self._testMapFusion([x, y])
for z in functions:
self._testMapFusion([x, y, z])
@combinations.generate(test_base.default_test_combinations())
def testMapAndFilterFusionTuple(self):
with_42 = lambda x: (x, 42)
swap = lambda x, y: (y, x)
self._testMapFusion([with_42, swap])
self._testMapFusion([with_42, swap, swap])
if __name__ == "__main__":
test.main()

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
@ -32,9 +34,33 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
def _test_combinations():
def assert_greater(x):
assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
with ops.control_dependencies([assert_op]):
return x
cases = [
("Identity", lambda x: x, True),
("Increment", lambda x: x + 1, True),
("AssertGreater", assert_greater, True),
]
def reduce_fn(x, y):
name, function, should_optimize = y
return x + combinations.combine(
function=combinations.NamedObject(name, function),
should_optimize=should_optimize)
return functools.reduce(reduce_fn, cases, [])
class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
def _testMapParallelization(self, function, should_optimize):
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_test_combinations()))
def testMapParallelization(self, function, should_optimize):
next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
dataset = dataset_ops.Dataset.range(5).apply(
testing.assert_next(next_nodes)).map(function)
@ -45,24 +71,6 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
dataset, expected_output=[function(x) for x in range(5)])
@combinations.generate(test_base.default_test_combinations())
def testIdentity(self):
self._testMapParallelization(lambda x: x, should_optimize=True)
@combinations.generate(test_base.default_test_combinations())
def testIncrement(self):
self._testMapParallelization(lambda x: x + 1, should_optimize=True)
@combinations.generate(test_base.default_test_combinations())
def testAssert(self):
def assert_greater(x):
assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
with ops.control_dependencies([assert_op]):
return x
self._testMapParallelization(assert_greater, should_optimize=True)
@combinations.generate(test_base.default_test_combinations())
def testCapturedConstant(self):
captured_t = constant_op.constant(42, dtype=dtypes.int64)

View File

@ -400,6 +400,9 @@ class NamedObject(object):
def __call__(self, *args, **kwargs):
return self._obj(*args, **kwargs)
def __iter__(self):
return self._obj.__iter__()
def __repr__(self):
return self._name