Put distribution and cluster parameter modifiers in their own combination
It's easier to add more strategy/cluster modifiers. PiperOrigin-RevId: 324859132 Change-Id: Ided1a3d1106d71e86335ca3c59d7698550fa2155
This commit is contained in:
parent
6acd86d539
commit
e1e2645945
@ -90,8 +90,8 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase):
|
|||||||
retrain_flag_value=["true", "false"],
|
retrain_flag_value=["true", "false"],
|
||||||
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
|
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
|
||||||
)),
|
)),
|
||||||
test_combinations=(distribute_combinations.NamedGPUCombination(),
|
test_combinations=(distribute_combinations.GPUCombination(),
|
||||||
distribute_combinations.NamedTPUCombination()))
|
distribute_combinations.TPUCombination()))
|
||||||
|
|
||||||
@combinations.generate(**TEST_MNIST_CNN_GENERATE_KWARGS)
|
@combinations.generate(**TEST_MNIST_CNN_GENERATE_KWARGS)
|
||||||
def test_mnist_cnn(self, use_keras_save_api, named_strategy,
|
def test_mnist_cnn(self, use_keras_save_api, named_strategy,
|
||||||
|
|||||||
@ -99,7 +99,24 @@ class ClusterParameters(combinations_lib.ParameterModifier):
|
|||||||
return update
|
return update
|
||||||
|
|
||||||
|
|
||||||
class NamedGPUCombination(combinations_lib.TestCombination):
|
class DistributionCombination(combinations_lib.TestCombination):
|
||||||
|
"""Sets up distribution strategy for tests."""
|
||||||
|
|
||||||
|
def parameter_modifiers(self):
|
||||||
|
return [
|
||||||
|
DistributionParameter(),
|
||||||
|
combinations_lib.OptionalParameter("use_var_policy"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ClusterCombination(combinations_lib.TestCombination):
|
||||||
|
"""Sets up multi worker tests."""
|
||||||
|
|
||||||
|
def parameter_modifiers(self):
|
||||||
|
return [ClusterParameters()]
|
||||||
|
|
||||||
|
|
||||||
|
class GPUCombination(combinations_lib.TestCombination):
|
||||||
"""Enable tests to request GPU hardware and skip non-GPU combinations.
|
"""Enable tests to request GPU hardware and skip non-GPU combinations.
|
||||||
|
|
||||||
This class expects test_combinations to be generated with `NamedDistribution`
|
This class expects test_combinations to be generated with `NamedDistribution`
|
||||||
@ -141,17 +158,7 @@ class NamedGPUCombination(combinations_lib.TestCombination):
|
|||||||
return [combinations_lib.OptionalParameter("required_gpus")]
|
return [combinations_lib.OptionalParameter("required_gpus")]
|
||||||
|
|
||||||
|
|
||||||
class GPUCombination(NamedGPUCombination):
|
class TPUCombination(combinations_lib.TestCombination):
|
||||||
"""NamedGPUCombination that passes `tf.distribute.Strategy` to the tests."""
|
|
||||||
|
|
||||||
def parameter_modifiers(self):
|
|
||||||
return [
|
|
||||||
ClusterParameters(),
|
|
||||||
DistributionParameter(),
|
|
||||||
] + NamedGPUCombination.parameter_modifiers(self)
|
|
||||||
|
|
||||||
|
|
||||||
class NamedTPUCombination(combinations_lib.TestCombination):
|
|
||||||
"""Allow to request TPU hardware and skip non-TPU combinations.
|
"""Allow to request TPU hardware and skip non-TPU combinations.
|
||||||
|
|
||||||
This class expects test_combinations to be generated with `NamedDistribution`
|
This class expects test_combinations to be generated with `NamedDistribution`
|
||||||
@ -213,16 +220,6 @@ class NamedTPUCombination(combinations_lib.TestCombination):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TPUCombination(NamedTPUCombination):
|
|
||||||
"""NamedTPUCombination that passes `tf.distribute.Strategy` to the tests."""
|
|
||||||
|
|
||||||
def parameter_modifiers(self):
|
|
||||||
return [
|
|
||||||
ClusterParameters(),
|
|
||||||
DistributionParameter(),
|
|
||||||
] + NamedTPUCombination.parameter_modifiers(self)
|
|
||||||
|
|
||||||
|
|
||||||
class NamedDistribution(object):
|
class NamedDistribution(object):
|
||||||
"""Wraps a `tf.distribute.Strategy` and adds a name for test titles."""
|
"""Wraps a `tf.distribute.Strategy` and adds a name for test titles."""
|
||||||
|
|
||||||
@ -304,6 +301,8 @@ def generate(combinations, test_combinations=()):
|
|||||||
default_combinations = (
|
default_combinations = (
|
||||||
framework_combinations.EagerGraphCombination(),
|
framework_combinations.EagerGraphCombination(),
|
||||||
framework_combinations.TFVersionCombination(),
|
framework_combinations.TFVersionCombination(),
|
||||||
|
ClusterCombination(),
|
||||||
|
DistributionCombination(),
|
||||||
GPUCombination(),
|
GPUCombination(),
|
||||||
TPUCombination(),
|
TPUCombination(),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from tensorflow.python.framework import combinations as framework_combinations
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class ClusterParametersTest(test.TestCase, parameterized.TestCase):
|
class ClusterCombinationTest(test.TestCase, parameterized.TestCase):
|
||||||
# For this test we need to use `framework.test_combinations` because our
|
# For this test we need to use `framework.test_combinations` because our
|
||||||
# `generate` eats the cluster parameters.
|
# `generate` eats the cluster parameters.
|
||||||
#
|
#
|
||||||
@ -42,7 +42,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
|
|||||||
combinations.NamedDistribution(
|
combinations.NamedDistribution(
|
||||||
"HasClusterParams", lambda: None, has_chief=True, num_workers=2),
|
"HasClusterParams", lambda: None, has_chief=True, num_workers=2),
|
||||||
]),
|
]),
|
||||||
test_combinations=(combinations.GPUCombination(),))
|
test_combinations=(combinations.ClusterCombination(),))
|
||||||
def testClusterParams(self, distribution, has_chief, num_workers):
|
def testClusterParams(self, distribution, has_chief, num_workers):
|
||||||
self.assertTrue(has_chief)
|
self.assertTrue(has_chief)
|
||||||
self.assertEqual(num_workers, 2)
|
self.assertEqual(num_workers, 2)
|
||||||
@ -51,14 +51,14 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
|
|||||||
framework_combinations.combine(distribution=[
|
framework_combinations.combine(distribution=[
|
||||||
combinations.NamedDistribution("NoClusterParams", lambda: None),
|
combinations.NamedDistribution("NoClusterParams", lambda: None),
|
||||||
]),
|
]),
|
||||||
test_combinations=(combinations.GPUCombination(),))
|
test_combinations=(combinations.ClusterCombination(),))
|
||||||
def testClusterParamsHasDefault(self, distribution, has_chief, num_workers):
|
def testClusterParamsHasDefault(self, distribution, has_chief, num_workers):
|
||||||
self.assertFalse(has_chief)
|
self.assertFalse(has_chief)
|
||||||
self.assertEqual(num_workers, 1)
|
self.assertEqual(num_workers, 1)
|
||||||
|
|
||||||
@framework_combinations.generate(
|
@framework_combinations.generate(
|
||||||
framework_combinations.combine(v=1),
|
framework_combinations.combine(v=1),
|
||||||
test_combinations=(combinations.GPUCombination(),))
|
test_combinations=(combinations.ClusterCombination(),))
|
||||||
def testClusterParamsNoStrategy(self, v, has_chief, num_workers):
|
def testClusterParamsNoStrategy(self, v, has_chief, num_workers):
|
||||||
self.assertFalse(has_chief)
|
self.assertFalse(has_chief)
|
||||||
self.assertEqual(num_workers, 1)
|
self.assertEqual(num_workers, 1)
|
||||||
@ -69,7 +69,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
|
|||||||
"WithClusterParams", lambda: None, has_chief=True, num_workers=2),
|
"WithClusterParams", lambda: None, has_chief=True, num_workers=2),
|
||||||
combinations.NamedDistribution("WithoutClusterParams", lambda: None),
|
combinations.NamedDistribution("WithoutClusterParams", lambda: None),
|
||||||
]),
|
]),
|
||||||
test_combinations=(combinations.GPUCombination(),))
|
test_combinations=(combinations.ClusterCombination(),))
|
||||||
def testClusterParamsAreOptional(self, distribution):
|
def testClusterParamsAreOptional(self, distribution):
|
||||||
# If combinations library doesn't raise an exception, the test is passed.
|
# If combinations library doesn't raise an exception, the test is passed.
|
||||||
pass
|
pass
|
||||||
@ -83,7 +83,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase):
|
|||||||
ds3=combinations.NamedDistribution(
|
ds3=combinations.NamedDistribution(
|
||||||
"Strategy3", lambda: None, has_chief=True, num_workers=0),
|
"Strategy3", lambda: None, has_chief=True, num_workers=0),
|
||||||
),
|
),
|
||||||
test_combinations=(combinations.GPUCombination(),))
|
test_combinations=(combinations.ClusterCombination(),))
|
||||||
def testMultipleDistributionSingleWorker(self, ds1, ds2, ds3):
|
def testMultipleDistributionSingleWorker(self, ds1, ds2, ds3):
|
||||||
# If combinations library doesn't raise an exception, the test is passed.
|
# If combinations library doesn't raise an exception, the test is passed.
|
||||||
pass
|
pass
|
||||||
@ -101,7 +101,7 @@ class ClusterParametersShouldFailTest(test.TestCase, parameterized.TestCase):
|
|||||||
ds2=combinations.NamedDistribution(
|
ds2=combinations.NamedDistribution(
|
||||||
"Strategy2", lambda: None, has_chief=True, num_workers=2),
|
"Strategy2", lambda: None, has_chief=True, num_workers=2),
|
||||||
),
|
),
|
||||||
test_combinations=(combinations.GPUCombination(),))
|
test_combinations=(combinations.ClusterCombination(),))
|
||||||
def testMultipleDistributionMultiWorker(self, ds1, ds2):
|
def testMultipleDistributionMultiWorker(self, ds1, ds2):
|
||||||
# combinations library should raise an exception.
|
# combinations library should raise an exception.
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -95,8 +95,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
|||||||
sess.run({"complicated": mirrored})
|
sess.run({"complicated": mirrored})
|
||||||
|
|
||||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||||
def testAssign(self, distribution, experimental_run_tf_function,
|
def testAssign(self, distribution, experimental_run_tf_function):
|
||||||
use_var_policy):
|
|
||||||
|
|
||||||
def assign(fn, v, update_value, cross_replica):
|
def assign(fn, v, update_value, cross_replica):
|
||||||
update_fn = lambda: getattr(v, fn)(update_value)
|
update_fn = lambda: getattr(v, fn)(update_value)
|
||||||
@ -136,8 +135,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
|||||||
self.evaluate(array_ops.ones_like(component)))
|
self.evaluate(array_ops.ones_like(component)))
|
||||||
|
|
||||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||||
def testAssignOnWriteVar(self, distribution, experimental_run_tf_function,
|
def testAssignOnWriteVar(self, distribution, experimental_run_tf_function):
|
||||||
use_var_policy):
|
|
||||||
|
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
v_to_assign = variable_scope.variable(
|
v_to_assign = variable_scope.variable(
|
||||||
@ -182,8 +180,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(2.0, self.evaluate(component.read_value()))
|
self.assertAllEqual(2.0, self.evaluate(component.read_value()))
|
||||||
|
|
||||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||||
def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function,
|
def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function):
|
||||||
use_var_policy):
|
|
||||||
|
|
||||||
if isinstance(distribution, _TPU_STRATEGIES):
|
if isinstance(distribution, _TPU_STRATEGIES):
|
||||||
self.skipTest("Assigning PerReplica values is not supported. See"
|
self.skipTest("Assigning PerReplica values is not supported. See"
|
||||||
@ -241,7 +238,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(expected, self.evaluate(component.read_value()))
|
self.assertAllEqual(expected, self.evaluate(component.read_value()))
|
||||||
|
|
||||||
@combinations.generate(strategy_with_var_policy())
|
@combinations.generate(strategy_with_var_policy())
|
||||||
def testValueInReplicaContext(self, distribution, use_var_policy):
|
def testValueInReplicaContext(self, distribution):
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
v = variables_lib.Variable(
|
v = variables_lib.Variable(
|
||||||
1., aggregation=variables_lib.VariableAggregation.MEAN)
|
1., aggregation=variables_lib.VariableAggregation.MEAN)
|
||||||
@ -260,8 +257,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||||
def testReadValueInReplicaContext(self, distribution,
|
def testReadValueInReplicaContext(self, distribution,
|
||||||
experimental_run_tf_function,
|
experimental_run_tf_function):
|
||||||
use_var_policy):
|
|
||||||
aggregations = [
|
aggregations = [
|
||||||
variables_lib.VariableAggregation.NONE,
|
variables_lib.VariableAggregation.NONE,
|
||||||
variables_lib.VariableAggregation.SUM,
|
variables_lib.VariableAggregation.SUM,
|
||||||
@ -286,8 +282,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||||
def testReadValueInCrossReplicaContext(self, distribution,
|
def testReadValueInCrossReplicaContext(self, distribution,
|
||||||
experimental_run_tf_function,
|
experimental_run_tf_function):
|
||||||
use_var_policy):
|
|
||||||
aggregations = [
|
aggregations = [
|
||||||
variables_lib.VariableAggregation.NONE,
|
variables_lib.VariableAggregation.NONE,
|
||||||
variables_lib.VariableAggregation.SUM,
|
variables_lib.VariableAggregation.SUM,
|
||||||
@ -312,7 +307,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
|||||||
self.evaluate(results))
|
self.evaluate(results))
|
||||||
|
|
||||||
@combinations.generate(strategy_with_var_policy())
|
@combinations.generate(strategy_with_var_policy())
|
||||||
def testAssignOutOfScope(self, distribution, use_var_policy):
|
def testAssignOutOfScope(self, distribution):
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
mirrored = variables_lib.Variable(1.)
|
mirrored = variables_lib.Variable(1.)
|
||||||
self.evaluate(mirrored.assign(3.))
|
self.evaluate(mirrored.assign(3.))
|
||||||
@ -321,8 +316,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual(self.evaluate(component.read_value()), 3.)
|
self.assertEqual(self.evaluate(component.read_value()), 3.)
|
||||||
|
|
||||||
@combinations.generate(strategy_with_var_policy())
|
@combinations.generate(strategy_with_var_policy())
|
||||||
def testAssignAggregationMeanDTypeNonFloat(self, distribution,
|
def testAssignAggregationMeanDTypeNonFloat(self, distribution):
|
||||||
use_var_policy):
|
|
||||||
if isinstance(distribution, _TPU_STRATEGIES):
|
if isinstance(distribution, _TPU_STRATEGIES):
|
||||||
self.skipTest("Fix sponge/6e8ab540-4c0f-4da5-aedf-86505ff810c9 before "
|
self.skipTest("Fix sponge/6e8ab540-4c0f-4da5-aedf-86505ff810c9 before "
|
||||||
"reenabling test.")
|
"reenabling test.")
|
||||||
@ -379,8 +373,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual(self.evaluate(v.read_value()), 4)
|
self.assertEqual(self.evaluate(v.read_value()), 4)
|
||||||
|
|
||||||
@combinations.generate(strategy_with_var_policy())
|
@combinations.generate(strategy_with_var_policy())
|
||||||
def testInitializedToSameValueInsideEagerRun(self, distribution,
|
def testInitializedToSameValueInsideEagerRun(self, distribution):
|
||||||
use_var_policy):
|
|
||||||
if not context.executing_eagerly(): self.skipTest("eager only test")
|
if not context.executing_eagerly(): self.skipTest("eager only test")
|
||||||
v = [None]
|
v = [None]
|
||||||
|
|
||||||
@ -399,7 +392,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(vals[0], vals[1])
|
self.assertAllEqual(vals[0], vals[1])
|
||||||
|
|
||||||
@combinations.generate(strategy_with_var_policy())
|
@combinations.generate(strategy_with_var_policy())
|
||||||
def testAggregationOnlyFirstReplica(self, distribution, use_var_policy):
|
def testAggregationOnlyFirstReplica(self, distribution):
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
v = variable_scope.variable(
|
v = variable_scope.variable(
|
||||||
15.,
|
15.,
|
||||||
@ -420,7 +413,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
|||||||
per_replica_results)
|
per_replica_results)
|
||||||
|
|
||||||
@combinations.generate(strategy_with_var_policy())
|
@combinations.generate(strategy_with_var_policy())
|
||||||
def testInitScope(self, distribution, use_var_policy):
|
def testInitScope(self, distribution):
|
||||||
if not context.executing_eagerly(): self.skipTest("eager only")
|
if not context.executing_eagerly(): self.skipTest("eager only")
|
||||||
|
|
||||||
class C(object):
|
class C(object):
|
||||||
@ -448,7 +441,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual([2, 2], per_replica_results)
|
self.assertAllEqual([2, 2], per_replica_results)
|
||||||
|
|
||||||
@combinations.generate(strategy_with_var_policy())
|
@combinations.generate(strategy_with_var_policy())
|
||||||
def testOperatorOverride(self, distribution, use_var_policy):
|
def testOperatorOverride(self, distribution):
|
||||||
|
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
v = variable_scope.variable(
|
v = variable_scope.variable(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user