Add CollectiveBcastSend/RecvV2 which takes input tensors rather than attributes.

CollectiveBcastSend/Recv accepts the following inputs as attributes on the op:
group_size, group_key, and instance_key.  Attributes imply these values are
embedded in the NodeDef.

The use case motivating this change is a compact representation for SPMD
computation.  The goal is to change those inputs from the collective op which
can be accepted during runtime to tensors rather than attributes.  This enables
the graph builder to avoid early explosion of the SPMD program.

This op is not exposed in the `tf.` namespace for now, and should be considered
experimental.

PiperOrigin-RevId: 348049802
Change-Id: I15ffcb562fb75cc64e5970ac7795cb81ff187fd8
This commit is contained in:
Ayush Dubey 2020-12-17 10:40:16 -08:00 committed by TensorFlower Gardener
parent 94264ee6f2
commit e38610399a
10 changed files with 447 additions and 1 deletions

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "CollectiveBcastRecvV2"
summary: "Receives a tensor value broadcast from another device."
visibility: HIDDEN
}

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "CollectiveBcastSendV2"
summary: "Broadcasts a tensor value to one or more other devices."
visibility: HIDDEN
}

View File

@ -830,7 +830,8 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
// to run asynchronously to avoid deadlock.
"CollectiveGather", "CollectiveGatherV2", "CollectiveReduce",
"CollectiveReduceV2", "CollectiveBcastSend", "CollectiveBcastRecv",
"NcclAllReduce", "Send", "Recv",
"CollectiveBcastSendV2", "CollectiveBcastRecvV2", "NcclAllReduce",
"Send", "Recv",
// Legacy random ops.
// See details in tensorflow/python/framework/auto_control_deps.py.

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
@ -742,5 +743,261 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2")
.HostMemory("instance_key"),
CollectiveGatherV2OpKernel);
class CollectiveBcastSendV2OpKernel : public AsyncOpKernel {
public:
explicit CollectiveBcastSendV2OpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
const bool is_source = true;
name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
}
protected:
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
name_),
done);
const Tensor& input = c->input(0);
const Tensor& group_size = c->input(1);
const Tensor& group_key = c->input(2);
const Tensor& instance_key = c->input(3);
OP_REQUIRES_ASYNC(
c, group_size.dims() == 0,
errors::Internal("Unexpected dimensions on input group_size"), done);
OP_REQUIRES_ASYNC(
c, group_key.dims() == 0,
errors::Internal("Unexpected dimensions on input group_key"), done);
OP_REQUIRES_ASYNC(
c, instance_key.dims() == 0,
errors::Internal("Unexpected dimensions on input instance_key"), done);
auto col_params = new CollectiveParams();
col_params->name = name_;
col_params->group.device_type = device_type_;
col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
col_params->instance.type = BROADCAST_COLLECTIVE;
col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
col_params->instance.data_type = data_type_;
col_params->instance.impl_details.communication_hint = communication_hint_;
col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
col_params->is_source = true;
// Add a default value for subdiv offsets, which is the same as the default
// value in the V1 op's attribute.
col_params->instance.impl_details.subdiv_offsets.push_back(0);
VLOG(1) << "CollectiveBcastSendV2 group_size "
<< col_params->group.group_size << " group_key "
<< col_params->group.group_key << " instance_key "
<< col_params->instance.instance_key;
auto done_with_cleanup = [col_params, done = std::move(done)]() {
delete col_params;
done();
};
// Allocate the output tensor, trying to reuse the input.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
c, c->forward_input_or_allocate_output({0}, 0, input.shape(), &output),
done_with_cleanup);
col_params->instance.shape = input.shape();
// Resolve the collective params.
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
// blocking work because it's not guaranteed that this call cannot block.
c->collective_executor()->RunClosure([c,
done = std::move(done_with_cleanup),
col_params, col_exec]() {
VLOG(1) << "CollectiveBcastSendV2 CompleteParams for collective "
<< col_params->name << " device " << c->device()->name()
<< " group " << col_params->group.group_key << " instance "
<< col_params->instance.instance_key;
col_exec->CompleteParamsAsync(
c->device()->attributes(), col_params, c->cancellation_manager(),
[c, done = std::move(done), col_params, col_exec](const Status& s) {
if (s.ok()) {
auto actual_done = [c, group_key = col_params->group.group_key,
instance_key =
col_params->instance.instance_key,
done = std::move(done)](const Status& s) {
VLOG(1) << "CollectiveBcastSendV2 ExecuteAsync done for "
"collective "
<< c->op_kernel().name() << " device "
<< c->device()->name() << " group " << group_key
<< " instance " << instance_key << " status " << s;
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
VLOG(1) << "CollectiveBcastSendV2 ExecuteAsync start for "
"collective "
<< col_params->name << " device " << c->device()->name()
<< " group " << col_params->group.group_key
<< " instance " << col_params->instance.instance_key;
col_exec->ExecuteAsync(
c, *col_params,
CollectiveKey(c, col_params->group.group_key,
col_params->instance.instance_key),
actual_done);
} else {
c->SetStatus(s);
done();
}
});
});
}
private:
DeviceType device_type_;
DataType data_type_ = DT_INVALID;
string communication_hint_;
float timeout_seconds_ = 0;
string name_;
};
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2").Device(DEVICE_CPU),
CollectiveBcastSendV2OpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2")
.Device(DEVICE_GPU)
.HostMemory("group_size")
.HostMemory("group_key")
.HostMemory("instance_key"),
CollectiveBcastSendV2OpKernel);
class CollectiveBcastRecvV2OpKernel : public AsyncOpKernel {
public:
explicit CollectiveBcastRecvV2OpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
const bool is_source = false;
name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
}
protected:
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
name_),
done);
const Tensor& group_size = c->input(0);
const Tensor& group_key = c->input(1);
const Tensor& instance_key = c->input(2);
const Tensor& shape_tensor = c->input(3);
OP_REQUIRES_ASYNC(
c, group_size.dims() == 0,
errors::Internal("Unexpected dimensions on input group_size"), done);
OP_REQUIRES_ASYNC(
c, group_key.dims() == 0,
errors::Internal("Unexpected dimensions on input group_key"), done);
OP_REQUIRES_ASYNC(
c, instance_key.dims() == 0,
errors::Internal("Unexpected dimensions on input instance_key"), done);
auto col_params = new CollectiveParams();
auto done_with_cleanup = [col_params, done = std::move(done)]() {
delete col_params;
done();
};
OP_REQUIRES_OK_ASYNC(
c, tensor::MakeShape(shape_tensor, &col_params->instance.shape),
done_with_cleanup);
col_params->name = name_;
col_params->group.device_type = device_type_;
col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
col_params->instance.type = BROADCAST_COLLECTIVE;
col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
col_params->instance.data_type = data_type_;
col_params->instance.impl_details.communication_hint = communication_hint_;
col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
col_params->is_source = false;
// Add a default value for subdiv offsets, which is the same as the default
// value in the V1 op's attribute.
col_params->instance.impl_details.subdiv_offsets.push_back(0);
VLOG(1) << "CollectiveBcastRecvV2 group_size "
<< col_params->group.group_size << " group_key "
<< col_params->group.group_key << " instance_key "
<< col_params->instance.instance_key;
// Allocate the output tensor.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(c,
c->forward_input_or_allocate_output(
{0}, 0, col_params->instance.shape, &output),
done_with_cleanup);
// Resolve the collective params.
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
// blocking work because it's not guaranteed that this call cannot block.
c->collective_executor()->RunClosure([c,
done = std::move(done_with_cleanup),
col_params, col_exec]() {
VLOG(1) << "CollectiveBcastRecvV2 CompleteParams for collective "
<< col_params->name << " device " << c->device()->name()
<< " group " << col_params->group.group_key << " instance "
<< col_params->instance.instance_key;
col_exec->CompleteParamsAsync(
c->device()->attributes(), col_params, c->cancellation_manager(),
[c, done = std::move(done), col_params, col_exec](const Status& s) {
if (s.ok()) {
auto actual_done = [c, group_key = col_params->group.group_key,
instance_key =
col_params->instance.instance_key,
done = std::move(done)](const Status& s) {
VLOG(1) << "CollectiveBcastRecvV2 ExecuteAsync done for "
"collective "
<< c->op_kernel().name() << " device "
<< c->device()->name() << " group " << group_key
<< " instance " << instance_key << " status " << s;
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
VLOG(1) << "CollectiveBcastRecvV2 ExecuteAsync start for "
"collective "
<< col_params->name << " device " << c->device()->name()
<< " group " << col_params->group.group_key
<< " instance " << col_params->instance.instance_key;
col_exec->ExecuteAsync(
c, *col_params,
CollectiveKey(c, col_params->group.group_key,
col_params->instance.instance_key),
actual_done);
} else {
c->SetStatus(s);
done();
}
});
});
}
private:
DeviceType device_type_;
DataType data_type_ = DT_INVALID;
string communication_hint_;
float timeout_seconds_ = 0;
string name_;
};
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2").Device(DEVICE_CPU),
CollectiveBcastRecvV2OpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2")
.Device(DEVICE_GPU)
.HostMemory("group_size")
.HostMemory("group_key")
.HostMemory("instance_key")
.HostMemory("shape"),
CollectiveBcastRecvV2OpKernel);
} // namespace
} // namespace tensorflow

View File

@ -145,4 +145,35 @@ REGISTER_OP("CollectiveGatherV2")
return Status::OK();
});
REGISTER_OP("CollectiveBcastSendV2")
.Input("input: T")
.Output("data: T")
.Attr("T: {bool, float, float16, float64, int32, int64}")
.Input("group_size: int32")
.Input("group_key: int32")
.Input("instance_key: int32")
.Attr("communication_hint: string = 'auto'")
.Attr("timeout_seconds: float = 0")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("CollectiveBcastRecvV2")
.Output("data: T")
.Attr("T: {bool, float, float16, float64, int32, int64}")
.Input("group_size: int32")
.Input("group_key: int32")
.Input("instance_key: int32")
.Input("shape: Tshape")
.Attr("Tshape: {int32, int64} = DT_INT32")
.Attr("communication_hint: string = 'auto'")
.Attr("timeout_seconds: float = 0")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
// The output shape is given by the `shape` input at index 3.
shape_inference::ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(/*input_idx=*/3, &out));
c->set_output(/*idx=*/0, out);
return Status::OK();
});
} // namespace tensorflow

View File

@ -45,7 +45,9 @@ ASYNC_STATEFUL_OPS = [
"CollectiveReduce",
"CollectiveReduceV2",
"CollectiveBcastSend",
"CollectiveBcastSendV2",
"CollectiveBcastRecv",
"CollectiveBcastRecvV2",
"NcclAllReduce",
# We do not add "Send" here since we want it to be added as a control output
# in order to avoid being pruned.

View File

@ -43,6 +43,8 @@ from tensorflow.python.platform import test
class CollectiveOpsV1(object):
all_reduce = _collective_ops.all_reduce
all_gather = _collective_ops.all_gather
broadcast_send = _collective_ops.broadcast_send
broadcast_recv = _collective_ops.broadcast_recv
class CollectiveOpsV2(object):
@ -63,6 +65,25 @@ class CollectiveOpsV2(object):
return _collective_ops.all_gather_v2(t, group_size, group_key, instance_key,
*args, **kwargs)
@staticmethod
def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
*args, **kwargs):
group_size = array_ops.identity(group_size)
group_key = array_ops.identity(group_key)
instance_key = array_ops.identity(instance_key)
return _collective_ops.broadcast_send_v2(t, group_size, group_key,
instance_key, *args, **kwargs)
@staticmethod
def broadcast_recv(shape, dtype, group_size, group_key, instance_key, *args,
**kwargs):
group_size = array_ops.identity(group_size)
group_key = array_ops.identity(group_key)
instance_key = array_ops.identity(instance_key)
shape = array_ops.identity(shape)
return _collective_ops.broadcast_recv_v2(
shape, dtype, group_size, group_key, instance_key, *args, **kwargs)
device_combination = (
combinations.combine(device='CPU', communication='RING', required_gpus=0) +
@ -191,6 +212,42 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
for result in run_all_gather_2devices():
self.assertAllClose(result, [1., 1.], rtol=1e-5, atol=1e-5)
def testBroadcast(self, collective_ops, device, communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
@def_function.function
def run_broadcast_2devices():
shape = [3]
in_value = constant_op.constant([1., 2., 3.], shape=shape)
group_size = 2
group_key = 2
instance_key = 2
collectives = []
with ops.device(dev0):
collectives.append(
collective_ops.broadcast_send(
in_value,
shape,
in_value.dtype,
group_size,
group_key,
instance_key,
communication_hint=communication))
with ops.device(dev1):
collectives.append(
collective_ops.broadcast_recv(
shape,
in_value.dtype,
group_size,
group_key,
instance_key,
communication_hint=communication))
return collectives
for result in run_broadcast_2devices():
self.assertAllClose(result, [1., 2., 3.], rtol=1e-5, atol=1e-5)
def testInstanceKeyScopedUnderGroupKey(self, collective_ops, device,
communication):
if device == 'GPU' and context.num_gpus() < 4:

View File

@ -261,6 +261,40 @@ def broadcast_send(t,
timeout_seconds=timeout)
def broadcast_send_v2(t,
group_size,
group_key,
instance_key,
communication_hint='auto',
timeout=0):
"""Broadcasts one tensor to a group of others, across devices.
Args:
t: the tensor to be sent.
group_size: an int32 tensor. One plus the number of receiving tensors, i.e.
the total number of devices participating. Each tensor must reside on a
different device.
group_key: an int32 tensor identifying the group of devices.
instance_key: an int32 tensor identifying the participating group of Ops.
communication_hint: preferred collective communication. The implementation
may fall back to another mechanism. Options include `auto`, `ring`, and
`nccl`.
timeout: If set to a non zero, set a completion timeout to detect staleness.
If the timer goes off, a DeadlineExceededError is raised.
The timeout value in seconds. This feature is experimental.
Returns:
An Op implementing the distributed broadcast send.
"""
return gen_collective_ops.collective_bcast_send_v2(
t,
group_size=group_size,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication_hint.lower(),
timeout_seconds=timeout)
def broadcast_recv(shape,
dtype,
group_size,
@ -302,3 +336,41 @@ def broadcast_recv(shape,
instance_key=instance_key,
communication_hint=communication_hint.lower(),
timeout_seconds=timeout)
def broadcast_recv_v2(shape,
dtype,
group_size,
group_key,
instance_key,
communication_hint='auto',
timeout=0):
"""Receives a broadcasts tensor, across devices.
Args:
shape: an int tensor. Shape of the tensor to be received.
dtype: Type of the tensor to be received.
group_size: an int32 tensor. One plus the number of receiving tensors, i.e.
the total number of devices participating. Each tensor must reside on a
different device.
group_key: an int32 tensor identifying the group of devices.
instance_key: an int32 tensor identifying the participating group of Ops.
communication_hint: preferred collective communication. The implementation
may fall back to another mechanism. Options include `auto`, `ring`, and
`nccl`.
timeout: If set to a non zero, set a completion timeout to detect staleness.
If the timer goes off, a DeadlineExceededError is raised.
The timeout value in seconds. This feature is experimental.
Returns:
An Op implementing the broadcast receive.
"""
return gen_collective_ops.collective_bcast_recv_v2(
T=dtype,
group_size=group_size,
group_key=group_key,
instance_key=instance_key,
shape=shape,
communication_hint=communication_hint.lower(),
timeout_seconds=timeout)

View File

@ -752,10 +752,18 @@ tf_module {
name: "CollectiveBcastRecv"
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CollectiveBcastRecvV2"
argspec: "args=[\'group_size\', \'group_key\', \'instance_key\', \'shape\', \'T\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CollectiveBcastSend"
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CollectiveBcastSendV2"
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CollectiveGather"
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "

View File

@ -752,10 +752,18 @@ tf_module {
name: "CollectiveBcastRecv"
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CollectiveBcastRecvV2"
argspec: "args=[\'group_size\', \'group_key\', \'instance_key\', \'shape\', \'T\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CollectiveBcastSend"
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CollectiveBcastSendV2"
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
}
member_method {
name: "CollectiveGather"
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "