Adding NCCL sum op, register all_sum gradient.
Streamlining nccl test. PiperOrigin-RevId: 168347428
This commit is contained in:
parent
bc300318e7
commit
9b9e54b344
tensorflow/contrib/nccl
@ -18,6 +18,7 @@
|
||||
@@all_min
|
||||
@@all_prod
|
||||
@@all_sum
|
||||
@@reduce_sum
|
||||
@@broadcast
|
||||
|
||||
"""
|
||||
@ -31,6 +32,7 @@ from tensorflow.contrib.nccl.python.ops.nccl_ops import all_min
|
||||
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_prod
|
||||
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_sum
|
||||
from tensorflow.contrib.nccl.python.ops.nccl_ops import broadcast
|
||||
from tensorflow.contrib.nccl.python.ops.nccl_ops import reduce_sum
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
||||
|
@ -260,7 +260,7 @@ NcclManager::Communicator* NcclManager::GetCommunicator(
|
||||
|
||||
std::vector<ncclComm_t> nccl_comms(num_devices);
|
||||
auto result = ncclCommInitAll(nccl_comms.data(), num_devices, devices.data());
|
||||
CHECK_EQ(result, ncclSuccess);
|
||||
CHECK_EQ(result, ncclSuccess) << ncclGetErrorString(result);
|
||||
for (int rank = 0; rank < num_devices; ++rank) {
|
||||
members[rank].nccl_comm = nccl_comms[rank];
|
||||
}
|
||||
@ -307,6 +307,35 @@ void NcclManager::AddBroadcastRecv(
|
||||
kBroadcast, ncclSum /* unused */);
|
||||
}
|
||||
|
||||
void NcclManager::AddReduceSend(int num_devices, const string& key,
|
||||
ncclRedOp_t reduction_op,
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
int gpu_device_id, EventMgr* event_mgr,
|
||||
perftools::gputools::Stream* tensor_stream,
|
||||
const Tensor* in_t, Tensor* temp_t,
|
||||
DoneCallback done_callback) {
|
||||
std::unique_ptr<Participant> participant(
|
||||
new Participant(in_t, temp_t, event_mgr, tensor_stream, executor,
|
||||
gpu_device_id, std::move(done_callback)));
|
||||
AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
|
||||
kReduce, reduction_op);
|
||||
}
|
||||
|
||||
void NcclManager::AddReduceRecv(int num_devices, const string& key,
|
||||
ncclRedOp_t reduction_op,
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
int gpu_device_id, EventMgr* event_mgr,
|
||||
perftools::gputools::Stream* tensor_stream,
|
||||
const Tensor* in_t, Tensor* out_t,
|
||||
DoneCallback done_callback) {
|
||||
std::unique_ptr<Participant> participant(
|
||||
new Participant(in_t, out_t, event_mgr, tensor_stream, executor,
|
||||
gpu_device_id, std::move(done_callback)));
|
||||
participant->root = true;
|
||||
AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
|
||||
kReduce, reduction_op);
|
||||
}
|
||||
|
||||
void NcclManager::AddParticipant(int num_devices, const string& key,
|
||||
std::unique_ptr<Participant> participant,
|
||||
DataType data_type,
|
||||
@ -431,6 +460,14 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
|
||||
collective->root_rank, nccl_comm, *cu_stream);
|
||||
break;
|
||||
}
|
||||
case kReduce: {
|
||||
const void* sendbuff = p->in_t->tensor_data().data();
|
||||
void* recvbuff = const_cast<char*>(p->out_t->tensor_data().data());
|
||||
nccl_result = ncclReduce(sendbuff, recvbuff, p->in_t->NumElements(),
|
||||
data_type, collective->reduction_op,
|
||||
collective->root_rank, nccl_comm, *cu_stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Run the done_callback when the nccl kernel finishes running.
|
||||
@ -441,7 +478,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
|
||||
// Propagate the error, but note that if other members of the collective
|
||||
// did launch their kernels, then they are hanging.
|
||||
collective->participants[rank]->done_callback(errors::Unknown(
|
||||
"Error invoking AllReduce: ", ncclGetErrorString(nccl_result)));
|
||||
"Error invoking NCCL: ", ncclGetErrorString(nccl_result)));
|
||||
}
|
||||
|
||||
// TODO(cwhipkey): use RefCounted after figuring out how to use in a
|
||||
|
@ -75,10 +75,28 @@ class NcclManager {
|
||||
perftools::gputools::Stream* tensor_stream,
|
||||
Tensor* out_t, DoneCallback done_callback);
|
||||
|
||||
// AddReduceSend and AddReduceRecv combine to sent data from all senders
|
||||
// to one receiver.
|
||||
void AddReduceSend(int num_devices, const string& key,
|
||||
ncclRedOp_t reduction_op,
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
int gpu_device_id, EventMgr* event_mgr,
|
||||
perftools::gputools::Stream* tensor_stream,
|
||||
const Tensor* in_t, Tensor* temp_t,
|
||||
DoneCallback done_callback);
|
||||
void AddReduceRecv(int num_devices, const string& key,
|
||||
ncclRedOp_t reduction_op,
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
int gpu_device_id, EventMgr* event_mgr,
|
||||
perftools::gputools::Stream* tensor_stream,
|
||||
const Tensor* in_t, Tensor* out_t,
|
||||
DoneCallback done_callback);
|
||||
|
||||
private:
|
||||
enum CollectiveType {
|
||||
kAllReduce = 1,
|
||||
kBroadcast = 2,
|
||||
kReduce = 3,
|
||||
};
|
||||
struct Collective;
|
||||
struct Communicator;
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
@ -58,11 +59,9 @@ class NcclAsyncOpBase : public AsyncOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(NcclAsyncOpBase);
|
||||
};
|
||||
|
||||
// To execute a single all-reduce, this kernel is called once for each of the
|
||||
// <k> devices in the communicator.
|
||||
class NcclAllReduceOpKernel : public NcclAsyncOpBase {
|
||||
class NcclReduceOpBase : public NcclAsyncOpBase {
|
||||
public:
|
||||
explicit NcclAllReduceOpKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) {
|
||||
explicit NcclReduceOpBase(OpKernelConstruction* c) : NcclAsyncOpBase(c) {
|
||||
string reduction;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("reduction", &reduction));
|
||||
if (reduction == "min") {
|
||||
@ -79,6 +78,19 @@ class NcclAllReduceOpKernel : public NcclAsyncOpBase {
|
||||
}
|
||||
}
|
||||
|
||||
ncclRedOp_t reduction_op() const { return reduction_op_; }
|
||||
|
||||
private:
|
||||
ncclRedOp_t reduction_op_;
|
||||
};
|
||||
|
||||
// To execute a single all-reduce, this kernel is called once for each of the
|
||||
// <k> devices in the communicator.
|
||||
class NcclAllReduceOpKernel : public NcclReduceOpBase {
|
||||
public:
|
||||
explicit NcclAllReduceOpKernel(OpKernelConstruction* c)
|
||||
: NcclReduceOpBase(c) {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
const Tensor* in_t = &c->input(0);
|
||||
Tensor* out_t;
|
||||
@ -92,18 +104,81 @@ class NcclAllReduceOpKernel : public NcclAsyncOpBase {
|
||||
auto* compute_stream = c->op_device_context()->stream();
|
||||
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
|
||||
NcclManager::instance()->AddToAllReduce(
|
||||
num_devices(), GetCollectiveKey(c), reduction_op_,
|
||||
num_devices(), GetCollectiveKey(c), reduction_op(),
|
||||
compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr,
|
||||
compute_stream, in_t, out_t, actual_done);
|
||||
compute_stream, in_t, out_t, std::move(actual_done));
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("NcclAllReduce").Device(DEVICE_GPU),
|
||||
NcclAllReduceOpKernel);
|
||||
|
||||
// To execute a single reduce, this kernel is called once for all but one of the
|
||||
// <k> devices in the communicator, and NcclReduceRecvKernel is called once for
|
||||
// the remaining device.
|
||||
class NcclReduceSendKernel : public NcclReduceOpBase {
|
||||
public:
|
||||
explicit NcclReduceSendKernel(OpKernelConstruction* c)
|
||||
: NcclReduceOpBase(c) {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
const Tensor& in_t = c->input(0);
|
||||
std::unique_ptr<Tensor> temp_ptr(new Tensor());
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
c, c->allocate_temp(in_t.dtype(), in_t.shape(), temp_ptr.get()), done);
|
||||
Tensor* temp_t = temp_ptr.release();
|
||||
|
||||
auto actual_done = [c, done, temp_t](Status s) {
|
||||
delete temp_t;
|
||||
OP_REQUIRES_OK_ASYNC(c, s, done);
|
||||
done();
|
||||
};
|
||||
|
||||
auto* compute_stream = c->op_device_context()->stream();
|
||||
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
|
||||
NcclManager::instance()->AddReduceSend(
|
||||
num_devices(), GetCollectiveKey(c), reduction_op(),
|
||||
compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr,
|
||||
compute_stream, &in_t, temp_t, std::move(actual_done));
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("NcclReduceSend").Device(DEVICE_GPU),
|
||||
NcclReduceSendKernel);
|
||||
|
||||
// To execute a single reduce, this kernel is called once for one devices, and
|
||||
// NcclReduceSendKernel is called for all other <k-1> devices in the
|
||||
// communicator.
|
||||
class NcclReduceRecvKernel : public NcclReduceOpBase {
|
||||
public:
|
||||
explicit NcclReduceRecvKernel(OpKernelConstruction* c)
|
||||
: NcclReduceOpBase(c) {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
const Tensor& in_t = c->input(0);
|
||||
Tensor* out_t;
|
||||
OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, in_t.shape(), &out_t), done);
|
||||
|
||||
auto actual_done = [c, done](Status s) {
|
||||
OP_REQUIRES_OK_ASYNC(c, s, done);
|
||||
done();
|
||||
};
|
||||
|
||||
auto* compute_stream = c->op_device_context()->stream();
|
||||
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
|
||||
NcclManager::instance()->AddReduceRecv(
|
||||
num_devices(), GetCollectiveKey(c), reduction_op(),
|
||||
compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr,
|
||||
compute_stream, &in_t, out_t, std::move(actual_done));
|
||||
}
|
||||
|
||||
private:
|
||||
ncclRedOp_t reduction_op_;
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("NcclReduceRecv").Device(DEVICE_GPU),
|
||||
NcclReduceRecvKernel);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("NcclAllReduce").Device(DEVICE_GPU),
|
||||
NcclAllReduceOpKernel);
|
||||
|
||||
// To execute a single broadcast, this kernel is called once for one device, and
|
||||
// NcclBroadcastRecvKernel is called for all other <k-1> devices in the
|
||||
// communicator.
|
||||
class NcclBroadcastSendKernel : public NcclAsyncOpBase {
|
||||
public:
|
||||
explicit NcclBroadcastSendKernel(OpKernelConstruction* c)
|
||||
@ -126,6 +201,9 @@ class NcclBroadcastSendKernel : public NcclAsyncOpBase {
|
||||
REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU),
|
||||
NcclBroadcastSendKernel);
|
||||
|
||||
// To execute a single broadcast, this kernel is called once for all but one of
|
||||
// the <k> devices in the communicator, and NcclBroadcastSendKernel is called
|
||||
// once for the remaining device.
|
||||
class NcclBroadcastRecvKernel : public NcclAsyncOpBase {
|
||||
public:
|
||||
explicit NcclBroadcastRecvKernel(OpKernelConstruction* c)
|
||||
|
@ -45,6 +45,51 @@ num_devices: The number of devices participating in this reduction.
|
||||
shared_name: Identifier that shared between ops of the same reduction.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("NcclReduceSend")
|
||||
.Input("input: T")
|
||||
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
|
||||
.Attr("T: {float, float64, int32, int64}")
|
||||
.Attr("num_devices: int")
|
||||
.Attr("shared_name: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::NoOutputs)
|
||||
.Doc(R"doc(
|
||||
Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`.
|
||||
|
||||
The graph should be constructed so that 'num_devices-1' devices run
|
||||
`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value
|
||||
`c`. Failure to do so will cause the graph execution to fail to complete.
|
||||
|
||||
input: The input to the reduction
|
||||
reduction: the reduction operation to perform.
|
||||
num_devices: The number of devices participating in this reduction.
|
||||
shared_name: Identifier that is shared between ops of the same reduce.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("NcclReduceRecv")
|
||||
.Input("input: T")
|
||||
.Output("data: T")
|
||||
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
|
||||
.Attr("T: {float, float64, int32, int64}")
|
||||
.Attr("num_devices: int")
|
||||
.Attr("shared_name: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
Reduces 'input' from this op and the NcclReduceSend ops registered in the same
|
||||
`shared_name`.
|
||||
|
||||
The graph should be constructed so that 'num_devices-1' devices run
|
||||
`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value
|
||||
`c`. Failure to do so will cause the graph execution to fail to complete.
|
||||
|
||||
input: The input to the reduction
|
||||
data: The reduced data received from this op and the NcclReduceSend op.
|
||||
reduction: the reduction operation to perform.
|
||||
num_devices: The number of devices participating in this reduction.
|
||||
shared_name: Identifier that is shared between ops of the same reduce.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("NcclBroadcastSend")
|
||||
.Input("input: T")
|
||||
.Attr("T: {float, float64, int32, int64}")
|
||||
|
@ -21,6 +21,7 @@ import threading
|
||||
|
||||
from tensorflow.contrib.nccl.ops import gen_nccl_ops
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import device
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -48,6 +49,35 @@ def all_sum(tensors):
|
||||
return _apply_all_reduce('sum', tensors)
|
||||
|
||||
|
||||
@ops.RegisterGradient('NcclAllReduce')
|
||||
def _all_sum_grad(op, grad):
|
||||
"""The gradients for `all_sum`.
|
||||
|
||||
Args:
|
||||
op: The `all_sum` `Operation` that we are differentiating.
|
||||
grad: Gradient with respect to the output of the `all_sum` op.
|
||||
|
||||
Returns:
|
||||
The gradient with respect to the output of `all_sum`.
|
||||
|
||||
Raises:
|
||||
LookupError: If `reduction` is not `sum`.
|
||||
"""
|
||||
if op.get_attr('reduction') != 'sum':
|
||||
raise LookupError('No gradient defined for NcclAllReduce except all_sum.')
|
||||
|
||||
_check_device_assignment(grad)
|
||||
num_devices = op.get_attr('num_devices')
|
||||
shared_name = op.get_attr('shared_name') + '_grad'
|
||||
|
||||
with ops.device(grad.device):
|
||||
return gen_nccl_ops.nccl_all_reduce(
|
||||
input=grad,
|
||||
reduction='sum',
|
||||
num_devices=num_devices,
|
||||
shared_name=shared_name)
|
||||
|
||||
|
||||
def all_prod(tensors):
|
||||
"""Returns a list of tensors with the all-reduce product across `tensors`.
|
||||
|
||||
@ -99,6 +129,24 @@ def all_max(tensors):
|
||||
return _apply_all_reduce('max', tensors)
|
||||
|
||||
|
||||
def reduce_sum(tensors, dst_device):
|
||||
"""Returns a tensor with the reduce sum across `tensors`.
|
||||
|
||||
The computation is done with a reduce operation, so only one tensor is
|
||||
returned.
|
||||
|
||||
Args:
|
||||
tensors: The input tensors across which to sum; must be assigned
|
||||
to GPU devices.
|
||||
dst_device: The device of the returned tensor.
|
||||
|
||||
Returns:
|
||||
A tensor containing the sum of the input tensors, with the device of the
|
||||
tensor being `dst_device`.
|
||||
"""
|
||||
return _apply_reduce('sum', tensors, dst_device)
|
||||
|
||||
|
||||
def broadcast(src_tensor, dst_devices):
|
||||
"""Returns a list of tensors on `dst_devices`, each with value `tensor`.
|
||||
|
||||
@ -111,50 +159,93 @@ def broadcast(src_tensor, dst_devices):
|
||||
dst_devices: The GPU devices to receive the sent tensor.
|
||||
|
||||
Returns:
|
||||
List of tensors, each with the value of `src_tensor`, which the device
|
||||
of tensor i is `dst_devices[i]`.
|
||||
An `Operation` to send the `src_tensor`, and a list of tensors, each with
|
||||
the value of `src_tensor`, where the device of tensor i is `dst_devices[i]`.
|
||||
"""
|
||||
if not dst_devices:
|
||||
raise ValueError('Must pass >0 dst_devices to broadcast')
|
||||
all_devices = [src_tensor.device] + dst_devices
|
||||
_check_graph_mode()
|
||||
_check_device_assignment(src_tensor)
|
||||
|
||||
shape = array_ops.shape(src_tensor, out_type=dtypes.int64)
|
||||
num_devices = len(dst_devices) + 1
|
||||
shared_name = _get_shared_name()
|
||||
|
||||
with ops.device(src_tensor.device):
|
||||
send = gen_nccl_ops.nccl_broadcast_send(
|
||||
input=src_tensor, num_devices=len(all_devices), shared_name=shared_name)
|
||||
input=src_tensor, num_devices=num_devices, shared_name=shared_name)
|
||||
|
||||
shape_op = array_ops.shape(src_tensor, out_type=dtypes.int64)
|
||||
recvs = []
|
||||
for d in dst_devices:
|
||||
with ops.device(d):
|
||||
recvs.append(
|
||||
gen_nccl_ops.nccl_broadcast_recv(
|
||||
shape=shape_op,
|
||||
shape=shape,
|
||||
T=src_tensor.dtype,
|
||||
num_devices=len(all_devices),
|
||||
num_devices=num_devices,
|
||||
shared_name=shared_name))
|
||||
|
||||
return send, recvs
|
||||
|
||||
|
||||
def _apply_all_reduce(reduction_op, tensors):
|
||||
def _apply_all_reduce(reduction, tensors):
|
||||
"""Helper function for all_* functions."""
|
||||
if not tensors:
|
||||
raise ValueError('Must pass >0 tensors to all reduce operations')
|
||||
_check_graph_mode()
|
||||
|
||||
shared_name = _get_shared_name()
|
||||
res = []
|
||||
|
||||
for t in tensors:
|
||||
if not device.canonical_name(t.device):
|
||||
raise ValueError('Device assignment required for nccl collective ops')
|
||||
_check_device_assignment(t)
|
||||
with ops.device(t.device):
|
||||
res.append(
|
||||
gen_nccl_ops.nccl_all_reduce(
|
||||
t,
|
||||
reduction=reduction_op,
|
||||
input=t,
|
||||
reduction=reduction,
|
||||
num_devices=len(tensors),
|
||||
shared_name=shared_name))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _apply_reduce(reduction, tensors, dst_device):
|
||||
"""Helper function for reduce_* functions."""
|
||||
if not tensors:
|
||||
raise ValueError('Must pass >0 tensors to reduce operations')
|
||||
if not dst_device:
|
||||
raise ValueError('Must pass dst_device to reduce operations')
|
||||
_check_graph_mode()
|
||||
|
||||
try:
|
||||
recv_index = next(i for i, t in enumerate(tensors)
|
||||
if t.device == dst_device)
|
||||
except StopIteration:
|
||||
raise ValueError('One of the tensors must be assigned to dst_device')
|
||||
shared_name = _get_shared_name()
|
||||
|
||||
sends = []
|
||||
for t in tensors[:recv_index] + tensors[recv_index + 1:]:
|
||||
_check_device_assignment(t)
|
||||
with ops.device(t.device):
|
||||
sends.append(
|
||||
gen_nccl_ops.nccl_reduce_send(
|
||||
input=t,
|
||||
reduction=reduction,
|
||||
num_devices=len(tensors),
|
||||
shared_name=shared_name))
|
||||
|
||||
with ops.device(dst_device):
|
||||
recv = gen_nccl_ops.nccl_reduce_recv(
|
||||
input=tensors[recv_index],
|
||||
reduction=reduction,
|
||||
num_devices=len(tensors),
|
||||
shared_name=shared_name)
|
||||
|
||||
return recv, sends
|
||||
|
||||
|
||||
_lock = threading.Lock()
|
||||
_shared_name_counter = 0
|
||||
|
||||
@ -166,3 +257,13 @@ def _get_shared_name():
|
||||
val = _shared_name_counter
|
||||
_shared_name_counter += 1
|
||||
return 'c%s' % val
|
||||
|
||||
|
||||
def _check_device_assignment(tensor):
|
||||
if not device.canonical_name(tensor.device):
|
||||
raise ValueError('Device assignment required for nccl collective ops')
|
||||
|
||||
|
||||
def _check_graph_mode():
|
||||
if context.in_eager_mode():
|
||||
raise ValueError('Nccl ops are not supported in eager mode')
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib import nccl
|
||||
@ -26,9 +27,45 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class AllReduceTest(test.TestCase):
|
||||
def _DeviceTensors(tensors, devices):
|
||||
res = []
|
||||
for t, d in zip(tensors, devices):
|
||||
with ops.device(d):
|
||||
res.append(array_ops.identity(t))
|
||||
return res
|
||||
|
||||
def testAllReduce(self):
|
||||
|
||||
def _NcclAllReduce(nccl_fun, tensors, devices):
|
||||
return nccl_fun(_DeviceTensors(tensors, devices)), []
|
||||
|
||||
|
||||
def _NcclReduce(nccl_fun, tensors, devices):
|
||||
d_tensors = _DeviceTensors(tensors, devices)
|
||||
receiver = np.random.randint(0, len(devices))
|
||||
received_tensor, send_ops = nccl_fun(d_tensors, devices[receiver])
|
||||
return [received_tensor], send_ops
|
||||
|
||||
|
||||
def _NcclBroadcast(tensors, devices):
|
||||
sender = np.random.randint(0, len(devices))
|
||||
d_tensor = _DeviceTensors(tensors[0:1], devices[sender:sender + 1])[0]
|
||||
other_devices = devices[:sender] + devices[sender + 1:]
|
||||
send_op, received_tensors = nccl.broadcast(d_tensor, other_devices)
|
||||
return received_tensors, [send_op]
|
||||
|
||||
|
||||
class NcclTestCase(test.TestCase):
|
||||
|
||||
def _Test(self, nccl_reduce, numpy_fn):
|
||||
"""Tests that nccl_reduce does the same as reduction with numpy_fn.
|
||||
|
||||
Args:
|
||||
nccl_reduce: A function taking a list of tensors and a list of devices,
|
||||
and returns a list of reduced tensors and a list of ops to perform the
|
||||
reduction.
|
||||
numpy_fn: A function taking two tensors and returning the reduction of the
|
||||
two.
|
||||
"""
|
||||
if not test.is_gpu_available():
|
||||
return # Test requires access to a GPU
|
||||
|
||||
@ -36,37 +73,62 @@ class AllReduceTest(test.TestCase):
|
||||
# Create session inside outer loop to test use of
|
||||
# same communicator across multiple sessions.
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
self._testSingleAllReduce(sess, dtype, nccl.all_sum, lambda x, y: x + y)
|
||||
self._testSingleAllReduce(sess, dtype, nccl.all_prod,
|
||||
lambda x, y: x * y)
|
||||
self._testSingleAllReduce(sess, dtype, nccl.all_min, np.minimum)
|
||||
self._testSingleAllReduce(sess, dtype, nccl.all_max, np.maximum)
|
||||
|
||||
def _testSingleAllReduce(self, sess, np_type, nccl_fn, numpy_accumulation_fn):
|
||||
for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'],
|
||||
['/device:GPU:1', '/device:GPU:0']]:
|
||||
shape = (3, 4)
|
||||
np_ans = None
|
||||
tensors = []
|
||||
for d in devices:
|
||||
with ops.device(d):
|
||||
t = ((np.random.random_sample(shape) - .5) * 1024).astype(np_type)
|
||||
if np_ans is None:
|
||||
np_ans = t
|
||||
else:
|
||||
np_ans = numpy_accumulation_fn(np_ans, t)
|
||||
tensors.append(array_ops.identity(t))
|
||||
for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'],
|
||||
['/device:GPU:1', '/device:GPU:0']]:
|
||||
shape = (3, 4)
|
||||
random = (np.random.random_sample(shape) - .5) * 1024
|
||||
tensors = [random.astype(dtype)] * len(devices)
|
||||
np_ans = tensors[0]
|
||||
for t in tensors[1:]:
|
||||
np_ans = numpy_fn(np_ans, t)
|
||||
|
||||
all_reduce_tensors = nccl_fn(tensors)
|
||||
reduce_tensors, reduce_ops = nccl_reduce(tensors, devices)
|
||||
self.assertNotEmpty(reduce_tensors)
|
||||
|
||||
# Test shape inference.
|
||||
for r in all_reduce_tensors:
|
||||
self.assertEqual(shape, r.get_shape())
|
||||
# Test shape inference.
|
||||
for r in reduce_tensors:
|
||||
self.assertEqual(shape, r.get_shape())
|
||||
|
||||
# Test execution and results.
|
||||
nccl_results = sess.run(all_reduce_tensors)
|
||||
for r in nccl_results:
|
||||
self.assertAllClose(r, np_ans)
|
||||
# Test execution and results.
|
||||
nccl_results = sess.run(reduce_tensors + reduce_ops)
|
||||
for r in nccl_results[:len(reduce_tensors)]:
|
||||
self.assertAllClose(r, np_ans)
|
||||
|
||||
def _TestGradient(self, nccl_reduce, numpy_fn):
|
||||
"""Tests the gradient of nccl_reduce.
|
||||
|
||||
Args:
|
||||
nccl_reduce: A function taking a list of tensors and a list of devices,
|
||||
and returns a list of reduced tensors and a list of ops to perform the
|
||||
reduction.
|
||||
numpy_fn: A function taking two tensors and returning the gradient of the
|
||||
reduction of the two.
|
||||
"""
|
||||
def _Gradient(tensors, devices):
|
||||
reduce_tensors, _ = nccl_reduce(tensors, devices)
|
||||
tensor_ops = [t.op for t in reduce_tensors]
|
||||
d_tensors = _DeviceTensors(tensors, devices)
|
||||
grad_tensors = [
|
||||
ops.get_gradient_function(op)(op, loss)
|
||||
for op, loss in zip(tensor_ops, d_tensors)
|
||||
]
|
||||
return grad_tensors, []
|
||||
|
||||
self._Test(_Gradient, numpy_fn)
|
||||
|
||||
|
||||
class AllReduceTest(NcclTestCase):
|
||||
|
||||
def testAllReduce(self):
|
||||
self._Test(partial(_NcclAllReduce, nccl.all_sum), lambda x, y: x + y)
|
||||
self._Test(partial(_NcclAllReduce, nccl.all_prod), lambda x, y: x * y)
|
||||
self._Test(partial(_NcclAllReduce, nccl.all_min), np.minimum)
|
||||
self._Test(partial(_NcclAllReduce, nccl.all_max), np.maximum)
|
||||
|
||||
def testAllSumGrad(self):
|
||||
self._TestGradient(
|
||||
partial(_NcclAllReduce, nccl.all_sum), lambda x, y: x + y)
|
||||
|
||||
def testErrors(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Device assignment required'):
|
||||
@ -75,79 +137,32 @@ class AllReduceTest(test.TestCase):
|
||||
nccl.all_sum([])
|
||||
|
||||
|
||||
class BroadcastTest(test.TestCase):
|
||||
class SingleReduceTest(NcclTestCase):
|
||||
|
||||
def testSum(self):
|
||||
self._Test(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x + y)
|
||||
|
||||
|
||||
class BroadcastTest(NcclTestCase):
|
||||
|
||||
def testBroadcast(self):
|
||||
if not test.is_gpu_available():
|
||||
return # Test requires access to a GPU
|
||||
|
||||
for dtype in [np.float32, np.int32, np.int64, np.float64]:
|
||||
# Create session inside outer loop to test use of
|
||||
# same communicator across multiple sessions.
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
for devices in [['/device:GPU:1', '/device:GPU:0', '/device:GPU:2'],
|
||||
['/device:GPU:1', '/device:GPU:0']]:
|
||||
shape = (3, 4)
|
||||
sender = np.random.randint(0, len(devices) - 1)
|
||||
with ops.device(devices[sender]):
|
||||
np_ans = ((
|
||||
(np.random.random_sample(shape) - .5) * 1024).astype(dtype))
|
||||
t = array_ops.identity(np_ans)
|
||||
other_devices = devices[:sender] + devices[sender + 1:]
|
||||
send_op, received_tensors = nccl.broadcast(t, other_devices)
|
||||
|
||||
# Verify shape inference.
|
||||
for r in received_tensors:
|
||||
self.assertEqual(shape, r.get_shape())
|
||||
|
||||
# Run and verify results.
|
||||
nccl_results = sess.run(received_tensors + [send_op])
|
||||
for r in nccl_results[:-1]:
|
||||
self.assertAllClose(r, np_ans)
|
||||
self._Test(_NcclBroadcast, lambda x, y: x)
|
||||
|
||||
|
||||
class CombinedTest(test.TestCase):
|
||||
"""Tests using a mix of all-reduce ops in one session.run call."""
|
||||
class CombinedTest(NcclTestCase):
|
||||
"""Test all-reduce vs. single-reduce plus broadcast in one session.run."""
|
||||
|
||||
def _combined(self, tensors, devices):
|
||||
all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices)[0]
|
||||
single_reduce_tensors, single_reduce_ops = _NcclReduce(
|
||||
nccl.reduce_sum, tensors, devices)
|
||||
broadcast_tensors, broadcast_ops = _NcclBroadcast(single_reduce_tensors,
|
||||
devices)
|
||||
all_tensors = all_reduce_tensors + single_reduce_tensors + broadcast_tensors
|
||||
return all_tensors, single_reduce_ops + broadcast_ops
|
||||
|
||||
def testCombined(self):
|
||||
if not test.is_gpu_available():
|
||||
return # Test requires access to a GPU
|
||||
|
||||
for dtype in [np.float32, np.int32, np.int64, np.float64]:
|
||||
# Create session inside outer loop to test use of
|
||||
# same communicator across multiple sessions.
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'],
|
||||
['/device:GPU:0', '/device:GPU:1']]:
|
||||
shape = (3, 4)
|
||||
|
||||
# all-reduce
|
||||
np_ans = np.zeros(shape=shape, dtype=dtype)
|
||||
tensors = []
|
||||
for d in devices:
|
||||
with ops.device(d):
|
||||
t = ((np.random.random_sample(shape) - .5) * 1024).astype(dtype)
|
||||
np_ans += t
|
||||
tensors.append(array_ops.identity(t))
|
||||
all_reduce_tensors = nccl.all_sum(tensors)
|
||||
|
||||
sender = np.random.randint(0, len(devices) - 1)
|
||||
other_devices = devices[:sender] + devices[sender + 1:]
|
||||
send_op, received_tensors = nccl.broadcast(all_reduce_tensors[sender],
|
||||
other_devices)
|
||||
|
||||
# sender doesn't need to be fetched as part of outputs of session.run.
|
||||
del all_reduce_tensors[sender]
|
||||
|
||||
# Verify shape inference.
|
||||
for r in received_tensors:
|
||||
self.assertEqual(shape, r.get_shape())
|
||||
|
||||
# Run and verify results.
|
||||
nccl_results = sess.run(
|
||||
received_tensors + [send_op] + all_reduce_tensors)
|
||||
for r in nccl_results[:len(received_tensors)]:
|
||||
self.assertAllClose(r, np_ans)
|
||||
self._Test(self._combined, lambda x, y: x + y)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user