Add CollectiveBcastSend/RecvV2 which takes input tensors rather than attributes.
CollectiveBcastSend/Recv accepts the following inputs as attributes on the op: group_size, group_key, and instance_key. Attributes imply these values are embedded in the NodeDef. The use case motivating this change is a compact representation for SPMD computation. The goal is to change those inputs from the collective op which can be accepted during runtime to tensors rather than attributes. This enables the graph builder to avoid early explosion of the SPMD program. This op is not exposed in the `tf.` namespace for now, and should be considered experimental. PiperOrigin-RevId: 348049802 Change-Id: I15ffcb562fb75cc64e5970ac7795cb81ff187fd8
This commit is contained in:
parent
94264ee6f2
commit
e38610399a
tensorflow
core
api_def/base_api
grappler/optimizers
kernels
ops
python
tools/api/golden
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "CollectiveBcastRecvV2"
|
||||
summary: "Receives a tensor value broadcast from another device."
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "CollectiveBcastSendV2"
|
||||
summary: "Broadcasts a tensor value to one or more other devices."
|
||||
visibility: HIDDEN
|
||||
}
|
@ -830,7 +830,8 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
|
||||
// to run asynchronously to avoid deadlock.
|
||||
"CollectiveGather", "CollectiveGatherV2", "CollectiveReduce",
|
||||
"CollectiveReduceV2", "CollectiveBcastSend", "CollectiveBcastRecv",
|
||||
"NcclAllReduce", "Send", "Recv",
|
||||
"CollectiveBcastSendV2", "CollectiveBcastRecvV2", "NcclAllReduce",
|
||||
"Send", "Recv",
|
||||
|
||||
// Legacy random ops.
|
||||
// See details in tensorflow/python/framework/auto_control_deps.py.
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -742,5 +743,261 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2")
|
||||
.HostMemory("instance_key"),
|
||||
CollectiveGatherV2OpKernel);
|
||||
|
||||
class CollectiveBcastSendV2OpKernel : public AsyncOpKernel {
|
||||
public:
|
||||
explicit CollectiveBcastSendV2OpKernel(OpKernelConstruction* c)
|
||||
: AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
|
||||
const bool is_source = true;
|
||||
name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
|
||||
}
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
name_),
|
||||
done);
|
||||
const Tensor& input = c->input(0);
|
||||
const Tensor& group_size = c->input(1);
|
||||
const Tensor& group_key = c->input(2);
|
||||
const Tensor& instance_key = c->input(3);
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, group_size.dims() == 0,
|
||||
errors::Internal("Unexpected dimensions on input group_size"), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, group_key.dims() == 0,
|
||||
errors::Internal("Unexpected dimensions on input group_key"), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, instance_key.dims() == 0,
|
||||
errors::Internal("Unexpected dimensions on input instance_key"), done);
|
||||
|
||||
auto col_params = new CollectiveParams();
|
||||
col_params->name = name_;
|
||||
col_params->group.device_type = device_type_;
|
||||
col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
|
||||
col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
|
||||
col_params->instance.type = BROADCAST_COLLECTIVE;
|
||||
col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
|
||||
col_params->instance.data_type = data_type_;
|
||||
col_params->instance.impl_details.communication_hint = communication_hint_;
|
||||
col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
|
||||
col_params->is_source = true;
|
||||
// Add a default value for subdiv offsets, which is the same as the default
|
||||
// value in the V1 op's attribute.
|
||||
col_params->instance.impl_details.subdiv_offsets.push_back(0);
|
||||
VLOG(1) << "CollectiveBcastSendV2 group_size "
|
||||
<< col_params->group.group_size << " group_key "
|
||||
<< col_params->group.group_key << " instance_key "
|
||||
<< col_params->instance.instance_key;
|
||||
|
||||
auto done_with_cleanup = [col_params, done = std::move(done)]() {
|
||||
delete col_params;
|
||||
done();
|
||||
};
|
||||
|
||||
// Allocate the output tensor, trying to reuse the input.
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
c, c->forward_input_or_allocate_output({0}, 0, input.shape(), &output),
|
||||
done_with_cleanup);
|
||||
col_params->instance.shape = input.shape();
|
||||
|
||||
// Resolve the collective params.
|
||||
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
|
||||
// blocking work because it's not guaranteed that this call cannot block.
|
||||
c->collective_executor()->RunClosure([c,
|
||||
done = std::move(done_with_cleanup),
|
||||
col_params, col_exec]() {
|
||||
VLOG(1) << "CollectiveBcastSendV2 CompleteParams for collective "
|
||||
<< col_params->name << " device " << c->device()->name()
|
||||
<< " group " << col_params->group.group_key << " instance "
|
||||
<< col_params->instance.instance_key;
|
||||
col_exec->CompleteParamsAsync(
|
||||
c->device()->attributes(), col_params, c->cancellation_manager(),
|
||||
[c, done = std::move(done), col_params, col_exec](const Status& s) {
|
||||
if (s.ok()) {
|
||||
auto actual_done = [c, group_key = col_params->group.group_key,
|
||||
instance_key =
|
||||
col_params->instance.instance_key,
|
||||
done = std::move(done)](const Status& s) {
|
||||
VLOG(1) << "CollectiveBcastSendV2 ExecuteAsync done for "
|
||||
"collective "
|
||||
<< c->op_kernel().name() << " device "
|
||||
<< c->device()->name() << " group " << group_key
|
||||
<< " instance " << instance_key << " status " << s;
|
||||
OP_REQUIRES_OK_ASYNC(c, s, done);
|
||||
done();
|
||||
};
|
||||
VLOG(1) << "CollectiveBcastSendV2 ExecuteAsync start for "
|
||||
"collective "
|
||||
<< col_params->name << " device " << c->device()->name()
|
||||
<< " group " << col_params->group.group_key
|
||||
<< " instance " << col_params->instance.instance_key;
|
||||
col_exec->ExecuteAsync(
|
||||
c, *col_params,
|
||||
CollectiveKey(c, col_params->group.group_key,
|
||||
col_params->instance.instance_key),
|
||||
actual_done);
|
||||
} else {
|
||||
c->SetStatus(s);
|
||||
done();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
DeviceType device_type_;
|
||||
DataType data_type_ = DT_INVALID;
|
||||
string communication_hint_;
|
||||
float timeout_seconds_ = 0;
|
||||
string name_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2").Device(DEVICE_CPU),
|
||||
CollectiveBcastSendV2OpKernel);
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("group_size")
|
||||
.HostMemory("group_key")
|
||||
.HostMemory("instance_key"),
|
||||
CollectiveBcastSendV2OpKernel);
|
||||
|
||||
class CollectiveBcastRecvV2OpKernel : public AsyncOpKernel {
|
||||
public:
|
||||
explicit CollectiveBcastRecvV2OpKernel(OpKernelConstruction* c)
|
||||
: AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
|
||||
const bool is_source = false;
|
||||
name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
|
||||
}
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
name_),
|
||||
done);
|
||||
const Tensor& group_size = c->input(0);
|
||||
const Tensor& group_key = c->input(1);
|
||||
const Tensor& instance_key = c->input(2);
|
||||
const Tensor& shape_tensor = c->input(3);
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, group_size.dims() == 0,
|
||||
errors::Internal("Unexpected dimensions on input group_size"), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, group_key.dims() == 0,
|
||||
errors::Internal("Unexpected dimensions on input group_key"), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, instance_key.dims() == 0,
|
||||
errors::Internal("Unexpected dimensions on input instance_key"), done);
|
||||
|
||||
auto col_params = new CollectiveParams();
|
||||
auto done_with_cleanup = [col_params, done = std::move(done)]() {
|
||||
delete col_params;
|
||||
done();
|
||||
};
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
c, tensor::MakeShape(shape_tensor, &col_params->instance.shape),
|
||||
done_with_cleanup);
|
||||
col_params->name = name_;
|
||||
col_params->group.device_type = device_type_;
|
||||
col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
|
||||
col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
|
||||
col_params->instance.type = BROADCAST_COLLECTIVE;
|
||||
col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
|
||||
col_params->instance.data_type = data_type_;
|
||||
col_params->instance.impl_details.communication_hint = communication_hint_;
|
||||
col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
|
||||
col_params->is_source = false;
|
||||
// Add a default value for subdiv offsets, which is the same as the default
|
||||
// value in the V1 op's attribute.
|
||||
col_params->instance.impl_details.subdiv_offsets.push_back(0);
|
||||
VLOG(1) << "CollectiveBcastRecvV2 group_size "
|
||||
<< col_params->group.group_size << " group_key "
|
||||
<< col_params->group.group_key << " instance_key "
|
||||
<< col_params->instance.instance_key;
|
||||
|
||||
// Allocate the output tensor.
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK_ASYNC(c,
|
||||
c->forward_input_or_allocate_output(
|
||||
{0}, 0, col_params->instance.shape, &output),
|
||||
done_with_cleanup);
|
||||
|
||||
// Resolve the collective params.
|
||||
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
|
||||
// blocking work because it's not guaranteed that this call cannot block.
|
||||
c->collective_executor()->RunClosure([c,
|
||||
done = std::move(done_with_cleanup),
|
||||
col_params, col_exec]() {
|
||||
VLOG(1) << "CollectiveBcastRecvV2 CompleteParams for collective "
|
||||
<< col_params->name << " device " << c->device()->name()
|
||||
<< " group " << col_params->group.group_key << " instance "
|
||||
<< col_params->instance.instance_key;
|
||||
col_exec->CompleteParamsAsync(
|
||||
c->device()->attributes(), col_params, c->cancellation_manager(),
|
||||
[c, done = std::move(done), col_params, col_exec](const Status& s) {
|
||||
if (s.ok()) {
|
||||
auto actual_done = [c, group_key = col_params->group.group_key,
|
||||
instance_key =
|
||||
col_params->instance.instance_key,
|
||||
done = std::move(done)](const Status& s) {
|
||||
VLOG(1) << "CollectiveBcastRecvV2 ExecuteAsync done for "
|
||||
"collective "
|
||||
<< c->op_kernel().name() << " device "
|
||||
<< c->device()->name() << " group " << group_key
|
||||
<< " instance " << instance_key << " status " << s;
|
||||
OP_REQUIRES_OK_ASYNC(c, s, done);
|
||||
done();
|
||||
};
|
||||
VLOG(1) << "CollectiveBcastRecvV2 ExecuteAsync start for "
|
||||
"collective "
|
||||
<< col_params->name << " device " << c->device()->name()
|
||||
<< " group " << col_params->group.group_key
|
||||
<< " instance " << col_params->instance.instance_key;
|
||||
col_exec->ExecuteAsync(
|
||||
c, *col_params,
|
||||
CollectiveKey(c, col_params->group.group_key,
|
||||
col_params->instance.instance_key),
|
||||
actual_done);
|
||||
} else {
|
||||
c->SetStatus(s);
|
||||
done();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
DeviceType device_type_;
|
||||
DataType data_type_ = DT_INVALID;
|
||||
string communication_hint_;
|
||||
float timeout_seconds_ = 0;
|
||||
string name_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2").Device(DEVICE_CPU),
|
||||
CollectiveBcastRecvV2OpKernel);
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("group_size")
|
||||
.HostMemory("group_key")
|
||||
.HostMemory("instance_key")
|
||||
.HostMemory("shape"),
|
||||
CollectiveBcastRecvV2OpKernel);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -145,4 +145,35 @@ REGISTER_OP("CollectiveGatherV2")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("CollectiveBcastSendV2")
|
||||
.Input("input: T")
|
||||
.Output("data: T")
|
||||
.Attr("T: {bool, float, float16, float64, int32, int64}")
|
||||
.Input("group_size: int32")
|
||||
.Input("group_key: int32")
|
||||
.Input("instance_key: int32")
|
||||
.Attr("communication_hint: string = 'auto'")
|
||||
.Attr("timeout_seconds: float = 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("CollectiveBcastRecvV2")
|
||||
.Output("data: T")
|
||||
.Attr("T: {bool, float, float16, float64, int32, int64}")
|
||||
.Input("group_size: int32")
|
||||
.Input("group_key: int32")
|
||||
.Input("instance_key: int32")
|
||||
.Input("shape: Tshape")
|
||||
.Attr("Tshape: {int32, int64} = DT_INT32")
|
||||
.Attr("communication_hint: string = 'auto'")
|
||||
.Attr("timeout_seconds: float = 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
// The output shape is given by the `shape` input at index 3.
|
||||
shape_inference::ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(/*input_idx=*/3, &out));
|
||||
c->set_output(/*idx=*/0, out);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -45,7 +45,9 @@ ASYNC_STATEFUL_OPS = [
|
||||
"CollectiveReduce",
|
||||
"CollectiveReduceV2",
|
||||
"CollectiveBcastSend",
|
||||
"CollectiveBcastSendV2",
|
||||
"CollectiveBcastRecv",
|
||||
"CollectiveBcastRecvV2",
|
||||
"NcclAllReduce",
|
||||
# We do not add "Send" here since we want it to be added as a control output
|
||||
# in order to avoid being pruned.
|
||||
|
@ -43,6 +43,8 @@ from tensorflow.python.platform import test
|
||||
class CollectiveOpsV1(object):
|
||||
all_reduce = _collective_ops.all_reduce
|
||||
all_gather = _collective_ops.all_gather
|
||||
broadcast_send = _collective_ops.broadcast_send
|
||||
broadcast_recv = _collective_ops.broadcast_recv
|
||||
|
||||
|
||||
class CollectiveOpsV2(object):
|
||||
@ -63,6 +65,25 @@ class CollectiveOpsV2(object):
|
||||
return _collective_ops.all_gather_v2(t, group_size, group_key, instance_key,
|
||||
*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
|
||||
*args, **kwargs):
|
||||
group_size = array_ops.identity(group_size)
|
||||
group_key = array_ops.identity(group_key)
|
||||
instance_key = array_ops.identity(instance_key)
|
||||
return _collective_ops.broadcast_send_v2(t, group_size, group_key,
|
||||
instance_key, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def broadcast_recv(shape, dtype, group_size, group_key, instance_key, *args,
|
||||
**kwargs):
|
||||
group_size = array_ops.identity(group_size)
|
||||
group_key = array_ops.identity(group_key)
|
||||
instance_key = array_ops.identity(instance_key)
|
||||
shape = array_ops.identity(shape)
|
||||
return _collective_ops.broadcast_recv_v2(
|
||||
shape, dtype, group_size, group_key, instance_key, *args, **kwargs)
|
||||
|
||||
|
||||
device_combination = (
|
||||
combinations.combine(device='CPU', communication='RING', required_gpus=0) +
|
||||
@ -191,6 +212,42 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
for result in run_all_gather_2devices():
|
||||
self.assertAllClose(result, [1., 1.], rtol=1e-5, atol=1e-5)
|
||||
|
||||
def testBroadcast(self, collective_ops, device, communication):
|
||||
dev0 = '/device:%s:0' % device
|
||||
dev1 = '/device:%s:1' % device
|
||||
|
||||
@def_function.function
|
||||
def run_broadcast_2devices():
|
||||
shape = [3]
|
||||
in_value = constant_op.constant([1., 2., 3.], shape=shape)
|
||||
group_size = 2
|
||||
group_key = 2
|
||||
instance_key = 2
|
||||
collectives = []
|
||||
with ops.device(dev0):
|
||||
collectives.append(
|
||||
collective_ops.broadcast_send(
|
||||
in_value,
|
||||
shape,
|
||||
in_value.dtype,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication))
|
||||
with ops.device(dev1):
|
||||
collectives.append(
|
||||
collective_ops.broadcast_recv(
|
||||
shape,
|
||||
in_value.dtype,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication))
|
||||
return collectives
|
||||
|
||||
for result in run_broadcast_2devices():
|
||||
self.assertAllClose(result, [1., 2., 3.], rtol=1e-5, atol=1e-5)
|
||||
|
||||
def testInstanceKeyScopedUnderGroupKey(self, collective_ops, device,
|
||||
communication):
|
||||
if device == 'GPU' and context.num_gpus() < 4:
|
||||
|
@ -261,6 +261,40 @@ def broadcast_send(t,
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
||||
def broadcast_send_v2(t,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint='auto',
|
||||
timeout=0):
|
||||
"""Broadcasts one tensor to a group of others, across devices.
|
||||
|
||||
Args:
|
||||
t: the tensor to be sent.
|
||||
group_size: an int32 tensor. One plus the number of receiving tensors, i.e.
|
||||
the total number of devices participating. Each tensor must reside on a
|
||||
different device.
|
||||
group_key: an int32 tensor identifying the group of devices.
|
||||
instance_key: an int32 tensor identifying the participating group of Ops.
|
||||
communication_hint: preferred collective communication. The implementation
|
||||
may fall back to another mechanism. Options include `auto`, `ring`, and
|
||||
`nccl`.
|
||||
timeout: If set to a non zero, set a completion timeout to detect staleness.
|
||||
If the timer goes off, a DeadlineExceededError is raised.
|
||||
The timeout value in seconds. This feature is experimental.
|
||||
|
||||
Returns:
|
||||
An Op implementing the distributed broadcast send.
|
||||
"""
|
||||
return gen_collective_ops.collective_bcast_send_v2(
|
||||
t,
|
||||
group_size=group_size,
|
||||
group_key=group_key,
|
||||
instance_key=instance_key,
|
||||
communication_hint=communication_hint.lower(),
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
||||
def broadcast_recv(shape,
|
||||
dtype,
|
||||
group_size,
|
||||
@ -302,3 +336,41 @@ def broadcast_recv(shape,
|
||||
instance_key=instance_key,
|
||||
communication_hint=communication_hint.lower(),
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
||||
def broadcast_recv_v2(shape,
|
||||
dtype,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint='auto',
|
||||
timeout=0):
|
||||
"""Receives a broadcasts tensor, across devices.
|
||||
|
||||
Args:
|
||||
shape: an int tensor. Shape of the tensor to be received.
|
||||
dtype: Type of the tensor to be received.
|
||||
group_size: an int32 tensor. One plus the number of receiving tensors, i.e.
|
||||
the total number of devices participating. Each tensor must reside on a
|
||||
different device.
|
||||
group_key: an int32 tensor identifying the group of devices.
|
||||
instance_key: an int32 tensor identifying the participating group of Ops.
|
||||
communication_hint: preferred collective communication. The implementation
|
||||
may fall back to another mechanism. Options include `auto`, `ring`, and
|
||||
`nccl`.
|
||||
timeout: If set to a non zero, set a completion timeout to detect staleness.
|
||||
If the timer goes off, a DeadlineExceededError is raised.
|
||||
The timeout value in seconds. This feature is experimental.
|
||||
|
||||
Returns:
|
||||
An Op implementing the broadcast receive.
|
||||
"""
|
||||
return gen_collective_ops.collective_bcast_recv_v2(
|
||||
T=dtype,
|
||||
group_size=group_size,
|
||||
group_key=group_key,
|
||||
instance_key=instance_key,
|
||||
shape=shape,
|
||||
communication_hint=communication_hint.lower(),
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
@ -752,10 +752,18 @@ tf_module {
|
||||
name: "CollectiveBcastRecv"
|
||||
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastRecvV2"
|
||||
argspec: "args=[\'group_size\', \'group_key\', \'instance_key\', \'shape\', \'T\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastSend"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastSendV2"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveGather"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
|
@ -752,10 +752,18 @@ tf_module {
|
||||
name: "CollectiveBcastRecv"
|
||||
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastRecvV2"
|
||||
argspec: "args=[\'group_size\', \'group_key\', \'instance_key\', \'shape\', \'T\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastSend"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastSendV2"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveGather"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user