Split out common NCCL utils.
Changes made while refactoring: - The rendezvouz is now encapsulated in the code to acquire the NCCL clique. - Fixed issue where `ncclComm`s might not be cleaned up correctly. - Don't sort device IDs - order matters for several of the NCCL collective ops. - Removed support for cleanup in `Rendezvous` as it is no longer needed. PiperOrigin-RevId: 345429594 Change-Id: Ide30958a12fd8d63e16cd764926c844b097c6be1
This commit is contained in:
parent
21f003d154
commit
8d5103318f
@ -254,8 +254,6 @@ class Rendezvous {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual void CleanupImpl(O handle, bool is_primary) {}
|
|
||||||
|
|
||||||
tensorflow::mutex mu_;
|
tensorflow::mutex mu_;
|
||||||
|
|
||||||
bool initialized_ TF_GUARDED_BY(mu_) = false;
|
bool initialized_ TF_GUARDED_BY(mu_) = false;
|
||||||
@ -296,34 +294,14 @@ class Rendezvous {
|
|||||||
participant.device_ordinal, participant.stream, key_.ToString());
|
participant.device_ordinal, participant.stream, key_.ToString());
|
||||||
});
|
});
|
||||||
|
|
||||||
StatusOr<ParticipantImplOutput> p_or = RunCollectiveOp(participant);
|
TF_ASSIGN_OR_RETURN(ParticipantImplOutput p, RunCollectiveOp(participant));
|
||||||
|
|
||||||
done_.DecrementCount();
|
|
||||||
if (!p_or.ok()) {
|
|
||||||
return p_or.status();
|
|
||||||
}
|
|
||||||
ParticipantImplOutput p = p_or.ValueOrDie();
|
|
||||||
|
|
||||||
// The primary owns the lock on the NCCL clique. Hold it until all threads
|
|
||||||
// are done. (We'll release it when we return from this function.)
|
|
||||||
if (p.is_primary) {
|
|
||||||
WaitAndLogIfStuck(&done_, [&] {
|
|
||||||
return absl::StrFormat(
|
|
||||||
"primary participant waiting for all other participants to "
|
|
||||||
"complete all-reduce %s",
|
|
||||||
key_.ToString());
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
CleanupImpl(p.custom_output, p.is_primary);
|
|
||||||
|
|
||||||
return std::make_pair(p.custom_output, returned_blocking_counter_);
|
return std::make_pair(p.custom_output, returned_blocking_counter_);
|
||||||
}
|
}
|
||||||
|
|
||||||
const RendezvousKey key_;
|
const RendezvousKey key_;
|
||||||
|
|
||||||
tensorflow::BlockingCounter all_participants_present_{
|
tensorflow::BlockingCounter all_participants_present_{
|
||||||
key_.num_local_participants};
|
key_.num_local_participants};
|
||||||
tensorflow::BlockingCounter done_{key_.num_local_participants};
|
|
||||||
|
|
||||||
// tensorflow::BlockingCounter returned by SubmitParticipant.
|
// tensorflow::BlockingCounter returned by SubmitParticipant.
|
||||||
std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{
|
std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{
|
||||||
|
@ -455,17 +455,17 @@ tf_cuda_library(
|
|||||||
":thunk",
|
":thunk",
|
||||||
":gpu_executable_run_options",
|
":gpu_executable_run_options",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"//tensorflow/compiler/xla:refcounting_hash_map",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla/service:collective_ops_utils",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
|
||||||
"@com_google_absl//absl/synchronization",
|
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||||
|
"//tensorflow/compiler/xla/service:collective_ops_utils",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||||
|
"//tensorflow/compiler/xla/service:pattern_matcher",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:stream_executor_no_cuda",
|
"//tensorflow/core/platform:stream_executor_no_cuda",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
|
||||||
] + if_cuda([
|
] + if_cuda([
|
||||||
"//tensorflow/stream_executor/cuda:cuda_activation",
|
"//tensorflow/stream_executor/cuda:cuda_activation",
|
||||||
"//tensorflow/stream_executor/cuda:cuda_gpu_executor",
|
"//tensorflow/stream_executor/cuda:cuda_gpu_executor",
|
||||||
@ -474,10 +474,77 @@ tf_cuda_library(
|
|||||||
"//tensorflow/stream_executor/rocm:rocm_gpu_executor",
|
"//tensorflow/stream_executor/rocm:rocm_gpu_executor",
|
||||||
]) + if_nccl([
|
]) + if_nccl([
|
||||||
":virtual_nccl",
|
":virtual_nccl",
|
||||||
|
":virtual_nccl_utils",
|
||||||
":virtual_rccl",
|
":virtual_rccl",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# First level of nested select. NCCL requires both if_cuda and if_nccl.
|
||||||
|
filegroup(
|
||||||
|
name = "nccl_test_utils_src",
|
||||||
|
srcs = if_nccl(
|
||||||
|
["nccl_test_utils.cc"],
|
||||||
|
["dummy_nccl_test_utils.cc"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "nccl_test_utils",
|
||||||
|
srcs = if_cuda_or_rocm(
|
||||||
|
[":nccl_test_utils_src"],
|
||||||
|
["dummy_nccl_test_utils.cc"],
|
||||||
|
),
|
||||||
|
hdrs = ["nccl_test_utils.h"],
|
||||||
|
deps = [
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
"//tensorflow/compiler/xla/service:global_device_id",
|
||||||
|
] + if_nccl([
|
||||||
|
":virtual_nccl_utils",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# First level of nested select. NCCL requires both if_cuda and if_nccl.
|
||||||
|
filegroup(
|
||||||
|
name = "nccl_utils_srcs",
|
||||||
|
srcs = if_nccl(["nccl_utils.cc"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# First level of nested select. NCCL requires both if_cuda and if_nccl.
|
||||||
|
filegroup(
|
||||||
|
name = "nccl_utils_hdrs",
|
||||||
|
srcs = if_nccl(["nccl_utils.h"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "nccl_utils",
|
||||||
|
srcs = if_cuda_or_rocm([":nccl_utils_srcs"]),
|
||||||
|
hdrs = if_cuda_or_rocm([":nccl_utils_hdrs"]),
|
||||||
|
deps = if_cuda_or_rocm([
|
||||||
|
":gpu_executable_run_options",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
|
"//tensorflow/compiler/xla:refcounting_hash_map",
|
||||||
|
"//tensorflow/compiler/xla:status",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
"//tensorflow/compiler/xla/service:collective_ops_utils",
|
||||||
|
"//tensorflow/compiler/xla/service:global_device_id",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/stream_executor/lib",
|
||||||
|
]) + if_nccl([
|
||||||
|
":virtual_nccl",
|
||||||
|
":virtual_rccl",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "virtual_nccl_utils",
|
||||||
|
actual = if_cuda_or_rocm(":nccl_utils", ":empty"),
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gpu_debug_info_manager",
|
name = "gpu_debug_info_manager",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -44,11 +44,6 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
"compiler, which is necessary to build the NCCL source library.");
|
"compiler, which is necessary to build the NCCL source library.");
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ absl::flat_hash_set<GlobalDeviceId>
|
|
||||||
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||||
ThunkInfo thunk_info, NcclAllReduceConfig &&config,
|
ThunkInfo thunk_info, NcclAllReduceConfig &&config,
|
||||||
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
||||||
|
24
tensorflow/compiler/xla/service/gpu/dummy_nccl_test_utils.cc
Normal file
24
tensorflow/compiler/xla/service/gpu/dummy_nccl_test_utils.cc
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/nccl_test_utils.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
absl::flat_hash_set<GlobalDeviceId> DevicesWithOpenNcclChannels() { return {}; }
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace xla
|
@ -21,12 +21,7 @@ namespace xla {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
NcclCliqueKey::NcclCliqueKey(std::vector<GlobalDeviceId> devices)
|
NcclCliqueKey::NcclCliqueKey(std::vector<GlobalDeviceId> devices)
|
||||||
: devices_(std::move(devices)) {
|
: devices_(std::move(devices)) {}
|
||||||
absl::c_sort(devices_);
|
|
||||||
CHECK(absl::c_adjacent_find(devices_) == devices_.end())
|
|
||||||
<< "Duplicate devices are not allowed: "
|
|
||||||
<< GlobalDeviceIdsToString(devices_);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string NcclCliqueKey::ToString() const {
|
std::string NcclCliqueKey::ToString() const {
|
||||||
return GlobalDeviceIdsToString(devices_);
|
return GlobalDeviceIdsToString(devices_);
|
||||||
|
@ -22,40 +22,21 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
|
||||||
#include "absl/base/thread_annotations.h"
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/memory/memory.h"
|
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/strings/str_join.h"
|
|
||||||
#include "absl/types/optional.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/nccl/nccl.h"
|
#include "third_party/nccl/nccl.h"
|
||||||
#elif TENSORFLOW_USE_ROCM
|
#elif TENSORFLOW_USE_ROCM
|
||||||
#include "rocm/include/rccl/rccl.h"
|
#include "rocm/include/rccl/rccl.h"
|
||||||
#endif
|
#endif
|
||||||
#include "tensorflow/compiler/xla/layout_util.h"
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
|
#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
|
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
|
||||||
|
|
||||||
#if TENSORFLOW_USE_ROCM
|
|
||||||
// Local hipify of cuda symbols
|
|
||||||
#define cudaError_t hipError_t
|
|
||||||
#define cudaStream_t hipStream_t
|
|
||||||
#define cudaGetErrorString hipGetErrorString
|
|
||||||
#define cudaGetDevice hipGetDevice
|
|
||||||
#define cudaSetDevice hipSetDevice
|
|
||||||
#define cudaSuccess hipSuccess
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
@ -82,432 +63,6 @@ namespace gpu {
|
|||||||
return true; // Skylark selects this source file if NCCL is enabled.
|
return true; // Skylark selects this source file if NCCL is enabled.
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
bool IsGlobalNcclConfig() {
|
|
||||||
static bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr;
|
|
||||||
return global_nccl_config;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Functions to translate an ncclResult_t/cudaError_t to a Status object. Used
|
|
||||||
// by the macros below.
|
|
||||||
Status TranslateStatus(ncclResult_t s, const char* file, int64 line,
|
|
||||||
const char* expr) {
|
|
||||||
if (s == ncclSuccess) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
return tensorflow::errors::Internal(
|
|
||||||
absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr,
|
|
||||||
ncclGetErrorString(s)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Status TranslateStatus(cudaError_t s, const char* file, int64 line,
|
|
||||||
const char* expr) {
|
|
||||||
if (s == cudaSuccess) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
return tensorflow::errors::Internal(
|
|
||||||
absl::StrFormat("%s:%d: CUDA operation %s failed: %s", file, line, expr,
|
|
||||||
cudaGetErrorString(s)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Macros to return or warn on CUDA/NCCL errors. (The same macro works for both
|
|
||||||
// NCCL and CUDA errors.)
|
|
||||||
//
|
|
||||||
// It's tempting to say these macros belong in an XLA header somewhere, but in
|
|
||||||
// practice we don't do much direct-to-CUDA-API stuff outside of this file.
|
|
||||||
#define XLA_CUDA_RETURN_IF_ERROR(expr) \
|
|
||||||
do { \
|
|
||||||
Status s = ::xla::gpu::TranslateStatus(expr, __FILE__, __LINE__, #expr); \
|
|
||||||
if (!s.ok()) { \
|
|
||||||
return s; \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define XLA_CUDA_WARN_IF_ERROR(expr) \
|
|
||||||
do { \
|
|
||||||
Status s = ::xla::gpu::TranslateStatus(expr, __FILE__, __LINE__, #expr); \
|
|
||||||
if (!s.ok()) { \
|
|
||||||
LOG(ERROR) << s.ToString(); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
// RAII class owning a ncclComm_t, ensuring it doesn't leak.
|
|
||||||
class NcclComm {
|
|
||||||
public:
|
|
||||||
explicit NcclComm(ncclComm_t comm) : comm_(comm) {}
|
|
||||||
|
|
||||||
// Movable, but not copyable.
|
|
||||||
NcclComm(NcclComm&& c) noexcept : comm_(c.comm_) { c.comm_.reset(); }
|
|
||||||
NcclComm& operator=(NcclComm&& c) noexcept {
|
|
||||||
comm_ = c.comm_;
|
|
||||||
c.comm_.reset();
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
NcclComm(const NcclComm&) = delete;
|
|
||||||
NcclComm& operator=(const NcclComm&) = delete;
|
|
||||||
|
|
||||||
~NcclComm() {
|
|
||||||
if (comm_.has_value() && *comm_ != nullptr) {
|
|
||||||
VLOG(3) << absl::StreamFormat("Destroying comm %p", *comm_);
|
|
||||||
XLA_CUDA_WARN_IF_ERROR(ncclCommDestroy(*comm_));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ncclComm_t comm() { return *comm_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
absl::optional<ncclComm_t> comm_;
|
|
||||||
};
|
|
||||||
|
|
||||||
ncclRedOp_t ReductionKindToNccl(ReductionKind kind) {
|
|
||||||
switch (kind) {
|
|
||||||
case ReductionKind::SUM:
|
|
||||||
return ncclSum;
|
|
||||||
case ReductionKind::PRODUCT:
|
|
||||||
return ncclProd;
|
|
||||||
case ReductionKind::MIN:
|
|
||||||
return ncclMin;
|
|
||||||
case ReductionKind::MAX:
|
|
||||||
return ncclMax;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::optional<ncclDataType_t> DatatypeToNccl(PrimitiveType element_type) {
|
|
||||||
switch (element_type) {
|
|
||||||
case S8:
|
|
||||||
return ncclInt8;
|
|
||||||
case PRED:
|
|
||||||
case U8:
|
|
||||||
return ncclUint8;
|
|
||||||
case S32:
|
|
||||||
return ncclInt32;
|
|
||||||
case U32:
|
|
||||||
return ncclUint32;
|
|
||||||
case S64:
|
|
||||||
return ncclInt64;
|
|
||||||
case U64:
|
|
||||||
return ncclUint64;
|
|
||||||
case F16:
|
|
||||||
return ncclFloat16;
|
|
||||||
case F32:
|
|
||||||
return ncclFloat32;
|
|
||||||
case F64:
|
|
||||||
return ncclFloat64;
|
|
||||||
default:
|
|
||||||
return absl::nullopt;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status StringToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) {
|
|
||||||
if (str_id.size() != NCCL_UNIQUE_ID_BYTES) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(),
|
|
||||||
NCCL_UNIQUE_ID_BYTES);
|
|
||||||
}
|
|
||||||
// NcclUniqueId is internally just a char[].
|
|
||||||
static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
|
|
||||||
"NCCL_UNIQUE_ID_BYTES");
|
|
||||||
std::memcpy(static_cast<void*>(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Owns a clique of NCCL comms which can be used for collective operations among
|
|
||||||
// a particular set of GPUs.
|
|
||||||
//
|
|
||||||
// You must ensure this is not in an error state (i.e. status() is OK) before
|
|
||||||
// touching any other methods.
|
|
||||||
//
|
|
||||||
// (Usually allowing objects to be in a constructed-but-uninitialized state is
|
|
||||||
// an antipattern. We do it here because it allows us to have a
|
|
||||||
// RefcountingHashMap which contains and automatically constructs NcclCliques.
|
|
||||||
// This greatly simplifies the rest of this file.)
|
|
||||||
//
|
|
||||||
// Note that if you want to do a collective operation among a subset of these
|
|
||||||
// GPUs, you'll need a different clique.
|
|
||||||
class NcclClique {
|
|
||||||
public:
|
|
||||||
explicit NcclClique(
|
|
||||||
int64 num_global_devices, std::vector<int64> local_device_ordinals,
|
|
||||||
std::vector<int64> local_device_ranks,
|
|
||||||
const StatusOr<absl::optional<std::string>>& nccl_unique_id)
|
|
||||||
: num_global_devices_(num_global_devices),
|
|
||||||
local_device_ordinals_(std::move(local_device_ordinals)),
|
|
||||||
local_device_ranks_(std::move(local_device_ranks)) {
|
|
||||||
CHECK_EQ(local_device_ordinals_.size(), local_device_ranks_.size());
|
|
||||||
// It's unusual to pass a StatusOr<> into a class, but since this class
|
|
||||||
// already has a erroneous state, it turns out to be a little easier to
|
|
||||||
// implement this way than to change RefcountingHashMap.
|
|
||||||
status_ = Init(nccl_unique_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status status() { return status_; }
|
|
||||||
|
|
||||||
// A NCCL communicator is the NCCL state associated with a participant (rank)
|
|
||||||
// in a reduction. This method returns the state associated with a particular
|
|
||||||
// local device ordinal.
|
|
||||||
ncclComm_t comm(int64 device_ordinal) {
|
|
||||||
int64 idx =
|
|
||||||
std::distance(local_device_ordinals_.begin(),
|
|
||||||
absl::c_find(local_device_ordinals_, device_ordinal));
|
|
||||||
return comms_.at(idx).comm();
|
|
||||||
}
|
|
||||||
|
|
||||||
// These methods let you acquire exclusive access to a NCCL clique, ensuring
|
|
||||||
// no other NCCL operations are taking place on the clique's comms.
|
|
||||||
//
|
|
||||||
// We disable thread-safety analysis because in common use, only the primary
|
|
||||||
// thread in a Rendezvous acquires this lock, and that makes thread-safety
|
|
||||||
// analysis unhappy. Tread carefully, you are playing with fire.
|
|
||||||
void Lock() ABSL_NO_THREAD_SAFETY_ANALYSIS {
|
|
||||||
TF_CHECK_OK(status_);
|
|
||||||
mu_->lock();
|
|
||||||
}
|
|
||||||
void Unlock() ABSL_NO_THREAD_SAFETY_ANALYSIS {
|
|
||||||
TF_CHECK_OK(status_);
|
|
||||||
mu_->unlock();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
Status Init(
|
|
||||||
const StatusOr<absl::optional<std::string>>& maybe_nccl_unique_id) {
|
|
||||||
VLOG(3) << absl::StreamFormat(
|
|
||||||
"Initializing nccl comms for participant device ordinals %s ranks {%s}",
|
|
||||||
absl::StrJoin(local_device_ordinals_, ", "),
|
|
||||||
absl::StrJoin(local_device_ranks_, ", "));
|
|
||||||
|
|
||||||
// Restore CUDA device after running this. XLA shouldn't care, but maybe
|
|
||||||
// another consumer does.
|
|
||||||
int initial_cuda_device;
|
|
||||||
XLA_CUDA_RETURN_IF_ERROR(cudaGetDevice(&initial_cuda_device));
|
|
||||||
auto cuda_device_restorer = MakeCleanup(
|
|
||||||
[&] { XLA_CUDA_WARN_IF_ERROR(cudaSetDevice(initial_cuda_device)); });
|
|
||||||
|
|
||||||
// When using ncclGroupStart/End it seems that the ncclComm_t's are not
|
|
||||||
// populated until the End() call. This unfortunately makes error handling
|
|
||||||
// tricky.
|
|
||||||
std::vector<ncclComm_t> raw_comms(local_device_ordinals_.size(), nullptr);
|
|
||||||
TF_ASSIGN_OR_RETURN(const absl::optional<std::string>& nccl_id_string,
|
|
||||||
maybe_nccl_unique_id);
|
|
||||||
ncclUniqueId nccl_id;
|
|
||||||
if (nccl_id_string) {
|
|
||||||
TF_RETURN_IF_ERROR(StringToNcclUniqueId(*nccl_id_string, &nccl_id));
|
|
||||||
} else {
|
|
||||||
XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id));
|
|
||||||
}
|
|
||||||
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
|
|
||||||
Status status = [&] {
|
|
||||||
for (int i = 0; i < local_device_ordinals_.size(); ++i) {
|
|
||||||
XLA_CUDA_RETURN_IF_ERROR(cudaSetDevice(local_device_ordinals_[i]));
|
|
||||||
XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i],
|
|
||||||
num_global_devices_, nccl_id,
|
|
||||||
local_device_ranks_.at(i)));
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}();
|
|
||||||
// Always call ncclGroupEnd().
|
|
||||||
XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
|
|
||||||
|
|
||||||
// Populate comms_ from the raw comms we created above. If we encountered
|
|
||||||
// an error above we'll later clear comms_ thus destroying any raw comms
|
|
||||||
// that were created before the error.
|
|
||||||
for (int i = 0; i < local_device_ordinals_.size(); ++i) {
|
|
||||||
VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
|
|
||||||
local_device_ordinals_[i], raw_comms[i]);
|
|
||||||
CHECK(raw_comms[i] != nullptr || !status.ok());
|
|
||||||
comms_.emplace_back(raw_comms[i]);
|
|
||||||
}
|
|
||||||
if (!status.ok()) {
|
|
||||||
comms_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status status_;
|
|
||||||
int64 num_global_devices_;
|
|
||||||
std::vector<int64> local_device_ordinals_;
|
|
||||||
// NCCL communicator rank for each local device. The rank of a device is equal
|
|
||||||
// to the offset of the local device in the global device set.
|
|
||||||
std::vector<int64> local_device_ranks_;
|
|
||||||
std::vector<NcclComm> comms_;
|
|
||||||
|
|
||||||
// This mutex is in a unique_ptr so NcclClique can be movable.
|
|
||||||
std::unique_ptr<tensorflow::mutex> mu_ =
|
|
||||||
absl::make_unique<tensorflow::mutex>();
|
|
||||||
};
|
|
||||||
|
|
||||||
// Global cache of NCCL cliques. An entry in this map is kept alive as long as
|
|
||||||
// there's a reference to it somewhere. A Thunk holds a reference to each
|
|
||||||
// Clique it's ever used.
|
|
||||||
//
|
|
||||||
// A consequence of the fact that this is process-global is that we'll only ever
|
|
||||||
// have one clique alive for a given set of GPUs. This means that a process
|
|
||||||
// will never do two collective operations concurrently on the same set of GPUs.
|
|
||||||
RefcountingHashMap<NcclCliqueKey, NcclClique>& GlobalNcclCliqueMap() {
|
|
||||||
static auto& m = *new RefcountingHashMap<NcclCliqueKey, NcclClique>();
|
|
||||||
return m;
|
|
||||||
}
|
|
||||||
|
|
||||||
using RendezvousBase =
|
|
||||||
Rendezvous<AllReduceParticipantData, std::shared_ptr<NcclClique>>;
|
|
||||||
class RendezvousNcclAllReduce : public RendezvousBase {
|
|
||||||
public:
|
|
||||||
explicit RendezvousNcclAllReduce(const RendezvousKey& k)
|
|
||||||
: RendezvousBase(k) {}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
StatusOr<ParticipantImplOutput> RunCollectiveOp(
|
|
||||||
const AllReduceParticipantData& participant) override;
|
|
||||||
|
|
||||||
void CleanupImpl(std::shared_ptr<NcclClique> handle,
|
|
||||||
bool is_primary) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Global map of Rendezvous objects. A thread participating in a collective op
|
|
||||||
// looks up its Rendezvous in this map to find the other threads that it's
|
|
||||||
// participating with.
|
|
||||||
//
|
|
||||||
// Rendezvous objects are one-time use, so they're removed from this map once
|
|
||||||
// we're through with them.
|
|
||||||
RefcountingHashMap<RendezvousKey, RendezvousNcclAllReduce>&
|
|
||||||
GlobalRendezvousMap() {
|
|
||||||
static auto& m =
|
|
||||||
*new RefcountingHashMap<RendezvousKey, RendezvousNcclAllReduce>();
|
|
||||||
return m;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<RendezvousNcclAllReduce::ParticipantImplOutput>
|
|
||||||
RendezvousNcclAllReduce::RunCollectiveOp(
|
|
||||||
const AllReduceParticipantData& participant) {
|
|
||||||
// We pull into our thread a) the communication handle and b) whether we're
|
|
||||||
// the "primary" thread for this rendezvous -- the "primary" thread has some
|
|
||||||
// additional responsibilities for setup/teardown.
|
|
||||||
ncclComm_t comm;
|
|
||||||
bool primary;
|
|
||||||
std::shared_ptr<NcclClique> clique;
|
|
||||||
|
|
||||||
{
|
|
||||||
tensorflow::mutex_lock lock(mu_);
|
|
||||||
|
|
||||||
// The first thread to get here has additional responsibilities, such as
|
|
||||||
// ensuring that there's a NCCL clique available for us to use.
|
|
||||||
primary = !initialized_;
|
|
||||||
|
|
||||||
TF_RET_CHECK(participant.local_devices.size() ==
|
|
||||||
participant.rendezvous_key.num_local_participants);
|
|
||||||
|
|
||||||
// Look up or create the NCCL clique for this set of devices.
|
|
||||||
NcclCliqueKey clique_key(participant.rendezvous_key.global_devices);
|
|
||||||
|
|
||||||
auto clique_factory =
|
|
||||||
[&](const NcclCliqueKey& key) -> std::unique_ptr<NcclClique> {
|
|
||||||
std::vector<int64> local_device_ranks;
|
|
||||||
std::vector<int64> local_device_ordinals;
|
|
||||||
local_device_ranks.reserve(participant.local_devices.size());
|
|
||||||
local_device_ordinals.reserve(participant.local_devices.size());
|
|
||||||
for (const auto& l : participant.local_devices) {
|
|
||||||
auto it =
|
|
||||||
absl::c_find(participant.rendezvous_key.global_devices, l.first);
|
|
||||||
CHECK(it != participant.rendezvous_key.global_devices.end()) << l.first;
|
|
||||||
local_device_ranks.push_back(std::distance(
|
|
||||||
participant.rendezvous_key.global_devices.begin(), it));
|
|
||||||
local_device_ordinals.push_back(l.second);
|
|
||||||
}
|
|
||||||
StatusOr<absl::optional<std::string>> nccl_unique_id;
|
|
||||||
if (participant.nccl_unique_id_callback) {
|
|
||||||
nccl_unique_id = (*participant.nccl_unique_id_callback)(clique_key);
|
|
||||||
} else {
|
|
||||||
if (participant.rendezvous_key.global_devices.size() !=
|
|
||||||
participant.rendezvous_key.num_local_participants &&
|
|
||||||
!IsGlobalNcclConfig()) {
|
|
||||||
nccl_unique_id = InvalidArgument(
|
|
||||||
"If not local devices are taking part of a collective API on "
|
|
||||||
"GPU, the nccl_unique_id_callback must be provided by the "
|
|
||||||
"client.");
|
|
||||||
} else {
|
|
||||||
nccl_unique_id = absl::optional<std::string>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return absl::make_unique<NcclClique>(
|
|
||||||
participant.rendezvous_key.global_devices.size(),
|
|
||||||
std::move(local_device_ordinals), std::move(local_device_ranks),
|
|
||||||
nccl_unique_id);
|
|
||||||
};
|
|
||||||
clique =
|
|
||||||
GlobalNcclCliqueMap().GetOrCreateIfAbsent(clique_key, clique_factory);
|
|
||||||
|
|
||||||
if (primary) {
|
|
||||||
VLOG(3) << "Primary initializing accounting data.";
|
|
||||||
initialized_ = true;
|
|
||||||
|
|
||||||
// Acquire exclusive access to the NCCL clique itself so that two
|
|
||||||
// unrelated collective operations won't try to use the clique
|
|
||||||
// concurrently.
|
|
||||||
// We'll unlock it in CleanupImpl.
|
|
||||||
clique->Lock();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!clique->status().ok()) {
|
|
||||||
VLOG(1)
|
|
||||||
<< "SubmitParticipant failing because clique failed to initialize: "
|
|
||||||
<< clique->status().ToString();
|
|
||||||
return clique->status();
|
|
||||||
}
|
|
||||||
|
|
||||||
comm = clique->comm(participant.device_ordinal);
|
|
||||||
|
|
||||||
// Drop the lock at the end of scope so other participants may enter.
|
|
||||||
}
|
|
||||||
|
|
||||||
VLOG(3) << "Performing all reduce from device ordinal: "
|
|
||||||
<< participant.device_ordinal;
|
|
||||||
ncclRedOp_t computation = ReductionKindToNccl(participant.reduction_kind);
|
|
||||||
|
|
||||||
se::StreamExecutor* executor = participant.stream->parent();
|
|
||||||
se::gpu::ScopedActivateExecutorContext scoped_context(executor);
|
|
||||||
cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
|
|
||||||
participant.stream->implementation()->GpuStreamMemberHack());
|
|
||||||
VLOG(3) << "Using stream pointer: " << cu_stream
|
|
||||||
<< " on device: " << participant.device_ordinal;
|
|
||||||
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
|
|
||||||
for (auto& buffer : participant.buffers) {
|
|
||||||
void* send_buffer = const_cast<void*>(buffer.source_data.opaque());
|
|
||||||
void* recv_buffer = const_cast<void*>(buffer.destination_data.opaque());
|
|
||||||
absl::optional<ncclDataType_t> allreduce_datatype =
|
|
||||||
DatatypeToNccl(buffer.primitive_type);
|
|
||||||
CHECK(allreduce_datatype.has_value());
|
|
||||||
VLOG(3) << absl::StreamFormat(
|
|
||||||
"Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
|
|
||||||
"comm=%p, stream=%p)",
|
|
||||||
send_buffer, recv_buffer, buffer.element_count,
|
|
||||||
static_cast<const void*>(comm), cu_stream);
|
|
||||||
XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
|
|
||||||
/*count=*/buffer.element_count,
|
|
||||||
/*datatype=*/*allreduce_datatype,
|
|
||||||
/*op=*/computation,
|
|
||||||
/*comm=*/comm,
|
|
||||||
/*stream=*/*cu_stream));
|
|
||||||
}
|
|
||||||
XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
|
|
||||||
|
|
||||||
VLOG(3) << "Done performing all reduce for ordinal: "
|
|
||||||
<< participant.device_ordinal;
|
|
||||||
VLOG(3) << "This thread done with all-reduce op.";
|
|
||||||
|
|
||||||
return ParticipantImplOutput{primary, clique};
|
|
||||||
}
|
|
||||||
|
|
||||||
void RendezvousNcclAllReduce::CleanupImpl(std::shared_ptr<NcclClique> handle,
|
|
||||||
bool is_primary) {
|
|
||||||
// Releases the lock on the clique (held only by the primary thread).
|
|
||||||
if (is_primary) {
|
|
||||||
handle->Unlock();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// Extra data stored in NcclAllReduceThunk that we didn't want to expose in the
|
// Extra data stored in NcclAllReduceThunk that we didn't want to expose in the
|
||||||
// header. In particular, this stores the thunk's cache of all NcclCliques it's
|
// header. In particular, this stores the thunk's cache of all NcclCliques it's
|
||||||
// ever used. This causes those cliques to stay alive as long as the thunk
|
// ever used. This causes those cliques to stay alive as long as the thunk
|
||||||
@ -550,23 +105,13 @@ NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr,
|
|||||||
auto operands_are_supported = [crs]() {
|
auto operands_are_supported = [crs]() {
|
||||||
return absl::c_all_of(crs->operands(), [](HloInstruction* operand) {
|
return absl::c_all_of(crs->operands(), [](HloInstruction* operand) {
|
||||||
return LayoutUtil::IsDenseArray(operand->shape()) &&
|
return LayoutUtil::IsDenseArray(operand->shape()) &&
|
||||||
DatatypeToNccl(operand->shape().element_type()).has_value();
|
ToNcclDataType(operand->shape().element_type()).ok();
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
return MatchReductionComputation(crs->to_apply()).has_value() &&
|
return MatchReductionComputation(crs->to_apply()).has_value() &&
|
||||||
crs->IsCrossReplicaAllReduce() && operands_are_supported();
|
crs->IsCrossReplicaAllReduce() && operands_are_supported();
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ absl::flat_hash_set<GlobalDeviceId>
|
|
||||||
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
|
||||||
absl::flat_hash_set<GlobalDeviceId> devices;
|
|
||||||
GlobalNcclCliqueMap().ForEach(
|
|
||||||
[&](const NcclCliqueKey& k, const std::shared_ptr<NcclClique>&) {
|
|
||||||
devices.insert(k.devices().begin(), k.devices().end());
|
|
||||||
});
|
|
||||||
return devices;
|
|
||||||
}
|
|
||||||
|
|
||||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||||
ThunkInfo thunk_info, NcclAllReduceConfig&& config,
|
ThunkInfo thunk_info, NcclAllReduceConfig&& config,
|
||||||
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
||||||
@ -583,97 +128,87 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
auto op_profiler =
|
auto op_profiler =
|
||||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||||
|
|
||||||
int64 local_device_ordinal = params.stream->parent()->device_ordinal();
|
se::StreamExecutor* executor = params.stream->parent();
|
||||||
|
int device_ordinal = executor->device_ordinal();
|
||||||
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
|
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
|
||||||
params.GetGlobalDeviceId());
|
params.GetGlobalDeviceId());
|
||||||
// Determines the set of global and local devices that are participating in
|
// Determines the set of global and local devices that are participating in
|
||||||
// the same collective group as the caller.
|
// the same collective group as the caller.
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::vector<int64> global_participating_replicas,
|
std::vector<int64> participating_replicas,
|
||||||
GetParticipatingReplicas(global_device_id, config_.replica_groups,
|
GetParticipatingReplicas(global_device_id, config_.replica_groups,
|
||||||
config_.replica_count, *params.device_assn));
|
config_.replica_count, *params.device_assn));
|
||||||
if (IsGlobalNcclConfig() &&
|
if (IsGlobalNcclConfig() &&
|
||||||
global_participating_replicas.size() != config_.replica_count) {
|
participating_replicas.size() != config_.replica_count) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Partial replica groups are not allowed when using NCCL_COMM_ID "
|
"Partial replica groups are not allowed when using NCCL_COMM_ID "
|
||||||
"environment configuration.");
|
"environment configuration.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<GlobalDeviceId> global_devices;
|
|
||||||
std::vector<std::pair<GlobalDeviceId, int64>> local_devices;
|
|
||||||
local_devices.reserve(global_participating_replicas.size());
|
|
||||||
global_devices.reserve(global_participating_replicas.size());
|
|
||||||
TF_RET_CHECK(params.device_assn->computation_count() == 1)
|
TF_RET_CHECK(params.device_assn->computation_count() == 1)
|
||||||
<< params.device_assn->ToString();
|
<< params.device_assn->ToString();
|
||||||
for (int64 replica : global_participating_replicas) {
|
std::vector<GlobalDeviceId> participants;
|
||||||
GlobalDeviceId global_device(
|
participants.reserve(participating_replicas.size());
|
||||||
|
for (int64 replica : participating_replicas) {
|
||||||
|
participants.emplace_back(
|
||||||
(*params.device_assn)(replica, /*computation=*/0));
|
(*params.device_assn)(replica, /*computation=*/0));
|
||||||
global_devices.push_back(global_device);
|
|
||||||
if (!params.gpu_global_device_ids) {
|
|
||||||
local_devices.emplace_back(global_device, global_device.value());
|
|
||||||
} else {
|
|
||||||
auto it = absl::c_find(*params.gpu_global_device_ids, global_device);
|
|
||||||
if (it != params.gpu_global_device_ids->end()) {
|
|
||||||
local_devices.emplace_back(
|
|
||||||
*it, std::distance(params.gpu_global_device_ids->begin(), it));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
absl::c_sort(global_devices);
|
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::vector<LocalParticipant> local_participants,
|
||||||
|
GetLocalParticipants(participants, params.gpu_global_device_ids));
|
||||||
|
|
||||||
// Create the rendezvous for this collective operation.
|
// Create the rendezvous for this collective operation.
|
||||||
RendezvousKey rendezvous_key(params.run_id, global_devices,
|
RendezvousKey rendezvous_key(params.run_id, std::move(participants),
|
||||||
local_devices.size(), config_.collective_op_kind,
|
local_participants.size(),
|
||||||
config_.op_id);
|
config_.collective_op_kind, config_.op_id);
|
||||||
if (VLOG_IS_ON(2)) {
|
|
||||||
std::vector<std::string> local_participants;
|
TF_ASSIGN_OR_RETURN(
|
||||||
local_participants.reserve(local_devices.size());
|
LockedNcclClique locked_clique,
|
||||||
for (const auto& entry : local_devices) {
|
AcquireNcclClique(rendezvous_key, device_ordinal, params.stream,
|
||||||
local_participants.push_back(absl::StrFormat(
|
local_participants, params.nccl_unique_id_callback));
|
||||||
"global=%d/local=%d", entry.first.value(), entry.second));
|
ncclComm_t comm =
|
||||||
}
|
locked_clique.clique->GetCommForDeviceOrdinal(device_ordinal);
|
||||||
VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
|
|
||||||
<< ", global participating replicas: "
|
VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal;
|
||||||
<< absl::StrJoin(global_participating_replicas, ", ")
|
ncclRedOp_t reduction_kind = ToNcclReduction(config_.reduction_kind);
|
||||||
<< ", global participating devices: "
|
|
||||||
<< GlobalDeviceIdsToString(global_devices)
|
se::gpu::ScopedActivateExecutorContext scoped_context(executor);
|
||||||
<< ", local participants: "
|
cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
|
||||||
<< absl::StrJoin(local_participants, ",");
|
params.stream->implementation()->GpuStreamMemberHack());
|
||||||
}
|
VLOG(3) << "Using stream pointer: " << cu_stream
|
||||||
AllReduceParticipantData participant(rendezvous_key, local_device_ordinal,
|
<< " on device: " << device_ordinal;
|
||||||
params.stream);
|
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
|
||||||
for (size_t i = 0; i < buffers_.size(); ++i) {
|
for (size_t i = 0; i < buffers_.size(); ++i) {
|
||||||
const NcclAllReduceThunk::Buffer& buffer = buffers_[i];
|
const Buffer& buffer = buffers_[i];
|
||||||
AllReduceParticipantData::Buffer pbuffer;
|
const void* send_buffer =
|
||||||
pbuffer.element_count = buffer.element_count;
|
params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
|
||||||
pbuffer.source_data =
|
.opaque();
|
||||||
params.buffer_allocations->GetDeviceAddress(buffer.source_buffer);
|
void* recv_buffer =
|
||||||
pbuffer.destination_data =
|
params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
|
||||||
params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer);
|
.opaque();
|
||||||
pbuffer.primitive_type = config_.operand_element_type[i];
|
TF_ASSIGN_OR_RETURN(ncclDataType_t datatype,
|
||||||
participant.buffers.push_back(pbuffer);
|
ToNcclDataType(config_.operand_element_type[i]));
|
||||||
|
VLOG(3) << absl::StreamFormat(
|
||||||
|
"Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
|
||||||
|
"comm=%p, stream=%p)",
|
||||||
|
send_buffer, recv_buffer, buffer.element_count,
|
||||||
|
static_cast<const void*>(comm), cu_stream);
|
||||||
|
XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
|
||||||
|
/*count=*/buffer.element_count,
|
||||||
|
datatype,
|
||||||
|
/*op=*/reduction_kind, comm,
|
||||||
|
/*stream=*/*cu_stream));
|
||||||
}
|
}
|
||||||
participant.local_devices = std::move(local_devices);
|
XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
|
||||||
participant.nccl_unique_id_callback = params.nccl_unique_id_callback;
|
|
||||||
participant.reduction_kind = config_.reduction_kind;
|
|
||||||
|
|
||||||
auto rendezvous_factory = [](const RendezvousKey& k) {
|
VLOG(3) << "Done performing all-reduce for ordinal: " << device_ordinal;
|
||||||
return absl::make_unique<RendezvousNcclAllReduce>(k);
|
|
||||||
};
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<NcclClique> clique,
|
|
||||||
RendezvousNcclAllReduce::SubmitParticipant(
|
|
||||||
[&] {
|
|
||||||
return GlobalRendezvousMap().GetOrCreateIfAbsent(
|
|
||||||
rendezvous_key, rendezvous_factory);
|
|
||||||
},
|
|
||||||
participant));
|
|
||||||
|
|
||||||
// Keep the clique we used alive for as long as this Thunk lives. Creating
|
// Keep the clique we used alive for as long as this Thunk lives. Creating
|
||||||
// new NCCL cliques is expensive, and this is how we avoid thrashing them.
|
// new NCCL cliques is expensive, and this is how we avoid thrashing them.
|
||||||
{
|
{
|
||||||
tensorflow::mutex_lock lock(config_.aux_data->mu);
|
tensorflow::mutex_lock lock(config_.aux_data->mu);
|
||||||
config_.aux_data->cliques.insert(std::move(clique));
|
config_.aux_data->cliques.insert(std::move(locked_clique.clique));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -16,17 +16,12 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
|
||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -67,14 +62,6 @@ class NcclAllReduceThunk : public Thunk {
|
|||||||
// error.
|
// error.
|
||||||
static bool NcclIsEnabled();
|
static bool NcclIsEnabled();
|
||||||
|
|
||||||
// Gets the set of devices that have a NCCL channel open. This is primarily
|
|
||||||
// for testing.
|
|
||||||
//
|
|
||||||
// (Indeed, because the NCCL channels are a global variable, in the real
|
|
||||||
// world, the value returned here is stale as soon as you read it, so it's not
|
|
||||||
// clear how you *could* use it for anything other than tests.)
|
|
||||||
static absl::flat_hash_set<GlobalDeviceId> DevicesWithOpenNcclChannels();
|
|
||||||
|
|
||||||
struct Buffer {
|
struct Buffer {
|
||||||
int64 element_count;
|
int64 element_count;
|
||||||
BufferAllocation::Slice source_buffer;
|
BufferAllocation::Slice source_buffer;
|
||||||
|
35
tensorflow/compiler/xla/service/gpu/nccl_test_utils.cc
Normal file
35
tensorflow/compiler/xla/service/gpu/nccl_test_utils.cc
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/nccl_test_utils.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
absl::flat_hash_set<GlobalDeviceId> DevicesWithOpenNcclChannels() {
|
||||||
|
absl::flat_hash_set<GlobalDeviceId> devices;
|
||||||
|
NcclCliqueCache().ForEach(
|
||||||
|
[&](const NcclCliqueKey& k, const std::shared_ptr<NcclClique>&) {
|
||||||
|
devices.insert(k.devices().begin(), k.devices().end());
|
||||||
|
});
|
||||||
|
return devices;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace xla
|
36
tensorflow/compiler/xla/service/gpu/nccl_test_utils.h
Normal file
36
tensorflow/compiler/xla/service/gpu/nccl_test_utils.h
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_TEST_UTILS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_TEST_UTILS_H_
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// Gets the set of devices that have a NCCL channel open. This is primarily
|
||||||
|
// for testing.
|
||||||
|
//
|
||||||
|
// (Indeed, because the NCCL channels are a global variable, in the real
|
||||||
|
// world, the value returned here is stale as soon as you read it, so it's not
|
||||||
|
// clear how you *could* use it for anything other than tests.)
|
||||||
|
absl::flat_hash_set<GlobalDeviceId> DevicesWithOpenNcclChannels();
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_TEST_UTILS_H_
|
311
tensorflow/compiler/xla/service/gpu/nccl_utils.cc
Normal file
311
tensorflow/compiler/xla/service/gpu/nccl_utils.cc
Normal file
@ -0,0 +1,311 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
ncclRedOp_t ToNcclReduction(ReductionKind kind) {
|
||||||
|
switch (kind) {
|
||||||
|
case ReductionKind::SUM:
|
||||||
|
return ncclSum;
|
||||||
|
case ReductionKind::PRODUCT:
|
||||||
|
return ncclProd;
|
||||||
|
case ReductionKind::MIN:
|
||||||
|
return ncclMin;
|
||||||
|
case ReductionKind::MAX:
|
||||||
|
return ncclMax;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type) {
|
||||||
|
switch (element_type) {
|
||||||
|
case S8:
|
||||||
|
return ncclInt8;
|
||||||
|
case PRED:
|
||||||
|
case U8:
|
||||||
|
return ncclUint8;
|
||||||
|
case S32:
|
||||||
|
return ncclInt32;
|
||||||
|
case U32:
|
||||||
|
return ncclUint32;
|
||||||
|
case S64:
|
||||||
|
return ncclInt64;
|
||||||
|
case U64:
|
||||||
|
return ncclUint64;
|
||||||
|
case F16:
|
||||||
|
return ncclFloat16;
|
||||||
|
case F32:
|
||||||
|
return ncclFloat32;
|
||||||
|
case F64:
|
||||||
|
return ncclFloat64;
|
||||||
|
default:
|
||||||
|
return tensorflow::errors::InvalidArgument(absl::StrFormat(
|
||||||
|
"Unsupported data type: %s", PrimitiveType_Name(element_type)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsGlobalNcclConfig() {
|
||||||
|
static bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr;
|
||||||
|
return global_nccl_config;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ToStatus(ncclResult_t s, const char* file, int64 line,
|
||||||
|
const char* expr) {
|
||||||
|
if (s == ncclSuccess) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return tensorflow::errors::Internal(
|
||||||
|
absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr,
|
||||||
|
ncclGetErrorString(s)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ToStatus(cudaError_t s, const char* file, int64 line, const char* expr) {
|
||||||
|
if (s == cudaSuccess) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return tensorflow::errors::Internal(
|
||||||
|
absl::StrFormat("%s:%d: CUDA operation %s failed: %s", file, line, expr,
|
||||||
|
cudaGetErrorString(s)));
|
||||||
|
}
|
||||||
|
|
||||||
|
NcclClique::NcclClique(
|
||||||
|
absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal)
|
||||||
|
: comms_by_device_ordinal_(std::move(comms_by_device_ordinal)) {}
|
||||||
|
|
||||||
|
ncclComm_t NcclClique::GetCommForDeviceOrdinal(int device_ordinal) const {
|
||||||
|
return comms_by_device_ordinal_.at(device_ordinal).get();
|
||||||
|
}
|
||||||
|
|
||||||
|
RefcountingHashMap<NcclCliqueKey, NcclClique>& NcclCliqueCache() {
|
||||||
|
// Global cache of NCCL cliques. An entry in this map is kept alive as long
|
||||||
|
// as there's a reference to it somewhere. A Thunk holds a reference to each
|
||||||
|
// Clique it's ever used.
|
||||||
|
//
|
||||||
|
// A consequence of the fact that this is process-global is that we'll only
|
||||||
|
// ever have one clique alive for a given set of GPUs. This means that a
|
||||||
|
// process will never do two collective operations concurrently on the same
|
||||||
|
// set of GPUs.
|
||||||
|
static auto& cache = *new RefcountingHashMap<NcclCliqueKey, NcclClique>();
|
||||||
|
return cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void DestroyNcclComm(ncclComm_t comm) {
|
||||||
|
VLOG(3) << absl::StreamFormat("Destroying comm %p", comm);
|
||||||
|
XLA_CUDA_WARN_IF_ERROR(ncclCommDestroy(comm));
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) {
|
||||||
|
if (str_id.size() != NCCL_UNIQUE_ID_BYTES) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(),
|
||||||
|
NCCL_UNIQUE_ID_BYTES);
|
||||||
|
}
|
||||||
|
// NcclUniqueId is internally just a char[].
|
||||||
|
static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
|
||||||
|
"NCCL_UNIQUE_ID_BYTES");
|
||||||
|
std::memcpy(static_cast<void*>(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string LocalParticipantsToString(
|
||||||
|
const std::vector<LocalParticipant>& local_participants) {
|
||||||
|
std::vector<std::string> parts;
|
||||||
|
for (const LocalParticipant& local_participant : local_participants) {
|
||||||
|
parts.push_back(absl::StrFormat("%d/rank=%d",
|
||||||
|
local_participant.device_ordinal,
|
||||||
|
local_participant.rank));
|
||||||
|
}
|
||||||
|
return absl::StrJoin(parts, ",");
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<NcclClique>> CreateNcclClique(
|
||||||
|
const NcclCliqueKey& key,
|
||||||
|
const std::vector<LocalParticipant>& local_participants,
|
||||||
|
const NcclUniqueIdCallback* callback) {
|
||||||
|
int num_participants = key.devices().size();
|
||||||
|
ncclUniqueId unique_id;
|
||||||
|
if (callback) { // Multi-host collective.
|
||||||
|
TF_ASSIGN_OR_RETURN(std::string id_string, (*callback)(key));
|
||||||
|
TF_RETURN_IF_ERROR(ToNcclUniqueId(id_string, &unique_id));
|
||||||
|
} else {
|
||||||
|
TF_RET_CHECK((num_participants == local_participants.size()) ||
|
||||||
|
IsGlobalNcclConfig())
|
||||||
|
<< "If non-local devices are taking part of a collective API on GPU, "
|
||||||
|
"the nccl_unique_id_callback must be provided by the client.";
|
||||||
|
XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&unique_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
VLOG(3) << "Initializing nccl comms for local participants: "
|
||||||
|
<< LocalParticipantsToString(local_participants);
|
||||||
|
|
||||||
|
// Restore CUDA device after running this. XLA shouldn't care, but maybe
|
||||||
|
// another consumer does.
|
||||||
|
int initial_cuda_device;
|
||||||
|
XLA_CUDA_RETURN_IF_ERROR(cudaGetDevice(&initial_cuda_device));
|
||||||
|
auto cuda_device_restorer = MakeCleanup(
|
||||||
|
[&] { XLA_CUDA_WARN_IF_ERROR(cudaSetDevice(initial_cuda_device)); });
|
||||||
|
|
||||||
|
// When using ncclGroupStart/End it seems that the ncclComm_t's are not
|
||||||
|
// populated until the End() call.
|
||||||
|
std::vector<ncclComm_t> raw_comms(local_participants.size(), nullptr);
|
||||||
|
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
|
||||||
|
Status status = [&] {
|
||||||
|
for (int i = 0; i < local_participants.size(); ++i) {
|
||||||
|
XLA_CUDA_RETURN_IF_ERROR(
|
||||||
|
cudaSetDevice(local_participants[i].device_ordinal));
|
||||||
|
XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i], num_participants,
|
||||||
|
unique_id,
|
||||||
|
local_participants[i].rank));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}();
|
||||||
|
// Always call ncclGroupEnd().
|
||||||
|
status.Update(XLA_CUDA_STATUS(ncclGroupEnd()));
|
||||||
|
|
||||||
|
// Always copy raw comms to RAII type, so they are cleaned up properly.
|
||||||
|
absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal(raw_comms.size());
|
||||||
|
for (int i = 0; i < raw_comms.size(); ++i) {
|
||||||
|
int device_ordinal = local_participants[i].device_ordinal;
|
||||||
|
VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
|
||||||
|
device_ordinal, raw_comms[i]);
|
||||||
|
CHECK(raw_comms[i] != nullptr || !status.ok());
|
||||||
|
comms_by_device_ordinal.emplace(device_ordinal,
|
||||||
|
NcclComm(raw_comms[i], &DestroyNcclComm));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we can check if there was an error creating the communicators.
|
||||||
|
TF_RETURN_IF_ERROR(status);
|
||||||
|
return std::make_unique<NcclClique>(std::move(comms_by_device_ordinal));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct NcclCliqueParticipantData : public ParticipantData {
|
||||||
|
using ParticipantData::ParticipantData;
|
||||||
|
std::string ToString() const override { return ""; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class NcclCliqueRendezvous
|
||||||
|
: public Rendezvous<NcclCliqueParticipantData, LockedNcclClique> {
|
||||||
|
public:
|
||||||
|
NcclCliqueRendezvous(const RendezvousKey& rendezvous_key,
|
||||||
|
const std::vector<LocalParticipant>& local_participants,
|
||||||
|
const NcclUniqueIdCallback* callback)
|
||||||
|
: Rendezvous(rendezvous_key),
|
||||||
|
key_(std::move(rendezvous_key.global_devices)),
|
||||||
|
local_participants_(local_participants),
|
||||||
|
callback_(callback) {}
|
||||||
|
|
||||||
|
StatusOr<ParticipantImplOutput> RunCollectiveOp(
|
||||||
|
const NcclCliqueParticipantData&) override {
|
||||||
|
tensorflow::mutex_lock lock(mu_);
|
||||||
|
bool primary = !initialized_;
|
||||||
|
if (primary) {
|
||||||
|
maybe_clique_ = NcclCliqueCache().GetOrTryCreateIfAbsent(
|
||||||
|
key_, [&](const NcclCliqueKey& key) {
|
||||||
|
return CreateNcclClique(key, local_participants_, callback_);
|
||||||
|
});
|
||||||
|
initialized_ = true;
|
||||||
|
}
|
||||||
|
TF_ASSIGN_OR_RETURN(std::shared_ptr<NcclClique> clique, maybe_clique_);
|
||||||
|
if (primary) {
|
||||||
|
lock_ = std::make_shared<absl::MutexLock>(clique->mu());
|
||||||
|
}
|
||||||
|
return ParticipantImplOutput{primary, LockedNcclClique{clique, lock_}};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
NcclCliqueKey key_;
|
||||||
|
const std::vector<LocalParticipant>& local_participants_;
|
||||||
|
const NcclUniqueIdCallback* callback_;
|
||||||
|
|
||||||
|
StatusOr<std::shared_ptr<NcclClique>> maybe_clique_;
|
||||||
|
std::shared_ptr<absl::MutexLock> lock_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
|
||||||
|
const std::vector<GlobalDeviceId>& participants,
|
||||||
|
const std::vector<GlobalDeviceId>* local_devices) {
|
||||||
|
std::vector<LocalParticipant> local_participants;
|
||||||
|
if (local_devices) {
|
||||||
|
absl::flat_hash_map<GlobalDeviceId, int> device_ranks(participants.size());
|
||||||
|
for (int rank = 0; rank < participants.size(); ++rank) {
|
||||||
|
auto result = device_ranks.emplace(participants[rank], rank);
|
||||||
|
TF_RET_CHECK(result.second) << "Duplicate device found";
|
||||||
|
}
|
||||||
|
|
||||||
|
local_participants.reserve(local_devices->size());
|
||||||
|
for (int device_ordinal = 0; device_ordinal < local_devices->size();
|
||||||
|
++device_ordinal) {
|
||||||
|
auto it = device_ranks.find((*local_devices)[device_ordinal]);
|
||||||
|
if (it != device_ranks.end()) {
|
||||||
|
local_participants.push_back({device_ordinal, /*rank=*/it->second});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else { // Single host, so use identity mapping (device ordinal == id).
|
||||||
|
local_participants.reserve(participants.size());
|
||||||
|
for (int rank = 0; rank < participants.size(); ++rank) {
|
||||||
|
int device_ordinal = participants[rank].value();
|
||||||
|
local_participants.push_back({device_ordinal, rank});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return local_participants;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<LockedNcclClique> AcquireNcclClique(
|
||||||
|
const RendezvousKey& rendezvous_key, int local_device_ordinal,
|
||||||
|
se::Stream* stream, const std::vector<LocalParticipant>& local_participants,
|
||||||
|
const NcclUniqueIdCallback* callback) {
|
||||||
|
VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
|
||||||
|
<< ", local participants: "
|
||||||
|
<< LocalParticipantsToString(local_participants);
|
||||||
|
|
||||||
|
static auto& rendezvous_map =
|
||||||
|
*new RefcountingHashMap<RendezvousKey, NcclCliqueRendezvous>();
|
||||||
|
|
||||||
|
NcclCliqueParticipantData participant(rendezvous_key, local_device_ordinal,
|
||||||
|
stream);
|
||||||
|
return NcclCliqueRendezvous::SubmitParticipant(
|
||||||
|
/*rendezvous_getter=*/
|
||||||
|
[&] {
|
||||||
|
return rendezvous_map.GetOrCreateIfAbsent(
|
||||||
|
rendezvous_key, [&](const RendezvousKey& rendezvous_key) {
|
||||||
|
return std::make_unique<NcclCliqueRendezvous>(
|
||||||
|
rendezvous_key, local_participants, callback);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
participant);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace xla
|
131
tensorflow/compiler/xla/service/gpu/nccl_utils.h
Normal file
131
tensorflow/compiler/xla/service/gpu/nccl_utils.h
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "third_party/nccl/nccl.h"
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
#include "rocm/include/rccl/rccl.h"
|
||||||
|
#endif
|
||||||
|
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||||
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
|
||||||
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
// Local hipify of cuda symbols
|
||||||
|
#define cudaError_t hipError_t
|
||||||
|
#define cudaStream_t hipStream_t
|
||||||
|
#define cudaGetErrorString hipGetErrorString
|
||||||
|
#define cudaGetDevice hipGetDevice
|
||||||
|
#define cudaSetDevice hipSetDevice
|
||||||
|
#define cudaSuccess hipSuccess
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
ncclRedOp_t ToNcclReduction(ReductionKind kind);
|
||||||
|
StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type);
|
||||||
|
|
||||||
|
bool IsGlobalNcclConfig();
|
||||||
|
|
||||||
|
Status ToStatus(ncclResult_t s, const char* file, int64 line, const char* expr);
|
||||||
|
Status ToStatus(cudaError_t s, const char* file, int64 line, const char* expr);
|
||||||
|
|
||||||
|
// Macros to return or warn on CUDA/NCCL errors. (The same macro works for both
|
||||||
|
// NCCL and CUDA errors.)
|
||||||
|
//
|
||||||
|
// It's tempting to say these macros belong in an XLA header somewhere, but in
|
||||||
|
// practice we don't do much direct-to-CUDA-API stuff outside of this file.
|
||||||
|
#define XLA_CUDA_STATUS(expr) \
|
||||||
|
::xla::gpu::ToStatus(expr, __FILE__, __LINE__, #expr)
|
||||||
|
|
||||||
|
#define XLA_CUDA_RETURN_IF_ERROR(expr) \
|
||||||
|
do { \
|
||||||
|
Status s = XLA_CUDA_STATUS(expr); \
|
||||||
|
if (!s.ok()) { \
|
||||||
|
return s; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define XLA_CUDA_WARN_IF_ERROR(expr) \
|
||||||
|
do { \
|
||||||
|
Status s = XLA_CUDA_STATUS(expr); \
|
||||||
|
if (!s.ok()) { \
|
||||||
|
LOG(ERROR) << s.ToString(); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
// RAII type for NCCL communicators.
|
||||||
|
using NcclComm = std::unique_ptr<ncclComm, void (*)(ncclComm_t)>;
|
||||||
|
|
||||||
|
// Owns a clique of NCCL comms which can be used for collective operations among
|
||||||
|
// a particular set of GPUs.
|
||||||
|
//
|
||||||
|
// Note that if you want to do a collective operation among a subset of these
|
||||||
|
// GPUs, you'll need a different clique.
|
||||||
|
class NcclClique {
|
||||||
|
public:
|
||||||
|
explicit NcclClique(
|
||||||
|
absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal);
|
||||||
|
|
||||||
|
ncclComm_t GetCommForDeviceOrdinal(int device_ordinal) const;
|
||||||
|
absl::Mutex* mu() { return &mu_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal_;
|
||||||
|
absl::Mutex mu_;
|
||||||
|
};
|
||||||
|
|
||||||
|
RefcountingHashMap<NcclCliqueKey, NcclClique>& NcclCliqueCache();
|
||||||
|
|
||||||
|
struct LocalParticipant {
|
||||||
|
int device_ordinal;
|
||||||
|
int rank;
|
||||||
|
};
|
||||||
|
|
||||||
|
StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
|
||||||
|
const std::vector<GlobalDeviceId>& participants,
|
||||||
|
const std::vector<GlobalDeviceId>* local_devices); // may be null
|
||||||
|
|
||||||
|
struct LockedNcclClique {
|
||||||
|
std::shared_ptr<NcclClique> clique;
|
||||||
|
// Must come after clique, so it is destroyed first.
|
||||||
|
// This lock prevents other threads from using this clique. All of the threads
|
||||||
|
// involved should hold onto the lock until they have finished with their
|
||||||
|
// communicator.
|
||||||
|
std::shared_ptr<absl::MutexLock> lock;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Acquires a locked NCCL clique for use in NCCL collective operations.
|
||||||
|
StatusOr<LockedNcclClique> AcquireNcclClique(
|
||||||
|
const RendezvousKey& rendezvous_key, int local_device_ordinal,
|
||||||
|
se::Stream* stream, const std::vector<LocalParticipant>& local_participants,
|
||||||
|
const NcclUniqueIdCallback* callback); // may be null
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
|
@ -2026,7 +2026,7 @@ xla_test(
|
|||||||
"//tensorflow/compiler/xla/client:local_client",
|
"//tensorflow/compiler/xla/client:local_client",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_runner",
|
"//tensorflow/compiler/xla/service:hlo_runner",
|
||||||
"//tensorflow/compiler/xla/service/gpu:nccl_all_reduce_thunk",
|
"//tensorflow/compiler/xla/service/gpu:nccl_test_utils",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
|
#include "tensorflow/compiler/xla/service/gpu/nccl_test_utils.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
@ -160,7 +160,7 @@ DeviceAssignment MakeDeviceAssn(std::vector<int64> devices) {
|
|||||||
|
|
||||||
// Shorter alias for this function.
|
// Shorter alias for this function.
|
||||||
absl::flat_hash_set<GlobalDeviceId> OpenNcclChannels() {
|
absl::flat_hash_set<GlobalDeviceId> OpenNcclChannels() {
|
||||||
return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels();
|
return gpu::DevicesWithOpenNcclChannels();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user