[XLA] Fix race condition in RefcountingHashMap that could lead to a use-after-free segfault.

PiperOrigin-RevId: 347661005
Change-Id: I3bf498db9822cb87aacbbd6f0d66e267e6d81671
This commit is contained in:
Chris Jones 2020-12-15 11:51:23 -08:00 committed by TensorFlower Gardener
parent 4961664aa4
commit 2dd9bc663e

View File

@ -73,20 +73,19 @@ class RefcountingHashMap {
value_factory) { value_factory) {
absl::MutexLock lock(&mu_); absl::MutexLock lock(&mu_);
auto it = map_.find(key); auto it = map_.find(key);
// We ensure that the entry has not expired in case deleter was running when
// we have entered this block.
if (it != map_.end()) { if (it != map_.end()) {
// We ensure that the entry has not expired in case deleter was running
// when we have entered this block.
if (std::shared_ptr<V> value = it->second.lock()) { if (std::shared_ptr<V> value = it->second.lock()) {
return value; return value;
} }
map_.erase(it);
} }
// Create entry in the map and then set its value, so the value can // Create entry in the map and then set its value, so the value can
// contain a pointer back into the map. // contain a pointer back into the map.
TF_ASSIGN_OR_RETURN(std::unique_ptr<V> value_unique, value_factory(key)); TF_ASSIGN_OR_RETURN(std::unique_ptr<V> value_unique, value_factory(key));
it = map_.emplace(key, std::weak_ptr<V>()).first; it = map_.emplace(key, std::weak_ptr<V>()).first;
std::shared_ptr<V> value(value_unique.release(), Deleter{&it->first, this}); std::shared_ptr<V> value(value_unique.release(), Deleter{it->first, *this});
it->second = value; // Set the weak ptr to the shared ptr. it->second = value; // Set the weak ptr to the shared ptr.
return value; return value;
} }
@ -108,15 +107,17 @@ class RefcountingHashMap {
private: private:
struct Deleter { struct Deleter {
const K* key; // Points into parent->map_. const K& key; // Points into parent->map_.
RefcountingHashMap* parent; RefcountingHashMap& parent;
void operator()(V* v) { void operator()(V* v) {
delete v; delete v;
absl::MutexLock lock(&parent->mu_); absl::MutexLock lock(&parent.mu_);
auto it = parent->map_.find(*key); // We must check if that the entry is still expired in case the value was
if (it != parent->map_.end() && it->second.expired()) { // replaced while the deleter was running.
parent->map_.erase(it); auto it = parent.map_.find(key);
if (it != parent.map_.end() && it->second.expired()) {
parent.map_.erase(it);
} }
} }
}; };