Add contrib/nccl for using all-reduce collectives across GPUs of a single
server. Change: 145475050
This commit is contained in:
parent
a0087e26e6
commit
38daff28c1
120
tensorflow/contrib/nccl/BUILD
Normal file
120
tensorflow/contrib/nccl/BUILD
Normal file
@ -0,0 +1,120 @@
|
||||
# Description:
|
||||
# Wrap NVIDIA (https://github.com/NVIDIA/nccl) NCCL with tensorflow ops.
|
||||
# APIs are meant to change over time.
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
features = ["-parse_headers"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cuda_cc_test",
|
||||
"tf_custom_op_library",
|
||||
"tf_gen_op_libs",
|
||||
"tf_gen_op_wrapper_py",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "python/ops/_nccl_ops.so",
|
||||
srcs = [
|
||||
"kernels/nccl_manager.cc",
|
||||
"kernels/nccl_manager.h",
|
||||
"kernels/nccl_ops.cc",
|
||||
"ops/nccl_ops.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:gpu_headers_lib",
|
||||
"@nccl_archive//:nccl",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["nccl_ops"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "nccl_ops",
|
||||
deps = [":nccl_ops_op_lib"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "nccl_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/ops/nccl_ops.py",
|
||||
],
|
||||
data = [
|
||||
":python/ops/_nccl_ops.so",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":nccl_ops",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "nccl_ops_test",
|
||||
size = "small",
|
||||
srcs = ["python/ops/nccl_ops_test.py"],
|
||||
additional_deps = [
|
||||
":nccl_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
tags = [
|
||||
"manual",
|
||||
"requires_cudnn5",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "nccl_manager_test",
|
||||
size = "small",
|
||||
srcs = if_cuda(
|
||||
[
|
||||
"kernels/nccl_manager.cc",
|
||||
"kernels/nccl_manager.h",
|
||||
"kernels/nccl_manager_test.cc",
|
||||
],
|
||||
[],
|
||||
),
|
||||
deps = if_cuda(
|
||||
[
|
||||
"@nccl_archive//:nccl",
|
||||
"//tensorflow/core",
|
||||
"//tensorflow/core:cuda",
|
||||
],
|
||||
[],
|
||||
) + [
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
24
tensorflow/contrib/nccl/__init__.py
Normal file
24
tensorflow/contrib/nccl/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ops for nccl AllReduce."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.nccl.python.ops.nccl_ops import *
|
||||
# pylint: enable=wildcard-import
|
471
tensorflow/contrib/nccl/kernels/nccl_manager.cc
Normal file
471
tensorflow/contrib/nccl/kernels/nccl_manager.cc
Normal file
@ -0,0 +1,471 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/contrib/nccl/kernels/nccl_manager.h"
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/cuda.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
|
||||
|
||||
// Contains data for a single stream used for nccl communication; this includes
|
||||
// a background thread that calls NcclManager::LoopKernelLaunches.
|
||||
struct NcclManager::NcclStream {
|
||||
public:
|
||||
NcclStream() {}
|
||||
~NcclStream() {
|
||||
mutex_lock l(mu);
|
||||
shutdown_requested = true;
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
perftools::gputools::StreamExecutor* executor = nullptr;
|
||||
|
||||
// The stream on which to run the nccl collective.
|
||||
// This is a different stream than the tensorflow compute stream.
|
||||
std::unique_ptr<perftools::gputools::Stream> stream;
|
||||
|
||||
// See NcclManager::LoopKernelLaunches for information on these.
|
||||
std::unique_ptr<Thread> thread;
|
||||
mutex mu;
|
||||
condition_variable cv;
|
||||
// Has collective,rank pairs.
|
||||
std::deque<std::pair<Collective*, int>> pending_launches_ GUARDED_BY(mu);
|
||||
bool shutdown_requested GUARDED_BY(mu) = false;
|
||||
};
|
||||
|
||||
struct NcclManager::CommunicatorMember {
|
||||
public:
|
||||
CommunicatorMember() {}
|
||||
~CommunicatorMember() {
|
||||
if (nccl_comm != nullptr) ncclCommDestroy(nccl_comm);
|
||||
}
|
||||
ncclComm_t nccl_comm;
|
||||
|
||||
// Owned by NcclManager::device_to_comm_streams_.
|
||||
NcclStream* nccl_stream = nullptr;
|
||||
};
|
||||
|
||||
struct NcclManager::Communicator {
|
||||
public:
|
||||
Communicator(std::vector<CommunicatorMember> members)
|
||||
: num_devices(members.size()), members(std::move(members)) {}
|
||||
|
||||
const int num_devices;
|
||||
const std::vector<CommunicatorMember> members; // indexed by rank.
|
||||
};
|
||||
|
||||
namespace {
|
||||
ncclDataType_t ToNcclType(DataType t) {
|
||||
switch (t) {
|
||||
case DT_FLOAT:
|
||||
return ncclFloat;
|
||||
case DT_DOUBLE:
|
||||
return ncclDouble;
|
||||
case DT_INT32:
|
||||
return ncclInt;
|
||||
case DT_INT64:
|
||||
return ncclInt64;
|
||||
default:
|
||||
return ncclFloat;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// A participant in a Collective. See <Collective> below.
|
||||
struct NcclManager::Participant {
|
||||
Participant(const Tensor* in_t, Tensor* out_t, EventMgr* event_mgr,
|
||||
perftools::gputools::Stream* tensor_stream,
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
NcclManager::DoneCallback done_callback)
|
||||
: in_t(in_t),
|
||||
out_t(out_t),
|
||||
event_mgr(event_mgr),
|
||||
tensor_stream(tensor_stream),
|
||||
executor(executor),
|
||||
done_callback(std::move(done_callback)) {
|
||||
DCHECK(executor != nullptr);
|
||||
DCHECK(event_mgr != nullptr);
|
||||
DCHECK(tensor_stream != nullptr);
|
||||
}
|
||||
// Owned by the caller, who must keep it live until <done_callback> is called.
|
||||
// Is NULL for participants that only receive data.
|
||||
const Tensor* in_t;
|
||||
|
||||
// Owned by the caller, who must keep it live until <done_callback> is called.
|
||||
// Is NULL for participants that only send data.
|
||||
Tensor* out_t;
|
||||
|
||||
// Owned by the caller, who must keep it live until <done_callback> is called.
|
||||
EventMgr* const event_mgr;
|
||||
|
||||
// Owned by the caller, who must keep it live until <done_callback> is called.
|
||||
perftools::gputools::Stream* const tensor_stream;
|
||||
|
||||
// Matches the executor in CommunicatorMember::stream. Expected to be live for
|
||||
// process lifetime.
|
||||
perftools::gputools::StreamExecutor* executor = nullptr;
|
||||
|
||||
NcclManager::DoneCallback done_callback;
|
||||
|
||||
bool root = false;
|
||||
};
|
||||
|
||||
// A Collective tracks a single communicator operation (e.g., a single
|
||||
// AllReduce call).
|
||||
struct NcclManager::Collective {
|
||||
Collective(DataType data_type_in, CollectiveType type_in,
|
||||
ncclRedOp_t reduction_op_in, int num_devices)
|
||||
: data_type(data_type_in),
|
||||
type(type_in),
|
||||
reduction_op(reduction_op_in),
|
||||
remaining_participants(num_devices) {
|
||||
participants.reserve(num_devices);
|
||||
}
|
||||
|
||||
const DataType data_type;
|
||||
const CollectiveType type;
|
||||
const ncclRedOp_t reduction_op; // applies when <type> is a reduction.
|
||||
|
||||
Communicator* communicator = nullptr;
|
||||
|
||||
// All collective participants.
|
||||
//
|
||||
// Adding values in this vector is guarded by the mutex of the containing
|
||||
// NcclManager.
|
||||
std::vector<std::unique_ptr<Participant>> participants;
|
||||
|
||||
// For collective types that have a root (e.g. the root of broadcast is the
|
||||
// sender), this is the rank of the root.
|
||||
int root_rank = -1;
|
||||
|
||||
// How many participants have been registered so far. The Collective is
|
||||
// eligible for running with <available_participants> == participants.size().
|
||||
//
|
||||
// Guarded by the mutex of the containing Communicator.
|
||||
int available_participants = 0;
|
||||
|
||||
mutable std::atomic_int_fast32_t remaining_participants;
|
||||
};
|
||||
|
||||
NcclManager::NcclManager() {}
|
||||
NcclManager::~NcclManager() {}
|
||||
NcclManager* NcclManager::instance() {
|
||||
static NcclManager* instance = new NcclManager();
|
||||
return instance;
|
||||
}
|
||||
|
||||
NcclManager::Communicator* NcclManager::GetCommunicator(
|
||||
NcclManager::Collective* collective) {
|
||||
// Sort by executor to make ordering of executors deterministic.
|
||||
std::sort(collective->participants.begin(), collective->participants.end(),
|
||||
[](const std::unique_ptr<Participant>& a,
|
||||
const std::unique_ptr<Participant>& b) {
|
||||
return a->executor < b->executor;
|
||||
});
|
||||
const int num_devices = collective->participants.size();
|
||||
|
||||
mutex_lock l(mu_);
|
||||
|
||||
// Scan to find an existing communicator that provides nccl communication
|
||||
// between the executors used by the participants in the collective. For
|
||||
// example, if a collective is for GPUs 0, 1, and 2 then this will scan
|
||||
// to find the communicator for GPUs 0, 1, and 2.
|
||||
//
|
||||
// Note that each executor identifies a context on one device, so this is the
|
||||
// same as getting the communicator connecting the devices in the collective.
|
||||
// A device can be in different communicators as well - for example, a
|
||||
// communicator for GPUs 0 and 1 is separate from one for GPUs 0, 1, and 2.
|
||||
//
|
||||
// Since it's expected that a small number of distinct communicators will
|
||||
// be needed, communicators_ is not garbage collected currently.
|
||||
//
|
||||
// Launching of kernels must be serialized so that, given collectives A and B,
|
||||
// and an order of them (e.g., A before B), then for each comm_stream
|
||||
// involved, the kernel for A is launched before the kernel for B. This is
|
||||
// guaranteed currently be a global mutex controlling additions of the kernels
|
||||
// to per-stream launch queues. The launch queues are processed by
|
||||
// LoopKernelLaunches.
|
||||
for (auto& comm : communicators_) {
|
||||
if (comm->num_devices == num_devices) {
|
||||
int i;
|
||||
for (i = 0; i < num_devices; ++i) {
|
||||
if (comm->members[i].nccl_stream->executor !=
|
||||
collective->participants[i]->executor) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (i == num_devices) return comm.get();
|
||||
}
|
||||
}
|
||||
|
||||
auto* env = Env::Default();
|
||||
std::set<NcclStream*> used_streams;
|
||||
|
||||
// Create and initialize a new communicator.
|
||||
// Note that this is done under the lock; performance is not expected to
|
||||
// matter as this happens a very small number of times.
|
||||
std::vector<CommunicatorMember> members(num_devices);
|
||||
for (int i = 0; i < num_devices; ++i) {
|
||||
auto* executor = collective->participants[i]->executor;
|
||||
|
||||
// Find a communication stream to use for the device.
|
||||
auto& streams = device_to_comm_streams_[executor];
|
||||
NcclStream* nccl_stream = nullptr;
|
||||
for (const auto& s : streams) {
|
||||
if (used_streams.insert(s.get()).second) {
|
||||
nccl_stream = s.get();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (nccl_stream == nullptr) {
|
||||
nccl_stream = new NcclStream();
|
||||
nccl_stream->executor = executor;
|
||||
nccl_stream->stream.reset(new perftools::gputools::Stream(executor));
|
||||
nccl_stream->stream->Init();
|
||||
|
||||
streams.emplace_back(nccl_stream);
|
||||
used_streams.insert(nccl_stream);
|
||||
|
||||
nccl_stream->thread.reset(env->StartThread(
|
||||
ThreadOptions(), "nccl_kernel_launch",
|
||||
[this, nccl_stream] { LoopKernelLaunches(nccl_stream); }));
|
||||
}
|
||||
|
||||
members[i].nccl_stream = nccl_stream;
|
||||
}
|
||||
|
||||
// Call ncclCommInitRank for each member.
|
||||
ncclUniqueId id;
|
||||
CHECK_EQ(ncclSuccess, ncclGetUniqueId(&id));
|
||||
std::unique_ptr<thread::ThreadPool> pool(
|
||||
new thread::ThreadPool(env, "ncclCommInitRank", num_devices));
|
||||
std::vector<ncclResult_t> results(num_devices);
|
||||
for (int rank = 0; rank < num_devices; ++rank) {
|
||||
CommunicatorMember* member = &members[rank];
|
||||
ncclResult_t* result = &results[rank];
|
||||
pool->Schedule([member, num_devices, result, rank, &id]() {
|
||||
ScopedActivateExecutorContext scoped_context(
|
||||
member->nccl_stream->executor);
|
||||
LOG(INFO) << "Calling ncclCommInitRank for rank " << rank;
|
||||
*result = ncclCommInitRank(&member->nccl_comm, num_devices, id, rank);
|
||||
LOG(INFO) << "Done calling ncclCommInitRank for rank " << rank << " : "
|
||||
<< *result;
|
||||
});
|
||||
}
|
||||
|
||||
pool.reset(); // wait for completion.
|
||||
for (int i = 0; i < num_devices; ++i) {
|
||||
CHECK_EQ(results[i], ncclSuccess);
|
||||
}
|
||||
communicators_.emplace_back(new Communicator(std::move(members)));
|
||||
return communicators_.back().get();
|
||||
}
|
||||
|
||||
void NcclManager::AddToAllReduce(int num_devices, const string& key,
|
||||
ncclRedOp_t reduction_op,
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
EventMgr* event_mgr,
|
||||
perftools::gputools::Stream* tensor_stream,
|
||||
const Tensor* in_t, Tensor* out_t,
|
||||
const DoneCallback& done_callback) {
|
||||
std::unique_ptr<Participant> participant(new Participant(
|
||||
in_t, out_t, event_mgr, tensor_stream, executor, done_callback));
|
||||
AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
|
||||
kAllReduce, reduction_op);
|
||||
}
|
||||
|
||||
void NcclManager::AddBroadcastSend(
|
||||
int num_devices, const string& key,
|
||||
perftools::gputools::StreamExecutor* executor, EventMgr* event_mgr,
|
||||
perftools::gputools::Stream* tensor_stream, const Tensor* in_t,
|
||||
DoneCallback done_callback) {
|
||||
std::unique_ptr<Participant> participant(
|
||||
new Participant(in_t, nullptr /* out_t */, event_mgr, tensor_stream,
|
||||
executor, done_callback));
|
||||
participant->root = true;
|
||||
AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
|
||||
kBroadcast, ncclSum /* unused */);
|
||||
}
|
||||
|
||||
void NcclManager::AddBroadcastRecv(
|
||||
int num_devices, const string& key,
|
||||
perftools::gputools::StreamExecutor* executor, EventMgr* event_mgr,
|
||||
perftools::gputools::Stream* tensor_stream, Tensor* out_t,
|
||||
DoneCallback done_callback) {
|
||||
std::unique_ptr<Participant> participant(
|
||||
new Participant(nullptr /* in_t */, out_t, event_mgr, tensor_stream,
|
||||
executor, done_callback));
|
||||
AddParticipant(num_devices, key, std::move(participant), out_t->dtype(),
|
||||
kBroadcast, ncclSum /* unused */);
|
||||
}
|
||||
|
||||
void NcclManager::AddParticipant(int num_devices, const string& key,
|
||||
std::unique_ptr<Participant> participant,
|
||||
DataType data_type,
|
||||
CollectiveType collective_type,
|
||||
ncclRedOp_t reduction_op) {
|
||||
Collective* to_run = nullptr;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
auto& collective_ptr = collectives_[key];
|
||||
if (collective_ptr == nullptr) {
|
||||
collective_ptr.reset(new Collective(data_type, collective_type,
|
||||
reduction_op, num_devices));
|
||||
}
|
||||
Collective* collective = collective_ptr.get();
|
||||
DCHECK_EQ(collective->type, collective_type);
|
||||
DCHECK_EQ(collective->participants.size(), num_devices);
|
||||
collective->participants.emplace_back(std::move(participant));
|
||||
++collective->available_participants;
|
||||
|
||||
if (collective->available_participants == num_devices) {
|
||||
to_run = collective;
|
||||
|
||||
// Ownership is going to be transferred to RunCollective.
|
||||
collective_ptr.release();
|
||||
collectives_.erase(key);
|
||||
}
|
||||
}
|
||||
|
||||
if (to_run != nullptr) {
|
||||
RunCollective(key, to_run);
|
||||
}
|
||||
}
|
||||
|
||||
void NcclManager::RunCollective(const string& key, Collective* collective) {
|
||||
static mutex collective_mu;
|
||||
|
||||
auto* communicator = GetCommunicator(collective);
|
||||
collective->communicator = communicator;
|
||||
const int size = communicator->num_devices;
|
||||
|
||||
for (int rank = 0; rank < size; ++rank) {
|
||||
Participant* p = collective->participants[rank].get();
|
||||
NcclStream* nccl_stream = communicator->members[rank].nccl_stream;
|
||||
CHECK(nccl_stream != nullptr);
|
||||
|
||||
if (p->in_t != nullptr) {
|
||||
// Wait to ensure that the kernel that produces the data in the input
|
||||
// tensor has finished running before the nccl kernel runs on the
|
||||
// communication stream.
|
||||
nccl_stream->stream->ThenWaitFor(p->tensor_stream);
|
||||
}
|
||||
if (p->root) {
|
||||
CHECK_EQ(collective->root_rank, -1);
|
||||
collective->root_rank = rank;
|
||||
}
|
||||
}
|
||||
|
||||
if (collective->type == kBroadcast) {
|
||||
CHECK_NE(collective->root_rank, -1);
|
||||
}
|
||||
|
||||
{
|
||||
// Allow only one collective at a time to queue kernels for launching. This
|
||||
// is to prevent collectives from deadlocking each other.
|
||||
// Note that it would be possible to run multiple collectives at once, if
|
||||
// they have non-intersecting sets of devices.
|
||||
mutex_lock l(collective_mu);
|
||||
for (int rank = 0; rank < size; ++rank) {
|
||||
NcclStream* nccl_stream = communicator->members[rank].nccl_stream;
|
||||
mutex_lock l(nccl_stream->mu);
|
||||
nccl_stream->pending_launches_.push_front(
|
||||
std::make_pair(collective, rank));
|
||||
nccl_stream->cv.notify_all();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
|
||||
perftools::gputools::Stream* comm_stream = nccl_stream->stream.get();
|
||||
ScopedActivateExecutorContext scoped_context(nccl_stream->executor);
|
||||
const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>(
|
||||
comm_stream->implementation()->CudaStreamMemberHack());
|
||||
|
||||
while (true) {
|
||||
// Find collective to run.
|
||||
std::pair<Collective*, int> next_launch;
|
||||
{
|
||||
mutex_lock l(nccl_stream->mu);
|
||||
while (nccl_stream->pending_launches_.empty()) {
|
||||
if (nccl_stream->shutdown_requested) {
|
||||
// No work and shutdown requested, exit.
|
||||
return;
|
||||
}
|
||||
nccl_stream->cv.wait(l);
|
||||
}
|
||||
next_launch = nccl_stream->pending_launches_.back();
|
||||
nccl_stream->pending_launches_.pop_back();
|
||||
}
|
||||
Collective* collective = next_launch.first;
|
||||
int rank = next_launch.second;
|
||||
|
||||
// Launch the nccl kernel.
|
||||
ncclDataType_t data_type = ToNcclType(collective->data_type);
|
||||
Participant* p = collective->participants[rank].get();
|
||||
|
||||
auto nccl_comm = collective->communicator->members[rank].nccl_comm;
|
||||
ncclResult_t nccl_result = ncclSuccess;
|
||||
switch (collective->type) {
|
||||
case kAllReduce: {
|
||||
const void* sendbuff = p->in_t->tensor_data().data();
|
||||
void* recvbuff = const_cast<char*>(p->out_t->tensor_data().data());
|
||||
|
||||
nccl_result =
|
||||
ncclAllReduce(sendbuff, recvbuff, p->in_t->NumElements(), data_type,
|
||||
collective->reduction_op, nccl_comm, *cu_stream);
|
||||
break;
|
||||
}
|
||||
case kBroadcast: {
|
||||
const Tensor* buf_t = p->in_t ? p->in_t : p->out_t;
|
||||
void* buf = const_cast<char*>(buf_t->tensor_data().data());
|
||||
nccl_result = ncclBcast(buf, buf_t->NumElements(), data_type,
|
||||
collective->root_rank, nccl_comm, *cu_stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Run the done_callback when the nccl kernel finishes running.
|
||||
auto done_callback = [collective, rank, nccl_result]() {
|
||||
if (nccl_result == ncclSuccess) {
|
||||
collective->participants[rank]->done_callback(Status::OK());
|
||||
} else {
|
||||
// Propagate the error, but note that if other members of the collective
|
||||
// did launch their kernels, then they are hanging.
|
||||
collective->participants[rank]->done_callback(errors::Unknown(
|
||||
"Error invoking AllReduce: ", ncclGetErrorString(nccl_result)));
|
||||
}
|
||||
|
||||
// TODO(cwhipkey): use RefCounted after figuring out how to use in a
|
||||
// custom op library.
|
||||
// See tensorflow/core/lib/core/refcount.h for details on this locking.
|
||||
if (collective->remaining_participants.load(std::memory_order_acquire) ==
|
||||
1 ||
|
||||
collective->remaining_participants.fetch_sub(1) == 1) {
|
||||
delete collective;
|
||||
}
|
||||
};
|
||||
p->event_mgr->ThenExecute(comm_stream, done_callback);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
122
tensorflow/contrib/nccl/kernels/nccl_manager.h
Normal file
122
tensorflow/contrib/nccl/kernels/nccl_manager.h
Normal file
@ -0,0 +1,122 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "external/nccl_archive/src/nccl.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// The communicator is used to make the asynchronous communicator calls and to
|
||||
// manage the per-device streams used for communication.
|
||||
//
|
||||
// See nccl_ops.cc for example usage, including description of memory
|
||||
// management and stream synchronization.
|
||||
class NcclManager {
|
||||
public:
|
||||
typedef std::function<void(Status)> DoneCallback;
|
||||
NcclManager();
|
||||
~NcclManager();
|
||||
|
||||
static NcclManager* instance();
|
||||
|
||||
// Add one participant to an all-reduce, sending in data from <in_t> and
|
||||
// receiving the result of the all-reduce in <out_t>. The device for this
|
||||
// participant is managed by <executor>, and its events are polled by
|
||||
// <event_mgr>.
|
||||
//
|
||||
// This is an asynchronous call. When <done_callback> is called, <out_t> has
|
||||
// been set to the all-reduce result (note: the stream may not yet have been
|
||||
// synced).
|
||||
//
|
||||
// <tensor_stream> is the stream that should be waited on to ensure <in_t>'s
|
||||
// data is available on the GPU for the communication stream to access. It
|
||||
// is also the stream that will use the produced data; <done_callback> is
|
||||
// not called until the next kernel launched on <stream> would see the data.
|
||||
void AddToAllReduce(int num_devices, const string& key,
|
||||
ncclRedOp_t reduction_op,
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
EventMgr* event_mgr,
|
||||
perftools::gputools::Stream* tensor_stream,
|
||||
const Tensor* in_t, Tensor* out_t,
|
||||
const DoneCallback& done_callback);
|
||||
|
||||
// AddBroadcastSend and AddBroadcastRecv combine to sent data from one sender
|
||||
// to all receivers.
|
||||
void AddBroadcastSend(int num_devices, const string& key,
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
EventMgr* event_mgr,
|
||||
perftools::gputools::Stream* tensor_stream,
|
||||
const Tensor* in_t, DoneCallback done_callback);
|
||||
void AddBroadcastRecv(int num_devices, const string& key,
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
EventMgr* event_mgr,
|
||||
perftools::gputools::Stream* tensor_stream,
|
||||
Tensor* out_t, DoneCallback done_callback);
|
||||
|
||||
private:
|
||||
enum CollectiveType {
|
||||
kAllReduce = 1,
|
||||
kBroadcast = 2,
|
||||
};
|
||||
struct Collective;
|
||||
struct Communicator;
|
||||
struct CommunicatorMember;
|
||||
struct NcclStream;
|
||||
struct Participant;
|
||||
|
||||
Communicator* GetCommunicator(Collective* collective);
|
||||
|
||||
void AddParticipant(int num_devices, const string& key,
|
||||
std::unique_ptr<Participant> participant,
|
||||
DataType data_type, CollectiveType collective_type,
|
||||
ncclRedOp_t reduction_op);
|
||||
|
||||
// Run <collective>. This calls takes ownership of <collective>.
|
||||
void RunCollective(const string& key, Collective* collective);
|
||||
void LoopKernelLaunches(NcclStream* stream);
|
||||
|
||||
mutex mu_;
|
||||
|
||||
// Maps key to collectives currently being assembled or run.
|
||||
std::unordered_map<string, std::unique_ptr<Collective>> collectives_
|
||||
GUARDED_BY(mu_);
|
||||
|
||||
// Maps a device to the communication streams that make up its collective.
|
||||
// This is used to share the stream across different communicators that
|
||||
// include the same device.
|
||||
std::map<perftools::gputools::StreamExecutor*,
|
||||
std::vector<std::unique_ptr<NcclStream>>>
|
||||
device_to_comm_streams_ GUARDED_BY(mu_);
|
||||
|
||||
std::vector<std::unique_ptr<Communicator>> communicators_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(NcclManager);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
|
285
tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
Normal file
285
tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
Normal file
@ -0,0 +1,285 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/nccl/kernels/nccl_manager.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static std::vector<BaseGPUDevice*> GetGPUDevices() {
|
||||
std::vector<Device*> devices;
|
||||
SessionOptions session_options;
|
||||
session_options.env = Env::Default();
|
||||
Status s = DeviceFactory::GetFactory(DEVICE_GPU)
|
||||
->AddDevices(session_options, "", &devices);
|
||||
TF_CHECK_OK(s);
|
||||
std::vector<BaseGPUDevice*> gpus;
|
||||
for (Device* d : devices) {
|
||||
if (d->device_type() == "GPU") {
|
||||
gpus.push_back(static_cast<BaseGPUDevice*>(d));
|
||||
} else {
|
||||
delete d;
|
||||
}
|
||||
}
|
||||
return gpus;
|
||||
}
|
||||
|
||||
class NcclManagerTest : public ::testing::Test {
|
||||
protected:
|
||||
static void SetUpTestCase() {
|
||||
setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
|
||||
devices = new std::vector<BaseGPUDevice*>(GetGPUDevices());
|
||||
CHECK(!devices->empty());
|
||||
LOG(ERROR) << "Running test with " << devices->size() << " gpus";
|
||||
}
|
||||
static void TearDownTestCase() {
|
||||
for (auto device : *devices) delete device;
|
||||
delete devices;
|
||||
}
|
||||
|
||||
static Allocator* gpu_allocator(BaseGPUDevice* device) {
|
||||
return device->GetStepAllocator(AllocatorAttributes(),
|
||||
nullptr /* step_resource_manager */);
|
||||
}
|
||||
|
||||
static std::vector<BaseGPUDevice*>* devices;
|
||||
|
||||
template <typename Scalar>
|
||||
perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
|
||||
const Scalar* cuda_memory) {
|
||||
perftools::gputools::DeviceMemoryBase wrapped(
|
||||
const_cast<Scalar*>(cuda_memory));
|
||||
perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
|
||||
return typed;
|
||||
}
|
||||
|
||||
// A single all-reduce to apply.
|
||||
struct TestCase {
|
||||
string key;
|
||||
std::vector<Tensor> ins;
|
||||
std::vector<Tensor> outs;
|
||||
Tensor expected;
|
||||
|
||||
mutex mu;
|
||||
Status final_status;
|
||||
int num_completed = 0;
|
||||
};
|
||||
|
||||
TestCase* MakeTestCase(int num_ranks, ncclRedOp_t reduction_op,
|
||||
TensorShape shape, float value_offset) {
|
||||
TestCase* test_case = new TestCase();
|
||||
test_case->expected = Tensor(DT_FLOAT, shape);
|
||||
if (reduction_op == ncclProd) {
|
||||
test::FillFn<float>(&test_case->expected, [](int) { return 1; });
|
||||
} else if (reduction_op == ncclSum) {
|
||||
test::FillFn<float>(&test_case->expected, [](int) { return 0; });
|
||||
} else if (reduction_op == ncclMax) {
|
||||
test::FillFn<float>(&test_case->expected, [](int) {
|
||||
return -1 * std::numeric_limits<float>::max();
|
||||
});
|
||||
} else if (reduction_op == ncclMin) {
|
||||
test::FillFn<float>(&test_case->expected, [](int) {
|
||||
return std::numeric_limits<float>::max();
|
||||
});
|
||||
} else {
|
||||
LOG(FATAL) << "Invalid reduction_op " << reduction_op;
|
||||
}
|
||||
|
||||
int mult = 1;
|
||||
for (int i = 0; i < num_ranks; ++i) {
|
||||
auto* device = devices->at(i % devices->size());
|
||||
auto* stream = device->tensorflow_gpu_device_info()->stream;
|
||||
|
||||
Tensor in_cpu(DT_FLOAT, shape);
|
||||
test::FillFn<float>(&in_cpu, [mult, value_offset](int index) {
|
||||
return value_offset + (index + 1) * mult;
|
||||
});
|
||||
for (int j = 0; j < shape.num_elements(); ++j) {
|
||||
auto in_val = in_cpu.flat<float>()(j);
|
||||
auto out_expr = test_case->expected.flat<float>();
|
||||
if (reduction_op == ncclProd) {
|
||||
out_expr(j) *= in_val;
|
||||
} else if (reduction_op == ncclSum) {
|
||||
out_expr(j) += in_val;
|
||||
} else if (reduction_op == ncclMax) {
|
||||
if (in_val > out_expr(j)) {
|
||||
out_expr(j) = in_val;
|
||||
}
|
||||
} else if (reduction_op == ncclMin) {
|
||||
if (in_val < out_expr(j)) {
|
||||
out_expr(j) = in_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mult *= 10;
|
||||
test_case->ins.emplace_back(gpu_allocator(device), DT_FLOAT, shape);
|
||||
test_case->outs.emplace_back(gpu_allocator(device), DT_FLOAT, shape);
|
||||
|
||||
const Tensor& in_gpu = test_case->ins.back();
|
||||
auto in_gpu_mem = AsDeviceMemory(in_gpu.flat<float>().data());
|
||||
stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat<float>().data(),
|
||||
in_cpu.TotalBytes());
|
||||
}
|
||||
return test_case;
|
||||
}
|
||||
|
||||
NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) {
|
||||
return [this, test_case](Status s) {
|
||||
mutex_lock l(test_case->mu);
|
||||
++test_case->num_completed;
|
||||
test_case->final_status.Update(s);
|
||||
};
|
||||
}
|
||||
|
||||
void VerifyResults(const string& case_label, TestCase* test_case) {
|
||||
// Wait for the done callback to be called.
|
||||
{
|
||||
test_case->mu.lock();
|
||||
while (test_case->num_completed != test_case->outs.size()) {
|
||||
test_case->mu.unlock();
|
||||
Env::Default()->SleepForMicroseconds(10);
|
||||
test_case->mu.lock();
|
||||
}
|
||||
test_case->mu.unlock();
|
||||
}
|
||||
// Copy memory to host and verify.
|
||||
for (int i = 0; i < test_case->outs.size(); ++i) {
|
||||
auto* device = devices->at(i % devices->size());
|
||||
auto* stream = device->tensorflow_gpu_device_info()->stream;
|
||||
const Tensor& out_gpu = test_case->outs[i];
|
||||
Tensor out_cpu(DT_FLOAT, out_gpu.shape());
|
||||
auto out_gpu_mem = AsDeviceMemory(out_gpu.flat<float>().data());
|
||||
stream->ThenMemcpy(out_cpu.flat<float>().data(), out_gpu_mem,
|
||||
out_cpu.TotalBytes());
|
||||
stream->BlockHostUntilDone();
|
||||
test::ExpectTensorEqual<float>(test_case->expected, out_cpu);
|
||||
}
|
||||
}
|
||||
};
|
||||
std::vector<BaseGPUDevice*>* NcclManagerTest::devices = nullptr;
|
||||
|
||||
// Test basic sum reduction.
|
||||
TEST_F(NcclManagerTest, BasicSumReduction) {
|
||||
const int num_ranks = 3;
|
||||
|
||||
for (int op = 0; op < 4; ++op) {
|
||||
ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(op);
|
||||
std::unique_ptr<TestCase> test_case(
|
||||
MakeTestCase(num_ranks, reduction_op, TensorShape({2, 3}), 0));
|
||||
for (int device_num = 0; device_num < num_ranks; ++device_num) {
|
||||
auto* device = devices->at(device_num % devices->size());
|
||||
auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
|
||||
auto* stream = device->tensorflow_gpu_device_info()->stream;
|
||||
NcclManager::instance()->AddToAllReduce(
|
||||
num_ranks, "allreduce", reduction_op, device->executor(), event_mgr,
|
||||
stream, &test_case->ins[device_num], &test_case->outs[device_num],
|
||||
CreateDoneCallback(test_case.get()));
|
||||
}
|
||||
|
||||
LOG(ERROR) << "Verifying results";
|
||||
VerifyResults("test_case", test_case.get());
|
||||
}
|
||||
}
|
||||
|
||||
// Same as the Basic test, but with multiple threads launching parts of many
|
||||
// reductions.
|
||||
//
|
||||
// Testing the multi-rank execution is currently reduced as it can hang when run
|
||||
// with num_ranks > devices->size(), for some GPUs (e.g. K20m).
|
||||
// To test the higher settings, increase num_ranks,
|
||||
// num_collectives_per_iteration and time_limit_micros.
|
||||
TEST_F(NcclManagerTest, MultipleCallers) {
|
||||
const int num_ranks = 1; // 2;
|
||||
const int num_collectives_per_iteration = 1; // 1000;
|
||||
const int num_threads = 3;
|
||||
const int time_limit_micros = 1; // 60 * 30 * 1000 * 1000;
|
||||
|
||||
int64 start = Env::Default()->NowMicros();
|
||||
srand(Env::Default()->NowMicros());
|
||||
|
||||
for (;;) {
|
||||
std::vector<std::pair<int, int>> case_and_device_num;
|
||||
std::vector<std::unique_ptr<TestCase>> test_cases;
|
||||
for (int i = 0; i < num_collectives_per_iteration; ++i) {
|
||||
test_cases.emplace_back(
|
||||
MakeTestCase(num_ranks, ncclSum,
|
||||
TensorShape({100, i % 5 + 1, i % 3 + 1}), i + 0.1 * i));
|
||||
for (int j = 0; j < num_ranks; ++j) {
|
||||
case_and_device_num.emplace_back(i, j);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_ranks; ++i) {
|
||||
auto* device = devices->at(i % devices->size());
|
||||
auto* stream = device->tensorflow_gpu_device_info()->stream;
|
||||
stream->BlockHostUntilDone();
|
||||
}
|
||||
|
||||
std::random_shuffle(case_and_device_num.begin(), case_and_device_num.end());
|
||||
|
||||
mutex mu; // guards case_and_device_num.
|
||||
std::unique_ptr<thread::ThreadPool> pool(
|
||||
new thread::ThreadPool(Env::Default(), "test", num_threads));
|
||||
const int to_schedule = case_and_device_num.size();
|
||||
for (int i = 0; i < to_schedule; ++i) {
|
||||
auto fn = [&]() {
|
||||
int device_num;
|
||||
int test_num;
|
||||
{
|
||||
mutex_lock l(mu);
|
||||
test_num = case_and_device_num.back().first;
|
||||
device_num = case_and_device_num.back().second;
|
||||
case_and_device_num.pop_back();
|
||||
}
|
||||
auto* device = devices->at(device_num % devices->size());
|
||||
auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
|
||||
auto* stream = device->tensorflow_gpu_device_info()->stream;
|
||||
TestCase* test_case = test_cases[test_num].get();
|
||||
NcclManager::instance()->AddToAllReduce(
|
||||
num_ranks, strings::StrCat("allreduce", test_num), ncclSum,
|
||||
device->executor(), event_mgr, stream, &test_case->ins[device_num],
|
||||
&test_case->outs[device_num], CreateDoneCallback(test_case));
|
||||
};
|
||||
pool->Schedule(fn);
|
||||
}
|
||||
pool.reset(); // wait for all work to be scheduled.
|
||||
|
||||
LOG(ERROR) << "Verifying results for " << num_collectives_per_iteration
|
||||
<< " collectives";
|
||||
for (int i = 0; i < test_cases.size(); ++i) {
|
||||
VerifyResults(strings::StrCat("collective", i), test_cases[i].get());
|
||||
}
|
||||
|
||||
int64 delta = Env::Default()->NowMicros() - start;
|
||||
if (delta > time_limit_micros) {
|
||||
LOG(ERROR) << "Ran for " << delta << " quitting";
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
157
tensorflow/contrib/nccl/kernels/nccl_ops.cc
Normal file
157
tensorflow/contrib/nccl/kernels/nccl_ops.cc
Normal file
@ -0,0 +1,157 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "external/nccl_archive/src/nccl.h"
|
||||
#include "tensorflow/contrib/nccl/kernels/nccl_manager.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Base class for all communicator ops that use nccl.
|
||||
//
|
||||
// About memory management and stream syncing:
|
||||
// 1. The nccl communicator has a stream for each rank.
|
||||
// 2. For input tensors to the communicator, the compute stream is passed to the
|
||||
// NcclManager which will do a needed
|
||||
// communicator_stream.ThenWaitFor(input_tensor_stream).
|
||||
// 3. The done_callback of the async kernel is not called by the
|
||||
// NcclManager until after the communicator kernel is complete. This
|
||||
// is enough to a) keep the input tensor data valid for the lifetime of the
|
||||
// collective; and b) ensure the data in the output tensor is available
|
||||
// when the async op kernel's done callback is called.
|
||||
class NcclAsyncOpBase : public AsyncOpKernel {
|
||||
public:
|
||||
NcclAsyncOpBase(OpKernelConstruction* c) : AsyncOpKernel(c) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("num_devices", &num_devices_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("shared_name", &collective_prefix_));
|
||||
}
|
||||
|
||||
string GetCollectiveKey(OpKernelContext* c) {
|
||||
return strings::StrCat(collective_prefix_, ";", c->step_id(), ";",
|
||||
c->frame_iter().frame_id, ":",
|
||||
c->frame_iter().iter_id);
|
||||
}
|
||||
|
||||
int num_devices() const { return num_devices_; }
|
||||
|
||||
private:
|
||||
int num_devices_;
|
||||
string collective_prefix_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(NcclAsyncOpBase);
|
||||
};
|
||||
|
||||
// To execute a single all-reduce, this kernel is called once for each of the
|
||||
// <k> devices in the communicator.
|
||||
class NcclAllReduceOpKernel : public NcclAsyncOpBase {
|
||||
public:
|
||||
NcclAllReduceOpKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) {
|
||||
string reduction;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("reduction", &reduction));
|
||||
if (reduction == "min") {
|
||||
reduction_op_ = ncclMin;
|
||||
} else if (reduction == "max") {
|
||||
reduction_op_ = ncclMax;
|
||||
} else if (reduction == "sum") {
|
||||
reduction_op_ = ncclSum;
|
||||
} else if (reduction == "prod") {
|
||||
reduction_op_ = ncclProd;
|
||||
} else {
|
||||
OP_REQUIRES_OK(c,
|
||||
errors::InvalidArgument("Invalid reduction: ", reduction));
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
const Tensor* in_t = &c->input(0);
|
||||
Tensor* out_t;
|
||||
OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, in_t->shape(), &out_t), done);
|
||||
|
||||
auto actual_done = [c, done](Status s) {
|
||||
OP_REQUIRES_OK_ASYNC(c, s, done);
|
||||
done();
|
||||
};
|
||||
|
||||
auto* compute_stream = c->op_device_context()->stream();
|
||||
EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr;
|
||||
NcclManager::instance()->AddToAllReduce(
|
||||
num_devices(), GetCollectiveKey(c), reduction_op_,
|
||||
compute_stream->parent(), event_mgr, compute_stream, in_t, out_t,
|
||||
actual_done);
|
||||
}
|
||||
|
||||
private:
|
||||
ncclRedOp_t reduction_op_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("NcclAllReduce").Device(DEVICE_GPU),
|
||||
NcclAllReduceOpKernel);
|
||||
|
||||
class NcclBroadcastSendKernel : public NcclAsyncOpBase {
|
||||
public:
|
||||
NcclBroadcastSendKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
auto actual_done = [c, done](Status s) {
|
||||
OP_REQUIRES_OK_ASYNC(c, s, done);
|
||||
done();
|
||||
};
|
||||
|
||||
auto* compute_stream = c->op_device_context()->stream();
|
||||
EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr;
|
||||
NcclManager::instance()->AddBroadcastSend(
|
||||
num_devices(), GetCollectiveKey(c), compute_stream->parent(), event_mgr,
|
||||
compute_stream, &c->input(0), std::move(actual_done));
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU),
|
||||
NcclBroadcastSendKernel);
|
||||
|
||||
class NcclBroadcastRecvKernel : public NcclAsyncOpBase {
|
||||
public:
|
||||
NcclBroadcastRecvKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
const Tensor& shape_t = c->input(0);
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
c, TensorShapeUtils::MakeShape(shape_t.vec<int64>(), &shape), done);
|
||||
Tensor* out_t;
|
||||
OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape, &out_t), done);
|
||||
|
||||
auto actual_done = [c, done](Status s) {
|
||||
OP_REQUIRES_OK_ASYNC(c, s, done);
|
||||
done();
|
||||
};
|
||||
|
||||
auto* compute_stream = c->op_device_context()->stream();
|
||||
EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr;
|
||||
NcclManager::instance()->AddBroadcastRecv(
|
||||
num_devices(), GetCollectiveKey(c), compute_stream->parent(), event_mgr,
|
||||
compute_stream, out_t, std::move(actual_done));
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("NcclBroadcastRecv").Device(DEVICE_GPU).HostMemory("shape"),
|
||||
NcclBroadcastRecvKernel);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
94
tensorflow/contrib/nccl/ops/nccl_ops.cc
Normal file
94
tensorflow/contrib/nccl/ops/nccl_ops.cc
Normal file
@ -0,0 +1,94 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
REGISTER_OP("NcclAllReduce")
|
||||
.Input("input: T")
|
||||
.Output("data: T")
|
||||
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
|
||||
.Attr("T: {float, float64, int32, int64}")
|
||||
.Attr("num_devices: int")
|
||||
.Attr("shared_name: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
Outputs a tensor containing the reduction across all input tensors passed to ops
|
||||
within the same `shared_name.
|
||||
|
||||
The graph should be constructed so if one op runs with shared_name value `c`,
|
||||
then `num_devices` ops will run with shared_name value `c`. Failure to do so
|
||||
will cause the graph execution to fail to complete.
|
||||
|
||||
input: the input to the reduction
|
||||
data: the value of the reduction across all `num_devices` devices.
|
||||
reduction: the reduction operation to perform.
|
||||
num_devices: The number of devices participating in this reduction.
|
||||
shared_name: Identifier that shared between ops of the same reduction.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("NcclBroadcastSend")
|
||||
.Input("input: T")
|
||||
.Attr("T: {float, float64, int32, int64}")
|
||||
.Attr("num_devices: int")
|
||||
.Attr("shared_name: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::NoOutputs)
|
||||
.Doc(R"doc(
|
||||
Sends `input` to the NcclBroadcastRecv ops registered in the same `shared_name`.
|
||||
|
||||
The graph should be constructed so that one device runs `NcclBroadcastSend` and
|
||||
`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`.
|
||||
Failure to do so will cause the graph execution to fail to complete.
|
||||
|
||||
input: The input to the broadcast
|
||||
num_devices: The number of devices participating in this reduction.
|
||||
shared_name: Identifier that is shared between ops of the same broadcast.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("NcclBroadcastRecv")
|
||||
.Input("shape: int64")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, float64, int32, int64}")
|
||||
.Attr("num_devices: int")
|
||||
.Attr("shared_name: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Sends data of shape `shape` from the NcclBroadcastSend op registered in the
|
||||
same `shared_name`.
|
||||
|
||||
The graph should be constructed so that one device runs `NcclBroadcastSend` and
|
||||
`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`.
|
||||
Failure to do so will cause the graph execution to fail to complete.
|
||||
|
||||
shape: The shape of the output.
|
||||
output: The broadcast data received from the NcclBroadcastSend op.
|
||||
num_devices: The number of devices participating in this reduction.
|
||||
shared_name: Identifier that is shared between ops of the same broadcast.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
168
tensorflow/contrib/nccl/python/ops/nccl_ops.py
Normal file
168
tensorflow/contrib/nccl/python/ops/nccl_ops.py
Normal file
@ -0,0 +1,168 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ops for GPU collective operations implemented using NVIDIA nccl."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.contrib.nccl.ops import gen_nccl_ops
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.framework import device
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
_nccl_ops_so = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile('_nccl_ops.so'))
|
||||
|
||||
|
||||
def all_sum(tensors):
|
||||
"""Returns a list of tensors with the all-reduce sum across `tensors`.
|
||||
|
||||
The computation is done with an all-reduce operation, so if only some of the
|
||||
returned tensors are evaluated then the computation will hang.
|
||||
|
||||
Args:
|
||||
tensors: The input tensors across which to sum; must be assigned
|
||||
to GPU devices.
|
||||
|
||||
Returns:
|
||||
List of tensors, each with the sum of the input tensors, where tensor i has
|
||||
the same device as `tensors[i]`.
|
||||
"""
|
||||
return _apply_all_reduce('sum', tensors)
|
||||
|
||||
|
||||
def all_prod(tensors):
|
||||
"""Returns a list of tensors with the all-reduce product across `tensors`.
|
||||
|
||||
The computation is done with an all-reduce operation, so if only some of the
|
||||
returned tensors are evaluated then the computation will hang.
|
||||
|
||||
Args:
|
||||
tensors: The input tensors across which to multiply; must be assigned
|
||||
to GPU devices.
|
||||
|
||||
Returns:
|
||||
List of tensors, each with the product of the input tensors, where tensor i
|
||||
has the same device as `tensors[i]`.
|
||||
"""
|
||||
return _apply_all_reduce('prod', tensors)
|
||||
|
||||
|
||||
def all_min(tensors):
|
||||
"""Returns a list of tensors with the all-reduce min across `tensors`.
|
||||
|
||||
The computation is done with an all-reduce operation, so if only some of the
|
||||
returned tensors are evaluated then the computation will hang.
|
||||
|
||||
Args:
|
||||
tensors: The input tensors across which to reduce; must be assigned
|
||||
to GPU devices.
|
||||
|
||||
Returns:
|
||||
List of tensors, each with the minimum of the input tensors, where tensor i
|
||||
has the same device as `tensors[i]`.
|
||||
"""
|
||||
return _apply_all_reduce('min', tensors)
|
||||
|
||||
|
||||
def all_max(tensors):
|
||||
"""Returns a list of tensors with the all-reduce max across `tensors`.
|
||||
|
||||
The computation is done with an all-reduce operation, so if only some of the
|
||||
returned tensors are evaluated then the computation will hang.
|
||||
|
||||
Args:
|
||||
tensors: The input tensors across which to reduce; must be assigned
|
||||
to GPU devices.
|
||||
|
||||
Returns:
|
||||
List of tensors, each with the maximum of the input tensors, where tensor i
|
||||
has the same device as `tensors[i]`.
|
||||
"""
|
||||
return _apply_all_reduce('max', tensors)
|
||||
|
||||
|
||||
def broadcast(src_tensor, dst_devices):
|
||||
"""Returns a list of tensors on `dst_devices`, each with value `tensor`.
|
||||
|
||||
The computation is done with a broadcast nccl operation, so if only some of
|
||||
the returned tensors and src_tensor are evaluated then the computation will
|
||||
hang.
|
||||
|
||||
Args:
|
||||
src_tensor: The tensor to send; must be assigned to a GPU device.
|
||||
dst_devices: The GPU devices to receive the sent tensor.
|
||||
|
||||
Returns:
|
||||
List of tensors, each with the value of `src_tensor`, which the device
|
||||
of tensor i is `dst_devices[i]`.
|
||||
"""
|
||||
if not dst_devices:
|
||||
raise ValueError('Must pass >0 dst_devices to broadcast')
|
||||
all_devices = [src_tensor.device] + dst_devices
|
||||
shared_name = _get_shared_name()
|
||||
|
||||
with ops.device(src_tensor.device):
|
||||
send = gen_nccl_ops.nccl_broadcast_send(
|
||||
input=src_tensor, num_devices=len(all_devices), shared_name=shared_name)
|
||||
|
||||
shape_op = array_ops.shape(src_tensor, out_type=dtypes.int64)
|
||||
recvs = []
|
||||
for d in dst_devices:
|
||||
with ops.device(d):
|
||||
recvs.append(
|
||||
gen_nccl_ops.nccl_broadcast_recv(
|
||||
shape=shape_op,
|
||||
T=src_tensor.dtype,
|
||||
num_devices=len(all_devices),
|
||||
shared_name=shared_name))
|
||||
|
||||
return send, recvs
|
||||
|
||||
|
||||
def _apply_all_reduce(reduction_op, tensors):
|
||||
if not tensors:
|
||||
raise ValueError('Must pass >0 tensors to all reduce operations')
|
||||
shared_name = _get_shared_name()
|
||||
res = []
|
||||
for t in tensors:
|
||||
if not device.canonical_name(t.device):
|
||||
raise ValueError('Device assignment required for nccl collective ops')
|
||||
with ops.device(t.device):
|
||||
res.append(
|
||||
gen_nccl_ops.nccl_all_reduce(
|
||||
t,
|
||||
reduction=reduction_op,
|
||||
num_devices=len(tensors),
|
||||
shared_name=shared_name))
|
||||
return res
|
||||
|
||||
|
||||
_lock = threading.Lock()
|
||||
_shared_name_counter = 0
|
||||
|
||||
|
||||
def _get_shared_name():
|
||||
global _shared_name_counter
|
||||
|
||||
with _lock:
|
||||
val = _shared_name_counter
|
||||
_shared_name_counter += 1
|
||||
return 'c%s' % val
|
151
tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
Normal file
151
tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
Normal file
@ -0,0 +1,151 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for nccl ops. See also the cc test for nccl_communicator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib import nccl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class AllReduceTest(test.TestCase):
|
||||
|
||||
def testAllReduce(self):
|
||||
if not test.is_gpu_available():
|
||||
return # Test requires access to a GPU
|
||||
|
||||
for dtype in [np.float32, np.int32, np.int64, np.float64]:
|
||||
# Create session inside outer loop to test use of
|
||||
# same communicator across multiple sessions.
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
self._testSingleAllReduce(sess, dtype, nccl.all_sum, lambda x, y: x + y)
|
||||
self._testSingleAllReduce(sess, dtype, nccl.all_prod,
|
||||
lambda x, y: x * y)
|
||||
self._testSingleAllReduce(sess, dtype, nccl.all_min, np.minimum)
|
||||
self._testSingleAllReduce(sess, dtype, nccl.all_max, np.maximum)
|
||||
|
||||
def _testSingleAllReduce(self, sess, np_type, nccl_fn, numpy_accumulation_fn):
|
||||
for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]:
|
||||
shape = (3, 4)
|
||||
np_ans = None
|
||||
tensors = []
|
||||
for d in devices:
|
||||
with ops.device(d):
|
||||
t = ((np.random.random_sample(shape) - .5) * 1024).astype(np_type)
|
||||
if np_ans is None:
|
||||
np_ans = t
|
||||
else:
|
||||
np_ans = numpy_accumulation_fn(np_ans, t)
|
||||
tensors.append(array_ops.identity(t))
|
||||
|
||||
all_reduce_tensors = nccl_fn(tensors)
|
||||
|
||||
# Test shape inference.
|
||||
for r in all_reduce_tensors:
|
||||
self.assertEqual(shape, r.get_shape())
|
||||
|
||||
# Test execution and results.
|
||||
nccl_results = sess.run(all_reduce_tensors)
|
||||
for r in nccl_results:
|
||||
self.assertAllClose(r, np_ans)
|
||||
|
||||
def testErrors(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Device assignment required'):
|
||||
nccl.all_sum([array_ops.identity(np.random.random_sample((3, 4)))])
|
||||
with self.assertRaisesRegexp(ValueError, 'Must pass >0 tensors'):
|
||||
nccl.all_sum([])
|
||||
|
||||
|
||||
class BroadcastTest(test.TestCase):
|
||||
|
||||
def testBroadcast(self):
|
||||
if not test.is_gpu_available():
|
||||
return # Test requires access to a GPU
|
||||
|
||||
for dtype in [np.float32, np.int32, np.int64, np.float64]:
|
||||
# Create session inside outer loop to test use of
|
||||
# same communicator across multiple sessions.
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]:
|
||||
shape = (3, 4)
|
||||
sender = np.random.randint(0, len(devices) - 1)
|
||||
with ops.device(devices[sender]):
|
||||
np_ans = ((
|
||||
(np.random.random_sample(shape) - .5) * 1024).astype(dtype))
|
||||
t = array_ops.identity(np_ans)
|
||||
other_devices = devices[:sender] + devices[sender + 1:]
|
||||
send_op, received_tensors = nccl.broadcast(t, other_devices)
|
||||
|
||||
# Verify shape inference.
|
||||
for r in received_tensors:
|
||||
self.assertEqual(shape, r.get_shape())
|
||||
|
||||
# Run and verify results.
|
||||
nccl_results = sess.run(received_tensors + [send_op])
|
||||
for r in nccl_results[:-1]:
|
||||
self.assertAllClose(r, np_ans)
|
||||
|
||||
|
||||
class CombinedTest(test.TestCase):
|
||||
"""Tests using a mix of all-reduce ops in one session.run call."""
|
||||
|
||||
def testCombined(self):
|
||||
if not test.is_gpu_available():
|
||||
return # Test requires access to a GPU
|
||||
|
||||
for dtype in [np.float32, np.int32, np.int64, np.float64]:
|
||||
# Create session inside outer loop to test use of
|
||||
# same communicator across multiple sessions.
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]:
|
||||
shape = (3, 4)
|
||||
|
||||
# all-reduce
|
||||
np_ans = np.zeros(shape=shape, dtype=dtype)
|
||||
tensors = []
|
||||
for d in devices:
|
||||
with ops.device(d):
|
||||
t = ((np.random.random_sample(shape) - .5) * 1024).astype(dtype)
|
||||
np_ans += t
|
||||
tensors.append(array_ops.identity(t))
|
||||
all_reduce_tensors = nccl.all_sum(tensors)
|
||||
|
||||
sender = np.random.randint(0, len(devices) - 1)
|
||||
other_devices = devices[:sender] + devices[sender + 1:]
|
||||
send_op, received_tensors = nccl.broadcast(all_reduce_tensors[sender],
|
||||
other_devices)
|
||||
|
||||
# sender doesn't need to be fetched as part of outputs of session.run.
|
||||
del all_reduce_tensors[sender]
|
||||
|
||||
# Verify shape inference.
|
||||
for r in received_tensors:
|
||||
self.assertEqual(shape, r.get_shape())
|
||||
|
||||
# Run and verify results.
|
||||
nccl_results = sess.run(
|
||||
received_tensors + [send_op] + all_reduce_tensors)
|
||||
for r in nccl_results[:len(received_tensors)]:
|
||||
self.assertAllClose(r, np_ans)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -385,6 +385,14 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
||||
actual = "@zlib_archive//:zlib",
|
||||
)
|
||||
|
||||
native.new_http_archive(
|
||||
name = "nccl_archive",
|
||||
url = "https://github.com/NVIDIA/nccl/archive/2a974f5ca2aa12b178046b2206b43f1fd69d9fae.tar.gz",
|
||||
sha256 = "d6aa1a3f20ae85358890d9a96f49c51a75baa1d3af3598501f29ff9ef8a3107d",
|
||||
strip_prefix = "nccl-2a974f5ca2aa12b178046b2206b43f1fd69d9fae",
|
||||
build_file = str(Label("//third_party:nccl.BUILD")),
|
||||
)
|
||||
|
||||
# Make junit-4.12 available as //external:junit
|
||||
native.http_jar(
|
||||
name = "junit_jar",
|
||||
|
48
third_party/nccl.BUILD
vendored
Normal file
48
third_party/nccl.BUILD
vendored
Normal file
@ -0,0 +1,48 @@
|
||||
# NVIDIA nccl
|
||||
# A package of optimized primitives for collective multi-GPU communication.
|
||||
|
||||
licenses(["notice"]) # BSD
|
||||
|
||||
exports_files(["LICENSE.txt"])
|
||||
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda")
|
||||
|
||||
SRCS = [
|
||||
"src/all_gather.cu",
|
||||
"src/all_reduce.cu",
|
||||
"src/broadcast.cu",
|
||||
"src/core.cu",
|
||||
"src/libwrap.cu",
|
||||
"src/reduce.cu",
|
||||
"src/reduce_scatter.cu",
|
||||
]
|
||||
|
||||
# Copy .cu to .cu.cc so they can be in srcs of cc_library.
|
||||
[
|
||||
genrule(
|
||||
name = "gen_" + src,
|
||||
srcs = [src],
|
||||
outs = [src + ".cc"],
|
||||
cmd = "cp $(location " + src + ") $(location " + src + ".cc)",
|
||||
)
|
||||
for src in SRCS
|
||||
]
|
||||
|
||||
SRCS_CU_CC = [src + ".cc" for src in SRCS]
|
||||
|
||||
cc_library(
|
||||
name = "nccl",
|
||||
srcs = if_cuda(SRCS_CU_CC + glob(["src/*.h"])),
|
||||
hdrs = if_cuda(["src/nccl.h"]),
|
||||
copts = [
|
||||
"-DCUDA_MAJOR=0",
|
||||
"-DCUDA_MINOR=0",
|
||||
"-DNCCL_MAJOR=0",
|
||||
"-DNCCL_MINOR=0",
|
||||
"-DNCCL_PATCH=0",
|
||||
"-Iexternal/nccl_archive/src",
|
||||
"-O3",
|
||||
] + cuda_default_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["@local_config_cuda//cuda:cuda_headers"],
|
||||
)
|
Loading…
Reference in New Issue
Block a user