From e6fe105d22fe74ed75a0323917c25d943171cc8c Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 4 Dec 2020 11:19:31 -0800 Subject: [PATCH] [XLA-GPU] NFC: Split out common base class for NCCL collective thunks. PiperOrigin-RevId: 345717933 Change-Id: I56cf6ce929413aeacd20207408d75609ed6fd16f --- tensorflow/compiler/xla/service/gpu/BUILD | 78 ++++++--- .../xla/service/gpu/dummy_all_reduce_thunk.cc | 34 ++-- .../xla/service/gpu/dummy_collective_thunk.cc | 46 ++++++ .../xla/service/gpu/nccl_all_reduce_thunk.cc | 153 ++++-------------- .../xla/service/gpu/nccl_all_reduce_thunk.h | 46 ++---- .../xla/service/gpu/nccl_collective_thunk.cc | 150 +++++++++++++++++ .../xla/service/gpu/nccl_collective_thunk.h | 79 +++++++++ 7 files changed, 389 insertions(+), 197 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/dummy_collective_thunk.cc create mode 100644 tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc create mode 100644 tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 3f275abe453..898314c4f7d 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -418,15 +418,6 @@ cc_library( ], ) -# First level of nested select. NCCL requires both if_cuda and if_nccl. -filegroup( - name = "nccl_all_reduce_thunk_src", - srcs = if_nccl( - ["nccl_all_reduce_thunk.cc"], - ["dummy_all_reduce_thunk.cc"], - ), -) - # use alias since nested select statements not possible cc_library( name = "empty", @@ -442,6 +433,55 @@ alias( actual = if_rocm("@local_config_rocm//rocm:rccl", ":empty"), ) +# First level of nested select. NCCL requires both if_cuda and if_nccl. +filegroup( + name = "nccl_collective_thunk_src", + srcs = if_nccl( + ["nccl_collective_thunk.cc"], + ["dummy_collective_thunk.cc"], + ), +) + +tf_cuda_library( + name = "nccl_collective_thunk", + srcs = if_cuda_or_rocm( + [":nccl_collective_thunk_src"], + ["dummy_collective_thunk.cc"], + ), + hdrs = ["nccl_collective_thunk.h"], + deps = [ + ":thunk", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "//tensorflow/compiler/xla/service:collective_ops_utils", + "//tensorflow/compiler/xla/service:global_device_id", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/core:lib", + ] + if_cuda([ + "//tensorflow/stream_executor/cuda:cuda_activation", + "//tensorflow/stream_executor/cuda:cuda_gpu_executor", + ]) + if_rocm([ + "//tensorflow/stream_executor/rocm:rocm_activation", + "//tensorflow/stream_executor/rocm:rocm_gpu_executor", + ]) + if_nccl([ + ":virtual_nccl", + ":virtual_nccl_utils", + ":virtual_rccl", + ]), +) + +# First level of nested select. NCCL requires both if_cuda and if_nccl. +filegroup( + name = "nccl_all_reduce_thunk_src", + srcs = if_nccl( + ["nccl_all_reduce_thunk.cc"], + ["dummy_all_reduce_thunk.cc"], + ), +) + tf_cuda_library( name = "nccl_all_reduce_thunk", srcs = if_cuda_or_rocm( @@ -451,28 +491,22 @@ tf_cuda_library( hdrs = ["nccl_all_reduce_thunk.h"], deps = [ ":buffer_allocations", - ":hlo_execution_profiler", - ":thunk", ":gpu_executable_run_options", + ":hlo_execution_profiler", + ":nccl_collective_thunk", + ":thunk", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings:str_format", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:collective_ops_utils", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core/platform:stream_executor_no_cuda", - ] + if_cuda([ - "//tensorflow/stream_executor/cuda:cuda_activation", - "//tensorflow/stream_executor/cuda:cuda_gpu_executor", - ]) + if_rocm([ - "//tensorflow/stream_executor/rocm:rocm_activation", - "//tensorflow/stream_executor/rocm:rocm_gpu_executor", - ]) + if_nccl([ + ] + if_nccl([ ":virtual_nccl", ":virtual_nccl_utils", ":virtual_rccl", diff --git a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc index 61e7b4965ab..bf5ff8c52eb 100644 --- a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc @@ -19,37 +19,33 @@ limitations under the License. namespace xla { namespace gpu { -struct NcclAllReduceConfig::AuxData {}; - -NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig &&) = default; -NcclAllReduceConfig::~NcclAllReduceConfig() = default; - -NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr, +NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* hlo, int64 replica_count) { - NcclAllReduceConfig config = {}; - return config; + return NcclAllReduceConfig(); } -/* static */ bool NcclAllReduceThunk::NcclIsEnabled() { - return false; // Skylark selects this source file if NCCL is disabled. -} +NcclAllReduceThunk::NcclAllReduceThunk( + ThunkInfo thunk_info, NcclAllReduceConfig config, + std::vector buffers) + : NcclCollectiveThunk(Thunk::kNcclAllReduce, thunk_info), + config_(std::move(config)), + buffers_(std::move(buffers)) {} -/* static */ bool NcclAllReduceThunk::CanImplement(const HloInstruction* crs) { +/* static */ bool NcclAllReduceThunk::CanImplement(const HloInstruction* hlo) { return false; } -Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { +Status NcclAllReduceThunk::RunNcclCollective(const ExecuteParams&, ncclComm_t) { return Unimplemented( "NCCL support is not available: this binary was not built with a CUDA " "compiler, which is necessary to build the NCCL source library."); } -NcclAllReduceThunk::NcclAllReduceThunk( - ThunkInfo thunk_info, NcclAllReduceConfig &&config, - std::vector buffers) - : Thunk(Thunk::kNcclAllReduce, thunk_info), - config_(std::move(config)), - buffers_(std::move(buffers)) {} +const NcclCollectiveConfig& NcclAllReduceThunk::config() const { + // This function will never be called. + const NcclCollectiveConfig* config = nullptr; + return *config; +} } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/dummy_collective_thunk.cc b/tensorflow/compiler/xla/service/gpu/dummy_collective_thunk.cc new file mode 100644 index 00000000000..0c49b2d690a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/dummy_collective_thunk.cc @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { +namespace gpu { + +struct NcclCollectiveConfig::AuxData {}; + +NcclCollectiveConfig::NcclCollectiveConfig() = default; +NcclCollectiveConfig::NcclCollectiveConfig(NcclCollectiveConfig &&) = default; +NcclCollectiveConfig::~NcclCollectiveConfig() = default; +NcclCollectiveConfig &NcclCollectiveConfig::operator=(NcclCollectiveConfig &&) = + default; + +NcclCollectiveConfig GetNcclCollectiveConfig(const HloInstruction *hlo, + int64 replica_count) { + return NcclCollectiveConfig(); +} + +/* static */ bool NcclCollectiveThunk::NcclIsEnabled() { + return false; // Skylark selects this source file if NCCL is disabled. +} + +Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams &) { + return Unimplemented( + "NCCL support is not available: this binary was not built with a CUDA " + "compiler, which is necessary to build the NCCL source library."); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 9b797f7e9e6..6af0505ff38 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -22,10 +22,8 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" -#include "tensorflow/compiler/xla/service/collective_ops_utils.h" -#include "tensorflow/compiler/xla/service/global_device_id.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #if GOOGLE_CUDA #include "third_party/nccl/nccl.h" #elif TENSORFLOW_USE_ROCM @@ -36,140 +34,51 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/stream_executor/gpu/gpu_activation.h" namespace xla { namespace gpu { -// This file runs collective ops (i.e. ops that communicate between multiple -// GPUs) using NCCL. Currently only kAllReduce is implemented. -// -// Here's a high-level overview of how running an op works. -// -// - Multiple threads call NcclAllReduceThunk::ExecuteOnStream. -// - All threads that "go together" (i.e. are participating in the "same" -// collective op) choose the same Rendezvous object from a global map. -// - Once all threads have arrived at the Rendezvous, we know exactly which -// GPUs are participating in the op, so we get or create a NcclClique -// containing those GPUs. -// - We perform the NCCL operation using the clique, then destroy the -// Rendezvous. The clique is cached, see below. -// -// Creating NCCL cliques is expensive, so we cache them. Our policy is, a thunk -// keeps alive all cliques it's ever used. When the thunk is destroyed, it -// releases its handle on the cliques, and cliques whose refcounts go to 0 are -// destroyed. - -/* static */ bool NcclAllReduceThunk::NcclIsEnabled() { - return true; // Skylark selects this source file if NCCL is enabled. -} - -// Extra data stored in NcclAllReduceThunk that we didn't want to expose in the -// header. In particular, this stores the thunk's cache of all NcclCliques it's -// ever used. This causes those cliques to stay alive as long as the thunk -// lives, which is how we avoid expensive reinitialization of NCCL cliques. -struct NcclAllReduceConfig::AuxData { - tensorflow::mutex mu; - absl::flat_hash_set> cliques TF_GUARDED_BY(mu); -}; - -NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig&&) = default; -NcclAllReduceConfig::~NcclAllReduceConfig() = default; - -NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr, +NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* hlo, int64 replica_count) { - NcclAllReduceConfig config; - config.operand_count = instr->operands().size(); - config.operand_element_type.reserve(config.operand_count); - for (int i = 0; i < config.operand_count; i++) { - config.operand_element_type.push_back( - instr->operand(i)->shape().element_type()); - } - config.replica_count = replica_count; - config.replica_groups = instr->replica_groups(); - auto reduction_kind = MatchReductionComputation(instr->to_apply()); + auto reduction_kind = MatchReductionComputation(hlo->to_apply()); CHECK(reduction_kind.has_value()); - config.reduction_kind = reduction_kind.value(); - if (instr->channel_id().has_value()) { - config.collective_op_kind = RendezvousKey::kCrossModule; - config.op_id = instr->channel_id().value(); - } else { - config.collective_op_kind = RendezvousKey::kCrossReplica; - config.op_id = static_cast(instr->GetModule()->unique_id()); - } - config.aux_data = std::make_unique(); + NcclAllReduceConfig config; + config.config = GetNcclCollectiveConfig(hlo, replica_count); + config.reduction_kind = reduction_kind.value(); return config; } -/*static*/ bool NcclAllReduceThunk::CanImplement(const HloInstruction* crs) { - auto operands_are_supported = [crs]() { - return absl::c_all_of(crs->operands(), [](HloInstruction* operand) { +/*static*/ bool NcclAllReduceThunk::CanImplement(const HloInstruction* hlo) { + auto operands_are_supported = [hlo]() { + return absl::c_all_of(hlo->operands(), [](HloInstruction* operand) { return LayoutUtil::IsDenseArray(operand->shape()) && ToNcclDataType(operand->shape().element_type()).ok(); }); }; - return MatchReductionComputation(crs->to_apply()).has_value() && - crs->IsCrossReplicaAllReduce() && operands_are_supported(); + return MatchReductionComputation(hlo->to_apply()).has_value() && + hlo->IsCrossReplicaAllReduce() && operands_are_supported(); } NcclAllReduceThunk::NcclAllReduceThunk( - ThunkInfo thunk_info, NcclAllReduceConfig&& config, + ThunkInfo thunk_info, NcclAllReduceConfig config, std::vector buffers) - : Thunk(Thunk::kNcclAllReduce, thunk_info), + : NcclCollectiveThunk(Thunk::kNcclAllReduce, thunk_info), config_(std::move(config)), buffers_(std::move(buffers)) { - CHECK_EQ(config_.operand_count, buffers_.size()); + CHECK_EQ(config_.config.operand_count, buffers_.size()); } -// Figures out which devices (named by their replica-ids) are participating in -// the all-reduce subgroup that contains device_ordinal. -Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { - VLOG(1) << "Starting NcclAllReduceThunk."; - auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(profile_index()); - - se::StreamExecutor* executor = params.stream->parent(); - int device_ordinal = executor->device_ordinal(); - TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id, - params.GetGlobalDeviceId()); - - TF_ASSIGN_OR_RETURN( - std::vector participants, - GetParticipatingDevices(global_device_id, *params.device_assn, - config_.replica_count, config_.replica_groups)); - - if (IsGlobalNcclConfig() && (participants.size() != config_.replica_count)) { - return InvalidArgument( - "Partial replica groups are not allowed when using NCCL_COMM_ID " - "environment configuration."); - } - - TF_ASSIGN_OR_RETURN( - std::vector local_participants, - GetLocalParticipants(participants, params.gpu_global_device_ids)); - - // Create the rendezvous for this collective operation. - RendezvousKey rendezvous_key(params.run_id, std::move(participants), - local_participants.size(), - config_.collective_op_kind, config_.op_id); - - TF_ASSIGN_OR_RETURN( - LockedNcclClique locked_clique, - AcquireNcclClique(rendezvous_key, device_ordinal, params.stream, - local_participants, params.nccl_unique_id_callback)); - ncclComm_t comm = - locked_clique.clique->GetCommForDeviceOrdinal(device_ordinal); - +Status NcclAllReduceThunk::RunNcclCollective(const ExecuteParams& params, + ncclComm_t comm) { + int device_ordinal = params.stream->parent()->device_ordinal(); VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal; - ncclRedOp_t reduction_kind = ToNcclReduction(config_.reduction_kind); - se::gpu::ScopedActivateExecutorContext scoped_context(executor); + ncclRedOp_t reduce_op = ToNcclReduction(config_.reduction_kind); + cudaStream_t* cu_stream = reinterpret_cast( params.stream->implementation()->GpuStreamMemberHack()); - VLOG(3) << "Using stream pointer: " << cu_stream - << " on device: " << device_ordinal; + XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart()); for (size_t i = 0; i < buffers_.size(); ++i) { const Buffer& buffer = buffers_[i]; @@ -179,31 +88,29 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { void* recv_buffer = params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer) .opaque(); + TF_ASSIGN_OR_RETURN(ncclDataType_t datatype, - ToNcclDataType(config_.operand_element_type[i])); + ToNcclDataType(config_.config.operand_element_type[i])); + VLOG(3) << absl::StreamFormat( "Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, " "comm=%p, stream=%p)", send_buffer, recv_buffer, buffer.element_count, static_cast(comm), cu_stream); + XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer, - /*count=*/buffer.element_count, - datatype, - /*op=*/reduction_kind, comm, - /*stream=*/*cu_stream)); + buffer.element_count, datatype, + reduce_op, comm, *cu_stream)); } XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd()); VLOG(3) << "Done performing all-reduce for ordinal: " << device_ordinal; - - // Keep the clique we used alive for as long as this Thunk lives. Creating - // new NCCL cliques is expensive, and this is how we avoid thrashing them. - { - tensorflow::mutex_lock lock(config_.aux_data->mu); - config_.aux_data->cliques.insert(std::move(locked_clique.clique)); - } return Status::OK(); } +const NcclCollectiveConfig& NcclAllReduceThunk::config() const { + return config_.config; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index d6f37d9b953..a9261026c06 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -16,10 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ -#include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/collective_ops_utils.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" -#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" @@ -28,53 +27,34 @@ namespace xla { namespace gpu { struct NcclAllReduceConfig { - int64 operand_count; - std::vector operand_element_type; - int64 replica_count; - std::vector replica_groups; + NcclCollectiveConfig config; ReductionKind reduction_kind; - RendezvousKey::CollectiveOpKind collective_op_kind; - int64 op_id; - - NcclAllReduceConfig() = default; - NcclAllReduceConfig(NcclAllReduceConfig &&); - ~NcclAllReduceConfig(); - - // Extra data stored in NcclAllReduceThunk whose types we don't want exposed - // in the header file. (This is mainly because the implementation of - // NcclAllReduceThunk is different depending on whether CUDA is enabled in the - // build, and we don't want to expose *that* mess in the header.) - struct AuxData; - std::unique_ptr aux_data; }; -NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr, +NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* hlo, int64 replica_count); // Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas. -class NcclAllReduceThunk : public Thunk { +class NcclAllReduceThunk : public NcclCollectiveThunk { public: - // Returns whether NCCL operations appear possible to perform; e.g. if we - // haven't done a build with the CUDA compiler enabled, we can't compile the - // NCCL header, and thus this will be false. - // - // When this is false, the ExecuteOnStream() call will simply return a status - // error. - static bool NcclIsEnabled(); - struct Buffer { int64 element_count; BufferAllocation::Slice source_buffer; BufferAllocation::Slice destination_buffer; }; - NcclAllReduceThunk(ThunkInfo thunk_info, NcclAllReduceConfig &&config, - std::vector buffers); - Status ExecuteOnStream(const ExecuteParams& params) override; + NcclAllReduceThunk(ThunkInfo thunk_info, NcclAllReduceConfig config, + std::vector buffers); // Returns whether the given instruction can be lowered to a nccl all-reduce // call. - static bool CanImplement(const HloInstruction* crs); + static bool CanImplement(const HloInstruction* hlo); + + protected: + Status RunNcclCollective(const ExecuteParams& params, + ncclComm_t comm) override; + + const NcclCollectiveConfig& config() const override; private: const NcclAllReduceConfig config_; diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc new file mode 100644 index 00000000000..03d289ed54a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc @@ -0,0 +1,150 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h" + +#include // NOLINT (required by TF interfaces) +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/service/collective_ops_utils.h" +#include "tensorflow/compiler/xla/service/global_device_id.h" +#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/stream_executor/gpu/gpu_activation.h" + +namespace xla { +namespace gpu { + +// This file runs collective ops (i.e. ops that communicate between multiple +// GPUs) using NCCL. +// +// Here's a high-level overview of how running an op works. +// +// - Multiple threads call ExecuteOnStream. +// - All threads that "go together" (i.e. are participating in the "same" +// collective op) choose the same Rendezvous object from a global map. +// - Once all threads have arrived at the Rendezvous, we know exactly which +// GPUs are participating in the op, so we get or create a NcclClique +// containing those GPUs. +// - We perform the NCCL operation using the clique. +// +// Creating NCCL cliques is expensive, so we cache them. Our policy is, a thunk +// keeps alive all cliques it's ever used. When the thunk is destroyed, it +// releases its handle on the cliques, and cliques whose refcounts go to 0 are +// destroyed. + +// Extra data stored in NcclCollectiveThunk that we didn't want to expose in the +// header. In particular, this stores the thunk's cache of all NcclCliques it's +// ever used. This causes those cliques to stay alive as long as the thunk +// lives, which is how we avoid expensive reinitialization of NCCL cliques. +struct NcclCollectiveConfig::AuxData { + tensorflow::mutex mu; + absl::flat_hash_set> cliques TF_GUARDED_BY(mu); +}; + +NcclCollectiveConfig::NcclCollectiveConfig() = default; +NcclCollectiveConfig::NcclCollectiveConfig(NcclCollectiveConfig&&) = default; +NcclCollectiveConfig::~NcclCollectiveConfig() = default; +NcclCollectiveConfig& NcclCollectiveConfig::operator=(NcclCollectiveConfig&&) = + default; + +NcclCollectiveConfig GetNcclCollectiveConfig(const HloInstruction* hlo, + int64 replica_count) { + NcclCollectiveConfig config; + config.operand_count = hlo->operands().size(); + config.operand_element_type.reserve(config.operand_count); + for (int i = 0; i < config.operand_count; i++) { + config.operand_element_type.push_back( + hlo->operand(i)->shape().element_type()); + } + config.replica_count = replica_count; + config.replica_groups = hlo->replica_groups(); + + if (hlo->channel_id().has_value()) { + config.collective_op_kind = RendezvousKey::kCrossModule; + config.op_id = hlo->channel_id().value(); + } else { + config.collective_op_kind = RendezvousKey::kCrossReplica; + config.op_id = static_cast(hlo->GetModule()->unique_id()); + } + config.aux_data = std::make_unique(); + return config; +} + +/* static */ bool NcclCollectiveThunk::NcclIsEnabled() { + return true; // Skylark selects this source file if NCCL is enabled. +} + +Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { + VLOG(1) << absl::StreamFormat("Starting %s.", ThunkKindToString(kind())); + auto op_profiler = + params.profiler->MakeScopedInstructionProfiler(profile_index()); + + TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id, + params.GetGlobalDeviceId()); + + TF_ASSIGN_OR_RETURN( + std::vector participants, + GetParticipatingDevices(global_device_id, *params.device_assn, + config().replica_count, config().replica_groups)); + + if (IsGlobalNcclConfig() && (participants.size() != config().replica_count)) { + return InvalidArgument( + "Partial replica groups are not allowed when using NCCL_COMM_ID " + "environment configuration."); + } + + TF_ASSIGN_OR_RETURN( + std::vector local_participants, + GetLocalParticipants(participants, params.gpu_global_device_ids)); + + // Create the rendezvous for this collective operation. + RendezvousKey rendezvous_key(params.run_id, std::move(participants), + local_participants.size(), + config().collective_op_kind, config().op_id); + + int device_ordinal = params.stream->parent()->device_ordinal(); + + TF_ASSIGN_OR_RETURN( + LockedNcclClique locked_clique, + AcquireNcclClique(rendezvous_key, device_ordinal, params.stream, + local_participants, params.nccl_unique_id_callback)); + ncclComm_t comm = + locked_clique.clique->GetCommForDeviceOrdinal(device_ordinal); + + se::StreamExecutor* executor = params.stream->parent(); + se::gpu::ScopedActivateExecutorContext scoped_context(executor); + + TF_RETURN_IF_ERROR(RunNcclCollective(params, comm)); + + // Keep the clique we used alive for as long as this Thunk lives. Creating + // new NCCL cliques is expensive, and this is how we avoid thrashing them. + { + tensorflow::mutex_lock lock(config().aux_data->mu); + config().aux_data->cliques.insert(std::move(locked_clique.clique)); + } + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h new file mode 100644 index 00000000000..7f60c70c3bd --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h @@ -0,0 +1,79 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_ + +#include "tensorflow/compiler/xla/service/collective_ops_utils.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +struct ncclComm; +using ncclComm_t = ncclComm*; + +namespace xla { +namespace gpu { + +struct NcclCollectiveConfig { + NcclCollectiveConfig(); + NcclCollectiveConfig(NcclCollectiveConfig&&); + ~NcclCollectiveConfig(); + + NcclCollectiveConfig& operator=(NcclCollectiveConfig&&); + + int64 operand_count; + std::vector operand_element_type; + int64 replica_count; + std::vector replica_groups; + RendezvousKey::CollectiveOpKind collective_op_kind; + int64 op_id; + // Extra data stored in NcclCollectiveConfig whose types we don't want exposed + // in the header file. (This is mainly because the implementation of + // NcclCollectiveConfig is different depending on whether CUDA is enabled in + // the build, and we don't want to expose *that* mess in the header.) + struct AuxData; + std::unique_ptr aux_data; +}; + +NcclCollectiveConfig GetNcclCollectiveConfig(const HloInstruction* hlo, + int64 replica_count); + +// Thunk base class for NCCL collective operations. +class NcclCollectiveThunk : public Thunk { + public: + using Thunk::Thunk; + + // Returns whether NCCL operations appear possible to perform; e.g. if we + // haven't done a build with the CUDA compiler enabled, we can't compile the + // NCCL header, and thus this will be false. + // + // When this is false, the ExecuteOnStream() call will simply return a status + // error. + static bool NcclIsEnabled(); + + Status ExecuteOnStream(const ExecuteParams& params) override; + + protected: + virtual Status RunNcclCollective(const ExecuteParams& params, + ncclComm_t comm) = 0; + virtual const NcclCollectiveConfig& config() const = 0; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_