[XLA] Fix race condition in RefcountingHashMap that could lead to a use-after-free segfault.

PiperOrigin-RevId: 347661005
Change-Id: I3bf498db9822cb87aacbbd6f0d66e267e6d81671
This commit is contained in:
Chris Jones 2020-12-15 11:51:23 -08:00 committed by TensorFlower Gardener
parent 4961664aa4
commit 2dd9bc663e

View File

@ -73,20 +73,19 @@ class RefcountingHashMap {
value_factory) {
absl::MutexLock lock(&mu_);
auto it = map_.find(key);
// We ensure that the entry has not expired in case deleter was running when
// we have entered this block.
if (it != map_.end()) {
// We ensure that the entry has not expired in case deleter was running
// when we have entered this block.
if (std::shared_ptr<V> value = it->second.lock()) {
return value;
}
map_.erase(it);
}
// Create entry in the map and then set its value, so the value can
// contain a pointer back into the map.
TF_ASSIGN_OR_RETURN(std::unique_ptr<V> value_unique, value_factory(key));
it = map_.emplace(key, std::weak_ptr<V>()).first;
std::shared_ptr<V> value(value_unique.release(), Deleter{&it->first, this});
std::shared_ptr<V> value(value_unique.release(), Deleter{it->first, *this});
it->second = value; // Set the weak ptr to the shared ptr.
return value;
}
@ -108,15 +107,17 @@ class RefcountingHashMap {
private:
struct Deleter {
const K* key; // Points into parent->map_.
RefcountingHashMap* parent;
const K& key; // Points into parent->map_.
RefcountingHashMap& parent;
void operator()(V* v) {
delete v;
absl::MutexLock lock(&parent->mu_);
auto it = parent->map_.find(*key);
if (it != parent->map_.end() && it->second.expired()) {
parent->map_.erase(it);
absl::MutexLock lock(&parent.mu_);
// We must check if that the entry is still expired in case the value was
// replaced while the deleter was running.
auto it = parent.map_.find(key);
if (it != parent.map_.end() && it->second.expired()) {
parent.map_.erase(it);
}
}
};