Add timeout to collective ops to detect deadlocks.

The timeout is set as an argument to a collective op. When non zero value, a completion timeout is set to detect staleness. If a timeout goes off, the execution is aborted through a DEADLINE_EXCEEDED error.

PiperOrigin-RevId: 313861868
Change-Id: I7fee45736608ad7fbcc9dd980db2fd302c9cb4df
This commit is contained in:
A. Unique TensorFlower 2020-05-29 15:31:08 -07:00 committed by TensorFlower Gardener
parent 85396efcd3
commit 66529c35a7
8 changed files with 217 additions and 36 deletions

View File

@ -221,23 +221,42 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
const CollectiveParams& col_params,
const string& exec_key,
StatusCallback done) {
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
// On any individual collective Op failure we need to abort the
// BufRendezvous so that other Ops in the instance don't hang
// waiting for transmissions that will never happen. Do so after a
// delay so that the original error status is more likely to
// propagate up, and peers are unlikely to re-create the purged
// BufRendezvous by late-arriving requests.
StatusCallback done_safe = [this, done](const Status& s) {
if (!s.ok()) {
Ref(); // Ensure this lasts until the closure executes.
SchedNonBlockingClosureAfter(1000000, [this, s] {
remote_access_->buf_rendezvous()->StartAbort(s);
Unref();
});
StatusCallback done_safe = [this, done, is_callback_called](const Status& s) {
auto should_call_callback = !is_callback_called->exchange(true);
if (should_call_callback) {
if (!s.ok()) {
Ref(); // Ensure this lasts until the closure executes.
SchedNonBlockingClosureAfter(1000000, [this, s] {
remote_access_->buf_rendezvous()->StartAbort(s);
Unref();
});
}
done(s);
}
done(s);
};
auto timeout_microseconds = static_cast<int64>(
col_params.instance.impl_details.timeout_seconds * 1'000'000);
if (timeout_microseconds > 0) {
// TODO(xldrx): Share the timeout watchdog thread among collectives.
SchedNonBlockingClosureAfter(
timeout_microseconds, [is_callback_called, done_safe] {
if (!is_callback_called->load()) {
auto status = Status(error::DEADLINE_EXCEEDED,
"Collective has timed out during execution.");
done_safe(status);
}
});
}
Tensor* output = ctx->mutable_output(0);
const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
col_params.instance.type == GATHER_COLLECTIVE ||
@ -284,7 +303,30 @@ void BaseCollectiveExecutor::CompleteParamsAsync(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
StatusCallback done) {
cp->instance.gpu_ring_order = *gpu_ring_order_;
cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, done);
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
auto done_with_timeout = done;
auto timeout_microseconds =
static_cast<int64>(cp->instance.impl_details.timeout_seconds * 1'000'000);
if (timeout_microseconds > 0) {
// TODO(xldrx): Share the timeout watchdog thread among collectives.
SchedNonBlockingClosureAfter(
timeout_microseconds, [is_callback_called, done] {
if (!is_callback_called->load()) {
auto status =
Status(error::DEADLINE_EXCEEDED,
"Collective has timed out waiting for other workers.");
done(status);
}
});
done_with_timeout = [is_callback_called, done](const Status& s) {
auto should_call_callback = !is_callback_called->exchange(true);
if (should_call_callback) {
done(s);
}
};
}
cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr,
done_with_timeout);
}
Status BaseCollectiveExecutor::CreateCollective(

View File

@ -84,6 +84,8 @@ struct CollImplDetails {
dependencies; // collective instances on which this node depends
string communication_hint; // user-supplied hint for implementation choice,
// e.g. ring or nccl
float timeout_seconds; // If non zero, set a completion timeout for the
// collective op to detect staleness.
};
// Data common to all members of a collective instance.

View File

@ -85,6 +85,9 @@ class CollectiveGatherOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
const NodeDef& real_node = c->def();
col_params_.name = strings::StrCat(real_node.name(), ": Gather");
col_params_.group.device_type = c->device_type();
@ -176,10 +179,14 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
VLOG(2) << "CollectiveReduce instance " << col_params_.instance.instance_key
<< " merge_op " << merge_op_name << " final_op " << final_op_name
<< " communication_hint "
<< col_params_.instance.impl_details.communication_hint;
<< col_params_.instance.impl_details.communication_hint
<< " timeout " << col_params_.instance.impl_details.timeout_seconds;
const NodeDef& real_node = c->def();
col_params_.name = strings::StrCat(real_node.name(), ": Reduce(",
@ -284,6 +291,9 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
col_params_.is_source = true;
col_params_.instance.impl_details.subdiv_offsets = {0};
@ -363,6 +373,9 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
col_params_.is_source = false;
col_params_.instance.impl_details.subdiv_offsets = {0};

View File

@ -31,6 +31,7 @@ REGISTER_OP("CollectiveReduce")
.Attr("subdiv_offsets: list(int)")
.Attr("wait_for: list(int) = []")
.Attr("communication_hint: string = 'auto'")
.Attr("timeout_seconds: float = 0")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape);
@ -43,6 +44,7 @@ REGISTER_OP("CollectiveGather")
.Attr("instance_key: int")
.Attr("shape: shape")
.Attr("communication_hint: string = 'auto'")
.Attr("timeout_seconds: float = 0")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
// Scalar input is not supported.
@ -86,6 +88,7 @@ REGISTER_OP("CollectiveBcastSend")
.Attr("instance_key: int")
.Attr("shape: shape")
.Attr("communication_hint: string = 'auto'")
.Attr("timeout_seconds: float = 0")
.SetIsStateful()
.SetShapeFn(shape_inference::ExplicitShape);
@ -97,6 +100,7 @@ REGISTER_OP("CollectiveBcastRecv")
.Attr("instance_key: int")
.Attr("shape: shape")
.Attr("communication_hint: string = 'auto'")
.Attr("timeout_seconds: float = 0")
.SetIsStateful()
.SetShapeFn(shape_inference::ExplicitShape);

View File

@ -20,8 +20,15 @@ from __future__ import print_function
from tensorflow.python.ops import gen_collective_ops
def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op,
subdiv_offsets=(0,), communication_hint='auto'):
def all_reduce(t,
group_size,
group_key,
instance_key,
merge_op,
final_op,
subdiv_offsets=(0,),
communication_hint='auto',
timeout=0):
"""Reduces tensors collectively, across devices.
Args:
@ -40,6 +47,9 @@ def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op,
communication_hint: preferred collective communication. The implementation
may fall back to another mechanism. Options include `auto`, `ring`, and
`nccl`.
timeout: If set to a non zero, set a completion timeout to detect staleness.
If the timer goes off, a DeadlineExceededError is raised.
The timeout value in seconds. This feature is experimental.
Returns:
An Op implementing the distributed reduction.
@ -57,11 +67,16 @@ def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op,
merge_op=merge_op,
final_op=final_op,
subdiv_offsets=subdiv_offsets,
communication_hint=communication_hint.lower())
communication_hint=communication_hint.lower(),
timeout_seconds=timeout)
def all_gather(t, group_size, group_key, instance_key,
communication_hint='auto'):
def all_gather(t,
group_size,
group_key,
instance_key,
communication_hint='auto',
timeout=0):
"""Accumulates tensors collectively, across devices, along first dimension.
Args:
@ -73,6 +88,9 @@ def all_gather(t, group_size, group_key, instance_key,
communication_hint: preferred collective communication. The implementation
may fall back to another mechanism. Options include `auto`, `ring`, and
`nccl`.
timeout: If set to a non zero, set a completion timeout to detect staleness.
If the timer goes off, a DeadlineExceededError is raised.
The timeout value in seconds. This feature is experimental.
Returns:
An Op implementing the distributed operation.
@ -88,11 +106,18 @@ def all_gather(t, group_size, group_key, instance_key,
group_size=group_size,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication_hint.lower())
communication_hint=communication_hint.lower(),
timeout_seconds=timeout)
def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
communication_hint='auto'):
def broadcast_send(t,
shape,
dtype,
group_size,
group_key,
instance_key,
communication_hint='auto',
timeout=0):
"""Broadcasts one tensor to a group of others, across devices.
Args:
@ -107,6 +132,9 @@ def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
communication_hint: preferred collective communication. The implementation
may fall back to another mechanism. Options include `auto`, `ring`, and
`nccl`.
timeout: If set to a non zero, set a completion timeout to detect staleness.
If the timer goes off, a DeadlineExceededError is raised.
The timeout value in seconds. This feature is experimental.
Returns:
An Op implementing the distributed broadcast send.
@ -139,11 +167,17 @@ def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
group_size=group_size,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication_hint.lower())
communication_hint=communication_hint.lower(),
timeout_seconds=timeout)
def broadcast_recv(shape, dtype, group_size, group_key, instance_key,
communication_hint='auto'):
def broadcast_recv(shape,
dtype,
group_size,
group_key,
instance_key,
communication_hint='auto',
timeout=0):
"""Receives a broadcasts tensor, across devices.
Args:
@ -157,6 +191,9 @@ def broadcast_recv(shape, dtype, group_size, group_key, instance_key,
communication_hint: preferred collective communication. The implementation
may fall back to another mechanism. Options include `auto`, `ring`, and
`nccl`.
timeout: If set to a non zero, set a completion timeout to detect staleness.
If the timer goes off, a DeadlineExceededError is raised.
The timeout value in seconds. This feature is experimental.
Returns:
An Op implementing the broadcast receive.
@ -173,4 +210,5 @@ def broadcast_recv(shape, dtype, group_size, group_key, instance_key,
group_size=group_size,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication_hint.lower())
communication_hint=communication_hint.lower(),
timeout_seconds=timeout)

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.eager import context
@ -40,11 +42,21 @@ from tensorflow.python.platform import tf_logging as logging
class CollectiveOpTest(test.TestCase):
def _testCollectiveReduce(self, inputs, expected, set_graph_key,
communication_hint='auto', fp16=False,
instance_key=1, merge_op='Add', final_op='Div'):
def _testCollectiveReduce(self,
inputs,
expected,
set_graph_key,
communication_hint='auto',
fp16=False,
instance_key=1,
merge_op='Add',
final_op='Div',
timeout=0,
reported_group_size=None):
group_key = 1
group_size = len(inputs)
if reported_group_size is None:
reported_group_size = group_size
device_type = 'CPU'
config = config_pb2.ConfigProto(device_count={device_type: group_size})
devices = ['/{}:{}'.format(device_type, i) for i in range(group_size)]
@ -55,9 +67,16 @@ class CollectiveOpTest(test.TestCase):
with ops.device(devices[i]):
tensor = constant_op.constant(inputs[i], dtype=(
dtypes.float16 if fp16 else dtypes.float32))
colred.append(collective_ops.all_reduce(
tensor, group_size, group_key, instance_key, merge_op, final_op,
communication_hint=communication_hint))
colred.append(
collective_ops.all_reduce(
tensor,
reported_group_size,
group_key,
instance_key,
merge_op,
final_op,
communication_hint=communication_hint,
timeout=timeout))
run_options = config_pb2.RunOptions()
if set_graph_key:
run_options.experimental.collective_graph_key = 1
@ -117,6 +136,69 @@ class CollectiveOpTest(test.TestCase):
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
@test_util.run_deprecated_v1
def testCollectiveTimeoutV1(self):
timeout = 4.5
kwargs = dict(
inputs=[[i + j + 0.1 for i in range(8)] for j in range(3)],
expected=[1 + i + 0.1 for i in range(8)],
set_graph_key=True,
timeout=timeout)
self._testCollectiveReduce(**kwargs)
start_time = time.time()
with self.assertRaisesRegex(
errors.DeadlineExceededError,
'Collective has timed out waiting for other workers'):
self._testCollectiveReduce(
reported_group_size=len(kwargs['inputs']) + 1, **kwargs)
elapsed = time.time() - start_time
self.assertAllGreaterEqual(elapsed, timeout)
@test_util.run_v2_only
def testCollectiveTimeoutV2(self):
context._reset_context()
timeout = 4.5
cpus = config.list_physical_devices('CPU')
self.assertEqual(len(cpus), 1)
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration()
])
context.ensure_initialized()
@def_function.function
def run_all_reduce(group_size, reported_group_size=None):
group_key = 20
instance_key = 30
tensor = [1, 2, 3, 4]
results = []
if reported_group_size is None:
reported_group_size = group_size
for i in range(group_size):
with ops.device('/CPU:{}'.format(i)):
input_data = constant_op.constant(tensor)
collective_op = collective_ops.all_reduce(
input_data,
group_size=reported_group_size,
group_key=group_key,
instance_key=instance_key,
merge_op='Add',
final_op='Id',
timeout=timeout)
results.append(collective_op)
return results
run_all_reduce(2, 2)
start_time = time.time()
with self.assertRaisesRegex(errors.DeadlineExceededError,
'Collective has timed out during execution'):
run_all_reduce(1, 2)
elapsed = time.time() - start_time
self.assertAllGreaterEqual(elapsed, timeout)
@test_util.run_deprecated_v1
def testNcclHintFallbackToRingReduce(self):
"""Tests that setting `communication_hint=nccl` works on non-GPU builds."""

View File

@ -702,15 +702,15 @@ tf_module {
}
member_method {
name: "CollectiveBcastRecv"
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CollectiveBcastSend"
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CollectiveGather"
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CollectivePermute"
@ -718,7 +718,7 @@ tf_module {
}
member_method {
name: "CollectiveReduce"
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'None\'], "
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'0\', \'None\'], "
}
member_method {
name: "CombinedNonMaxSuppression"

View File

@ -702,15 +702,15 @@ tf_module {
}
member_method {
name: "CollectiveBcastRecv"
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CollectiveBcastSend"
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CollectiveGather"
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CollectivePermute"
@ -718,7 +718,7 @@ tf_module {
}
member_method {
name: "CollectiveReduce"
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'None\'], "
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'0\', \'None\'], "
}
member_method {
name: "CombinedNonMaxSuppression"