Add timeout to collective ops to detect deadlocks.
The timeout is set as an argument to a collective op. When non zero value, a completion timeout is set to detect staleness. If a timeout goes off, the execution is aborted through a DEADLINE_EXCEEDED error. PiperOrigin-RevId: 313861868 Change-Id: I7fee45736608ad7fbcc9dd980db2fd302c9cb4df
This commit is contained in:
parent
85396efcd3
commit
66529c35a7
@ -221,23 +221,42 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
|
|||||||
const CollectiveParams& col_params,
|
const CollectiveParams& col_params,
|
||||||
const string& exec_key,
|
const string& exec_key,
|
||||||
StatusCallback done) {
|
StatusCallback done) {
|
||||||
|
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
|
||||||
|
|
||||||
// On any individual collective Op failure we need to abort the
|
// On any individual collective Op failure we need to abort the
|
||||||
// BufRendezvous so that other Ops in the instance don't hang
|
// BufRendezvous so that other Ops in the instance don't hang
|
||||||
// waiting for transmissions that will never happen. Do so after a
|
// waiting for transmissions that will never happen. Do so after a
|
||||||
// delay so that the original error status is more likely to
|
// delay so that the original error status is more likely to
|
||||||
// propagate up, and peers are unlikely to re-create the purged
|
// propagate up, and peers are unlikely to re-create the purged
|
||||||
// BufRendezvous by late-arriving requests.
|
// BufRendezvous by late-arriving requests.
|
||||||
StatusCallback done_safe = [this, done](const Status& s) {
|
StatusCallback done_safe = [this, done, is_callback_called](const Status& s) {
|
||||||
if (!s.ok()) {
|
auto should_call_callback = !is_callback_called->exchange(true);
|
||||||
Ref(); // Ensure this lasts until the closure executes.
|
if (should_call_callback) {
|
||||||
SchedNonBlockingClosureAfter(1000000, [this, s] {
|
if (!s.ok()) {
|
||||||
remote_access_->buf_rendezvous()->StartAbort(s);
|
Ref(); // Ensure this lasts until the closure executes.
|
||||||
Unref();
|
SchedNonBlockingClosureAfter(1000000, [this, s] {
|
||||||
});
|
remote_access_->buf_rendezvous()->StartAbort(s);
|
||||||
|
Unref();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
done(s);
|
||||||
}
|
}
|
||||||
done(s);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto timeout_microseconds = static_cast<int64>(
|
||||||
|
col_params.instance.impl_details.timeout_seconds * 1'000'000);
|
||||||
|
if (timeout_microseconds > 0) {
|
||||||
|
// TODO(xldrx): Share the timeout watchdog thread among collectives.
|
||||||
|
SchedNonBlockingClosureAfter(
|
||||||
|
timeout_microseconds, [is_callback_called, done_safe] {
|
||||||
|
if (!is_callback_called->load()) {
|
||||||
|
auto status = Status(error::DEADLINE_EXCEEDED,
|
||||||
|
"Collective has timed out during execution.");
|
||||||
|
done_safe(status);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
Tensor* output = ctx->mutable_output(0);
|
Tensor* output = ctx->mutable_output(0);
|
||||||
const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
|
const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
|
||||||
col_params.instance.type == GATHER_COLLECTIVE ||
|
col_params.instance.type == GATHER_COLLECTIVE ||
|
||||||
@ -284,7 +303,30 @@ void BaseCollectiveExecutor::CompleteParamsAsync(
|
|||||||
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
|
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
|
||||||
StatusCallback done) {
|
StatusCallback done) {
|
||||||
cp->instance.gpu_ring_order = *gpu_ring_order_;
|
cp->instance.gpu_ring_order = *gpu_ring_order_;
|
||||||
cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, done);
|
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
|
||||||
|
auto done_with_timeout = done;
|
||||||
|
auto timeout_microseconds =
|
||||||
|
static_cast<int64>(cp->instance.impl_details.timeout_seconds * 1'000'000);
|
||||||
|
if (timeout_microseconds > 0) {
|
||||||
|
// TODO(xldrx): Share the timeout watchdog thread among collectives.
|
||||||
|
SchedNonBlockingClosureAfter(
|
||||||
|
timeout_microseconds, [is_callback_called, done] {
|
||||||
|
if (!is_callback_called->load()) {
|
||||||
|
auto status =
|
||||||
|
Status(error::DEADLINE_EXCEEDED,
|
||||||
|
"Collective has timed out waiting for other workers.");
|
||||||
|
done(status);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
done_with_timeout = [is_callback_called, done](const Status& s) {
|
||||||
|
auto should_call_callback = !is_callback_called->exchange(true);
|
||||||
|
if (should_call_callback) {
|
||||||
|
done(s);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr,
|
||||||
|
done_with_timeout);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BaseCollectiveExecutor::CreateCollective(
|
Status BaseCollectiveExecutor::CreateCollective(
|
||||||
|
@ -84,6 +84,8 @@ struct CollImplDetails {
|
|||||||
dependencies; // collective instances on which this node depends
|
dependencies; // collective instances on which this node depends
|
||||||
string communication_hint; // user-supplied hint for implementation choice,
|
string communication_hint; // user-supplied hint for implementation choice,
|
||||||
// e.g. ring or nccl
|
// e.g. ring or nccl
|
||||||
|
float timeout_seconds; // If non zero, set a completion timeout for the
|
||||||
|
// collective op to detect staleness.
|
||||||
};
|
};
|
||||||
|
|
||||||
// Data common to all members of a collective instance.
|
// Data common to all members of a collective instance.
|
||||||
|
@ -85,6 +85,9 @@ class CollectiveGatherOpKernel : public CollectiveOpKernel {
|
|||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
c, c->GetAttr("communication_hint",
|
c, c->GetAttr("communication_hint",
|
||||||
&col_params_.instance.impl_details.communication_hint));
|
&col_params_.instance.impl_details.communication_hint));
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
c, c->GetAttr("timeout_seconds",
|
||||||
|
&col_params_.instance.impl_details.timeout_seconds));
|
||||||
const NodeDef& real_node = c->def();
|
const NodeDef& real_node = c->def();
|
||||||
col_params_.name = strings::StrCat(real_node.name(), ": Gather");
|
col_params_.name = strings::StrCat(real_node.name(), ": Gather");
|
||||||
col_params_.group.device_type = c->device_type();
|
col_params_.group.device_type = c->device_type();
|
||||||
@ -176,10 +179,14 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
|
|||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
c, c->GetAttr("communication_hint",
|
c, c->GetAttr("communication_hint",
|
||||||
&col_params_.instance.impl_details.communication_hint));
|
&col_params_.instance.impl_details.communication_hint));
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
c, c->GetAttr("timeout_seconds",
|
||||||
|
&col_params_.instance.impl_details.timeout_seconds));
|
||||||
VLOG(2) << "CollectiveReduce instance " << col_params_.instance.instance_key
|
VLOG(2) << "CollectiveReduce instance " << col_params_.instance.instance_key
|
||||||
<< " merge_op " << merge_op_name << " final_op " << final_op_name
|
<< " merge_op " << merge_op_name << " final_op " << final_op_name
|
||||||
<< " communication_hint "
|
<< " communication_hint "
|
||||||
<< col_params_.instance.impl_details.communication_hint;
|
<< col_params_.instance.impl_details.communication_hint
|
||||||
|
<< " timeout " << col_params_.instance.impl_details.timeout_seconds;
|
||||||
|
|
||||||
const NodeDef& real_node = c->def();
|
const NodeDef& real_node = c->def();
|
||||||
col_params_.name = strings::StrCat(real_node.name(), ": Reduce(",
|
col_params_.name = strings::StrCat(real_node.name(), ": Reduce(",
|
||||||
@ -284,6 +291,9 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
|
|||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
c, c->GetAttr("communication_hint",
|
c, c->GetAttr("communication_hint",
|
||||||
&col_params_.instance.impl_details.communication_hint));
|
&col_params_.instance.impl_details.communication_hint));
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
c, c->GetAttr("timeout_seconds",
|
||||||
|
&col_params_.instance.impl_details.timeout_seconds));
|
||||||
col_params_.is_source = true;
|
col_params_.is_source = true;
|
||||||
col_params_.instance.impl_details.subdiv_offsets = {0};
|
col_params_.instance.impl_details.subdiv_offsets = {0};
|
||||||
|
|
||||||
@ -363,6 +373,9 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
|
|||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
c, c->GetAttr("communication_hint",
|
c, c->GetAttr("communication_hint",
|
||||||
&col_params_.instance.impl_details.communication_hint));
|
&col_params_.instance.impl_details.communication_hint));
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
c, c->GetAttr("timeout_seconds",
|
||||||
|
&col_params_.instance.impl_details.timeout_seconds));
|
||||||
col_params_.is_source = false;
|
col_params_.is_source = false;
|
||||||
col_params_.instance.impl_details.subdiv_offsets = {0};
|
col_params_.instance.impl_details.subdiv_offsets = {0};
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ REGISTER_OP("CollectiveReduce")
|
|||||||
.Attr("subdiv_offsets: list(int)")
|
.Attr("subdiv_offsets: list(int)")
|
||||||
.Attr("wait_for: list(int) = []")
|
.Attr("wait_for: list(int) = []")
|
||||||
.Attr("communication_hint: string = 'auto'")
|
.Attr("communication_hint: string = 'auto'")
|
||||||
|
.Attr("timeout_seconds: float = 0")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(shape_inference::UnchangedShape);
|
.SetShapeFn(shape_inference::UnchangedShape);
|
||||||
|
|
||||||
@ -43,6 +44,7 @@ REGISTER_OP("CollectiveGather")
|
|||||||
.Attr("instance_key: int")
|
.Attr("instance_key: int")
|
||||||
.Attr("shape: shape")
|
.Attr("shape: shape")
|
||||||
.Attr("communication_hint: string = 'auto'")
|
.Attr("communication_hint: string = 'auto'")
|
||||||
|
.Attr("timeout_seconds: float = 0")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
// Scalar input is not supported.
|
// Scalar input is not supported.
|
||||||
@ -86,6 +88,7 @@ REGISTER_OP("CollectiveBcastSend")
|
|||||||
.Attr("instance_key: int")
|
.Attr("instance_key: int")
|
||||||
.Attr("shape: shape")
|
.Attr("shape: shape")
|
||||||
.Attr("communication_hint: string = 'auto'")
|
.Attr("communication_hint: string = 'auto'")
|
||||||
|
.Attr("timeout_seconds: float = 0")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(shape_inference::ExplicitShape);
|
.SetShapeFn(shape_inference::ExplicitShape);
|
||||||
|
|
||||||
@ -97,6 +100,7 @@ REGISTER_OP("CollectiveBcastRecv")
|
|||||||
.Attr("instance_key: int")
|
.Attr("instance_key: int")
|
||||||
.Attr("shape: shape")
|
.Attr("shape: shape")
|
||||||
.Attr("communication_hint: string = 'auto'")
|
.Attr("communication_hint: string = 'auto'")
|
||||||
|
.Attr("timeout_seconds: float = 0")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(shape_inference::ExplicitShape);
|
.SetShapeFn(shape_inference::ExplicitShape);
|
||||||
|
|
||||||
|
@ -20,8 +20,15 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.ops import gen_collective_ops
|
from tensorflow.python.ops import gen_collective_ops
|
||||||
|
|
||||||
|
|
||||||
def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op,
|
def all_reduce(t,
|
||||||
subdiv_offsets=(0,), communication_hint='auto'):
|
group_size,
|
||||||
|
group_key,
|
||||||
|
instance_key,
|
||||||
|
merge_op,
|
||||||
|
final_op,
|
||||||
|
subdiv_offsets=(0,),
|
||||||
|
communication_hint='auto',
|
||||||
|
timeout=0):
|
||||||
"""Reduces tensors collectively, across devices.
|
"""Reduces tensors collectively, across devices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -40,6 +47,9 @@ def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op,
|
|||||||
communication_hint: preferred collective communication. The implementation
|
communication_hint: preferred collective communication. The implementation
|
||||||
may fall back to another mechanism. Options include `auto`, `ring`, and
|
may fall back to another mechanism. Options include `auto`, `ring`, and
|
||||||
`nccl`.
|
`nccl`.
|
||||||
|
timeout: If set to a non zero, set a completion timeout to detect staleness.
|
||||||
|
If the timer goes off, a DeadlineExceededError is raised.
|
||||||
|
The timeout value in seconds. This feature is experimental.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An Op implementing the distributed reduction.
|
An Op implementing the distributed reduction.
|
||||||
@ -57,11 +67,16 @@ def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op,
|
|||||||
merge_op=merge_op,
|
merge_op=merge_op,
|
||||||
final_op=final_op,
|
final_op=final_op,
|
||||||
subdiv_offsets=subdiv_offsets,
|
subdiv_offsets=subdiv_offsets,
|
||||||
communication_hint=communication_hint.lower())
|
communication_hint=communication_hint.lower(),
|
||||||
|
timeout_seconds=timeout)
|
||||||
|
|
||||||
|
|
||||||
def all_gather(t, group_size, group_key, instance_key,
|
def all_gather(t,
|
||||||
communication_hint='auto'):
|
group_size,
|
||||||
|
group_key,
|
||||||
|
instance_key,
|
||||||
|
communication_hint='auto',
|
||||||
|
timeout=0):
|
||||||
"""Accumulates tensors collectively, across devices, along first dimension.
|
"""Accumulates tensors collectively, across devices, along first dimension.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -73,6 +88,9 @@ def all_gather(t, group_size, group_key, instance_key,
|
|||||||
communication_hint: preferred collective communication. The implementation
|
communication_hint: preferred collective communication. The implementation
|
||||||
may fall back to another mechanism. Options include `auto`, `ring`, and
|
may fall back to another mechanism. Options include `auto`, `ring`, and
|
||||||
`nccl`.
|
`nccl`.
|
||||||
|
timeout: If set to a non zero, set a completion timeout to detect staleness.
|
||||||
|
If the timer goes off, a DeadlineExceededError is raised.
|
||||||
|
The timeout value in seconds. This feature is experimental.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An Op implementing the distributed operation.
|
An Op implementing the distributed operation.
|
||||||
@ -88,11 +106,18 @@ def all_gather(t, group_size, group_key, instance_key,
|
|||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
group_key=group_key,
|
group_key=group_key,
|
||||||
instance_key=instance_key,
|
instance_key=instance_key,
|
||||||
communication_hint=communication_hint.lower())
|
communication_hint=communication_hint.lower(),
|
||||||
|
timeout_seconds=timeout)
|
||||||
|
|
||||||
|
|
||||||
def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
|
def broadcast_send(t,
|
||||||
communication_hint='auto'):
|
shape,
|
||||||
|
dtype,
|
||||||
|
group_size,
|
||||||
|
group_key,
|
||||||
|
instance_key,
|
||||||
|
communication_hint='auto',
|
||||||
|
timeout=0):
|
||||||
"""Broadcasts one tensor to a group of others, across devices.
|
"""Broadcasts one tensor to a group of others, across devices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -107,6 +132,9 @@ def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
|
|||||||
communication_hint: preferred collective communication. The implementation
|
communication_hint: preferred collective communication. The implementation
|
||||||
may fall back to another mechanism. Options include `auto`, `ring`, and
|
may fall back to another mechanism. Options include `auto`, `ring`, and
|
||||||
`nccl`.
|
`nccl`.
|
||||||
|
timeout: If set to a non zero, set a completion timeout to detect staleness.
|
||||||
|
If the timer goes off, a DeadlineExceededError is raised.
|
||||||
|
The timeout value in seconds. This feature is experimental.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An Op implementing the distributed broadcast send.
|
An Op implementing the distributed broadcast send.
|
||||||
@ -139,11 +167,17 @@ def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
|
|||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
group_key=group_key,
|
group_key=group_key,
|
||||||
instance_key=instance_key,
|
instance_key=instance_key,
|
||||||
communication_hint=communication_hint.lower())
|
communication_hint=communication_hint.lower(),
|
||||||
|
timeout_seconds=timeout)
|
||||||
|
|
||||||
|
|
||||||
def broadcast_recv(shape, dtype, group_size, group_key, instance_key,
|
def broadcast_recv(shape,
|
||||||
communication_hint='auto'):
|
dtype,
|
||||||
|
group_size,
|
||||||
|
group_key,
|
||||||
|
instance_key,
|
||||||
|
communication_hint='auto',
|
||||||
|
timeout=0):
|
||||||
"""Receives a broadcasts tensor, across devices.
|
"""Receives a broadcasts tensor, across devices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -157,6 +191,9 @@ def broadcast_recv(shape, dtype, group_size, group_key, instance_key,
|
|||||||
communication_hint: preferred collective communication. The implementation
|
communication_hint: preferred collective communication. The implementation
|
||||||
may fall back to another mechanism. Options include `auto`, `ring`, and
|
may fall back to another mechanism. Options include `auto`, `ring`, and
|
||||||
`nccl`.
|
`nccl`.
|
||||||
|
timeout: If set to a non zero, set a completion timeout to detect staleness.
|
||||||
|
If the timer goes off, a DeadlineExceededError is raised.
|
||||||
|
The timeout value in seconds. This feature is experimental.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An Op implementing the broadcast receive.
|
An Op implementing the broadcast receive.
|
||||||
@ -173,4 +210,5 @@ def broadcast_recv(shape, dtype, group_size, group_key, instance_key,
|
|||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
group_key=group_key,
|
group_key=group_key,
|
||||||
instance_key=instance_key,
|
instance_key=instance_key,
|
||||||
communication_hint=communication_hint.lower())
|
communication_hint=communication_hint.lower(),
|
||||||
|
timeout_seconds=timeout)
|
||||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -40,11 +42,21 @@ from tensorflow.python.platform import tf_logging as logging
|
|||||||
|
|
||||||
class CollectiveOpTest(test.TestCase):
|
class CollectiveOpTest(test.TestCase):
|
||||||
|
|
||||||
def _testCollectiveReduce(self, inputs, expected, set_graph_key,
|
def _testCollectiveReduce(self,
|
||||||
communication_hint='auto', fp16=False,
|
inputs,
|
||||||
instance_key=1, merge_op='Add', final_op='Div'):
|
expected,
|
||||||
|
set_graph_key,
|
||||||
|
communication_hint='auto',
|
||||||
|
fp16=False,
|
||||||
|
instance_key=1,
|
||||||
|
merge_op='Add',
|
||||||
|
final_op='Div',
|
||||||
|
timeout=0,
|
||||||
|
reported_group_size=None):
|
||||||
group_key = 1
|
group_key = 1
|
||||||
group_size = len(inputs)
|
group_size = len(inputs)
|
||||||
|
if reported_group_size is None:
|
||||||
|
reported_group_size = group_size
|
||||||
device_type = 'CPU'
|
device_type = 'CPU'
|
||||||
config = config_pb2.ConfigProto(device_count={device_type: group_size})
|
config = config_pb2.ConfigProto(device_count={device_type: group_size})
|
||||||
devices = ['/{}:{}'.format(device_type, i) for i in range(group_size)]
|
devices = ['/{}:{}'.format(device_type, i) for i in range(group_size)]
|
||||||
@ -55,9 +67,16 @@ class CollectiveOpTest(test.TestCase):
|
|||||||
with ops.device(devices[i]):
|
with ops.device(devices[i]):
|
||||||
tensor = constant_op.constant(inputs[i], dtype=(
|
tensor = constant_op.constant(inputs[i], dtype=(
|
||||||
dtypes.float16 if fp16 else dtypes.float32))
|
dtypes.float16 if fp16 else dtypes.float32))
|
||||||
colred.append(collective_ops.all_reduce(
|
colred.append(
|
||||||
tensor, group_size, group_key, instance_key, merge_op, final_op,
|
collective_ops.all_reduce(
|
||||||
communication_hint=communication_hint))
|
tensor,
|
||||||
|
reported_group_size,
|
||||||
|
group_key,
|
||||||
|
instance_key,
|
||||||
|
merge_op,
|
||||||
|
final_op,
|
||||||
|
communication_hint=communication_hint,
|
||||||
|
timeout=timeout))
|
||||||
run_options = config_pb2.RunOptions()
|
run_options = config_pb2.RunOptions()
|
||||||
if set_graph_key:
|
if set_graph_key:
|
||||||
run_options.experimental.collective_graph_key = 1
|
run_options.experimental.collective_graph_key = 1
|
||||||
@ -117,6 +136,69 @@ class CollectiveOpTest(test.TestCase):
|
|||||||
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
|
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
|
||||||
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
|
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testCollectiveTimeoutV1(self):
|
||||||
|
timeout = 4.5
|
||||||
|
kwargs = dict(
|
||||||
|
inputs=[[i + j + 0.1 for i in range(8)] for j in range(3)],
|
||||||
|
expected=[1 + i + 0.1 for i in range(8)],
|
||||||
|
set_graph_key=True,
|
||||||
|
timeout=timeout)
|
||||||
|
|
||||||
|
self._testCollectiveReduce(**kwargs)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
errors.DeadlineExceededError,
|
||||||
|
'Collective has timed out waiting for other workers'):
|
||||||
|
self._testCollectiveReduce(
|
||||||
|
reported_group_size=len(kwargs['inputs']) + 1, **kwargs)
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
self.assertAllGreaterEqual(elapsed, timeout)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testCollectiveTimeoutV2(self):
|
||||||
|
context._reset_context()
|
||||||
|
timeout = 4.5
|
||||||
|
cpus = config.list_physical_devices('CPU')
|
||||||
|
self.assertEqual(len(cpus), 1)
|
||||||
|
config.set_logical_device_configuration(cpus[0], [
|
||||||
|
context.LogicalDeviceConfiguration(),
|
||||||
|
context.LogicalDeviceConfiguration()
|
||||||
|
])
|
||||||
|
context.ensure_initialized()
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def run_all_reduce(group_size, reported_group_size=None):
|
||||||
|
group_key = 20
|
||||||
|
instance_key = 30
|
||||||
|
tensor = [1, 2, 3, 4]
|
||||||
|
results = []
|
||||||
|
if reported_group_size is None:
|
||||||
|
reported_group_size = group_size
|
||||||
|
for i in range(group_size):
|
||||||
|
with ops.device('/CPU:{}'.format(i)):
|
||||||
|
input_data = constant_op.constant(tensor)
|
||||||
|
collective_op = collective_ops.all_reduce(
|
||||||
|
input_data,
|
||||||
|
group_size=reported_group_size,
|
||||||
|
group_key=group_key,
|
||||||
|
instance_key=instance_key,
|
||||||
|
merge_op='Add',
|
||||||
|
final_op='Id',
|
||||||
|
timeout=timeout)
|
||||||
|
results.append(collective_op)
|
||||||
|
return results
|
||||||
|
|
||||||
|
run_all_reduce(2, 2)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
with self.assertRaisesRegex(errors.DeadlineExceededError,
|
||||||
|
'Collective has timed out during execution'):
|
||||||
|
run_all_reduce(1, 2)
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
self.assertAllGreaterEqual(elapsed, timeout)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testNcclHintFallbackToRingReduce(self):
|
def testNcclHintFallbackToRingReduce(self):
|
||||||
"""Tests that setting `communication_hint=nccl` works on non-GPU builds."""
|
"""Tests that setting `communication_hint=nccl` works on non-GPU builds."""
|
||||||
|
@ -702,15 +702,15 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveBcastRecv"
|
name: "CollectiveBcastRecv"
|
||||||
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
|
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveBcastSend"
|
name: "CollectiveBcastSend"
|
||||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
|
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveGather"
|
name: "CollectiveGather"
|
||||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
|
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectivePermute"
|
name: "CollectivePermute"
|
||||||
@ -718,7 +718,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveReduce"
|
name: "CollectiveReduce"
|
||||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'None\'], "
|
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CombinedNonMaxSuppression"
|
name: "CombinedNonMaxSuppression"
|
||||||
|
@ -702,15 +702,15 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveBcastRecv"
|
name: "CollectiveBcastRecv"
|
||||||
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
|
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveBcastSend"
|
name: "CollectiveBcastSend"
|
||||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
|
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveGather"
|
name: "CollectiveGather"
|
||||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
|
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectivePermute"
|
name: "CollectivePermute"
|
||||||
@ -718,7 +718,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveReduce"
|
name: "CollectiveReduce"
|
||||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'None\'], "
|
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CombinedNonMaxSuppression"
|
name: "CombinedNonMaxSuppression"
|
||||||
|
Loading…
Reference in New Issue
Block a user