[XLA:GPU] Cleanups to collective-permute code.
In particular, simplify how our BlockingCounters are created. PiperOrigin-RevId: 253625335
This commit is contained in:
parent
48bd9b6b56
commit
9f65b55d73
@ -67,7 +67,7 @@ void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
|
||||
// because we use that information when constructing the Rendezvous.
|
||||
struct RendezvousKey {
|
||||
RunId run_id;
|
||||
int64 num_participants;
|
||||
int num_participants; // int, not int64, to match BlockingCounter's counter.
|
||||
|
||||
string ToString() const {
|
||||
return absl::StrFormat("RendezvousKey{run_id=%s, num_participants=%d}",
|
||||
@ -108,11 +108,7 @@ struct ParticipantData {
|
||||
// Rendezvous objects can only be used once.
|
||||
class Rendezvous {
|
||||
public:
|
||||
explicit Rendezvous(const RendezvousKey& key)
|
||||
: key_(key),
|
||||
all_arrived_(key.num_participants),
|
||||
returned_blocking_counter_(
|
||||
std::make_shared<BlockingCounter>(key.num_participants)) {}
|
||||
explicit Rendezvous(const RendezvousKey& key) : key_(key) {}
|
||||
|
||||
// Runs the collective permute on the given thread.
|
||||
//
|
||||
@ -125,10 +121,11 @@ class Rendezvous {
|
||||
|
||||
private:
|
||||
const RendezvousKey key_;
|
||||
BlockingCounter all_arrived_;
|
||||
BlockingCounter all_arrived_{key_.num_participants};
|
||||
|
||||
// BlockingCounter returned by SubmitParticipant.
|
||||
std::shared_ptr<BlockingCounter> returned_blocking_counter_;
|
||||
std::shared_ptr<BlockingCounter> returned_blocking_counter_{
|
||||
std::make_shared<BlockingCounter>(key_.num_participants)};
|
||||
|
||||
tensorflow::mutex mu_;
|
||||
bool initialized_ GUARDED_BY(mu_) = false;
|
||||
|
@ -72,6 +72,8 @@ namespace gpu {
|
||||
#if GOOGLE_CUDA
|
||||
namespace {
|
||||
|
||||
using tensorflow::BlockingCounter;
|
||||
|
||||
// Functions to translate an ncclResult_t/cudaError_t to a Status object. Used
|
||||
// by the macros below.
|
||||
Status TranslateStatus(ncclResult_t s, const char* file, int64 line,
|
||||
@ -116,8 +118,7 @@ Status TranslateStatus(cudaError_t s, const char* file, int64 line,
|
||||
} while (0)
|
||||
|
||||
template <typename DescFn>
|
||||
void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
|
||||
const DescFn& desc_fn) {
|
||||
void WaitAndLogIfStuck(BlockingCounter* counter, const DescFn& desc_fn) {
|
||||
VLOG(3) << "Begin: " << desc_fn();
|
||||
const std::chrono::milliseconds timeout(5000);
|
||||
bool ok = counter->WaitFor(timeout);
|
||||
@ -202,6 +203,8 @@ struct RendezvousKey {
|
||||
static_cast<int64>(instr->GetModule()->unique_id()));
|
||||
}
|
||||
|
||||
int num_participants() const { return participating_replicas.size(); }
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const RendezvousKey& k) {
|
||||
return H::combine(std::move(h), k.run_id, k.participating_replicas,
|
||||
@ -248,9 +251,7 @@ struct ParticipantData {
|
||||
se::DeviceMemoryBase destination_data;
|
||||
se::Stream* stream;
|
||||
|
||||
int64 num_participants() const {
|
||||
return rendezvous_key.participating_replicas.size();
|
||||
}
|
||||
int num_participants() const { return rendezvous_key.num_participants(); }
|
||||
|
||||
string ToString() const {
|
||||
return absl::StrFormat(
|
||||
@ -409,8 +410,7 @@ RefcountingHashMap<NcclCliqueKey, NcclClique>& GlobalNcclCliqueMap() {
|
||||
// Rendezvous objects can only be used once.
|
||||
class Rendezvous {
|
||||
public:
|
||||
explicit Rendezvous(const RendezvousKey& k)
|
||||
: key_(k), all_participants_present_(k.participating_replicas.size()) {}
|
||||
explicit Rendezvous(const RendezvousKey& k) : key_(k) {}
|
||||
|
||||
// Runs the all-reduce on the given thread. If successful, returns
|
||||
// - a handle to the clique that was used, so that the caller may keep the
|
||||
@ -418,25 +418,26 @@ class Rendezvous {
|
||||
// - a BlockingCounter initialized to the number of participants, so that
|
||||
// the caller can coordinate with the participants one last time if it
|
||||
// chooses. This is useful for coordinating destruction of the Rendezvous.
|
||||
StatusOr<std::pair<std::shared_ptr<NcclClique>,
|
||||
std::shared_ptr<tensorflow::BlockingCounter>>>
|
||||
StatusOr<
|
||||
std::pair<std::shared_ptr<NcclClique>, std::shared_ptr<BlockingCounter>>>
|
||||
SubmitParticipant(ParticipantData participant);
|
||||
|
||||
private:
|
||||
Status DoAllReduce(ParticipantData participant, ncclComm_t comm);
|
||||
|
||||
const RendezvousKey key_;
|
||||
tensorflow::BlockingCounter all_participants_present_;
|
||||
|
||||
BlockingCounter all_participants_present_{key_.num_participants()};
|
||||
BlockingCounter done_{key_.num_participants()};
|
||||
// BlockingCounter returned by SubmitParticipant.
|
||||
std::shared_ptr<BlockingCounter> returned_blocking_counter_{
|
||||
std::make_shared<BlockingCounter>(key_.num_participants())};
|
||||
|
||||
tensorflow::mutex mu_;
|
||||
|
||||
bool initialized_ GUARDED_BY(mu_) = false;
|
||||
absl::optional<tensorflow::BlockingCounter> done_;
|
||||
std::vector<ParticipantData> participants_ GUARDED_BY(mu_);
|
||||
|
||||
// BlockingCounter returned by SubmitParticipant. Initialized by the primary
|
||||
// thread.
|
||||
std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_;
|
||||
std::vector<ParticipantData> participants_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
// Global map of Rendezvous objects. A thread participating in a collective op
|
||||
@ -451,8 +452,8 @@ RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
|
||||
return m;
|
||||
}
|
||||
|
||||
StatusOr<std::pair<std::shared_ptr<NcclClique>,
|
||||
std::shared_ptr<tensorflow::BlockingCounter>>>
|
||||
StatusOr<
|
||||
std::pair<std::shared_ptr<NcclClique>, std::shared_ptr<BlockingCounter>>>
|
||||
Rendezvous::SubmitParticipant(ParticipantData participant) {
|
||||
{
|
||||
tensorflow::mutex_lock lock(mu_);
|
||||
@ -506,10 +507,6 @@ Rendezvous::SubmitParticipant(ParticipantData participant) {
|
||||
if (primary) {
|
||||
VLOG(3) << "Primary initializing accounting data.";
|
||||
initialized_ = true;
|
||||
done_.emplace(participant.num_participants());
|
||||
returned_blocking_counter_ =
|
||||
std::make_shared<tensorflow::BlockingCounter>(
|
||||
participant.num_participants());
|
||||
|
||||
// Acquire exclusive access to the NCCL clique itself so that two
|
||||
// unrelated collective operations won't try to use the clique
|
||||
@ -535,12 +532,12 @@ Rendezvous::SubmitParticipant(ParticipantData participant) {
|
||||
Status all_reduce_status = DoAllReduce(participant, comm);
|
||||
VLOG(3) << "This thread done with all-reduce op.";
|
||||
|
||||
done_->DecrementCount();
|
||||
done_.DecrementCount();
|
||||
|
||||
// The primary owns the lock on the NCCL clique. Hold it until all threads
|
||||
// are done. (We'll release it when we return from this function.)
|
||||
if (primary) {
|
||||
WaitAndLogIfStuck(&*done_, [&] {
|
||||
WaitAndLogIfStuck(&done_, [&] {
|
||||
return absl::StrFormat(
|
||||
"primary participant (device ordinal %d, stream %p) waiting for all "
|
||||
"other participants to complete all-reduce %s",
|
||||
|
Loading…
Reference in New Issue
Block a user