[XLA-GPU] Add NCCL implementation of AllGather op.

PiperOrigin-RevId: 346133758
Change-Id: I530eb132890b14cf25be209777b93eee1bddc1a8
This commit is contained in:
Chris Jones 2020-12-07 11:08:07 -08:00 committed by TensorFlower Gardener
parent e9deb8b804
commit 083d01651a
9 changed files with 396 additions and 0 deletions

View File

@ -259,6 +259,7 @@ cc_library(
":hlo_to_ir_bindings",
":ir_emission_utils",
":launch_dimensions",
":nccl_all_gather_thunk",
":nccl_all_reduce_thunk",
":nccl_all_to_all_thunk",
":parallel_loop_emitter",
@ -474,6 +475,46 @@ tf_cuda_library(
]),
)
# First level of nested select. NCCL requires both if_cuda and if_nccl.
filegroup(
name = "nccl_all_gather_thunk_src",
srcs = if_nccl(
["nccl_all_gather_thunk.cc"],
["dummy_all_gather_thunk.cc"],
),
)
tf_cuda_library(
name = "nccl_all_gather_thunk",
srcs = if_cuda_or_rocm(
[":nccl_all_gather_thunk_src"],
["dummy_all_gather_thunk.cc"],
),
hdrs = ["nccl_all_gather_thunk.h"],
deps = [
":buffer_allocations",
":gpu_executable_run_options",
":hlo_execution_profiler",
":nccl_collective_thunk",
":thunk",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings:str_format",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:collective_ops_utils",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:lib",
] + if_nccl([
":virtual_nccl",
":virtual_nccl_utils",
":virtual_rccl",
]),
)
# First level of nested select. NCCL requires both if_cuda and if_nccl.
filegroup(
name = "nccl_all_reduce_thunk_src",

View File

@ -0,0 +1,51 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
namespace xla {
namespace gpu {
NcclAllGatherConfig GetNcclAllGatherConfig(const HloInstruction* hlo,
int64 replica_count) {
return NcclAllGatherConfig();
}
NcclAllGatherThunk::NcclAllGatherThunk(
ThunkInfo thunk_info, NcclAllGatherConfig config,
std::vector<NcclAllGatherThunk::Buffer> buffers)
: NcclCollectiveThunk(Thunk::kNcclAllGather, thunk_info),
config_(std::move(config)),
buffers_(std::move(buffers)) {}
/* static */ bool NcclAllGatherThunk::CanImplement(const HloInstruction* hlo) {
return false;
}
Status NcclAllGatherThunk::RunNcclCollective(const ExecuteParams&, ncclComm_t) {
return Unimplemented(
"NCCL support is not available: this binary was not built with a CUDA "
"compiler, which is necessary to build the NCCL source library.");
}
const NcclCollectiveConfig& NcclAllGatherThunk::config() const {
// This function will never be called.
const NcclCollectiveConfig* config = nullptr;
return *config;
}
} // namespace gpu
} // namespace xla

View File

@ -71,6 +71,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
@ -2435,6 +2436,103 @@ Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
return Status::OK();
}
Status IrEmitterUnnested::HandleAllGather(HloInstruction* hlo) {
VLOG(2) << "AllGather; replica count: " << hlo_module_config_.replica_count()
<< "; operand count: " << hlo->operand_count()
<< "; NCCL is enabled: " << NcclAllGatherThunk::NcclIsEnabled();
// Note the replica_count == 1 case is handled via device-to-device copy
// below.
bool should_use_nccl_thunk = hlo_module_config_.replica_count() > 1 &&
NcclAllGatherThunk::CanImplement(hlo);
if (should_use_nccl_thunk) {
std::vector<NcclAllGatherThunk::Buffer> buffers;
std::vector<BufferAllocation::Slice> tuple_element_buffers;
buffers.resize(hlo->operand_count());
tuple_element_buffers.reserve(hlo->operand_count());
CHECK(hlo->shape().IsArray() && hlo->operand_count() == 1 ||
hlo->shape().IsTuple() &&
hlo->shape().tuple_shapes_size() == hlo->operand_count());
for (int i = 0; i < hlo->operand_count(); ++i) {
CHECK(hlo->operand(i)->shape().IsArray())
<< "Operands to all-gather must be arrays: " << hlo->ToString();
buffers[i].element_count =
ShapeUtil::ElementsIn(hlo->operand(i)->shape());
buffers[i].source_buffer = GetAllocationSlice(*hlo->operand(i));
buffers[i].destination_buffer = GetAllocationSlice(
*hlo, hlo->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({}));
tuple_element_buffers.push_back(buffers[i].destination_buffer);
}
NcclAllGatherConfig config =
GetNcclAllGatherConfig(hlo, hlo_module_config_.replica_count());
auto all_gather_thunk = absl::make_unique<NcclAllGatherThunk>(
GetThunkInfo(hlo), std::move(config),
/*buffers=*/std::move(buffers));
if (hlo->shape().IsTuple()) {
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(std::move(all_gather_thunk));
thunks.push_back(absl::make_unique<TupleThunk>(
Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*hlo)));
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
GetThunkInfo(hlo), std::move(thunks)));
} else {
AddThunkToThunkSequence(std::move(all_gather_thunk));
}
return Status::OK();
}
if (hlo_module_config_.replica_count() != 1) {
string message = absl::StrFormat(
"Requested AllGather not implemented on GPU; replica_count: %d; "
"operand_count: %d; NCCL support: %d",
hlo_module_config_.replica_count(), hlo->operand_count(),
NcclAllGatherThunk::NcclIsEnabled());
if (hlo->operand_count() > 0) {
absl::StrAppendFormat(
&message, "; first operand array element-type: %s",
PrimitiveType_Name(hlo->operand(0)->shape().element_type()));
}
return Unimplemented("%s", message);
}
// All-gather with one operand and one replica is simply the identity
// function. Buffer assignment expects a copy, so that's what we do.
if (hlo->operand_count() == 1) {
CHECK(hlo->operand(0)->shape().IsArray())
<< "Operands to all-gather must be arrays: " << hlo->ToString();
AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
GetThunkInfo(hlo),
/*source_address=*/GetAllocationSlice(*hlo->operand(0)),
/*destination_buffer=*/GetAllocationSlice(*hlo),
/*mem_size=*/ShapeUtil::ByteSizeOf(hlo->shape())));
return Status::OK();
}
// One-replica all-gather with multiple operands produces a tuple of the
// inputs. Again, buffer assignment expects us to copy each.
std::vector<std::unique_ptr<Thunk>> thunks;
std::vector<BufferAllocation::Slice> tuple_element_buffers;
for (int64 i = 0; i < hlo->operand_count(); ++i) {
tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment()
.GetUniqueSlice(hlo, {i})
.ValueOrDie());
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo(),
/*source_address=*/GetAllocationSlice(*hlo->operand(i)),
/*destination_buffer=*/tuple_element_buffers.back(),
/*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(i)->shape())));
}
// Output a tuple of the buffers above.
thunks.push_back(absl::make_unique<TupleThunk>(
Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*hlo)));
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(GetThunkInfo(hlo), std::move(thunks)));
return Status::OK();
}
Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count()
<< "; operand count: " << crs->operand_count()

View File

@ -192,6 +192,7 @@ class IrEmitterUnnested : public IrEmitter,
Status HandleSort(HloInstruction* sort) override;
Status EmitSortFromMlir(MlirEmitterInput mlir_input);
Status HandleTriangularSolve(HloInstruction* hlo) override;
Status HandleAllGather(HloInstruction* hlo) override;
Status HandleAllReduce(HloInstruction* crs) override;
Status HandleAllToAll(HloInstruction* hlo) override;
Status HandleAfterAll(HloInstruction* after_all) override;

View File

@ -0,0 +1,109 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
#include <chrono> // NOLINT (required by TF interfaces)
#include <cstdlib>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/str_format.h"
#if GOOGLE_CUDA
#include "third_party/nccl/nccl.h"
#elif TENSORFLOW_USE_ROCM
#include "rocm/include/rccl/rccl.h"
#endif
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace gpu {
NcclAllGatherConfig GetNcclAllGatherConfig(const HloInstruction* hlo,
int64 replica_count) {
NcclAllGatherConfig config;
config.config = GetNcclCollectiveConfig(hlo, replica_count);
return config;
}
/*static*/ bool NcclAllGatherThunk::CanImplement(const HloInstruction* hlo) {
auto operands_are_supported = [hlo]() {
return absl::c_all_of(hlo->operands(), [](HloInstruction* operand) {
return LayoutUtil::IsDenseArray(operand->shape()) &&
ToNcclDataType(operand->shape().element_type()).ok();
});
};
return (Cast<HloAllGatherInstruction>(hlo)->all_gather_dimension() == 0) &&
operands_are_supported();
}
NcclAllGatherThunk::NcclAllGatherThunk(
ThunkInfo thunk_info, NcclAllGatherConfig config,
std::vector<NcclAllGatherThunk::Buffer> buffers)
: NcclCollectiveThunk(Thunk::kNcclAllGather, thunk_info),
config_(std::move(config)),
buffers_(std::move(buffers)) {
CHECK_EQ(config_.config.operand_count, buffers_.size());
}
Status NcclAllGatherThunk::RunNcclCollective(const ExecuteParams& params,
ncclComm_t comm) {
int device_ordinal = params.stream->parent()->device_ordinal();
VLOG(3) << "Performing all-gather from device ordinal: " << device_ordinal;
cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
params.stream->implementation()->GpuStreamMemberHack());
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
for (size_t i = 0; i < buffers_.size(); ++i) {
const Buffer& buffer = buffers_[i];
const void* send_buffer =
params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
.opaque();
void* recv_buffer =
params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
.opaque();
TF_ASSIGN_OR_RETURN(ncclDataType_t datatype,
ToNcclDataType(config_.config.operand_element_type[i]));
VLOG(3) << absl::StreamFormat(
"Calling ncclAllGather(send_buffer=%p, recv_buffer=%p, count=%d, "
"comm=%p, stream=%p)",
send_buffer, recv_buffer, buffer.element_count,
static_cast<const void*>(comm), cu_stream);
XLA_CUDA_RETURN_IF_ERROR(ncclAllGather(send_buffer, recv_buffer,
buffer.element_count, datatype, comm,
*cu_stream));
}
XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
VLOG(3) << "Done performing all-gather for ordinal: " << device_ordinal;
return Status::OK();
}
const NcclCollectiveConfig& NcclAllGatherThunk::config() const {
return config_.config;
}
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,66 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_GATHER_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_GATHER_THUNK_H_
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace gpu {
struct NcclAllGatherConfig {
NcclCollectiveConfig config;
};
NcclAllGatherConfig GetNcclAllGatherConfig(const HloInstruction* hlo,
int64 replica_count);
// Thunk that performs a NCCL-based All-Gather among CUDA GPU-based replicas.
class NcclAllGatherThunk : public NcclCollectiveThunk {
public:
struct Buffer {
int64 element_count;
BufferAllocation::Slice source_buffer;
BufferAllocation::Slice destination_buffer;
};
NcclAllGatherThunk(ThunkInfo thunk_info, NcclAllGatherConfig config,
std::vector<Buffer> buffers);
// Returns whether the given instruction can be lowered to a nccl all-gather
// call.
static bool CanImplement(const HloInstruction* hlo);
protected:
Status RunNcclCollective(const ExecuteParams& params,
ncclComm_t comm) override;
const NcclCollectiveConfig& config() const override;
private:
const NcclAllGatherConfig config_;
const std::vector<Buffer> buffers_;
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_GATHER_THUNK_H_

View File

@ -50,6 +50,8 @@ absl::string_view ThunkKindToString(Thunk::Kind kind) {
return "kCudnnBatchNormForwardTraining";
case Thunk::kCustomCall:
return "kCustomCall";
case Thunk::kNcclAllGather:
return "kNcclAllGather";
case Thunk::kNcclAllReduce:
return "kNcclAllReduce";
case Thunk::kNcclAllToAll:

View File

@ -59,6 +59,7 @@ class Thunk {
kKernel,
kMemset32BitValue,
kMemzero,
kNcclAllGather,
kNcclAllReduce,
kNcclAllToAll,
kOutfeed,

View File

@ -738,6 +738,33 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_SplitDimension)) {
results[3]);
}
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
id = u32[] replica-id()
id2 = u32[1, 2] broadcast(id), dimensions={}
a0 = u32[1, 2] constant({{10, 15}})
a1 = u32[1, 2] add(id2, a0)
allgather = u32[4, 2] all-gather(a1), dimensions={0}
ROOT out = u32[8] reshape(allgather)
}
)";
const int64 kNumReplicas = 4;
auto config = GetModuleConfigForTest(kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
ExecuteReplicated(std::move(module), {}, kNumReplicas,
/*use_threads=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (const Literal& result : results) {
LiteralTestUtil::ExpectR1Equal<uint32>({10, 15, 11, 16, 12, 17, 13, 18},
result);
}
}
XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) {
std::string hlo_string = R"(
HloModule test