From e1e264594584d740ce5b077d92d58ca167318f22 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Tue, 4 Aug 2020 11:56:38 -0700 Subject: [PATCH] Put distribution and cluster parameter modifiers in their own combination It's easier to add more strategy/cluster modifiers. PiperOrigin-RevId: 324859132 Change-Id: Ided1a3d1106d71e86335ca3c59d7698550fa2155 --- .../integration_tests/saved_model_test.py | 4 +- tensorflow/python/distribute/combinations.py | 43 +++++++++---------- .../python/distribute/combinations_test.py | 14 +++--- tensorflow/python/distribute/vars_test.py | 31 ++++++------- 4 files changed, 42 insertions(+), 50 deletions(-) diff --git a/tensorflow/examples/saved_model/integration_tests/saved_model_test.py b/tensorflow/examples/saved_model/integration_tests/saved_model_test.py index 6333e55999e..434d5ed4ad5 100644 --- a/tensorflow/examples/saved_model/integration_tests/saved_model_test.py +++ b/tensorflow/examples/saved_model/integration_tests/saved_model_test.py @@ -90,8 +90,8 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase): retrain_flag_value=["true", "false"], regularization_loss_multiplier=[None, 2], # Test for b/134528831. )), - test_combinations=(distribute_combinations.NamedGPUCombination(), - distribute_combinations.NamedTPUCombination())) + test_combinations=(distribute_combinations.GPUCombination(), + distribute_combinations.TPUCombination())) @combinations.generate(**TEST_MNIST_CNN_GENERATE_KWARGS) def test_mnist_cnn(self, use_keras_save_api, named_strategy, diff --git a/tensorflow/python/distribute/combinations.py b/tensorflow/python/distribute/combinations.py index 17bc285b222..3856b6fd132 100644 --- a/tensorflow/python/distribute/combinations.py +++ b/tensorflow/python/distribute/combinations.py @@ -99,7 +99,24 @@ class ClusterParameters(combinations_lib.ParameterModifier): return update -class NamedGPUCombination(combinations_lib.TestCombination): +class DistributionCombination(combinations_lib.TestCombination): + """Sets up distribution strategy for tests.""" + + def parameter_modifiers(self): + return [ + DistributionParameter(), + combinations_lib.OptionalParameter("use_var_policy"), + ] + + +class ClusterCombination(combinations_lib.TestCombination): + """Sets up multi worker tests.""" + + def parameter_modifiers(self): + return [ClusterParameters()] + + +class GPUCombination(combinations_lib.TestCombination): """Enable tests to request GPU hardware and skip non-GPU combinations. This class expects test_combinations to be generated with `NamedDistribution` @@ -141,17 +158,7 @@ class NamedGPUCombination(combinations_lib.TestCombination): return [combinations_lib.OptionalParameter("required_gpus")] -class GPUCombination(NamedGPUCombination): - """NamedGPUCombination that passes `tf.distribute.Strategy` to the tests.""" - - def parameter_modifiers(self): - return [ - ClusterParameters(), - DistributionParameter(), - ] + NamedGPUCombination.parameter_modifiers(self) - - -class NamedTPUCombination(combinations_lib.TestCombination): +class TPUCombination(combinations_lib.TestCombination): """Allow to request TPU hardware and skip non-TPU combinations. This class expects test_combinations to be generated with `NamedDistribution` @@ -213,16 +220,6 @@ class NamedTPUCombination(combinations_lib.TestCombination): ] -class TPUCombination(NamedTPUCombination): - """NamedTPUCombination that passes `tf.distribute.Strategy` to the tests.""" - - def parameter_modifiers(self): - return [ - ClusterParameters(), - DistributionParameter(), - ] + NamedTPUCombination.parameter_modifiers(self) - - class NamedDistribution(object): """Wraps a `tf.distribute.Strategy` and adds a name for test titles.""" @@ -304,6 +301,8 @@ def generate(combinations, test_combinations=()): default_combinations = ( framework_combinations.EagerGraphCombination(), framework_combinations.TFVersionCombination(), + ClusterCombination(), + DistributionCombination(), GPUCombination(), TPUCombination(), ) diff --git a/tensorflow/python/distribute/combinations_test.py b/tensorflow/python/distribute/combinations_test.py index 6d9d0b2570f..3fc3735d560 100644 --- a/tensorflow/python/distribute/combinations_test.py +++ b/tensorflow/python/distribute/combinations_test.py @@ -30,7 +30,7 @@ from tensorflow.python.framework import combinations as framework_combinations from tensorflow.python.platform import test -class ClusterParametersTest(test.TestCase, parameterized.TestCase): +class ClusterCombinationTest(test.TestCase, parameterized.TestCase): # For this test we need to use `framework.test_combinations` because our # `generate` eats the cluster parameters. # @@ -42,7 +42,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase): combinations.NamedDistribution( "HasClusterParams", lambda: None, has_chief=True, num_workers=2), ]), - test_combinations=(combinations.GPUCombination(),)) + test_combinations=(combinations.ClusterCombination(),)) def testClusterParams(self, distribution, has_chief, num_workers): self.assertTrue(has_chief) self.assertEqual(num_workers, 2) @@ -51,14 +51,14 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase): framework_combinations.combine(distribution=[ combinations.NamedDistribution("NoClusterParams", lambda: None), ]), - test_combinations=(combinations.GPUCombination(),)) + test_combinations=(combinations.ClusterCombination(),)) def testClusterParamsHasDefault(self, distribution, has_chief, num_workers): self.assertFalse(has_chief) self.assertEqual(num_workers, 1) @framework_combinations.generate( framework_combinations.combine(v=1), - test_combinations=(combinations.GPUCombination(),)) + test_combinations=(combinations.ClusterCombination(),)) def testClusterParamsNoStrategy(self, v, has_chief, num_workers): self.assertFalse(has_chief) self.assertEqual(num_workers, 1) @@ -69,7 +69,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase): "WithClusterParams", lambda: None, has_chief=True, num_workers=2), combinations.NamedDistribution("WithoutClusterParams", lambda: None), ]), - test_combinations=(combinations.GPUCombination(),)) + test_combinations=(combinations.ClusterCombination(),)) def testClusterParamsAreOptional(self, distribution): # If combinations library doesn't raise an exception, the test is passed. pass @@ -83,7 +83,7 @@ class ClusterParametersTest(test.TestCase, parameterized.TestCase): ds3=combinations.NamedDistribution( "Strategy3", lambda: None, has_chief=True, num_workers=0), ), - test_combinations=(combinations.GPUCombination(),)) + test_combinations=(combinations.ClusterCombination(),)) def testMultipleDistributionSingleWorker(self, ds1, ds2, ds3): # If combinations library doesn't raise an exception, the test is passed. pass @@ -101,7 +101,7 @@ class ClusterParametersShouldFailTest(test.TestCase, parameterized.TestCase): ds2=combinations.NamedDistribution( "Strategy2", lambda: None, has_chief=True, num_workers=2), ), - test_combinations=(combinations.GPUCombination(),)) + test_combinations=(combinations.ClusterCombination(),)) def testMultipleDistributionMultiWorker(self, ds1, ds2): # combinations library should raise an exception. pass diff --git a/tensorflow/python/distribute/vars_test.py b/tensorflow/python/distribute/vars_test.py index efbb6c23aaa..a8605a3f2da 100644 --- a/tensorflow/python/distribute/vars_test.py +++ b/tensorflow/python/distribute/vars_test.py @@ -95,8 +95,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase): sess.run({"complicated": mirrored}) @combinations.generate(strategy_and_run_tf_function_combinations()) - def testAssign(self, distribution, experimental_run_tf_function, - use_var_policy): + def testAssign(self, distribution, experimental_run_tf_function): def assign(fn, v, update_value, cross_replica): update_fn = lambda: getattr(v, fn)(update_value) @@ -136,8 +135,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase): self.evaluate(array_ops.ones_like(component))) @combinations.generate(strategy_and_run_tf_function_combinations()) - def testAssignOnWriteVar(self, distribution, experimental_run_tf_function, - use_var_policy): + def testAssignOnWriteVar(self, distribution, experimental_run_tf_function): with distribution.scope(): v_to_assign = variable_scope.variable( @@ -182,8 +180,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase): self.assertAllEqual(2.0, self.evaluate(component.read_value())) @combinations.generate(strategy_and_run_tf_function_combinations()) - def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function, - use_var_policy): + def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function): if isinstance(distribution, _TPU_STRATEGIES): self.skipTest("Assigning PerReplica values is not supported. See" @@ -241,7 +238,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase): self.assertAllEqual(expected, self.evaluate(component.read_value())) @combinations.generate(strategy_with_var_policy()) - def testValueInReplicaContext(self, distribution, use_var_policy): + def testValueInReplicaContext(self, distribution): with distribution.scope(): v = variables_lib.Variable( 1., aggregation=variables_lib.VariableAggregation.MEAN) @@ -260,8 +257,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase): @combinations.generate(strategy_and_run_tf_function_combinations()) def testReadValueInReplicaContext(self, distribution, - experimental_run_tf_function, - use_var_policy): + experimental_run_tf_function): aggregations = [ variables_lib.VariableAggregation.NONE, variables_lib.VariableAggregation.SUM, @@ -286,8 +282,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase): @combinations.generate(strategy_and_run_tf_function_combinations()) def testReadValueInCrossReplicaContext(self, distribution, - experimental_run_tf_function, - use_var_policy): + experimental_run_tf_function): aggregations = [ variables_lib.VariableAggregation.NONE, variables_lib.VariableAggregation.SUM, @@ -312,7 +307,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase): self.evaluate(results)) @combinations.generate(strategy_with_var_policy()) - def testAssignOutOfScope(self, distribution, use_var_policy): + def testAssignOutOfScope(self, distribution): with distribution.scope(): mirrored = variables_lib.Variable(1.) self.evaluate(mirrored.assign(3.)) @@ -321,8 +316,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase): self.assertEqual(self.evaluate(component.read_value()), 3.) @combinations.generate(strategy_with_var_policy()) - def testAssignAggregationMeanDTypeNonFloat(self, distribution, - use_var_policy): + def testAssignAggregationMeanDTypeNonFloat(self, distribution): if isinstance(distribution, _TPU_STRATEGIES): self.skipTest("Fix sponge/6e8ab540-4c0f-4da5-aedf-86505ff810c9 before " "reenabling test.") @@ -379,8 +373,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase): self.assertEqual(self.evaluate(v.read_value()), 4) @combinations.generate(strategy_with_var_policy()) - def testInitializedToSameValueInsideEagerRun(self, distribution, - use_var_policy): + def testInitializedToSameValueInsideEagerRun(self, distribution): if not context.executing_eagerly(): self.skipTest("eager only test") v = [None] @@ -399,7 +392,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase): self.assertAllEqual(vals[0], vals[1]) @combinations.generate(strategy_with_var_policy()) - def testAggregationOnlyFirstReplica(self, distribution, use_var_policy): + def testAggregationOnlyFirstReplica(self, distribution): with distribution.scope(): v = variable_scope.variable( 15., @@ -420,7 +413,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase): per_replica_results) @combinations.generate(strategy_with_var_policy()) - def testInitScope(self, distribution, use_var_policy): + def testInitScope(self, distribution): if not context.executing_eagerly(): self.skipTest("eager only") class C(object): @@ -448,7 +441,7 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase): self.assertAllEqual([2, 2], per_replica_results) @combinations.generate(strategy_with_var_policy()) - def testOperatorOverride(self, distribution, use_var_policy): + def testOperatorOverride(self, distribution): with distribution.scope(): v = variable_scope.variable(