[XLA:GPU] Cleanups to collective-permute code.

In particular, simplify how our BlockingCounters are created.

PiperOrigin-RevId: 253625335
This commit is contained in:
Justin Lebar 2019-06-17 11:40:25 -07:00 committed by TensorFlower Gardener
parent 48bd9b6b56
commit 9f65b55d73
2 changed files with 25 additions and 31 deletions

View File

@ -67,7 +67,7 @@ void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
// because we use that information when constructing the Rendezvous.
struct RendezvousKey {
RunId run_id;
int64 num_participants;
int num_participants; // int, not int64, to match BlockingCounter's counter.
string ToString() const {
return absl::StrFormat("RendezvousKey{run_id=%s, num_participants=%d}",
@ -108,11 +108,7 @@ struct ParticipantData {
// Rendezvous objects can only be used once.
class Rendezvous {
public:
explicit Rendezvous(const RendezvousKey& key)
: key_(key),
all_arrived_(key.num_participants),
returned_blocking_counter_(
std::make_shared<BlockingCounter>(key.num_participants)) {}
explicit Rendezvous(const RendezvousKey& key) : key_(key) {}
// Runs the collective permute on the given thread.
//
@ -125,10 +121,11 @@ class Rendezvous {
private:
const RendezvousKey key_;
BlockingCounter all_arrived_;
BlockingCounter all_arrived_{key_.num_participants};
// BlockingCounter returned by SubmitParticipant.
std::shared_ptr<BlockingCounter> returned_blocking_counter_;
std::shared_ptr<BlockingCounter> returned_blocking_counter_{
std::make_shared<BlockingCounter>(key_.num_participants)};
tensorflow::mutex mu_;
bool initialized_ GUARDED_BY(mu_) = false;

View File

@ -72,6 +72,8 @@ namespace gpu {
#if GOOGLE_CUDA
namespace {
using tensorflow::BlockingCounter;
// Functions to translate an ncclResult_t/cudaError_t to a Status object. Used
// by the macros below.
Status TranslateStatus(ncclResult_t s, const char* file, int64 line,
@ -116,8 +118,7 @@ Status TranslateStatus(cudaError_t s, const char* file, int64 line,
} while (0)
template <typename DescFn>
void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
const DescFn& desc_fn) {
void WaitAndLogIfStuck(BlockingCounter* counter, const DescFn& desc_fn) {
VLOG(3) << "Begin: " << desc_fn();
const std::chrono::milliseconds timeout(5000);
bool ok = counter->WaitFor(timeout);
@ -202,6 +203,8 @@ struct RendezvousKey {
static_cast<int64>(instr->GetModule()->unique_id()));
}
int num_participants() const { return participating_replicas.size(); }
template <typename H>
friend H AbslHashValue(H h, const RendezvousKey& k) {
return H::combine(std::move(h), k.run_id, k.participating_replicas,
@ -248,9 +251,7 @@ struct ParticipantData {
se::DeviceMemoryBase destination_data;
se::Stream* stream;
int64 num_participants() const {
return rendezvous_key.participating_replicas.size();
}
int num_participants() const { return rendezvous_key.num_participants(); }
string ToString() const {
return absl::StrFormat(
@ -409,8 +410,7 @@ RefcountingHashMap<NcclCliqueKey, NcclClique>& GlobalNcclCliqueMap() {
// Rendezvous objects can only be used once.
class Rendezvous {
public:
explicit Rendezvous(const RendezvousKey& k)
: key_(k), all_participants_present_(k.participating_replicas.size()) {}
explicit Rendezvous(const RendezvousKey& k) : key_(k) {}
// Runs the all-reduce on the given thread. If successful, returns
// - a handle to the clique that was used, so that the caller may keep the
@ -418,25 +418,26 @@ class Rendezvous {
// - a BlockingCounter initialized to the number of participants, so that
// the caller can coordinate with the participants one last time if it
// chooses. This is useful for coordinating destruction of the Rendezvous.
StatusOr<std::pair<std::shared_ptr<NcclClique>,
std::shared_ptr<tensorflow::BlockingCounter>>>
StatusOr<
std::pair<std::shared_ptr<NcclClique>, std::shared_ptr<BlockingCounter>>>
SubmitParticipant(ParticipantData participant);
private:
Status DoAllReduce(ParticipantData participant, ncclComm_t comm);
const RendezvousKey key_;
tensorflow::BlockingCounter all_participants_present_;
BlockingCounter all_participants_present_{key_.num_participants()};
BlockingCounter done_{key_.num_participants()};
// BlockingCounter returned by SubmitParticipant.
std::shared_ptr<BlockingCounter> returned_blocking_counter_{
std::make_shared<BlockingCounter>(key_.num_participants())};
tensorflow::mutex mu_;
bool initialized_ GUARDED_BY(mu_) = false;
absl::optional<tensorflow::BlockingCounter> done_;
std::vector<ParticipantData> participants_ GUARDED_BY(mu_);
// BlockingCounter returned by SubmitParticipant. Initialized by the primary
// thread.
std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_;
std::vector<ParticipantData> participants_ GUARDED_BY(mu_);
};
// Global map of Rendezvous objects. A thread participating in a collective op
@ -451,8 +452,8 @@ RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
return m;
}
StatusOr<std::pair<std::shared_ptr<NcclClique>,
std::shared_ptr<tensorflow::BlockingCounter>>>
StatusOr<
std::pair<std::shared_ptr<NcclClique>, std::shared_ptr<BlockingCounter>>>
Rendezvous::SubmitParticipant(ParticipantData participant) {
{
tensorflow::mutex_lock lock(mu_);
@ -506,10 +507,6 @@ Rendezvous::SubmitParticipant(ParticipantData participant) {
if (primary) {
VLOG(3) << "Primary initializing accounting data.";
initialized_ = true;
done_.emplace(participant.num_participants());
returned_blocking_counter_ =
std::make_shared<tensorflow::BlockingCounter>(
participant.num_participants());
// Acquire exclusive access to the NCCL clique itself so that two
// unrelated collective operations won't try to use the clique
@ -535,12 +532,12 @@ Rendezvous::SubmitParticipant(ParticipantData participant) {
Status all_reduce_status = DoAllReduce(participant, comm);
VLOG(3) << "This thread done with all-reduce op.";
done_->DecrementCount();
done_.DecrementCount();
// The primary owns the lock on the NCCL clique. Hold it until all threads
// are done. (We'll release it when we return from this function.)
if (primary) {
WaitAndLogIfStuck(&*done_, [&] {
WaitAndLogIfStuck(&done_, [&] {
return absl::StrFormat(
"primary participant (device ordinal %d, stream %p) waiting for all "
"other participants to complete all-reduce %s",