[XLA:GPU] Cleanups to collective-permute code.

In particular, simplify how our BlockingCounters are created.

PiperOrigin-RevId: 253625335
This commit is contained in:
Justin Lebar 2019-06-17 11:40:25 -07:00 committed by TensorFlower Gardener
parent 48bd9b6b56
commit 9f65b55d73
2 changed files with 25 additions and 31 deletions

View File

@ -67,7 +67,7 @@ void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
// because we use that information when constructing the Rendezvous. // because we use that information when constructing the Rendezvous.
struct RendezvousKey { struct RendezvousKey {
RunId run_id; RunId run_id;
int64 num_participants; int num_participants; // int, not int64, to match BlockingCounter's counter.
string ToString() const { string ToString() const {
return absl::StrFormat("RendezvousKey{run_id=%s, num_participants=%d}", return absl::StrFormat("RendezvousKey{run_id=%s, num_participants=%d}",
@ -108,11 +108,7 @@ struct ParticipantData {
// Rendezvous objects can only be used once. // Rendezvous objects can only be used once.
class Rendezvous { class Rendezvous {
public: public:
explicit Rendezvous(const RendezvousKey& key) explicit Rendezvous(const RendezvousKey& key) : key_(key) {}
: key_(key),
all_arrived_(key.num_participants),
returned_blocking_counter_(
std::make_shared<BlockingCounter>(key.num_participants)) {}
// Runs the collective permute on the given thread. // Runs the collective permute on the given thread.
// //
@ -125,10 +121,11 @@ class Rendezvous {
private: private:
const RendezvousKey key_; const RendezvousKey key_;
BlockingCounter all_arrived_; BlockingCounter all_arrived_{key_.num_participants};
// BlockingCounter returned by SubmitParticipant. // BlockingCounter returned by SubmitParticipant.
std::shared_ptr<BlockingCounter> returned_blocking_counter_; std::shared_ptr<BlockingCounter> returned_blocking_counter_{
std::make_shared<BlockingCounter>(key_.num_participants)};
tensorflow::mutex mu_; tensorflow::mutex mu_;
bool initialized_ GUARDED_BY(mu_) = false; bool initialized_ GUARDED_BY(mu_) = false;

View File

@ -72,6 +72,8 @@ namespace gpu {
#if GOOGLE_CUDA #if GOOGLE_CUDA
namespace { namespace {
using tensorflow::BlockingCounter;
// Functions to translate an ncclResult_t/cudaError_t to a Status object. Used // Functions to translate an ncclResult_t/cudaError_t to a Status object. Used
// by the macros below. // by the macros below.
Status TranslateStatus(ncclResult_t s, const char* file, int64 line, Status TranslateStatus(ncclResult_t s, const char* file, int64 line,
@ -116,8 +118,7 @@ Status TranslateStatus(cudaError_t s, const char* file, int64 line,
} while (0) } while (0)
template <typename DescFn> template <typename DescFn>
void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter, void WaitAndLogIfStuck(BlockingCounter* counter, const DescFn& desc_fn) {
const DescFn& desc_fn) {
VLOG(3) << "Begin: " << desc_fn(); VLOG(3) << "Begin: " << desc_fn();
const std::chrono::milliseconds timeout(5000); const std::chrono::milliseconds timeout(5000);
bool ok = counter->WaitFor(timeout); bool ok = counter->WaitFor(timeout);
@ -202,6 +203,8 @@ struct RendezvousKey {
static_cast<int64>(instr->GetModule()->unique_id())); static_cast<int64>(instr->GetModule()->unique_id()));
} }
int num_participants() const { return participating_replicas.size(); }
template <typename H> template <typename H>
friend H AbslHashValue(H h, const RendezvousKey& k) { friend H AbslHashValue(H h, const RendezvousKey& k) {
return H::combine(std::move(h), k.run_id, k.participating_replicas, return H::combine(std::move(h), k.run_id, k.participating_replicas,
@ -248,9 +251,7 @@ struct ParticipantData {
se::DeviceMemoryBase destination_data; se::DeviceMemoryBase destination_data;
se::Stream* stream; se::Stream* stream;
int64 num_participants() const { int num_participants() const { return rendezvous_key.num_participants(); }
return rendezvous_key.participating_replicas.size();
}
string ToString() const { string ToString() const {
return absl::StrFormat( return absl::StrFormat(
@ -409,8 +410,7 @@ RefcountingHashMap<NcclCliqueKey, NcclClique>& GlobalNcclCliqueMap() {
// Rendezvous objects can only be used once. // Rendezvous objects can only be used once.
class Rendezvous { class Rendezvous {
public: public:
explicit Rendezvous(const RendezvousKey& k) explicit Rendezvous(const RendezvousKey& k) : key_(k) {}
: key_(k), all_participants_present_(k.participating_replicas.size()) {}
// Runs the all-reduce on the given thread. If successful, returns // Runs the all-reduce on the given thread. If successful, returns
// - a handle to the clique that was used, so that the caller may keep the // - a handle to the clique that was used, so that the caller may keep the
@ -418,25 +418,26 @@ class Rendezvous {
// - a BlockingCounter initialized to the number of participants, so that // - a BlockingCounter initialized to the number of participants, so that
// the caller can coordinate with the participants one last time if it // the caller can coordinate with the participants one last time if it
// chooses. This is useful for coordinating destruction of the Rendezvous. // chooses. This is useful for coordinating destruction of the Rendezvous.
StatusOr<std::pair<std::shared_ptr<NcclClique>, StatusOr<
std::shared_ptr<tensorflow::BlockingCounter>>> std::pair<std::shared_ptr<NcclClique>, std::shared_ptr<BlockingCounter>>>
SubmitParticipant(ParticipantData participant); SubmitParticipant(ParticipantData participant);
private: private:
Status DoAllReduce(ParticipantData participant, ncclComm_t comm); Status DoAllReduce(ParticipantData participant, ncclComm_t comm);
const RendezvousKey key_; const RendezvousKey key_;
tensorflow::BlockingCounter all_participants_present_;
BlockingCounter all_participants_present_{key_.num_participants()};
BlockingCounter done_{key_.num_participants()};
// BlockingCounter returned by SubmitParticipant.
std::shared_ptr<BlockingCounter> returned_blocking_counter_{
std::make_shared<BlockingCounter>(key_.num_participants())};
tensorflow::mutex mu_; tensorflow::mutex mu_;
bool initialized_ GUARDED_BY(mu_) = false; bool initialized_ GUARDED_BY(mu_) = false;
absl::optional<tensorflow::BlockingCounter> done_;
std::vector<ParticipantData> participants_ GUARDED_BY(mu_);
// BlockingCounter returned by SubmitParticipant. Initialized by the primary std::vector<ParticipantData> participants_ GUARDED_BY(mu_);
// thread.
std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_;
}; };
// Global map of Rendezvous objects. A thread participating in a collective op // Global map of Rendezvous objects. A thread participating in a collective op
@ -451,8 +452,8 @@ RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
return m; return m;
} }
StatusOr<std::pair<std::shared_ptr<NcclClique>, StatusOr<
std::shared_ptr<tensorflow::BlockingCounter>>> std::pair<std::shared_ptr<NcclClique>, std::shared_ptr<BlockingCounter>>>
Rendezvous::SubmitParticipant(ParticipantData participant) { Rendezvous::SubmitParticipant(ParticipantData participant) {
{ {
tensorflow::mutex_lock lock(mu_); tensorflow::mutex_lock lock(mu_);
@ -506,10 +507,6 @@ Rendezvous::SubmitParticipant(ParticipantData participant) {
if (primary) { if (primary) {
VLOG(3) << "Primary initializing accounting data."; VLOG(3) << "Primary initializing accounting data.";
initialized_ = true; initialized_ = true;
done_.emplace(participant.num_participants());
returned_blocking_counter_ =
std::make_shared<tensorflow::BlockingCounter>(
participant.num_participants());
// Acquire exclusive access to the NCCL clique itself so that two // Acquire exclusive access to the NCCL clique itself so that two
// unrelated collective operations won't try to use the clique // unrelated collective operations won't try to use the clique
@ -535,12 +532,12 @@ Rendezvous::SubmitParticipant(ParticipantData participant) {
Status all_reduce_status = DoAllReduce(participant, comm); Status all_reduce_status = DoAllReduce(participant, comm);
VLOG(3) << "This thread done with all-reduce op."; VLOG(3) << "This thread done with all-reduce op.";
done_->DecrementCount(); done_.DecrementCount();
// The primary owns the lock on the NCCL clique. Hold it until all threads // The primary owns the lock on the NCCL clique. Hold it until all threads
// are done. (We'll release it when we return from this function.) // are done. (We'll release it when we return from this function.)
if (primary) { if (primary) {
WaitAndLogIfStuck(&*done_, [&] { WaitAndLogIfStuck(&done_, [&] {
return absl::StrFormat( return absl::StrFormat(
"primary participant (device ordinal %d, stream %p) waiting for all " "primary participant (device ordinal %d, stream %p) waiting for all "
"other participants to complete all-reduce %s", "other participants to complete all-reduce %s",