Put distribution and cluster parameter modifiers in their own combination
It's easier to add more strategy/cluster modifiers. PiperOrigin-RevId: 324859132 Change-Id: Ided1a3d1106d71e86335ca3c59d7698550fa2155
This commit is contained in:
parent
6acd86d539
commit
e1e2645945
@ -90,8 +90,8 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase):
|
||||
retrain_flag_value=["true", "false"],
|
||||
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
|
||||
)),
|
||||
test_combinations=(distribute_combinations.NamedGPUCombination(),
|
||||
distribute_combinations.NamedTPUCombination()))
|
||||
test_combinations=(distribute_combinations.GPUCombination(),
|
||||
distribute_combinations.TPUCombination()))
|
||||
|
||||
@combinations.generate(**TEST_MNIST_CNN_GENERATE_KWARGS)
|
||||
def test_mnist_cnn(self, use_keras_save_api, named_strategy,
|
||||
|
||||
@ -99,7 +99,24 @@ class ClusterParameters(combinations_lib.ParameterModifier):
|
||||
return update
|
||||
|
||||
|
||||
class NamedGPUCombination(combinations_lib.TestCombination):
|
||||
class DistributionCombination(combinations_lib.TestCombination):
|
||||
"""Sets up distribution strategy for tests."""
|
||||
|
||||
def parameter_modifiers(self):
|
||||
return [
|
||||
DistributionParameter(),
|
||||
combinations_lib.OptionalParameter("use_var_policy"),
|
||||
]
|
||||
|
||||
|
||||
class ClusterCombination(combinations_lib.TestCombination):
|
||||
"""Sets up multi worker tests."""
|
||||
|
||||
def parameter_modifiers(self):
|
||||
return [ClusterParameters()]
|
||||
|
||||
|
||||
class GPUCombination(combinations_lib.TestCombination):
|
||||
"""Enable tests to request GPU hardware and skip non-GPU combinations.
|
||||
|
||||
This class expects test_combinations to be generated with `NamedDistribution`
|
||||
@ -141,17 +158,7 @@ class NamedGPUCombination(combinations_lib.TestCombination):
|
||||
return [combinations_lib.OptionalParameter("required_gpus")]
|
||||
|
||||
|
||||
class GPUCombination(NamedGPUCombination):
|
||||
"""NamedGPUCombination that passes `tf.distribute.Strategy` to the tests."""
|
||||
|
||||
def parameter_modifiers(self):
|
||||
return [
|
||||
ClusterParameters(),
|
||||
DistributionParameter(),
|
||||
] + NamedGPUCombination.parameter_modifiers(self)
|
||||
|
||||
|
||||
class NamedTPUCombination(combinations_lib.TestCombination):
|
||||
class TPUCombination(combinations_lib.TestCombination):
|
||||
"""Allow to request TPU hardware and skip non-TPU combinations.
|
||||
|
||||
This class expects test_combinations to be generated with `NamedDistribution`
|
||||
@ -213,16 +220,6 @@ class NamedTPUCombination(combinations_lib.TestCombination):
|
||||
]
|
||||
|
||||
|
||||
class TPUCombination(NamedTPUCombination):
|
||||
"""NamedTPUCombination that passes `tf.distribute.Strategy` to the tests."""
|
||||
|
||||
def parameter_modifiers(self):
|
||||
return [
|
||||
ClusterParameters(),
|
||||
DistributionParameter(),
|
||||
] + NamedTPUCombination.parameter_modifiers(self)
|
||||
|
||||
|
||||
class NamedDistribution(object):
|
||||
"""Wraps a `tf.distribute.Strategy` and adds a name for test titles."""
|
||||
|
||||
@ -304,6 +301,8 @@ def generate(combinations, test_combinations=()):
|
||||
default_combinations = (
|
||||
framework_combinations.EagerGraphCombination(),
|
||||
framework_combinations.TFVersionCombination(),
|
||||
ClusterCombination(),
|
||||
DistributionCombination(),
|
||||
GPUCombination(),
|
||||
TPUCombination(),
|
||||
)
|
||||
|
||||
@ -30,7 +30,7 @@ from tensorflow.python.framework import combinations as framework_combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ClusterParametersTest(test.TestCase, parameterized.TestCase):
|
||||
class ClusterCombinationTest(test.TestCase, parameterized.TestCase):
|
||||
# For this test we need to use `framework.test_combinations` because our
|
||||
# `generate` eats the cluster parameters.
|
||||
#
|
||||
@ -42,7 +42,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.NamedDistribution(
|
||||
"HasClusterParams", lambda: None, has_chief=True, num_workers=2),
|
||||
]),
|
||||
test_combinations=(combinations.GPUCombination(),))
|
||||
test_combinations=(combinations.ClusterCombination(),))
|
||||
def testClusterParams(self, distribution, has_chief, num_workers):
|
||||
self.assertTrue(has_chief)
|
||||
self.assertEqual(num_workers, 2)
|
||||
@ -51,14 +51,14 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
|
||||
framework_combinations.combine(distribution=[
|
||||
combinations.NamedDistribution("NoClusterParams", lambda: None),
|
||||
]),
|
||||
test_combinations=(combinations.GPUCombination(),))
|
||||
test_combinations=(combinations.ClusterCombination(),))
|
||||
def testClusterParamsHasDefault(self, distribution, has_chief, num_workers):
|
||||
self.assertFalse(has_chief)
|
||||
self.assertEqual(num_workers, 1)
|
||||
|
||||
@framework_combinations.generate(
|
||||
framework_combinations.combine(v=1),
|
||||
test_combinations=(combinations.GPUCombination(),))
|
||||
test_combinations=(combinations.ClusterCombination(),))
|
||||
def testClusterParamsNoStrategy(self, v, has_chief, num_workers):
|
||||
self.assertFalse(has_chief)
|
||||
self.assertEqual(num_workers, 1)
|
||||
@ -69,7 +69,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
|
||||
"WithClusterParams", lambda: None, has_chief=True, num_workers=2),
|
||||
combinations.NamedDistribution("WithoutClusterParams", lambda: None),
|
||||
]),
|
||||
test_combinations=(combinations.GPUCombination(),))
|
||||
test_combinations=(combinations.ClusterCombination(),))
|
||||
def testClusterParamsAreOptional(self, distribution):
|
||||
# If combinations library doesn't raise an exception, the test is passed.
|
||||
pass
|
||||
@ -83,7 +83,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
|
||||
ds3=combinations.NamedDistribution(
|
||||
"Strategy3", lambda: None, has_chief=True, num_workers=0),
|
||||
),
|
||||
test_combinations=(combinations.GPUCombination(),))
|
||||
test_combinations=(combinations.ClusterCombination(),))
|
||||
def testMultipleDistributionSingleWorker(self, ds1, ds2, ds3):
|
||||
# If combinations library doesn't raise an exception, the test is passed.
|
||||
pass
|
||||
@ -101,7 +101,7 @@ class ClusterParametersShouldFailTest(test.TestCase, parameterized.TestCase):
|
||||
ds2=combinations.NamedDistribution(
|
||||
"Strategy2", lambda: None, has_chief=True, num_workers=2),
|
||||
),
|
||||
test_combinations=(combinations.GPUCombination(),))
|
||||
test_combinations=(combinations.ClusterCombination(),))
|
||||
def testMultipleDistributionMultiWorker(self, ds1, ds2):
|
||||
# combinations library should raise an exception.
|
||||
pass
|
||||
|
||||
@ -95,8 +95,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
||||
sess.run({"complicated": mirrored})
|
||||
|
||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||
def testAssign(self, distribution, experimental_run_tf_function,
|
||||
use_var_policy):
|
||||
def testAssign(self, distribution, experimental_run_tf_function):
|
||||
|
||||
def assign(fn, v, update_value, cross_replica):
|
||||
update_fn = lambda: getattr(v, fn)(update_value)
|
||||
@ -136,8 +135,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
||||
self.evaluate(array_ops.ones_like(component)))
|
||||
|
||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||
def testAssignOnWriteVar(self, distribution, experimental_run_tf_function,
|
||||
use_var_policy):
|
||||
def testAssignOnWriteVar(self, distribution, experimental_run_tf_function):
|
||||
|
||||
with distribution.scope():
|
||||
v_to_assign = variable_scope.variable(
|
||||
@ -182,8 +180,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(2.0, self.evaluate(component.read_value()))
|
||||
|
||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||
def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function,
|
||||
use_var_policy):
|
||||
def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function):
|
||||
|
||||
if isinstance(distribution, _TPU_STRATEGIES):
|
||||
self.skipTest("Assigning PerReplica values is not supported. See"
|
||||
@ -241,7 +238,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(expected, self.evaluate(component.read_value()))
|
||||
|
||||
@combinations.generate(strategy_with_var_policy())
|
||||
def testValueInReplicaContext(self, distribution, use_var_policy):
|
||||
def testValueInReplicaContext(self, distribution):
|
||||
with distribution.scope():
|
||||
v = variables_lib.Variable(
|
||||
1., aggregation=variables_lib.VariableAggregation.MEAN)
|
||||
@ -260,8 +257,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||
def testReadValueInReplicaContext(self, distribution,
|
||||
experimental_run_tf_function,
|
||||
use_var_policy):
|
||||
experimental_run_tf_function):
|
||||
aggregations = [
|
||||
variables_lib.VariableAggregation.NONE,
|
||||
variables_lib.VariableAggregation.SUM,
|
||||
@ -286,8 +282,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||
def testReadValueInCrossReplicaContext(self, distribution,
|
||||
experimental_run_tf_function,
|
||||
use_var_policy):
|
||||
experimental_run_tf_function):
|
||||
aggregations = [
|
||||
variables_lib.VariableAggregation.NONE,
|
||||
variables_lib.VariableAggregation.SUM,
|
||||
@ -312,7 +307,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
||||
self.evaluate(results))
|
||||
|
||||
@combinations.generate(strategy_with_var_policy())
|
||||
def testAssignOutOfScope(self, distribution, use_var_policy):
|
||||
def testAssignOutOfScope(self, distribution):
|
||||
with distribution.scope():
|
||||
mirrored = variables_lib.Variable(1.)
|
||||
self.evaluate(mirrored.assign(3.))
|
||||
@ -321,8 +316,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(component.read_value()), 3.)
|
||||
|
||||
@combinations.generate(strategy_with_var_policy())
|
||||
def testAssignAggregationMeanDTypeNonFloat(self, distribution,
|
||||
use_var_policy):
|
||||
def testAssignAggregationMeanDTypeNonFloat(self, distribution):
|
||||
if isinstance(distribution, _TPU_STRATEGIES):
|
||||
self.skipTest("Fix sponge/6e8ab540-4c0f-4da5-aedf-86505ff810c9 before "
|
||||
"reenabling test.")
|
||||
@ -379,8 +373,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(v.read_value()), 4)
|
||||
|
||||
@combinations.generate(strategy_with_var_policy())
|
||||
def testInitializedToSameValueInsideEagerRun(self, distribution,
|
||||
use_var_policy):
|
||||
def testInitializedToSameValueInsideEagerRun(self, distribution):
|
||||
if not context.executing_eagerly(): self.skipTest("eager only test")
|
||||
v = [None]
|
||||
|
||||
@ -399,7 +392,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(vals[0], vals[1])
|
||||
|
||||
@combinations.generate(strategy_with_var_policy())
|
||||
def testAggregationOnlyFirstReplica(self, distribution, use_var_policy):
|
||||
def testAggregationOnlyFirstReplica(self, distribution):
|
||||
with distribution.scope():
|
||||
v = variable_scope.variable(
|
||||
15.,
|
||||
@ -420,7 +413,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
||||
per_replica_results)
|
||||
|
||||
@combinations.generate(strategy_with_var_policy())
|
||||
def testInitScope(self, distribution, use_var_policy):
|
||||
def testInitScope(self, distribution):
|
||||
if not context.executing_eagerly(): self.skipTest("eager only")
|
||||
|
||||
class C(object):
|
||||
@ -448,7 +441,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual([2, 2], per_replica_results)
|
||||
|
||||
@combinations.generate(strategy_with_var_policy())
|
||||
def testOperatorOverride(self, distribution, use_var_policy):
|
||||
def testOperatorOverride(self, distribution):
|
||||
|
||||
with distribution.scope():
|
||||
v = variable_scope.variable(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user