Store NcclCliques in new NcclCliqueMap.

PiperOrigin-RevId: 354505971
Change-Id: If7f52af22a37e125a587971abbe665f855335e37
This commit is contained in:
Loren Maggiore 2021-01-29 04:08:22 -08:00 committed by TensorFlower Gardener
parent 78250f869b
commit 90ec1352d9
6 changed files with 57 additions and 97 deletions

View File

@ -59,18 +59,6 @@ class RefcountingHashMap {
std::shared_ptr<V> GetOrCreateIfAbsent(
const K& key,
const std::function<std::unique_ptr<V>(const K&)>& value_factory) {
return *GetOrTryCreateIfAbsent(key, [&](const K& key) {
return StatusOr<std::unique_ptr<V>>(value_factory(key));
});
}
// Gets the value for the given key.
//
// If the map doesn't contain a live value for the key, constructs one
// using `value_factory`, or returns the status from `value_factory`.
StatusOr<std::shared_ptr<V>> GetOrTryCreateIfAbsent(
const K& key, const std::function<StatusOr<std::unique_ptr<V>>(const K&)>&
value_factory) {
absl::MutexLock lock(&mu_);
auto it = map_.find(key);
if (it != map_.end()) {
@ -83,28 +71,13 @@ class RefcountingHashMap {
// Create entry in the map and then set its value, so the value can
// contain a pointer back into the map.
TF_ASSIGN_OR_RETURN(std::unique_ptr<V> value_unique, value_factory(key));
it = map_.emplace(key, std::weak_ptr<V>()).first;
std::shared_ptr<V> value(value_unique.release(), Deleter{it->first, *this});
std::shared_ptr<V> value(value_factory(key).release(),
Deleter{it->first, *this});
it->second = value; // Set the weak ptr to the shared ptr.
return value;
}
// Runs a function over every key/value in the map.
//
// Touching the map from within this function may deadlock; don't do it.
//
// Function signature must be compatible with
// void fn(const K&, std::shared_ptr<V>)
//
template <typename Fn>
void ForEach(Fn&& fn) {
absl::MutexLock lock(&mu_);
for (const auto& kv : map_) {
fn(kv.first, kv.second.lock());
}
}
private:
struct Deleter {
const K& key; // Points into parent->map_.

View File

@ -81,43 +81,5 @@ TEST(RefcountingHashMapTest, CustomFactory) {
EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101);
}
TEST(RefcountingHashMapTest, TrySuccessful) {
RefcountingHashMap<int, int> m;
auto factory = [](const int&) { return absl::make_unique<int>(7); };
StatusOr<std::shared_ptr<int>> result = m.GetOrTryCreateIfAbsent(42, factory);
ASSERT_TRUE(result.ok());
EXPECT_EQ(**result, 7);
}
TEST(RefcountingHashMapTest, TryFailure) {
RefcountingHashMap<int, int> m;
Status status = tensorflow::errors::Internal("Arrggg!");
auto factory = [&](const int&) { return status; };
EXPECT_EQ(m.GetOrTryCreateIfAbsent(42, factory).status(), status);
}
TEST(RefcountingHashMapTest, ForEachEmpty) {
RefcountingHashMap<int, int> m;
int64 count = 0;
m.ForEach([&](const int&, std::shared_ptr<int>) { ++count; });
EXPECT_EQ(count, 0);
}
TEST(RefcountingHashMapTest, ForEachNonempty) {
RefcountingHashMap<int, int> m;
auto factory = [](const int&) { return absl::make_unique<int>(); };
auto a = m.GetOrCreateIfAbsent(0, factory);
auto b = m.GetOrCreateIfAbsent(1, factory);
std::vector<int> seen_keys;
std::vector<int*> seen_values;
m.ForEach([&](const int& k, std::shared_ptr<int> v) {
seen_keys.push_back(k);
seen_values.push_back(v.get());
});
EXPECT_THAT(seen_keys, testing::UnorderedElementsAre(0, 1));
EXPECT_THAT(seen_values, testing::UnorderedElementsAre(a.get(), b.get()));
}
} // anonymous namespace
} // namespace xla

View File

@ -115,16 +115,12 @@ Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
AcquireNcclClique(rendezvous_key, device_ordinal, params.stream,
local_participants, params.nccl_unique_id_callback));
ncclComm_t comm =
locked_clique.clique->GetCommForDeviceOrdinal(device_ordinal);
locked_clique.clique.GetCommForDeviceOrdinal(device_ordinal);
se::StreamExecutor* executor = params.stream->parent();
se::gpu::ScopedActivateExecutorContext scoped_context(executor);
TF_RETURN_IF_ERROR(RunNcclCollective(params, comm));
// Keep the clique we used alive for as long as this thunk lives.
absl::MutexLock lock(&mu_);
cliques_.insert(std::move(locked_clique.clique));
return Status::OK();
}

View File

@ -93,19 +93,12 @@ class NcclCollectiveThunk : public Thunk {
// error.
static bool NcclIsEnabled();
Status ExecuteOnStream(const ExecuteParams& params) override
ABSL_LOCKS_EXCLUDED(mu_);
Status ExecuteOnStream(const ExecuteParams& params) override;
protected:
virtual Status RunNcclCollective(const ExecuteParams& params,
ncclComm_t comm) = 0;
virtual const NcclCollectiveConfig& config() const = 0;
private:
// Creating NCCL cliques is expensive, so we cache them.
absl::Mutex mu_;
absl::flat_hash_set<std::shared_ptr<NcclClique>> cliques_
ABSL_GUARDED_BY(mu_);
};
} // namespace gpu

View File

@ -105,16 +105,14 @@ ncclComm_t NcclClique::GetCommForDeviceOrdinal(int device_ordinal) const {
return comms_by_device_ordinal_.at(device_ordinal).get();
}
RefcountingHashMap<NcclCliqueKey, NcclClique>& NcclCliqueCache() {
// Global cache of NCCL cliques. An entry in this map is kept alive as long
// as there's a reference to it somewhere. A Thunk holds a reference to each
// Clique it's ever used.
NcclCliqueMap& NcclCliqueCache() {
// Global cache of NCCL cliques. An entry in this map is always kept alive.
//
// A consequence of the fact that this is process-global is that we'll only
// ever have one clique alive for a given set of GPUs. This means that a
// process will never do two collective operations concurrently on the same
// set of GPUs.
static auto& cache = *new RefcountingHashMap<NcclCliqueKey, NcclClique>();
static auto& cache = *new NcclCliqueMap();
return cache;
}
@ -237,14 +235,13 @@ class NcclCliqueRendezvous
});
initialized_ = true;
}
TF_ASSIGN_OR_RETURN(std::shared_ptr<NcclClique> clique, maybe_clique_);
TF_ASSIGN_OR_RETURN(NcclClique * clique, maybe_clique_);
std::unique_ptr<absl::MutexLock> clique_lock;
if (primary) {
clique_lock = std::make_unique<absl::MutexLock>(clique->mu());
counter_ = new absl::BlockingCounter(local_participants_.size());
}
return LockedNcclClique(std::move(clique), std::move(clique_lock),
counter_);
return LockedNcclClique(*clique, std::move(clique_lock), counter_);
}
private:
@ -252,7 +249,7 @@ class NcclCliqueRendezvous
const std::vector<LocalParticipant>& local_participants_;
const NcclUniqueIdCallback* callback_;
StatusOr<std::shared_ptr<NcclClique>> maybe_clique_;
StatusOr<NcclClique*> maybe_clique_;
absl::BlockingCounter* counter_;
};
@ -288,13 +285,13 @@ StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
return local_participants;
}
LockedNcclClique::LockedNcclClique(std::shared_ptr<NcclClique> clique,
LockedNcclClique::LockedNcclClique(NcclClique& clique,
std::unique_ptr<absl::MutexLock> lock,
absl::BlockingCounter* counter)
: clique(std::move(clique)), lock_(std::move(lock)), counter_(counter) {}
: clique(clique), lock_(std::move(lock)), counter_(counter) {}
LockedNcclClique::LockedNcclClique(LockedNcclClique&& other)
: clique(std::move(other.clique)),
: clique(other.clique),
lock_(std::move(other.lock_)),
counter_(std::exchange(other.counter_, nullptr)) {}
@ -308,6 +305,27 @@ LockedNcclClique::~LockedNcclClique() {
}
}
StatusOr<NcclClique*> NcclCliqueMap::GetOrTryCreateIfAbsent(
const NcclCliqueKey& key,
const std::function<StatusOr<std::unique_ptr<NcclClique>>(
const NcclCliqueKey&)>& value_factory) {
absl::MutexLock lock(&mu_);
auto it = map_.find(key);
if (it == map_.end()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<NcclClique> value, value_factory(key));
it = map_.emplace(key, std::move(value)).first;
}
return it->second.get();
}
void NcclCliqueMap::ForEach(
const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn) {
absl::MutexLock lock(&mu_);
for (const auto& kv : map_) {
fn(kv.first, *kv.second);
}
}
StatusOr<LockedNcclClique> AcquireNcclClique(
const RendezvousKey& rendezvous_key, int local_device_ordinal,
se::Stream* stream, const std::vector<LocalParticipant>& local_participants,

View File

@ -100,8 +100,6 @@ class NcclClique {
absl::Mutex mu_;
};
RefcountingHashMap<NcclCliqueKey, NcclClique>& NcclCliqueCache();
struct LocalParticipant {
int device_ordinal;
int rank;
@ -113,13 +111,12 @@ StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
class LockedNcclClique {
public:
LockedNcclClique(std::shared_ptr<NcclClique> clique,
std::unique_ptr<absl::MutexLock> lock,
LockedNcclClique(NcclClique& clique, std::unique_ptr<absl::MutexLock> lock,
absl::BlockingCounter* counter);
LockedNcclClique(LockedNcclClique&&);
~LockedNcclClique();
std::shared_ptr<NcclClique> clique;
NcclClique& clique;
private:
// Must come after clique, so it is destroyed first.
@ -128,6 +125,27 @@ class LockedNcclClique {
absl::BlockingCounter* counter_;
};
// Threadsafe leaky map from NcclCliqueKeys to NcclCliques.
class NcclCliqueMap {
public:
StatusOr<NcclClique*> GetOrTryCreateIfAbsent(
const NcclCliqueKey& key,
const std::function<StatusOr<std::unique_ptr<NcclClique>>(
const NcclCliqueKey&)>& value_factory) ABSL_LOCKS_EXCLUDED(mu_);
// Runs a function over every key/value in the map.
void ForEach(
const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn)
ABSL_LOCKS_EXCLUDED(mu_);
private:
absl::Mutex mu_;
absl::flat_hash_map<NcclCliqueKey, std::unique_ptr<NcclClique>> map_
ABSL_GUARDED_BY(mu_);
};
NcclCliqueMap& NcclCliqueCache();
// Acquires a locked NCCL clique for use in NCCL collective operations.
StatusOr<LockedNcclClique> AcquireNcclClique(
const RendezvousKey& rendezvous_key, int local_device_ordinal,