Update to NCCL version 1.3.5. Remove temporary buffer for ncclReduce, it's no longer needed in this version.
PiperOrigin-RevId: 169221983
This commit is contained in:
parent
23da21150d
commit
5882ae35d3
@ -312,11 +312,11 @@ void NcclManager::AddReduceSend(int num_devices, const string& key,
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
int gpu_device_id, EventMgr* event_mgr,
|
||||
perftools::gputools::Stream* tensor_stream,
|
||||
const Tensor* in_t, Tensor* temp_t,
|
||||
const Tensor* in_t,
|
||||
DoneCallback done_callback) {
|
||||
std::unique_ptr<Participant> participant(
|
||||
new Participant(in_t, temp_t, event_mgr, tensor_stream, executor,
|
||||
gpu_device_id, std::move(done_callback)));
|
||||
new Participant(in_t, nullptr /* out_t */, event_mgr, tensor_stream,
|
||||
executor, gpu_device_id, std::move(done_callback)));
|
||||
AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
|
||||
kReduce, reduction_op);
|
||||
}
|
||||
@ -462,7 +462,9 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
|
||||
}
|
||||
case kReduce: {
|
||||
const void* sendbuff = p->in_t->tensor_data().data();
|
||||
void* recvbuff = const_cast<char*>(p->out_t->tensor_data().data());
|
||||
void* recvbuff = p->out_t
|
||||
? const_cast<char*>(p->out_t->tensor_data().data())
|
||||
: nullptr;
|
||||
nccl_result = ncclReduce(sendbuff, recvbuff, p->in_t->NumElements(),
|
||||
data_type, collective->reduction_op,
|
||||
collective->root_rank, nccl_comm, *cu_stream);
|
||||
|
@ -82,8 +82,7 @@ class NcclManager {
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
int gpu_device_id, EventMgr* event_mgr,
|
||||
perftools::gputools::Stream* tensor_stream,
|
||||
const Tensor* in_t, Tensor* temp_t,
|
||||
DoneCallback done_callback);
|
||||
const Tensor* in_t, DoneCallback done_callback);
|
||||
void AddReduceRecv(int num_devices, const string& key,
|
||||
ncclRedOp_t reduction_op,
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
|
@ -121,14 +121,7 @@ class NcclReduceSendKernel : public NcclReduceOpBase {
|
||||
: NcclReduceOpBase(c) {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
const Tensor& in_t = c->input(0);
|
||||
std::unique_ptr<Tensor> temp_ptr(new Tensor());
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
c, c->allocate_temp(in_t.dtype(), in_t.shape(), temp_ptr.get()), done);
|
||||
Tensor* temp_t = temp_ptr.release();
|
||||
|
||||
auto actual_done = [c, done, temp_t](Status s) {
|
||||
delete temp_t;
|
||||
auto actual_done = [c, done](Status s) {
|
||||
OP_REQUIRES_OK_ASYNC(c, s, done);
|
||||
done();
|
||||
};
|
||||
@ -138,7 +131,7 @@ class NcclReduceSendKernel : public NcclReduceOpBase {
|
||||
NcclManager::instance()->AddReduceSend(
|
||||
num_devices(), GetCollectiveKey(c), reduction_op(),
|
||||
compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr,
|
||||
compute_stream, &in_t, temp_t, std::move(actual_done));
|
||||
compute_stream, &c->input(0), std::move(actual_done));
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("NcclReduceSend").Device(DEVICE_GPU),
|
||||
|
@ -629,11 +629,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
||||
temp_workaround_http_archive(
|
||||
name = "nccl_archive",
|
||||
urls = [
|
||||
"http://mirror.bazel.build/github.com/nvidia/nccl/archive/29a1a916dc14bb2c00feed3d4820d51fa85be1e6.tar.gz",
|
||||
"https://github.com/nvidia/nccl/archive/29a1a916dc14bb2c00feed3d4820d51fa85be1e6.tar.gz",
|
||||
"http://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
|
||||
"https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
|
||||
],
|
||||
sha256 = "6387030e37d14762f87eefbc86ee527293ec04745c66ccd820cf7fc0fdc23f92",
|
||||
strip_prefix = "nccl-29a1a916dc14bb2c00feed3d4820d51fa85be1e6",
|
||||
sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176",
|
||||
strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7",
|
||||
build_file = str(Label("//third_party:nccl.BUILD")),
|
||||
repository = tf_repo_name,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user