Return the correct replica id within a sync group for MWMS. Currently we return the local replica id within a worker as opposed to within a sync group.
PiperOrigin-RevId: 336352000 Change-Id: Ie8936ee327340f04ea4edfae6aada6bf719f3488
This commit is contained in:
parent
c7f6ac1762
commit
540f6db6f8
@ -1524,7 +1524,9 @@ cuda_py_test(
|
|||||||
":multi_worker_test_base",
|
":multi_worker_test_base",
|
||||||
":multi_worker_util",
|
":multi_worker_util",
|
||||||
":reduce_util",
|
":reduce_util",
|
||||||
|
":strategy_combinations",
|
||||||
":strategy_test_lib",
|
":strategy_test_lib",
|
||||||
|
":test_util",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:constant_op",
|
"//tensorflow/python:constant_op",
|
||||||
|
@ -767,3 +767,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
Boolean.
|
Boolean.
|
||||||
"""
|
"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _get_replica_id_in_sync_group(self, replica_id):
|
||||||
|
return self._id_in_cluster * len(self.worker_devices) + replica_id
|
||||||
|
|
||||||
|
def _get_local_replica_id(self, replica_id_in_sync_group):
|
||||||
|
return (replica_id_in_sync_group -
|
||||||
|
self._id_in_cluster * len(self.worker_devices))
|
||||||
|
@ -32,11 +32,14 @@ from tensorflow.python.distribute import combinations
|
|||||||
from tensorflow.python.distribute import cross_device_utils
|
from tensorflow.python.distribute import cross_device_utils
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
from tensorflow.python.distribute import distribute_utils
|
from tensorflow.python.distribute import distribute_utils
|
||||||
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.distribute import input_lib
|
from tensorflow.python.distribute import input_lib
|
||||||
from tensorflow.python.distribute import multi_worker_test_base
|
from tensorflow.python.distribute import multi_worker_test_base
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
from tensorflow.python.distribute import strategy_test_lib
|
from tensorflow.python.distribute import strategy_test_lib
|
||||||
|
from tensorflow.python.distribute import test_util
|
||||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import config as tf_config
|
from tensorflow.python.framework import config as tf_config
|
||||||
@ -598,5 +601,29 @@ class LogicalDeviceTest(test.TestCase, parameterized.TestCase):
|
|||||||
context._reset_context() # pylint: disable=protected-access
|
context._reset_context() # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
strategy=[
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||||
|
],
|
||||||
|
mode=['eager']))
|
||||||
|
class CollectiveAllReduceStrategyV2Test(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_replica_id_in_sync_group(self, strategy):
|
||||||
|
|
||||||
|
def replica_fn():
|
||||||
|
replica_ctx = distribution_strategy_context.get_replica_context()
|
||||||
|
return replica_ctx.replica_id_in_sync_group, replica_ctx._replica_id
|
||||||
|
|
||||||
|
results = test_util.gather(strategy, strategy.run(replica_fn))
|
||||||
|
self.assertAllEqual(list(range(strategy.extended._num_replicas_in_sync)),
|
||||||
|
results[0].numpy())
|
||||||
|
self.assertAllEqual(
|
||||||
|
list(range(len(strategy.extended.worker_devices))) *
|
||||||
|
strategy.extended._num_workers, results[1].numpy())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test_util.main()
|
||||||
|
@ -2880,6 +2880,11 @@ class ReplicaContext(object):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"replica_id_in_sync_group can only be an integer, a Tensor or None.")
|
"replica_id_in_sync_group can only be an integer, a Tensor or None.")
|
||||||
self._replica_id_in_sync_group = replica_id_in_sync_group
|
self._replica_id_in_sync_group = replica_id_in_sync_group
|
||||||
|
# We need this check becaused TPUContext extends from ReplicaContext and
|
||||||
|
# does not pass a strategy object since it is used by TPUEstimator.
|
||||||
|
if strategy:
|
||||||
|
self._local_replica_id = strategy.extended._get_local_replica_id(
|
||||||
|
replica_id_in_sync_group)
|
||||||
self._summary_recording_distribution_strategy = None
|
self._summary_recording_distribution_strategy = None
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
@ -2979,6 +2984,11 @@ class ReplicaContext(object):
|
|||||||
dtypes.int32,
|
dtypes.int32,
|
||||||
name="replica_id_in_sync_group")
|
name="replica_id_in_sync_group")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _replica_id(self):
|
||||||
|
"""This is the local replica id in a given sync group."""
|
||||||
|
return self._local_replica_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strategy(self):
|
def strategy(self):
|
||||||
"""The current `tf.distribute.Strategy` object."""
|
"""The current `tf.distribute.Strategy` object."""
|
||||||
@ -3404,6 +3414,12 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
|
|||||||
def should_save_summary(self):
|
def should_save_summary(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _get_local_replica_id(self, replica_id_in_sync_group):
|
||||||
|
return replica_id_in_sync_group
|
||||||
|
|
||||||
|
def _get_replica_id_in_sync_group(self, replica_id):
|
||||||
|
return replica_id
|
||||||
|
|
||||||
# TODO(priyag): This should inherit from `InputIterator`, once dependency
|
# TODO(priyag): This should inherit from `InputIterator`, once dependency
|
||||||
# issues have been resolved.
|
# issues have been resolved.
|
||||||
class DefaultInputIterator(object):
|
class DefaultInputIterator(object):
|
||||||
|
@ -124,6 +124,9 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
else:
|
else:
|
||||||
return nest.map_structure(self._unwrap, result)
|
return nest.map_structure(self._unwrap, result)
|
||||||
|
|
||||||
|
def _get_local_replica_id(self, replica_id_in_sync_group):
|
||||||
|
return replica_id_in_sync_group
|
||||||
|
|
||||||
|
|
||||||
def _assert_in_default_state(t):
|
def _assert_in_default_state(t):
|
||||||
t.assertIs(ds_context._get_default_replica_context(),
|
t.assertIs(ds_context._get_default_replica_context(),
|
||||||
@ -161,7 +164,7 @@ class TestStrategyTest(test.TestCase):
|
|||||||
|
|
||||||
def run_fn():
|
def run_fn():
|
||||||
replica_context = ds_context.get_replica_context()
|
replica_context = ds_context.get_replica_context()
|
||||||
self.assertTrue(replica_context is not None)
|
self.assertIsNotNone(replica_context)
|
||||||
self.assertIs(None, ds_context.get_cross_replica_context())
|
self.assertIs(None, ds_context.get_cross_replica_context())
|
||||||
self.assertFalse(ds_context.in_cross_replica_context())
|
self.assertFalse(ds_context.in_cross_replica_context())
|
||||||
self.assertTrue(ds_context.has_strategy())
|
self.assertTrue(ds_context.has_strategy())
|
||||||
|
@ -246,6 +246,9 @@ class _MirroredReplicaThread(threading.Thread):
|
|||||||
self.distribution = dist
|
self.distribution = dist
|
||||||
self.devices = devices
|
self.devices = devices
|
||||||
self.replica_id = replica_id
|
self.replica_id = replica_id
|
||||||
|
self.replica_id_in_sync_group = (
|
||||||
|
dist.extended._get_replica_id_in_sync_group(replica_id)) # pylint: disable=protected-access
|
||||||
|
|
||||||
self.variable_creator_fn = variable_creator_fn
|
self.variable_creator_fn = variable_creator_fn
|
||||||
# State needed to run and return the results of `fn`.
|
# State needed to run and return the results of `fn`.
|
||||||
self.main_fn = fn
|
self.main_fn = fn
|
||||||
@ -310,7 +313,8 @@ class _MirroredReplicaThread(threading.Thread):
|
|||||||
_enter_graph(self.graph, self.in_eager,
|
_enter_graph(self.graph, self.in_eager,
|
||||||
self._variable_creator_stack), \
|
self._variable_creator_stack), \
|
||||||
context.device_policy(self.context_device_policy), \
|
context.device_policy(self.context_device_policy), \
|
||||||
_MirroredReplicaContext(self.distribution, self.replica_id), \
|
_MirroredReplicaContext(self.distribution,
|
||||||
|
self.replica_id_in_sync_group), \
|
||||||
ops.device(self.devices[self.replica_id]), \
|
ops.device(self.devices[self.replica_id]), \
|
||||||
ops.name_scope(self._name_scope), \
|
ops.name_scope(self._name_scope), \
|
||||||
variable_scope.variable_scope(
|
variable_scope.variable_scope(
|
||||||
|
@ -758,3 +758,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
def _in_multi_worker_mode(self):
|
def _in_multi_worker_mode(self):
|
||||||
"""Whether this strategy indicates working in multi-worker settings."""
|
"""Whether this strategy indicates working in multi-worker settings."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _get_local_replica_id(self, replica_id_in_sync_group):
|
||||||
|
return replica_id_in_sync_group
|
||||||
|
|
||||||
|
def _get_replica_id_in_sync_group(self, replica_id):
|
||||||
|
return replica_id
|
||||||
|
@ -458,6 +458,9 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
def _support_per_replica_values(self):
|
def _support_per_replica_values(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _get_local_replica_id(self, replica_id_in_sync_group):
|
||||||
|
return replica_id_in_sync_group
|
||||||
|
|
||||||
|
|
||||||
class _OneDeviceReplicaContext(distribute_lib.ReplicaContext):
|
class _OneDeviceReplicaContext(distribute_lib.ReplicaContext):
|
||||||
"""ReplicaContext for OneDeviceStrategy."""
|
"""ReplicaContext for OneDeviceStrategy."""
|
||||||
|
@ -693,3 +693,9 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
Boolean.
|
Boolean.
|
||||||
"""
|
"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _get_local_replica_id(self, replica_id_in_sync_group):
|
||||||
|
return replica_id_in_sync_group
|
||||||
|
|
||||||
|
def _get_replica_id_in_sync_group(self, replica_id):
|
||||||
|
return replica_id
|
||||||
|
@ -1345,6 +1345,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
# TPUStrategy.
|
# TPUStrategy.
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _get_local_replica_id(self, replica_id_in_sync_group):
|
||||||
|
return replica_id_in_sync_group
|
||||||
|
|
||||||
|
|
||||||
class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
||||||
"""Replication Context class for TPU Strategy."""
|
"""Replication Context class for TPU Strategy."""
|
||||||
|
@ -261,7 +261,7 @@ def get_current_replica_id_as_int():
|
|||||||
"""Returns the current replica ID as an integer, or `None`."""
|
"""Returns the current replica ID as an integer, or `None`."""
|
||||||
replica_context = ds_context.get_replica_context()
|
replica_context = ds_context.get_replica_context()
|
||||||
if replica_context:
|
if replica_context:
|
||||||
replica_id = replica_context.replica_id_in_sync_group
|
replica_id = replica_context._replica_id # pylint: disable=protected-access
|
||||||
if not isinstance(replica_id, int):
|
if not isinstance(replica_id, int):
|
||||||
replica_id = tensor_util.constant_value(replica_id)
|
replica_id = tensor_util.constant_value(replica_id)
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user