[XLA-GPU] NFC: Split out common base class for NCCL collective thunks.

PiperOrigin-RevId: 345717933
Change-Id: I56cf6ce929413aeacd20207408d75609ed6fd16f
This commit is contained in:
Chris Jones 2020-12-04 11:19:31 -08:00 committed by TensorFlower Gardener
parent 46cf2ef65b
commit e6fe105d22
7 changed files with 389 additions and 197 deletions

View File

@ -418,15 +418,6 @@ cc_library(
],
)
# First level of nested select. NCCL requires both if_cuda and if_nccl.
filegroup(
name = "nccl_all_reduce_thunk_src",
srcs = if_nccl(
["nccl_all_reduce_thunk.cc"],
["dummy_all_reduce_thunk.cc"],
),
)
# use alias since nested select statements not possible
cc_library(
name = "empty",
@ -442,6 +433,55 @@ alias(
actual = if_rocm("@local_config_rocm//rocm:rccl", ":empty"),
)
# First level of nested select. NCCL requires both if_cuda and if_nccl.
filegroup(
name = "nccl_collective_thunk_src",
srcs = if_nccl(
["nccl_collective_thunk.cc"],
["dummy_collective_thunk.cc"],
),
)
tf_cuda_library(
name = "nccl_collective_thunk",
srcs = if_cuda_or_rocm(
[":nccl_collective_thunk_src"],
["dummy_collective_thunk.cc"],
),
hdrs = ["nccl_collective_thunk.h"],
deps = [
":thunk",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"//tensorflow/compiler/xla/service:collective_ops_utils",
"//tensorflow/compiler/xla/service:global_device_id",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:lib",
] + if_cuda([
"//tensorflow/stream_executor/cuda:cuda_activation",
"//tensorflow/stream_executor/cuda:cuda_gpu_executor",
]) + if_rocm([
"//tensorflow/stream_executor/rocm:rocm_activation",
"//tensorflow/stream_executor/rocm:rocm_gpu_executor",
]) + if_nccl([
":virtual_nccl",
":virtual_nccl_utils",
":virtual_rccl",
]),
)
# First level of nested select. NCCL requires both if_cuda and if_nccl.
filegroup(
name = "nccl_all_reduce_thunk_src",
srcs = if_nccl(
["nccl_all_reduce_thunk.cc"],
["dummy_all_reduce_thunk.cc"],
),
)
tf_cuda_library(
name = "nccl_all_reduce_thunk",
srcs = if_cuda_or_rocm(
@ -451,28 +491,22 @@ tf_cuda_library(
hdrs = ["nccl_all_reduce_thunk.h"],
deps = [
":buffer_allocations",
":hlo_execution_profiler",
":thunk",
":gpu_executable_run_options",
":hlo_execution_profiler",
":nccl_collective_thunk",
":thunk",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings:str_format",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:collective_ops_utils",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core/platform:stream_executor_no_cuda",
] + if_cuda([
"//tensorflow/stream_executor/cuda:cuda_activation",
"//tensorflow/stream_executor/cuda:cuda_gpu_executor",
]) + if_rocm([
"//tensorflow/stream_executor/rocm:rocm_activation",
"//tensorflow/stream_executor/rocm:rocm_gpu_executor",
]) + if_nccl([
] + if_nccl([
":virtual_nccl",
":virtual_nccl_utils",
":virtual_rccl",

View File

@ -19,37 +19,33 @@ limitations under the License.
namespace xla {
namespace gpu {
struct NcclAllReduceConfig::AuxData {};
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig &&) = default;
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr,
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* hlo,
int64 replica_count) {
NcclAllReduceConfig config = {};
return config;
return NcclAllReduceConfig();
}
/* static */ bool NcclAllReduceThunk::NcclIsEnabled() {
return false; // Skylark selects this source file if NCCL is disabled.
}
NcclAllReduceThunk::NcclAllReduceThunk(
ThunkInfo thunk_info, NcclAllReduceConfig config,
std::vector<NcclAllReduceThunk::Buffer> buffers)
: NcclCollectiveThunk(Thunk::kNcclAllReduce, thunk_info),
config_(std::move(config)),
buffers_(std::move(buffers)) {}
/* static */ bool NcclAllReduceThunk::CanImplement(const HloInstruction* crs) {
/* static */ bool NcclAllReduceThunk::CanImplement(const HloInstruction* hlo) {
return false;
}
Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
Status NcclAllReduceThunk::RunNcclCollective(const ExecuteParams&, ncclComm_t) {
return Unimplemented(
"NCCL support is not available: this binary was not built with a CUDA "
"compiler, which is necessary to build the NCCL source library.");
}
NcclAllReduceThunk::NcclAllReduceThunk(
ThunkInfo thunk_info, NcclAllReduceConfig &&config,
std::vector<NcclAllReduceThunk::Buffer> buffers)
: Thunk(Thunk::kNcclAllReduce, thunk_info),
config_(std::move(config)),
buffers_(std::move(buffers)) {}
const NcclCollectiveConfig& NcclAllReduceThunk::config() const {
// This function will never be called.
const NcclCollectiveConfig* config = nullptr;
return *config;
}
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,46 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
namespace xla {
namespace gpu {
struct NcclCollectiveConfig::AuxData {};
NcclCollectiveConfig::NcclCollectiveConfig() = default;
NcclCollectiveConfig::NcclCollectiveConfig(NcclCollectiveConfig &&) = default;
NcclCollectiveConfig::~NcclCollectiveConfig() = default;
NcclCollectiveConfig &NcclCollectiveConfig::operator=(NcclCollectiveConfig &&) =
default;
NcclCollectiveConfig GetNcclCollectiveConfig(const HloInstruction *hlo,
int64 replica_count) {
return NcclCollectiveConfig();
}
/* static */ bool NcclCollectiveThunk::NcclIsEnabled() {
return false; // Skylark selects this source file if NCCL is disabled.
}
Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams &) {
return Unimplemented(
"NCCL support is not available: this binary was not built with a CUDA "
"compiler, which is necessary to build the NCCL source library.");
}
} // namespace gpu
} // namespace xla

View File

@ -22,10 +22,8 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/global_device_id.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#if GOOGLE_CUDA
#include "third_party/nccl/nccl.h"
#elif TENSORFLOW_USE_ROCM
@ -36,140 +34,51 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
namespace xla {
namespace gpu {
// This file runs collective ops (i.e. ops that communicate between multiple
// GPUs) using NCCL. Currently only kAllReduce is implemented.
//
// Here's a high-level overview of how running an op works.
//
// - Multiple threads call NcclAllReduceThunk::ExecuteOnStream.
// - All threads that "go together" (i.e. are participating in the "same"
// collective op) choose the same Rendezvous object from a global map.
// - Once all threads have arrived at the Rendezvous, we know exactly which
// GPUs are participating in the op, so we get or create a NcclClique
// containing those GPUs.
// - We perform the NCCL operation using the clique, then destroy the
// Rendezvous. The clique is cached, see below.
//
// Creating NCCL cliques is expensive, so we cache them. Our policy is, a thunk
// keeps alive all cliques it's ever used. When the thunk is destroyed, it
// releases its handle on the cliques, and cliques whose refcounts go to 0 are
// destroyed.
/* static */ bool NcclAllReduceThunk::NcclIsEnabled() {
return true; // Skylark selects this source file if NCCL is enabled.
}
// Extra data stored in NcclAllReduceThunk that we didn't want to expose in the
// header. In particular, this stores the thunk's cache of all NcclCliques it's
// ever used. This causes those cliques to stay alive as long as the thunk
// lives, which is how we avoid expensive reinitialization of NCCL cliques.
struct NcclAllReduceConfig::AuxData {
tensorflow::mutex mu;
absl::flat_hash_set<std::shared_ptr<NcclClique>> cliques TF_GUARDED_BY(mu);
};
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig&&) = default;
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr,
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* hlo,
int64 replica_count) {
NcclAllReduceConfig config;
config.operand_count = instr->operands().size();
config.operand_element_type.reserve(config.operand_count);
for (int i = 0; i < config.operand_count; i++) {
config.operand_element_type.push_back(
instr->operand(i)->shape().element_type());
}
config.replica_count = replica_count;
config.replica_groups = instr->replica_groups();
auto reduction_kind = MatchReductionComputation(instr->to_apply());
auto reduction_kind = MatchReductionComputation(hlo->to_apply());
CHECK(reduction_kind.has_value());
config.reduction_kind = reduction_kind.value();
if (instr->channel_id().has_value()) {
config.collective_op_kind = RendezvousKey::kCrossModule;
config.op_id = instr->channel_id().value();
} else {
config.collective_op_kind = RendezvousKey::kCrossReplica;
config.op_id = static_cast<int64>(instr->GetModule()->unique_id());
}
config.aux_data = std::make_unique<NcclAllReduceConfig::AuxData>();
NcclAllReduceConfig config;
config.config = GetNcclCollectiveConfig(hlo, replica_count);
config.reduction_kind = reduction_kind.value();
return config;
}
/*static*/ bool NcclAllReduceThunk::CanImplement(const HloInstruction* crs) {
auto operands_are_supported = [crs]() {
return absl::c_all_of(crs->operands(), [](HloInstruction* operand) {
/*static*/ bool NcclAllReduceThunk::CanImplement(const HloInstruction* hlo) {
auto operands_are_supported = [hlo]() {
return absl::c_all_of(hlo->operands(), [](HloInstruction* operand) {
return LayoutUtil::IsDenseArray(operand->shape()) &&
ToNcclDataType(operand->shape().element_type()).ok();
});
};
return MatchReductionComputation(crs->to_apply()).has_value() &&
crs->IsCrossReplicaAllReduce() && operands_are_supported();
return MatchReductionComputation(hlo->to_apply()).has_value() &&
hlo->IsCrossReplicaAllReduce() && operands_are_supported();
}
NcclAllReduceThunk::NcclAllReduceThunk(
ThunkInfo thunk_info, NcclAllReduceConfig&& config,
ThunkInfo thunk_info, NcclAllReduceConfig config,
std::vector<NcclAllReduceThunk::Buffer> buffers)
: Thunk(Thunk::kNcclAllReduce, thunk_info),
: NcclCollectiveThunk(Thunk::kNcclAllReduce, thunk_info),
config_(std::move(config)),
buffers_(std::move(buffers)) {
CHECK_EQ(config_.operand_count, buffers_.size());
CHECK_EQ(config_.config.operand_count, buffers_.size());
}
// Figures out which devices (named by their replica-ids) are participating in
// the all-reduce subgroup that contains device_ordinal.
Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(1) << "Starting NcclAllReduceThunk.";
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());
se::StreamExecutor* executor = params.stream->parent();
int device_ordinal = executor->device_ordinal();
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
params.GetGlobalDeviceId());
TF_ASSIGN_OR_RETURN(
std::vector<GlobalDeviceId> participants,
GetParticipatingDevices(global_device_id, *params.device_assn,
config_.replica_count, config_.replica_groups));
if (IsGlobalNcclConfig() && (participants.size() != config_.replica_count)) {
return InvalidArgument(
"Partial replica groups are not allowed when using NCCL_COMM_ID "
"environment configuration.");
}
TF_ASSIGN_OR_RETURN(
std::vector<LocalParticipant> local_participants,
GetLocalParticipants(participants, params.gpu_global_device_ids));
// Create the rendezvous for this collective operation.
RendezvousKey rendezvous_key(params.run_id, std::move(participants),
local_participants.size(),
config_.collective_op_kind, config_.op_id);
TF_ASSIGN_OR_RETURN(
LockedNcclClique locked_clique,
AcquireNcclClique(rendezvous_key, device_ordinal, params.stream,
local_participants, params.nccl_unique_id_callback));
ncclComm_t comm =
locked_clique.clique->GetCommForDeviceOrdinal(device_ordinal);
Status NcclAllReduceThunk::RunNcclCollective(const ExecuteParams& params,
ncclComm_t comm) {
int device_ordinal = params.stream->parent()->device_ordinal();
VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal;
ncclRedOp_t reduction_kind = ToNcclReduction(config_.reduction_kind);
se::gpu::ScopedActivateExecutorContext scoped_context(executor);
ncclRedOp_t reduce_op = ToNcclReduction(config_.reduction_kind);
cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
params.stream->implementation()->GpuStreamMemberHack());
VLOG(3) << "Using stream pointer: " << cu_stream
<< " on device: " << device_ordinal;
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
for (size_t i = 0; i < buffers_.size(); ++i) {
const Buffer& buffer = buffers_[i];
@ -179,31 +88,29 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
void* recv_buffer =
params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
.opaque();
TF_ASSIGN_OR_RETURN(ncclDataType_t datatype,
ToNcclDataType(config_.operand_element_type[i]));
ToNcclDataType(config_.config.operand_element_type[i]));
VLOG(3) << absl::StreamFormat(
"Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
"comm=%p, stream=%p)",
send_buffer, recv_buffer, buffer.element_count,
static_cast<const void*>(comm), cu_stream);
XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
/*count=*/buffer.element_count,
datatype,
/*op=*/reduction_kind, comm,
/*stream=*/*cu_stream));
buffer.element_count, datatype,
reduce_op, comm, *cu_stream));
}
XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
VLOG(3) << "Done performing all-reduce for ordinal: " << device_ordinal;
// Keep the clique we used alive for as long as this Thunk lives. Creating
// new NCCL cliques is expensive, and this is how we avoid thrashing them.
{
tensorflow::mutex_lock lock(config_.aux_data->mu);
config_.aux_data->cliques.insert(std::move(locked_clique.clique));
}
return Status::OK();
}
const NcclCollectiveConfig& NcclAllReduceThunk::config() const {
return config_.config;
}
} // namespace gpu
} // namespace xla

View File

@ -16,10 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
@ -28,53 +27,34 @@ namespace xla {
namespace gpu {
struct NcclAllReduceConfig {
int64 operand_count;
std::vector<PrimitiveType> operand_element_type;
int64 replica_count;
std::vector<ReplicaGroup> replica_groups;
NcclCollectiveConfig config;
ReductionKind reduction_kind;
RendezvousKey::CollectiveOpKind collective_op_kind;
int64 op_id;
NcclAllReduceConfig() = default;
NcclAllReduceConfig(NcclAllReduceConfig &&);
~NcclAllReduceConfig();
// Extra data stored in NcclAllReduceThunk whose types we don't want exposed
// in the header file. (This is mainly because the implementation of
// NcclAllReduceThunk is different depending on whether CUDA is enabled in the
// build, and we don't want to expose *that* mess in the header.)
struct AuxData;
std::unique_ptr<AuxData> aux_data;
};
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr,
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* hlo,
int64 replica_count);
// Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas.
class NcclAllReduceThunk : public Thunk {
class NcclAllReduceThunk : public NcclCollectiveThunk {
public:
// Returns whether NCCL operations appear possible to perform; e.g. if we
// haven't done a build with the CUDA compiler enabled, we can't compile the
// NCCL header, and thus this will be false.
//
// When this is false, the ExecuteOnStream() call will simply return a status
// error.
static bool NcclIsEnabled();
struct Buffer {
int64 element_count;
BufferAllocation::Slice source_buffer;
BufferAllocation::Slice destination_buffer;
};
NcclAllReduceThunk(ThunkInfo thunk_info, NcclAllReduceConfig &&config,
std::vector<Buffer> buffers);
Status ExecuteOnStream(const ExecuteParams& params) override;
NcclAllReduceThunk(ThunkInfo thunk_info, NcclAllReduceConfig config,
std::vector<Buffer> buffers);
// Returns whether the given instruction can be lowered to a nccl all-reduce
// call.
static bool CanImplement(const HloInstruction* crs);
static bool CanImplement(const HloInstruction* hlo);
protected:
Status RunNcclCollective(const ExecuteParams& params,
ncclComm_t comm) override;
const NcclCollectiveConfig& config() const override;
private:
const NcclAllReduceConfig config_;

View File

@ -0,0 +1,150 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
#include <chrono> // NOLINT (required by TF interfaces)
#include <cstdlib>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/global_device_id.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
namespace xla {
namespace gpu {
// This file runs collective ops (i.e. ops that communicate between multiple
// GPUs) using NCCL.
//
// Here's a high-level overview of how running an op works.
//
// - Multiple threads call ExecuteOnStream.
// - All threads that "go together" (i.e. are participating in the "same"
// collective op) choose the same Rendezvous object from a global map.
// - Once all threads have arrived at the Rendezvous, we know exactly which
// GPUs are participating in the op, so we get or create a NcclClique
// containing those GPUs.
// - We perform the NCCL operation using the clique.
//
// Creating NCCL cliques is expensive, so we cache them. Our policy is, a thunk
// keeps alive all cliques it's ever used. When the thunk is destroyed, it
// releases its handle on the cliques, and cliques whose refcounts go to 0 are
// destroyed.
// Extra data stored in NcclCollectiveThunk that we didn't want to expose in the
// header. In particular, this stores the thunk's cache of all NcclCliques it's
// ever used. This causes those cliques to stay alive as long as the thunk
// lives, which is how we avoid expensive reinitialization of NCCL cliques.
struct NcclCollectiveConfig::AuxData {
tensorflow::mutex mu;
absl::flat_hash_set<std::shared_ptr<NcclClique>> cliques TF_GUARDED_BY(mu);
};
NcclCollectiveConfig::NcclCollectiveConfig() = default;
NcclCollectiveConfig::NcclCollectiveConfig(NcclCollectiveConfig&&) = default;
NcclCollectiveConfig::~NcclCollectiveConfig() = default;
NcclCollectiveConfig& NcclCollectiveConfig::operator=(NcclCollectiveConfig&&) =
default;
NcclCollectiveConfig GetNcclCollectiveConfig(const HloInstruction* hlo,
int64 replica_count) {
NcclCollectiveConfig config;
config.operand_count = hlo->operands().size();
config.operand_element_type.reserve(config.operand_count);
for (int i = 0; i < config.operand_count; i++) {
config.operand_element_type.push_back(
hlo->operand(i)->shape().element_type());
}
config.replica_count = replica_count;
config.replica_groups = hlo->replica_groups();
if (hlo->channel_id().has_value()) {
config.collective_op_kind = RendezvousKey::kCrossModule;
config.op_id = hlo->channel_id().value();
} else {
config.collective_op_kind = RendezvousKey::kCrossReplica;
config.op_id = static_cast<int64>(hlo->GetModule()->unique_id());
}
config.aux_data = std::make_unique<NcclCollectiveConfig::AuxData>();
return config;
}
/* static */ bool NcclCollectiveThunk::NcclIsEnabled() {
return true; // Skylark selects this source file if NCCL is enabled.
}
Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(1) << absl::StreamFormat("Starting %s.", ThunkKindToString(kind()));
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
params.GetGlobalDeviceId());
TF_ASSIGN_OR_RETURN(
std::vector<GlobalDeviceId> participants,
GetParticipatingDevices(global_device_id, *params.device_assn,
config().replica_count, config().replica_groups));
if (IsGlobalNcclConfig() && (participants.size() != config().replica_count)) {
return InvalidArgument(
"Partial replica groups are not allowed when using NCCL_COMM_ID "
"environment configuration.");
}
TF_ASSIGN_OR_RETURN(
std::vector<LocalParticipant> local_participants,
GetLocalParticipants(participants, params.gpu_global_device_ids));
// Create the rendezvous for this collective operation.
RendezvousKey rendezvous_key(params.run_id, std::move(participants),
local_participants.size(),
config().collective_op_kind, config().op_id);
int device_ordinal = params.stream->parent()->device_ordinal();
TF_ASSIGN_OR_RETURN(
LockedNcclClique locked_clique,
AcquireNcclClique(rendezvous_key, device_ordinal, params.stream,
local_participants, params.nccl_unique_id_callback));
ncclComm_t comm =
locked_clique.clique->GetCommForDeviceOrdinal(device_ordinal);
se::StreamExecutor* executor = params.stream->parent();
se::gpu::ScopedActivateExecutorContext scoped_context(executor);
TF_RETURN_IF_ERROR(RunNcclCollective(params, comm));
// Keep the clique we used alive for as long as this Thunk lives. Creating
// new NCCL cliques is expensive, and this is how we avoid thrashing them.
{
tensorflow::mutex_lock lock(config().aux_data->mu);
config().aux_data->cliques.insert(std::move(locked_clique.clique));
}
return Status::OK();
}
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,79 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
struct ncclComm;
using ncclComm_t = ncclComm*;
namespace xla {
namespace gpu {
struct NcclCollectiveConfig {
NcclCollectiveConfig();
NcclCollectiveConfig(NcclCollectiveConfig&&);
~NcclCollectiveConfig();
NcclCollectiveConfig& operator=(NcclCollectiveConfig&&);
int64 operand_count;
std::vector<PrimitiveType> operand_element_type;
int64 replica_count;
std::vector<ReplicaGroup> replica_groups;
RendezvousKey::CollectiveOpKind collective_op_kind;
int64 op_id;
// Extra data stored in NcclCollectiveConfig whose types we don't want exposed
// in the header file. (This is mainly because the implementation of
// NcclCollectiveConfig is different depending on whether CUDA is enabled in
// the build, and we don't want to expose *that* mess in the header.)
struct AuxData;
std::unique_ptr<AuxData> aux_data;
};
NcclCollectiveConfig GetNcclCollectiveConfig(const HloInstruction* hlo,
int64 replica_count);
// Thunk base class for NCCL collective operations.
class NcclCollectiveThunk : public Thunk {
public:
using Thunk::Thunk;
// Returns whether NCCL operations appear possible to perform; e.g. if we
// haven't done a build with the CUDA compiler enabled, we can't compile the
// NCCL header, and thus this will be false.
//
// When this is false, the ExecuteOnStream() call will simply return a status
// error.
static bool NcclIsEnabled();
Status ExecuteOnStream(const ExecuteParams& params) override;
protected:
virtual Status RunNcclCollective(const ExecuteParams& params,
ncclComm_t comm) = 0;
virtual const NcclCollectiveConfig& config() const = 0;
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_