From 540f6db6f85f4d2431fb0bc9c2bb56c731f6f347 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Fri, 9 Oct 2020 13:15:05 -0700 Subject: [PATCH] Return the correct replica id within a sync group for MWMS. Currently we return the local replica id within a worker as opposed to within a sync group. PiperOrigin-RevId: 336352000 Change-Id: Ie8936ee327340f04ea4edfae6aada6bf719f3488 --- tensorflow/python/distribute/BUILD | 2 ++ .../collective_all_reduce_strategy.py | 7 +++++ .../collective_all_reduce_strategy_test.py | 29 ++++++++++++++++++- .../python/distribute/distribute_lib.py | 16 ++++++++++ .../python/distribute/distribute_lib_test.py | 5 +++- tensorflow/python/distribute/mirrored_run.py | 6 +++- .../python/distribute/mirrored_strategy.py | 6 ++++ .../python/distribute/one_device_strategy.py | 3 ++ .../distribute/parameter_server_strategy.py | 6 ++++ tensorflow/python/distribute/tpu_strategy.py | 3 ++ tensorflow/python/distribute/values_util.py | 2 +- 11 files changed, 81 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index ff1c5355ecb..ae59ff50705 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1524,7 +1524,9 @@ cuda_py_test( ":multi_worker_test_base", ":multi_worker_util", ":reduce_util", + ":strategy_combinations", ":strategy_test_lib", + ":test_util", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index 363b84acc51..500782d0e89 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -767,3 +767,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): Boolean. """ return True + + def _get_replica_id_in_sync_group(self, replica_id): + return self._id_in_cluster * len(self.worker_devices) + replica_id + + def _get_local_replica_id(self, replica_id_in_sync_group): + return (replica_id_in_sync_group - + self._id_in_cluster * len(self.worker_devices)) diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py index 67e156f1a3d..4faad331a06 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py @@ -32,11 +32,14 @@ from tensorflow.python.distribute import combinations from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_utils +from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_test_lib +from tensorflow.python.distribute import test_util from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import context from tensorflow.python.framework import config as tf_config @@ -598,5 +601,29 @@ class LogicalDeviceTest(test.TestCase, parameterized.TestCase): context._reset_context() # pylint: disable=protected-access +@combinations.generate( + combinations.combine( + strategy=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + ], + mode=['eager'])) +class CollectiveAllReduceStrategyV2Test(test.TestCase, parameterized.TestCase): + + def test_replica_id_in_sync_group(self, strategy): + + def replica_fn(): + replica_ctx = distribution_strategy_context.get_replica_context() + return replica_ctx.replica_id_in_sync_group, replica_ctx._replica_id + + results = test_util.gather(strategy, strategy.run(replica_fn)) + self.assertAllEqual(list(range(strategy.extended._num_replicas_in_sync)), + results[0].numpy()) + self.assertAllEqual( + list(range(len(strategy.extended.worker_devices))) * + strategy.extended._num_workers, results[1].numpy()) + + if __name__ == '__main__': - test.main() + test_util.main() diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index c4b14a849a3..9f84067e8f9 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -2880,6 +2880,11 @@ class ReplicaContext(object): raise ValueError( "replica_id_in_sync_group can only be an integer, a Tensor or None.") self._replica_id_in_sync_group = replica_id_in_sync_group + # We need this check becaused TPUContext extends from ReplicaContext and + # does not pass a strategy object since it is used by TPUEstimator. + if strategy: + self._local_replica_id = strategy.extended._get_local_replica_id( + replica_id_in_sync_group) self._summary_recording_distribution_strategy = None @doc_controls.do_not_generate_docs @@ -2979,6 +2984,11 @@ class ReplicaContext(object): dtypes.int32, name="replica_id_in_sync_group") + @property + def _replica_id(self): + """This is the local replica id in a given sync group.""" + return self._local_replica_id + @property def strategy(self): """The current `tf.distribute.Strategy` object.""" @@ -3404,6 +3414,12 @@ class _DefaultDistributionExtended(StrategyExtendedV1): def should_save_summary(self): return True + def _get_local_replica_id(self, replica_id_in_sync_group): + return replica_id_in_sync_group + + def _get_replica_id_in_sync_group(self, replica_id): + return replica_id + # TODO(priyag): This should inherit from `InputIterator`, once dependency # issues have been resolved. class DefaultInputIterator(object): diff --git a/tensorflow/python/distribute/distribute_lib_test.py b/tensorflow/python/distribute/distribute_lib_test.py index 0fe05b52d6f..938cc42f070 100644 --- a/tensorflow/python/distribute/distribute_lib_test.py +++ b/tensorflow/python/distribute/distribute_lib_test.py @@ -124,6 +124,9 @@ class _TestExtended(distribute_lib.StrategyExtendedV1): else: return nest.map_structure(self._unwrap, result) + def _get_local_replica_id(self, replica_id_in_sync_group): + return replica_id_in_sync_group + def _assert_in_default_state(t): t.assertIs(ds_context._get_default_replica_context(), @@ -161,7 +164,7 @@ class TestStrategyTest(test.TestCase): def run_fn(): replica_context = ds_context.get_replica_context() - self.assertTrue(replica_context is not None) + self.assertIsNotNone(replica_context) self.assertIs(None, ds_context.get_cross_replica_context()) self.assertFalse(ds_context.in_cross_replica_context()) self.assertTrue(ds_context.has_strategy()) diff --git a/tensorflow/python/distribute/mirrored_run.py b/tensorflow/python/distribute/mirrored_run.py index 2cf23e96e67..4f1f48d30cc 100644 --- a/tensorflow/python/distribute/mirrored_run.py +++ b/tensorflow/python/distribute/mirrored_run.py @@ -246,6 +246,9 @@ class _MirroredReplicaThread(threading.Thread): self.distribution = dist self.devices = devices self.replica_id = replica_id + self.replica_id_in_sync_group = ( + dist.extended._get_replica_id_in_sync_group(replica_id)) # pylint: disable=protected-access + self.variable_creator_fn = variable_creator_fn # State needed to run and return the results of `fn`. self.main_fn = fn @@ -310,7 +313,8 @@ class _MirroredReplicaThread(threading.Thread): _enter_graph(self.graph, self.in_eager, self._variable_creator_stack), \ context.device_policy(self.context_device_policy), \ - _MirroredReplicaContext(self.distribution, self.replica_id), \ + _MirroredReplicaContext(self.distribution, + self.replica_id_in_sync_group), \ ops.device(self.devices[self.replica_id]), \ ops.name_scope(self._name_scope), \ variable_scope.variable_scope( diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index fc360a98f1e..d573b3966e3 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -758,3 +758,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): def _in_multi_worker_mode(self): """Whether this strategy indicates working in multi-worker settings.""" return False + + def _get_local_replica_id(self, replica_id_in_sync_group): + return replica_id_in_sync_group + + def _get_replica_id_in_sync_group(self, replica_id): + return replica_id diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py index 5ef2ce03efe..3d5175d9055 100644 --- a/tensorflow/python/distribute/one_device_strategy.py +++ b/tensorflow/python/distribute/one_device_strategy.py @@ -458,6 +458,9 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1): def _support_per_replica_values(self): return False + def _get_local_replica_id(self, replica_id_in_sync_group): + return replica_id_in_sync_group + class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): """ReplicaContext for OneDeviceStrategy.""" diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index e495efff62a..5b10edc69b6 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -693,3 +693,9 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): Boolean. """ return True + + def _get_local_replica_id(self, replica_id_in_sync_group): + return replica_id_in_sync_group + + def _get_replica_id_in_sync_group(self, replica_id): + return replica_id diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 27d9bb9da90..ae27d6d14bb 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -1345,6 +1345,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): # TPUStrategy. return False + def _get_local_replica_id(self, replica_id_in_sync_group): + return replica_id_in_sync_group + class _TPUReplicaContext(distribute_lib.ReplicaContext): """Replication Context class for TPU Strategy.""" diff --git a/tensorflow/python/distribute/values_util.py b/tensorflow/python/distribute/values_util.py index d1ba958ae0f..9653be0087e 100644 --- a/tensorflow/python/distribute/values_util.py +++ b/tensorflow/python/distribute/values_util.py @@ -261,7 +261,7 @@ def get_current_replica_id_as_int(): """Returns the current replica ID as an integer, or `None`.""" replica_context = ds_context.get_replica_context() if replica_context: - replica_id = replica_context.replica_id_in_sync_group + replica_id = replica_context._replica_id # pylint: disable=protected-access if not isinstance(replica_id, int): replica_id = tensor_util.constant_value(replica_id) else: