diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index e97f5eef322..6724a228513 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -189,10 +189,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): _check_health_interval = 30 # Timeout in seconds for the first check health. The first check health needs # to wait for cluster, which may make a longer time. - # - # TODO(b/151232436): now the inital barrier may hang in a rare case, so we - # need a finite timeout. - _check_health_initial_timeout = 1200 + _check_health_initial_timeout = 0 # Times to retry before considering the peer is down. _check_health_retry_limit = 3 @@ -683,8 +680,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): # # TODO(b/151232436): change to an explicit barrier if we have it. dummy_value = ops.convert_to_tensor([]) - logging.info("Waiting for the cluster, timeout = %d", - self._check_health_initial_timeout) + logging.info("Waiting for the cluster, timeout = %s", + self._check_health_initial_timeout or "inf") try: self._host_cross_device_ops.reduce( reduce_util.ReduceOp.SUM, diff --git a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py index 003fb5f1a33..6d822ca1b97 100644 --- a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py +++ b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py @@ -36,7 +36,7 @@ from tensorflow.python.eager import test # Put it in top level so it executes in the child processes as well. mwms_lib.CollectiveAllReduceExtended._enable_check_health = True mwms_lib.CollectiveAllReduceExtended._check_health_interval = 3 -mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 6 +mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 0 def get_attempt(strategy, attempts):