Return the correct replica id within a sync group for MWMS. Currently we return the local replica id within a worker as opposed to within a sync group.

PiperOrigin-RevId: 336352000
Change-Id: Ie8936ee327340f04ea4edfae6aada6bf719f3488
This commit is contained in:
Anjali Sridhar 2020-10-09 13:15:05 -07:00 committed by TensorFlower Gardener
parent c7f6ac1762
commit 540f6db6f8
11 changed files with 81 additions and 4 deletions

View File

@ -1524,7 +1524,9 @@ cuda_py_test(
":multi_worker_test_base",
":multi_worker_util",
":reduce_util",
":strategy_combinations",
":strategy_test_lib",
":test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",

View File

@ -767,3 +767,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
Boolean.
"""
return True
def _get_replica_id_in_sync_group(self, replica_id):
return self._id_in_cluster * len(self.worker_devices) + replica_id
def _get_local_replica_id(self, replica_id_in_sync_group):
return (replica_id_in_sync_group -
self._id_in_cluster * len(self.worker_devices))

View File

@ -32,11 +32,14 @@ from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute import test_util
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import context
from tensorflow.python.framework import config as tf_config
@ -598,5 +601,29 @@ class LogicalDeviceTest(test.TestCase, parameterized.TestCase):
context._reset_context() # pylint: disable=protected-access
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
mode=['eager']))
class CollectiveAllReduceStrategyV2Test(test.TestCase, parameterized.TestCase):
def test_replica_id_in_sync_group(self, strategy):
def replica_fn():
replica_ctx = distribution_strategy_context.get_replica_context()
return replica_ctx.replica_id_in_sync_group, replica_ctx._replica_id
results = test_util.gather(strategy, strategy.run(replica_fn))
self.assertAllEqual(list(range(strategy.extended._num_replicas_in_sync)),
results[0].numpy())
self.assertAllEqual(
list(range(len(strategy.extended.worker_devices))) *
strategy.extended._num_workers, results[1].numpy())
if __name__ == '__main__':
test.main()
test_util.main()

View File

@ -2880,6 +2880,11 @@ class ReplicaContext(object):
raise ValueError(
"replica_id_in_sync_group can only be an integer, a Tensor or None.")
self._replica_id_in_sync_group = replica_id_in_sync_group
# We need this check becaused TPUContext extends from ReplicaContext and
# does not pass a strategy object since it is used by TPUEstimator.
if strategy:
self._local_replica_id = strategy.extended._get_local_replica_id(
replica_id_in_sync_group)
self._summary_recording_distribution_strategy = None
@doc_controls.do_not_generate_docs
@ -2979,6 +2984,11 @@ class ReplicaContext(object):
dtypes.int32,
name="replica_id_in_sync_group")
@property
def _replica_id(self):
"""This is the local replica id in a given sync group."""
return self._local_replica_id
@property
def strategy(self):
"""The current `tf.distribute.Strategy` object."""
@ -3404,6 +3414,12 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
def should_save_summary(self):
return True
def _get_local_replica_id(self, replica_id_in_sync_group):
return replica_id_in_sync_group
def _get_replica_id_in_sync_group(self, replica_id):
return replica_id
# TODO(priyag): This should inherit from `InputIterator`, once dependency
# issues have been resolved.
class DefaultInputIterator(object):

View File

@ -124,6 +124,9 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
else:
return nest.map_structure(self._unwrap, result)
def _get_local_replica_id(self, replica_id_in_sync_group):
return replica_id_in_sync_group
def _assert_in_default_state(t):
t.assertIs(ds_context._get_default_replica_context(),
@ -161,7 +164,7 @@ class TestStrategyTest(test.TestCase):
def run_fn():
replica_context = ds_context.get_replica_context()
self.assertTrue(replica_context is not None)
self.assertIsNotNone(replica_context)
self.assertIs(None, ds_context.get_cross_replica_context())
self.assertFalse(ds_context.in_cross_replica_context())
self.assertTrue(ds_context.has_strategy())

View File

@ -246,6 +246,9 @@ class _MirroredReplicaThread(threading.Thread):
self.distribution = dist
self.devices = devices
self.replica_id = replica_id
self.replica_id_in_sync_group = (
dist.extended._get_replica_id_in_sync_group(replica_id)) # pylint: disable=protected-access
self.variable_creator_fn = variable_creator_fn
# State needed to run and return the results of `fn`.
self.main_fn = fn
@ -310,7 +313,8 @@ class _MirroredReplicaThread(threading.Thread):
_enter_graph(self.graph, self.in_eager,
self._variable_creator_stack), \
context.device_policy(self.context_device_policy), \
_MirroredReplicaContext(self.distribution, self.replica_id), \
_MirroredReplicaContext(self.distribution,
self.replica_id_in_sync_group), \
ops.device(self.devices[self.replica_id]), \
ops.name_scope(self._name_scope), \
variable_scope.variable_scope(

View File

@ -758,3 +758,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
def _in_multi_worker_mode(self):
"""Whether this strategy indicates working in multi-worker settings."""
return False
def _get_local_replica_id(self, replica_id_in_sync_group):
return replica_id_in_sync_group
def _get_replica_id_in_sync_group(self, replica_id):
return replica_id

View File

@ -458,6 +458,9 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
def _support_per_replica_values(self):
return False
def _get_local_replica_id(self, replica_id_in_sync_group):
return replica_id_in_sync_group
class _OneDeviceReplicaContext(distribute_lib.ReplicaContext):
"""ReplicaContext for OneDeviceStrategy."""

View File

@ -693,3 +693,9 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
Boolean.
"""
return True
def _get_local_replica_id(self, replica_id_in_sync_group):
return replica_id_in_sync_group
def _get_replica_id_in_sync_group(self, replica_id):
return replica_id

View File

@ -1345,6 +1345,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
# TPUStrategy.
return False
def _get_local_replica_id(self, replica_id_in_sync_group):
return replica_id_in_sync_group
class _TPUReplicaContext(distribute_lib.ReplicaContext):
"""Replication Context class for TPU Strategy."""

View File

@ -261,7 +261,7 @@ def get_current_replica_id_as_int():
"""Returns the current replica ID as an integer, or `None`."""
replica_context = ds_context.get_replica_context()
if replica_context:
replica_id = replica_context.replica_id_in_sync_group
replica_id = replica_context._replica_id # pylint: disable=protected-access
if not isinstance(replica_id, int):
replica_id = tensor_util.constant_value(replica_id)
else: