diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index eeef87f5765..6a133c7d4b8 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function import copy +import threading +import time import weakref from tensorflow.core.protobuf import rewriter_config_pb2 @@ -37,6 +39,8 @@ from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops @@ -176,6 +180,16 @@ class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1): class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): """Implementation of CollectiveAllReduceStrategy.""" + # Whether to perdically check the health of the cluster. If any worker is not + # reachable, collectives are aborted and the user program should get a + # tf.errors.UnavailableError. It's required to restart in order to recover. + _enable_check_health = False + # Check health interval in seconds. + _check_health_interval = 30 + # Timeout in seconds for the first check health. The first check health needs + # to wait for cluster, which may make a longer time. + _check_health_initial_timeout = 1200 + def __init__(self, container_strategy, communication, @@ -370,6 +384,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._rpc_layer = cluster_resolver.rpc_layer self._warn_nccl_no_gpu() + # TODO(b/151232436): Enable check health thread by default. + if self._enable_check_health: + self._start_check_health_thread() + logging.info( "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, " "task_id = %r, num_workers = %r, local_devices = %r, " @@ -377,6 +395,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): task_id, self._num_workers, local_devices, self._communication) + def __del__(self): + if self._enable_check_health: + self._stop_check_health_thread() + def _input_workers_with_options(self, options=None): host_device = device_util.get_host_for_device(self._worker_device) if not options or options.experimental_prefetch_to_device: @@ -607,6 +629,88 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): destinations=destinations, experimental_hints=experimental_hints) + def _check_health(self, device, group_key, instance_key): + first = True + # We need to use a large enough value so that the all-reduce forms a + # complete RING. In RING implementation, when value is too small, the + # all-reduce may degrade into broadcasts. This means that some worker + # failure may not be detected. + value = array_ops.ones((32, 32), dtype=dtypes.float32) + while True: + if self._check_health_thread_should_stop.is_set(): + return + timeout = None + if first: + # For the first check health we set timeout since it may need to do + # group resolution, which may hang if the cluster is never healthy. + timeout = self._check_health_initial_timeout + first = False + try: + # We use an dummy all-reduce as a way to check the health of a cluster. + # For RING it should be able to detect failed workers in the cluster if + # the values are large enough. + # + # We're not using CrossDeviceOps because we need to run it with + # pre-allocated group and instance keys. + # + # TODO(b/151232436): Replace the reduce with a check health op once we + # add that. + with ops.device(device): + collective_ops.all_reduce( + value, + group_size=self._num_workers, + group_key=group_key, + instance_key=instance_key, + merge_op="Add", + final_op="Id", + subdiv_offsets=[0], + communication_hint="ring", + timeout=timeout) + if context.is_async(): + context.async_wait() + except (errors.UnavailableError, errors.DeadlineExceededError, + errors.FailedPreconditionError, errors.CancelledError) as e: + # TODO(b/151232436): Always raise UnavailableError when a peer fails. + # Now there could be many kinds of errors: + # - Unavailable: when the peer is not reachable, e.g. it's down. + # - FailedPrecondition: when the peer has restarted. + # - DeadlineExceeded: when the first check health exceeds the deadline, + # e.g. the peers take too long to be ready. + # - Cancelled: when failures in organic collectives aborts first, + # outgoing RPCs may be aborted with Cancelled. + logging.error("Cluster check alive failed, aborting collectives") + context.context().abort_collective_ops( + errors.UNAVAILABLE, "cluster check alive failed: %s" % e) + except Exception as e: # pylint: disable=broad-except + logging.exception("Unexpected exception in check alive.") + context.context().abort_collective_ops( + errors.INTERNAL, "unexecpted exception in check alive: %s" % e) + return + time.sleep(self._check_health_interval) + + def _start_check_health_thread(self): + # Allocate group and instance key before starting the thread to avoid + # indeterminism. There can only be one thread that assigns group keys and + # instance keys, otherwise different workers may end up with unmatched keys + # since execution order between threads are arbitrary. + device = device_util.canonicalize(self._worker_device) + group_key = self._collective_keys.get_group_key([device]) + instance_key = self._collective_keys.get_op_instance_key() + self._check_health_thread_should_stop = threading.Event() + # Start the thread as daemon to avoid it blocking the program from exiting. + # We try best to shutdown the thread but __del__ is not guaranteed to be + # called when program exists. + self._check_health_thread = threading.Thread( + target=self._check_health, + args=(device, group_key, instance_key), + daemon=True) + self._check_health_thread.start() + + def _stop_check_health_thread(self): + self._check_health_thread_should_stop.set() + self._check_health_thread.join() + self._check_health_thread = None + def _warn_nccl_no_gpu(self): if ((self._communication == cross_device_ops_lib.CollectiveCommunication.NCCL) and diff --git a/tensorflow/python/distribute/integration_test/BUILD b/tensorflow/python/distribute/integration_test/BUILD index 307f2580996..361c8a42dbe 100644 --- a/tensorflow/python/distribute/integration_test/BUILD +++ b/tensorflow/python/distribute/integration_test/BUILD @@ -32,6 +32,7 @@ cuda_py_test( ], deps = [ "//tensorflow:tensorflow_py", + "//tensorflow/python/distribute:collective_all_reduce_strategy", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:multi_process_runner", "//tensorflow/python/distribute:multi_worker_test_base", diff --git a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py index c247be1c280..003fb5f1a33 100644 --- a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py +++ b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py @@ -26,12 +26,19 @@ import os import tensorflow as tf +from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib from tensorflow.python.distribute import combinations from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.eager import test +# Put it in top level so it executes in the child processes as well. +mwms_lib.CollectiveAllReduceExtended._enable_check_health = True +mwms_lib.CollectiveAllReduceExtended._check_health_interval = 3 +mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 6 + + def get_attempt(strategy, attempts): task_type = strategy.cluster_resolver.task_type task_id = strategy.cluster_resolver.task_id @@ -62,11 +69,70 @@ class PeerFailureTest(test.TestCase): # events in real world. E.g. some tests make a worker fail on the first # attempt only, and asserts that it should recovery. - def test_creating_variable_broken(self): + def test_creating_variable(self): # This test simulates the case when a worker fails before or during creating # a variable. Creating variables involve broadcasting the initial value from # the first replica to all replicas. + def worker_fn(): + strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() + with strategy.scope(): + tf.Variable(1.) + # worker-1 dies here. + if strategy.cluster_resolver.task_id == 1: + quick_exit(1) + v = tf.Variable(tf.random.uniform(())) + return v.read_value().numpy() + + cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) + mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec) + mpr.start() + # TODO(b/151232436): Always raise UnavailableError when a peer fails. + with self.assertRaises( + (tf.errors.UnavailableError, tf.errors.DeadlineExceededError)): + mpr.join(timeout=30) + + def test_reduce_small_tensor(self): + # This test simulates the case when a worker fails before or during reducing + # a small tensors, e.g. reading a metric. + # + # Note that this is written for a specific corner case that used to happen + # only when all of the following conditions are met: + # - There're two workers. + # - They're reducing a small tensor. The definition of small varies + # per platform. + # - They're reducing a single tensor. Batched all-reduce are not affected. + # - It must be worker-1 that fails. + # Under this case, the all-reduce is effectively two send/recv operation, + # the first one from worker-0 to worker-1, and the second one vice versa. + # The first one blocks the second one. In send/recv, the sending party is + # not aware of the failures of the receiving party. + + def worker_fn(): + strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() + value = tf.identity([1.]) + strategy.reduce("sum", value, axis=None) + # worker-1 dies here. + if strategy.cluster_resolver.task_id == 1: + quick_exit(1) + strategy.reduce("sum", value, axis=None) + + cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) + mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec) + mpr.start() + # TODO(b/151232436): Always raise UnavailableError when a peer fails. + with self.assertRaises( + (tf.errors.UnavailableError, tf.errors.DeadlineExceededError)): + mpr.join(timeout=30) + + +class PeerFailureRecoverTest(test.TestCase): + # Similar to PeerFailureTest but simulates the situation where there's some + # external system that automatically restarts failed workers. + + def test_creating_variable(self): + # See PeerFailureTest.test_creating_variable + def worker_fn(attempts): strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() task_id, attempt = get_attempt(strategy, attempts) @@ -83,23 +149,11 @@ class PeerFailureTest(test.TestCase): mpr = multi_process_runner.MultiProcessRunner( worker_fn, cluster_spec, args=(attempts,), auto_restart=True) mpr.start() - # TODO(b/151232436): worker-0 should raises Unavailable instead of hanging. - # Now after worker-1 fails, worker-0 waits on the second variable creation; - # after worker-1 recovers, worker-1 waits on the first variable creation. - with self.assertRaises(multi_process_runner.SubprocessTimeoutError): - mpr.join(timeout=30) + results = mpr.join(timeout=90).return_value + self.assertEqual(results[0], results[1]) - def test_reduce_small_tensor_broken(self): - # This test simulates the case when a worker fails before or during reducing - # a small tensors, e.g. reading a metric. - # - # Note that this is a rather corner case and only happens when all of the - # following conditions are met: - # - There're two workers. - # - They're reducing a small tensor. The definition of small varies - # per platform. - # - They're reducing a single tensor. Batched all-reduce are not affected. - # - It must be worker-1 that fails. + def test_reduce_small_tensor(self): + # See PeerFailureTest.test_reduce_small_tensor def worker_fn(attempts): strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() @@ -109,18 +163,15 @@ class PeerFailureTest(test.TestCase): # worker-1 dies here. if attempt == 1 and task_id == 1: quick_exit(1) - strategy.reduce("sum", value, axis=None) + return strategy.reduce("sum", value, axis=None).numpy() cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) attempts = multi_process_runner.manager().dict() mpr = multi_process_runner.MultiProcessRunner( worker_fn, cluster_spec, args=(attempts,), auto_restart=True) mpr.start() - # TODO(b/151232436): worker-0 should raises Unavailable instead of hanging. - # Now after worker-1 fails, worker-0 waits on the second reduce; after - # worker-1 recovers, worker-1 waits on the first reduce. - with self.assertRaises(multi_process_runner.SubprocessTimeoutError): - mpr.join(timeout=30) + results = mpr.join(timeout=90).return_value + self.assertAllEqual(results, [[2.], [2.]]) def test_quick_recover(self): # This test simulates the case when a worker fails but recovers quickly @@ -131,12 +182,14 @@ class PeerFailureTest(test.TestCase): # failed workers. def worker_fn(attempts): + # Set a long check alive interval to better simulate the case when a + # worker fails and recovers during a check alive interval. + mwms_lib.CollectiveAllReduceExtended._check_alive_interval = 30 + mwms_lib.CollectiveAllReduceExtended._check_alive_initial_timeout = 30 + strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() task_id, attempt = get_attempt(strategy, attempts) - if attempt == 2 and task_id == 1: - multi_process_runner.barrier().wait() - @tf.function def replica_fn(): ctx = tf.distribute.get_replica_context() @@ -149,10 +202,6 @@ class PeerFailureTest(test.TestCase): # worker-1 dies here. if attempt == 1 and task_id == 1: quick_exit(1) - # Make worker-0 waits for worker-1 to restart before entering the next - # collective to simulate a quick recovery of worker-1. - if attempt == 1 and task_id == 0: - multi_process_runner.barrier().wait() strategy.run(replica_fn) cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)