Adding NCCL sum op, register all_sum gradient.

Streamlining nccl test.

PiperOrigin-RevId: 168347428
This commit is contained in:
A. Unique TensorFlower 2017-09-12 02:00:55 -07:00 committed by TensorFlower Gardener
parent bc300318e7
commit 9b9e54b344
7 changed files with 415 additions and 119 deletions

View File

@ -18,6 +18,7 @@
@@all_min
@@all_prod
@@all_sum
@@reduce_sum
@@broadcast
"""
@ -31,6 +32,7 @@ from tensorflow.contrib.nccl.python.ops.nccl_ops import all_min
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_prod
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_sum
from tensorflow.contrib.nccl.python.ops.nccl_ops import broadcast
from tensorflow.contrib.nccl.python.ops.nccl_ops import reduce_sum
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)

View File

@ -260,7 +260,7 @@ NcclManager::Communicator* NcclManager::GetCommunicator(
std::vector<ncclComm_t> nccl_comms(num_devices);
auto result = ncclCommInitAll(nccl_comms.data(), num_devices, devices.data());
CHECK_EQ(result, ncclSuccess);
CHECK_EQ(result, ncclSuccess) << ncclGetErrorString(result);
for (int rank = 0; rank < num_devices; ++rank) {
members[rank].nccl_comm = nccl_comms[rank];
}
@ -307,6 +307,35 @@ void NcclManager::AddBroadcastRecv(
kBroadcast, ncclSum /* unused */);
}
void NcclManager::AddReduceSend(int num_devices, const string& key,
ncclRedOp_t reduction_op,
perftools::gputools::StreamExecutor* executor,
int gpu_device_id, EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream,
const Tensor* in_t, Tensor* temp_t,
DoneCallback done_callback) {
std::unique_ptr<Participant> participant(
new Participant(in_t, temp_t, event_mgr, tensor_stream, executor,
gpu_device_id, std::move(done_callback)));
AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
kReduce, reduction_op);
}
void NcclManager::AddReduceRecv(int num_devices, const string& key,
ncclRedOp_t reduction_op,
perftools::gputools::StreamExecutor* executor,
int gpu_device_id, EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream,
const Tensor* in_t, Tensor* out_t,
DoneCallback done_callback) {
std::unique_ptr<Participant> participant(
new Participant(in_t, out_t, event_mgr, tensor_stream, executor,
gpu_device_id, std::move(done_callback)));
participant->root = true;
AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
kReduce, reduction_op);
}
void NcclManager::AddParticipant(int num_devices, const string& key,
std::unique_ptr<Participant> participant,
DataType data_type,
@ -431,6 +460,14 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
collective->root_rank, nccl_comm, *cu_stream);
break;
}
case kReduce: {
const void* sendbuff = p->in_t->tensor_data().data();
void* recvbuff = const_cast<char*>(p->out_t->tensor_data().data());
nccl_result = ncclReduce(sendbuff, recvbuff, p->in_t->NumElements(),
data_type, collective->reduction_op,
collective->root_rank, nccl_comm, *cu_stream);
break;
}
}
// Run the done_callback when the nccl kernel finishes running.
@ -441,7 +478,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
// Propagate the error, but note that if other members of the collective
// did launch their kernels, then they are hanging.
collective->participants[rank]->done_callback(errors::Unknown(
"Error invoking AllReduce: ", ncclGetErrorString(nccl_result)));
"Error invoking NCCL: ", ncclGetErrorString(nccl_result)));
}
// TODO(cwhipkey): use RefCounted after figuring out how to use in a

View File

@ -75,10 +75,28 @@ class NcclManager {
perftools::gputools::Stream* tensor_stream,
Tensor* out_t, DoneCallback done_callback);
// AddReduceSend and AddReduceRecv combine to sent data from all senders
// to one receiver.
void AddReduceSend(int num_devices, const string& key,
ncclRedOp_t reduction_op,
perftools::gputools::StreamExecutor* executor,
int gpu_device_id, EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream,
const Tensor* in_t, Tensor* temp_t,
DoneCallback done_callback);
void AddReduceRecv(int num_devices, const string& key,
ncclRedOp_t reduction_op,
perftools::gputools::StreamExecutor* executor,
int gpu_device_id, EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream,
const Tensor* in_t, Tensor* out_t,
DoneCallback done_callback);
private:
enum CollectiveType {
kAllReduce = 1,
kBroadcast = 2,
kReduce = 3,
};
struct Collective;
struct Communicator;

View File

@ -15,6 +15,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include <memory>
#include <unordered_map>
#include <vector>
@ -58,11 +59,9 @@ class NcclAsyncOpBase : public AsyncOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(NcclAsyncOpBase);
};
// To execute a single all-reduce, this kernel is called once for each of the
// <k> devices in the communicator.
class NcclAllReduceOpKernel : public NcclAsyncOpBase {
class NcclReduceOpBase : public NcclAsyncOpBase {
public:
explicit NcclAllReduceOpKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) {
explicit NcclReduceOpBase(OpKernelConstruction* c) : NcclAsyncOpBase(c) {
string reduction;
OP_REQUIRES_OK(c, c->GetAttr("reduction", &reduction));
if (reduction == "min") {
@ -79,6 +78,19 @@ class NcclAllReduceOpKernel : public NcclAsyncOpBase {
}
}
ncclRedOp_t reduction_op() const { return reduction_op_; }
private:
ncclRedOp_t reduction_op_;
};
// To execute a single all-reduce, this kernel is called once for each of the
// <k> devices in the communicator.
class NcclAllReduceOpKernel : public NcclReduceOpBase {
public:
explicit NcclAllReduceOpKernel(OpKernelConstruction* c)
: NcclReduceOpBase(c) {}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
const Tensor* in_t = &c->input(0);
Tensor* out_t;
@ -92,18 +104,81 @@ class NcclAllReduceOpKernel : public NcclAsyncOpBase {
auto* compute_stream = c->op_device_context()->stream();
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
NcclManager::instance()->AddToAllReduce(
num_devices(), GetCollectiveKey(c), reduction_op_,
num_devices(), GetCollectiveKey(c), reduction_op(),
compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr,
compute_stream, in_t, out_t, actual_done);
compute_stream, in_t, out_t, std::move(actual_done));
}
};
REGISTER_KERNEL_BUILDER(Name("NcclAllReduce").Device(DEVICE_GPU),
NcclAllReduceOpKernel);
// To execute a single reduce, this kernel is called once for all but one of the
// <k> devices in the communicator, and NcclReduceRecvKernel is called once for
// the remaining device.
class NcclReduceSendKernel : public NcclReduceOpBase {
public:
explicit NcclReduceSendKernel(OpKernelConstruction* c)
: NcclReduceOpBase(c) {}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
const Tensor& in_t = c->input(0);
std::unique_ptr<Tensor> temp_ptr(new Tensor());
OP_REQUIRES_OK_ASYNC(
c, c->allocate_temp(in_t.dtype(), in_t.shape(), temp_ptr.get()), done);
Tensor* temp_t = temp_ptr.release();
auto actual_done = [c, done, temp_t](Status s) {
delete temp_t;
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
auto* compute_stream = c->op_device_context()->stream();
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
NcclManager::instance()->AddReduceSend(
num_devices(), GetCollectiveKey(c), reduction_op(),
compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr,
compute_stream, &in_t, temp_t, std::move(actual_done));
}
};
REGISTER_KERNEL_BUILDER(Name("NcclReduceSend").Device(DEVICE_GPU),
NcclReduceSendKernel);
// To execute a single reduce, this kernel is called once for one devices, and
// NcclReduceSendKernel is called for all other <k-1> devices in the
// communicator.
class NcclReduceRecvKernel : public NcclReduceOpBase {
public:
explicit NcclReduceRecvKernel(OpKernelConstruction* c)
: NcclReduceOpBase(c) {}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
const Tensor& in_t = c->input(0);
Tensor* out_t;
OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, in_t.shape(), &out_t), done);
auto actual_done = [c, done](Status s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
auto* compute_stream = c->op_device_context()->stream();
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
NcclManager::instance()->AddReduceRecv(
num_devices(), GetCollectiveKey(c), reduction_op(),
compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr,
compute_stream, &in_t, out_t, std::move(actual_done));
}
private:
ncclRedOp_t reduction_op_;
};
REGISTER_KERNEL_BUILDER(Name("NcclReduceRecv").Device(DEVICE_GPU),
NcclReduceRecvKernel);
REGISTER_KERNEL_BUILDER(Name("NcclAllReduce").Device(DEVICE_GPU),
NcclAllReduceOpKernel);
// To execute a single broadcast, this kernel is called once for one device, and
// NcclBroadcastRecvKernel is called for all other <k-1> devices in the
// communicator.
class NcclBroadcastSendKernel : public NcclAsyncOpBase {
public:
explicit NcclBroadcastSendKernel(OpKernelConstruction* c)
@ -126,6 +201,9 @@ class NcclBroadcastSendKernel : public NcclAsyncOpBase {
REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU),
NcclBroadcastSendKernel);
// To execute a single broadcast, this kernel is called once for all but one of
// the <k> devices in the communicator, and NcclBroadcastSendKernel is called
// once for the remaining device.
class NcclBroadcastRecvKernel : public NcclAsyncOpBase {
public:
explicit NcclBroadcastRecvKernel(OpKernelConstruction* c)

View File

@ -45,6 +45,51 @@ num_devices: The number of devices participating in this reduction.
shared_name: Identifier that shared between ops of the same reduction.
)doc");
REGISTER_OP("NcclReduceSend")
.Input("input: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
.Attr("T: {float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs)
.Doc(R"doc(
Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`.
The graph should be constructed so that 'num_devices-1' devices run
`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value
`c`. Failure to do so will cause the graph execution to fail to complete.
input: The input to the reduction
reduction: the reduction operation to perform.
num_devices: The number of devices participating in this reduction.
shared_name: Identifier that is shared between ops of the same reduce.
)doc");
REGISTER_OP("NcclReduceRecv")
.Input("input: T")
.Output("data: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
.Attr("T: {float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Reduces 'input' from this op and the NcclReduceSend ops registered in the same
`shared_name`.
The graph should be constructed so that 'num_devices-1' devices run
`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value
`c`. Failure to do so will cause the graph execution to fail to complete.
input: The input to the reduction
data: The reduced data received from this op and the NcclReduceSend op.
reduction: the reduction operation to perform.
num_devices: The number of devices participating in this reduction.
shared_name: Identifier that is shared between ops of the same reduce.
)doc");
REGISTER_OP("NcclBroadcastSend")
.Input("input: T")
.Attr("T: {float, float64, int32, int64}")

View File

@ -21,6 +21,7 @@ import threading
from tensorflow.contrib.nccl.ops import gen_nccl_ops
from tensorflow.contrib.util import loader
from tensorflow.python.eager import context
from tensorflow.python.framework import device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -48,6 +49,35 @@ def all_sum(tensors):
return _apply_all_reduce('sum', tensors)
@ops.RegisterGradient('NcclAllReduce')
def _all_sum_grad(op, grad):
"""The gradients for `all_sum`.
Args:
op: The `all_sum` `Operation` that we are differentiating.
grad: Gradient with respect to the output of the `all_sum` op.
Returns:
The gradient with respect to the output of `all_sum`.
Raises:
LookupError: If `reduction` is not `sum`.
"""
if op.get_attr('reduction') != 'sum':
raise LookupError('No gradient defined for NcclAllReduce except all_sum.')
_check_device_assignment(grad)
num_devices = op.get_attr('num_devices')
shared_name = op.get_attr('shared_name') + '_grad'
with ops.device(grad.device):
return gen_nccl_ops.nccl_all_reduce(
input=grad,
reduction='sum',
num_devices=num_devices,
shared_name=shared_name)
def all_prod(tensors):
"""Returns a list of tensors with the all-reduce product across `tensors`.
@ -99,6 +129,24 @@ def all_max(tensors):
return _apply_all_reduce('max', tensors)
def reduce_sum(tensors, dst_device):
"""Returns a tensor with the reduce sum across `tensors`.
The computation is done with a reduce operation, so only one tensor is
returned.
Args:
tensors: The input tensors across which to sum; must be assigned
to GPU devices.
dst_device: The device of the returned tensor.
Returns:
A tensor containing the sum of the input tensors, with the device of the
tensor being `dst_device`.
"""
return _apply_reduce('sum', tensors, dst_device)
def broadcast(src_tensor, dst_devices):
"""Returns a list of tensors on `dst_devices`, each with value `tensor`.
@ -111,50 +159,93 @@ def broadcast(src_tensor, dst_devices):
dst_devices: The GPU devices to receive the sent tensor.
Returns:
List of tensors, each with the value of `src_tensor`, which the device
of tensor i is `dst_devices[i]`.
An `Operation` to send the `src_tensor`, and a list of tensors, each with
the value of `src_tensor`, where the device of tensor i is `dst_devices[i]`.
"""
if not dst_devices:
raise ValueError('Must pass >0 dst_devices to broadcast')
all_devices = [src_tensor.device] + dst_devices
_check_graph_mode()
_check_device_assignment(src_tensor)
shape = array_ops.shape(src_tensor, out_type=dtypes.int64)
num_devices = len(dst_devices) + 1
shared_name = _get_shared_name()
with ops.device(src_tensor.device):
send = gen_nccl_ops.nccl_broadcast_send(
input=src_tensor, num_devices=len(all_devices), shared_name=shared_name)
input=src_tensor, num_devices=num_devices, shared_name=shared_name)
shape_op = array_ops.shape(src_tensor, out_type=dtypes.int64)
recvs = []
for d in dst_devices:
with ops.device(d):
recvs.append(
gen_nccl_ops.nccl_broadcast_recv(
shape=shape_op,
shape=shape,
T=src_tensor.dtype,
num_devices=len(all_devices),
num_devices=num_devices,
shared_name=shared_name))
return send, recvs
def _apply_all_reduce(reduction_op, tensors):
def _apply_all_reduce(reduction, tensors):
"""Helper function for all_* functions."""
if not tensors:
raise ValueError('Must pass >0 tensors to all reduce operations')
_check_graph_mode()
shared_name = _get_shared_name()
res = []
for t in tensors:
if not device.canonical_name(t.device):
raise ValueError('Device assignment required for nccl collective ops')
_check_device_assignment(t)
with ops.device(t.device):
res.append(
gen_nccl_ops.nccl_all_reduce(
t,
reduction=reduction_op,
input=t,
reduction=reduction,
num_devices=len(tensors),
shared_name=shared_name))
return res
def _apply_reduce(reduction, tensors, dst_device):
"""Helper function for reduce_* functions."""
if not tensors:
raise ValueError('Must pass >0 tensors to reduce operations')
if not dst_device:
raise ValueError('Must pass dst_device to reduce operations')
_check_graph_mode()
try:
recv_index = next(i for i, t in enumerate(tensors)
if t.device == dst_device)
except StopIteration:
raise ValueError('One of the tensors must be assigned to dst_device')
shared_name = _get_shared_name()
sends = []
for t in tensors[:recv_index] + tensors[recv_index + 1:]:
_check_device_assignment(t)
with ops.device(t.device):
sends.append(
gen_nccl_ops.nccl_reduce_send(
input=t,
reduction=reduction,
num_devices=len(tensors),
shared_name=shared_name))
with ops.device(dst_device):
recv = gen_nccl_ops.nccl_reduce_recv(
input=tensors[recv_index],
reduction=reduction,
num_devices=len(tensors),
shared_name=shared_name)
return recv, sends
_lock = threading.Lock()
_shared_name_counter = 0
@ -166,3 +257,13 @@ def _get_shared_name():
val = _shared_name_counter
_shared_name_counter += 1
return 'c%s' % val
def _check_device_assignment(tensor):
if not device.canonical_name(tensor.device):
raise ValueError('Device assignment required for nccl collective ops')
def _check_graph_mode():
if context.in_eager_mode():
raise ValueError('Nccl ops are not supported in eager mode')

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
import numpy as np
from tensorflow.contrib import nccl
@ -26,9 +27,45 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class AllReduceTest(test.TestCase):
def _DeviceTensors(tensors, devices):
res = []
for t, d in zip(tensors, devices):
with ops.device(d):
res.append(array_ops.identity(t))
return res
def testAllReduce(self):
def _NcclAllReduce(nccl_fun, tensors, devices):
return nccl_fun(_DeviceTensors(tensors, devices)), []
def _NcclReduce(nccl_fun, tensors, devices):
d_tensors = _DeviceTensors(tensors, devices)
receiver = np.random.randint(0, len(devices))
received_tensor, send_ops = nccl_fun(d_tensors, devices[receiver])
return [received_tensor], send_ops
def _NcclBroadcast(tensors, devices):
sender = np.random.randint(0, len(devices))
d_tensor = _DeviceTensors(tensors[0:1], devices[sender:sender + 1])[0]
other_devices = devices[:sender] + devices[sender + 1:]
send_op, received_tensors = nccl.broadcast(d_tensor, other_devices)
return received_tensors, [send_op]
class NcclTestCase(test.TestCase):
def _Test(self, nccl_reduce, numpy_fn):
"""Tests that nccl_reduce does the same as reduction with numpy_fn.
Args:
nccl_reduce: A function taking a list of tensors and a list of devices,
and returns a list of reduced tensors and a list of ops to perform the
reduction.
numpy_fn: A function taking two tensors and returning the reduction of the
two.
"""
if not test.is_gpu_available():
return # Test requires access to a GPU
@ -36,37 +73,62 @@ class AllReduceTest(test.TestCase):
# Create session inside outer loop to test use of
# same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess:
self._testSingleAllReduce(sess, dtype, nccl.all_sum, lambda x, y: x + y)
self._testSingleAllReduce(sess, dtype, nccl.all_prod,
lambda x, y: x * y)
self._testSingleAllReduce(sess, dtype, nccl.all_min, np.minimum)
self._testSingleAllReduce(sess, dtype, nccl.all_max, np.maximum)
def _testSingleAllReduce(self, sess, np_type, nccl_fn, numpy_accumulation_fn):
for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'],
['/device:GPU:1', '/device:GPU:0']]:
shape = (3, 4)
np_ans = None
tensors = []
for d in devices:
with ops.device(d):
t = ((np.random.random_sample(shape) - .5) * 1024).astype(np_type)
if np_ans is None:
np_ans = t
else:
np_ans = numpy_accumulation_fn(np_ans, t)
tensors.append(array_ops.identity(t))
for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'],
['/device:GPU:1', '/device:GPU:0']]:
shape = (3, 4)
random = (np.random.random_sample(shape) - .5) * 1024
tensors = [random.astype(dtype)] * len(devices)
np_ans = tensors[0]
for t in tensors[1:]:
np_ans = numpy_fn(np_ans, t)
all_reduce_tensors = nccl_fn(tensors)
reduce_tensors, reduce_ops = nccl_reduce(tensors, devices)
self.assertNotEmpty(reduce_tensors)
# Test shape inference.
for r in all_reduce_tensors:
self.assertEqual(shape, r.get_shape())
# Test shape inference.
for r in reduce_tensors:
self.assertEqual(shape, r.get_shape())
# Test execution and results.
nccl_results = sess.run(all_reduce_tensors)
for r in nccl_results:
self.assertAllClose(r, np_ans)
# Test execution and results.
nccl_results = sess.run(reduce_tensors + reduce_ops)
for r in nccl_results[:len(reduce_tensors)]:
self.assertAllClose(r, np_ans)
def _TestGradient(self, nccl_reduce, numpy_fn):
"""Tests the gradient of nccl_reduce.
Args:
nccl_reduce: A function taking a list of tensors and a list of devices,
and returns a list of reduced tensors and a list of ops to perform the
reduction.
numpy_fn: A function taking two tensors and returning the gradient of the
reduction of the two.
"""
def _Gradient(tensors, devices):
reduce_tensors, _ = nccl_reduce(tensors, devices)
tensor_ops = [t.op for t in reduce_tensors]
d_tensors = _DeviceTensors(tensors, devices)
grad_tensors = [
ops.get_gradient_function(op)(op, loss)
for op, loss in zip(tensor_ops, d_tensors)
]
return grad_tensors, []
self._Test(_Gradient, numpy_fn)
class AllReduceTest(NcclTestCase):
def testAllReduce(self):
self._Test(partial(_NcclAllReduce, nccl.all_sum), lambda x, y: x + y)
self._Test(partial(_NcclAllReduce, nccl.all_prod), lambda x, y: x * y)
self._Test(partial(_NcclAllReduce, nccl.all_min), np.minimum)
self._Test(partial(_NcclAllReduce, nccl.all_max), np.maximum)
def testAllSumGrad(self):
self._TestGradient(
partial(_NcclAllReduce, nccl.all_sum), lambda x, y: x + y)
def testErrors(self):
with self.assertRaisesRegexp(ValueError, 'Device assignment required'):
@ -75,79 +137,32 @@ class AllReduceTest(test.TestCase):
nccl.all_sum([])
class BroadcastTest(test.TestCase):
class SingleReduceTest(NcclTestCase):
def testSum(self):
self._Test(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x + y)
class BroadcastTest(NcclTestCase):
def testBroadcast(self):
if not test.is_gpu_available():
return # Test requires access to a GPU
for dtype in [np.float32, np.int32, np.int64, np.float64]:
# Create session inside outer loop to test use of
# same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess:
for devices in [['/device:GPU:1', '/device:GPU:0', '/device:GPU:2'],
['/device:GPU:1', '/device:GPU:0']]:
shape = (3, 4)
sender = np.random.randint(0, len(devices) - 1)
with ops.device(devices[sender]):
np_ans = ((
(np.random.random_sample(shape) - .5) * 1024).astype(dtype))
t = array_ops.identity(np_ans)
other_devices = devices[:sender] + devices[sender + 1:]
send_op, received_tensors = nccl.broadcast(t, other_devices)
# Verify shape inference.
for r in received_tensors:
self.assertEqual(shape, r.get_shape())
# Run and verify results.
nccl_results = sess.run(received_tensors + [send_op])
for r in nccl_results[:-1]:
self.assertAllClose(r, np_ans)
self._Test(_NcclBroadcast, lambda x, y: x)
class CombinedTest(test.TestCase):
"""Tests using a mix of all-reduce ops in one session.run call."""
class CombinedTest(NcclTestCase):
"""Test all-reduce vs. single-reduce plus broadcast in one session.run."""
def _combined(self, tensors, devices):
all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices)[0]
single_reduce_tensors, single_reduce_ops = _NcclReduce(
nccl.reduce_sum, tensors, devices)
broadcast_tensors, broadcast_ops = _NcclBroadcast(single_reduce_tensors,
devices)
all_tensors = all_reduce_tensors + single_reduce_tensors + broadcast_tensors
return all_tensors, single_reduce_ops + broadcast_ops
def testCombined(self):
if not test.is_gpu_available():
return # Test requires access to a GPU
for dtype in [np.float32, np.int32, np.int64, np.float64]:
# Create session inside outer loop to test use of
# same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess:
for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'],
['/device:GPU:0', '/device:GPU:1']]:
shape = (3, 4)
# all-reduce
np_ans = np.zeros(shape=shape, dtype=dtype)
tensors = []
for d in devices:
with ops.device(d):
t = ((np.random.random_sample(shape) - .5) * 1024).astype(dtype)
np_ans += t
tensors.append(array_ops.identity(t))
all_reduce_tensors = nccl.all_sum(tensors)
sender = np.random.randint(0, len(devices) - 1)
other_devices = devices[:sender] + devices[sender + 1:]
send_op, received_tensors = nccl.broadcast(all_reduce_tensors[sender],
other_devices)
# sender doesn't need to be fetched as part of outputs of session.run.
del all_reduce_tensors[sender]
# Verify shape inference.
for r in received_tensors:
self.assertEqual(shape, r.get_shape())
# Run and verify results.
nccl_results = sess.run(
received_tensors + [send_op] + all_reduce_tensors)
for r in nccl_results[:len(received_tensors)]:
self.assertAllClose(r, np_ans)
self._Test(self._combined, lambda x, y: x + y)
if __name__ == '__main__':