Add a check alive thread in MWMS behind a flag
This is now disabled by default but can let us test the idea. The thread keeps checking the cluster and aborts the collective if any worker is not reachable. Currently the only way to recovery is to restart, since once collectives are aborted, all subsequent collectives fail immediately. The thread uses a RING all-reduce as a check mechanism, since we don't have a check alive op yet. PiperOrigin-RevId: 327493026 Change-Id: I9ab9e5be1f5c1a15b3741a9f26e42f25b9a59a12
This commit is contained in:
parent
d179d2d42f
commit
9c828254cd
@ -19,6 +19,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
@ -37,6 +39,8 @@ from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
@ -176,6 +180,16 @@ class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
|
||||
class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
"""Implementation of CollectiveAllReduceStrategy."""
|
||||
|
||||
# Whether to perdically check the health of the cluster. If any worker is not
|
||||
# reachable, collectives are aborted and the user program should get a
|
||||
# tf.errors.UnavailableError. It's required to restart in order to recover.
|
||||
_enable_check_health = False
|
||||
# Check health interval in seconds.
|
||||
_check_health_interval = 30
|
||||
# Timeout in seconds for the first check health. The first check health needs
|
||||
# to wait for cluster, which may make a longer time.
|
||||
_check_health_initial_timeout = 1200
|
||||
|
||||
def __init__(self,
|
||||
container_strategy,
|
||||
communication,
|
||||
@ -370,6 +384,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
self._rpc_layer = cluster_resolver.rpc_layer
|
||||
self._warn_nccl_no_gpu()
|
||||
|
||||
# TODO(b/151232436): Enable check health thread by default.
|
||||
if self._enable_check_health:
|
||||
self._start_check_health_thread()
|
||||
|
||||
logging.info(
|
||||
"MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, "
|
||||
"task_id = %r, num_workers = %r, local_devices = %r, "
|
||||
@ -377,6 +395,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
task_id, self._num_workers, local_devices,
|
||||
self._communication)
|
||||
|
||||
def __del__(self):
|
||||
if self._enable_check_health:
|
||||
self._stop_check_health_thread()
|
||||
|
||||
def _input_workers_with_options(self, options=None):
|
||||
host_device = device_util.get_host_for_device(self._worker_device)
|
||||
if not options or options.experimental_prefetch_to_device:
|
||||
@ -607,6 +629,88 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
destinations=destinations,
|
||||
experimental_hints=experimental_hints)
|
||||
|
||||
def _check_health(self, device, group_key, instance_key):
|
||||
first = True
|
||||
# We need to use a large enough value so that the all-reduce forms a
|
||||
# complete RING. In RING implementation, when value is too small, the
|
||||
# all-reduce may degrade into broadcasts. This means that some worker
|
||||
# failure may not be detected.
|
||||
value = array_ops.ones((32, 32), dtype=dtypes.float32)
|
||||
while True:
|
||||
if self._check_health_thread_should_stop.is_set():
|
||||
return
|
||||
timeout = None
|
||||
if first:
|
||||
# For the first check health we set timeout since it may need to do
|
||||
# group resolution, which may hang if the cluster is never healthy.
|
||||
timeout = self._check_health_initial_timeout
|
||||
first = False
|
||||
try:
|
||||
# We use an dummy all-reduce as a way to check the health of a cluster.
|
||||
# For RING it should be able to detect failed workers in the cluster if
|
||||
# the values are large enough.
|
||||
#
|
||||
# We're not using CrossDeviceOps because we need to run it with
|
||||
# pre-allocated group and instance keys.
|
||||
#
|
||||
# TODO(b/151232436): Replace the reduce with a check health op once we
|
||||
# add that.
|
||||
with ops.device(device):
|
||||
collective_ops.all_reduce(
|
||||
value,
|
||||
group_size=self._num_workers,
|
||||
group_key=group_key,
|
||||
instance_key=instance_key,
|
||||
merge_op="Add",
|
||||
final_op="Id",
|
||||
subdiv_offsets=[0],
|
||||
communication_hint="ring",
|
||||
timeout=timeout)
|
||||
if context.is_async():
|
||||
context.async_wait()
|
||||
except (errors.UnavailableError, errors.DeadlineExceededError,
|
||||
errors.FailedPreconditionError, errors.CancelledError) as e:
|
||||
# TODO(b/151232436): Always raise UnavailableError when a peer fails.
|
||||
# Now there could be many kinds of errors:
|
||||
# - Unavailable: when the peer is not reachable, e.g. it's down.
|
||||
# - FailedPrecondition: when the peer has restarted.
|
||||
# - DeadlineExceeded: when the first check health exceeds the deadline,
|
||||
# e.g. the peers take too long to be ready.
|
||||
# - Cancelled: when failures in organic collectives aborts first,
|
||||
# outgoing RPCs may be aborted with Cancelled.
|
||||
logging.error("Cluster check alive failed, aborting collectives")
|
||||
context.context().abort_collective_ops(
|
||||
errors.UNAVAILABLE, "cluster check alive failed: %s" % e)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
logging.exception("Unexpected exception in check alive.")
|
||||
context.context().abort_collective_ops(
|
||||
errors.INTERNAL, "unexecpted exception in check alive: %s" % e)
|
||||
return
|
||||
time.sleep(self._check_health_interval)
|
||||
|
||||
def _start_check_health_thread(self):
|
||||
# Allocate group and instance key before starting the thread to avoid
|
||||
# indeterminism. There can only be one thread that assigns group keys and
|
||||
# instance keys, otherwise different workers may end up with unmatched keys
|
||||
# since execution order between threads are arbitrary.
|
||||
device = device_util.canonicalize(self._worker_device)
|
||||
group_key = self._collective_keys.get_group_key([device])
|
||||
instance_key = self._collective_keys.get_op_instance_key()
|
||||
self._check_health_thread_should_stop = threading.Event()
|
||||
# Start the thread as daemon to avoid it blocking the program from exiting.
|
||||
# We try best to shutdown the thread but __del__ is not guaranteed to be
|
||||
# called when program exists.
|
||||
self._check_health_thread = threading.Thread(
|
||||
target=self._check_health,
|
||||
args=(device, group_key, instance_key),
|
||||
daemon=True)
|
||||
self._check_health_thread.start()
|
||||
|
||||
def _stop_check_health_thread(self):
|
||||
self._check_health_thread_should_stop.set()
|
||||
self._check_health_thread.join()
|
||||
self._check_health_thread = None
|
||||
|
||||
def _warn_nccl_no_gpu(self):
|
||||
if ((self._communication ==
|
||||
cross_device_ops_lib.CollectiveCommunication.NCCL) and
|
||||
|
@ -32,6 +32,7 @@ cuda_py_test(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python/distribute:collective_all_reduce_strategy",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:multi_process_runner",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
|
@ -26,12 +26,19 @@ import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.eager import test
|
||||
|
||||
|
||||
# Put it in top level so it executes in the child processes as well.
|
||||
mwms_lib.CollectiveAllReduceExtended._enable_check_health = True
|
||||
mwms_lib.CollectiveAllReduceExtended._check_health_interval = 3
|
||||
mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 6
|
||||
|
||||
|
||||
def get_attempt(strategy, attempts):
|
||||
task_type = strategy.cluster_resolver.task_type
|
||||
task_id = strategy.cluster_resolver.task_id
|
||||
@ -62,11 +69,70 @@ class PeerFailureTest(test.TestCase):
|
||||
# events in real world. E.g. some tests make a worker fail on the first
|
||||
# attempt only, and asserts that it should recovery.
|
||||
|
||||
def test_creating_variable_broken(self):
|
||||
def test_creating_variable(self):
|
||||
# This test simulates the case when a worker fails before or during creating
|
||||
# a variable. Creating variables involve broadcasting the initial value from
|
||||
# the first replica to all replicas.
|
||||
|
||||
def worker_fn():
|
||||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||
with strategy.scope():
|
||||
tf.Variable(1.)
|
||||
# worker-1 dies here.
|
||||
if strategy.cluster_resolver.task_id == 1:
|
||||
quick_exit(1)
|
||||
v = tf.Variable(tf.random.uniform(()))
|
||||
return v.read_value().numpy()
|
||||
|
||||
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
||||
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
|
||||
mpr.start()
|
||||
# TODO(b/151232436): Always raise UnavailableError when a peer fails.
|
||||
with self.assertRaises(
|
||||
(tf.errors.UnavailableError, tf.errors.DeadlineExceededError)):
|
||||
mpr.join(timeout=30)
|
||||
|
||||
def test_reduce_small_tensor(self):
|
||||
# This test simulates the case when a worker fails before or during reducing
|
||||
# a small tensors, e.g. reading a metric.
|
||||
#
|
||||
# Note that this is written for a specific corner case that used to happen
|
||||
# only when all of the following conditions are met:
|
||||
# - There're two workers.
|
||||
# - They're reducing a small tensor. The definition of small varies
|
||||
# per platform.
|
||||
# - They're reducing a single tensor. Batched all-reduce are not affected.
|
||||
# - It must be worker-1 that fails.
|
||||
# Under this case, the all-reduce is effectively two send/recv operation,
|
||||
# the first one from worker-0 to worker-1, and the second one vice versa.
|
||||
# The first one blocks the second one. In send/recv, the sending party is
|
||||
# not aware of the failures of the receiving party.
|
||||
|
||||
def worker_fn():
|
||||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||
value = tf.identity([1.])
|
||||
strategy.reduce("sum", value, axis=None)
|
||||
# worker-1 dies here.
|
||||
if strategy.cluster_resolver.task_id == 1:
|
||||
quick_exit(1)
|
||||
strategy.reduce("sum", value, axis=None)
|
||||
|
||||
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
||||
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
|
||||
mpr.start()
|
||||
# TODO(b/151232436): Always raise UnavailableError when a peer fails.
|
||||
with self.assertRaises(
|
||||
(tf.errors.UnavailableError, tf.errors.DeadlineExceededError)):
|
||||
mpr.join(timeout=30)
|
||||
|
||||
|
||||
class PeerFailureRecoverTest(test.TestCase):
|
||||
# Similar to PeerFailureTest but simulates the situation where there's some
|
||||
# external system that automatically restarts failed workers.
|
||||
|
||||
def test_creating_variable(self):
|
||||
# See PeerFailureTest.test_creating_variable
|
||||
|
||||
def worker_fn(attempts):
|
||||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||
task_id, attempt = get_attempt(strategy, attempts)
|
||||
@ -83,23 +149,11 @@ class PeerFailureTest(test.TestCase):
|
||||
mpr = multi_process_runner.MultiProcessRunner(
|
||||
worker_fn, cluster_spec, args=(attempts,), auto_restart=True)
|
||||
mpr.start()
|
||||
# TODO(b/151232436): worker-0 should raises Unavailable instead of hanging.
|
||||
# Now after worker-1 fails, worker-0 waits on the second variable creation;
|
||||
# after worker-1 recovers, worker-1 waits on the first variable creation.
|
||||
with self.assertRaises(multi_process_runner.SubprocessTimeoutError):
|
||||
mpr.join(timeout=30)
|
||||
results = mpr.join(timeout=90).return_value
|
||||
self.assertEqual(results[0], results[1])
|
||||
|
||||
def test_reduce_small_tensor_broken(self):
|
||||
# This test simulates the case when a worker fails before or during reducing
|
||||
# a small tensors, e.g. reading a metric.
|
||||
#
|
||||
# Note that this is a rather corner case and only happens when all of the
|
||||
# following conditions are met:
|
||||
# - There're two workers.
|
||||
# - They're reducing a small tensor. The definition of small varies
|
||||
# per platform.
|
||||
# - They're reducing a single tensor. Batched all-reduce are not affected.
|
||||
# - It must be worker-1 that fails.
|
||||
def test_reduce_small_tensor(self):
|
||||
# See PeerFailureTest.test_reduce_small_tensor
|
||||
|
||||
def worker_fn(attempts):
|
||||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||
@ -109,18 +163,15 @@ class PeerFailureTest(test.TestCase):
|
||||
# worker-1 dies here.
|
||||
if attempt == 1 and task_id == 1:
|
||||
quick_exit(1)
|
||||
strategy.reduce("sum", value, axis=None)
|
||||
return strategy.reduce("sum", value, axis=None).numpy()
|
||||
|
||||
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
||||
attempts = multi_process_runner.manager().dict()
|
||||
mpr = multi_process_runner.MultiProcessRunner(
|
||||
worker_fn, cluster_spec, args=(attempts,), auto_restart=True)
|
||||
mpr.start()
|
||||
# TODO(b/151232436): worker-0 should raises Unavailable instead of hanging.
|
||||
# Now after worker-1 fails, worker-0 waits on the second reduce; after
|
||||
# worker-1 recovers, worker-1 waits on the first reduce.
|
||||
with self.assertRaises(multi_process_runner.SubprocessTimeoutError):
|
||||
mpr.join(timeout=30)
|
||||
results = mpr.join(timeout=90).return_value
|
||||
self.assertAllEqual(results, [[2.], [2.]])
|
||||
|
||||
def test_quick_recover(self):
|
||||
# This test simulates the case when a worker fails but recovers quickly
|
||||
@ -131,12 +182,14 @@ class PeerFailureTest(test.TestCase):
|
||||
# failed workers.
|
||||
|
||||
def worker_fn(attempts):
|
||||
# Set a long check alive interval to better simulate the case when a
|
||||
# worker fails and recovers during a check alive interval.
|
||||
mwms_lib.CollectiveAllReduceExtended._check_alive_interval = 30
|
||||
mwms_lib.CollectiveAllReduceExtended._check_alive_initial_timeout = 30
|
||||
|
||||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||
task_id, attempt = get_attempt(strategy, attempts)
|
||||
|
||||
if attempt == 2 and task_id == 1:
|
||||
multi_process_runner.barrier().wait()
|
||||
|
||||
@tf.function
|
||||
def replica_fn():
|
||||
ctx = tf.distribute.get_replica_context()
|
||||
@ -149,10 +202,6 @@ class PeerFailureTest(test.TestCase):
|
||||
# worker-1 dies here.
|
||||
if attempt == 1 and task_id == 1:
|
||||
quick_exit(1)
|
||||
# Make worker-0 waits for worker-1 to restart before entering the next
|
||||
# collective to simulate a quick recovery of worker-1.
|
||||
if attempt == 1 and task_id == 0:
|
||||
multi_process_runner.barrier().wait()
|
||||
strategy.run(replica_fn)
|
||||
|
||||
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user