Return the correct replica id within a sync group for MWMS. Currently we return the local replica id within a worker as opposed to within a sync group.
PiperOrigin-RevId: 336352000 Change-Id: Ie8936ee327340f04ea4edfae6aada6bf719f3488
This commit is contained in:
parent
c7f6ac1762
commit
540f6db6f8
@ -1524,7 +1524,9 @@ cuda_py_test(
|
||||
":multi_worker_test_base",
|
||||
":multi_worker_util",
|
||||
":reduce_util",
|
||||
":strategy_combinations",
|
||||
":strategy_test_lib",
|
||||
":test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
|
@ -767,3 +767,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
Boolean.
|
||||
"""
|
||||
return True
|
||||
|
||||
def _get_replica_id_in_sync_group(self, replica_id):
|
||||
return self._id_in_cluster * len(self.worker_devices) + replica_id
|
||||
|
||||
def _get_local_replica_id(self, replica_id_in_sync_group):
|
||||
return (replica_id_in_sync_group -
|
||||
self._id_in_cluster * len(self.worker_devices))
|
||||
|
@ -32,11 +32,14 @@ from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import strategy_test_lib
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import config as tf_config
|
||||
@ -598,5 +601,29 @@ class LogicalDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
context._reset_context() # pylint: disable=protected-access
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
strategy=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
mode=['eager']))
|
||||
class CollectiveAllReduceStrategyV2Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_replica_id_in_sync_group(self, strategy):
|
||||
|
||||
def replica_fn():
|
||||
replica_ctx = distribution_strategy_context.get_replica_context()
|
||||
return replica_ctx.replica_id_in_sync_group, replica_ctx._replica_id
|
||||
|
||||
results = test_util.gather(strategy, strategy.run(replica_fn))
|
||||
self.assertAllEqual(list(range(strategy.extended._num_replicas_in_sync)),
|
||||
results[0].numpy())
|
||||
self.assertAllEqual(
|
||||
list(range(len(strategy.extended.worker_devices))) *
|
||||
strategy.extended._num_workers, results[1].numpy())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
test_util.main()
|
||||
|
@ -2880,6 +2880,11 @@ class ReplicaContext(object):
|
||||
raise ValueError(
|
||||
"replica_id_in_sync_group can only be an integer, a Tensor or None.")
|
||||
self._replica_id_in_sync_group = replica_id_in_sync_group
|
||||
# We need this check becaused TPUContext extends from ReplicaContext and
|
||||
# does not pass a strategy object since it is used by TPUEstimator.
|
||||
if strategy:
|
||||
self._local_replica_id = strategy.extended._get_local_replica_id(
|
||||
replica_id_in_sync_group)
|
||||
self._summary_recording_distribution_strategy = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
@ -2979,6 +2984,11 @@ class ReplicaContext(object):
|
||||
dtypes.int32,
|
||||
name="replica_id_in_sync_group")
|
||||
|
||||
@property
|
||||
def _replica_id(self):
|
||||
"""This is the local replica id in a given sync group."""
|
||||
return self._local_replica_id
|
||||
|
||||
@property
|
||||
def strategy(self):
|
||||
"""The current `tf.distribute.Strategy` object."""
|
||||
@ -3404,6 +3414,12 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
|
||||
def should_save_summary(self):
|
||||
return True
|
||||
|
||||
def _get_local_replica_id(self, replica_id_in_sync_group):
|
||||
return replica_id_in_sync_group
|
||||
|
||||
def _get_replica_id_in_sync_group(self, replica_id):
|
||||
return replica_id
|
||||
|
||||
# TODO(priyag): This should inherit from `InputIterator`, once dependency
|
||||
# issues have been resolved.
|
||||
class DefaultInputIterator(object):
|
||||
|
@ -124,6 +124,9 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
|
||||
else:
|
||||
return nest.map_structure(self._unwrap, result)
|
||||
|
||||
def _get_local_replica_id(self, replica_id_in_sync_group):
|
||||
return replica_id_in_sync_group
|
||||
|
||||
|
||||
def _assert_in_default_state(t):
|
||||
t.assertIs(ds_context._get_default_replica_context(),
|
||||
@ -161,7 +164,7 @@ class TestStrategyTest(test.TestCase):
|
||||
|
||||
def run_fn():
|
||||
replica_context = ds_context.get_replica_context()
|
||||
self.assertTrue(replica_context is not None)
|
||||
self.assertIsNotNone(replica_context)
|
||||
self.assertIs(None, ds_context.get_cross_replica_context())
|
||||
self.assertFalse(ds_context.in_cross_replica_context())
|
||||
self.assertTrue(ds_context.has_strategy())
|
||||
|
@ -246,6 +246,9 @@ class _MirroredReplicaThread(threading.Thread):
|
||||
self.distribution = dist
|
||||
self.devices = devices
|
||||
self.replica_id = replica_id
|
||||
self.replica_id_in_sync_group = (
|
||||
dist.extended._get_replica_id_in_sync_group(replica_id)) # pylint: disable=protected-access
|
||||
|
||||
self.variable_creator_fn = variable_creator_fn
|
||||
# State needed to run and return the results of `fn`.
|
||||
self.main_fn = fn
|
||||
@ -310,7 +313,8 @@ class _MirroredReplicaThread(threading.Thread):
|
||||
_enter_graph(self.graph, self.in_eager,
|
||||
self._variable_creator_stack), \
|
||||
context.device_policy(self.context_device_policy), \
|
||||
_MirroredReplicaContext(self.distribution, self.replica_id), \
|
||||
_MirroredReplicaContext(self.distribution,
|
||||
self.replica_id_in_sync_group), \
|
||||
ops.device(self.devices[self.replica_id]), \
|
||||
ops.name_scope(self._name_scope), \
|
||||
variable_scope.variable_scope(
|
||||
|
@ -758,3 +758,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
def _in_multi_worker_mode(self):
|
||||
"""Whether this strategy indicates working in multi-worker settings."""
|
||||
return False
|
||||
|
||||
def _get_local_replica_id(self, replica_id_in_sync_group):
|
||||
return replica_id_in_sync_group
|
||||
|
||||
def _get_replica_id_in_sync_group(self, replica_id):
|
||||
return replica_id
|
||||
|
@ -458,6 +458,9 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
||||
def _support_per_replica_values(self):
|
||||
return False
|
||||
|
||||
def _get_local_replica_id(self, replica_id_in_sync_group):
|
||||
return replica_id_in_sync_group
|
||||
|
||||
|
||||
class _OneDeviceReplicaContext(distribute_lib.ReplicaContext):
|
||||
"""ReplicaContext for OneDeviceStrategy."""
|
||||
|
@ -693,3 +693,9 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
Boolean.
|
||||
"""
|
||||
return True
|
||||
|
||||
def _get_local_replica_id(self, replica_id_in_sync_group):
|
||||
return replica_id_in_sync_group
|
||||
|
||||
def _get_replica_id_in_sync_group(self, replica_id):
|
||||
return replica_id
|
||||
|
@ -1345,6 +1345,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
# TPUStrategy.
|
||||
return False
|
||||
|
||||
def _get_local_replica_id(self, replica_id_in_sync_group):
|
||||
return replica_id_in_sync_group
|
||||
|
||||
|
||||
class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
||||
"""Replication Context class for TPU Strategy."""
|
||||
|
@ -261,7 +261,7 @@ def get_current_replica_id_as_int():
|
||||
"""Returns the current replica ID as an integer, or `None`."""
|
||||
replica_context = ds_context.get_replica_context()
|
||||
if replica_context:
|
||||
replica_id = replica_context.replica_id_in_sync_group
|
||||
replica_id = replica_context._replica_id # pylint: disable=protected-access
|
||||
if not isinstance(replica_id, int):
|
||||
replica_id = tensor_util.constant_value(replica_id)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user