Add contrib/nccl for using all-reduce collectives across GPUs of a single

server.
Change: 145475050
This commit is contained in:
A. Unique TensorFlower 2017-01-24 15:19:38 -08:00 committed by gunan
parent a0087e26e6
commit 38daff28c1
11 changed files with 1648 additions and 0 deletions

View File

@ -0,0 +1,120 @@
# Description:
# Wrap NVIDIA (https://github.com/NVIDIA/nccl) NCCL with tensorflow ops.
# APIs are meant to change over time.
package(
default_visibility = ["//visibility:private"],
features = ["-parse_headers"],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
"tf_cuda_cc_test",
"tf_custom_op_library",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
)
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
tf_custom_op_library(
name = "python/ops/_nccl_ops.so",
srcs = [
"kernels/nccl_manager.cc",
"kernels/nccl_manager.h",
"kernels/nccl_ops.cc",
"ops/nccl_ops.cc",
],
deps = [
"//tensorflow/core:gpu_headers_lib",
"@nccl_archive//:nccl",
],
)
tf_gen_op_libs(
op_lib_names = ["nccl_ops"],
deps = [
"//tensorflow/core:lib",
],
)
tf_gen_op_wrapper_py(
name = "nccl_ops",
deps = [":nccl_ops_op_lib"],
)
py_library(
name = "nccl_py",
srcs = [
"__init__.py",
"python/ops/nccl_ops.py",
],
data = [
":python/ops/_nccl_ops.so",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":nccl_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:platform",
],
)
cuda_py_test(
name = "nccl_ops_test",
size = "small",
srcs = ["python/ops/nccl_ops_test.py"],
additional_deps = [
":nccl_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
tags = [
"manual",
"requires_cudnn5",
],
)
tf_cuda_cc_test(
name = "nccl_manager_test",
size = "small",
srcs = if_cuda(
[
"kernels/nccl_manager.cc",
"kernels/nccl_manager.h",
"kernels/nccl_manager_test.cc",
],
[],
),
deps = if_cuda(
[
"@nccl_archive//:nccl",
"//tensorflow/core",
"//tensorflow/core:cuda",
],
[],
) + [
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,24 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for nccl AllReduce."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.contrib.nccl.python.ops.nccl_ops import *
# pylint: enable=wildcard-import

View File

@ -0,0 +1,471 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/nccl/kernels/nccl_manager.h"
#ifdef GOOGLE_CUDA
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/cuda.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
// Contains data for a single stream used for nccl communication; this includes
// a background thread that calls NcclManager::LoopKernelLaunches.
struct NcclManager::NcclStream {
public:
NcclStream() {}
~NcclStream() {
mutex_lock l(mu);
shutdown_requested = true;
cv.notify_all();
}
perftools::gputools::StreamExecutor* executor = nullptr;
// The stream on which to run the nccl collective.
// This is a different stream than the tensorflow compute stream.
std::unique_ptr<perftools::gputools::Stream> stream;
// See NcclManager::LoopKernelLaunches for information on these.
std::unique_ptr<Thread> thread;
mutex mu;
condition_variable cv;
// Has collective,rank pairs.
std::deque<std::pair<Collective*, int>> pending_launches_ GUARDED_BY(mu);
bool shutdown_requested GUARDED_BY(mu) = false;
};
struct NcclManager::CommunicatorMember {
public:
CommunicatorMember() {}
~CommunicatorMember() {
if (nccl_comm != nullptr) ncclCommDestroy(nccl_comm);
}
ncclComm_t nccl_comm;
// Owned by NcclManager::device_to_comm_streams_.
NcclStream* nccl_stream = nullptr;
};
struct NcclManager::Communicator {
public:
Communicator(std::vector<CommunicatorMember> members)
: num_devices(members.size()), members(std::move(members)) {}
const int num_devices;
const std::vector<CommunicatorMember> members; // indexed by rank.
};
namespace {
ncclDataType_t ToNcclType(DataType t) {
switch (t) {
case DT_FLOAT:
return ncclFloat;
case DT_DOUBLE:
return ncclDouble;
case DT_INT32:
return ncclInt;
case DT_INT64:
return ncclInt64;
default:
return ncclFloat;
}
}
} // namespace
// A participant in a Collective. See <Collective> below.
struct NcclManager::Participant {
Participant(const Tensor* in_t, Tensor* out_t, EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream,
perftools::gputools::StreamExecutor* executor,
NcclManager::DoneCallback done_callback)
: in_t(in_t),
out_t(out_t),
event_mgr(event_mgr),
tensor_stream(tensor_stream),
executor(executor),
done_callback(std::move(done_callback)) {
DCHECK(executor != nullptr);
DCHECK(event_mgr != nullptr);
DCHECK(tensor_stream != nullptr);
}
// Owned by the caller, who must keep it live until <done_callback> is called.
// Is NULL for participants that only receive data.
const Tensor* in_t;
// Owned by the caller, who must keep it live until <done_callback> is called.
// Is NULL for participants that only send data.
Tensor* out_t;
// Owned by the caller, who must keep it live until <done_callback> is called.
EventMgr* const event_mgr;
// Owned by the caller, who must keep it live until <done_callback> is called.
perftools::gputools::Stream* const tensor_stream;
// Matches the executor in CommunicatorMember::stream. Expected to be live for
// process lifetime.
perftools::gputools::StreamExecutor* executor = nullptr;
NcclManager::DoneCallback done_callback;
bool root = false;
};
// A Collective tracks a single communicator operation (e.g., a single
// AllReduce call).
struct NcclManager::Collective {
Collective(DataType data_type_in, CollectiveType type_in,
ncclRedOp_t reduction_op_in, int num_devices)
: data_type(data_type_in),
type(type_in),
reduction_op(reduction_op_in),
remaining_participants(num_devices) {
participants.reserve(num_devices);
}
const DataType data_type;
const CollectiveType type;
const ncclRedOp_t reduction_op; // applies when <type> is a reduction.
Communicator* communicator = nullptr;
// All collective participants.
//
// Adding values in this vector is guarded by the mutex of the containing
// NcclManager.
std::vector<std::unique_ptr<Participant>> participants;
// For collective types that have a root (e.g. the root of broadcast is the
// sender), this is the rank of the root.
int root_rank = -1;
// How many participants have been registered so far. The Collective is
// eligible for running with <available_participants> == participants.size().
//
// Guarded by the mutex of the containing Communicator.
int available_participants = 0;
mutable std::atomic_int_fast32_t remaining_participants;
};
NcclManager::NcclManager() {}
NcclManager::~NcclManager() {}
NcclManager* NcclManager::instance() {
static NcclManager* instance = new NcclManager();
return instance;
}
NcclManager::Communicator* NcclManager::GetCommunicator(
NcclManager::Collective* collective) {
// Sort by executor to make ordering of executors deterministic.
std::sort(collective->participants.begin(), collective->participants.end(),
[](const std::unique_ptr<Participant>& a,
const std::unique_ptr<Participant>& b) {
return a->executor < b->executor;
});
const int num_devices = collective->participants.size();
mutex_lock l(mu_);
// Scan to find an existing communicator that provides nccl communication
// between the executors used by the participants in the collective. For
// example, if a collective is for GPUs 0, 1, and 2 then this will scan
// to find the communicator for GPUs 0, 1, and 2.
//
// Note that each executor identifies a context on one device, so this is the
// same as getting the communicator connecting the devices in the collective.
// A device can be in different communicators as well - for example, a
// communicator for GPUs 0 and 1 is separate from one for GPUs 0, 1, and 2.
//
// Since it's expected that a small number of distinct communicators will
// be needed, communicators_ is not garbage collected currently.
//
// Launching of kernels must be serialized so that, given collectives A and B,
// and an order of them (e.g., A before B), then for each comm_stream
// involved, the kernel for A is launched before the kernel for B. This is
// guaranteed currently be a global mutex controlling additions of the kernels
// to per-stream launch queues. The launch queues are processed by
// LoopKernelLaunches.
for (auto& comm : communicators_) {
if (comm->num_devices == num_devices) {
int i;
for (i = 0; i < num_devices; ++i) {
if (comm->members[i].nccl_stream->executor !=
collective->participants[i]->executor) {
break;
}
}
if (i == num_devices) return comm.get();
}
}
auto* env = Env::Default();
std::set<NcclStream*> used_streams;
// Create and initialize a new communicator.
// Note that this is done under the lock; performance is not expected to
// matter as this happens a very small number of times.
std::vector<CommunicatorMember> members(num_devices);
for (int i = 0; i < num_devices; ++i) {
auto* executor = collective->participants[i]->executor;
// Find a communication stream to use for the device.
auto& streams = device_to_comm_streams_[executor];
NcclStream* nccl_stream = nullptr;
for (const auto& s : streams) {
if (used_streams.insert(s.get()).second) {
nccl_stream = s.get();
break;
}
}
if (nccl_stream == nullptr) {
nccl_stream = new NcclStream();
nccl_stream->executor = executor;
nccl_stream->stream.reset(new perftools::gputools::Stream(executor));
nccl_stream->stream->Init();
streams.emplace_back(nccl_stream);
used_streams.insert(nccl_stream);
nccl_stream->thread.reset(env->StartThread(
ThreadOptions(), "nccl_kernel_launch",
[this, nccl_stream] { LoopKernelLaunches(nccl_stream); }));
}
members[i].nccl_stream = nccl_stream;
}
// Call ncclCommInitRank for each member.
ncclUniqueId id;
CHECK_EQ(ncclSuccess, ncclGetUniqueId(&id));
std::unique_ptr<thread::ThreadPool> pool(
new thread::ThreadPool(env, "ncclCommInitRank", num_devices));
std::vector<ncclResult_t> results(num_devices);
for (int rank = 0; rank < num_devices; ++rank) {
CommunicatorMember* member = &members[rank];
ncclResult_t* result = &results[rank];
pool->Schedule([member, num_devices, result, rank, &id]() {
ScopedActivateExecutorContext scoped_context(
member->nccl_stream->executor);
LOG(INFO) << "Calling ncclCommInitRank for rank " << rank;
*result = ncclCommInitRank(&member->nccl_comm, num_devices, id, rank);
LOG(INFO) << "Done calling ncclCommInitRank for rank " << rank << " : "
<< *result;
});
}
pool.reset(); // wait for completion.
for (int i = 0; i < num_devices; ++i) {
CHECK_EQ(results[i], ncclSuccess);
}
communicators_.emplace_back(new Communicator(std::move(members)));
return communicators_.back().get();
}
void NcclManager::AddToAllReduce(int num_devices, const string& key,
ncclRedOp_t reduction_op,
perftools::gputools::StreamExecutor* executor,
EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream,
const Tensor* in_t, Tensor* out_t,
const DoneCallback& done_callback) {
std::unique_ptr<Participant> participant(new Participant(
in_t, out_t, event_mgr, tensor_stream, executor, done_callback));
AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
kAllReduce, reduction_op);
}
void NcclManager::AddBroadcastSend(
int num_devices, const string& key,
perftools::gputools::StreamExecutor* executor, EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream, const Tensor* in_t,
DoneCallback done_callback) {
std::unique_ptr<Participant> participant(
new Participant(in_t, nullptr /* out_t */, event_mgr, tensor_stream,
executor, done_callback));
participant->root = true;
AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
kBroadcast, ncclSum /* unused */);
}
void NcclManager::AddBroadcastRecv(
int num_devices, const string& key,
perftools::gputools::StreamExecutor* executor, EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream, Tensor* out_t,
DoneCallback done_callback) {
std::unique_ptr<Participant> participant(
new Participant(nullptr /* in_t */, out_t, event_mgr, tensor_stream,
executor, done_callback));
AddParticipant(num_devices, key, std::move(participant), out_t->dtype(),
kBroadcast, ncclSum /* unused */);
}
void NcclManager::AddParticipant(int num_devices, const string& key,
std::unique_ptr<Participant> participant,
DataType data_type,
CollectiveType collective_type,
ncclRedOp_t reduction_op) {
Collective* to_run = nullptr;
{
mutex_lock l(mu_);
auto& collective_ptr = collectives_[key];
if (collective_ptr == nullptr) {
collective_ptr.reset(new Collective(data_type, collective_type,
reduction_op, num_devices));
}
Collective* collective = collective_ptr.get();
DCHECK_EQ(collective->type, collective_type);
DCHECK_EQ(collective->participants.size(), num_devices);
collective->participants.emplace_back(std::move(participant));
++collective->available_participants;
if (collective->available_participants == num_devices) {
to_run = collective;
// Ownership is going to be transferred to RunCollective.
collective_ptr.release();
collectives_.erase(key);
}
}
if (to_run != nullptr) {
RunCollective(key, to_run);
}
}
void NcclManager::RunCollective(const string& key, Collective* collective) {
static mutex collective_mu;
auto* communicator = GetCommunicator(collective);
collective->communicator = communicator;
const int size = communicator->num_devices;
for (int rank = 0; rank < size; ++rank) {
Participant* p = collective->participants[rank].get();
NcclStream* nccl_stream = communicator->members[rank].nccl_stream;
CHECK(nccl_stream != nullptr);
if (p->in_t != nullptr) {
// Wait to ensure that the kernel that produces the data in the input
// tensor has finished running before the nccl kernel runs on the
// communication stream.
nccl_stream->stream->ThenWaitFor(p->tensor_stream);
}
if (p->root) {
CHECK_EQ(collective->root_rank, -1);
collective->root_rank = rank;
}
}
if (collective->type == kBroadcast) {
CHECK_NE(collective->root_rank, -1);
}
{
// Allow only one collective at a time to queue kernels for launching. This
// is to prevent collectives from deadlocking each other.
// Note that it would be possible to run multiple collectives at once, if
// they have non-intersecting sets of devices.
mutex_lock l(collective_mu);
for (int rank = 0; rank < size; ++rank) {
NcclStream* nccl_stream = communicator->members[rank].nccl_stream;
mutex_lock l(nccl_stream->mu);
nccl_stream->pending_launches_.push_front(
std::make_pair(collective, rank));
nccl_stream->cv.notify_all();
}
}
}
void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
perftools::gputools::Stream* comm_stream = nccl_stream->stream.get();
ScopedActivateExecutorContext scoped_context(nccl_stream->executor);
const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>(
comm_stream->implementation()->CudaStreamMemberHack());
while (true) {
// Find collective to run.
std::pair<Collective*, int> next_launch;
{
mutex_lock l(nccl_stream->mu);
while (nccl_stream->pending_launches_.empty()) {
if (nccl_stream->shutdown_requested) {
// No work and shutdown requested, exit.
return;
}
nccl_stream->cv.wait(l);
}
next_launch = nccl_stream->pending_launches_.back();
nccl_stream->pending_launches_.pop_back();
}
Collective* collective = next_launch.first;
int rank = next_launch.second;
// Launch the nccl kernel.
ncclDataType_t data_type = ToNcclType(collective->data_type);
Participant* p = collective->participants[rank].get();
auto nccl_comm = collective->communicator->members[rank].nccl_comm;
ncclResult_t nccl_result = ncclSuccess;
switch (collective->type) {
case kAllReduce: {
const void* sendbuff = p->in_t->tensor_data().data();
void* recvbuff = const_cast<char*>(p->out_t->tensor_data().data());
nccl_result =
ncclAllReduce(sendbuff, recvbuff, p->in_t->NumElements(), data_type,
collective->reduction_op, nccl_comm, *cu_stream);
break;
}
case kBroadcast: {
const Tensor* buf_t = p->in_t ? p->in_t : p->out_t;
void* buf = const_cast<char*>(buf_t->tensor_data().data());
nccl_result = ncclBcast(buf, buf_t->NumElements(), data_type,
collective->root_rank, nccl_comm, *cu_stream);
break;
}
}
// Run the done_callback when the nccl kernel finishes running.
auto done_callback = [collective, rank, nccl_result]() {
if (nccl_result == ncclSuccess) {
collective->participants[rank]->done_callback(Status::OK());
} else {
// Propagate the error, but note that if other members of the collective
// did launch their kernels, then they are hanging.
collective->participants[rank]->done_callback(errors::Unknown(
"Error invoking AllReduce: ", ncclGetErrorString(nccl_result)));
}
// TODO(cwhipkey): use RefCounted after figuring out how to use in a
// custom op library.
// See tensorflow/core/lib/core/refcount.h for details on this locking.
if (collective->remaining_participants.load(std::memory_order_acquire) ==
1 ||
collective->remaining_participants.fetch_sub(1) == 1) {
delete collective;
}
};
p->event_mgr->ThenExecute(comm_stream, done_callback);
}
}
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,122 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
#ifdef GOOGLE_CUDA
#include <unordered_map>
#include <vector>
#include "external/nccl_archive/src/nccl.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor.h"
namespace tensorflow {
// The communicator is used to make the asynchronous communicator calls and to
// manage the per-device streams used for communication.
//
// See nccl_ops.cc for example usage, including description of memory
// management and stream synchronization.
class NcclManager {
public:
typedef std::function<void(Status)> DoneCallback;
NcclManager();
~NcclManager();
static NcclManager* instance();
// Add one participant to an all-reduce, sending in data from <in_t> and
// receiving the result of the all-reduce in <out_t>. The device for this
// participant is managed by <executor>, and its events are polled by
// <event_mgr>.
//
// This is an asynchronous call. When <done_callback> is called, <out_t> has
// been set to the all-reduce result (note: the stream may not yet have been
// synced).
//
// <tensor_stream> is the stream that should be waited on to ensure <in_t>'s
// data is available on the GPU for the communication stream to access. It
// is also the stream that will use the produced data; <done_callback> is
// not called until the next kernel launched on <stream> would see the data.
void AddToAllReduce(int num_devices, const string& key,
ncclRedOp_t reduction_op,
perftools::gputools::StreamExecutor* executor,
EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream,
const Tensor* in_t, Tensor* out_t,
const DoneCallback& done_callback);
// AddBroadcastSend and AddBroadcastRecv combine to sent data from one sender
// to all receivers.
void AddBroadcastSend(int num_devices, const string& key,
perftools::gputools::StreamExecutor* executor,
EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream,
const Tensor* in_t, DoneCallback done_callback);
void AddBroadcastRecv(int num_devices, const string& key,
perftools::gputools::StreamExecutor* executor,
EventMgr* event_mgr,
perftools::gputools::Stream* tensor_stream,
Tensor* out_t, DoneCallback done_callback);
private:
enum CollectiveType {
kAllReduce = 1,
kBroadcast = 2,
};
struct Collective;
struct Communicator;
struct CommunicatorMember;
struct NcclStream;
struct Participant;
Communicator* GetCommunicator(Collective* collective);
void AddParticipant(int num_devices, const string& key,
std::unique_ptr<Participant> participant,
DataType data_type, CollectiveType collective_type,
ncclRedOp_t reduction_op);
// Run <collective>. This calls takes ownership of <collective>.
void RunCollective(const string& key, Collective* collective);
void LoopKernelLaunches(NcclStream* stream);
mutex mu_;
// Maps key to collectives currently being assembled or run.
std::unordered_map<string, std::unique_ptr<Collective>> collectives_
GUARDED_BY(mu_);
// Maps a device to the communication streams that make up its collective.
// This is used to share the stream across different communicators that
// include the same device.
std::map<perftools::gputools::StreamExecutor*,
std::vector<std::unique_ptr<NcclStream>>>
device_to_comm_streams_ GUARDED_BY(mu_);
std::vector<std::unique_ptr<Communicator>> communicators_;
TF_DISALLOW_COPY_AND_ASSIGN(NcclManager);
};
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_

View File

@ -0,0 +1,285 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifdef GOOGLE_CUDA
#include <algorithm>
#include <vector>
#include "tensorflow/contrib/nccl/kernels/nccl_manager.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
static std::vector<BaseGPUDevice*> GetGPUDevices() {
std::vector<Device*> devices;
SessionOptions session_options;
session_options.env = Env::Default();
Status s = DeviceFactory::GetFactory(DEVICE_GPU)
->AddDevices(session_options, "", &devices);
TF_CHECK_OK(s);
std::vector<BaseGPUDevice*> gpus;
for (Device* d : devices) {
if (d->device_type() == "GPU") {
gpus.push_back(static_cast<BaseGPUDevice*>(d));
} else {
delete d;
}
}
return gpus;
}
class NcclManagerTest : public ::testing::Test {
protected:
static void SetUpTestCase() {
setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
devices = new std::vector<BaseGPUDevice*>(GetGPUDevices());
CHECK(!devices->empty());
LOG(ERROR) << "Running test with " << devices->size() << " gpus";
}
static void TearDownTestCase() {
for (auto device : *devices) delete device;
delete devices;
}
static Allocator* gpu_allocator(BaseGPUDevice* device) {
return device->GetStepAllocator(AllocatorAttributes(),
nullptr /* step_resource_manager */);
}
static std::vector<BaseGPUDevice*>* devices;
template <typename Scalar>
perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
const Scalar* cuda_memory) {
perftools::gputools::DeviceMemoryBase wrapped(
const_cast<Scalar*>(cuda_memory));
perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
return typed;
}
// A single all-reduce to apply.
struct TestCase {
string key;
std::vector<Tensor> ins;
std::vector<Tensor> outs;
Tensor expected;
mutex mu;
Status final_status;
int num_completed = 0;
};
TestCase* MakeTestCase(int num_ranks, ncclRedOp_t reduction_op,
TensorShape shape, float value_offset) {
TestCase* test_case = new TestCase();
test_case->expected = Tensor(DT_FLOAT, shape);
if (reduction_op == ncclProd) {
test::FillFn<float>(&test_case->expected, [](int) { return 1; });
} else if (reduction_op == ncclSum) {
test::FillFn<float>(&test_case->expected, [](int) { return 0; });
} else if (reduction_op == ncclMax) {
test::FillFn<float>(&test_case->expected, [](int) {
return -1 * std::numeric_limits<float>::max();
});
} else if (reduction_op == ncclMin) {
test::FillFn<float>(&test_case->expected, [](int) {
return std::numeric_limits<float>::max();
});
} else {
LOG(FATAL) << "Invalid reduction_op " << reduction_op;
}
int mult = 1;
for (int i = 0; i < num_ranks; ++i) {
auto* device = devices->at(i % devices->size());
auto* stream = device->tensorflow_gpu_device_info()->stream;
Tensor in_cpu(DT_FLOAT, shape);
test::FillFn<float>(&in_cpu, [mult, value_offset](int index) {
return value_offset + (index + 1) * mult;
});
for (int j = 0; j < shape.num_elements(); ++j) {
auto in_val = in_cpu.flat<float>()(j);
auto out_expr = test_case->expected.flat<float>();
if (reduction_op == ncclProd) {
out_expr(j) *= in_val;
} else if (reduction_op == ncclSum) {
out_expr(j) += in_val;
} else if (reduction_op == ncclMax) {
if (in_val > out_expr(j)) {
out_expr(j) = in_val;
}
} else if (reduction_op == ncclMin) {
if (in_val < out_expr(j)) {
out_expr(j) = in_val;
}
}
}
mult *= 10;
test_case->ins.emplace_back(gpu_allocator(device), DT_FLOAT, shape);
test_case->outs.emplace_back(gpu_allocator(device), DT_FLOAT, shape);
const Tensor& in_gpu = test_case->ins.back();
auto in_gpu_mem = AsDeviceMemory(in_gpu.flat<float>().data());
stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat<float>().data(),
in_cpu.TotalBytes());
}
return test_case;
}
NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) {
return [this, test_case](Status s) {
mutex_lock l(test_case->mu);
++test_case->num_completed;
test_case->final_status.Update(s);
};
}
void VerifyResults(const string& case_label, TestCase* test_case) {
// Wait for the done callback to be called.
{
test_case->mu.lock();
while (test_case->num_completed != test_case->outs.size()) {
test_case->mu.unlock();
Env::Default()->SleepForMicroseconds(10);
test_case->mu.lock();
}
test_case->mu.unlock();
}
// Copy memory to host and verify.
for (int i = 0; i < test_case->outs.size(); ++i) {
auto* device = devices->at(i % devices->size());
auto* stream = device->tensorflow_gpu_device_info()->stream;
const Tensor& out_gpu = test_case->outs[i];
Tensor out_cpu(DT_FLOAT, out_gpu.shape());
auto out_gpu_mem = AsDeviceMemory(out_gpu.flat<float>().data());
stream->ThenMemcpy(out_cpu.flat<float>().data(), out_gpu_mem,
out_cpu.TotalBytes());
stream->BlockHostUntilDone();
test::ExpectTensorEqual<float>(test_case->expected, out_cpu);
}
}
};
std::vector<BaseGPUDevice*>* NcclManagerTest::devices = nullptr;
// Test basic sum reduction.
TEST_F(NcclManagerTest, BasicSumReduction) {
const int num_ranks = 3;
for (int op = 0; op < 4; ++op) {
ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(op);
std::unique_ptr<TestCase> test_case(
MakeTestCase(num_ranks, reduction_op, TensorShape({2, 3}), 0));
for (int device_num = 0; device_num < num_ranks; ++device_num) {
auto* device = devices->at(device_num % devices->size());
auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
auto* stream = device->tensorflow_gpu_device_info()->stream;
NcclManager::instance()->AddToAllReduce(
num_ranks, "allreduce", reduction_op, device->executor(), event_mgr,
stream, &test_case->ins[device_num], &test_case->outs[device_num],
CreateDoneCallback(test_case.get()));
}
LOG(ERROR) << "Verifying results";
VerifyResults("test_case", test_case.get());
}
}
// Same as the Basic test, but with multiple threads launching parts of many
// reductions.
//
// Testing the multi-rank execution is currently reduced as it can hang when run
// with num_ranks > devices->size(), for some GPUs (e.g. K20m).
// To test the higher settings, increase num_ranks,
// num_collectives_per_iteration and time_limit_micros.
TEST_F(NcclManagerTest, MultipleCallers) {
const int num_ranks = 1; // 2;
const int num_collectives_per_iteration = 1; // 1000;
const int num_threads = 3;
const int time_limit_micros = 1; // 60 * 30 * 1000 * 1000;
int64 start = Env::Default()->NowMicros();
srand(Env::Default()->NowMicros());
for (;;) {
std::vector<std::pair<int, int>> case_and_device_num;
std::vector<std::unique_ptr<TestCase>> test_cases;
for (int i = 0; i < num_collectives_per_iteration; ++i) {
test_cases.emplace_back(
MakeTestCase(num_ranks, ncclSum,
TensorShape({100, i % 5 + 1, i % 3 + 1}), i + 0.1 * i));
for (int j = 0; j < num_ranks; ++j) {
case_and_device_num.emplace_back(i, j);
}
}
for (int i = 0; i < num_ranks; ++i) {
auto* device = devices->at(i % devices->size());
auto* stream = device->tensorflow_gpu_device_info()->stream;
stream->BlockHostUntilDone();
}
std::random_shuffle(case_and_device_num.begin(), case_and_device_num.end());
mutex mu; // guards case_and_device_num.
std::unique_ptr<thread::ThreadPool> pool(
new thread::ThreadPool(Env::Default(), "test", num_threads));
const int to_schedule = case_and_device_num.size();
for (int i = 0; i < to_schedule; ++i) {
auto fn = [&]() {
int device_num;
int test_num;
{
mutex_lock l(mu);
test_num = case_and_device_num.back().first;
device_num = case_and_device_num.back().second;
case_and_device_num.pop_back();
}
auto* device = devices->at(device_num % devices->size());
auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
auto* stream = device->tensorflow_gpu_device_info()->stream;
TestCase* test_case = test_cases[test_num].get();
NcclManager::instance()->AddToAllReduce(
num_ranks, strings::StrCat("allreduce", test_num), ncclSum,
device->executor(), event_mgr, stream, &test_case->ins[device_num],
&test_case->outs[device_num], CreateDoneCallback(test_case));
};
pool->Schedule(fn);
}
pool.reset(); // wait for all work to be scheduled.
LOG(ERROR) << "Verifying results for " << num_collectives_per_iteration
<< " collectives";
for (int i = 0; i < test_cases.size(); ++i) {
VerifyResults(strings::StrCat("collective", i), test_cases[i].get());
}
int64 delta = Env::Default()->NowMicros() - start;
if (delta > time_limit_micros) {
LOG(ERROR) << "Ran for " << delta << " quitting";
break;
}
}
}
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,157 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#include <unordered_map>
#include <vector>
#include "external/nccl_archive/src/nccl.h"
#include "tensorflow/contrib/nccl/kernels/nccl_manager.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
// Base class for all communicator ops that use nccl.
//
// About memory management and stream syncing:
// 1. The nccl communicator has a stream for each rank.
// 2. For input tensors to the communicator, the compute stream is passed to the
// NcclManager which will do a needed
// communicator_stream.ThenWaitFor(input_tensor_stream).
// 3. The done_callback of the async kernel is not called by the
// NcclManager until after the communicator kernel is complete. This
// is enough to a) keep the input tensor data valid for the lifetime of the
// collective; and b) ensure the data in the output tensor is available
// when the async op kernel's done callback is called.
class NcclAsyncOpBase : public AsyncOpKernel {
public:
NcclAsyncOpBase(OpKernelConstruction* c) : AsyncOpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("num_devices", &num_devices_));
OP_REQUIRES_OK(c, c->GetAttr("shared_name", &collective_prefix_));
}
string GetCollectiveKey(OpKernelContext* c) {
return strings::StrCat(collective_prefix_, ";", c->step_id(), ";",
c->frame_iter().frame_id, ":",
c->frame_iter().iter_id);
}
int num_devices() const { return num_devices_; }
private:
int num_devices_;
string collective_prefix_;
TF_DISALLOW_COPY_AND_ASSIGN(NcclAsyncOpBase);
};
// To execute a single all-reduce, this kernel is called once for each of the
// <k> devices in the communicator.
class NcclAllReduceOpKernel : public NcclAsyncOpBase {
public:
NcclAllReduceOpKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) {
string reduction;
OP_REQUIRES_OK(c, c->GetAttr("reduction", &reduction));
if (reduction == "min") {
reduction_op_ = ncclMin;
} else if (reduction == "max") {
reduction_op_ = ncclMax;
} else if (reduction == "sum") {
reduction_op_ = ncclSum;
} else if (reduction == "prod") {
reduction_op_ = ncclProd;
} else {
OP_REQUIRES_OK(c,
errors::InvalidArgument("Invalid reduction: ", reduction));
}
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
const Tensor* in_t = &c->input(0);
Tensor* out_t;
OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, in_t->shape(), &out_t), done);
auto actual_done = [c, done](Status s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
auto* compute_stream = c->op_device_context()->stream();
EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr;
NcclManager::instance()->AddToAllReduce(
num_devices(), GetCollectiveKey(c), reduction_op_,
compute_stream->parent(), event_mgr, compute_stream, in_t, out_t,
actual_done);
}
private:
ncclRedOp_t reduction_op_;
};
REGISTER_KERNEL_BUILDER(Name("NcclAllReduce").Device(DEVICE_GPU),
NcclAllReduceOpKernel);
class NcclBroadcastSendKernel : public NcclAsyncOpBase {
public:
NcclBroadcastSendKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) {}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
auto actual_done = [c, done](Status s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
auto* compute_stream = c->op_device_context()->stream();
EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr;
NcclManager::instance()->AddBroadcastSend(
num_devices(), GetCollectiveKey(c), compute_stream->parent(), event_mgr,
compute_stream, &c->input(0), std::move(actual_done));
}
};
REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU),
NcclBroadcastSendKernel);
class NcclBroadcastRecvKernel : public NcclAsyncOpBase {
public:
NcclBroadcastRecvKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) {}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
const Tensor& shape_t = c->input(0);
TensorShape shape;
OP_REQUIRES_OK_ASYNC(
c, TensorShapeUtils::MakeShape(shape_t.vec<int64>(), &shape), done);
Tensor* out_t;
OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape, &out_t), done);
auto actual_done = [c, done](Status s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
auto* compute_stream = c->op_device_context()->stream();
EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr;
NcclManager::instance()->AddBroadcastRecv(
num_devices(), GetCollectiveKey(c), compute_stream->parent(), event_mgr,
compute_stream, out_t, std::move(actual_done));
}
};
REGISTER_KERNEL_BUILDER(
Name("NcclBroadcastRecv").Device(DEVICE_GPU).HostMemory("shape"),
NcclBroadcastRecvKernel);
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,94 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
REGISTER_OP("NcclAllReduce")
.Input("input: T")
.Output("data: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
.Attr("T: {float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Outputs a tensor containing the reduction across all input tensors passed to ops
within the same `shared_name.
The graph should be constructed so if one op runs with shared_name value `c`,
then `num_devices` ops will run with shared_name value `c`. Failure to do so
will cause the graph execution to fail to complete.
input: the input to the reduction
data: the value of the reduction across all `num_devices` devices.
reduction: the reduction operation to perform.
num_devices: The number of devices participating in this reduction.
shared_name: Identifier that shared between ops of the same reduction.
)doc");
REGISTER_OP("NcclBroadcastSend")
.Input("input: T")
.Attr("T: {float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs)
.Doc(R"doc(
Sends `input` to the NcclBroadcastRecv ops registered in the same `shared_name`.
The graph should be constructed so that one device runs `NcclBroadcastSend` and
`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`.
Failure to do so will cause the graph execution to fail to complete.
input: The input to the broadcast
num_devices: The number of devices participating in this reduction.
shared_name: Identifier that is shared between ops of the same broadcast.
)doc");
REGISTER_OP("NcclBroadcastRecv")
.Input("shape: int64")
.Output("output: T")
.Attr("T: {float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
c->set_output(0, out);
return Status::OK();
})
.Doc(R"doc(
Sends data of shape `shape` from the NcclBroadcastSend op registered in the
same `shared_name`.
The graph should be constructed so that one device runs `NcclBroadcastSend` and
`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`.
Failure to do so will cause the graph execution to fail to complete.
shape: The shape of the output.
output: The broadcast data received from the NcclBroadcastSend op.
num_devices: The number of devices participating in this reduction.
shared_name: Identifier that is shared between ops of the same broadcast.
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,168 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for GPU collective operations implemented using NVIDIA nccl."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from tensorflow.contrib.nccl.ops import gen_nccl_ops
from tensorflow.contrib.util import loader
from tensorflow.python.framework import device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import resource_loader
_nccl_ops_so = loader.load_op_library(
resource_loader.get_path_to_datafile('_nccl_ops.so'))
def all_sum(tensors):
"""Returns a list of tensors with the all-reduce sum across `tensors`.
The computation is done with an all-reduce operation, so if only some of the
returned tensors are evaluated then the computation will hang.
Args:
tensors: The input tensors across which to sum; must be assigned
to GPU devices.
Returns:
List of tensors, each with the sum of the input tensors, where tensor i has
the same device as `tensors[i]`.
"""
return _apply_all_reduce('sum', tensors)
def all_prod(tensors):
"""Returns a list of tensors with the all-reduce product across `tensors`.
The computation is done with an all-reduce operation, so if only some of the
returned tensors are evaluated then the computation will hang.
Args:
tensors: The input tensors across which to multiply; must be assigned
to GPU devices.
Returns:
List of tensors, each with the product of the input tensors, where tensor i
has the same device as `tensors[i]`.
"""
return _apply_all_reduce('prod', tensors)
def all_min(tensors):
"""Returns a list of tensors with the all-reduce min across `tensors`.
The computation is done with an all-reduce operation, so if only some of the
returned tensors are evaluated then the computation will hang.
Args:
tensors: The input tensors across which to reduce; must be assigned
to GPU devices.
Returns:
List of tensors, each with the minimum of the input tensors, where tensor i
has the same device as `tensors[i]`.
"""
return _apply_all_reduce('min', tensors)
def all_max(tensors):
"""Returns a list of tensors with the all-reduce max across `tensors`.
The computation is done with an all-reduce operation, so if only some of the
returned tensors are evaluated then the computation will hang.
Args:
tensors: The input tensors across which to reduce; must be assigned
to GPU devices.
Returns:
List of tensors, each with the maximum of the input tensors, where tensor i
has the same device as `tensors[i]`.
"""
return _apply_all_reduce('max', tensors)
def broadcast(src_tensor, dst_devices):
"""Returns a list of tensors on `dst_devices`, each with value `tensor`.
The computation is done with a broadcast nccl operation, so if only some of
the returned tensors and src_tensor are evaluated then the computation will
hang.
Args:
src_tensor: The tensor to send; must be assigned to a GPU device.
dst_devices: The GPU devices to receive the sent tensor.
Returns:
List of tensors, each with the value of `src_tensor`, which the device
of tensor i is `dst_devices[i]`.
"""
if not dst_devices:
raise ValueError('Must pass >0 dst_devices to broadcast')
all_devices = [src_tensor.device] + dst_devices
shared_name = _get_shared_name()
with ops.device(src_tensor.device):
send = gen_nccl_ops.nccl_broadcast_send(
input=src_tensor, num_devices=len(all_devices), shared_name=shared_name)
shape_op = array_ops.shape(src_tensor, out_type=dtypes.int64)
recvs = []
for d in dst_devices:
with ops.device(d):
recvs.append(
gen_nccl_ops.nccl_broadcast_recv(
shape=shape_op,
T=src_tensor.dtype,
num_devices=len(all_devices),
shared_name=shared_name))
return send, recvs
def _apply_all_reduce(reduction_op, tensors):
if not tensors:
raise ValueError('Must pass >0 tensors to all reduce operations')
shared_name = _get_shared_name()
res = []
for t in tensors:
if not device.canonical_name(t.device):
raise ValueError('Device assignment required for nccl collective ops')
with ops.device(t.device):
res.append(
gen_nccl_ops.nccl_all_reduce(
t,
reduction=reduction_op,
num_devices=len(tensors),
shared_name=shared_name))
return res
_lock = threading.Lock()
_shared_name_counter = 0
def _get_shared_name():
global _shared_name_counter
with _lock:
val = _shared_name_counter
_shared_name_counter += 1
return 'c%s' % val

View File

@ -0,0 +1,151 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for nccl ops. See also the cc test for nccl_communicator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib import nccl
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class AllReduceTest(test.TestCase):
def testAllReduce(self):
if not test.is_gpu_available():
return # Test requires access to a GPU
for dtype in [np.float32, np.int32, np.int64, np.float64]:
# Create session inside outer loop to test use of
# same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess:
self._testSingleAllReduce(sess, dtype, nccl.all_sum, lambda x, y: x + y)
self._testSingleAllReduce(sess, dtype, nccl.all_prod,
lambda x, y: x * y)
self._testSingleAllReduce(sess, dtype, nccl.all_min, np.minimum)
self._testSingleAllReduce(sess, dtype, nccl.all_max, np.maximum)
def _testSingleAllReduce(self, sess, np_type, nccl_fn, numpy_accumulation_fn):
for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]:
shape = (3, 4)
np_ans = None
tensors = []
for d in devices:
with ops.device(d):
t = ((np.random.random_sample(shape) - .5) * 1024).astype(np_type)
if np_ans is None:
np_ans = t
else:
np_ans = numpy_accumulation_fn(np_ans, t)
tensors.append(array_ops.identity(t))
all_reduce_tensors = nccl_fn(tensors)
# Test shape inference.
for r in all_reduce_tensors:
self.assertEqual(shape, r.get_shape())
# Test execution and results.
nccl_results = sess.run(all_reduce_tensors)
for r in nccl_results:
self.assertAllClose(r, np_ans)
def testErrors(self):
with self.assertRaisesRegexp(ValueError, 'Device assignment required'):
nccl.all_sum([array_ops.identity(np.random.random_sample((3, 4)))])
with self.assertRaisesRegexp(ValueError, 'Must pass >0 tensors'):
nccl.all_sum([])
class BroadcastTest(test.TestCase):
def testBroadcast(self):
if not test.is_gpu_available():
return # Test requires access to a GPU
for dtype in [np.float32, np.int32, np.int64, np.float64]:
# Create session inside outer loop to test use of
# same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess:
for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]:
shape = (3, 4)
sender = np.random.randint(0, len(devices) - 1)
with ops.device(devices[sender]):
np_ans = ((
(np.random.random_sample(shape) - .5) * 1024).astype(dtype))
t = array_ops.identity(np_ans)
other_devices = devices[:sender] + devices[sender + 1:]
send_op, received_tensors = nccl.broadcast(t, other_devices)
# Verify shape inference.
for r in received_tensors:
self.assertEqual(shape, r.get_shape())
# Run and verify results.
nccl_results = sess.run(received_tensors + [send_op])
for r in nccl_results[:-1]:
self.assertAllClose(r, np_ans)
class CombinedTest(test.TestCase):
"""Tests using a mix of all-reduce ops in one session.run call."""
def testCombined(self):
if not test.is_gpu_available():
return # Test requires access to a GPU
for dtype in [np.float32, np.int32, np.int64, np.float64]:
# Create session inside outer loop to test use of
# same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess:
for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]:
shape = (3, 4)
# all-reduce
np_ans = np.zeros(shape=shape, dtype=dtype)
tensors = []
for d in devices:
with ops.device(d):
t = ((np.random.random_sample(shape) - .5) * 1024).astype(dtype)
np_ans += t
tensors.append(array_ops.identity(t))
all_reduce_tensors = nccl.all_sum(tensors)
sender = np.random.randint(0, len(devices) - 1)
other_devices = devices[:sender] + devices[sender + 1:]
send_op, received_tensors = nccl.broadcast(all_reduce_tensors[sender],
other_devices)
# sender doesn't need to be fetched as part of outputs of session.run.
del all_reduce_tensors[sender]
# Verify shape inference.
for r in received_tensors:
self.assertEqual(shape, r.get_shape())
# Run and verify results.
nccl_results = sess.run(
received_tensors + [send_op] + all_reduce_tensors)
for r in nccl_results[:len(received_tensors)]:
self.assertAllClose(r, np_ans)
if __name__ == '__main__':
test.main()

View File

@ -385,6 +385,14 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
actual = "@zlib_archive//:zlib",
)
native.new_http_archive(
name = "nccl_archive",
url = "https://github.com/NVIDIA/nccl/archive/2a974f5ca2aa12b178046b2206b43f1fd69d9fae.tar.gz",
sha256 = "d6aa1a3f20ae85358890d9a96f49c51a75baa1d3af3598501f29ff9ef8a3107d",
strip_prefix = "nccl-2a974f5ca2aa12b178046b2206b43f1fd69d9fae",
build_file = str(Label("//third_party:nccl.BUILD")),
)
# Make junit-4.12 available as //external:junit
native.http_jar(
name = "junit_jar",

48
third_party/nccl.BUILD vendored Normal file
View File

@ -0,0 +1,48 @@
# NVIDIA nccl
# A package of optimized primitives for collective multi-GPU communication.
licenses(["notice"]) # BSD
exports_files(["LICENSE.txt"])
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda")
SRCS = [
"src/all_gather.cu",
"src/all_reduce.cu",
"src/broadcast.cu",
"src/core.cu",
"src/libwrap.cu",
"src/reduce.cu",
"src/reduce_scatter.cu",
]
# Copy .cu to .cu.cc so they can be in srcs of cc_library.
[
genrule(
name = "gen_" + src,
srcs = [src],
outs = [src + ".cc"],
cmd = "cp $(location " + src + ") $(location " + src + ".cc)",
)
for src in SRCS
]
SRCS_CU_CC = [src + ".cc" for src in SRCS]
cc_library(
name = "nccl",
srcs = if_cuda(SRCS_CU_CC + glob(["src/*.h"])),
hdrs = if_cuda(["src/nccl.h"]),
copts = [
"-DCUDA_MAJOR=0",
"-DCUDA_MINOR=0",
"-DNCCL_MAJOR=0",
"-DNCCL_MINOR=0",
"-DNCCL_PATCH=0",
"-Iexternal/nccl_archive/src",
"-O3",
] + cuda_default_copts(),
visibility = ["//visibility:public"],
deps = ["@local_config_cuda//cuda:cuda_headers"],
)