Split out common NCCL utils.
Changes made while refactoring: - The rendezvouz is now encapsulated in the code to acquire the NCCL clique. - Fixed issue where `ncclComm`s might not be cleaned up correctly. - Don't sort device IDs - order matters for several of the NCCL collective ops. - Removed support for cleanup in `Rendezvous` as it is no longer needed. PiperOrigin-RevId: 345156260 Change-Id: I9871093bad4c403876ce5584180a42fedf8d9900
This commit is contained in:
parent
90881b041f
commit
ffe615e998
@ -254,6 +254,8 @@ class Rendezvous {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual void CleanupImpl(O handle, bool is_primary) {}
|
||||
|
||||
tensorflow::mutex mu_;
|
||||
|
||||
bool initialized_ TF_GUARDED_BY(mu_) = false;
|
||||
@ -294,14 +296,34 @@ class Rendezvous {
|
||||
participant.device_ordinal, participant.stream, key_.ToString());
|
||||
});
|
||||
|
||||
TF_ASSIGN_OR_RETURN(ParticipantImplOutput p, RunCollectiveOp(participant));
|
||||
StatusOr<ParticipantImplOutput> p_or = RunCollectiveOp(participant);
|
||||
|
||||
done_.DecrementCount();
|
||||
if (!p_or.ok()) {
|
||||
return p_or.status();
|
||||
}
|
||||
ParticipantImplOutput p = p_or.ValueOrDie();
|
||||
|
||||
// The primary owns the lock on the NCCL clique. Hold it until all threads
|
||||
// are done. (We'll release it when we return from this function.)
|
||||
if (p.is_primary) {
|
||||
WaitAndLogIfStuck(&done_, [&] {
|
||||
return absl::StrFormat(
|
||||
"primary participant waiting for all other participants to "
|
||||
"complete all-reduce %s",
|
||||
key_.ToString());
|
||||
});
|
||||
}
|
||||
|
||||
CleanupImpl(p.custom_output, p.is_primary);
|
||||
|
||||
return std::make_pair(p.custom_output, returned_blocking_counter_);
|
||||
}
|
||||
|
||||
const RendezvousKey key_;
|
||||
|
||||
tensorflow::BlockingCounter all_participants_present_{
|
||||
key_.num_local_participants};
|
||||
tensorflow::BlockingCounter done_{key_.num_local_participants};
|
||||
|
||||
// tensorflow::BlockingCounter returned by SubmitParticipant.
|
||||
std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{
|
||||
|
@ -455,17 +455,17 @@ tf_cuda_library(
|
||||
":thunk",
|
||||
":gpu_executable_run_options",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:collective_ops_utils",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
||||
"//tensorflow/compiler/xla:refcounting_hash_map",
|
||||
"//tensorflow/compiler/xla/service:collective_ops_utils",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:stream_executor_no_cuda",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
] + if_cuda([
|
||||
"//tensorflow/stream_executor/cuda:cuda_activation",
|
||||
"//tensorflow/stream_executor/cuda:cuda_gpu_executor",
|
||||
@ -474,53 +474,10 @@ tf_cuda_library(
|
||||
"//tensorflow/stream_executor/rocm:rocm_gpu_executor",
|
||||
]) + if_nccl([
|
||||
":virtual_nccl",
|
||||
":virtual_nccl_utils",
|
||||
":virtual_rccl",
|
||||
]),
|
||||
)
|
||||
|
||||
# First level of nested select. NCCL requires both if_cuda and if_nccl.
|
||||
filegroup(
|
||||
name = "nccl_utils_srcs",
|
||||
srcs = if_nccl(["nccl_utils.cc"]),
|
||||
)
|
||||
|
||||
# First level of nested select. NCCL requires both if_cuda and if_nccl.
|
||||
filegroup(
|
||||
name = "nccl_utils_hdrs",
|
||||
srcs = if_nccl(["nccl_utils.h"]),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "nccl_utils",
|
||||
srcs = if_cuda_or_rocm([":nccl_utils_srcs"]),
|
||||
hdrs = if_cuda_or_rocm([":nccl_utils_hdrs"]),
|
||||
deps = if_cuda_or_rocm([
|
||||
":gpu_executable_run_options",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"//tensorflow/compiler/xla:refcounting_hash_map",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:collective_ops_utils",
|
||||
"//tensorflow/compiler/xla/service:global_device_id",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
]) + if_nccl([
|
||||
":virtual_nccl",
|
||||
":virtual_rccl",
|
||||
]),
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "virtual_nccl_utils",
|
||||
actual = if_cuda_or_rocm(":nccl_utils", ":empty"),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_debug_info_manager",
|
||||
srcs = [
|
||||
|
@ -44,6 +44,11 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
"compiler, which is necessary to build the NCCL source library.");
|
||||
}
|
||||
|
||||
/*static*/ absl::flat_hash_set<GlobalDeviceId>
|
||||
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
return {};
|
||||
}
|
||||
|
||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||
ThunkInfo thunk_info, NcclAllReduceConfig &&config,
|
||||
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
||||
|
@ -21,7 +21,12 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
NcclCliqueKey::NcclCliqueKey(std::vector<GlobalDeviceId> devices)
|
||||
: devices_(std::move(devices)) {}
|
||||
: devices_(std::move(devices)) {
|
||||
absl::c_sort(devices_);
|
||||
CHECK(absl::c_adjacent_find(devices_) == devices_.end())
|
||||
<< "Duplicate devices are not allowed: "
|
||||
<< GlobalDeviceIdsToString(devices_);
|
||||
}
|
||||
|
||||
std::string NcclCliqueKey::ToString() const {
|
||||
return GlobalDeviceIdsToString(devices_);
|
||||
|
@ -22,21 +22,40 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/nccl/nccl.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/rccl/rccl.h"
|
||||
#endif
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
|
||||
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
// Local hipify of cuda symbols
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaStream_t hipStream_t
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#define cudaGetDevice hipGetDevice
|
||||
#define cudaSetDevice hipSetDevice
|
||||
#define cudaSuccess hipSuccess
|
||||
#endif
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
@ -63,6 +82,432 @@ namespace gpu {
|
||||
return true; // Skylark selects this source file if NCCL is enabled.
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool IsGlobalNcclConfig() {
|
||||
static bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr;
|
||||
return global_nccl_config;
|
||||
}
|
||||
|
||||
// Functions to translate an ncclResult_t/cudaError_t to a Status object. Used
|
||||
// by the macros below.
|
||||
Status TranslateStatus(ncclResult_t s, const char* file, int64 line,
|
||||
const char* expr) {
|
||||
if (s == ncclSuccess) {
|
||||
return Status::OK();
|
||||
}
|
||||
return tensorflow::errors::Internal(
|
||||
absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr,
|
||||
ncclGetErrorString(s)));
|
||||
}
|
||||
|
||||
Status TranslateStatus(cudaError_t s, const char* file, int64 line,
|
||||
const char* expr) {
|
||||
if (s == cudaSuccess) {
|
||||
return Status::OK();
|
||||
}
|
||||
return tensorflow::errors::Internal(
|
||||
absl::StrFormat("%s:%d: CUDA operation %s failed: %s", file, line, expr,
|
||||
cudaGetErrorString(s)));
|
||||
}
|
||||
|
||||
// Macros to return or warn on CUDA/NCCL errors. (The same macro works for both
|
||||
// NCCL and CUDA errors.)
|
||||
//
|
||||
// It's tempting to say these macros belong in an XLA header somewhere, but in
|
||||
// practice we don't do much direct-to-CUDA-API stuff outside of this file.
|
||||
#define XLA_CUDA_RETURN_IF_ERROR(expr) \
|
||||
do { \
|
||||
Status s = ::xla::gpu::TranslateStatus(expr, __FILE__, __LINE__, #expr); \
|
||||
if (!s.ok()) { \
|
||||
return s; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define XLA_CUDA_WARN_IF_ERROR(expr) \
|
||||
do { \
|
||||
Status s = ::xla::gpu::TranslateStatus(expr, __FILE__, __LINE__, #expr); \
|
||||
if (!s.ok()) { \
|
||||
LOG(ERROR) << s.ToString(); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// RAII class owning a ncclComm_t, ensuring it doesn't leak.
|
||||
class NcclComm {
|
||||
public:
|
||||
explicit NcclComm(ncclComm_t comm) : comm_(comm) {}
|
||||
|
||||
// Movable, but not copyable.
|
||||
NcclComm(NcclComm&& c) noexcept : comm_(c.comm_) { c.comm_.reset(); }
|
||||
NcclComm& operator=(NcclComm&& c) noexcept {
|
||||
comm_ = c.comm_;
|
||||
c.comm_.reset();
|
||||
return *this;
|
||||
}
|
||||
NcclComm(const NcclComm&) = delete;
|
||||
NcclComm& operator=(const NcclComm&) = delete;
|
||||
|
||||
~NcclComm() {
|
||||
if (comm_.has_value() && *comm_ != nullptr) {
|
||||
VLOG(3) << absl::StreamFormat("Destroying comm %p", *comm_);
|
||||
XLA_CUDA_WARN_IF_ERROR(ncclCommDestroy(*comm_));
|
||||
}
|
||||
}
|
||||
|
||||
ncclComm_t comm() { return *comm_; }
|
||||
|
||||
private:
|
||||
absl::optional<ncclComm_t> comm_;
|
||||
};
|
||||
|
||||
ncclRedOp_t ReductionKindToNccl(ReductionKind kind) {
|
||||
switch (kind) {
|
||||
case ReductionKind::SUM:
|
||||
return ncclSum;
|
||||
case ReductionKind::PRODUCT:
|
||||
return ncclProd;
|
||||
case ReductionKind::MIN:
|
||||
return ncclMin;
|
||||
case ReductionKind::MAX:
|
||||
return ncclMax;
|
||||
}
|
||||
}
|
||||
|
||||
absl::optional<ncclDataType_t> DatatypeToNccl(PrimitiveType element_type) {
|
||||
switch (element_type) {
|
||||
case S8:
|
||||
return ncclInt8;
|
||||
case PRED:
|
||||
case U8:
|
||||
return ncclUint8;
|
||||
case S32:
|
||||
return ncclInt32;
|
||||
case U32:
|
||||
return ncclUint32;
|
||||
case S64:
|
||||
return ncclInt64;
|
||||
case U64:
|
||||
return ncclUint64;
|
||||
case F16:
|
||||
return ncclFloat16;
|
||||
case F32:
|
||||
return ncclFloat32;
|
||||
case F64:
|
||||
return ncclFloat64;
|
||||
default:
|
||||
return absl::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
Status StringToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) {
|
||||
if (str_id.size() != NCCL_UNIQUE_ID_BYTES) {
|
||||
return InvalidArgument(
|
||||
"ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(),
|
||||
NCCL_UNIQUE_ID_BYTES);
|
||||
}
|
||||
// NcclUniqueId is internally just a char[].
|
||||
static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
|
||||
"NCCL_UNIQUE_ID_BYTES");
|
||||
std::memcpy(static_cast<void*>(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Owns a clique of NCCL comms which can be used for collective operations among
|
||||
// a particular set of GPUs.
|
||||
//
|
||||
// You must ensure this is not in an error state (i.e. status() is OK) before
|
||||
// touching any other methods.
|
||||
//
|
||||
// (Usually allowing objects to be in a constructed-but-uninitialized state is
|
||||
// an antipattern. We do it here because it allows us to have a
|
||||
// RefcountingHashMap which contains and automatically constructs NcclCliques.
|
||||
// This greatly simplifies the rest of this file.)
|
||||
//
|
||||
// Note that if you want to do a collective operation among a subset of these
|
||||
// GPUs, you'll need a different clique.
|
||||
class NcclClique {
|
||||
public:
|
||||
explicit NcclClique(
|
||||
int64 num_global_devices, std::vector<int64> local_device_ordinals,
|
||||
std::vector<int64> local_device_ranks,
|
||||
const StatusOr<absl::optional<std::string>>& nccl_unique_id)
|
||||
: num_global_devices_(num_global_devices),
|
||||
local_device_ordinals_(std::move(local_device_ordinals)),
|
||||
local_device_ranks_(std::move(local_device_ranks)) {
|
||||
CHECK_EQ(local_device_ordinals_.size(), local_device_ranks_.size());
|
||||
// It's unusual to pass a StatusOr<> into a class, but since this class
|
||||
// already has a erroneous state, it turns out to be a little easier to
|
||||
// implement this way than to change RefcountingHashMap.
|
||||
status_ = Init(nccl_unique_id);
|
||||
}
|
||||
|
||||
Status status() { return status_; }
|
||||
|
||||
// A NCCL communicator is the NCCL state associated with a participant (rank)
|
||||
// in a reduction. This method returns the state associated with a particular
|
||||
// local device ordinal.
|
||||
ncclComm_t comm(int64 device_ordinal) {
|
||||
int64 idx =
|
||||
std::distance(local_device_ordinals_.begin(),
|
||||
absl::c_find(local_device_ordinals_, device_ordinal));
|
||||
return comms_.at(idx).comm();
|
||||
}
|
||||
|
||||
// These methods let you acquire exclusive access to a NCCL clique, ensuring
|
||||
// no other NCCL operations are taking place on the clique's comms.
|
||||
//
|
||||
// We disable thread-safety analysis because in common use, only the primary
|
||||
// thread in a Rendezvous acquires this lock, and that makes thread-safety
|
||||
// analysis unhappy. Tread carefully, you are playing with fire.
|
||||
void Lock() ABSL_NO_THREAD_SAFETY_ANALYSIS {
|
||||
TF_CHECK_OK(status_);
|
||||
mu_->lock();
|
||||
}
|
||||
void Unlock() ABSL_NO_THREAD_SAFETY_ANALYSIS {
|
||||
TF_CHECK_OK(status_);
|
||||
mu_->unlock();
|
||||
}
|
||||
|
||||
private:
|
||||
Status Init(
|
||||
const StatusOr<absl::optional<std::string>>& maybe_nccl_unique_id) {
|
||||
VLOG(3) << absl::StreamFormat(
|
||||
"Initializing nccl comms for participant device ordinals %s ranks {%s}",
|
||||
absl::StrJoin(local_device_ordinals_, ", "),
|
||||
absl::StrJoin(local_device_ranks_, ", "));
|
||||
|
||||
// Restore CUDA device after running this. XLA shouldn't care, but maybe
|
||||
// another consumer does.
|
||||
int initial_cuda_device;
|
||||
XLA_CUDA_RETURN_IF_ERROR(cudaGetDevice(&initial_cuda_device));
|
||||
auto cuda_device_restorer = MakeCleanup(
|
||||
[&] { XLA_CUDA_WARN_IF_ERROR(cudaSetDevice(initial_cuda_device)); });
|
||||
|
||||
// When using ncclGroupStart/End it seems that the ncclComm_t's are not
|
||||
// populated until the End() call. This unfortunately makes error handling
|
||||
// tricky.
|
||||
std::vector<ncclComm_t> raw_comms(local_device_ordinals_.size(), nullptr);
|
||||
TF_ASSIGN_OR_RETURN(const absl::optional<std::string>& nccl_id_string,
|
||||
maybe_nccl_unique_id);
|
||||
ncclUniqueId nccl_id;
|
||||
if (nccl_id_string) {
|
||||
TF_RETURN_IF_ERROR(StringToNcclUniqueId(*nccl_id_string, &nccl_id));
|
||||
} else {
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id));
|
||||
}
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
|
||||
Status status = [&] {
|
||||
for (int i = 0; i < local_device_ordinals_.size(); ++i) {
|
||||
XLA_CUDA_RETURN_IF_ERROR(cudaSetDevice(local_device_ordinals_[i]));
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i],
|
||||
num_global_devices_, nccl_id,
|
||||
local_device_ranks_.at(i)));
|
||||
}
|
||||
return Status::OK();
|
||||
}();
|
||||
// Always call ncclGroupEnd().
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
|
||||
|
||||
// Populate comms_ from the raw comms we created above. If we encountered
|
||||
// an error above we'll later clear comms_ thus destroying any raw comms
|
||||
// that were created before the error.
|
||||
for (int i = 0; i < local_device_ordinals_.size(); ++i) {
|
||||
VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
|
||||
local_device_ordinals_[i], raw_comms[i]);
|
||||
CHECK(raw_comms[i] != nullptr || !status.ok());
|
||||
comms_.emplace_back(raw_comms[i]);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
comms_.clear();
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
Status status_;
|
||||
int64 num_global_devices_;
|
||||
std::vector<int64> local_device_ordinals_;
|
||||
// NCCL communicator rank for each local device. The rank of a device is equal
|
||||
// to the offset of the local device in the global device set.
|
||||
std::vector<int64> local_device_ranks_;
|
||||
std::vector<NcclComm> comms_;
|
||||
|
||||
// This mutex is in a unique_ptr so NcclClique can be movable.
|
||||
std::unique_ptr<tensorflow::mutex> mu_ =
|
||||
absl::make_unique<tensorflow::mutex>();
|
||||
};
|
||||
|
||||
// Global cache of NCCL cliques. An entry in this map is kept alive as long as
|
||||
// there's a reference to it somewhere. A Thunk holds a reference to each
|
||||
// Clique it's ever used.
|
||||
//
|
||||
// A consequence of the fact that this is process-global is that we'll only ever
|
||||
// have one clique alive for a given set of GPUs. This means that a process
|
||||
// will never do two collective operations concurrently on the same set of GPUs.
|
||||
RefcountingHashMap<NcclCliqueKey, NcclClique>& GlobalNcclCliqueMap() {
|
||||
static auto& m = *new RefcountingHashMap<NcclCliqueKey, NcclClique>();
|
||||
return m;
|
||||
}
|
||||
|
||||
using RendezvousBase =
|
||||
Rendezvous<AllReduceParticipantData, std::shared_ptr<NcclClique>>;
|
||||
class RendezvousNcclAllReduce : public RendezvousBase {
|
||||
public:
|
||||
explicit RendezvousNcclAllReduce(const RendezvousKey& k)
|
||||
: RendezvousBase(k) {}
|
||||
|
||||
protected:
|
||||
StatusOr<ParticipantImplOutput> RunCollectiveOp(
|
||||
const AllReduceParticipantData& participant) override;
|
||||
|
||||
void CleanupImpl(std::shared_ptr<NcclClique> handle,
|
||||
bool is_primary) override;
|
||||
};
|
||||
|
||||
// Global map of Rendezvous objects. A thread participating in a collective op
|
||||
// looks up its Rendezvous in this map to find the other threads that it's
|
||||
// participating with.
|
||||
//
|
||||
// Rendezvous objects are one-time use, so they're removed from this map once
|
||||
// we're through with them.
|
||||
RefcountingHashMap<RendezvousKey, RendezvousNcclAllReduce>&
|
||||
GlobalRendezvousMap() {
|
||||
static auto& m =
|
||||
*new RefcountingHashMap<RendezvousKey, RendezvousNcclAllReduce>();
|
||||
return m;
|
||||
}
|
||||
|
||||
StatusOr<RendezvousNcclAllReduce::ParticipantImplOutput>
|
||||
RendezvousNcclAllReduce::RunCollectiveOp(
|
||||
const AllReduceParticipantData& participant) {
|
||||
// We pull into our thread a) the communication handle and b) whether we're
|
||||
// the "primary" thread for this rendezvous -- the "primary" thread has some
|
||||
// additional responsibilities for setup/teardown.
|
||||
ncclComm_t comm;
|
||||
bool primary;
|
||||
std::shared_ptr<NcclClique> clique;
|
||||
|
||||
{
|
||||
tensorflow::mutex_lock lock(mu_);
|
||||
|
||||
// The first thread to get here has additional responsibilities, such as
|
||||
// ensuring that there's a NCCL clique available for us to use.
|
||||
primary = !initialized_;
|
||||
|
||||
TF_RET_CHECK(participant.local_devices.size() ==
|
||||
participant.rendezvous_key.num_local_participants);
|
||||
|
||||
// Look up or create the NCCL clique for this set of devices.
|
||||
NcclCliqueKey clique_key(participant.rendezvous_key.global_devices);
|
||||
|
||||
auto clique_factory =
|
||||
[&](const NcclCliqueKey& key) -> std::unique_ptr<NcclClique> {
|
||||
std::vector<int64> local_device_ranks;
|
||||
std::vector<int64> local_device_ordinals;
|
||||
local_device_ranks.reserve(participant.local_devices.size());
|
||||
local_device_ordinals.reserve(participant.local_devices.size());
|
||||
for (const auto& l : participant.local_devices) {
|
||||
auto it =
|
||||
absl::c_find(participant.rendezvous_key.global_devices, l.first);
|
||||
CHECK(it != participant.rendezvous_key.global_devices.end()) << l.first;
|
||||
local_device_ranks.push_back(std::distance(
|
||||
participant.rendezvous_key.global_devices.begin(), it));
|
||||
local_device_ordinals.push_back(l.second);
|
||||
}
|
||||
StatusOr<absl::optional<std::string>> nccl_unique_id;
|
||||
if (participant.nccl_unique_id_callback) {
|
||||
nccl_unique_id = (*participant.nccl_unique_id_callback)(clique_key);
|
||||
} else {
|
||||
if (participant.rendezvous_key.global_devices.size() !=
|
||||
participant.rendezvous_key.num_local_participants &&
|
||||
!IsGlobalNcclConfig()) {
|
||||
nccl_unique_id = InvalidArgument(
|
||||
"If not local devices are taking part of a collective API on "
|
||||
"GPU, the nccl_unique_id_callback must be provided by the "
|
||||
"client.");
|
||||
} else {
|
||||
nccl_unique_id = absl::optional<std::string>();
|
||||
}
|
||||
}
|
||||
return absl::make_unique<NcclClique>(
|
||||
participant.rendezvous_key.global_devices.size(),
|
||||
std::move(local_device_ordinals), std::move(local_device_ranks),
|
||||
nccl_unique_id);
|
||||
};
|
||||
clique =
|
||||
GlobalNcclCliqueMap().GetOrCreateIfAbsent(clique_key, clique_factory);
|
||||
|
||||
if (primary) {
|
||||
VLOG(3) << "Primary initializing accounting data.";
|
||||
initialized_ = true;
|
||||
|
||||
// Acquire exclusive access to the NCCL clique itself so that two
|
||||
// unrelated collective operations won't try to use the clique
|
||||
// concurrently.
|
||||
// We'll unlock it in CleanupImpl.
|
||||
clique->Lock();
|
||||
}
|
||||
|
||||
if (!clique->status().ok()) {
|
||||
VLOG(1)
|
||||
<< "SubmitParticipant failing because clique failed to initialize: "
|
||||
<< clique->status().ToString();
|
||||
return clique->status();
|
||||
}
|
||||
|
||||
comm = clique->comm(participant.device_ordinal);
|
||||
|
||||
// Drop the lock at the end of scope so other participants may enter.
|
||||
}
|
||||
|
||||
VLOG(3) << "Performing all reduce from device ordinal: "
|
||||
<< participant.device_ordinal;
|
||||
ncclRedOp_t computation = ReductionKindToNccl(participant.reduction_kind);
|
||||
|
||||
se::StreamExecutor* executor = participant.stream->parent();
|
||||
se::gpu::ScopedActivateExecutorContext scoped_context(executor);
|
||||
cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
|
||||
participant.stream->implementation()->GpuStreamMemberHack());
|
||||
VLOG(3) << "Using stream pointer: " << cu_stream
|
||||
<< " on device: " << participant.device_ordinal;
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
|
||||
for (auto& buffer : participant.buffers) {
|
||||
void* send_buffer = const_cast<void*>(buffer.source_data.opaque());
|
||||
void* recv_buffer = const_cast<void*>(buffer.destination_data.opaque());
|
||||
absl::optional<ncclDataType_t> allreduce_datatype =
|
||||
DatatypeToNccl(buffer.primitive_type);
|
||||
CHECK(allreduce_datatype.has_value());
|
||||
VLOG(3) << absl::StreamFormat(
|
||||
"Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
|
||||
"comm=%p, stream=%p)",
|
||||
send_buffer, recv_buffer, buffer.element_count,
|
||||
static_cast<const void*>(comm), cu_stream);
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
|
||||
/*count=*/buffer.element_count,
|
||||
/*datatype=*/*allreduce_datatype,
|
||||
/*op=*/computation,
|
||||
/*comm=*/comm,
|
||||
/*stream=*/*cu_stream));
|
||||
}
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
|
||||
|
||||
VLOG(3) << "Done performing all reduce for ordinal: "
|
||||
<< participant.device_ordinal;
|
||||
VLOG(3) << "This thread done with all-reduce op.";
|
||||
|
||||
return ParticipantImplOutput{primary, clique};
|
||||
}
|
||||
|
||||
void RendezvousNcclAllReduce::CleanupImpl(std::shared_ptr<NcclClique> handle,
|
||||
bool is_primary) {
|
||||
// Releases the lock on the clique (held only by the primary thread).
|
||||
if (is_primary) {
|
||||
handle->Unlock();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Extra data stored in NcclAllReduceThunk that we didn't want to expose in the
|
||||
// header. In particular, this stores the thunk's cache of all NcclCliques it's
|
||||
// ever used. This causes those cliques to stay alive as long as the thunk
|
||||
@ -105,13 +550,23 @@ NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr,
|
||||
auto operands_are_supported = [crs]() {
|
||||
return absl::c_all_of(crs->operands(), [](HloInstruction* operand) {
|
||||
return LayoutUtil::IsDenseArray(operand->shape()) &&
|
||||
ToNcclDataType(operand->shape().element_type()).ok();
|
||||
DatatypeToNccl(operand->shape().element_type()).has_value();
|
||||
});
|
||||
};
|
||||
return MatchReductionComputation(crs->to_apply()).has_value() &&
|
||||
crs->IsCrossReplicaAllReduce() && operands_are_supported();
|
||||
}
|
||||
|
||||
/*static*/ absl::flat_hash_set<GlobalDeviceId>
|
||||
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
absl::flat_hash_set<GlobalDeviceId> devices;
|
||||
GlobalNcclCliqueMap().ForEach(
|
||||
[&](const NcclCliqueKey& k, const std::shared_ptr<NcclClique>&) {
|
||||
devices.insert(k.devices().begin(), k.devices().end());
|
||||
});
|
||||
return devices;
|
||||
}
|
||||
|
||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||
ThunkInfo thunk_info, NcclAllReduceConfig&& config,
|
||||
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
||||
@ -128,87 +583,97 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
|
||||
se::StreamExecutor* executor = params.stream->parent();
|
||||
int device_ordinal = executor->device_ordinal();
|
||||
int64 local_device_ordinal = params.stream->parent()->device_ordinal();
|
||||
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
|
||||
params.GetGlobalDeviceId());
|
||||
// Determines the set of global and local devices that are participating in
|
||||
// the same collective group as the caller.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<int64> participating_replicas,
|
||||
std::vector<int64> global_participating_replicas,
|
||||
GetParticipatingReplicas(global_device_id, config_.replica_groups,
|
||||
config_.replica_count, *params.device_assn));
|
||||
if (IsGlobalNcclConfig() &&
|
||||
participating_replicas.size() != config_.replica_count) {
|
||||
global_participating_replicas.size() != config_.replica_count) {
|
||||
return InvalidArgument(
|
||||
"Partial replica groups are not allowed when using NCCL_COMM_ID "
|
||||
"environment configuration.");
|
||||
}
|
||||
|
||||
std::vector<GlobalDeviceId> global_devices;
|
||||
std::vector<std::pair<GlobalDeviceId, int64>> local_devices;
|
||||
local_devices.reserve(global_participating_replicas.size());
|
||||
global_devices.reserve(global_participating_replicas.size());
|
||||
TF_RET_CHECK(params.device_assn->computation_count() == 1)
|
||||
<< params.device_assn->ToString();
|
||||
std::vector<GlobalDeviceId> participants;
|
||||
participants.reserve(participating_replicas.size());
|
||||
for (int64 replica : participating_replicas) {
|
||||
participants.emplace_back(
|
||||
for (int64 replica : global_participating_replicas) {
|
||||
GlobalDeviceId global_device(
|
||||
(*params.device_assn)(replica, /*computation=*/0));
|
||||
global_devices.push_back(global_device);
|
||||
if (!params.gpu_global_device_ids) {
|
||||
local_devices.emplace_back(global_device, global_device.value());
|
||||
} else {
|
||||
auto it = absl::c_find(*params.gpu_global_device_ids, global_device);
|
||||
if (it != params.gpu_global_device_ids->end()) {
|
||||
local_devices.emplace_back(
|
||||
*it, std::distance(params.gpu_global_device_ids->begin(), it));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<LocalParticipant> local_participants,
|
||||
GetLocalParticipants(participants, params.gpu_global_device_ids));
|
||||
absl::c_sort(global_devices);
|
||||
|
||||
// Create the rendezvous for this collective operation.
|
||||
RendezvousKey rendezvous_key(params.run_id, std::move(participants),
|
||||
local_participants.size(),
|
||||
config_.collective_op_kind, config_.op_id);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
LockedNcclClique locked_clique,
|
||||
AcquireNcclClique(rendezvous_key, device_ordinal, params.stream,
|
||||
local_participants, params.nccl_unique_id_callback));
|
||||
ncclComm_t comm =
|
||||
locked_clique.clique->GetCommForDeviceOrdinal(device_ordinal);
|
||||
|
||||
VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal;
|
||||
ncclRedOp_t reduction_kind = ToNcclReduction(config_.reduction_kind);
|
||||
|
||||
se::gpu::ScopedActivateExecutorContext scoped_context(executor);
|
||||
cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
|
||||
params.stream->implementation()->GpuStreamMemberHack());
|
||||
VLOG(3) << "Using stream pointer: " << cu_stream
|
||||
<< " on device: " << device_ordinal;
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
|
||||
for (size_t i = 0; i < buffers_.size(); ++i) {
|
||||
const Buffer& buffer = buffers_[i];
|
||||
const void* send_buffer =
|
||||
params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
|
||||
.opaque();
|
||||
void* recv_buffer =
|
||||
params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
|
||||
.opaque();
|
||||
TF_ASSIGN_OR_RETURN(ncclDataType_t datatype,
|
||||
ToNcclDataType(config_.operand_element_type[i]));
|
||||
VLOG(3) << absl::StreamFormat(
|
||||
"Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
|
||||
"comm=%p, stream=%p)",
|
||||
send_buffer, recv_buffer, buffer.element_count,
|
||||
static_cast<const void*>(comm), cu_stream);
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
|
||||
/*count=*/buffer.element_count,
|
||||
datatype,
|
||||
/*op=*/reduction_kind, comm,
|
||||
/*stream=*/*cu_stream));
|
||||
RendezvousKey rendezvous_key(params.run_id, global_devices,
|
||||
local_devices.size(), config_.collective_op_kind,
|
||||
config_.op_id);
|
||||
if (VLOG_IS_ON(2)) {
|
||||
std::vector<std::string> local_participants;
|
||||
local_participants.reserve(local_devices.size());
|
||||
for (const auto& entry : local_devices) {
|
||||
local_participants.push_back(absl::StrFormat(
|
||||
"global=%d/local=%d", entry.first.value(), entry.second));
|
||||
}
|
||||
VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
|
||||
<< ", global participating replicas: "
|
||||
<< absl::StrJoin(global_participating_replicas, ", ")
|
||||
<< ", global participating devices: "
|
||||
<< GlobalDeviceIdsToString(global_devices)
|
||||
<< ", local participants: "
|
||||
<< absl::StrJoin(local_participants, ",");
|
||||
}
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
|
||||
AllReduceParticipantData participant(rendezvous_key, local_device_ordinal,
|
||||
params.stream);
|
||||
for (size_t i = 0; i < buffers_.size(); ++i) {
|
||||
const NcclAllReduceThunk::Buffer& buffer = buffers_[i];
|
||||
AllReduceParticipantData::Buffer pbuffer;
|
||||
pbuffer.element_count = buffer.element_count;
|
||||
pbuffer.source_data =
|
||||
params.buffer_allocations->GetDeviceAddress(buffer.source_buffer);
|
||||
pbuffer.destination_data =
|
||||
params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer);
|
||||
pbuffer.primitive_type = config_.operand_element_type[i];
|
||||
participant.buffers.push_back(pbuffer);
|
||||
}
|
||||
participant.local_devices = std::move(local_devices);
|
||||
participant.nccl_unique_id_callback = params.nccl_unique_id_callback;
|
||||
participant.reduction_kind = config_.reduction_kind;
|
||||
|
||||
VLOG(3) << "Done performing all-reduce for ordinal: " << device_ordinal;
|
||||
auto rendezvous_factory = [](const RendezvousKey& k) {
|
||||
return absl::make_unique<RendezvousNcclAllReduce>(k);
|
||||
};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<NcclClique> clique,
|
||||
RendezvousNcclAllReduce::SubmitParticipant(
|
||||
[&] {
|
||||
return GlobalRendezvousMap().GetOrCreateIfAbsent(
|
||||
rendezvous_key, rendezvous_factory);
|
||||
},
|
||||
participant));
|
||||
|
||||
// Keep the clique we used alive for as long as this Thunk lives. Creating
|
||||
// new NCCL cliques is expensive, and this is how we avoid thrashing them.
|
||||
{
|
||||
tensorflow::mutex_lock lock(config_.aux_data->mu);
|
||||
config_.aux_data->cliques.insert(std::move(locked_clique.clique));
|
||||
config_.aux_data->cliques.insert(std::move(clique));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -16,12 +16,17 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
@ -62,6 +67,14 @@ class NcclAllReduceThunk : public Thunk {
|
||||
// error.
|
||||
static bool NcclIsEnabled();
|
||||
|
||||
// Gets the set of devices that have a NCCL channel open. This is primarily
|
||||
// for testing.
|
||||
//
|
||||
// (Indeed, because the NCCL channels are a global variable, in the real
|
||||
// world, the value returned here is stale as soon as you read it, so it's not
|
||||
// clear how you *could* use it for anything other than tests.)
|
||||
static absl::flat_hash_set<GlobalDeviceId> DevicesWithOpenNcclChannels();
|
||||
|
||||
struct Buffer {
|
||||
int64 element_count;
|
||||
BufferAllocation::Slice source_buffer;
|
||||
|
@ -1,311 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
ncclRedOp_t ToNcclReduction(ReductionKind kind) {
|
||||
switch (kind) {
|
||||
case ReductionKind::SUM:
|
||||
return ncclSum;
|
||||
case ReductionKind::PRODUCT:
|
||||
return ncclProd;
|
||||
case ReductionKind::MIN:
|
||||
return ncclMin;
|
||||
case ReductionKind::MAX:
|
||||
return ncclMax;
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type) {
|
||||
switch (element_type) {
|
||||
case S8:
|
||||
return ncclInt8;
|
||||
case PRED:
|
||||
case U8:
|
||||
return ncclUint8;
|
||||
case S32:
|
||||
return ncclInt32;
|
||||
case U32:
|
||||
return ncclUint32;
|
||||
case S64:
|
||||
return ncclInt64;
|
||||
case U64:
|
||||
return ncclUint64;
|
||||
case F16:
|
||||
return ncclFloat16;
|
||||
case F32:
|
||||
return ncclFloat32;
|
||||
case F64:
|
||||
return ncclFloat64;
|
||||
default:
|
||||
return tensorflow::errors::InvalidArgument(absl::StrFormat(
|
||||
"Unsupported data type: %s", PrimitiveType_Name(element_type)));
|
||||
}
|
||||
}
|
||||
|
||||
bool IsGlobalNcclConfig() {
|
||||
static bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr;
|
||||
return global_nccl_config;
|
||||
}
|
||||
|
||||
Status ToStatus(ncclResult_t s, const char* file, int64 line,
|
||||
const char* expr) {
|
||||
if (s == ncclSuccess) {
|
||||
return Status::OK();
|
||||
}
|
||||
return tensorflow::errors::Internal(
|
||||
absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr,
|
||||
ncclGetErrorString(s)));
|
||||
}
|
||||
|
||||
Status ToStatus(cudaError_t s, const char* file, int64 line, const char* expr) {
|
||||
if (s == cudaSuccess) {
|
||||
return Status::OK();
|
||||
}
|
||||
return tensorflow::errors::Internal(
|
||||
absl::StrFormat("%s:%d: CUDA operation %s failed: %s", file, line, expr,
|
||||
cudaGetErrorString(s)));
|
||||
}
|
||||
|
||||
NcclClique::NcclClique(
|
||||
absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal)
|
||||
: comms_by_device_ordinal_(std::move(comms_by_device_ordinal)) {}
|
||||
|
||||
ncclComm_t NcclClique::GetCommForDeviceOrdinal(int device_ordinal) const {
|
||||
return comms_by_device_ordinal_.at(device_ordinal).get();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void DestroyNcclComm(ncclComm_t comm) {
|
||||
VLOG(3) << absl::StreamFormat("Destroying comm %p", comm);
|
||||
XLA_CUDA_WARN_IF_ERROR(ncclCommDestroy(comm));
|
||||
}
|
||||
|
||||
Status ToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) {
|
||||
if (str_id.size() != NCCL_UNIQUE_ID_BYTES) {
|
||||
return InvalidArgument(
|
||||
"ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(),
|
||||
NCCL_UNIQUE_ID_BYTES);
|
||||
}
|
||||
// NcclUniqueId is internally just a char[].
|
||||
static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
|
||||
"NCCL_UNIQUE_ID_BYTES");
|
||||
std::memcpy(static_cast<void*>(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::string LocalParticipantsToString(
|
||||
const std::vector<LocalParticipant>& local_participants) {
|
||||
std::vector<std::string> parts;
|
||||
for (const LocalParticipant& local_participant : local_participants) {
|
||||
parts.push_back(absl::StrFormat("%d/rank=%d",
|
||||
local_participant.device_ordinal,
|
||||
local_participant.rank));
|
||||
}
|
||||
return absl::StrJoin(parts, ",");
|
||||
}
|
||||
|
||||
RefcountingHashMap<NcclCliqueKey, NcclClique>& NcclCliqueCache() {
|
||||
// Global cache of NCCL cliques. An entry in this map is kept alive as long
|
||||
// as there's a reference to it somewhere. A Thunk holds a reference to each
|
||||
// Clique it's ever used.
|
||||
//
|
||||
// A consequence of the fact that this is process-global is that we'll only
|
||||
// ever have one clique alive for a given set of GPUs. This means that a
|
||||
// process will never do two collective operations concurrently on the same
|
||||
// set of GPUs.
|
||||
static auto& cache = *new RefcountingHashMap<NcclCliqueKey, NcclClique>();
|
||||
return cache;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<NcclClique>> CreateNcclClique(
|
||||
const NcclCliqueKey& key,
|
||||
const std::vector<LocalParticipant>& local_participants,
|
||||
const NcclUniqueIdCallback* callback) {
|
||||
int num_participants = key.devices().size();
|
||||
ncclUniqueId unique_id;
|
||||
if (callback) { // Multi-host collective.
|
||||
TF_ASSIGN_OR_RETURN(std::string id_string, (*callback)(key));
|
||||
TF_RETURN_IF_ERROR(ToNcclUniqueId(id_string, &unique_id));
|
||||
} else {
|
||||
TF_RET_CHECK((num_participants == local_participants.size()) ||
|
||||
IsGlobalNcclConfig())
|
||||
<< "If non-local devices are taking part of a collective API on GPU, "
|
||||
"the nccl_unique_id_callback must be provided by the client.";
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&unique_id));
|
||||
}
|
||||
|
||||
VLOG(3) << "Initializing nccl comms for local participants: "
|
||||
<< LocalParticipantsToString(local_participants);
|
||||
|
||||
// Restore CUDA device after running this. XLA shouldn't care, but maybe
|
||||
// another consumer does.
|
||||
int initial_cuda_device;
|
||||
XLA_CUDA_RETURN_IF_ERROR(cudaGetDevice(&initial_cuda_device));
|
||||
auto cuda_device_restorer = MakeCleanup(
|
||||
[&] { XLA_CUDA_WARN_IF_ERROR(cudaSetDevice(initial_cuda_device)); });
|
||||
|
||||
// When using ncclGroupStart/End it seems that the ncclComm_t's are not
|
||||
// populated until the End() call.
|
||||
std::vector<ncclComm_t> raw_comms(local_participants.size(), nullptr);
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
|
||||
Status status = [&] {
|
||||
for (int i = 0; i < local_participants.size(); ++i) {
|
||||
XLA_CUDA_RETURN_IF_ERROR(
|
||||
cudaSetDevice(local_participants[i].device_ordinal));
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i], num_participants,
|
||||
unique_id,
|
||||
local_participants[i].rank));
|
||||
}
|
||||
return Status::OK();
|
||||
}();
|
||||
// Always call ncclGroupEnd().
|
||||
status.Update(XLA_CUDA_STATUS(ncclGroupEnd()));
|
||||
|
||||
// Always copy raw comms to RAII type, so they are cleaned up properly.
|
||||
absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal(raw_comms.size());
|
||||
for (int i = 0; i < raw_comms.size(); ++i) {
|
||||
int device_ordinal = local_participants[i].device_ordinal;
|
||||
VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
|
||||
device_ordinal, raw_comms[i]);
|
||||
CHECK(raw_comms[i] != nullptr || !status.ok());
|
||||
comms_by_device_ordinal.emplace(device_ordinal,
|
||||
NcclComm(raw_comms[i], &DestroyNcclComm));
|
||||
}
|
||||
|
||||
// Now we can check if there was an error creating the communicators.
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
return std::make_unique<NcclClique>(std::move(comms_by_device_ordinal));
|
||||
}
|
||||
|
||||
struct NcclCliqueParticipantData : public ParticipantData {
|
||||
using ParticipantData::ParticipantData;
|
||||
std::string ToString() const override { return ""; }
|
||||
};
|
||||
|
||||
class NcclCliqueRendezvous
|
||||
: public Rendezvous<NcclCliqueParticipantData, LockedNcclClique> {
|
||||
public:
|
||||
NcclCliqueRendezvous(const RendezvousKey& rendezvous_key,
|
||||
const std::vector<LocalParticipant>& local_participants,
|
||||
const NcclUniqueIdCallback* callback)
|
||||
: Rendezvous(rendezvous_key) {
|
||||
NcclCliqueKey key(std::move(rendezvous_key.global_devices));
|
||||
maybe_clique_ = NcclCliqueCache().GetOrTryCreateIfAbsent(
|
||||
key, [&](const NcclCliqueKey& key) {
|
||||
return CreateNcclClique(key, local_participants, callback);
|
||||
});
|
||||
if (maybe_clique_.ok()) {
|
||||
lock_ = std::make_shared<absl::MutexLock>((*maybe_clique_)->mu());
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<ParticipantImplOutput> RunCollectiveOp(
|
||||
const NcclCliqueParticipantData&) override {
|
||||
bool primary = InitializationBarrier();
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<NcclClique> clique, maybe_clique_);
|
||||
return ParticipantImplOutput{primary, LockedNcclClique{clique, lock_}};
|
||||
}
|
||||
|
||||
private:
|
||||
StatusOr<std::shared_ptr<NcclClique>> maybe_clique_;
|
||||
std::shared_ptr<absl::MutexLock> lock_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
|
||||
const std::vector<GlobalDeviceId>& participants,
|
||||
const std::vector<GlobalDeviceId>* local_devices) {
|
||||
std::vector<LocalParticipant> local_participants;
|
||||
if (local_devices) {
|
||||
absl::flat_hash_map<GlobalDeviceId, int> device_ranks(participants.size());
|
||||
for (int rank = 0; rank < participants.size(); ++rank) {
|
||||
auto result = device_ranks.emplace(participants[rank], rank);
|
||||
TF_RET_CHECK(result.second) << "Duplicate device found";
|
||||
}
|
||||
|
||||
local_participants.reserve(local_devices->size());
|
||||
for (int device_ordinal = 0; device_ordinal < local_devices->size();
|
||||
++device_ordinal) {
|
||||
auto it = device_ranks.find((*local_devices)[device_ordinal]);
|
||||
if (it != device_ranks.end()) {
|
||||
local_participants.push_back({device_ordinal, /*rank=*/it->second});
|
||||
}
|
||||
}
|
||||
} else { // Single host, so use identity mapping (device ordinal == id).
|
||||
local_participants.reserve(participants.size());
|
||||
for (int rank = 0; rank < participants.size(); ++rank) {
|
||||
int device_ordinal = participants[rank].value();
|
||||
local_participants.push_back({device_ordinal, rank});
|
||||
}
|
||||
}
|
||||
|
||||
return local_participants;
|
||||
}
|
||||
|
||||
StatusOr<LockedNcclClique> AcquireNcclClique(
|
||||
const RendezvousKey& rendezvous_key, int local_device_ordinal,
|
||||
se::Stream* stream, const std::vector<LocalParticipant>& local_participants,
|
||||
const NcclUniqueIdCallback* callback) {
|
||||
VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
|
||||
<< ", local participants: "
|
||||
<< LocalParticipantsToString(local_participants);
|
||||
|
||||
static auto& rendezvous_map =
|
||||
*new RefcountingHashMap<RendezvousKey, NcclCliqueRendezvous>();
|
||||
|
||||
NcclCliqueParticipantData participant(rendezvous_key, local_device_ordinal,
|
||||
stream);
|
||||
return NcclCliqueRendezvous::SubmitParticipant(
|
||||
/*rendezvous_getter=*/
|
||||
[&] {
|
||||
return rendezvous_map.GetOrCreateIfAbsent(
|
||||
rendezvous_key, [&](const RendezvousKey& rendezvous_key) {
|
||||
return std::make_unique<NcclCliqueRendezvous>(
|
||||
rendezvous_key, local_participants, callback);
|
||||
});
|
||||
},
|
||||
participant);
|
||||
}
|
||||
|
||||
absl::flat_hash_set<GlobalDeviceId> DevicesWithOpenNcclChannels() {
|
||||
absl::flat_hash_set<GlobalDeviceId> devices;
|
||||
NcclCliqueCache().ForEach(
|
||||
[&](const NcclCliqueKey& k, const std::shared_ptr<NcclClique>&) {
|
||||
devices.insert(k.devices().begin(), k.devices().end());
|
||||
});
|
||||
return devices;
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
@ -1,136 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/nccl/nccl.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/rccl/rccl.h"
|
||||
#endif
|
||||
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
// Local hipify of cuda symbols
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaStream_t hipStream_t
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#define cudaGetDevice hipGetDevice
|
||||
#define cudaSetDevice hipSetDevice
|
||||
#define cudaSuccess hipSuccess
|
||||
#endif
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
ncclRedOp_t ToNcclReduction(ReductionKind kind);
|
||||
StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type);
|
||||
|
||||
bool IsGlobalNcclConfig();
|
||||
|
||||
Status ToStatus(ncclResult_t s, const char* file, int64 line, const char* expr);
|
||||
Status ToStatus(cudaError_t s, const char* file, int64 line, const char* expr);
|
||||
|
||||
// Macros to return or warn on CUDA/NCCL errors. (The same macro works for both
|
||||
// NCCL and CUDA errors.)
|
||||
//
|
||||
// It's tempting to say these macros belong in an XLA header somewhere, but in
|
||||
// practice we don't do much direct-to-CUDA-API stuff outside of this file.
|
||||
#define XLA_CUDA_STATUS(expr) \
|
||||
::xla::gpu::ToStatus(expr, __FILE__, __LINE__, #expr)
|
||||
|
||||
#define XLA_CUDA_RETURN_IF_ERROR(expr) \
|
||||
do { \
|
||||
Status s = XLA_CUDA_STATUS(expr); \
|
||||
if (!s.ok()) { \
|
||||
return s; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define XLA_CUDA_WARN_IF_ERROR(expr) \
|
||||
do { \
|
||||
Status s = XLA_CUDA_STATUS(expr); \
|
||||
if (!s.ok()) { \
|
||||
LOG(ERROR) << s.ToString(); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// RAII type for NCCL communicators.
|
||||
using NcclComm = std::unique_ptr<ncclComm, void (*)(ncclComm_t)>;
|
||||
|
||||
// Owns a clique of NCCL comms which can be used for collective operations among
|
||||
// a particular set of GPUs.
|
||||
//
|
||||
// Note that if you want to do a collective operation among a subset of these
|
||||
// GPUs, you'll need a different clique.
|
||||
class NcclClique {
|
||||
public:
|
||||
explicit NcclClique(
|
||||
absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal);
|
||||
|
||||
ncclComm_t GetCommForDeviceOrdinal(int device_ordinal) const;
|
||||
absl::Mutex* mu() { return &mu_; }
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal_;
|
||||
absl::Mutex mu_;
|
||||
};
|
||||
|
||||
struct LocalParticipant {
|
||||
int device_ordinal;
|
||||
int rank;
|
||||
};
|
||||
|
||||
StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
|
||||
const std::vector<GlobalDeviceId>& participants,
|
||||
const std::vector<GlobalDeviceId>* local_devices); // may be null
|
||||
|
||||
struct LockedNcclClique {
|
||||
std::shared_ptr<NcclClique> clique;
|
||||
// Must come after clique, so it is destroyed first.
|
||||
// This lock prevents other threads from using this clique. All of the threads
|
||||
// involved should hold onto the lock until they have finished with their
|
||||
// communicator.
|
||||
std::shared_ptr<absl::MutexLock> lock;
|
||||
};
|
||||
|
||||
// Acquires a locked NCCL clique for use in NCCL collective operations.
|
||||
StatusOr<LockedNcclClique> AcquireNcclClique(
|
||||
const RendezvousKey& rendezvous_key, int local_device_ordinal,
|
||||
se::Stream* stream, const std::vector<LocalParticipant>& local_participants,
|
||||
const NcclUniqueIdCallback* callback); // may be null
|
||||
|
||||
// Gets the set of devices that have a NCCL channel open. This is primarily
|
||||
// for testing.
|
||||
//
|
||||
// (Indeed, because the NCCL channels are a global variable, in the real
|
||||
// world, the value returned here is stale as soon as you read it, so it's not
|
||||
// clear how you *could* use it for anything other than tests.)
|
||||
absl::flat_hash_set<GlobalDeviceId> DevicesWithOpenNcclChannels();
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
|
@ -17,9 +17,7 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#ifdef GOOGLE_CUDA
|
||||
#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
|
||||
#endif
|
||||
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
@ -162,11 +160,7 @@ DeviceAssignment MakeDeviceAssn(std::vector<int64> devices) {
|
||||
|
||||
// Shorter alias for this function.
|
||||
absl::flat_hash_set<GlobalDeviceId> OpenNcclChannels() {
|
||||
#ifdef GOOGLE_CUDA
|
||||
return gpu::DevicesWithOpenNcclChannels();
|
||||
#else
|
||||
return {};
|
||||
#endif
|
||||
return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
Loading…
Reference in New Issue
Block a user