[tf.data] Refactoring optimization test methods.
This CL breaks down large tests that iterate over different test cases into smaller ones -- one per test case. PiperOrigin-RevId: 283993741 Change-Id: I0e67958279d924d0b139164108e971bf39de96ca
This commit is contained in:
parent
f575856f03
commit
1bcbacd3df
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
@ -29,12 +31,42 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _test_combinations():
|
||||
cases = []
|
||||
|
||||
take_all = lambda x: constant_op.constant(True)
|
||||
is_zero = lambda x: math_ops.equal(x, 0)
|
||||
greater = lambda x: math_ops.greater(x + 5, 0)
|
||||
predicates = [take_all, is_zero, greater]
|
||||
for i, x in enumerate(predicates):
|
||||
for j, y in enumerate(predicates):
|
||||
cases.append((lambda x: x, "Scalar{}{}".format(i, j), [x, y]))
|
||||
for k, z in enumerate(predicates):
|
||||
cases.append((lambda x: x, "Scalar{}{}{}".format(i, j, k), [x, y, z]))
|
||||
|
||||
take_all = lambda x, y: constant_op.constant(True)
|
||||
is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
|
||||
|
||||
cases.append((lambda x: (x, x), "Tuple1", [take_all, take_all]))
|
||||
cases.append((lambda x: (x, 2), "Tuple2", [take_all, is_zero]))
|
||||
|
||||
def reduce_fn(x, y):
|
||||
function, name, predicates = y
|
||||
return x + combinations.combine(
|
||||
function=function,
|
||||
predicates=combinations.NamedObject(name, predicates))
|
||||
|
||||
return functools.reduce(reduce_fn, cases, [])
|
||||
|
||||
|
||||
class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _testFilterFusion(self, map_function, predicates):
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_test_combinations()))
|
||||
def testFilterFusion(self, function, predicates):
|
||||
dataset = dataset_ops.Dataset.range(5).apply(
|
||||
testing.assert_next(["Map", "Filter",
|
||||
"MemoryCacheImpl"])).map(map_function)
|
||||
testing.assert_next(["Map", "Filter", "MemoryCacheImpl"])).map(function)
|
||||
for predicate in predicates:
|
||||
dataset = dataset.filter(predicate)
|
||||
|
||||
@ -45,7 +77,7 @@ class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.with_options(options)
|
||||
expected_output = []
|
||||
for x in range(5):
|
||||
r = map_function(x)
|
||||
r = function(x)
|
||||
filtered = False
|
||||
for predicate in predicates:
|
||||
if isinstance(r, tuple):
|
||||
@ -60,26 +92,6 @@ class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output.append(r)
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFilterFusionScalar(self):
|
||||
take_all = lambda x: constant_op.constant(True)
|
||||
is_zero = lambda x: math_ops.equal(x, 0)
|
||||
greater = lambda x: math_ops.greater(x + 5, 0)
|
||||
predicates = [take_all, is_zero, greater]
|
||||
for x in predicates:
|
||||
for y in predicates:
|
||||
self._testFilterFusion(lambda x: x, [x, y])
|
||||
for z in predicates:
|
||||
self._testFilterFusion(lambda x: x, [x, y, z])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFilterFusionTuple(self):
|
||||
take_all = lambda x, y: constant_op.constant(True)
|
||||
is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
|
||||
|
||||
self._testFilterFusion(lambda x: (x, x), [take_all, take_all])
|
||||
self._testFilterFusion(lambda x: (x, 2), [take_all, is_zero])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
@ -33,6 +35,36 @@ from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _test_combinations():
|
||||
def random(_):
|
||||
return random_ops.random_uniform([],
|
||||
minval=1,
|
||||
maxval=10,
|
||||
dtype=dtypes.float32,
|
||||
seed=42)
|
||||
|
||||
def random_with_assert(x):
|
||||
y = random(x)
|
||||
assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
|
||||
with ops.control_dependencies([assert_op]):
|
||||
return y
|
||||
|
||||
cases = [
|
||||
("Increment", lambda x: x + 1, False),
|
||||
("Random", random, True),
|
||||
("RandomWithAssert", random_with_assert, True),
|
||||
("Complex", lambda x: (random(x) + random(x)) / 2, False),
|
||||
]
|
||||
|
||||
def reduce_fn(x, y):
|
||||
name, map_fn, should_optimize = y
|
||||
return x + combinations.combine(
|
||||
map_fn=combinations.NamedObject(name, map_fn),
|
||||
should_optimize=should_optimize)
|
||||
|
||||
return functools.reduce(reduce_fn, cases, [])
|
||||
|
||||
|
||||
class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _testDataset(self, dataset):
|
||||
@ -51,10 +83,13 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
def _testHoistFunction(self, function, should_optimize):
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_test_combinations()))
|
||||
def testHoistFunction(self, map_fn, should_optimize):
|
||||
dataset = dataset_ops.Dataset.range(5).apply(
|
||||
testing.assert_next(
|
||||
["Zip[0]", "Map"] if should_optimize else ["Map"])).map(function)
|
||||
["Zip[0]", "Map"] if should_optimize else ["Map"])).map(map_fn)
|
||||
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
@ -62,31 +97,6 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.with_options(options)
|
||||
self._testDataset(dataset)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNoRandom(self):
|
||||
self._testHoistFunction(lambda x: x + 1, should_optimize=False)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRandom(self):
|
||||
|
||||
def random(_):
|
||||
return random_ops.random_uniform([],
|
||||
minval=1,
|
||||
maxval=10,
|
||||
dtype=dtypes.float32,
|
||||
seed=42)
|
||||
|
||||
def random_with_assert(x):
|
||||
y = random(x)
|
||||
assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
|
||||
with ops.control_dependencies([assert_op]):
|
||||
return y
|
||||
|
||||
self._testHoistFunction(random, should_optimize=True)
|
||||
self._testHoistFunction(random_with_assert, should_optimize=True)
|
||||
self._testHoistFunction(
|
||||
lambda x: (random(x) + random(x)) / 2, should_optimize=False)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCapturedInputs(self):
|
||||
a = constant_op.constant(1, dtype=dtypes.float32)
|
||||
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
@ -29,6 +31,49 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _test_combinations():
|
||||
cases = []
|
||||
|
||||
identity = lambda x: x
|
||||
increment = lambda x: x + 1
|
||||
minus_five = lambda x: x - 5
|
||||
|
||||
def increment_and_square(x):
|
||||
y = x + 1
|
||||
return y * y
|
||||
|
||||
functions = [identity, increment, minus_five, increment_and_square]
|
||||
|
||||
take_all = lambda x: constant_op.constant(True)
|
||||
is_zero = lambda x: math_ops.equal(x, 0)
|
||||
is_odd = lambda x: math_ops.equal(x % 2, 0)
|
||||
greater = lambda x: math_ops.greater(x + 5, 0)
|
||||
predicates = [take_all, is_zero, is_odd, greater]
|
||||
|
||||
for i, function in enumerate(functions):
|
||||
for j, predicate in enumerate(predicates):
|
||||
cases.append((function, "Scalar{}{}".format(i, j), predicate))
|
||||
|
||||
replicate = lambda x: (x, x)
|
||||
with_two = lambda x: (x, 2)
|
||||
functions = [replicate, with_two]
|
||||
take_all = lambda x, y: constant_op.constant(True)
|
||||
is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
|
||||
predicates = [take_all, is_zero]
|
||||
|
||||
for i, function in enumerate(functions):
|
||||
for j, predicate in enumerate(predicates):
|
||||
cases.append((function, "Tuple{}{}".format(i, j), predicate))
|
||||
|
||||
def reduce_fn(x, y):
|
||||
function, name, predicate = y
|
||||
return x + combinations.combine(
|
||||
function=function,
|
||||
predicate=combinations.NamedObject(name, predicate))
|
||||
|
||||
return functools.reduce(reduce_fn, cases, [])
|
||||
|
||||
|
||||
class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _testDataset(self, dataset, function, predicate):
|
||||
@ -43,7 +88,10 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output.append(r)
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
def _testMapAndFilterFusion(self, function, predicate):
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_test_combinations()))
|
||||
def testMapAndFilterFusion(self, function, predicate):
|
||||
dataset = dataset_ops.Dataset.range(10).apply(
|
||||
testing.assert_next(["Map", "Filter",
|
||||
"Map"])).map(function).filter(predicate)
|
||||
@ -53,41 +101,6 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.with_options(options)
|
||||
self._testDataset(dataset, function, predicate)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapAndFilterFusionScalar(self):
|
||||
identity = lambda x: x
|
||||
increment = lambda x: x + 1
|
||||
minus_five = lambda x: x - 5
|
||||
|
||||
def increment_and_square(x):
|
||||
y = x + 1
|
||||
return y * y
|
||||
|
||||
functions = [identity, increment, minus_five, increment_and_square]
|
||||
|
||||
take_all = lambda x: constant_op.constant(True)
|
||||
is_zero = lambda x: math_ops.equal(x, 0)
|
||||
is_odd = lambda x: math_ops.equal(x % 2, 0)
|
||||
greater = lambda x: math_ops.greater(x + 5, 0)
|
||||
predicates = [take_all, is_zero, is_odd, greater]
|
||||
|
||||
for function in functions:
|
||||
for predicate in predicates:
|
||||
self._testMapAndFilterFusion(function, predicate)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapAndFilterFusionTuple(self):
|
||||
replicate = lambda x: (x, x)
|
||||
with_two = lambda x: (x, 2)
|
||||
functions = [replicate, with_two]
|
||||
take_all = lambda x, y: constant_op.constant(True)
|
||||
is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
|
||||
predicates = [take_all, is_zero]
|
||||
|
||||
for function in functions:
|
||||
for predicate in predicates:
|
||||
self._testMapAndFilterFusion(function, predicate)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCapturedInputs(self):
|
||||
a = constant_op.constant(3, dtype=dtypes.int64)
|
||||
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
@ -26,9 +28,44 @@ from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _test_combinations():
|
||||
cases = []
|
||||
|
||||
identity = lambda x: x
|
||||
increment = lambda x: x + 1
|
||||
|
||||
def increment_and_square(x):
|
||||
y = x + 1
|
||||
return y * y
|
||||
|
||||
functions = [identity, increment, increment_and_square]
|
||||
|
||||
for i, x in enumerate(functions):
|
||||
for j, y in enumerate(functions):
|
||||
cases.append(("Scalar{}{}".format(i, j), [x, y]))
|
||||
for k, z in enumerate(functions):
|
||||
cases.append(("Scalar{}{}{}".format(i, j, k), [x, y, z]))
|
||||
|
||||
with_42 = lambda x: (x, 42)
|
||||
swap = lambda x, y: (y, x)
|
||||
|
||||
cases.append(("Tuple1", [with_42, swap]))
|
||||
cases.append(("Tuple2", [with_42, swap, swap]))
|
||||
|
||||
def reduce_fn(x, y):
|
||||
name, functions = y
|
||||
return x + combinations.combine(
|
||||
functions=combinations.NamedObject(name, functions))
|
||||
|
||||
return functools.reduce(reduce_fn, cases, [])
|
||||
|
||||
|
||||
class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _testMapFusion(self, functions):
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_test_combinations()))
|
||||
def testMapFusion(self, functions):
|
||||
dataset = dataset_ops.Dataset.range(5).apply(
|
||||
testing.assert_next(["Map", "MemoryCacheImpl"]))
|
||||
for function in functions:
|
||||
@ -50,31 +87,6 @@ class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output.append(r)
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapFusionScalar(self):
|
||||
identity = lambda x: x
|
||||
increment = lambda x: x + 1
|
||||
|
||||
def increment_and_square(x):
|
||||
y = x + 1
|
||||
return y * y
|
||||
|
||||
functions = [identity, increment, increment_and_square]
|
||||
|
||||
for x in functions:
|
||||
for y in functions:
|
||||
self._testMapFusion([x, y])
|
||||
for z in functions:
|
||||
self._testMapFusion([x, y, z])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMapAndFilterFusionTuple(self):
|
||||
with_42 = lambda x: (x, 42)
|
||||
swap = lambda x, y: (y, x)
|
||||
|
||||
self._testMapFusion([with_42, swap])
|
||||
self._testMapFusion([with_42, swap, swap])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
@ -32,9 +34,33 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _test_combinations():
|
||||
def assert_greater(x):
|
||||
assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
|
||||
with ops.control_dependencies([assert_op]):
|
||||
return x
|
||||
|
||||
cases = [
|
||||
("Identity", lambda x: x, True),
|
||||
("Increment", lambda x: x + 1, True),
|
||||
("AssertGreater", assert_greater, True),
|
||||
]
|
||||
|
||||
def reduce_fn(x, y):
|
||||
name, function, should_optimize = y
|
||||
return x + combinations.combine(
|
||||
function=combinations.NamedObject(name, function),
|
||||
should_optimize=should_optimize)
|
||||
|
||||
return functools.reduce(reduce_fn, cases, [])
|
||||
|
||||
|
||||
class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _testMapParallelization(self, function, should_optimize):
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_test_combinations()))
|
||||
def testMapParallelization(self, function, should_optimize):
|
||||
next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
|
||||
dataset = dataset_ops.Dataset.range(5).apply(
|
||||
testing.assert_next(next_nodes)).map(function)
|
||||
@ -45,24 +71,6 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_output=[function(x) for x in range(5)])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testIdentity(self):
|
||||
self._testMapParallelization(lambda x: x, should_optimize=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testIncrement(self):
|
||||
self._testMapParallelization(lambda x: x + 1, should_optimize=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAssert(self):
|
||||
|
||||
def assert_greater(x):
|
||||
assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
|
||||
with ops.control_dependencies([assert_op]):
|
||||
return x
|
||||
|
||||
self._testMapParallelization(assert_greater, should_optimize=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCapturedConstant(self):
|
||||
captured_t = constant_op.constant(42, dtype=dtypes.int64)
|
||||
|
@ -400,6 +400,9 @@ class NamedObject(object):
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._obj(*args, **kwargs)
|
||||
|
||||
def __iter__(self):
|
||||
return self._obj.__iter__()
|
||||
|
||||
def __repr__(self):
|
||||
return self._name
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user