diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 881d4509475..f2e85a5abfb 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3470,10 +3470,13 @@ cc_library( srcs = ["hlo_runner.cc"], hdrs = ["hlo_runner.h"], deps = [ + ":backend", + ":compiler", ":computation_placer", ":executable", ":hlo", ":hlo_module_group", + ":hlo_parser", ":transfer_manager", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -3481,11 +3484,9 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:backend", - "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 16492c085cc..5b5ad63ec94 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -7,7 +7,7 @@ load( "//tensorflow/core:platform/default/build_config_root.bzl", "tf_cuda_tests_tags", ) -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library", "if_cuda") licenses(["notice"]) # Apache 2.0 @@ -156,7 +156,6 @@ cc_library( "ir_emitter_unnested.h", ], deps = [ - ":backend_configs", ":buffer_allocations", ":cudnn_conv_runner", ":elemental_ir_emitter", @@ -164,8 +163,10 @@ cc_library( ":gpu_executable", ":hlo_to_ir_bindings", ":ir_emission_utils", + ":nccl_all_reduce_thunk", ":parallel_loop_emitter", ":partition_assignment", + ":thunk", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -179,6 +180,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:name_uniquer", + "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", @@ -287,6 +289,40 @@ cc_library( ], ) +cc_library( + name = "thunk", + srcs = ["thunk.cc"], + hdrs = ["thunk.h"], + deps = [ + ":buffer_allocations", + ":hlo_execution_profiler", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +tf_cuda_library( + name = "nccl_all_reduce_thunk", + srcs = ["nccl_all_reduce_thunk.cc"], + hdrs = ["nccl_all_reduce_thunk.h"], + deps = [ + ":buffer_allocations", + ":hlo_execution_profiler", + ":thunk", + "@com_google_absl//absl/synchronization", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor/cuda:cuda_activation", + "//tensorflow/stream_executor/cuda:cuda_gpu_executor", + ] + if_cuda([ + "@local_config_nccl//:nccl", + ]), +) + cc_library( name = "gpu_executable", srcs = [ @@ -303,7 +339,6 @@ cc_library( "memset_thunk.cc", "outfeed_thunk.cc", "sequential_thunk.cc", - "thunk.cc", "thunk_schedule.cc", "triangular_solve_thunk.cc", "tuple_thunk.cc", @@ -323,7 +358,6 @@ cc_library( "memset_thunk.h", "outfeed_thunk.h", "sequential_thunk.h", - "thunk.h", "thunk_schedule.h", "triangular_solve_thunk.h", "tuple_thunk.h", @@ -335,9 +369,11 @@ cc_library( ":hlo_execution_profiler", ":infeed_manager", ":ir_emission_utils", + ":nccl_all_reduce_thunk", ":outfeed_manager", ":partition_assignment", ":stream_assignment", + ":thunk", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_tree", diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 4e5b86adb19..cacf6b97315 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -54,6 +54,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" @@ -74,6 +75,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -102,6 +104,8 @@ using absl::StrCat; using llvm_ir::IrArray; using llvm_ir::IrName; +namespace m = match; + // If a dimensions is smaller than this, untiled transposition may be more // efficient. const int64 kMinDimensionToTransposeTiled = 16; @@ -1318,11 +1322,55 @@ Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { return IrEmitter::HandleTupleSelect(tuple_select); } +namespace { + +bool IsScalarAddComputation(HloComputation* computation) { + return Match(computation->root_instruction(), + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)) + .WithShape(m::Shape().IsEffectiveScalar())); +} + +} // namespace + Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { + VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count() + << "; operand count: " << crs->operand_count() + << "; NCCL is enabled: " << NcclAllReduceThunk::NcclIsEnabled(); + + // Note the replica_count == 1 case is handled via device-to-device copy + // below. + bool should_use_nccl_thunk = + hlo_module_config_.replica_count() > 1 && + crs->IsCrossReplicaAllReduce() && + crs->operand_count() == 1 && // One array to reduce. + crs->operand(0)->shape().element_type() == F32 && + // Check the computation is a summation. + IsScalarAddComputation(crs->to_apply()); + + if (should_use_nccl_thunk) { + CHECK(crs->operand(0)->shape().IsArray()) + << "Operands to all-reduce must be arrays: " << crs->ToString(); + AddThunkToThunkSequence(absl::make_unique( + /*replica_count=*/hlo_module_config_.replica_count(), + /*elements=*/ShapeUtil::ElementsIn(crs->operand(0)->shape()), + /*source_address=*/GetAllocationSlice(*crs->operand(0)), + /*destination_buffer=*/GetAllocationSlice(*crs), crs)); + return Status::OK(); + } + if (hlo_module_config_.replica_count() != 1) { - // TODO(b/33011107): Support nontrivial cross replica sum on GPU. - return Unimplemented( - "AllReduce with >1 replica is not implemented on GPU."); + // TODO(b/33011107): Support more AllReduce configurations on GPU. + string message = absl::StrFormat( + "Requested AllReduce not implemented on GPU; replica_count: %d; " + "operand_count: %d; IsCrossReplicaAllReduce: %d; NCCL support: %d", + hlo_module_config_.replica_count(), crs->operand_count(), + crs->IsCrossReplicaAllReduce(), NcclAllReduceThunk::NcclIsEnabled()); + if (crs->operand_count() > 0) { + absl::StrAppendFormat( + &message, "; first operand array element-type: %s", + PrimitiveType_Name(crs->operand(0)->shape().element_type())); + } + return Unimplemented("%s", message); } // CRS with one operand and one replica is simply the identity function. diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc new file mode 100644 index 00000000000..3051db3af4a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -0,0 +1,356 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" + +#include "tensorflow/compiler/xla/util.h" + +#if GOOGLE_CUDA +#include "absl/synchronization/blocking_counter.h" +#include "third_party/nccl/nccl.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/stream_executor/cuda/cuda_activation.h" +#endif + +namespace xla { +namespace gpu { + +/* static */ bool NcclAllReduceThunk::NcclIsEnabled() { +#if GOOGLE_CUDA + return true; +#else + return false; +#endif +} + +#if GOOGLE_CUDA +namespace { + +// GPU-replica-driving host threads (i.e. the threads that call +// GpuExecutable::Execute) build up this structure to describe their +// participating replica, and then call to +// GlobalRendezvousManager::SubmitParticipant. +struct ParticipantData { + // Number of replicas particiating in the AllReduce. + int64 replica_count; + + int64 element_count; + int64 device_ordinal; + int64 generation_counter; + + // TODO(b/125951860): We should vet that we're buffer allocating such that + // source_buffer == destination_buffer if that avoids a NCCL copy (will depend + // on how well the NCCL in-place implementation performs vs the out-of-place + // implementation). + se::DeviceMemoryBase source_data; + se::DeviceMemoryBase destination_data; + se::Stream* stream; + + NcclAllReduceThunk* originator; + + string ToString() const { + return absl::StrFormat( + "ParticipantData{replica_count=%d, element_count=%d, " + "device_ordinal=%d, generation_counter=%d, stream=%p, originator=%p}", + replica_count, element_count, device_ordinal, generation_counter, + stream, originator); + } +}; + +// Class that gets instantiated as a singleton in GetGlobalRendezvous() to +// coordinate participating threads in performing an AllReduce operation. +// +// This manager is responsible for establishing communication channels and +// ultimately enqueueing the NCCL library operation onto the participating +// streams. +class GlobalRendezvousManager { + public: + // The GpuExecutable-executing threads call this in order to a) establish the + // all-reduce rendezvous and b) enqueue the AllReduce operation on the caller + // thread's associated stream (given in "participant"). + // + // Implementation note: since the rendezvous we're creating here is global, we + // try to be paranoid about the fact that the *correct* one is happening. In + // an ideal world we'd have some StreamExecutor se::Platform level construct + // that we could use for cross-device networking primitives (e.g. via a + // NetworkSupport interface) that could be shared between TensorFlow and XLA, + // but this is a reasonable stopgap measure to get multi-GPU-replica up and + // running properly for single-host, single-concurrent-XLA-module usage. + Status SubmitParticipant(ParticipantData participant); + + // Returns the current generation number of AllReduce operations. + // (Currently one AllReduce operation occurs per generation.) + int64 GetCurrentGeneration() { + tensorflow::mutex_lock lock(mutex_); + return current_generation_; + } + + private: + // Called by the primary thread to set up the communication links. + // + // TODO(b/125951860): This performs lots of (presumably) unnecessary host-side + // synchronization so that we can be paranoid about semantics in the earliest + // implementation. In the limit we should only need to synchronize host + // replica threads when the "number of replicas" or "participating device + // ordinals" change, to set up a new NCCL "communication" context, at which + // point we can enqueue onto device streams without host synchronization in + // our code -- this will likely be helpful for "lots of little AllReduce" + // cases. + Status InitializeCommunicationChannels() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Called when all necessary participants are present, the functionality + // that's implemented by all executing threads lives in here. + Status DoAllReduce(ParticipantData data, ncclComm_t comm); + + // Puts all state back into a "reset" state for the next generation of + // AllReduce requests. + void DeinitializeGeneration() EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + for (ncclComm_t& comm : comms_) { + ncclCommDestroy(comm); + } + comms_.clear(); + participants_.clear(); + current_generation_++; + initialized_ = false; + done_ = absl::nullopt; + } + + tensorflow::mutex mutex_; + tensorflow::condition_variable all_participants_present_; + tensorflow::condition_variable deinitialized_; + + // Communication handles that correspond to the participants below. + std::vector comms_ GUARDED_BY(mutex_); + + Status initialize_status_ GUARDED_BY(mutex_); + std::vector participants_ GUARDED_BY(mutex_); + int64 current_generation_ GUARDED_BY(mutex_) = 0; + bool initialized_ GUARDED_BY(mutex_) = false; + + // The participating threads wait for this to count down in order to know we + // can begin the teardown process. + absl::optional done_; +}; + +Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) { + auto all_participants_present = [this, &participant]() + EXCLUSIVE_LOCKS_REQUIRED(mutex_) -> bool { + return participants_.size() >= participant.replica_count; + }; + + // We remember the participant index at which we are inserted and use that + // same index for referring to auxiliary metadata (e.g. the ncclComm_t handle + // index) below. + int64 index; + + { + tensorflow::mutex_lock lock(mutex_); + + // Spot check for consistent replica counts among submitting threads. + if (!participants_.empty() && + (participants_.back().replica_count != participant.replica_count || + participants_.back().originator != participant.originator)) { + return InvalidArgument( + "Running two XLA modules with AllReduces in parallel is not " + "supported. It is possible this is due to a bug where were try to " + "run two different AllReduces from the same module at once. " + "(Attempted a rendezvous with a different replica count from other " + "participants; existing: %s; submitted: %s)", + participants_.back().ToString(), participant.ToString()); + } + index = participants_.size(); + participants_.push_back(participant); + + if (all_participants_present()) { + all_participants_present_.notify_all(); + } + } + + // We pull into our thread a) the communication handle and b) whether we're + // the "primary" thread for this rendezvous -- the "primary" thread has some + // additional responsibilities for setup/teardown. + ncclComm_t comm; + bool primary; + + { + tensorflow::mutex_lock lock(mutex_); + while (!all_participants_present()) { + // Once all the participants have arrived, all participating threads will + // cross this barrier, though only (the first) one will be the "primary". + all_participants_present_.wait(lock); + } + + // Somebody will be the first -- that thread has some additional + // responsibilities. + primary = !initialized_; + + CHECK_EQ(participant.generation_counter, current_generation_); + + // Bump the generation counter so the other threads know we've completed the + // global rendezvous and have set up the AllReduce. + if (primary) { + VLOG(3) << "Primary initializing accounting data."; + initialized_ = true; + done_.emplace(participant.replica_count); + initialize_status_ = InitializeCommunicationChannels(); + VLOG(3) << "Done initializing communication channels; status: " + << initialize_status_; + if (!initialize_status_.ok()) { + DeinitializeGeneration(); + } + } + + if (!initialize_status_.ok()) { + // TODO(b/125951860): If this fails once, it will fail forever. + return initialize_status_; + } + + comm = comms_[index]; + + // Drop the lock at the end of scope so other participants may enter. + } + + VLOG(3) << "Performing all reduce from device ordinal: " + << participant.device_ordinal; + + Status all_reduce_status = DoAllReduce(participant, comm); + + VLOG(3) << "Waiting for all participants to complete enqueue."; + + done_->DecrementCount(); + + if (primary) { + // Primary thread clears out the AllReduce state when everybody is done to + // make it clean-slate for any subsequent AllReduce request (e.g. number of + // replicas may change in the next request). + // + // Note surrounding TODOs for only reinitializing this when the replica + // count / participants actually change -- lots of "playing it safe" + // happening in this first cut. + done_->Wait(); + VLOG(3) << "All participants completed enqueue."; + VLOG(3) << "Primary thread clearing."; + tensorflow::mutex_lock lock(mutex_); + DeinitializeGeneration(); + VLOG(3) << "Generation is now: " << current_generation_; + deinitialized_.notify_all(); + } else { + VLOG(3) << "Waiting to deinitialize."; + tensorflow::mutex_lock lock(mutex_); + while (initialized_) { + deinitialized_.wait(lock); + } + } + + VLOG(3) << "Returning status: " << all_reduce_status; + return all_reduce_status; +} + +Status GlobalRendezvousManager::InitializeCommunicationChannels() { + std::vector ordinals; + for (ParticipantData& data : participants_) { + ordinals.push_back(data.device_ordinal); + } + comms_.resize(ordinals.size()); + VLOG(3) << "Participants: " << participants_.size() + << "; initializing comms."; + ncclResult_t result = ncclCommInitAll(comms_.data(), comms_.size(), + /*devlist=*/ordinals.data()); + if (result != ncclSuccess) { + comms_.clear(); + return InternalError( + "Failed to initialize NCCL communication channels for %d participants: " + "%s", + participants_.size(), ncclGetErrorString(result)); + } + return Status::OK(); +} + +Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant, + ncclComm_t comm) { + se::StreamExecutor* executor = participant.stream->parent(); + se::cuda::ScopedActivateExecutorContext scoped_context(executor); + cudaStream_t* cu_stream = reinterpret_cast( + participant.stream->implementation()->GpuStreamMemberHack()); + VLOG(3) << "Using stream pointer: " << cu_stream + << " on device: " << participant.device_ordinal; + void* send_buffer = participant.source_data.opaque(); + void* recv_buffer = participant.destination_data.opaque(); + ncclResult_t result = ncclAllReduce(send_buffer, recv_buffer, + /*count=*/participant.element_count, + /*datatype=*/ncclFloat, + /*op=*/ncclSum, + /*comm=*/comm, + /*stream=*/*cu_stream); + TF_RET_CHECK(ncclSuccess == result) + << "Failed to perform all-reduce: " << ncclGetErrorString(result); + + VLOG(3) << "Done performing all reduce for ordinal: " + << participant.device_ordinal; + + return Status::OK(); +} + +static GlobalRendezvousManager* GetGlobalRendezvous() { + static auto* manager = new GlobalRendezvousManager; + return manager; +} + +} // namespace + +Status NcclAllReduceThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + auto* global_rendezvous = GetGlobalRendezvous(); + + ParticipantData participant; + participant.replica_count = replica_count_; + participant.element_count = element_count_; + participant.device_ordinal = stream->parent()->device_ordinal(); + participant.generation_counter = global_rendezvous->GetCurrentGeneration(); + participant.source_data = buffer_allocations.GetDeviceAddress(source_buffer_); + participant.destination_data = + buffer_allocations.GetDeviceAddress(destination_buffer_); + participant.stream = stream; + participant.originator = this; + + return GetGlobalRendezvous()->SubmitParticipant(std::move(participant)); +} +#else + +Status NcclAllReduceThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + return Unimplemented( + "NCCL support is not available: this binary was not built with a CUDA " + "compiler, which is necessary to build the NCCL source library."); +} + +#endif // GOOGLE_CUDA + +NcclAllReduceThunk::NcclAllReduceThunk( + int64 replica_count, int64 element_count, + const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, + const HloInstruction* all_reduce) + : Thunk(Thunk::kNcclAllReduce, all_reduce), + replica_count_(replica_count), + element_count_(element_count), + source_buffer_(source_buffer), + destination_buffer_(destination_buffer) {} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h new file mode 100644 index 00000000000..1a8d1356c00 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -0,0 +1,62 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +// Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas. +class NcclAllReduceThunk : public Thunk { + public: + // Returns whether NCCL operations appear possible to perform; e.g. if we + // haven't done a build with the CUDA compiler enabled, we can't compile the + // NCCL header, and thus this will be false. + // + // When this is false, the ExecuteOnStream() call will simply return a status + // error. + static bool NcclIsEnabled(); + + // TODO(b/125951860): Plumb more datatypes / reduction operators. Initial + // implementation is simply F32 summation. + NcclAllReduceThunk(int64 replica_count, int64 element_count, + const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, + const HloInstruction* all_reduce); + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream, + HloExecutionProfiler* profiler) override; + + private: + const int64 replica_count_; + const int64 element_count_; + const BufferAllocation::Slice source_buffer_; + const BufferAllocation::Slice destination_buffer_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/thunk.cc b/tensorflow/compiler/xla/service/gpu/thunk.cc index a677617727c..2968c9ee39c 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk.cc @@ -32,6 +32,8 @@ std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { return os << "kCudnnBatchNormForwardInference"; case Thunk::kCudnnBatchNormForwardTraining: return os << "kCudnnBatchNormForwardTraining"; + case Thunk::kNcclAllReduce: + return os << "kNcclAllReduce"; case Thunk::kFft: return os << "kFft"; case Thunk::kGemm: diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index bc69af897a0..728aef82a76 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -48,6 +48,7 @@ class Thunk { kCudnnBatchNormBackward, kCudnnBatchNormForwardInference, kCudnnBatchNormForwardTraining, + kNcclAllReduce, kFft, kGemm, kInfeed, diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 5a5401e3513..8f44e1b37ee 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -269,8 +270,8 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( } StatusOr> HloRunner::ExecuteReplicated( - std::unique_ptr module, - const ReplicatedExecuteOptions& options) { + std::unique_ptr module, const ReplicatedExecuteOptions& options, + bool use_threads) { TF_ASSIGN_OR_RETURN( std::unique_ptr executable, CreateExecutable(std::move(module), options.run_hlo_passes)); @@ -369,9 +370,39 @@ StatusOr> HloRunner::ExecuteReplicated( } LOG(INFO) << "Replicated execution started"; - TF_ASSIGN_OR_RETURN(std::vector results, - executable->ExecuteOnStreams(service_run_options, - argument_buffer_slices)); + std::vector results; + if (!use_threads) { + TF_ASSIGN_OR_RETURN(results, + executable->ExecuteOnStreams(service_run_options, + argument_buffer_slices)); + } else { + tensorflow::mutex mutex; + std::vector> thread_results( + options.num_replicas); + { + LOG(INFO) << "Creating thread pool for " << options.num_replicas + << " replicas"; + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), + "replicas", options.num_replicas); + for (int64 i = 0; i < options.num_replicas; ++i) { + pool.Schedule([&, i] { + auto result = executable->ExecuteOnStream( + &service_run_options[i], argument_buffer_slices[i], nullptr); + tensorflow::mutex_lock lock(mutex); + thread_results[i] = std::move(result); + }); + } + + // Note: the thread pool destructor guarantees it completes all work + // before we leave this scope. + } + for (auto& thread_result : thread_results) { + if (!thread_result.ok()) { + return thread_result.status(); + } + results.push_back(std::move(thread_result).ValueOrDie()); + } + } LOG(INFO) << "Replicated execution terminated"; std::vector exec_results; diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 098989cd4c7..88a137e6452 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -165,9 +165,13 @@ class HloRunner { // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as // value. + // + // use_threads indicates whether this replicated computation will be executed + // with a thread-per-replica, vs using an implicitly async call such as + // Executable::ExecuteOnStreams. StatusOr> ExecuteReplicated( std::unique_ptr module, - const ReplicatedExecuteOptions& options); + const ReplicatedExecuteOptions& options, bool use_threads = false); // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 7158708e9c3..79a5b7539db 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1601,6 +1601,39 @@ xla_test( ], ) +xla_test( + name = "multi_device_all_reduce_test", + srcs = ["multi_device_all_reduce_test.cc"], + backends = ["gpu"], + tags = [ + "manual", + "multi_gpu", + "no_oss", + "notap", + ], + deps = [ + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 0151981ef16..62e2b465cfe 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -207,13 +207,14 @@ Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, StatusOr> HloTestBase::ExecuteReplicated( std::unique_ptr module, absl::Span arguments, - int64 num_replicas) { + int64 num_replicas, bool use_threads) { HloRunner::ReplicatedExecuteOptions options; options.num_replicas = num_replicas; for (auto argument : arguments) { options.arguments.push_back(argument); } - return test_runner_.ExecuteReplicated(std::move(module), options); + return test_runner_.ExecuteReplicated(std::move(module), options, + use_threads); } StatusOr> HloTestBase::MakeReferenceModule( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 3c2bcbb5df5..df9c29a186f 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -174,9 +174,13 @@ class HloTestBase : public ::testing::Test { absl::Span arguments); // Executes the given module on multiple replicas. + // + // use_threads indicates whether this replicated computation will be executed + // with a thread-per-replica, vs using an implicitly async call such as + // Executable::ExecuteOnStreams. StatusOr> ExecuteReplicated( std::unique_ptr module, absl::Span arguments, - int64 num_replicas); + int64 num_replicas, bool use_threads); // Executes the given hlo module on two backends and compares results. // diff --git a/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc b/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc new file mode 100644 index 00000000000..1513d89ba9c --- /dev/null +++ b/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc @@ -0,0 +1,56 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class MultiDeviceAllReduceTest : public HloTestBase {}; + +XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) { + const char* module_str = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p = f32[3] parameter(0) + ROOT crs = f32[3] all-reduce(p), to_apply=add + })"; + auto config = GetModuleConfigForTest(); + config.set_replica_count(2); + auto module = ParseHloString(module_str, config).ValueOrDie(); + auto literal = LiteralUtil::CreateR1({1, 2, 3}); + auto expected = LiteralUtil::CreateR1({2, 4, 6}); + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + ExecuteReplicated(std::move(module), {&literal}, 2, + /*use_threads=*/true)); + EXPECT_EQ(expected, results[0]); + EXPECT_EQ(expected, results[1]); +} + +} // namespace +} // namespace xla