Store NcclCliques in new NcclCliqueMap.
PiperOrigin-RevId: 354505971 Change-Id: If7f52af22a37e125a587971abbe665f855335e37
This commit is contained in:
parent
78250f869b
commit
90ec1352d9
tensorflow/compiler/xla
@ -59,18 +59,6 @@ class RefcountingHashMap {
|
||||
std::shared_ptr<V> GetOrCreateIfAbsent(
|
||||
const K& key,
|
||||
const std::function<std::unique_ptr<V>(const K&)>& value_factory) {
|
||||
return *GetOrTryCreateIfAbsent(key, [&](const K& key) {
|
||||
return StatusOr<std::unique_ptr<V>>(value_factory(key));
|
||||
});
|
||||
}
|
||||
|
||||
// Gets the value for the given key.
|
||||
//
|
||||
// If the map doesn't contain a live value for the key, constructs one
|
||||
// using `value_factory`, or returns the status from `value_factory`.
|
||||
StatusOr<std::shared_ptr<V>> GetOrTryCreateIfAbsent(
|
||||
const K& key, const std::function<StatusOr<std::unique_ptr<V>>(const K&)>&
|
||||
value_factory) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = map_.find(key);
|
||||
if (it != map_.end()) {
|
||||
@ -83,28 +71,13 @@ class RefcountingHashMap {
|
||||
|
||||
// Create entry in the map and then set its value, so the value can
|
||||
// contain a pointer back into the map.
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<V> value_unique, value_factory(key));
|
||||
it = map_.emplace(key, std::weak_ptr<V>()).first;
|
||||
std::shared_ptr<V> value(value_unique.release(), Deleter{it->first, *this});
|
||||
std::shared_ptr<V> value(value_factory(key).release(),
|
||||
Deleter{it->first, *this});
|
||||
it->second = value; // Set the weak ptr to the shared ptr.
|
||||
return value;
|
||||
}
|
||||
|
||||
// Runs a function over every key/value in the map.
|
||||
//
|
||||
// Touching the map from within this function may deadlock; don't do it.
|
||||
//
|
||||
// Function signature must be compatible with
|
||||
// void fn(const K&, std::shared_ptr<V>)
|
||||
//
|
||||
template <typename Fn>
|
||||
void ForEach(Fn&& fn) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
for (const auto& kv : map_) {
|
||||
fn(kv.first, kv.second.lock());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
struct Deleter {
|
||||
const K& key; // Points into parent->map_.
|
||||
|
@ -81,43 +81,5 @@ TEST(RefcountingHashMapTest, CustomFactory) {
|
||||
EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101);
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, TrySuccessful) {
|
||||
RefcountingHashMap<int, int> m;
|
||||
auto factory = [](const int&) { return absl::make_unique<int>(7); };
|
||||
StatusOr<std::shared_ptr<int>> result = m.GetOrTryCreateIfAbsent(42, factory);
|
||||
ASSERT_TRUE(result.ok());
|
||||
EXPECT_EQ(**result, 7);
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, TryFailure) {
|
||||
RefcountingHashMap<int, int> m;
|
||||
Status status = tensorflow::errors::Internal("Arrggg!");
|
||||
auto factory = [&](const int&) { return status; };
|
||||
EXPECT_EQ(m.GetOrTryCreateIfAbsent(42, factory).status(), status);
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, ForEachEmpty) {
|
||||
RefcountingHashMap<int, int> m;
|
||||
int64 count = 0;
|
||||
m.ForEach([&](const int&, std::shared_ptr<int>) { ++count; });
|
||||
EXPECT_EQ(count, 0);
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, ForEachNonempty) {
|
||||
RefcountingHashMap<int, int> m;
|
||||
auto factory = [](const int&) { return absl::make_unique<int>(); };
|
||||
auto a = m.GetOrCreateIfAbsent(0, factory);
|
||||
auto b = m.GetOrCreateIfAbsent(1, factory);
|
||||
|
||||
std::vector<int> seen_keys;
|
||||
std::vector<int*> seen_values;
|
||||
m.ForEach([&](const int& k, std::shared_ptr<int> v) {
|
||||
seen_keys.push_back(k);
|
||||
seen_values.push_back(v.get());
|
||||
});
|
||||
EXPECT_THAT(seen_keys, testing::UnorderedElementsAre(0, 1));
|
||||
EXPECT_THAT(seen_values, testing::UnorderedElementsAre(a.get(), b.get()));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace xla
|
||||
|
@ -115,16 +115,12 @@ Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
AcquireNcclClique(rendezvous_key, device_ordinal, params.stream,
|
||||
local_participants, params.nccl_unique_id_callback));
|
||||
ncclComm_t comm =
|
||||
locked_clique.clique->GetCommForDeviceOrdinal(device_ordinal);
|
||||
locked_clique.clique.GetCommForDeviceOrdinal(device_ordinal);
|
||||
|
||||
se::StreamExecutor* executor = params.stream->parent();
|
||||
se::gpu::ScopedActivateExecutorContext scoped_context(executor);
|
||||
|
||||
TF_RETURN_IF_ERROR(RunNcclCollective(params, comm));
|
||||
|
||||
// Keep the clique we used alive for as long as this thunk lives.
|
||||
absl::MutexLock lock(&mu_);
|
||||
cliques_.insert(std::move(locked_clique.clique));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -93,19 +93,12 @@ class NcclCollectiveThunk : public Thunk {
|
||||
// error.
|
||||
static bool NcclIsEnabled();
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override
|
||||
ABSL_LOCKS_EXCLUDED(mu_);
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
protected:
|
||||
virtual Status RunNcclCollective(const ExecuteParams& params,
|
||||
ncclComm_t comm) = 0;
|
||||
virtual const NcclCollectiveConfig& config() const = 0;
|
||||
|
||||
private:
|
||||
// Creating NCCL cliques is expensive, so we cache them.
|
||||
absl::Mutex mu_;
|
||||
absl::flat_hash_set<std::shared_ptr<NcclClique>> cliques_
|
||||
ABSL_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -105,16 +105,14 @@ ncclComm_t NcclClique::GetCommForDeviceOrdinal(int device_ordinal) const {
|
||||
return comms_by_device_ordinal_.at(device_ordinal).get();
|
||||
}
|
||||
|
||||
RefcountingHashMap<NcclCliqueKey, NcclClique>& NcclCliqueCache() {
|
||||
// Global cache of NCCL cliques. An entry in this map is kept alive as long
|
||||
// as there's a reference to it somewhere. A Thunk holds a reference to each
|
||||
// Clique it's ever used.
|
||||
NcclCliqueMap& NcclCliqueCache() {
|
||||
// Global cache of NCCL cliques. An entry in this map is always kept alive.
|
||||
//
|
||||
// A consequence of the fact that this is process-global is that we'll only
|
||||
// ever have one clique alive for a given set of GPUs. This means that a
|
||||
// process will never do two collective operations concurrently on the same
|
||||
// set of GPUs.
|
||||
static auto& cache = *new RefcountingHashMap<NcclCliqueKey, NcclClique>();
|
||||
static auto& cache = *new NcclCliqueMap();
|
||||
return cache;
|
||||
}
|
||||
|
||||
@ -237,14 +235,13 @@ class NcclCliqueRendezvous
|
||||
});
|
||||
initialized_ = true;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<NcclClique> clique, maybe_clique_);
|
||||
TF_ASSIGN_OR_RETURN(NcclClique * clique, maybe_clique_);
|
||||
std::unique_ptr<absl::MutexLock> clique_lock;
|
||||
if (primary) {
|
||||
clique_lock = std::make_unique<absl::MutexLock>(clique->mu());
|
||||
counter_ = new absl::BlockingCounter(local_participants_.size());
|
||||
}
|
||||
return LockedNcclClique(std::move(clique), std::move(clique_lock),
|
||||
counter_);
|
||||
return LockedNcclClique(*clique, std::move(clique_lock), counter_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -252,7 +249,7 @@ class NcclCliqueRendezvous
|
||||
const std::vector<LocalParticipant>& local_participants_;
|
||||
const NcclUniqueIdCallback* callback_;
|
||||
|
||||
StatusOr<std::shared_ptr<NcclClique>> maybe_clique_;
|
||||
StatusOr<NcclClique*> maybe_clique_;
|
||||
absl::BlockingCounter* counter_;
|
||||
};
|
||||
|
||||
@ -288,13 +285,13 @@ StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
|
||||
return local_participants;
|
||||
}
|
||||
|
||||
LockedNcclClique::LockedNcclClique(std::shared_ptr<NcclClique> clique,
|
||||
LockedNcclClique::LockedNcclClique(NcclClique& clique,
|
||||
std::unique_ptr<absl::MutexLock> lock,
|
||||
absl::BlockingCounter* counter)
|
||||
: clique(std::move(clique)), lock_(std::move(lock)), counter_(counter) {}
|
||||
: clique(clique), lock_(std::move(lock)), counter_(counter) {}
|
||||
|
||||
LockedNcclClique::LockedNcclClique(LockedNcclClique&& other)
|
||||
: clique(std::move(other.clique)),
|
||||
: clique(other.clique),
|
||||
lock_(std::move(other.lock_)),
|
||||
counter_(std::exchange(other.counter_, nullptr)) {}
|
||||
|
||||
@ -308,6 +305,27 @@ LockedNcclClique::~LockedNcclClique() {
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<NcclClique*> NcclCliqueMap::GetOrTryCreateIfAbsent(
|
||||
const NcclCliqueKey& key,
|
||||
const std::function<StatusOr<std::unique_ptr<NcclClique>>(
|
||||
const NcclCliqueKey&)>& value_factory) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = map_.find(key);
|
||||
if (it == map_.end()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<NcclClique> value, value_factory(key));
|
||||
it = map_.emplace(key, std::move(value)).first;
|
||||
}
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
void NcclCliqueMap::ForEach(
|
||||
const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
for (const auto& kv : map_) {
|
||||
fn(kv.first, *kv.second);
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<LockedNcclClique> AcquireNcclClique(
|
||||
const RendezvousKey& rendezvous_key, int local_device_ordinal,
|
||||
se::Stream* stream, const std::vector<LocalParticipant>& local_participants,
|
||||
|
@ -100,8 +100,6 @@ class NcclClique {
|
||||
absl::Mutex mu_;
|
||||
};
|
||||
|
||||
RefcountingHashMap<NcclCliqueKey, NcclClique>& NcclCliqueCache();
|
||||
|
||||
struct LocalParticipant {
|
||||
int device_ordinal;
|
||||
int rank;
|
||||
@ -113,13 +111,12 @@ StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
|
||||
|
||||
class LockedNcclClique {
|
||||
public:
|
||||
LockedNcclClique(std::shared_ptr<NcclClique> clique,
|
||||
std::unique_ptr<absl::MutexLock> lock,
|
||||
LockedNcclClique(NcclClique& clique, std::unique_ptr<absl::MutexLock> lock,
|
||||
absl::BlockingCounter* counter);
|
||||
LockedNcclClique(LockedNcclClique&&);
|
||||
~LockedNcclClique();
|
||||
|
||||
std::shared_ptr<NcclClique> clique;
|
||||
NcclClique& clique;
|
||||
|
||||
private:
|
||||
// Must come after clique, so it is destroyed first.
|
||||
@ -128,6 +125,27 @@ class LockedNcclClique {
|
||||
absl::BlockingCounter* counter_;
|
||||
};
|
||||
|
||||
// Threadsafe leaky map from NcclCliqueKeys to NcclCliques.
|
||||
class NcclCliqueMap {
|
||||
public:
|
||||
StatusOr<NcclClique*> GetOrTryCreateIfAbsent(
|
||||
const NcclCliqueKey& key,
|
||||
const std::function<StatusOr<std::unique_ptr<NcclClique>>(
|
||||
const NcclCliqueKey&)>& value_factory) ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Runs a function over every key/value in the map.
|
||||
void ForEach(
|
||||
const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn)
|
||||
ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
private:
|
||||
absl::Mutex mu_;
|
||||
absl::flat_hash_map<NcclCliqueKey, std::unique_ptr<NcclClique>> map_
|
||||
ABSL_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
NcclCliqueMap& NcclCliqueCache();
|
||||
|
||||
// Acquires a locked NCCL clique for use in NCCL collective operations.
|
||||
StatusOr<LockedNcclClique> AcquireNcclClique(
|
||||
const RendezvousKey& rendezvous_key, int local_device_ordinal,
|
||||
|
Loading…
Reference in New Issue
Block a user