Adding NCCL sum op, register all_sum gradient.

Streamlining nccl test.

PiperOrigin-RevId: 168347428
This commit is contained in:
A. Unique TensorFlower 2017-09-12 02:00:55 -07:00 committed by TensorFlower Gardener
parent bc300318e7
commit 9b9e54b344
7 changed files with 415 additions and 119 deletions

View File

@ -18,6 +18,7 @@
@@all_min @@all_min
@@all_prod @@all_prod
@@all_sum @@all_sum
@@reduce_sum
@@broadcast @@broadcast
""" """
@ -31,6 +32,7 @@ from tensorflow.contrib.nccl.python.ops.nccl_ops import all_min
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_prod from tensorflow.contrib.nccl.python.ops.nccl_ops import all_prod
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_sum from tensorflow.contrib.nccl.python.ops.nccl_ops import all_sum
from tensorflow.contrib.nccl.python.ops.nccl_ops import broadcast from tensorflow.contrib.nccl.python.ops.nccl_ops import broadcast
from tensorflow.contrib.nccl.python.ops.nccl_ops import reduce_sum
from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__) remove_undocumented(__name__)

View File

@ -260,7 +260,7 @@ NcclManager::Communicator* NcclManager::GetCommunicator(
std::vector<ncclComm_t> nccl_comms(num_devices); std::vector<ncclComm_t> nccl_comms(num_devices);
auto result = ncclCommInitAll(nccl_comms.data(), num_devices, devices.data()); auto result = ncclCommInitAll(nccl_comms.data(), num_devices, devices.data());
CHECK_EQ(result, ncclSuccess); CHECK_EQ(result, ncclSuccess) << ncclGetErrorString(result);
for (int rank = 0; rank < num_devices; ++rank) { for (int rank = 0; rank < num_devices; ++rank) {
members[rank].nccl_comm = nccl_comms[rank]; members[rank].nccl_comm = nccl_comms[rank];
} }
@ -307,6 +307,35 @@ void NcclManager::AddBroadcastRecv(
kBroadcast, ncclSum /* unused */); kBroadcast, ncclSum /* unused */);
} }
void NcclManager::AddReduceSend(int num_devices, const string& key,
ncclRedOp_t reduction_op,
perftools::gputools::StreamExecutor* executor,
int gpu_device_id, EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream,
const Tensor* in_t, Tensor* temp_t,
DoneCallback done_callback) {
std::unique_ptr<Participant> participant(
new Participant(in_t, temp_t, event_mgr, tensor_stream, executor,
gpu_device_id, std::move(done_callback)));
AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
kReduce, reduction_op);
}
void NcclManager::AddReduceRecv(int num_devices, const string& key,
ncclRedOp_t reduction_op,
perftools::gputools::StreamExecutor* executor,
int gpu_device_id, EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream,
const Tensor* in_t, Tensor* out_t,
DoneCallback done_callback) {
std::unique_ptr<Participant> participant(
new Participant(in_t, out_t, event_mgr, tensor_stream, executor,
gpu_device_id, std::move(done_callback)));
participant->root = true;
AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
kReduce, reduction_op);
}
void NcclManager::AddParticipant(int num_devices, const string& key, void NcclManager::AddParticipant(int num_devices, const string& key,
std::unique_ptr<Participant> participant, std::unique_ptr<Participant> participant,
DataType data_type, DataType data_type,
@ -431,6 +460,14 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
collective->root_rank, nccl_comm, *cu_stream); collective->root_rank, nccl_comm, *cu_stream);
break; break;
} }
case kReduce: {
const void* sendbuff = p->in_t->tensor_data().data();
void* recvbuff = const_cast<char*>(p->out_t->tensor_data().data());
nccl_result = ncclReduce(sendbuff, recvbuff, p->in_t->NumElements(),
data_type, collective->reduction_op,
collective->root_rank, nccl_comm, *cu_stream);
break;
}
} }
// Run the done_callback when the nccl kernel finishes running. // Run the done_callback when the nccl kernel finishes running.
@ -441,7 +478,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
// Propagate the error, but note that if other members of the collective // Propagate the error, but note that if other members of the collective
// did launch their kernels, then they are hanging. // did launch their kernels, then they are hanging.
collective->participants[rank]->done_callback(errors::Unknown( collective->participants[rank]->done_callback(errors::Unknown(
"Error invoking AllReduce: ", ncclGetErrorString(nccl_result))); "Error invoking NCCL: ", ncclGetErrorString(nccl_result)));
} }
// TODO(cwhipkey): use RefCounted after figuring out how to use in a // TODO(cwhipkey): use RefCounted after figuring out how to use in a

View File

@ -75,10 +75,28 @@ class NcclManager {
perftools::gputools::Stream* tensor_stream, perftools::gputools::Stream* tensor_stream,
Tensor* out_t, DoneCallback done_callback); Tensor* out_t, DoneCallback done_callback);
// AddReduceSend and AddReduceRecv combine to sent data from all senders
// to one receiver.
void AddReduceSend(int num_devices, const string& key,
ncclRedOp_t reduction_op,
perftools::gputools::StreamExecutor* executor,
int gpu_device_id, EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream,
const Tensor* in_t, Tensor* temp_t,
DoneCallback done_callback);
void AddReduceRecv(int num_devices, const string& key,
ncclRedOp_t reduction_op,
perftools::gputools::StreamExecutor* executor,
int gpu_device_id, EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream,
const Tensor* in_t, Tensor* out_t,
DoneCallback done_callback);
private: private:
enum CollectiveType { enum CollectiveType {
kAllReduce = 1, kAllReduce = 1,
kBroadcast = 2, kBroadcast = 2,
kReduce = 3,
}; };
struct Collective; struct Collective;
struct Communicator; struct Communicator;

View File

@ -15,6 +15,7 @@ limitations under the License.
#if GOOGLE_CUDA #if GOOGLE_CUDA
#include <memory>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
@ -58,11 +59,9 @@ class NcclAsyncOpBase : public AsyncOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(NcclAsyncOpBase); TF_DISALLOW_COPY_AND_ASSIGN(NcclAsyncOpBase);
}; };
// To execute a single all-reduce, this kernel is called once for each of the class NcclReduceOpBase : public NcclAsyncOpBase {
// <k> devices in the communicator.
class NcclAllReduceOpKernel : public NcclAsyncOpBase {
public: public:
explicit NcclAllReduceOpKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) { explicit NcclReduceOpBase(OpKernelConstruction* c) : NcclAsyncOpBase(c) {
string reduction; string reduction;
OP_REQUIRES_OK(c, c->GetAttr("reduction", &reduction)); OP_REQUIRES_OK(c, c->GetAttr("reduction", &reduction));
if (reduction == "min") { if (reduction == "min") {
@ -79,6 +78,19 @@ class NcclAllReduceOpKernel : public NcclAsyncOpBase {
} }
} }
ncclRedOp_t reduction_op() const { return reduction_op_; }
private:
ncclRedOp_t reduction_op_;
};
// To execute a single all-reduce, this kernel is called once for each of the
// <k> devices in the communicator.
class NcclAllReduceOpKernel : public NcclReduceOpBase {
public:
explicit NcclAllReduceOpKernel(OpKernelConstruction* c)
: NcclReduceOpBase(c) {}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override { void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
const Tensor* in_t = &c->input(0); const Tensor* in_t = &c->input(0);
Tensor* out_t; Tensor* out_t;
@ -92,18 +104,81 @@ class NcclAllReduceOpKernel : public NcclAsyncOpBase {
auto* compute_stream = c->op_device_context()->stream(); auto* compute_stream = c->op_device_context()->stream();
auto* gpu_info = c->device()->tensorflow_gpu_device_info(); auto* gpu_info = c->device()->tensorflow_gpu_device_info();
NcclManager::instance()->AddToAllReduce( NcclManager::instance()->AddToAllReduce(
num_devices(), GetCollectiveKey(c), reduction_op_, num_devices(), GetCollectiveKey(c), reduction_op(),
compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr, compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr,
compute_stream, in_t, out_t, actual_done); compute_stream, in_t, out_t, std::move(actual_done));
}
};
REGISTER_KERNEL_BUILDER(Name("NcclAllReduce").Device(DEVICE_GPU),
NcclAllReduceOpKernel);
// To execute a single reduce, this kernel is called once for all but one of the
// <k> devices in the communicator, and NcclReduceRecvKernel is called once for
// the remaining device.
class NcclReduceSendKernel : public NcclReduceOpBase {
public:
explicit NcclReduceSendKernel(OpKernelConstruction* c)
: NcclReduceOpBase(c) {}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
const Tensor& in_t = c->input(0);
std::unique_ptr<Tensor> temp_ptr(new Tensor());
OP_REQUIRES_OK_ASYNC(
c, c->allocate_temp(in_t.dtype(), in_t.shape(), temp_ptr.get()), done);
Tensor* temp_t = temp_ptr.release();
auto actual_done = [c, done, temp_t](Status s) {
delete temp_t;
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
auto* compute_stream = c->op_device_context()->stream();
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
NcclManager::instance()->AddReduceSend(
num_devices(), GetCollectiveKey(c), reduction_op(),
compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr,
compute_stream, &in_t, temp_t, std::move(actual_done));
}
};
REGISTER_KERNEL_BUILDER(Name("NcclReduceSend").Device(DEVICE_GPU),
NcclReduceSendKernel);
// To execute a single reduce, this kernel is called once for one devices, and
// NcclReduceSendKernel is called for all other <k-1> devices in the
// communicator.
class NcclReduceRecvKernel : public NcclReduceOpBase {
public:
explicit NcclReduceRecvKernel(OpKernelConstruction* c)
: NcclReduceOpBase(c) {}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
const Tensor& in_t = c->input(0);
Tensor* out_t;
OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, in_t.shape(), &out_t), done);
auto actual_done = [c, done](Status s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
auto* compute_stream = c->op_device_context()->stream();
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
NcclManager::instance()->AddReduceRecv(
num_devices(), GetCollectiveKey(c), reduction_op(),
compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr,
compute_stream, &in_t, out_t, std::move(actual_done));
} }
private: private:
ncclRedOp_t reduction_op_; ncclRedOp_t reduction_op_;
}; };
REGISTER_KERNEL_BUILDER(Name("NcclReduceRecv").Device(DEVICE_GPU),
NcclReduceRecvKernel);
REGISTER_KERNEL_BUILDER(Name("NcclAllReduce").Device(DEVICE_GPU), // To execute a single broadcast, this kernel is called once for one device, and
NcclAllReduceOpKernel); // NcclBroadcastRecvKernel is called for all other <k-1> devices in the
// communicator.
class NcclBroadcastSendKernel : public NcclAsyncOpBase { class NcclBroadcastSendKernel : public NcclAsyncOpBase {
public: public:
explicit NcclBroadcastSendKernel(OpKernelConstruction* c) explicit NcclBroadcastSendKernel(OpKernelConstruction* c)
@ -126,6 +201,9 @@ class NcclBroadcastSendKernel : public NcclAsyncOpBase {
REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU), REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU),
NcclBroadcastSendKernel); NcclBroadcastSendKernel);
// To execute a single broadcast, this kernel is called once for all but one of
// the <k> devices in the communicator, and NcclBroadcastSendKernel is called
// once for the remaining device.
class NcclBroadcastRecvKernel : public NcclAsyncOpBase { class NcclBroadcastRecvKernel : public NcclAsyncOpBase {
public: public:
explicit NcclBroadcastRecvKernel(OpKernelConstruction* c) explicit NcclBroadcastRecvKernel(OpKernelConstruction* c)

View File

@ -45,6 +45,51 @@ num_devices: The number of devices participating in this reduction.
shared_name: Identifier that shared between ops of the same reduction. shared_name: Identifier that shared between ops of the same reduction.
)doc"); )doc");
REGISTER_OP("NcclReduceSend")
.Input("input: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
.Attr("T: {float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs)
.Doc(R"doc(
Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`.
The graph should be constructed so that 'num_devices-1' devices run
`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value
`c`. Failure to do so will cause the graph execution to fail to complete.
input: The input to the reduction
reduction: the reduction operation to perform.
num_devices: The number of devices participating in this reduction.
shared_name: Identifier that is shared between ops of the same reduce.
)doc");
REGISTER_OP("NcclReduceRecv")
.Input("input: T")
.Output("data: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
.Attr("T: {float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Reduces 'input' from this op and the NcclReduceSend ops registered in the same
`shared_name`.
The graph should be constructed so that 'num_devices-1' devices run
`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value
`c`. Failure to do so will cause the graph execution to fail to complete.
input: The input to the reduction
data: The reduced data received from this op and the NcclReduceSend op.
reduction: the reduction operation to perform.
num_devices: The number of devices participating in this reduction.
shared_name: Identifier that is shared between ops of the same reduce.
)doc");
REGISTER_OP("NcclBroadcastSend") REGISTER_OP("NcclBroadcastSend")
.Input("input: T") .Input("input: T")
.Attr("T: {float, float64, int32, int64}") .Attr("T: {float, float64, int32, int64}")

View File

@ -21,6 +21,7 @@ import threading
from tensorflow.contrib.nccl.ops import gen_nccl_ops from tensorflow.contrib.nccl.ops import gen_nccl_ops
from tensorflow.contrib.util import loader from tensorflow.contrib.util import loader
from tensorflow.python.eager import context
from tensorflow.python.framework import device from tensorflow.python.framework import device
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -48,6 +49,35 @@ def all_sum(tensors):
return _apply_all_reduce('sum', tensors) return _apply_all_reduce('sum', tensors)
@ops.RegisterGradient('NcclAllReduce')
def _all_sum_grad(op, grad):
"""The gradients for `all_sum`.
Args:
op: The `all_sum` `Operation` that we are differentiating.
grad: Gradient with respect to the output of the `all_sum` op.
Returns:
The gradient with respect to the output of `all_sum`.
Raises:
LookupError: If `reduction` is not `sum`.
"""
if op.get_attr('reduction') != 'sum':
raise LookupError('No gradient defined for NcclAllReduce except all_sum.')
_check_device_assignment(grad)
num_devices = op.get_attr('num_devices')
shared_name = op.get_attr('shared_name') + '_grad'
with ops.device(grad.device):
return gen_nccl_ops.nccl_all_reduce(
input=grad,
reduction='sum',
num_devices=num_devices,
shared_name=shared_name)
def all_prod(tensors): def all_prod(tensors):
"""Returns a list of tensors with the all-reduce product across `tensors`. """Returns a list of tensors with the all-reduce product across `tensors`.
@ -99,6 +129,24 @@ def all_max(tensors):
return _apply_all_reduce('max', tensors) return _apply_all_reduce('max', tensors)
def reduce_sum(tensors, dst_device):
"""Returns a tensor with the reduce sum across `tensors`.
The computation is done with a reduce operation, so only one tensor is
returned.
Args:
tensors: The input tensors across which to sum; must be assigned
to GPU devices.
dst_device: The device of the returned tensor.
Returns:
A tensor containing the sum of the input tensors, with the device of the
tensor being `dst_device`.
"""
return _apply_reduce('sum', tensors, dst_device)
def broadcast(src_tensor, dst_devices): def broadcast(src_tensor, dst_devices):
"""Returns a list of tensors on `dst_devices`, each with value `tensor`. """Returns a list of tensors on `dst_devices`, each with value `tensor`.
@ -111,50 +159,93 @@ def broadcast(src_tensor, dst_devices):
dst_devices: The GPU devices to receive the sent tensor. dst_devices: The GPU devices to receive the sent tensor.
Returns: Returns:
List of tensors, each with the value of `src_tensor`, which the device An `Operation` to send the `src_tensor`, and a list of tensors, each with
of tensor i is `dst_devices[i]`. the value of `src_tensor`, where the device of tensor i is `dst_devices[i]`.
""" """
if not dst_devices: if not dst_devices:
raise ValueError('Must pass >0 dst_devices to broadcast') raise ValueError('Must pass >0 dst_devices to broadcast')
all_devices = [src_tensor.device] + dst_devices _check_graph_mode()
_check_device_assignment(src_tensor)
shape = array_ops.shape(src_tensor, out_type=dtypes.int64)
num_devices = len(dst_devices) + 1
shared_name = _get_shared_name() shared_name = _get_shared_name()
with ops.device(src_tensor.device): with ops.device(src_tensor.device):
send = gen_nccl_ops.nccl_broadcast_send( send = gen_nccl_ops.nccl_broadcast_send(
input=src_tensor, num_devices=len(all_devices), shared_name=shared_name) input=src_tensor, num_devices=num_devices, shared_name=shared_name)
shape_op = array_ops.shape(src_tensor, out_type=dtypes.int64)
recvs = [] recvs = []
for d in dst_devices: for d in dst_devices:
with ops.device(d): with ops.device(d):
recvs.append( recvs.append(
gen_nccl_ops.nccl_broadcast_recv( gen_nccl_ops.nccl_broadcast_recv(
shape=shape_op, shape=shape,
T=src_tensor.dtype, T=src_tensor.dtype,
num_devices=len(all_devices), num_devices=num_devices,
shared_name=shared_name)) shared_name=shared_name))
return send, recvs return send, recvs
def _apply_all_reduce(reduction_op, tensors): def _apply_all_reduce(reduction, tensors):
"""Helper function for all_* functions."""
if not tensors: if not tensors:
raise ValueError('Must pass >0 tensors to all reduce operations') raise ValueError('Must pass >0 tensors to all reduce operations')
_check_graph_mode()
shared_name = _get_shared_name() shared_name = _get_shared_name()
res = [] res = []
for t in tensors: for t in tensors:
if not device.canonical_name(t.device): _check_device_assignment(t)
raise ValueError('Device assignment required for nccl collective ops')
with ops.device(t.device): with ops.device(t.device):
res.append( res.append(
gen_nccl_ops.nccl_all_reduce( gen_nccl_ops.nccl_all_reduce(
t, input=t,
reduction=reduction_op, reduction=reduction,
num_devices=len(tensors), num_devices=len(tensors),
shared_name=shared_name)) shared_name=shared_name))
return res return res
def _apply_reduce(reduction, tensors, dst_device):
"""Helper function for reduce_* functions."""
if not tensors:
raise ValueError('Must pass >0 tensors to reduce operations')
if not dst_device:
raise ValueError('Must pass dst_device to reduce operations')
_check_graph_mode()
try:
recv_index = next(i for i, t in enumerate(tensors)
if t.device == dst_device)
except StopIteration:
raise ValueError('One of the tensors must be assigned to dst_device')
shared_name = _get_shared_name()
sends = []
for t in tensors[:recv_index] + tensors[recv_index + 1:]:
_check_device_assignment(t)
with ops.device(t.device):
sends.append(
gen_nccl_ops.nccl_reduce_send(
input=t,
reduction=reduction,
num_devices=len(tensors),
shared_name=shared_name))
with ops.device(dst_device):
recv = gen_nccl_ops.nccl_reduce_recv(
input=tensors[recv_index],
reduction=reduction,
num_devices=len(tensors),
shared_name=shared_name)
return recv, sends
_lock = threading.Lock() _lock = threading.Lock()
_shared_name_counter = 0 _shared_name_counter = 0
@ -166,3 +257,13 @@ def _get_shared_name():
val = _shared_name_counter val = _shared_name_counter
_shared_name_counter += 1 _shared_name_counter += 1
return 'c%s' % val return 'c%s' % val
def _check_device_assignment(tensor):
if not device.canonical_name(tensor.device):
raise ValueError('Device assignment required for nccl collective ops')
def _check_graph_mode():
if context.in_eager_mode():
raise ValueError('Nccl ops are not supported in eager mode')

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from functools import partial
import numpy as np import numpy as np
from tensorflow.contrib import nccl from tensorflow.contrib import nccl
@ -26,9 +27,45 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
class AllReduceTest(test.TestCase): def _DeviceTensors(tensors, devices):
res = []
for t, d in zip(tensors, devices):
with ops.device(d):
res.append(array_ops.identity(t))
return res
def testAllReduce(self):
def _NcclAllReduce(nccl_fun, tensors, devices):
return nccl_fun(_DeviceTensors(tensors, devices)), []
def _NcclReduce(nccl_fun, tensors, devices):
d_tensors = _DeviceTensors(tensors, devices)
receiver = np.random.randint(0, len(devices))
received_tensor, send_ops = nccl_fun(d_tensors, devices[receiver])
return [received_tensor], send_ops
def _NcclBroadcast(tensors, devices):
sender = np.random.randint(0, len(devices))
d_tensor = _DeviceTensors(tensors[0:1], devices[sender:sender + 1])[0]
other_devices = devices[:sender] + devices[sender + 1:]
send_op, received_tensors = nccl.broadcast(d_tensor, other_devices)
return received_tensors, [send_op]
class NcclTestCase(test.TestCase):
def _Test(self, nccl_reduce, numpy_fn):
"""Tests that nccl_reduce does the same as reduction with numpy_fn.
Args:
nccl_reduce: A function taking a list of tensors and a list of devices,
and returns a list of reduced tensors and a list of ops to perform the
reduction.
numpy_fn: A function taking two tensors and returning the reduction of the
two.
"""
if not test.is_gpu_available(): if not test.is_gpu_available():
return # Test requires access to a GPU return # Test requires access to a GPU
@ -36,37 +73,62 @@ class AllReduceTest(test.TestCase):
# Create session inside outer loop to test use of # Create session inside outer loop to test use of
# same communicator across multiple sessions. # same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess: with self.test_session(use_gpu=True) as sess:
self._testSingleAllReduce(sess, dtype, nccl.all_sum, lambda x, y: x + y)
self._testSingleAllReduce(sess, dtype, nccl.all_prod,
lambda x, y: x * y)
self._testSingleAllReduce(sess, dtype, nccl.all_min, np.minimum)
self._testSingleAllReduce(sess, dtype, nccl.all_max, np.maximum)
def _testSingleAllReduce(self, sess, np_type, nccl_fn, numpy_accumulation_fn): for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'],
for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'], ['/device:GPU:1', '/device:GPU:0']]:
['/device:GPU:1', '/device:GPU:0']]: shape = (3, 4)
shape = (3, 4) random = (np.random.random_sample(shape) - .5) * 1024
np_ans = None tensors = [random.astype(dtype)] * len(devices)
tensors = [] np_ans = tensors[0]
for d in devices: for t in tensors[1:]:
with ops.device(d): np_ans = numpy_fn(np_ans, t)
t = ((np.random.random_sample(shape) - .5) * 1024).astype(np_type)
if np_ans is None:
np_ans = t
else:
np_ans = numpy_accumulation_fn(np_ans, t)
tensors.append(array_ops.identity(t))
all_reduce_tensors = nccl_fn(tensors) reduce_tensors, reduce_ops = nccl_reduce(tensors, devices)
self.assertNotEmpty(reduce_tensors)
# Test shape inference. # Test shape inference.
for r in all_reduce_tensors: for r in reduce_tensors:
self.assertEqual(shape, r.get_shape()) self.assertEqual(shape, r.get_shape())
# Test execution and results. # Test execution and results.
nccl_results = sess.run(all_reduce_tensors) nccl_results = sess.run(reduce_tensors + reduce_ops)
for r in nccl_results: for r in nccl_results[:len(reduce_tensors)]:
self.assertAllClose(r, np_ans) self.assertAllClose(r, np_ans)
def _TestGradient(self, nccl_reduce, numpy_fn):
"""Tests the gradient of nccl_reduce.
Args:
nccl_reduce: A function taking a list of tensors and a list of devices,
and returns a list of reduced tensors and a list of ops to perform the
reduction.
numpy_fn: A function taking two tensors and returning the gradient of the
reduction of the two.
"""
def _Gradient(tensors, devices):
reduce_tensors, _ = nccl_reduce(tensors, devices)
tensor_ops = [t.op for t in reduce_tensors]
d_tensors = _DeviceTensors(tensors, devices)
grad_tensors = [
ops.get_gradient_function(op)(op, loss)
for op, loss in zip(tensor_ops, d_tensors)
]
return grad_tensors, []
self._Test(_Gradient, numpy_fn)
class AllReduceTest(NcclTestCase):
def testAllReduce(self):
self._Test(partial(_NcclAllReduce, nccl.all_sum), lambda x, y: x + y)
self._Test(partial(_NcclAllReduce, nccl.all_prod), lambda x, y: x * y)
self._Test(partial(_NcclAllReduce, nccl.all_min), np.minimum)
self._Test(partial(_NcclAllReduce, nccl.all_max), np.maximum)
def testAllSumGrad(self):
self._TestGradient(
partial(_NcclAllReduce, nccl.all_sum), lambda x, y: x + y)
def testErrors(self): def testErrors(self):
with self.assertRaisesRegexp(ValueError, 'Device assignment required'): with self.assertRaisesRegexp(ValueError, 'Device assignment required'):
@ -75,79 +137,32 @@ class AllReduceTest(test.TestCase):
nccl.all_sum([]) nccl.all_sum([])
class BroadcastTest(test.TestCase): class SingleReduceTest(NcclTestCase):
def testSum(self):
self._Test(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x + y)
class BroadcastTest(NcclTestCase):
def testBroadcast(self): def testBroadcast(self):
if not test.is_gpu_available(): self._Test(_NcclBroadcast, lambda x, y: x)
return # Test requires access to a GPU
for dtype in [np.float32, np.int32, np.int64, np.float64]:
# Create session inside outer loop to test use of
# same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess:
for devices in [['/device:GPU:1', '/device:GPU:0', '/device:GPU:2'],
['/device:GPU:1', '/device:GPU:0']]:
shape = (3, 4)
sender = np.random.randint(0, len(devices) - 1)
with ops.device(devices[sender]):
np_ans = ((
(np.random.random_sample(shape) - .5) * 1024).astype(dtype))
t = array_ops.identity(np_ans)
other_devices = devices[:sender] + devices[sender + 1:]
send_op, received_tensors = nccl.broadcast(t, other_devices)
# Verify shape inference.
for r in received_tensors:
self.assertEqual(shape, r.get_shape())
# Run and verify results.
nccl_results = sess.run(received_tensors + [send_op])
for r in nccl_results[:-1]:
self.assertAllClose(r, np_ans)
class CombinedTest(test.TestCase): class CombinedTest(NcclTestCase):
"""Tests using a mix of all-reduce ops in one session.run call.""" """Test all-reduce vs. single-reduce plus broadcast in one session.run."""
def _combined(self, tensors, devices):
all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices)[0]
single_reduce_tensors, single_reduce_ops = _NcclReduce(
nccl.reduce_sum, tensors, devices)
broadcast_tensors, broadcast_ops = _NcclBroadcast(single_reduce_tensors,
devices)
all_tensors = all_reduce_tensors + single_reduce_tensors + broadcast_tensors
return all_tensors, single_reduce_ops + broadcast_ops
def testCombined(self): def testCombined(self):
if not test.is_gpu_available(): self._Test(self._combined, lambda x, y: x + y)
return # Test requires access to a GPU
for dtype in [np.float32, np.int32, np.int64, np.float64]:
# Create session inside outer loop to test use of
# same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess:
for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'],
['/device:GPU:0', '/device:GPU:1']]:
shape = (3, 4)
# all-reduce
np_ans = np.zeros(shape=shape, dtype=dtype)
tensors = []
for d in devices:
with ops.device(d):
t = ((np.random.random_sample(shape) - .5) * 1024).astype(dtype)
np_ans += t
tensors.append(array_ops.identity(t))
all_reduce_tensors = nccl.all_sum(tensors)
sender = np.random.randint(0, len(devices) - 1)
other_devices = devices[:sender] + devices[sender + 1:]
send_op, received_tensors = nccl.broadcast(all_reduce_tensors[sender],
other_devices)
# sender doesn't need to be fetched as part of outputs of session.run.
del all_reduce_tensors[sender]
# Verify shape inference.
for r in received_tensors:
self.assertEqual(shape, r.get_shape())
# Run and verify results.
nccl_results = sess.run(
received_tensors + [send_op] + all_reduce_tensors)
for r in nccl_results[:len(received_tensors)]:
self.assertAllClose(r, np_ans)
if __name__ == '__main__': if __name__ == '__main__':