diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index 447a8b427eb..a9d7dd2125e 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -26,6 +26,7 @@ import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python import tf2 +from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_utils @@ -408,6 +409,9 @@ class DistributedDelegateTest(test.TestCase): strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy_packed_var, strategy_combinations.central_storage_strategy_with_two_gpus, + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu ], synchronization=[ variables_lib.VariableSynchronization.ON_READ, @@ -427,7 +431,19 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase): 1., synchronization=synchronization, aggregation=aggregation) self.assertIsInstance(v, variables_lib.Variable) - def testCheckpointing(self, distribution, synchronization, aggregation): + def testCheckpointing(self, distribution, synchronization, aggregation, mode): + # TODO(anjs): Remove this when b/162147051 is fixed. + if (isinstance(distribution, + collective_all_reduce_strategy.CollectiveAllReduceStrategy) + and aggregation == variables_lib.VariableAggregation.SUM + and synchronization == variables_lib.VariableSynchronization.ON_READ): + self.skipTest("b/162147051") + + if (isinstance(distribution, + collective_all_reduce_strategy.CollectiveAllReduceStrategy) + and mode == "graph"): + self.skipTest("MWMS combinations tests do not work well in graph mode.") + with distribution.scope(): v = variables_lib.Variable( constant_op.constant([1., 2., 3., 4]), @@ -594,6 +610,12 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase): self.skipTest("b/148689177: AggregatingVariable doesn't " "conform to Variable interface well") + # TODO(crccw): Unskip this in cl/323875091. + if (isinstance( + distribution, + collective_all_reduce_strategy.CollectiveAllReduceStrategy)): + self.skipTest("Being fixed in cl/323875091") + # tf.function requires the return value to be Tensors, which is not always # case for properties and methods of Variable, so we simply discard the # return values.