[XLA:GPU] Add NCCL-based AllReduce replica support to XLA.

This requires a CUDA-config build to enable, as the NCCL library can
only be built in a CUDA-enabled build. In non-CUDA-config builds the NCCL
thunk returns an error.

Used a super-conservative-and-quite-likely-overkill concurrency
approach, in a followup CL it'd be better to optimize for the common case where
we're enqueueing a lot of operations with the same replica count onto a stream
in a non-synchronizing fashion, and only force thread synchronization if the
number of replicas changes.

In the future this should likely be unified with NcclManager in
tensorflow/core/nccl -- for now it is separate since the EventMgr-style
memory allocation strategy from TensorFlow is not used in XLA, so some
parameterization of the memory strategy being used in that library is
likely necessary, at which point it should be reasonable to scoop out
this ~200 line implementation in the cc file and replace it with the
NcclManager abstraction to unify the two implementations.

PiperOrigin-RevId: 235632126
This commit is contained in:
Chris Leary 2019-02-25 17:23:06 -08:00 committed by TensorFlower Gardener
parent 5bce34fe7d
commit 9c95751e87
13 changed files with 654 additions and 19 deletions

View File

@ -3470,10 +3470,13 @@ cc_library(
srcs = ["hlo_runner.cc"],
hdrs = ["hlo_runner.h"],
deps = [
":backend",
":compiler",
":computation_placer",
":executable",
":hlo",
":hlo_module_group",
":hlo_parser",
":transfer_manager",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -3481,11 +3484,9 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
"@com_google_absl//absl/memory",

View File

@ -7,7 +7,7 @@ load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"tf_cuda_tests_tags",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library", "if_cuda")
licenses(["notice"]) # Apache 2.0
@ -156,7 +156,6 @@ cc_library(
"ir_emitter_unnested.h",
],
deps = [
":backend_configs",
":buffer_allocations",
":cudnn_conv_runner",
":elemental_ir_emitter",
@ -164,8 +163,10 @@ cc_library(
":gpu_executable",
":hlo_to_ir_bindings",
":ir_emission_utils",
":nccl_all_reduce_thunk",
":parallel_loop_emitter",
":partition_assignment",
":thunk",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -179,6 +180,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla/service:while_loop_analysis",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
@ -287,6 +289,40 @@ cc_library(
],
)
cc_library(
name = "thunk",
srcs = ["thunk.cc"],
hdrs = ["thunk.h"],
deps = [
":buffer_allocations",
":hlo_execution_profiler",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],
)
tf_cuda_library(
name = "nccl_all_reduce_thunk",
srcs = ["nccl_all_reduce_thunk.cc"],
hdrs = ["nccl_all_reduce_thunk.h"],
deps = [
":buffer_allocations",
":hlo_execution_profiler",
":thunk",
"@com_google_absl//absl/synchronization",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor/cuda:cuda_activation",
"//tensorflow/stream_executor/cuda:cuda_gpu_executor",
] + if_cuda([
"@local_config_nccl//:nccl",
]),
)
cc_library(
name = "gpu_executable",
srcs = [
@ -303,7 +339,6 @@ cc_library(
"memset_thunk.cc",
"outfeed_thunk.cc",
"sequential_thunk.cc",
"thunk.cc",
"thunk_schedule.cc",
"triangular_solve_thunk.cc",
"tuple_thunk.cc",
@ -323,7 +358,6 @@ cc_library(
"memset_thunk.h",
"outfeed_thunk.h",
"sequential_thunk.h",
"thunk.h",
"thunk_schedule.h",
"triangular_solve_thunk.h",
"tuple_thunk.h",
@ -335,9 +369,11 @@ cc_library(
":hlo_execution_profiler",
":infeed_manager",
":ir_emission_utils",
":nccl_all_reduce_thunk",
":outfeed_manager",
":partition_assignment",
":stream_assignment",
":thunk",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_tree",

View File

@ -54,6 +54,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
@ -74,6 +75,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -102,6 +104,8 @@ using absl::StrCat;
using llvm_ir::IrArray;
using llvm_ir::IrName;
namespace m = match;
// If a dimensions is smaller than this, untiled transposition may be more
// efficient.
const int64 kMinDimensionToTransposeTiled = 16;
@ -1318,11 +1322,55 @@ Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
return IrEmitter::HandleTupleSelect(tuple_select);
}
namespace {
bool IsScalarAddComputation(HloComputation* computation) {
return Match(computation->root_instruction(),
m::AddAnyOrder(m::Parameter(0), m::Parameter(1))
.WithShape(m::Shape().IsEffectiveScalar()));
}
} // namespace
Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count()
<< "; operand count: " << crs->operand_count()
<< "; NCCL is enabled: " << NcclAllReduceThunk::NcclIsEnabled();
// Note the replica_count == 1 case is handled via device-to-device copy
// below.
bool should_use_nccl_thunk =
hlo_module_config_.replica_count() > 1 &&
crs->IsCrossReplicaAllReduce() &&
crs->operand_count() == 1 && // One array to reduce.
crs->operand(0)->shape().element_type() == F32 &&
// Check the computation is a summation.
IsScalarAddComputation(crs->to_apply());
if (should_use_nccl_thunk) {
CHECK(crs->operand(0)->shape().IsArray())
<< "Operands to all-reduce must be arrays: " << crs->ToString();
AddThunkToThunkSequence(absl::make_unique<NcclAllReduceThunk>(
/*replica_count=*/hlo_module_config_.replica_count(),
/*elements=*/ShapeUtil::ElementsIn(crs->operand(0)->shape()),
/*source_address=*/GetAllocationSlice(*crs->operand(0)),
/*destination_buffer=*/GetAllocationSlice(*crs), crs));
return Status::OK();
}
if (hlo_module_config_.replica_count() != 1) {
// TODO(b/33011107): Support nontrivial cross replica sum on GPU.
return Unimplemented(
"AllReduce with >1 replica is not implemented on GPU.");
// TODO(b/33011107): Support more AllReduce configurations on GPU.
string message = absl::StrFormat(
"Requested AllReduce not implemented on GPU; replica_count: %d; "
"operand_count: %d; IsCrossReplicaAllReduce: %d; NCCL support: %d",
hlo_module_config_.replica_count(), crs->operand_count(),
crs->IsCrossReplicaAllReduce(), NcclAllReduceThunk::NcclIsEnabled());
if (crs->operand_count() > 0) {
absl::StrAppendFormat(
&message, "; first operand array element-type: %s",
PrimitiveType_Name(crs->operand(0)->shape().element_type()));
}
return Unimplemented("%s", message);
}
// CRS with one operand and one replica is simply the identity function.

View File

@ -0,0 +1,356 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
#include "tensorflow/compiler/xla/util.h"
#if GOOGLE_CUDA
#include "absl/synchronization/blocking_counter.h"
#include "third_party/nccl/nccl.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#endif
namespace xla {
namespace gpu {
/* static */ bool NcclAllReduceThunk::NcclIsEnabled() {
#if GOOGLE_CUDA
return true;
#else
return false;
#endif
}
#if GOOGLE_CUDA
namespace {
// GPU-replica-driving host threads (i.e. the threads that call
// GpuExecutable::Execute) build up this structure to describe their
// participating replica, and then call to
// GlobalRendezvousManager::SubmitParticipant.
struct ParticipantData {
// Number of replicas particiating in the AllReduce.
int64 replica_count;
int64 element_count;
int64 device_ordinal;
int64 generation_counter;
// TODO(b/125951860): We should vet that we're buffer allocating such that
// source_buffer == destination_buffer if that avoids a NCCL copy (will depend
// on how well the NCCL in-place implementation performs vs the out-of-place
// implementation).
se::DeviceMemoryBase source_data;
se::DeviceMemoryBase destination_data;
se::Stream* stream;
NcclAllReduceThunk* originator;
string ToString() const {
return absl::StrFormat(
"ParticipantData{replica_count=%d, element_count=%d, "
"device_ordinal=%d, generation_counter=%d, stream=%p, originator=%p}",
replica_count, element_count, device_ordinal, generation_counter,
stream, originator);
}
};
// Class that gets instantiated as a singleton in GetGlobalRendezvous() to
// coordinate participating threads in performing an AllReduce operation.
//
// This manager is responsible for establishing communication channels and
// ultimately enqueueing the NCCL library operation onto the participating
// streams.
class GlobalRendezvousManager {
public:
// The GpuExecutable-executing threads call this in order to a) establish the
// all-reduce rendezvous and b) enqueue the AllReduce operation on the caller
// thread's associated stream (given in "participant").
//
// Implementation note: since the rendezvous we're creating here is global, we
// try to be paranoid about the fact that the *correct* one is happening. In
// an ideal world we'd have some StreamExecutor se::Platform level construct
// that we could use for cross-device networking primitives (e.g. via a
// NetworkSupport interface) that could be shared between TensorFlow and XLA,
// but this is a reasonable stopgap measure to get multi-GPU-replica up and
// running properly for single-host, single-concurrent-XLA-module usage.
Status SubmitParticipant(ParticipantData participant);
// Returns the current generation number of AllReduce operations.
// (Currently one AllReduce operation occurs per generation.)
int64 GetCurrentGeneration() {
tensorflow::mutex_lock lock(mutex_);
return current_generation_;
}
private:
// Called by the primary thread to set up the communication links.
//
// TODO(b/125951860): This performs lots of (presumably) unnecessary host-side
// synchronization so that we can be paranoid about semantics in the earliest
// implementation. In the limit we should only need to synchronize host
// replica threads when the "number of replicas" or "participating device
// ordinals" change, to set up a new NCCL "communication" context, at which
// point we can enqueue onto device streams without host synchronization in
// our code -- this will likely be helpful for "lots of little AllReduce"
// cases.
Status InitializeCommunicationChannels() EXCLUSIVE_LOCKS_REQUIRED(mutex_);
// Called when all necessary participants are present, the functionality
// that's implemented by all executing threads lives in here.
Status DoAllReduce(ParticipantData data, ncclComm_t comm);
// Puts all state back into a "reset" state for the next generation of
// AllReduce requests.
void DeinitializeGeneration() EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
for (ncclComm_t& comm : comms_) {
ncclCommDestroy(comm);
}
comms_.clear();
participants_.clear();
current_generation_++;
initialized_ = false;
done_ = absl::nullopt;
}
tensorflow::mutex mutex_;
tensorflow::condition_variable all_participants_present_;
tensorflow::condition_variable deinitialized_;
// Communication handles that correspond to the participants below.
std::vector<ncclComm_t> comms_ GUARDED_BY(mutex_);
Status initialize_status_ GUARDED_BY(mutex_);
std::vector<ParticipantData> participants_ GUARDED_BY(mutex_);
int64 current_generation_ GUARDED_BY(mutex_) = 0;
bool initialized_ GUARDED_BY(mutex_) = false;
// The participating threads wait for this to count down in order to know we
// can begin the teardown process.
absl::optional<tensorflow::BlockingCounter> done_;
};
Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
auto all_participants_present = [this, &participant]()
EXCLUSIVE_LOCKS_REQUIRED(mutex_) -> bool {
return participants_.size() >= participant.replica_count;
};
// We remember the participant index at which we are inserted and use that
// same index for referring to auxiliary metadata (e.g. the ncclComm_t handle
// index) below.
int64 index;
{
tensorflow::mutex_lock lock(mutex_);
// Spot check for consistent replica counts among submitting threads.
if (!participants_.empty() &&
(participants_.back().replica_count != participant.replica_count ||
participants_.back().originator != participant.originator)) {
return InvalidArgument(
"Running two XLA modules with AllReduces in parallel is not "
"supported. It is possible this is due to a bug where were try to "
"run two different AllReduces from the same module at once. "
"(Attempted a rendezvous with a different replica count from other "
"participants; existing: %s; submitted: %s)",
participants_.back().ToString(), participant.ToString());
}
index = participants_.size();
participants_.push_back(participant);
if (all_participants_present()) {
all_participants_present_.notify_all();
}
}
// We pull into our thread a) the communication handle and b) whether we're
// the "primary" thread for this rendezvous -- the "primary" thread has some
// additional responsibilities for setup/teardown.
ncclComm_t comm;
bool primary;
{
tensorflow::mutex_lock lock(mutex_);
while (!all_participants_present()) {
// Once all the participants have arrived, all participating threads will
// cross this barrier, though only (the first) one will be the "primary".
all_participants_present_.wait(lock);
}
// Somebody will be the first -- that thread has some additional
// responsibilities.
primary = !initialized_;
CHECK_EQ(participant.generation_counter, current_generation_);
// Bump the generation counter so the other threads know we've completed the
// global rendezvous and have set up the AllReduce.
if (primary) {
VLOG(3) << "Primary initializing accounting data.";
initialized_ = true;
done_.emplace(participant.replica_count);
initialize_status_ = InitializeCommunicationChannels();
VLOG(3) << "Done initializing communication channels; status: "
<< initialize_status_;
if (!initialize_status_.ok()) {
DeinitializeGeneration();
}
}
if (!initialize_status_.ok()) {
// TODO(b/125951860): If this fails once, it will fail forever.
return initialize_status_;
}
comm = comms_[index];
// Drop the lock at the end of scope so other participants may enter.
}
VLOG(3) << "Performing all reduce from device ordinal: "
<< participant.device_ordinal;
Status all_reduce_status = DoAllReduce(participant, comm);
VLOG(3) << "Waiting for all participants to complete enqueue.";
done_->DecrementCount();
if (primary) {
// Primary thread clears out the AllReduce state when everybody is done to
// make it clean-slate for any subsequent AllReduce request (e.g. number of
// replicas may change in the next request).
//
// Note surrounding TODOs for only reinitializing this when the replica
// count / participants actually change -- lots of "playing it safe"
// happening in this first cut.
done_->Wait();
VLOG(3) << "All participants completed enqueue.";
VLOG(3) << "Primary thread clearing.";
tensorflow::mutex_lock lock(mutex_);
DeinitializeGeneration();
VLOG(3) << "Generation is now: " << current_generation_;
deinitialized_.notify_all();
} else {
VLOG(3) << "Waiting to deinitialize.";
tensorflow::mutex_lock lock(mutex_);
while (initialized_) {
deinitialized_.wait(lock);
}
}
VLOG(3) << "Returning status: " << all_reduce_status;
return all_reduce_status;
}
Status GlobalRendezvousManager::InitializeCommunicationChannels() {
std::vector<int> ordinals;
for (ParticipantData& data : participants_) {
ordinals.push_back(data.device_ordinal);
}
comms_.resize(ordinals.size());
VLOG(3) << "Participants: " << participants_.size()
<< "; initializing comms.";
ncclResult_t result = ncclCommInitAll(comms_.data(), comms_.size(),
/*devlist=*/ordinals.data());
if (result != ncclSuccess) {
comms_.clear();
return InternalError(
"Failed to initialize NCCL communication channels for %d participants: "
"%s",
participants_.size(), ncclGetErrorString(result));
}
return Status::OK();
}
Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant,
ncclComm_t comm) {
se::StreamExecutor* executor = participant.stream->parent();
se::cuda::ScopedActivateExecutorContext scoped_context(executor);
cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
participant.stream->implementation()->GpuStreamMemberHack());
VLOG(3) << "Using stream pointer: " << cu_stream
<< " on device: " << participant.device_ordinal;
void* send_buffer = participant.source_data.opaque();
void* recv_buffer = participant.destination_data.opaque();
ncclResult_t result = ncclAllReduce(send_buffer, recv_buffer,
/*count=*/participant.element_count,
/*datatype=*/ncclFloat,
/*op=*/ncclSum,
/*comm=*/comm,
/*stream=*/*cu_stream);
TF_RET_CHECK(ncclSuccess == result)
<< "Failed to perform all-reduce: " << ncclGetErrorString(result);
VLOG(3) << "Done performing all reduce for ordinal: "
<< participant.device_ordinal;
return Status::OK();
}
static GlobalRendezvousManager* GetGlobalRendezvous() {
static auto* manager = new GlobalRendezvousManager;
return manager;
}
} // namespace
Status NcclAllReduceThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
auto* global_rendezvous = GetGlobalRendezvous();
ParticipantData participant;
participant.replica_count = replica_count_;
participant.element_count = element_count_;
participant.device_ordinal = stream->parent()->device_ordinal();
participant.generation_counter = global_rendezvous->GetCurrentGeneration();
participant.source_data = buffer_allocations.GetDeviceAddress(source_buffer_);
participant.destination_data =
buffer_allocations.GetDeviceAddress(destination_buffer_);
participant.stream = stream;
participant.originator = this;
return GetGlobalRendezvous()->SubmitParticipant(std::move(participant));
}
#else
Status NcclAllReduceThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
return Unimplemented(
"NCCL support is not available: this binary was not built with a CUDA "
"compiler, which is necessary to build the NCCL source library.");
}
#endif // GOOGLE_CUDA
NcclAllReduceThunk::NcclAllReduceThunk(
int64 replica_count, int64 element_count,
const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer,
const HloInstruction* all_reduce)
: Thunk(Thunk::kNcclAllReduce, all_reduce),
replica_count_(replica_count),
element_count_(element_count),
source_buffer_(source_buffer),
destination_buffer_(destination_buffer) {}
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,62 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace gpu {
// Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas.
class NcclAllReduceThunk : public Thunk {
public:
// Returns whether NCCL operations appear possible to perform; e.g. if we
// haven't done a build with the CUDA compiler enabled, we can't compile the
// NCCL header, and thus this will be false.
//
// When this is false, the ExecuteOnStream() call will simply return a status
// error.
static bool NcclIsEnabled();
// TODO(b/125951860): Plumb more datatypes / reduction operators. Initial
// implementation is simply F32 summation.
NcclAllReduceThunk(int64 replica_count, int64 element_count,
const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer,
const HloInstruction* all_reduce);
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
HloExecutionProfiler* profiler) override;
private:
const int64 replica_count_;
const int64 element_count_;
const BufferAllocation::Slice source_buffer_;
const BufferAllocation::Slice destination_buffer_;
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_

View File

@ -32,6 +32,8 @@ std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) {
return os << "kCudnnBatchNormForwardInference";
case Thunk::kCudnnBatchNormForwardTraining:
return os << "kCudnnBatchNormForwardTraining";
case Thunk::kNcclAllReduce:
return os << "kNcclAllReduce";
case Thunk::kFft:
return os << "kFft";
case Thunk::kGemm:

View File

@ -48,6 +48,7 @@ class Thunk {
kCudnnBatchNormBackward,
kCudnnBatchNormForwardInference,
kCudnnBatchNormForwardTraining,
kNcclAllReduce,
kFft,
kGemm,
kInfeed,

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@ -269,8 +270,8 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
}
StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options) {
std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options,
bool use_threads) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
CreateExecutable(std::move(module), options.run_hlo_passes));
@ -369,9 +370,39 @@ StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
}
LOG(INFO) << "Replicated execution started";
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> results,
executable->ExecuteOnStreams(service_run_options,
argument_buffer_slices));
std::vector<ScopedShapedBuffer> results;
if (!use_threads) {
TF_ASSIGN_OR_RETURN(results,
executable->ExecuteOnStreams(service_run_options,
argument_buffer_slices));
} else {
tensorflow::mutex mutex;
std::vector<StatusOr<ScopedShapedBuffer>> thread_results(
options.num_replicas);
{
LOG(INFO) << "Creating thread pool for " << options.num_replicas
<< " replicas";
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(),
"replicas", options.num_replicas);
for (int64 i = 0; i < options.num_replicas; ++i) {
pool.Schedule([&, i] {
auto result = executable->ExecuteOnStream(
&service_run_options[i], argument_buffer_slices[i], nullptr);
tensorflow::mutex_lock lock(mutex);
thread_results[i] = std::move(result);
});
}
// Note: the thread pool destructor guarantees it completes all work
// before we leave this scope.
}
for (auto& thread_result : thread_results) {
if (!thread_result.ok()) {
return thread_result.status();
}
results.push_back(std::move(thread_result).ValueOrDie());
}
}
LOG(INFO) << "Replicated execution terminated";
std::vector<Literal> exec_results;

View File

@ -165,9 +165,13 @@ class HloRunner {
// Executes a given HLO module into a set of replicas, and returns a map
// with the replica number as key, and the corresponding returned literal as
// value.
//
// use_threads indicates whether this replicated computation will be executed
// with a thread-per-replica, vs using an implicitly async call such as
// Executable::ExecuteOnStreams.
StatusOr<std::vector<Literal>> ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options);
const ReplicatedExecuteOptions& options, bool use_threads = false);
// If backend is not created in the constructor, creates and returns the
// default backend. If creation fails, crashes the program.

View File

@ -1601,6 +1601,39 @@ xla_test(
],
)
xla_test(
name = "multi_device_all_reduce_test",
srcs = ["multi_device_all_reduce_test.cc"],
backends = ["gpu"],
tags = [
"manual",
"multi_gpu",
"no_oss",
"notap",
],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
xla_test(
name = "bitcast_convert_test",
srcs = ["bitcast_convert_test.cc"],

View File

@ -207,13 +207,14 @@ Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
int64 num_replicas) {
int64 num_replicas, bool use_threads) {
HloRunner::ReplicatedExecuteOptions options;
options.num_replicas = num_replicas;
for (auto argument : arguments) {
options.arguments.push_back(argument);
}
return test_runner_.ExecuteReplicated(std::move(module), options);
return test_runner_.ExecuteReplicated(std::move(module), options,
use_threads);
}
StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(

View File

@ -174,9 +174,13 @@ class HloTestBase : public ::testing::Test {
absl::Span<Literal* const> arguments);
// Executes the given module on multiple replicas.
//
// use_threads indicates whether this replicated computation will be executed
// with a thread-per-replica, vs using an implicitly async call such as
// Executable::ExecuteOnStreams.
StatusOr<std::vector<Literal>> ExecuteReplicated(
std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
int64 num_replicas);
int64 num_replicas, bool use_threads);
// Executes the given hlo module on two backends and compares results.
//

View File

@ -0,0 +1,56 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
namespace xla {
namespace {
class MultiDeviceAllReduceTest : public HloTestBase {};
XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) {
const char* module_str = R"(
HloModule test
add {
x = f32[] parameter(0)
y = f32[] parameter(1)
add = f32[] add(x, y)
}
ENTRY test_computation {
p = f32[3] parameter(0)
ROOT crs = f32[3] all-reduce(p), to_apply=add
})";
auto config = GetModuleConfigForTest();
config.set_replica_count(2);
auto module = ParseHloString(module_str, config).ValueOrDie();
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
auto expected = LiteralUtil::CreateR1<float>({2, 4, 6});
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
ExecuteReplicated(std::move(module), {&literal}, 2,
/*use_threads=*/true));
EXPECT_EQ(expected, results[0]);
EXPECT_EQ(expected, results[1]);
}
} // namespace
} // namespace xla