Add multi worker mirrored strategy to DistributedVariable test. Some cases that are broken are currently skipped and being fixed in separate changes.

PiperOrigin-RevId: 323948947
Change-Id: I3bd22d5309fb7be491b6036134b328883957e15c
This commit is contained in:
Priya Gupta 2020-07-30 00:04:00 -07:00 committed by TensorFlower Gardener
parent 55c87a582d
commit 26dc0e3ee2

View File

@ -26,6 +26,7 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import tf2
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
@ -408,6 +409,9 @@ class DistributedDelegateTest(test.TestCase):
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
strategy_combinations.central_storage_strategy_with_two_gpus,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu
],
synchronization=[
variables_lib.VariableSynchronization.ON_READ,
@ -427,7 +431,19 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
1., synchronization=synchronization, aggregation=aggregation)
self.assertIsInstance(v, variables_lib.Variable)
def testCheckpointing(self, distribution, synchronization, aggregation):
def testCheckpointing(self, distribution, synchronization, aggregation, mode):
# TODO(anjs): Remove this when b/162147051 is fixed.
if (isinstance(distribution,
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
and aggregation == variables_lib.VariableAggregation.SUM
and synchronization == variables_lib.VariableSynchronization.ON_READ):
self.skipTest("b/162147051")
if (isinstance(distribution,
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
and mode == "graph"):
self.skipTest("MWMS combinations tests do not work well in graph mode.")
with distribution.scope():
v = variables_lib.Variable(
constant_op.constant([1., 2., 3., 4]),
@ -594,6 +610,12 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
self.skipTest("b/148689177: AggregatingVariable doesn't "
"conform to Variable interface well")
# TODO(crccw): Unskip this in cl/323875091.
if (isinstance(
distribution,
collective_all_reduce_strategy.CollectiveAllReduceStrategy)):
self.skipTest("Being fixed in cl/323875091")
# tf.function requires the return value to be Tensors, which is not always
# case for properties and methods of Variable, so we simply discard the
# return values.