Split out common NCCL utils.

Changes made while refactoring:
- The rendezvouz is now encapsulated in the code to acquire the NCCL clique.
- Fixed issue where `ncclComm`s might not be cleaned up correctly.
- Don't sort device IDs - order matters for several of the NCCL collective ops.
- Removed support for cleanup in `Rendezvous` as it is no longer needed.

PiperOrigin-RevId: 345156260
Change-Id: I9871093bad4c403876ce5584180a42fedf8d9900
This commit is contained in:
Yujing Zhang 2020-12-01 20:37:28 -08:00 committed by TensorFlower Gardener
parent 90881b041f
commit ffe615e998
9 changed files with 580 additions and 566 deletions

View File

@ -254,6 +254,8 @@ class Rendezvous {
return false;
}
virtual void CleanupImpl(O handle, bool is_primary) {}
tensorflow::mutex mu_;
bool initialized_ TF_GUARDED_BY(mu_) = false;
@ -294,14 +296,34 @@ class Rendezvous {
participant.device_ordinal, participant.stream, key_.ToString());
});
TF_ASSIGN_OR_RETURN(ParticipantImplOutput p, RunCollectiveOp(participant));
StatusOr<ParticipantImplOutput> p_or = RunCollectiveOp(participant);
done_.DecrementCount();
if (!p_or.ok()) {
return p_or.status();
}
ParticipantImplOutput p = p_or.ValueOrDie();
// The primary owns the lock on the NCCL clique. Hold it until all threads
// are done. (We'll release it when we return from this function.)
if (p.is_primary) {
WaitAndLogIfStuck(&done_, [&] {
return absl::StrFormat(
"primary participant waiting for all other participants to "
"complete all-reduce %s",
key_.ToString());
});
}
CleanupImpl(p.custom_output, p.is_primary);
return std::make_pair(p.custom_output, returned_blocking_counter_);
}
const RendezvousKey key_;
tensorflow::BlockingCounter all_participants_present_{
key_.num_local_participants};
tensorflow::BlockingCounter done_{key_.num_local_participants};
// tensorflow::BlockingCounter returned by SubmitParticipant.
std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{

View File

@ -455,17 +455,17 @@ tf_cuda_library(
":thunk",
":gpu_executable_run_options",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings:str_format",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:collective_ops_utils",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla:refcounting_hash_map",
"//tensorflow/compiler/xla/service:collective_ops_utils",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/synchronization",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core/platform:stream_executor_no_cuda",
"//tensorflow/compiler/xla:xla_data_proto_cc",
] + if_cuda([
"//tensorflow/stream_executor/cuda:cuda_activation",
"//tensorflow/stream_executor/cuda:cuda_gpu_executor",
@ -474,53 +474,10 @@ tf_cuda_library(
"//tensorflow/stream_executor/rocm:rocm_gpu_executor",
]) + if_nccl([
":virtual_nccl",
":virtual_nccl_utils",
":virtual_rccl",
]),
)
# First level of nested select. NCCL requires both if_cuda and if_nccl.
filegroup(
name = "nccl_utils_srcs",
srcs = if_nccl(["nccl_utils.cc"]),
)
# First level of nested select. NCCL requires both if_cuda and if_nccl.
filegroup(
name = "nccl_utils_hdrs",
srcs = if_nccl(["nccl_utils.h"]),
)
tf_cuda_library(
name = "nccl_utils",
srcs = if_cuda_or_rocm([":nccl_utils_srcs"]),
hdrs = if_cuda_or_rocm([":nccl_utils_hdrs"]),
deps = if_cuda_or_rocm([
":gpu_executable_run_options",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"//tensorflow/compiler/xla:refcounting_hash_map",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:collective_ops_utils",
"//tensorflow/compiler/xla/service:global_device_id",
"//tensorflow/core:lib",
"//tensorflow/stream_executor/lib",
]) + if_nccl([
":virtual_nccl",
":virtual_rccl",
]),
)
alias(
name = "virtual_nccl_utils",
actual = if_cuda_or_rocm(":nccl_utils", ":empty"),
)
cc_library(
name = "gpu_debug_info_manager",
srcs = [

View File

@ -44,6 +44,11 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
"compiler, which is necessary to build the NCCL source library.");
}
/*static*/ absl::flat_hash_set<GlobalDeviceId>
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
return {};
}
NcclAllReduceThunk::NcclAllReduceThunk(
ThunkInfo thunk_info, NcclAllReduceConfig &&config,
std::vector<NcclAllReduceThunk::Buffer> buffers)

View File

@ -21,7 +21,12 @@ namespace xla {
namespace gpu {
NcclCliqueKey::NcclCliqueKey(std::vector<GlobalDeviceId> devices)
: devices_(std::move(devices)) {}
: devices_(std::move(devices)) {
absl::c_sort(devices_);
CHECK(absl::c_adjacent_find(devices_) == devices_.end())
<< "Duplicate devices are not allowed: "
<< GlobalDeviceIdsToString(devices_);
}
std::string NcclCliqueKey::ToString() const {
return GlobalDeviceIdsToString(devices_);

View File

@ -22,21 +22,40 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#if GOOGLE_CUDA
#include "third_party/nccl/nccl.h"
#elif TENSORFLOW_USE_ROCM
#include "rocm/include/rccl/rccl.h"
#endif
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
#if TENSORFLOW_USE_ROCM
// Local hipify of cuda symbols
#define cudaError_t hipError_t
#define cudaStream_t hipStream_t
#define cudaGetErrorString hipGetErrorString
#define cudaGetDevice hipGetDevice
#define cudaSetDevice hipSetDevice
#define cudaSuccess hipSuccess
#endif
namespace xla {
namespace gpu {
@ -63,6 +82,432 @@ namespace gpu {
return true; // Skylark selects this source file if NCCL is enabled.
}
namespace {
bool IsGlobalNcclConfig() {
static bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr;
return global_nccl_config;
}
// Functions to translate an ncclResult_t/cudaError_t to a Status object. Used
// by the macros below.
Status TranslateStatus(ncclResult_t s, const char* file, int64 line,
const char* expr) {
if (s == ncclSuccess) {
return Status::OK();
}
return tensorflow::errors::Internal(
absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr,
ncclGetErrorString(s)));
}
Status TranslateStatus(cudaError_t s, const char* file, int64 line,
const char* expr) {
if (s == cudaSuccess) {
return Status::OK();
}
return tensorflow::errors::Internal(
absl::StrFormat("%s:%d: CUDA operation %s failed: %s", file, line, expr,
cudaGetErrorString(s)));
}
// Macros to return or warn on CUDA/NCCL errors. (The same macro works for both
// NCCL and CUDA errors.)
//
// It's tempting to say these macros belong in an XLA header somewhere, but in
// practice we don't do much direct-to-CUDA-API stuff outside of this file.
#define XLA_CUDA_RETURN_IF_ERROR(expr) \
do { \
Status s = ::xla::gpu::TranslateStatus(expr, __FILE__, __LINE__, #expr); \
if (!s.ok()) { \
return s; \
} \
} while (0)
#define XLA_CUDA_WARN_IF_ERROR(expr) \
do { \
Status s = ::xla::gpu::TranslateStatus(expr, __FILE__, __LINE__, #expr); \
if (!s.ok()) { \
LOG(ERROR) << s.ToString(); \
} \
} while (0)
// RAII class owning a ncclComm_t, ensuring it doesn't leak.
class NcclComm {
public:
explicit NcclComm(ncclComm_t comm) : comm_(comm) {}
// Movable, but not copyable.
NcclComm(NcclComm&& c) noexcept : comm_(c.comm_) { c.comm_.reset(); }
NcclComm& operator=(NcclComm&& c) noexcept {
comm_ = c.comm_;
c.comm_.reset();
return *this;
}
NcclComm(const NcclComm&) = delete;
NcclComm& operator=(const NcclComm&) = delete;
~NcclComm() {
if (comm_.has_value() && *comm_ != nullptr) {
VLOG(3) << absl::StreamFormat("Destroying comm %p", *comm_);
XLA_CUDA_WARN_IF_ERROR(ncclCommDestroy(*comm_));
}
}
ncclComm_t comm() { return *comm_; }
private:
absl::optional<ncclComm_t> comm_;
};
ncclRedOp_t ReductionKindToNccl(ReductionKind kind) {
switch (kind) {
case ReductionKind::SUM:
return ncclSum;
case ReductionKind::PRODUCT:
return ncclProd;
case ReductionKind::MIN:
return ncclMin;
case ReductionKind::MAX:
return ncclMax;
}
}
absl::optional<ncclDataType_t> DatatypeToNccl(PrimitiveType element_type) {
switch (element_type) {
case S8:
return ncclInt8;
case PRED:
case U8:
return ncclUint8;
case S32:
return ncclInt32;
case U32:
return ncclUint32;
case S64:
return ncclInt64;
case U64:
return ncclUint64;
case F16:
return ncclFloat16;
case F32:
return ncclFloat32;
case F64:
return ncclFloat64;
default:
return absl::nullopt;
}
}
Status StringToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) {
if (str_id.size() != NCCL_UNIQUE_ID_BYTES) {
return InvalidArgument(
"ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(),
NCCL_UNIQUE_ID_BYTES);
}
// NcclUniqueId is internally just a char[].
static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
"NCCL_UNIQUE_ID_BYTES");
std::memcpy(static_cast<void*>(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES);
return Status::OK();
}
// Owns a clique of NCCL comms which can be used for collective operations among
// a particular set of GPUs.
//
// You must ensure this is not in an error state (i.e. status() is OK) before
// touching any other methods.
//
// (Usually allowing objects to be in a constructed-but-uninitialized state is
// an antipattern. We do it here because it allows us to have a
// RefcountingHashMap which contains and automatically constructs NcclCliques.
// This greatly simplifies the rest of this file.)
//
// Note that if you want to do a collective operation among a subset of these
// GPUs, you'll need a different clique.
class NcclClique {
public:
explicit NcclClique(
int64 num_global_devices, std::vector<int64> local_device_ordinals,
std::vector<int64> local_device_ranks,
const StatusOr<absl::optional<std::string>>& nccl_unique_id)
: num_global_devices_(num_global_devices),
local_device_ordinals_(std::move(local_device_ordinals)),
local_device_ranks_(std::move(local_device_ranks)) {
CHECK_EQ(local_device_ordinals_.size(), local_device_ranks_.size());
// It's unusual to pass a StatusOr<> into a class, but since this class
// already has a erroneous state, it turns out to be a little easier to
// implement this way than to change RefcountingHashMap.
status_ = Init(nccl_unique_id);
}
Status status() { return status_; }
// A NCCL communicator is the NCCL state associated with a participant (rank)
// in a reduction. This method returns the state associated with a particular
// local device ordinal.
ncclComm_t comm(int64 device_ordinal) {
int64 idx =
std::distance(local_device_ordinals_.begin(),
absl::c_find(local_device_ordinals_, device_ordinal));
return comms_.at(idx).comm();
}
// These methods let you acquire exclusive access to a NCCL clique, ensuring
// no other NCCL operations are taking place on the clique's comms.
//
// We disable thread-safety analysis because in common use, only the primary
// thread in a Rendezvous acquires this lock, and that makes thread-safety
// analysis unhappy. Tread carefully, you are playing with fire.
void Lock() ABSL_NO_THREAD_SAFETY_ANALYSIS {
TF_CHECK_OK(status_);
mu_->lock();
}
void Unlock() ABSL_NO_THREAD_SAFETY_ANALYSIS {
TF_CHECK_OK(status_);
mu_->unlock();
}
private:
Status Init(
const StatusOr<absl::optional<std::string>>& maybe_nccl_unique_id) {
VLOG(3) << absl::StreamFormat(
"Initializing nccl comms for participant device ordinals %s ranks {%s}",
absl::StrJoin(local_device_ordinals_, ", "),
absl::StrJoin(local_device_ranks_, ", "));
// Restore CUDA device after running this. XLA shouldn't care, but maybe
// another consumer does.
int initial_cuda_device;
XLA_CUDA_RETURN_IF_ERROR(cudaGetDevice(&initial_cuda_device));
auto cuda_device_restorer = MakeCleanup(
[&] { XLA_CUDA_WARN_IF_ERROR(cudaSetDevice(initial_cuda_device)); });
// When using ncclGroupStart/End it seems that the ncclComm_t's are not
// populated until the End() call. This unfortunately makes error handling
// tricky.
std::vector<ncclComm_t> raw_comms(local_device_ordinals_.size(), nullptr);
TF_ASSIGN_OR_RETURN(const absl::optional<std::string>& nccl_id_string,
maybe_nccl_unique_id);
ncclUniqueId nccl_id;
if (nccl_id_string) {
TF_RETURN_IF_ERROR(StringToNcclUniqueId(*nccl_id_string, &nccl_id));
} else {
XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id));
}
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
Status status = [&] {
for (int i = 0; i < local_device_ordinals_.size(); ++i) {
XLA_CUDA_RETURN_IF_ERROR(cudaSetDevice(local_device_ordinals_[i]));
XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i],
num_global_devices_, nccl_id,
local_device_ranks_.at(i)));
}
return Status::OK();
}();
// Always call ncclGroupEnd().
XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
// Populate comms_ from the raw comms we created above. If we encountered
// an error above we'll later clear comms_ thus destroying any raw comms
// that were created before the error.
for (int i = 0; i < local_device_ordinals_.size(); ++i) {
VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
local_device_ordinals_[i], raw_comms[i]);
CHECK(raw_comms[i] != nullptr || !status.ok());
comms_.emplace_back(raw_comms[i]);
}
if (!status.ok()) {
comms_.clear();
}
return status;
}
Status status_;
int64 num_global_devices_;
std::vector<int64> local_device_ordinals_;
// NCCL communicator rank for each local device. The rank of a device is equal
// to the offset of the local device in the global device set.
std::vector<int64> local_device_ranks_;
std::vector<NcclComm> comms_;
// This mutex is in a unique_ptr so NcclClique can be movable.
std::unique_ptr<tensorflow::mutex> mu_ =
absl::make_unique<tensorflow::mutex>();
};
// Global cache of NCCL cliques. An entry in this map is kept alive as long as
// there's a reference to it somewhere. A Thunk holds a reference to each
// Clique it's ever used.
//
// A consequence of the fact that this is process-global is that we'll only ever
// have one clique alive for a given set of GPUs. This means that a process
// will never do two collective operations concurrently on the same set of GPUs.
RefcountingHashMap<NcclCliqueKey, NcclClique>& GlobalNcclCliqueMap() {
static auto& m = *new RefcountingHashMap<NcclCliqueKey, NcclClique>();
return m;
}
using RendezvousBase =
Rendezvous<AllReduceParticipantData, std::shared_ptr<NcclClique>>;
class RendezvousNcclAllReduce : public RendezvousBase {
public:
explicit RendezvousNcclAllReduce(const RendezvousKey& k)
: RendezvousBase(k) {}
protected:
StatusOr<ParticipantImplOutput> RunCollectiveOp(
const AllReduceParticipantData& participant) override;
void CleanupImpl(std::shared_ptr<NcclClique> handle,
bool is_primary) override;
};
// Global map of Rendezvous objects. A thread participating in a collective op
// looks up its Rendezvous in this map to find the other threads that it's
// participating with.
//
// Rendezvous objects are one-time use, so they're removed from this map once
// we're through with them.
RefcountingHashMap<RendezvousKey, RendezvousNcclAllReduce>&
GlobalRendezvousMap() {
static auto& m =
*new RefcountingHashMap<RendezvousKey, RendezvousNcclAllReduce>();
return m;
}
StatusOr<RendezvousNcclAllReduce::ParticipantImplOutput>
RendezvousNcclAllReduce::RunCollectiveOp(
const AllReduceParticipantData& participant) {
// We pull into our thread a) the communication handle and b) whether we're
// the "primary" thread for this rendezvous -- the "primary" thread has some
// additional responsibilities for setup/teardown.
ncclComm_t comm;
bool primary;
std::shared_ptr<NcclClique> clique;
{
tensorflow::mutex_lock lock(mu_);
// The first thread to get here has additional responsibilities, such as
// ensuring that there's a NCCL clique available for us to use.
primary = !initialized_;
TF_RET_CHECK(participant.local_devices.size() ==
participant.rendezvous_key.num_local_participants);
// Look up or create the NCCL clique for this set of devices.
NcclCliqueKey clique_key(participant.rendezvous_key.global_devices);
auto clique_factory =
[&](const NcclCliqueKey& key) -> std::unique_ptr<NcclClique> {
std::vector<int64> local_device_ranks;
std::vector<int64> local_device_ordinals;
local_device_ranks.reserve(participant.local_devices.size());
local_device_ordinals.reserve(participant.local_devices.size());
for (const auto& l : participant.local_devices) {
auto it =
absl::c_find(participant.rendezvous_key.global_devices, l.first);
CHECK(it != participant.rendezvous_key.global_devices.end()) << l.first;
local_device_ranks.push_back(std::distance(
participant.rendezvous_key.global_devices.begin(), it));
local_device_ordinals.push_back(l.second);
}
StatusOr<absl::optional<std::string>> nccl_unique_id;
if (participant.nccl_unique_id_callback) {
nccl_unique_id = (*participant.nccl_unique_id_callback)(clique_key);
} else {
if (participant.rendezvous_key.global_devices.size() !=
participant.rendezvous_key.num_local_participants &&
!IsGlobalNcclConfig()) {
nccl_unique_id = InvalidArgument(
"If not local devices are taking part of a collective API on "
"GPU, the nccl_unique_id_callback must be provided by the "
"client.");
} else {
nccl_unique_id = absl::optional<std::string>();
}
}
return absl::make_unique<NcclClique>(
participant.rendezvous_key.global_devices.size(),
std::move(local_device_ordinals), std::move(local_device_ranks),
nccl_unique_id);
};
clique =
GlobalNcclCliqueMap().GetOrCreateIfAbsent(clique_key, clique_factory);
if (primary) {
VLOG(3) << "Primary initializing accounting data.";
initialized_ = true;
// Acquire exclusive access to the NCCL clique itself so that two
// unrelated collective operations won't try to use the clique
// concurrently.
// We'll unlock it in CleanupImpl.
clique->Lock();
}
if (!clique->status().ok()) {
VLOG(1)
<< "SubmitParticipant failing because clique failed to initialize: "
<< clique->status().ToString();
return clique->status();
}
comm = clique->comm(participant.device_ordinal);
// Drop the lock at the end of scope so other participants may enter.
}
VLOG(3) << "Performing all reduce from device ordinal: "
<< participant.device_ordinal;
ncclRedOp_t computation = ReductionKindToNccl(participant.reduction_kind);
se::StreamExecutor* executor = participant.stream->parent();
se::gpu::ScopedActivateExecutorContext scoped_context(executor);
cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
participant.stream->implementation()->GpuStreamMemberHack());
VLOG(3) << "Using stream pointer: " << cu_stream
<< " on device: " << participant.device_ordinal;
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
for (auto& buffer : participant.buffers) {
void* send_buffer = const_cast<void*>(buffer.source_data.opaque());
void* recv_buffer = const_cast<void*>(buffer.destination_data.opaque());
absl::optional<ncclDataType_t> allreduce_datatype =
DatatypeToNccl(buffer.primitive_type);
CHECK(allreduce_datatype.has_value());
VLOG(3) << absl::StreamFormat(
"Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
"comm=%p, stream=%p)",
send_buffer, recv_buffer, buffer.element_count,
static_cast<const void*>(comm), cu_stream);
XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
/*count=*/buffer.element_count,
/*datatype=*/*allreduce_datatype,
/*op=*/computation,
/*comm=*/comm,
/*stream=*/*cu_stream));
}
XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
VLOG(3) << "Done performing all reduce for ordinal: "
<< participant.device_ordinal;
VLOG(3) << "This thread done with all-reduce op.";
return ParticipantImplOutput{primary, clique};
}
void RendezvousNcclAllReduce::CleanupImpl(std::shared_ptr<NcclClique> handle,
bool is_primary) {
// Releases the lock on the clique (held only by the primary thread).
if (is_primary) {
handle->Unlock();
}
}
} // namespace
// Extra data stored in NcclAllReduceThunk that we didn't want to expose in the
// header. In particular, this stores the thunk's cache of all NcclCliques it's
// ever used. This causes those cliques to stay alive as long as the thunk
@ -105,13 +550,23 @@ NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr,
auto operands_are_supported = [crs]() {
return absl::c_all_of(crs->operands(), [](HloInstruction* operand) {
return LayoutUtil::IsDenseArray(operand->shape()) &&
ToNcclDataType(operand->shape().element_type()).ok();
DatatypeToNccl(operand->shape().element_type()).has_value();
});
};
return MatchReductionComputation(crs->to_apply()).has_value() &&
crs->IsCrossReplicaAllReduce() && operands_are_supported();
}
/*static*/ absl::flat_hash_set<GlobalDeviceId>
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
absl::flat_hash_set<GlobalDeviceId> devices;
GlobalNcclCliqueMap().ForEach(
[&](const NcclCliqueKey& k, const std::shared_ptr<NcclClique>&) {
devices.insert(k.devices().begin(), k.devices().end());
});
return devices;
}
NcclAllReduceThunk::NcclAllReduceThunk(
ThunkInfo thunk_info, NcclAllReduceConfig&& config,
std::vector<NcclAllReduceThunk::Buffer> buffers)
@ -128,87 +583,97 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());
se::StreamExecutor* executor = params.stream->parent();
int device_ordinal = executor->device_ordinal();
int64 local_device_ordinal = params.stream->parent()->device_ordinal();
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
params.GetGlobalDeviceId());
// Determines the set of global and local devices that are participating in
// the same collective group as the caller.
TF_ASSIGN_OR_RETURN(
std::vector<int64> participating_replicas,
std::vector<int64> global_participating_replicas,
GetParticipatingReplicas(global_device_id, config_.replica_groups,
config_.replica_count, *params.device_assn));
if (IsGlobalNcclConfig() &&
participating_replicas.size() != config_.replica_count) {
global_participating_replicas.size() != config_.replica_count) {
return InvalidArgument(
"Partial replica groups are not allowed when using NCCL_COMM_ID "
"environment configuration.");
}
std::vector<GlobalDeviceId> global_devices;
std::vector<std::pair<GlobalDeviceId, int64>> local_devices;
local_devices.reserve(global_participating_replicas.size());
global_devices.reserve(global_participating_replicas.size());
TF_RET_CHECK(params.device_assn->computation_count() == 1)
<< params.device_assn->ToString();
std::vector<GlobalDeviceId> participants;
participants.reserve(participating_replicas.size());
for (int64 replica : participating_replicas) {
participants.emplace_back(
for (int64 replica : global_participating_replicas) {
GlobalDeviceId global_device(
(*params.device_assn)(replica, /*computation=*/0));
global_devices.push_back(global_device);
if (!params.gpu_global_device_ids) {
local_devices.emplace_back(global_device, global_device.value());
} else {
auto it = absl::c_find(*params.gpu_global_device_ids, global_device);
if (it != params.gpu_global_device_ids->end()) {
local_devices.emplace_back(
*it, std::distance(params.gpu_global_device_ids->begin(), it));
}
}
}
TF_ASSIGN_OR_RETURN(
std::vector<LocalParticipant> local_participants,
GetLocalParticipants(participants, params.gpu_global_device_ids));
absl::c_sort(global_devices);
// Create the rendezvous for this collective operation.
RendezvousKey rendezvous_key(params.run_id, std::move(participants),
local_participants.size(),
config_.collective_op_kind, config_.op_id);
TF_ASSIGN_OR_RETURN(
LockedNcclClique locked_clique,
AcquireNcclClique(rendezvous_key, device_ordinal, params.stream,
local_participants, params.nccl_unique_id_callback));
ncclComm_t comm =
locked_clique.clique->GetCommForDeviceOrdinal(device_ordinal);
VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal;
ncclRedOp_t reduction_kind = ToNcclReduction(config_.reduction_kind);
se::gpu::ScopedActivateExecutorContext scoped_context(executor);
cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
params.stream->implementation()->GpuStreamMemberHack());
VLOG(3) << "Using stream pointer: " << cu_stream
<< " on device: " << device_ordinal;
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
for (size_t i = 0; i < buffers_.size(); ++i) {
const Buffer& buffer = buffers_[i];
const void* send_buffer =
params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
.opaque();
void* recv_buffer =
params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
.opaque();
TF_ASSIGN_OR_RETURN(ncclDataType_t datatype,
ToNcclDataType(config_.operand_element_type[i]));
VLOG(3) << absl::StreamFormat(
"Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
"comm=%p, stream=%p)",
send_buffer, recv_buffer, buffer.element_count,
static_cast<const void*>(comm), cu_stream);
XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
/*count=*/buffer.element_count,
datatype,
/*op=*/reduction_kind, comm,
/*stream=*/*cu_stream));
RendezvousKey rendezvous_key(params.run_id, global_devices,
local_devices.size(), config_.collective_op_kind,
config_.op_id);
if (VLOG_IS_ON(2)) {
std::vector<std::string> local_participants;
local_participants.reserve(local_devices.size());
for (const auto& entry : local_devices) {
local_participants.push_back(absl::StrFormat(
"global=%d/local=%d", entry.first.value(), entry.second));
}
VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
<< ", global participating replicas: "
<< absl::StrJoin(global_participating_replicas, ", ")
<< ", global participating devices: "
<< GlobalDeviceIdsToString(global_devices)
<< ", local participants: "
<< absl::StrJoin(local_participants, ",");
}
XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
AllReduceParticipantData participant(rendezvous_key, local_device_ordinal,
params.stream);
for (size_t i = 0; i < buffers_.size(); ++i) {
const NcclAllReduceThunk::Buffer& buffer = buffers_[i];
AllReduceParticipantData::Buffer pbuffer;
pbuffer.element_count = buffer.element_count;
pbuffer.source_data =
params.buffer_allocations->GetDeviceAddress(buffer.source_buffer);
pbuffer.destination_data =
params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer);
pbuffer.primitive_type = config_.operand_element_type[i];
participant.buffers.push_back(pbuffer);
}
participant.local_devices = std::move(local_devices);
participant.nccl_unique_id_callback = params.nccl_unique_id_callback;
participant.reduction_kind = config_.reduction_kind;
VLOG(3) << "Done performing all-reduce for ordinal: " << device_ordinal;
auto rendezvous_factory = [](const RendezvousKey& k) {
return absl::make_unique<RendezvousNcclAllReduce>(k);
};
TF_ASSIGN_OR_RETURN(std::shared_ptr<NcclClique> clique,
RendezvousNcclAllReduce::SubmitParticipant(
[&] {
return GlobalRendezvousMap().GetOrCreateIfAbsent(
rendezvous_key, rendezvous_factory);
},
participant));
// Keep the clique we used alive for as long as this Thunk lives. Creating
// new NCCL cliques is expensive, and this is how we avoid thrashing them.
{
tensorflow::mutex_lock lock(config_.aux_data->mu);
config_.aux_data->cliques.insert(std::move(locked_clique.clique));
config_.aux_data->cliques.insert(std::move(clique));
}
return Status::OK();
}

View File

@ -16,12 +16,17 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@ -62,6 +67,14 @@ class NcclAllReduceThunk : public Thunk {
// error.
static bool NcclIsEnabled();
// Gets the set of devices that have a NCCL channel open. This is primarily
// for testing.
//
// (Indeed, because the NCCL channels are a global variable, in the real
// world, the value returned here is stale as soon as you read it, so it's not
// clear how you *could* use it for anything other than tests.)
static absl::flat_hash_set<GlobalDeviceId> DevicesWithOpenNcclChannels();
struct Buffer {
int64 element_count;
BufferAllocation::Slice source_buffer;

View File

@ -1,311 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
#include <memory>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/global_device_id.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/platform/errors.h"
namespace xla {
namespace gpu {
ncclRedOp_t ToNcclReduction(ReductionKind kind) {
switch (kind) {
case ReductionKind::SUM:
return ncclSum;
case ReductionKind::PRODUCT:
return ncclProd;
case ReductionKind::MIN:
return ncclMin;
case ReductionKind::MAX:
return ncclMax;
}
}
StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type) {
switch (element_type) {
case S8:
return ncclInt8;
case PRED:
case U8:
return ncclUint8;
case S32:
return ncclInt32;
case U32:
return ncclUint32;
case S64:
return ncclInt64;
case U64:
return ncclUint64;
case F16:
return ncclFloat16;
case F32:
return ncclFloat32;
case F64:
return ncclFloat64;
default:
return tensorflow::errors::InvalidArgument(absl::StrFormat(
"Unsupported data type: %s", PrimitiveType_Name(element_type)));
}
}
bool IsGlobalNcclConfig() {
static bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr;
return global_nccl_config;
}
Status ToStatus(ncclResult_t s, const char* file, int64 line,
const char* expr) {
if (s == ncclSuccess) {
return Status::OK();
}
return tensorflow::errors::Internal(
absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr,
ncclGetErrorString(s)));
}
Status ToStatus(cudaError_t s, const char* file, int64 line, const char* expr) {
if (s == cudaSuccess) {
return Status::OK();
}
return tensorflow::errors::Internal(
absl::StrFormat("%s:%d: CUDA operation %s failed: %s", file, line, expr,
cudaGetErrorString(s)));
}
NcclClique::NcclClique(
absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal)
: comms_by_device_ordinal_(std::move(comms_by_device_ordinal)) {}
ncclComm_t NcclClique::GetCommForDeviceOrdinal(int device_ordinal) const {
return comms_by_device_ordinal_.at(device_ordinal).get();
}
namespace {
void DestroyNcclComm(ncclComm_t comm) {
VLOG(3) << absl::StreamFormat("Destroying comm %p", comm);
XLA_CUDA_WARN_IF_ERROR(ncclCommDestroy(comm));
}
Status ToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) {
if (str_id.size() != NCCL_UNIQUE_ID_BYTES) {
return InvalidArgument(
"ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(),
NCCL_UNIQUE_ID_BYTES);
}
// NcclUniqueId is internally just a char[].
static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
"NCCL_UNIQUE_ID_BYTES");
std::memcpy(static_cast<void*>(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES);
return Status::OK();
}
std::string LocalParticipantsToString(
const std::vector<LocalParticipant>& local_participants) {
std::vector<std::string> parts;
for (const LocalParticipant& local_participant : local_participants) {
parts.push_back(absl::StrFormat("%d/rank=%d",
local_participant.device_ordinal,
local_participant.rank));
}
return absl::StrJoin(parts, ",");
}
RefcountingHashMap<NcclCliqueKey, NcclClique>& NcclCliqueCache() {
// Global cache of NCCL cliques. An entry in this map is kept alive as long
// as there's a reference to it somewhere. A Thunk holds a reference to each
// Clique it's ever used.
//
// A consequence of the fact that this is process-global is that we'll only
// ever have one clique alive for a given set of GPUs. This means that a
// process will never do two collective operations concurrently on the same
// set of GPUs.
static auto& cache = *new RefcountingHashMap<NcclCliqueKey, NcclClique>();
return cache;
}
StatusOr<std::unique_ptr<NcclClique>> CreateNcclClique(
const NcclCliqueKey& key,
const std::vector<LocalParticipant>& local_participants,
const NcclUniqueIdCallback* callback) {
int num_participants = key.devices().size();
ncclUniqueId unique_id;
if (callback) { // Multi-host collective.
TF_ASSIGN_OR_RETURN(std::string id_string, (*callback)(key));
TF_RETURN_IF_ERROR(ToNcclUniqueId(id_string, &unique_id));
} else {
TF_RET_CHECK((num_participants == local_participants.size()) ||
IsGlobalNcclConfig())
<< "If non-local devices are taking part of a collective API on GPU, "
"the nccl_unique_id_callback must be provided by the client.";
XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&unique_id));
}
VLOG(3) << "Initializing nccl comms for local participants: "
<< LocalParticipantsToString(local_participants);
// Restore CUDA device after running this. XLA shouldn't care, but maybe
// another consumer does.
int initial_cuda_device;
XLA_CUDA_RETURN_IF_ERROR(cudaGetDevice(&initial_cuda_device));
auto cuda_device_restorer = MakeCleanup(
[&] { XLA_CUDA_WARN_IF_ERROR(cudaSetDevice(initial_cuda_device)); });
// When using ncclGroupStart/End it seems that the ncclComm_t's are not
// populated until the End() call.
std::vector<ncclComm_t> raw_comms(local_participants.size(), nullptr);
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
Status status = [&] {
for (int i = 0; i < local_participants.size(); ++i) {
XLA_CUDA_RETURN_IF_ERROR(
cudaSetDevice(local_participants[i].device_ordinal));
XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i], num_participants,
unique_id,
local_participants[i].rank));
}
return Status::OK();
}();
// Always call ncclGroupEnd().
status.Update(XLA_CUDA_STATUS(ncclGroupEnd()));
// Always copy raw comms to RAII type, so they are cleaned up properly.
absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal(raw_comms.size());
for (int i = 0; i < raw_comms.size(); ++i) {
int device_ordinal = local_participants[i].device_ordinal;
VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
device_ordinal, raw_comms[i]);
CHECK(raw_comms[i] != nullptr || !status.ok());
comms_by_device_ordinal.emplace(device_ordinal,
NcclComm(raw_comms[i], &DestroyNcclComm));
}
// Now we can check if there was an error creating the communicators.
TF_RETURN_IF_ERROR(status);
return std::make_unique<NcclClique>(std::move(comms_by_device_ordinal));
}
struct NcclCliqueParticipantData : public ParticipantData {
using ParticipantData::ParticipantData;
std::string ToString() const override { return ""; }
};
class NcclCliqueRendezvous
: public Rendezvous<NcclCliqueParticipantData, LockedNcclClique> {
public:
NcclCliqueRendezvous(const RendezvousKey& rendezvous_key,
const std::vector<LocalParticipant>& local_participants,
const NcclUniqueIdCallback* callback)
: Rendezvous(rendezvous_key) {
NcclCliqueKey key(std::move(rendezvous_key.global_devices));
maybe_clique_ = NcclCliqueCache().GetOrTryCreateIfAbsent(
key, [&](const NcclCliqueKey& key) {
return CreateNcclClique(key, local_participants, callback);
});
if (maybe_clique_.ok()) {
lock_ = std::make_shared<absl::MutexLock>((*maybe_clique_)->mu());
}
}
StatusOr<ParticipantImplOutput> RunCollectiveOp(
const NcclCliqueParticipantData&) override {
bool primary = InitializationBarrier();
TF_ASSIGN_OR_RETURN(std::shared_ptr<NcclClique> clique, maybe_clique_);
return ParticipantImplOutput{primary, LockedNcclClique{clique, lock_}};
}
private:
StatusOr<std::shared_ptr<NcclClique>> maybe_clique_;
std::shared_ptr<absl::MutexLock> lock_;
};
} // namespace
StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
const std::vector<GlobalDeviceId>& participants,
const std::vector<GlobalDeviceId>* local_devices) {
std::vector<LocalParticipant> local_participants;
if (local_devices) {
absl::flat_hash_map<GlobalDeviceId, int> device_ranks(participants.size());
for (int rank = 0; rank < participants.size(); ++rank) {
auto result = device_ranks.emplace(participants[rank], rank);
TF_RET_CHECK(result.second) << "Duplicate device found";
}
local_participants.reserve(local_devices->size());
for (int device_ordinal = 0; device_ordinal < local_devices->size();
++device_ordinal) {
auto it = device_ranks.find((*local_devices)[device_ordinal]);
if (it != device_ranks.end()) {
local_participants.push_back({device_ordinal, /*rank=*/it->second});
}
}
} else { // Single host, so use identity mapping (device ordinal == id).
local_participants.reserve(participants.size());
for (int rank = 0; rank < participants.size(); ++rank) {
int device_ordinal = participants[rank].value();
local_participants.push_back({device_ordinal, rank});
}
}
return local_participants;
}
StatusOr<LockedNcclClique> AcquireNcclClique(
const RendezvousKey& rendezvous_key, int local_device_ordinal,
se::Stream* stream, const std::vector<LocalParticipant>& local_participants,
const NcclUniqueIdCallback* callback) {
VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
<< ", local participants: "
<< LocalParticipantsToString(local_participants);
static auto& rendezvous_map =
*new RefcountingHashMap<RendezvousKey, NcclCliqueRendezvous>();
NcclCliqueParticipantData participant(rendezvous_key, local_device_ordinal,
stream);
return NcclCliqueRendezvous::SubmitParticipant(
/*rendezvous_getter=*/
[&] {
return rendezvous_map.GetOrCreateIfAbsent(
rendezvous_key, [&](const RendezvousKey& rendezvous_key) {
return std::make_unique<NcclCliqueRendezvous>(
rendezvous_key, local_participants, callback);
});
},
participant);
}
absl::flat_hash_set<GlobalDeviceId> DevicesWithOpenNcclChannels() {
absl::flat_hash_set<GlobalDeviceId> devices;
NcclCliqueCache().ForEach(
[&](const NcclCliqueKey& k, const std::shared_ptr<NcclClique>&) {
devices.insert(k.devices().begin(), k.devices().end());
});
return devices;
}
} // namespace gpu
} // namespace xla

View File

@ -1,136 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
#include <memory>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/synchronization/mutex.h"
#if GOOGLE_CUDA
#include "third_party/nccl/nccl.h"
#elif TENSORFLOW_USE_ROCM
#include "rocm/include/rccl/rccl.h"
#endif
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#if TENSORFLOW_USE_ROCM
// Local hipify of cuda symbols
#define cudaError_t hipError_t
#define cudaStream_t hipStream_t
#define cudaGetErrorString hipGetErrorString
#define cudaGetDevice hipGetDevice
#define cudaSetDevice hipSetDevice
#define cudaSuccess hipSuccess
#endif
namespace xla {
namespace gpu {
ncclRedOp_t ToNcclReduction(ReductionKind kind);
StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type);
bool IsGlobalNcclConfig();
Status ToStatus(ncclResult_t s, const char* file, int64 line, const char* expr);
Status ToStatus(cudaError_t s, const char* file, int64 line, const char* expr);
// Macros to return or warn on CUDA/NCCL errors. (The same macro works for both
// NCCL and CUDA errors.)
//
// It's tempting to say these macros belong in an XLA header somewhere, but in
// practice we don't do much direct-to-CUDA-API stuff outside of this file.
#define XLA_CUDA_STATUS(expr) \
::xla::gpu::ToStatus(expr, __FILE__, __LINE__, #expr)
#define XLA_CUDA_RETURN_IF_ERROR(expr) \
do { \
Status s = XLA_CUDA_STATUS(expr); \
if (!s.ok()) { \
return s; \
} \
} while (0)
#define XLA_CUDA_WARN_IF_ERROR(expr) \
do { \
Status s = XLA_CUDA_STATUS(expr); \
if (!s.ok()) { \
LOG(ERROR) << s.ToString(); \
} \
} while (0)
// RAII type for NCCL communicators.
using NcclComm = std::unique_ptr<ncclComm, void (*)(ncclComm_t)>;
// Owns a clique of NCCL comms which can be used for collective operations among
// a particular set of GPUs.
//
// Note that if you want to do a collective operation among a subset of these
// GPUs, you'll need a different clique.
class NcclClique {
public:
explicit NcclClique(
absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal);
ncclComm_t GetCommForDeviceOrdinal(int device_ordinal) const;
absl::Mutex* mu() { return &mu_; }
private:
absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal_;
absl::Mutex mu_;
};
struct LocalParticipant {
int device_ordinal;
int rank;
};
StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
const std::vector<GlobalDeviceId>& participants,
const std::vector<GlobalDeviceId>* local_devices); // may be null
struct LockedNcclClique {
std::shared_ptr<NcclClique> clique;
// Must come after clique, so it is destroyed first.
// This lock prevents other threads from using this clique. All of the threads
// involved should hold onto the lock until they have finished with their
// communicator.
std::shared_ptr<absl::MutexLock> lock;
};
// Acquires a locked NCCL clique for use in NCCL collective operations.
StatusOr<LockedNcclClique> AcquireNcclClique(
const RendezvousKey& rendezvous_key, int local_device_ordinal,
se::Stream* stream, const std::vector<LocalParticipant>& local_participants,
const NcclUniqueIdCallback* callback); // may be null
// Gets the set of devices that have a NCCL channel open. This is primarily
// for testing.
//
// (Indeed, because the NCCL channels are a global variable, in the real
// world, the value returned here is stale as soon as you read it, so it's not
// clear how you *could* use it for anything other than tests.)
absl::flat_hash_set<GlobalDeviceId> DevicesWithOpenNcclChannels();
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_

View File

@ -17,9 +17,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#ifdef GOOGLE_CUDA
#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
#endif
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@ -162,11 +160,7 @@ DeviceAssignment MakeDeviceAssn(std::vector<int64> devices) {
// Shorter alias for this function.
absl::flat_hash_set<GlobalDeviceId> OpenNcclChannels() {
#ifdef GOOGLE_CUDA
return gpu::DevicesWithOpenNcclChannels();
#else
return {};
#endif
return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels();
}
template <typename T>