Put distribution and cluster parameter modifiers in their own combination

It's easier to add more strategy/cluster modifiers.

PiperOrigin-RevId: 324859132
Change-Id: Ided1a3d1106d71e86335ca3c59d7698550fa2155
This commit is contained in:
Ran Chen 2020-08-04 11:56:38 -07:00 committed by TensorFlower Gardener
parent 6acd86d539
commit e1e2645945
4 changed files with 42 additions and 50 deletions

View File

@ -90,8 +90,8 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase):
retrain_flag_value=["true", "false"],
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
)),
test_combinations=(distribute_combinations.NamedGPUCombination(),
distribute_combinations.NamedTPUCombination()))
test_combinations=(distribute_combinations.GPUCombination(),
distribute_combinations.TPUCombination()))
@combinations.generate(**TEST_MNIST_CNN_GENERATE_KWARGS)
def test_mnist_cnn(self, use_keras_save_api, named_strategy,

View File

@ -99,7 +99,24 @@ class ClusterParameters(combinations_lib.ParameterModifier):
return update
class NamedGPUCombination(combinations_lib.TestCombination):
class DistributionCombination(combinations_lib.TestCombination):
"""Sets up distribution strategy for tests."""
def parameter_modifiers(self):
return [
DistributionParameter(),
combinations_lib.OptionalParameter("use_var_policy"),
]
class ClusterCombination(combinations_lib.TestCombination):
"""Sets up multi worker tests."""
def parameter_modifiers(self):
return [ClusterParameters()]
class GPUCombination(combinations_lib.TestCombination):
"""Enable tests to request GPU hardware and skip non-GPU combinations.
This class expects test_combinations to be generated with `NamedDistribution`
@ -141,17 +158,7 @@ class NamedGPUCombination(combinations_lib.TestCombination):
return [combinations_lib.OptionalParameter("required_gpus")]
class GPUCombination(NamedGPUCombination):
"""NamedGPUCombination that passes `tf.distribute.Strategy` to the tests."""
def parameter_modifiers(self):
return [
ClusterParameters(),
DistributionParameter(),
] + NamedGPUCombination.parameter_modifiers(self)
class NamedTPUCombination(combinations_lib.TestCombination):
class TPUCombination(combinations_lib.TestCombination):
"""Allow to request TPU hardware and skip non-TPU combinations.
This class expects test_combinations to be generated with `NamedDistribution`
@ -213,16 +220,6 @@ class NamedTPUCombination(combinations_lib.TestCombination):
]
class TPUCombination(NamedTPUCombination):
"""NamedTPUCombination that passes `tf.distribute.Strategy` to the tests."""
def parameter_modifiers(self):
return [
ClusterParameters(),
DistributionParameter(),
] + NamedTPUCombination.parameter_modifiers(self)
class NamedDistribution(object):
"""Wraps a `tf.distribute.Strategy` and adds a name for test titles."""
@ -304,6 +301,8 @@ def generate(combinations, test_combinations=()):
default_combinations = (
framework_combinations.EagerGraphCombination(),
framework_combinations.TFVersionCombination(),
ClusterCombination(),
DistributionCombination(),
GPUCombination(),
TPUCombination(),
)

View File

@ -30,7 +30,7 @@ from tensorflow.python.framework import combinations as framework_combinations
from tensorflow.python.platform import test
class ClusterParametersTest(test.TestCase, parameterized.TestCase):
class ClusterCombinationTest(test.TestCase, parameterized.TestCase):
# For this test we need to use `framework.test_combinations` because our
# `generate` eats the cluster parameters.
#
@ -42,7 +42,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
combinations.NamedDistribution(
"HasClusterParams", lambda: None, has_chief=True, num_workers=2),
]),
test_combinations=(combinations.GPUCombination(),))
test_combinations=(combinations.ClusterCombination(),))
def testClusterParams(self, distribution, has_chief, num_workers):
self.assertTrue(has_chief)
self.assertEqual(num_workers, 2)
@ -51,14 +51,14 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
framework_combinations.combine(distribution=[
combinations.NamedDistribution("NoClusterParams", lambda: None),
]),
test_combinations=(combinations.GPUCombination(),))
test_combinations=(combinations.ClusterCombination(),))
def testClusterParamsHasDefault(self, distribution, has_chief, num_workers):
self.assertFalse(has_chief)
self.assertEqual(num_workers, 1)
@framework_combinations.generate(
framework_combinations.combine(v=1),
test_combinations=(combinations.GPUCombination(),))
test_combinations=(combinations.ClusterCombination(),))
def testClusterParamsNoStrategy(self, v, has_chief, num_workers):
self.assertFalse(has_chief)
self.assertEqual(num_workers, 1)
@ -69,7 +69,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
"WithClusterParams", lambda: None, has_chief=True, num_workers=2),
combinations.NamedDistribution("WithoutClusterParams", lambda: None),
]),
test_combinations=(combinations.GPUCombination(),))
test_combinations=(combinations.ClusterCombination(),))
def testClusterParamsAreOptional(self, distribution):
# If combinations library doesn't raise an exception, the test is passed.
pass
@ -83,7 +83,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
ds3=combinations.NamedDistribution(
"Strategy3", lambda: None, has_chief=True, num_workers=0),
),
test_combinations=(combinations.GPUCombination(),))
test_combinations=(combinations.ClusterCombination(),))
def testMultipleDistributionSingleWorker(self, ds1, ds2, ds3):
# If combinations library doesn't raise an exception, the test is passed.
pass
@ -101,7 +101,7 @@ class ClusterParametersShouldFailTest(test.TestCase, parameterized.TestCase):
ds2=combinations.NamedDistribution(
"Strategy2", lambda: None, has_chief=True, num_workers=2),
),
test_combinations=(combinations.GPUCombination(),))
test_combinations=(combinations.ClusterCombination(),))
def testMultipleDistributionMultiWorker(self, ds1, ds2):
# combinations library should raise an exception.
pass

View File

@ -95,8 +95,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
sess.run({"complicated": mirrored})
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssign(self, distribution, experimental_run_tf_function,
use_var_policy):
def testAssign(self, distribution, experimental_run_tf_function):
def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value)
@ -136,8 +135,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.evaluate(array_ops.ones_like(component)))
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssignOnWriteVar(self, distribution, experimental_run_tf_function,
use_var_policy):
def testAssignOnWriteVar(self, distribution, experimental_run_tf_function):
with distribution.scope():
v_to_assign = variable_scope.variable(
@ -182,8 +180,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.assertAllEqual(2.0, self.evaluate(component.read_value()))
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function,
use_var_policy):
def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function):
if isinstance(distribution, _TPU_STRATEGIES):
self.skipTest("Assigning PerReplica values is not supported. See"
@ -241,7 +238,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.assertAllEqual(expected, self.evaluate(component.read_value()))
@combinations.generate(strategy_with_var_policy())
def testValueInReplicaContext(self, distribution, use_var_policy):
def testValueInReplicaContext(self, distribution):
with distribution.scope():
v = variables_lib.Variable(
1., aggregation=variables_lib.VariableAggregation.MEAN)
@ -260,8 +257,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
@combinations.generate(strategy_and_run_tf_function_combinations())
def testReadValueInReplicaContext(self, distribution,
experimental_run_tf_function,
use_var_policy):
experimental_run_tf_function):
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
@ -286,8 +282,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
@combinations.generate(strategy_and_run_tf_function_combinations())
def testReadValueInCrossReplicaContext(self, distribution,
experimental_run_tf_function,
use_var_policy):
experimental_run_tf_function):
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
@ -312,7 +307,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.evaluate(results))
@combinations.generate(strategy_with_var_policy())
def testAssignOutOfScope(self, distribution, use_var_policy):
def testAssignOutOfScope(self, distribution):
with distribution.scope():
mirrored = variables_lib.Variable(1.)
self.evaluate(mirrored.assign(3.))
@ -321,8 +316,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.assertEqual(self.evaluate(component.read_value()), 3.)
@combinations.generate(strategy_with_var_policy())
def testAssignAggregationMeanDTypeNonFloat(self, distribution,
use_var_policy):
def testAssignAggregationMeanDTypeNonFloat(self, distribution):
if isinstance(distribution, _TPU_STRATEGIES):
self.skipTest("Fix sponge/6e8ab540-4c0f-4da5-aedf-86505ff810c9 before "
"reenabling test.")
@ -379,8 +373,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.assertEqual(self.evaluate(v.read_value()), 4)
@combinations.generate(strategy_with_var_policy())
def testInitializedToSameValueInsideEagerRun(self, distribution,
use_var_policy):
def testInitializedToSameValueInsideEagerRun(self, distribution):
if not context.executing_eagerly(): self.skipTest("eager only test")
v = [None]
@ -399,7 +392,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.assertAllEqual(vals[0], vals[1])
@combinations.generate(strategy_with_var_policy())
def testAggregationOnlyFirstReplica(self, distribution, use_var_policy):
def testAggregationOnlyFirstReplica(self, distribution):
with distribution.scope():
v = variable_scope.variable(
15.,
@ -420,7 +413,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
per_replica_results)
@combinations.generate(strategy_with_var_policy())
def testInitScope(self, distribution, use_var_policy):
def testInitScope(self, distribution):
if not context.executing_eagerly(): self.skipTest("eager only")
class C(object):
@ -448,7 +441,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
self.assertAllEqual([2, 2], per_replica_results)
@combinations.generate(strategy_with_var_policy())
def testOperatorOverride(self, distribution, use_var_policy):
def testOperatorOverride(self, distribution):
with distribution.scope():
v = variable_scope.variable(