Add timeout to collective ops to detect deadlocks.
The timeout is set as an argument to a collective op. When non zero value, a completion timeout is set to detect staleness. If a timeout goes off, the execution is aborted through a DEADLINE_EXCEEDED error. PiperOrigin-RevId: 313861868 Change-Id: I7fee45736608ad7fbcc9dd980db2fd302c9cb4df
This commit is contained in:
parent
85396efcd3
commit
66529c35a7
@ -221,23 +221,42 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
|
||||
const CollectiveParams& col_params,
|
||||
const string& exec_key,
|
||||
StatusCallback done) {
|
||||
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
|
||||
|
||||
// On any individual collective Op failure we need to abort the
|
||||
// BufRendezvous so that other Ops in the instance don't hang
|
||||
// waiting for transmissions that will never happen. Do so after a
|
||||
// delay so that the original error status is more likely to
|
||||
// propagate up, and peers are unlikely to re-create the purged
|
||||
// BufRendezvous by late-arriving requests.
|
||||
StatusCallback done_safe = [this, done](const Status& s) {
|
||||
if (!s.ok()) {
|
||||
Ref(); // Ensure this lasts until the closure executes.
|
||||
SchedNonBlockingClosureAfter(1000000, [this, s] {
|
||||
remote_access_->buf_rendezvous()->StartAbort(s);
|
||||
Unref();
|
||||
});
|
||||
StatusCallback done_safe = [this, done, is_callback_called](const Status& s) {
|
||||
auto should_call_callback = !is_callback_called->exchange(true);
|
||||
if (should_call_callback) {
|
||||
if (!s.ok()) {
|
||||
Ref(); // Ensure this lasts until the closure executes.
|
||||
SchedNonBlockingClosureAfter(1000000, [this, s] {
|
||||
remote_access_->buf_rendezvous()->StartAbort(s);
|
||||
Unref();
|
||||
});
|
||||
}
|
||||
done(s);
|
||||
}
|
||||
done(s);
|
||||
};
|
||||
|
||||
auto timeout_microseconds = static_cast<int64>(
|
||||
col_params.instance.impl_details.timeout_seconds * 1'000'000);
|
||||
if (timeout_microseconds > 0) {
|
||||
// TODO(xldrx): Share the timeout watchdog thread among collectives.
|
||||
SchedNonBlockingClosureAfter(
|
||||
timeout_microseconds, [is_callback_called, done_safe] {
|
||||
if (!is_callback_called->load()) {
|
||||
auto status = Status(error::DEADLINE_EXCEEDED,
|
||||
"Collective has timed out during execution.");
|
||||
done_safe(status);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Tensor* output = ctx->mutable_output(0);
|
||||
const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
|
||||
col_params.instance.type == GATHER_COLLECTIVE ||
|
||||
@ -284,7 +303,30 @@ void BaseCollectiveExecutor::CompleteParamsAsync(
|
||||
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
|
||||
StatusCallback done) {
|
||||
cp->instance.gpu_ring_order = *gpu_ring_order_;
|
||||
cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, done);
|
||||
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
|
||||
auto done_with_timeout = done;
|
||||
auto timeout_microseconds =
|
||||
static_cast<int64>(cp->instance.impl_details.timeout_seconds * 1'000'000);
|
||||
if (timeout_microseconds > 0) {
|
||||
// TODO(xldrx): Share the timeout watchdog thread among collectives.
|
||||
SchedNonBlockingClosureAfter(
|
||||
timeout_microseconds, [is_callback_called, done] {
|
||||
if (!is_callback_called->load()) {
|
||||
auto status =
|
||||
Status(error::DEADLINE_EXCEEDED,
|
||||
"Collective has timed out waiting for other workers.");
|
||||
done(status);
|
||||
}
|
||||
});
|
||||
done_with_timeout = [is_callback_called, done](const Status& s) {
|
||||
auto should_call_callback = !is_callback_called->exchange(true);
|
||||
if (should_call_callback) {
|
||||
done(s);
|
||||
}
|
||||
};
|
||||
}
|
||||
cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr,
|
||||
done_with_timeout);
|
||||
}
|
||||
|
||||
Status BaseCollectiveExecutor::CreateCollective(
|
||||
|
@ -84,6 +84,8 @@ struct CollImplDetails {
|
||||
dependencies; // collective instances on which this node depends
|
||||
string communication_hint; // user-supplied hint for implementation choice,
|
||||
// e.g. ring or nccl
|
||||
float timeout_seconds; // If non zero, set a completion timeout for the
|
||||
// collective op to detect staleness.
|
||||
};
|
||||
|
||||
// Data common to all members of a collective instance.
|
||||
|
@ -85,6 +85,9 @@ class CollectiveGatherOpKernel : public CollectiveOpKernel {
|
||||
OP_REQUIRES_OK(
|
||||
c, c->GetAttr("communication_hint",
|
||||
&col_params_.instance.impl_details.communication_hint));
|
||||
OP_REQUIRES_OK(
|
||||
c, c->GetAttr("timeout_seconds",
|
||||
&col_params_.instance.impl_details.timeout_seconds));
|
||||
const NodeDef& real_node = c->def();
|
||||
col_params_.name = strings::StrCat(real_node.name(), ": Gather");
|
||||
col_params_.group.device_type = c->device_type();
|
||||
@ -176,10 +179,14 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
|
||||
OP_REQUIRES_OK(
|
||||
c, c->GetAttr("communication_hint",
|
||||
&col_params_.instance.impl_details.communication_hint));
|
||||
OP_REQUIRES_OK(
|
||||
c, c->GetAttr("timeout_seconds",
|
||||
&col_params_.instance.impl_details.timeout_seconds));
|
||||
VLOG(2) << "CollectiveReduce instance " << col_params_.instance.instance_key
|
||||
<< " merge_op " << merge_op_name << " final_op " << final_op_name
|
||||
<< " communication_hint "
|
||||
<< col_params_.instance.impl_details.communication_hint;
|
||||
<< col_params_.instance.impl_details.communication_hint
|
||||
<< " timeout " << col_params_.instance.impl_details.timeout_seconds;
|
||||
|
||||
const NodeDef& real_node = c->def();
|
||||
col_params_.name = strings::StrCat(real_node.name(), ": Reduce(",
|
||||
@ -284,6 +291,9 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
|
||||
OP_REQUIRES_OK(
|
||||
c, c->GetAttr("communication_hint",
|
||||
&col_params_.instance.impl_details.communication_hint));
|
||||
OP_REQUIRES_OK(
|
||||
c, c->GetAttr("timeout_seconds",
|
||||
&col_params_.instance.impl_details.timeout_seconds));
|
||||
col_params_.is_source = true;
|
||||
col_params_.instance.impl_details.subdiv_offsets = {0};
|
||||
|
||||
@ -363,6 +373,9 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
|
||||
OP_REQUIRES_OK(
|
||||
c, c->GetAttr("communication_hint",
|
||||
&col_params_.instance.impl_details.communication_hint));
|
||||
OP_REQUIRES_OK(
|
||||
c, c->GetAttr("timeout_seconds",
|
||||
&col_params_.instance.impl_details.timeout_seconds));
|
||||
col_params_.is_source = false;
|
||||
col_params_.instance.impl_details.subdiv_offsets = {0};
|
||||
|
||||
|
@ -31,6 +31,7 @@ REGISTER_OP("CollectiveReduce")
|
||||
.Attr("subdiv_offsets: list(int)")
|
||||
.Attr("wait_for: list(int) = []")
|
||||
.Attr("communication_hint: string = 'auto'")
|
||||
.Attr("timeout_seconds: float = 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
@ -43,6 +44,7 @@ REGISTER_OP("CollectiveGather")
|
||||
.Attr("instance_key: int")
|
||||
.Attr("shape: shape")
|
||||
.Attr("communication_hint: string = 'auto'")
|
||||
.Attr("timeout_seconds: float = 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
// Scalar input is not supported.
|
||||
@ -86,6 +88,7 @@ REGISTER_OP("CollectiveBcastSend")
|
||||
.Attr("instance_key: int")
|
||||
.Attr("shape: shape")
|
||||
.Attr("communication_hint: string = 'auto'")
|
||||
.Attr("timeout_seconds: float = 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ExplicitShape);
|
||||
|
||||
@ -97,6 +100,7 @@ REGISTER_OP("CollectiveBcastRecv")
|
||||
.Attr("instance_key: int")
|
||||
.Attr("shape: shape")
|
||||
.Attr("communication_hint: string = 'auto'")
|
||||
.Attr("timeout_seconds: float = 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ExplicitShape);
|
||||
|
||||
|
@ -20,8 +20,15 @@ from __future__ import print_function
|
||||
from tensorflow.python.ops import gen_collective_ops
|
||||
|
||||
|
||||
def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op,
|
||||
subdiv_offsets=(0,), communication_hint='auto'):
|
||||
def all_reduce(t,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
merge_op,
|
||||
final_op,
|
||||
subdiv_offsets=(0,),
|
||||
communication_hint='auto',
|
||||
timeout=0):
|
||||
"""Reduces tensors collectively, across devices.
|
||||
|
||||
Args:
|
||||
@ -40,6 +47,9 @@ def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op,
|
||||
communication_hint: preferred collective communication. The implementation
|
||||
may fall back to another mechanism. Options include `auto`, `ring`, and
|
||||
`nccl`.
|
||||
timeout: If set to a non zero, set a completion timeout to detect staleness.
|
||||
If the timer goes off, a DeadlineExceededError is raised.
|
||||
The timeout value in seconds. This feature is experimental.
|
||||
|
||||
Returns:
|
||||
An Op implementing the distributed reduction.
|
||||
@ -57,11 +67,16 @@ def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op,
|
||||
merge_op=merge_op,
|
||||
final_op=final_op,
|
||||
subdiv_offsets=subdiv_offsets,
|
||||
communication_hint=communication_hint.lower())
|
||||
communication_hint=communication_hint.lower(),
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
||||
def all_gather(t, group_size, group_key, instance_key,
|
||||
communication_hint='auto'):
|
||||
def all_gather(t,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint='auto',
|
||||
timeout=0):
|
||||
"""Accumulates tensors collectively, across devices, along first dimension.
|
||||
|
||||
Args:
|
||||
@ -73,6 +88,9 @@ def all_gather(t, group_size, group_key, instance_key,
|
||||
communication_hint: preferred collective communication. The implementation
|
||||
may fall back to another mechanism. Options include `auto`, `ring`, and
|
||||
`nccl`.
|
||||
timeout: If set to a non zero, set a completion timeout to detect staleness.
|
||||
If the timer goes off, a DeadlineExceededError is raised.
|
||||
The timeout value in seconds. This feature is experimental.
|
||||
|
||||
Returns:
|
||||
An Op implementing the distributed operation.
|
||||
@ -88,11 +106,18 @@ def all_gather(t, group_size, group_key, instance_key,
|
||||
group_size=group_size,
|
||||
group_key=group_key,
|
||||
instance_key=instance_key,
|
||||
communication_hint=communication_hint.lower())
|
||||
communication_hint=communication_hint.lower(),
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
||||
def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
|
||||
communication_hint='auto'):
|
||||
def broadcast_send(t,
|
||||
shape,
|
||||
dtype,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint='auto',
|
||||
timeout=0):
|
||||
"""Broadcasts one tensor to a group of others, across devices.
|
||||
|
||||
Args:
|
||||
@ -107,6 +132,9 @@ def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
|
||||
communication_hint: preferred collective communication. The implementation
|
||||
may fall back to another mechanism. Options include `auto`, `ring`, and
|
||||
`nccl`.
|
||||
timeout: If set to a non zero, set a completion timeout to detect staleness.
|
||||
If the timer goes off, a DeadlineExceededError is raised.
|
||||
The timeout value in seconds. This feature is experimental.
|
||||
|
||||
Returns:
|
||||
An Op implementing the distributed broadcast send.
|
||||
@ -139,11 +167,17 @@ def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
|
||||
group_size=group_size,
|
||||
group_key=group_key,
|
||||
instance_key=instance_key,
|
||||
communication_hint=communication_hint.lower())
|
||||
communication_hint=communication_hint.lower(),
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
||||
def broadcast_recv(shape, dtype, group_size, group_key, instance_key,
|
||||
communication_hint='auto'):
|
||||
def broadcast_recv(shape,
|
||||
dtype,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint='auto',
|
||||
timeout=0):
|
||||
"""Receives a broadcasts tensor, across devices.
|
||||
|
||||
Args:
|
||||
@ -157,6 +191,9 @@ def broadcast_recv(shape, dtype, group_size, group_key, instance_key,
|
||||
communication_hint: preferred collective communication. The implementation
|
||||
may fall back to another mechanism. Options include `auto`, `ring`, and
|
||||
`nccl`.
|
||||
timeout: If set to a non zero, set a completion timeout to detect staleness.
|
||||
If the timer goes off, a DeadlineExceededError is raised.
|
||||
The timeout value in seconds. This feature is experimental.
|
||||
|
||||
Returns:
|
||||
An Op implementing the broadcast receive.
|
||||
@ -173,4 +210,5 @@ def broadcast_recv(shape, dtype, group_size, group_key, instance_key,
|
||||
group_size=group_size,
|
||||
group_key=group_key,
|
||||
instance_key=instance_key,
|
||||
communication_hint=communication_hint.lower())
|
||||
communication_hint=communication_hint.lower(),
|
||||
timeout_seconds=timeout)
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.eager import context
|
||||
@ -40,11 +42,21 @@ from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
class CollectiveOpTest(test.TestCase):
|
||||
|
||||
def _testCollectiveReduce(self, inputs, expected, set_graph_key,
|
||||
communication_hint='auto', fp16=False,
|
||||
instance_key=1, merge_op='Add', final_op='Div'):
|
||||
def _testCollectiveReduce(self,
|
||||
inputs,
|
||||
expected,
|
||||
set_graph_key,
|
||||
communication_hint='auto',
|
||||
fp16=False,
|
||||
instance_key=1,
|
||||
merge_op='Add',
|
||||
final_op='Div',
|
||||
timeout=0,
|
||||
reported_group_size=None):
|
||||
group_key = 1
|
||||
group_size = len(inputs)
|
||||
if reported_group_size is None:
|
||||
reported_group_size = group_size
|
||||
device_type = 'CPU'
|
||||
config = config_pb2.ConfigProto(device_count={device_type: group_size})
|
||||
devices = ['/{}:{}'.format(device_type, i) for i in range(group_size)]
|
||||
@ -55,9 +67,16 @@ class CollectiveOpTest(test.TestCase):
|
||||
with ops.device(devices[i]):
|
||||
tensor = constant_op.constant(inputs[i], dtype=(
|
||||
dtypes.float16 if fp16 else dtypes.float32))
|
||||
colred.append(collective_ops.all_reduce(
|
||||
tensor, group_size, group_key, instance_key, merge_op, final_op,
|
||||
communication_hint=communication_hint))
|
||||
colred.append(
|
||||
collective_ops.all_reduce(
|
||||
tensor,
|
||||
reported_group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
merge_op,
|
||||
final_op,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout))
|
||||
run_options = config_pb2.RunOptions()
|
||||
if set_graph_key:
|
||||
run_options.experimental.collective_graph_key = 1
|
||||
@ -117,6 +136,69 @@ class CollectiveOpTest(test.TestCase):
|
||||
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
|
||||
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testCollectiveTimeoutV1(self):
|
||||
timeout = 4.5
|
||||
kwargs = dict(
|
||||
inputs=[[i + j + 0.1 for i in range(8)] for j in range(3)],
|
||||
expected=[1 + i + 0.1 for i in range(8)],
|
||||
set_graph_key=True,
|
||||
timeout=timeout)
|
||||
|
||||
self._testCollectiveReduce(**kwargs)
|
||||
|
||||
start_time = time.time()
|
||||
with self.assertRaisesRegex(
|
||||
errors.DeadlineExceededError,
|
||||
'Collective has timed out waiting for other workers'):
|
||||
self._testCollectiveReduce(
|
||||
reported_group_size=len(kwargs['inputs']) + 1, **kwargs)
|
||||
elapsed = time.time() - start_time
|
||||
self.assertAllGreaterEqual(elapsed, timeout)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testCollectiveTimeoutV2(self):
|
||||
context._reset_context()
|
||||
timeout = 4.5
|
||||
cpus = config.list_physical_devices('CPU')
|
||||
self.assertEqual(len(cpus), 1)
|
||||
config.set_logical_device_configuration(cpus[0], [
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration()
|
||||
])
|
||||
context.ensure_initialized()
|
||||
|
||||
@def_function.function
|
||||
def run_all_reduce(group_size, reported_group_size=None):
|
||||
group_key = 20
|
||||
instance_key = 30
|
||||
tensor = [1, 2, 3, 4]
|
||||
results = []
|
||||
if reported_group_size is None:
|
||||
reported_group_size = group_size
|
||||
for i in range(group_size):
|
||||
with ops.device('/CPU:{}'.format(i)):
|
||||
input_data = constant_op.constant(tensor)
|
||||
collective_op = collective_ops.all_reduce(
|
||||
input_data,
|
||||
group_size=reported_group_size,
|
||||
group_key=group_key,
|
||||
instance_key=instance_key,
|
||||
merge_op='Add',
|
||||
final_op='Id',
|
||||
timeout=timeout)
|
||||
results.append(collective_op)
|
||||
return results
|
||||
|
||||
run_all_reduce(2, 2)
|
||||
|
||||
start_time = time.time()
|
||||
with self.assertRaisesRegex(errors.DeadlineExceededError,
|
||||
'Collective has timed out during execution'):
|
||||
run_all_reduce(1, 2)
|
||||
elapsed = time.time() - start_time
|
||||
self.assertAllGreaterEqual(elapsed, timeout)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNcclHintFallbackToRingReduce(self):
|
||||
"""Tests that setting `communication_hint=nccl` works on non-GPU builds."""
|
||||
|
@ -702,15 +702,15 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastRecv"
|
||||
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
|
||||
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastSend"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveGather"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectivePermute"
|
||||
@ -718,7 +718,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveReduce"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CombinedNonMaxSuppression"
|
||||
|
@ -702,15 +702,15 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastRecv"
|
||||
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
|
||||
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastSend"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveGather"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectivePermute"
|
||||
@ -718,7 +718,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveReduce"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CombinedNonMaxSuppression"
|
||||
|
Loading…
Reference in New Issue
Block a user