[XLA-GPU] Fix undefined behaviour in NCCL utils.

PiperOrigin-RevId: 346755970
Change-Id: Ibb09087dbc3232b8df78382717cc3313174bfdd5
This commit is contained in:
Chris Jones 2020-12-10 04:19:23 -08:00 committed by TensorFlower Gardener
parent 9ca469ee6d
commit 60a18fee45
3 changed files with 45 additions and 10 deletions

View File

@ -236,7 +236,7 @@ class Rendezvous {
"rendezvous: %p",
rendezvous.get());
});
return p.first;
return std::move(p.first);
}
protected:

View File

@ -16,9 +16,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
#include <memory>
#include <utility>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
@ -221,7 +223,8 @@ class NcclCliqueRendezvous
: Rendezvous(rendezvous_key),
key_(std::move(rendezvous_key.global_devices)),
local_participants_(local_participants),
callback_(callback) {}
callback_(callback),
counter_(nullptr) {}
StatusOr<LockedNcclClique> RunCollectiveOp(
const NcclCliqueParticipantData&) override {
@ -235,10 +238,13 @@ class NcclCliqueRendezvous
initialized_ = true;
}
TF_ASSIGN_OR_RETURN(std::shared_ptr<NcclClique> clique, maybe_clique_);
std::unique_ptr<absl::MutexLock> clique_lock;
if (primary) {
lock_ = std::make_shared<absl::MutexLock>(clique->mu());
clique_lock = std::make_unique<absl::MutexLock>(clique->mu());
counter_ = new absl::BlockingCounter(local_participants_.size());
}
return LockedNcclClique{clique, lock_};
return LockedNcclClique(std::move(clique), std::move(clique_lock),
counter_);
}
private:
@ -247,7 +253,7 @@ class NcclCliqueRendezvous
const NcclUniqueIdCallback* callback_;
StatusOr<std::shared_ptr<NcclClique>> maybe_clique_;
std::shared_ptr<absl::MutexLock> lock_;
absl::BlockingCounter* counter_;
};
} // namespace
@ -282,6 +288,26 @@ StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
return local_participants;
}
LockedNcclClique::LockedNcclClique(std::shared_ptr<NcclClique> clique,
std::unique_ptr<absl::MutexLock> lock,
absl::BlockingCounter* counter)
: clique(std::move(clique)), lock_(std::move(lock)), counter_(counter) {}
LockedNcclClique::LockedNcclClique(LockedNcclClique&& other)
: clique(std::move(other.clique)),
lock_(std::move(other.lock_)),
counter_(std::exchange(other.counter_, nullptr)) {}
LockedNcclClique::~LockedNcclClique() {
if (counter_) {
counter_->DecrementCount();
if (lock_) {
counter_->Wait(); // Don't release lock until all threads are finished.
delete counter_;
}
}
}
StatusOr<LockedNcclClique> AcquireNcclClique(
const RendezvousKey& rendezvous_key, int local_device_ordinal,
se::Stream* stream, const std::vector<LocalParticipant>& local_participants,

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/synchronization/mutex.h"
#if GOOGLE_CUDA
#include "third_party/nccl/nccl.h"
@ -110,13 +111,21 @@ StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
const std::vector<GlobalDeviceId>& participants,
const std::vector<GlobalDeviceId>* local_devices); // may be null
struct LockedNcclClique {
class LockedNcclClique {
public:
LockedNcclClique(std::shared_ptr<NcclClique> clique,
std::unique_ptr<absl::MutexLock> lock,
absl::BlockingCounter* counter);
LockedNcclClique(LockedNcclClique&&);
~LockedNcclClique();
std::shared_ptr<NcclClique> clique;
private:
// Must come after clique, so it is destroyed first.
// This lock prevents other threads from using this clique. All of the threads
// involved should hold onto the lock until they have finished with their
// communicator.
std::shared_ptr<absl::MutexLock> lock;
// One thread holds a lock (it is null in the others).
std::unique_ptr<absl::MutexLock> lock_;
absl::BlockingCounter* counter_;
};
// Acquires a locked NCCL clique for use in NCCL collective operations.