[XLA:GPU] Cleanups to collective-permute code.
In particular, simplify how our BlockingCounters are created. PiperOrigin-RevId: 253625335
This commit is contained in:
parent
48bd9b6b56
commit
9f65b55d73
@ -67,7 +67,7 @@ void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
|
|||||||
// because we use that information when constructing the Rendezvous.
|
// because we use that information when constructing the Rendezvous.
|
||||||
struct RendezvousKey {
|
struct RendezvousKey {
|
||||||
RunId run_id;
|
RunId run_id;
|
||||||
int64 num_participants;
|
int num_participants; // int, not int64, to match BlockingCounter's counter.
|
||||||
|
|
||||||
string ToString() const {
|
string ToString() const {
|
||||||
return absl::StrFormat("RendezvousKey{run_id=%s, num_participants=%d}",
|
return absl::StrFormat("RendezvousKey{run_id=%s, num_participants=%d}",
|
||||||
@ -108,11 +108,7 @@ struct ParticipantData {
|
|||||||
// Rendezvous objects can only be used once.
|
// Rendezvous objects can only be used once.
|
||||||
class Rendezvous {
|
class Rendezvous {
|
||||||
public:
|
public:
|
||||||
explicit Rendezvous(const RendezvousKey& key)
|
explicit Rendezvous(const RendezvousKey& key) : key_(key) {}
|
||||||
: key_(key),
|
|
||||||
all_arrived_(key.num_participants),
|
|
||||||
returned_blocking_counter_(
|
|
||||||
std::make_shared<BlockingCounter>(key.num_participants)) {}
|
|
||||||
|
|
||||||
// Runs the collective permute on the given thread.
|
// Runs the collective permute on the given thread.
|
||||||
//
|
//
|
||||||
@ -125,10 +121,11 @@ class Rendezvous {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
const RendezvousKey key_;
|
const RendezvousKey key_;
|
||||||
BlockingCounter all_arrived_;
|
BlockingCounter all_arrived_{key_.num_participants};
|
||||||
|
|
||||||
// BlockingCounter returned by SubmitParticipant.
|
// BlockingCounter returned by SubmitParticipant.
|
||||||
std::shared_ptr<BlockingCounter> returned_blocking_counter_;
|
std::shared_ptr<BlockingCounter> returned_blocking_counter_{
|
||||||
|
std::make_shared<BlockingCounter>(key_.num_participants)};
|
||||||
|
|
||||||
tensorflow::mutex mu_;
|
tensorflow::mutex mu_;
|
||||||
bool initialized_ GUARDED_BY(mu_) = false;
|
bool initialized_ GUARDED_BY(mu_) = false;
|
||||||
|
@ -72,6 +72,8 @@ namespace gpu {
|
|||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using tensorflow::BlockingCounter;
|
||||||
|
|
||||||
// Functions to translate an ncclResult_t/cudaError_t to a Status object. Used
|
// Functions to translate an ncclResult_t/cudaError_t to a Status object. Used
|
||||||
// by the macros below.
|
// by the macros below.
|
||||||
Status TranslateStatus(ncclResult_t s, const char* file, int64 line,
|
Status TranslateStatus(ncclResult_t s, const char* file, int64 line,
|
||||||
@ -116,8 +118,7 @@ Status TranslateStatus(cudaError_t s, const char* file, int64 line,
|
|||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
template <typename DescFn>
|
template <typename DescFn>
|
||||||
void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
|
void WaitAndLogIfStuck(BlockingCounter* counter, const DescFn& desc_fn) {
|
||||||
const DescFn& desc_fn) {
|
|
||||||
VLOG(3) << "Begin: " << desc_fn();
|
VLOG(3) << "Begin: " << desc_fn();
|
||||||
const std::chrono::milliseconds timeout(5000);
|
const std::chrono::milliseconds timeout(5000);
|
||||||
bool ok = counter->WaitFor(timeout);
|
bool ok = counter->WaitFor(timeout);
|
||||||
@ -202,6 +203,8 @@ struct RendezvousKey {
|
|||||||
static_cast<int64>(instr->GetModule()->unique_id()));
|
static_cast<int64>(instr->GetModule()->unique_id()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int num_participants() const { return participating_replicas.size(); }
|
||||||
|
|
||||||
template <typename H>
|
template <typename H>
|
||||||
friend H AbslHashValue(H h, const RendezvousKey& k) {
|
friend H AbslHashValue(H h, const RendezvousKey& k) {
|
||||||
return H::combine(std::move(h), k.run_id, k.participating_replicas,
|
return H::combine(std::move(h), k.run_id, k.participating_replicas,
|
||||||
@ -248,9 +251,7 @@ struct ParticipantData {
|
|||||||
se::DeviceMemoryBase destination_data;
|
se::DeviceMemoryBase destination_data;
|
||||||
se::Stream* stream;
|
se::Stream* stream;
|
||||||
|
|
||||||
int64 num_participants() const {
|
int num_participants() const { return rendezvous_key.num_participants(); }
|
||||||
return rendezvous_key.participating_replicas.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
string ToString() const {
|
string ToString() const {
|
||||||
return absl::StrFormat(
|
return absl::StrFormat(
|
||||||
@ -409,8 +410,7 @@ RefcountingHashMap<NcclCliqueKey, NcclClique>& GlobalNcclCliqueMap() {
|
|||||||
// Rendezvous objects can only be used once.
|
// Rendezvous objects can only be used once.
|
||||||
class Rendezvous {
|
class Rendezvous {
|
||||||
public:
|
public:
|
||||||
explicit Rendezvous(const RendezvousKey& k)
|
explicit Rendezvous(const RendezvousKey& k) : key_(k) {}
|
||||||
: key_(k), all_participants_present_(k.participating_replicas.size()) {}
|
|
||||||
|
|
||||||
// Runs the all-reduce on the given thread. If successful, returns
|
// Runs the all-reduce on the given thread. If successful, returns
|
||||||
// - a handle to the clique that was used, so that the caller may keep the
|
// - a handle to the clique that was used, so that the caller may keep the
|
||||||
@ -418,25 +418,26 @@ class Rendezvous {
|
|||||||
// - a BlockingCounter initialized to the number of participants, so that
|
// - a BlockingCounter initialized to the number of participants, so that
|
||||||
// the caller can coordinate with the participants one last time if it
|
// the caller can coordinate with the participants one last time if it
|
||||||
// chooses. This is useful for coordinating destruction of the Rendezvous.
|
// chooses. This is useful for coordinating destruction of the Rendezvous.
|
||||||
StatusOr<std::pair<std::shared_ptr<NcclClique>,
|
StatusOr<
|
||||||
std::shared_ptr<tensorflow::BlockingCounter>>>
|
std::pair<std::shared_ptr<NcclClique>, std::shared_ptr<BlockingCounter>>>
|
||||||
SubmitParticipant(ParticipantData participant);
|
SubmitParticipant(ParticipantData participant);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status DoAllReduce(ParticipantData participant, ncclComm_t comm);
|
Status DoAllReduce(ParticipantData participant, ncclComm_t comm);
|
||||||
|
|
||||||
const RendezvousKey key_;
|
const RendezvousKey key_;
|
||||||
tensorflow::BlockingCounter all_participants_present_;
|
|
||||||
|
BlockingCounter all_participants_present_{key_.num_participants()};
|
||||||
|
BlockingCounter done_{key_.num_participants()};
|
||||||
|
// BlockingCounter returned by SubmitParticipant.
|
||||||
|
std::shared_ptr<BlockingCounter> returned_blocking_counter_{
|
||||||
|
std::make_shared<BlockingCounter>(key_.num_participants())};
|
||||||
|
|
||||||
tensorflow::mutex mu_;
|
tensorflow::mutex mu_;
|
||||||
|
|
||||||
bool initialized_ GUARDED_BY(mu_) = false;
|
bool initialized_ GUARDED_BY(mu_) = false;
|
||||||
absl::optional<tensorflow::BlockingCounter> done_;
|
|
||||||
std::vector<ParticipantData> participants_ GUARDED_BY(mu_);
|
|
||||||
|
|
||||||
// BlockingCounter returned by SubmitParticipant. Initialized by the primary
|
std::vector<ParticipantData> participants_ GUARDED_BY(mu_);
|
||||||
// thread.
|
|
||||||
std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Global map of Rendezvous objects. A thread participating in a collective op
|
// Global map of Rendezvous objects. A thread participating in a collective op
|
||||||
@ -451,8 +452,8 @@ RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
|
|||||||
return m;
|
return m;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::pair<std::shared_ptr<NcclClique>,
|
StatusOr<
|
||||||
std::shared_ptr<tensorflow::BlockingCounter>>>
|
std::pair<std::shared_ptr<NcclClique>, std::shared_ptr<BlockingCounter>>>
|
||||||
Rendezvous::SubmitParticipant(ParticipantData participant) {
|
Rendezvous::SubmitParticipant(ParticipantData participant) {
|
||||||
{
|
{
|
||||||
tensorflow::mutex_lock lock(mu_);
|
tensorflow::mutex_lock lock(mu_);
|
||||||
@ -506,10 +507,6 @@ Rendezvous::SubmitParticipant(ParticipantData participant) {
|
|||||||
if (primary) {
|
if (primary) {
|
||||||
VLOG(3) << "Primary initializing accounting data.";
|
VLOG(3) << "Primary initializing accounting data.";
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
done_.emplace(participant.num_participants());
|
|
||||||
returned_blocking_counter_ =
|
|
||||||
std::make_shared<tensorflow::BlockingCounter>(
|
|
||||||
participant.num_participants());
|
|
||||||
|
|
||||||
// Acquire exclusive access to the NCCL clique itself so that two
|
// Acquire exclusive access to the NCCL clique itself so that two
|
||||||
// unrelated collective operations won't try to use the clique
|
// unrelated collective operations won't try to use the clique
|
||||||
@ -535,12 +532,12 @@ Rendezvous::SubmitParticipant(ParticipantData participant) {
|
|||||||
Status all_reduce_status = DoAllReduce(participant, comm);
|
Status all_reduce_status = DoAllReduce(participant, comm);
|
||||||
VLOG(3) << "This thread done with all-reduce op.";
|
VLOG(3) << "This thread done with all-reduce op.";
|
||||||
|
|
||||||
done_->DecrementCount();
|
done_.DecrementCount();
|
||||||
|
|
||||||
// The primary owns the lock on the NCCL clique. Hold it until all threads
|
// The primary owns the lock on the NCCL clique. Hold it until all threads
|
||||||
// are done. (We'll release it when we return from this function.)
|
// are done. (We'll release it when we return from this function.)
|
||||||
if (primary) {
|
if (primary) {
|
||||||
WaitAndLogIfStuck(&*done_, [&] {
|
WaitAndLogIfStuck(&done_, [&] {
|
||||||
return absl::StrFormat(
|
return absl::StrFormat(
|
||||||
"primary participant (device ordinal %d, stream %p) waiting for all "
|
"primary participant (device ordinal %d, stream %p) waiting for all "
|
||||||
"other participants to complete all-reduce %s",
|
"other participants to complete all-reduce %s",
|
||||||
|
Loading…
Reference in New Issue
Block a user