Put distribution and cluster parameter modifiers in their own combination

It's easier to add more strategy/cluster modifiers.

PiperOrigin-RevId: 324859132
Change-Id: Ided1a3d1106d71e86335ca3c59d7698550fa2155
This commit is contained in:
Ran Chen 2020-08-04 11:56:38 -07:00 committed by TensorFlower Gardener
parent 6acd86d539
commit e1e2645945
4 changed files with 42 additions and 50 deletions

View File

@ -90,8 +90,8 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase):
retrain_flag_value=["true", "false"], retrain_flag_value=["true", "false"],
regularization_loss_multiplier=[None, 2], # Test for b/134528831. regularization_loss_multiplier=[None, 2], # Test for b/134528831.
)), )),
test_combinations=(distribute_combinations.NamedGPUCombination(), test_combinations=(distribute_combinations.GPUCombination(),
distribute_combinations.NamedTPUCombination())) distribute_combinations.TPUCombination()))
@combinations.generate(**TEST_MNIST_CNN_GENERATE_KWARGS) @combinations.generate(**TEST_MNIST_CNN_GENERATE_KWARGS)
def test_mnist_cnn(self, use_keras_save_api, named_strategy, def test_mnist_cnn(self, use_keras_save_api, named_strategy,

View File

@ -99,7 +99,24 @@ class ClusterParameters(combinations_lib.ParameterModifier):
return update return update
class NamedGPUCombination(combinations_lib.TestCombination): class DistributionCombination(combinations_lib.TestCombination):
"""Sets up distribution strategy for tests."""
def parameter_modifiers(self):
return [
DistributionParameter(),
combinations_lib.OptionalParameter("use_var_policy"),
]
class ClusterCombination(combinations_lib.TestCombination):
"""Sets up multi worker tests."""
def parameter_modifiers(self):
return [ClusterParameters()]
class GPUCombination(combinations_lib.TestCombination):
"""Enable tests to request GPU hardware and skip non-GPU combinations. """Enable tests to request GPU hardware and skip non-GPU combinations.
This class expects test_combinations to be generated with `NamedDistribution` This class expects test_combinations to be generated with `NamedDistribution`
@ -141,17 +158,7 @@ class NamedGPUCombination(combinations_lib.TestCombination):
return [combinations_lib.OptionalParameter("required_gpus")] return [combinations_lib.OptionalParameter("required_gpus")]
class GPUCombination(NamedGPUCombination): class TPUCombination(combinations_lib.TestCombination):
"""NamedGPUCombination that passes `tf.distribute.Strategy` to the tests."""
def parameter_modifiers(self):
return [
ClusterParameters(),
DistributionParameter(),
] + NamedGPUCombination.parameter_modifiers(self)
class NamedTPUCombination(combinations_lib.TestCombination):
"""Allow to request TPU hardware and skip non-TPU combinations. """Allow to request TPU hardware and skip non-TPU combinations.
This class expects test_combinations to be generated with `NamedDistribution` This class expects test_combinations to be generated with `NamedDistribution`
@ -213,16 +220,6 @@ class NamedTPUCombination(combinations_lib.TestCombination):
] ]
class TPUCombination(NamedTPUCombination):
"""NamedTPUCombination that passes `tf.distribute.Strategy` to the tests."""
def parameter_modifiers(self):
return [
ClusterParameters(),
DistributionParameter(),
] + NamedTPUCombination.parameter_modifiers(self)
class NamedDistribution(object): class NamedDistribution(object):
"""Wraps a `tf.distribute.Strategy` and adds a name for test titles.""" """Wraps a `tf.distribute.Strategy` and adds a name for test titles."""
@ -304,6 +301,8 @@ def generate(combinations, test_combinations=()):
default_combinations = ( default_combinations = (
framework_combinations.EagerGraphCombination(), framework_combinations.EagerGraphCombination(),
framework_combinations.TFVersionCombination(), framework_combinations.TFVersionCombination(),
ClusterCombination(),
DistributionCombination(),
GPUCombination(), GPUCombination(),
TPUCombination(), TPUCombination(),
) )

View File

@ -30,7 +30,7 @@ from tensorflow.python.framework import combinations as framework_combinations
from tensorflow.python.platform import test from tensorflow.python.platform import test
class ClusterParametersTest(test.TestCase, parameterized.TestCase): class ClusterCombinationTest(test.TestCase, parameterized.TestCase):
# For this test we need to use `framework.test_combinations` because our # For this test we need to use `framework.test_combinations` because our
# `generate` eats the cluster parameters. # `generate` eats the cluster parameters.
# #
@ -42,7 +42,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
combinations.NamedDistribution( combinations.NamedDistribution(
"HasClusterParams", lambda: None, has_chief=True, num_workers=2), "HasClusterParams", lambda: None, has_chief=True, num_workers=2),
]), ]),
test_combinations=(combinations.GPUCombination(),)) test_combinations=(combinations.ClusterCombination(),))
def testClusterParams(self, distribution, has_chief, num_workers): def testClusterParams(self, distribution, has_chief, num_workers):
self.assertTrue(has_chief) self.assertTrue(has_chief)
self.assertEqual(num_workers, 2) self.assertEqual(num_workers, 2)
@ -51,14 +51,14 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
framework_combinations.combine(distribution=[ framework_combinations.combine(distribution=[
combinations.NamedDistribution("NoClusterParams", lambda: None), combinations.NamedDistribution("NoClusterParams", lambda: None),
]), ]),
test_combinations=(combinations.GPUCombination(),)) test_combinations=(combinations.ClusterCombination(),))
def testClusterParamsHasDefault(self, distribution, has_chief, num_workers): def testClusterParamsHasDefault(self, distribution, has_chief, num_workers):
self.assertFalse(has_chief) self.assertFalse(has_chief)
self.assertEqual(num_workers, 1) self.assertEqual(num_workers, 1)
@framework_combinations.generate( @framework_combinations.generate(
framework_combinations.combine(v=1), framework_combinations.combine(v=1),
test_combinations=(combinations.GPUCombination(),)) test_combinations=(combinations.ClusterCombination(),))
def testClusterParamsNoStrategy(self, v, has_chief, num_workers): def testClusterParamsNoStrategy(self, v, has_chief, num_workers):
self.assertFalse(has_chief) self.assertFalse(has_chief)
self.assertEqual(num_workers, 1) self.assertEqual(num_workers, 1)
@ -69,7 +69,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
"WithClusterParams", lambda: None, has_chief=True, num_workers=2), "WithClusterParams", lambda: None, has_chief=True, num_workers=2),
combinations.NamedDistribution("WithoutClusterParams", lambda: None), combinations.NamedDistribution("WithoutClusterParams", lambda: None),
]), ]),
test_combinations=(combinations.GPUCombination(),)) test_combinations=(combinations.ClusterCombination(),))
def testClusterParamsAreOptional(self, distribution): def testClusterParamsAreOptional(self, distribution):
# If combinations library doesn't raise an exception, the test is passed. # If combinations library doesn't raise an exception, the test is passed.
pass pass
@ -83,7 +83,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
ds3=combinations.NamedDistribution( ds3=combinations.NamedDistribution(
"Strategy3", lambda: None, has_chief=True, num_workers=0), "Strategy3", lambda: None, has_chief=True, num_workers=0),
), ),
test_combinations=(combinations.GPUCombination(),)) test_combinations=(combinations.ClusterCombination(),))
def testMultipleDistributionSingleWorker(self, ds1, ds2, ds3): def testMultipleDistributionSingleWorker(self, ds1, ds2, ds3):
# If combinations library doesn't raise an exception, the test is passed. # If combinations library doesn't raise an exception, the test is passed.
pass pass
@ -101,7 +101,7 @@ class ClusterParametersShouldFailTest(test.TestCase, parameterized.TestCase):
ds2=combinations.NamedDistribution( ds2=combinations.NamedDistribution(
"Strategy2", lambda: None, has_chief=True, num_workers=2), "Strategy2", lambda: None, has_chief=True, num_workers=2),
), ),
test_combinations=(combinations.GPUCombination(),)) test_combinations=(combinations.ClusterCombination(),))
def testMultipleDistributionMultiWorker(self, ds1, ds2): def testMultipleDistributionMultiWorker(self, ds1, ds2):
# combinations library should raise an exception. # combinations library should raise an exception.
pass pass

View File

@ -95,8 +95,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
sess.run({"complicated": mirrored}) sess.run({"complicated": mirrored})
@combinations.generate(strategy_and_run_tf_function_combinations()) @combinations.generate(strategy_and_run_tf_function_combinations())
def testAssign(self, distribution, experimental_run_tf_function, def testAssign(self, distribution, experimental_run_tf_function):
use_var_policy):
def assign(fn, v, update_value, cross_replica): def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value) update_fn = lambda: getattr(v, fn)(update_value)
@ -136,8 +135,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.evaluate(array_ops.ones_like(component))) self.evaluate(array_ops.ones_like(component)))
@combinations.generate(strategy_and_run_tf_function_combinations()) @combinations.generate(strategy_and_run_tf_function_combinations())
def testAssignOnWriteVar(self, distribution, experimental_run_tf_function, def testAssignOnWriteVar(self, distribution, experimental_run_tf_function):
use_var_policy):
with distribution.scope(): with distribution.scope():
v_to_assign = variable_scope.variable( v_to_assign = variable_scope.variable(
@ -182,8 +180,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.assertAllEqual(2.0, self.evaluate(component.read_value())) self.assertAllEqual(2.0, self.evaluate(component.read_value()))
@combinations.generate(strategy_and_run_tf_function_combinations()) @combinations.generate(strategy_and_run_tf_function_combinations())
def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function, def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function):
use_var_policy):
if isinstance(distribution, _TPU_STRATEGIES): if isinstance(distribution, _TPU_STRATEGIES):
self.skipTest("Assigning PerReplica values is not supported. See" self.skipTest("Assigning PerReplica values is not supported. See"
@ -241,7 +238,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.assertAllEqual(expected, self.evaluate(component.read_value())) self.assertAllEqual(expected, self.evaluate(component.read_value()))
@combinations.generate(strategy_with_var_policy()) @combinations.generate(strategy_with_var_policy())
def testValueInReplicaContext(self, distribution, use_var_policy): def testValueInReplicaContext(self, distribution):
with distribution.scope(): with distribution.scope():
v = variables_lib.Variable( v = variables_lib.Variable(
1., aggregation=variables_lib.VariableAggregation.MEAN) 1., aggregation=variables_lib.VariableAggregation.MEAN)
@ -260,8 +257,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
@combinations.generate(strategy_and_run_tf_function_combinations()) @combinations.generate(strategy_and_run_tf_function_combinations())
def testReadValueInReplicaContext(self, distribution, def testReadValueInReplicaContext(self, distribution,
experimental_run_tf_function, experimental_run_tf_function):
use_var_policy):
aggregations = [ aggregations = [
variables_lib.VariableAggregation.NONE, variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM, variables_lib.VariableAggregation.SUM,
@ -286,8 +282,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
@combinations.generate(strategy_and_run_tf_function_combinations()) @combinations.generate(strategy_and_run_tf_function_combinations())
def testReadValueInCrossReplicaContext(self, distribution, def testReadValueInCrossReplicaContext(self, distribution,
experimental_run_tf_function, experimental_run_tf_function):
use_var_policy):
aggregations = [ aggregations = [
variables_lib.VariableAggregation.NONE, variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM, variables_lib.VariableAggregation.SUM,
@ -312,7 +307,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.evaluate(results)) self.evaluate(results))
@combinations.generate(strategy_with_var_policy()) @combinations.generate(strategy_with_var_policy())
def testAssignOutOfScope(self, distribution, use_var_policy): def testAssignOutOfScope(self, distribution):
with distribution.scope(): with distribution.scope():
mirrored = variables_lib.Variable(1.) mirrored = variables_lib.Variable(1.)
self.evaluate(mirrored.assign(3.)) self.evaluate(mirrored.assign(3.))
@ -321,8 +316,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.assertEqual(self.evaluate(component.read_value()), 3.) self.assertEqual(self.evaluate(component.read_value()), 3.)
@combinations.generate(strategy_with_var_policy()) @combinations.generate(strategy_with_var_policy())
def testAssignAggregationMeanDTypeNonFloat(self, distribution, def testAssignAggregationMeanDTypeNonFloat(self, distribution):
use_var_policy):
if isinstance(distribution, _TPU_STRATEGIES): if isinstance(distribution, _TPU_STRATEGIES):
self.skipTest("Fix sponge/6e8ab540-4c0f-4da5-aedf-86505ff810c9 before " self.skipTest("Fix sponge/6e8ab540-4c0f-4da5-aedf-86505ff810c9 before "
"reenabling test.") "reenabling test.")
@ -379,8 +373,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.assertEqual(self.evaluate(v.read_value()), 4) self.assertEqual(self.evaluate(v.read_value()), 4)
@combinations.generate(strategy_with_var_policy()) @combinations.generate(strategy_with_var_policy())
def testInitializedToSameValueInsideEagerRun(self, distribution, def testInitializedToSameValueInsideEagerRun(self, distribution):
use_var_policy):
if not context.executing_eagerly(): self.skipTest("eager only test") if not context.executing_eagerly(): self.skipTest("eager only test")
v = [None] v = [None]
@ -399,7 +392,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.assertAllEqual(vals[0], vals[1]) self.assertAllEqual(vals[0], vals[1])
@combinations.generate(strategy_with_var_policy()) @combinations.generate(strategy_with_var_policy())
def testAggregationOnlyFirstReplica(self, distribution, use_var_policy): def testAggregationOnlyFirstReplica(self, distribution):
with distribution.scope(): with distribution.scope():
v = variable_scope.variable( v = variable_scope.variable(
15., 15.,
@ -420,7 +413,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
per_replica_results) per_replica_results)
@combinations.generate(strategy_with_var_policy()) @combinations.generate(strategy_with_var_policy())
def testInitScope(self, distribution, use_var_policy): def testInitScope(self, distribution):
if not context.executing_eagerly(): self.skipTest("eager only") if not context.executing_eagerly(): self.skipTest("eager only")
class C(object): class C(object):
@ -448,7 +441,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.assertAllEqual([2, 2], per_replica_results) self.assertAllEqual([2, 2], per_replica_results)
@combinations.generate(strategy_with_var_policy()) @combinations.generate(strategy_with_var_policy())
def testOperatorOverride(self, distribution, use_var_policy): def testOperatorOverride(self, distribution):
with distribution.scope(): with distribution.scope():
v = variable_scope.variable( v = variable_scope.variable(