[XLA-GPU] Fix undefined behaviour in NCCL utils.
PiperOrigin-RevId: 346755970 Change-Id: Ibb09087dbc3232b8df78382717cc3313174bfdd5
This commit is contained in:
parent
9ca469ee6d
commit
60a18fee45
@ -236,7 +236,7 @@ class Rendezvous {
|
||||
"rendezvous: %p",
|
||||
rendezvous.get());
|
||||
});
|
||||
return p.first;
|
||||
return std::move(p.first);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -16,9 +16,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/blocking_counter.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||
@ -221,7 +223,8 @@ class NcclCliqueRendezvous
|
||||
: Rendezvous(rendezvous_key),
|
||||
key_(std::move(rendezvous_key.global_devices)),
|
||||
local_participants_(local_participants),
|
||||
callback_(callback) {}
|
||||
callback_(callback),
|
||||
counter_(nullptr) {}
|
||||
|
||||
StatusOr<LockedNcclClique> RunCollectiveOp(
|
||||
const NcclCliqueParticipantData&) override {
|
||||
@ -235,10 +238,13 @@ class NcclCliqueRendezvous
|
||||
initialized_ = true;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<NcclClique> clique, maybe_clique_);
|
||||
std::unique_ptr<absl::MutexLock> clique_lock;
|
||||
if (primary) {
|
||||
lock_ = std::make_shared<absl::MutexLock>(clique->mu());
|
||||
clique_lock = std::make_unique<absl::MutexLock>(clique->mu());
|
||||
counter_ = new absl::BlockingCounter(local_participants_.size());
|
||||
}
|
||||
return LockedNcclClique{clique, lock_};
|
||||
return LockedNcclClique(std::move(clique), std::move(clique_lock),
|
||||
counter_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -247,7 +253,7 @@ class NcclCliqueRendezvous
|
||||
const NcclUniqueIdCallback* callback_;
|
||||
|
||||
StatusOr<std::shared_ptr<NcclClique>> maybe_clique_;
|
||||
std::shared_ptr<absl::MutexLock> lock_;
|
||||
absl::BlockingCounter* counter_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@ -282,6 +288,26 @@ StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
|
||||
return local_participants;
|
||||
}
|
||||
|
||||
LockedNcclClique::LockedNcclClique(std::shared_ptr<NcclClique> clique,
|
||||
std::unique_ptr<absl::MutexLock> lock,
|
||||
absl::BlockingCounter* counter)
|
||||
: clique(std::move(clique)), lock_(std::move(lock)), counter_(counter) {}
|
||||
|
||||
LockedNcclClique::LockedNcclClique(LockedNcclClique&& other)
|
||||
: clique(std::move(other.clique)),
|
||||
lock_(std::move(other.lock_)),
|
||||
counter_(std::exchange(other.counter_, nullptr)) {}
|
||||
|
||||
LockedNcclClique::~LockedNcclClique() {
|
||||
if (counter_) {
|
||||
counter_->DecrementCount();
|
||||
if (lock_) {
|
||||
counter_->Wait(); // Don't release lock until all threads are finished.
|
||||
delete counter_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<LockedNcclClique> AcquireNcclClique(
|
||||
const RendezvousKey& rendezvous_key, int local_device_ordinal,
|
||||
se::Stream* stream, const std::vector<LocalParticipant>& local_participants,
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/synchronization/blocking_counter.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/nccl/nccl.h"
|
||||
@ -110,13 +111,21 @@ StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
|
||||
const std::vector<GlobalDeviceId>& participants,
|
||||
const std::vector<GlobalDeviceId>* local_devices); // may be null
|
||||
|
||||
struct LockedNcclClique {
|
||||
class LockedNcclClique {
|
||||
public:
|
||||
LockedNcclClique(std::shared_ptr<NcclClique> clique,
|
||||
std::unique_ptr<absl::MutexLock> lock,
|
||||
absl::BlockingCounter* counter);
|
||||
LockedNcclClique(LockedNcclClique&&);
|
||||
~LockedNcclClique();
|
||||
|
||||
std::shared_ptr<NcclClique> clique;
|
||||
|
||||
private:
|
||||
// Must come after clique, so it is destroyed first.
|
||||
// This lock prevents other threads from using this clique. All of the threads
|
||||
// involved should hold onto the lock until they have finished with their
|
||||
// communicator.
|
||||
std::shared_ptr<absl::MutexLock> lock;
|
||||
// One thread holds a lock (it is null in the others).
|
||||
std::unique_ptr<absl::MutexLock> lock_;
|
||||
absl::BlockingCounter* counter_;
|
||||
};
|
||||
|
||||
// Acquires a locked NCCL clique for use in NCCL collective operations.
|
||||
|
Loading…
x
Reference in New Issue
Block a user