Add multi worker mirrored strategy to DistributedVariable test. Some cases that are broken are currently skipped and being fixed in separate changes.
PiperOrigin-RevId: 323948947 Change-Id: I3bd22d5309fb7be491b6036134b328883957e15c
This commit is contained in:
parent
55c87a582d
commit
26dc0e3ee2
@ -26,6 +26,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
@ -408,6 +409,9 @@ class DistributedDelegateTest(test.TestCase):
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu
|
||||
],
|
||||
synchronization=[
|
||||
variables_lib.VariableSynchronization.ON_READ,
|
||||
@ -427,7 +431,19 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
1., synchronization=synchronization, aggregation=aggregation)
|
||||
self.assertIsInstance(v, variables_lib.Variable)
|
||||
|
||||
def testCheckpointing(self, distribution, synchronization, aggregation):
|
||||
def testCheckpointing(self, distribution, synchronization, aggregation, mode):
|
||||
# TODO(anjs): Remove this when b/162147051 is fixed.
|
||||
if (isinstance(distribution,
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
|
||||
and aggregation == variables_lib.VariableAggregation.SUM
|
||||
and synchronization == variables_lib.VariableSynchronization.ON_READ):
|
||||
self.skipTest("b/162147051")
|
||||
|
||||
if (isinstance(distribution,
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
|
||||
and mode == "graph"):
|
||||
self.skipTest("MWMS combinations tests do not work well in graph mode.")
|
||||
|
||||
with distribution.scope():
|
||||
v = variables_lib.Variable(
|
||||
constant_op.constant([1., 2., 3., 4]),
|
||||
@ -594,6 +610,12 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.skipTest("b/148689177: AggregatingVariable doesn't "
|
||||
"conform to Variable interface well")
|
||||
|
||||
# TODO(crccw): Unskip this in cl/323875091.
|
||||
if (isinstance(
|
||||
distribution,
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy)):
|
||||
self.skipTest("Being fixed in cl/323875091")
|
||||
|
||||
# tf.function requires the return value to be Tensors, which is not always
|
||||
# case for properties and methods of Variable, so we simply discard the
|
||||
# return values.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user