[XLA] Add RefcountingHashMap.
RefcountingHashMap is an eager, threadsafe, refcounted cache. When you look up an element, you get a handle to a new or existing element. When all handles to an existing element are destroyed, the element is freed and removed from the cache. Used in a future patch. PiperOrigin-RevId: 249797851
This commit is contained in:
parent
27ab98d8d5
commit
353e06c3a3
@ -881,6 +881,26 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "refcounting_hash_map",
|
||||
hdrs = ["refcounting_hash_map.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "refcounting_hash_map_test",
|
||||
srcs = ["refcounting_hash_map_test.cc"],
|
||||
deps = [
|
||||
":refcounting_hash_map",
|
||||
":test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.
|
||||
|
115
tensorflow/compiler/xla/refcounting_hash_map.h
Normal file
115
tensorflow/compiler/xla/refcounting_hash_map.h
Normal file
@ -0,0 +1,115 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/node_hash_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// RefcountingHashMap is an "eager, thread-safe cache".
|
||||
//
|
||||
// Given a key k you can retrieve a shared_ptr to a value v. If k is not
|
||||
// already in the map, we construct a new V; if it is already in the map, we'll
|
||||
// return the existing v. Once all shared_ptrs are destroyed, the entry is
|
||||
// removed from the map.
|
||||
//
|
||||
// This class is thread-safe.
|
||||
//
|
||||
// Word to the wise: You might want an erase() function here that removes a
|
||||
// value from the map but leaves existing shared_ptrs intact. My experience is,
|
||||
// this is extremely complicated to implement correctly.
|
||||
template <typename K, typename V>
|
||||
class RefcountingHashMap {
|
||||
public:
|
||||
// Default-constructs new values.
|
||||
RefcountingHashMap()
|
||||
: value_factory_([](const K&) { return absl::make_unique<V>(); }) {}
|
||||
|
||||
// Constructs new values according to the given factory function.
|
||||
explicit RefcountingHashMap(
|
||||
std::function<std::unique_ptr<V>(const K&)> value_factory)
|
||||
: value_factory_(std::move(value_factory)) {}
|
||||
|
||||
// Not copyable or movable because this contains internal pointers (namely,
|
||||
// instances of Deleter contain pointers to `this` and into `map_`).
|
||||
RefcountingHashMap(const RefcountingHashMap&) = delete;
|
||||
RefcountingHashMap(RefcountingHashMap&&) = delete;
|
||||
RefcountingHashMap& operator=(const RefcountingHashMap&) = delete;
|
||||
RefcountingHashMap& operator=(RefcountingHashMap&&) = delete;
|
||||
|
||||
// Gets the value for the given key.
|
||||
//
|
||||
// If the map doesn't contain a live value for the key, constructs one
|
||||
// according to the factory passed to the map's constructor.
|
||||
std::shared_ptr<V> operator[](const K& key) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = map_.find(key);
|
||||
if (it == map_.end()) {
|
||||
// Create entry in the map and then set its value, so the value can
|
||||
// contain a pointer back into the map.
|
||||
it = map_.emplace(key, std::weak_ptr<V>()).first;
|
||||
std::shared_ptr<V> value(value_factory_(key).release(),
|
||||
Deleter{&it->first, this});
|
||||
it->second = value; // Set the weak ptr to the shared ptr.
|
||||
return value;
|
||||
}
|
||||
return it->second.lock();
|
||||
}
|
||||
|
||||
// Runs a function over every key/value in the map.
|
||||
//
|
||||
// Touching the map from within this function may deadlock; don't do it.
|
||||
//
|
||||
// Function signature must be compatible with
|
||||
// void fn(const K&, std::shared_ptr<V>)
|
||||
//
|
||||
template <typename Fn>
|
||||
void ForEach(Fn&& fn) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
for (const auto& kv : map_) {
|
||||
fn(kv.first, kv.second.lock());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
struct Deleter {
|
||||
const K* key; // Points into parent->map_.
|
||||
RefcountingHashMap* parent;
|
||||
|
||||
void operator()(V* v) {
|
||||
delete v;
|
||||
absl::MutexLock lock(&parent->mu_);
|
||||
auto it = parent->map_.find(*key);
|
||||
CHECK(it != parent->map_.end());
|
||||
CHECK(it->second.expired());
|
||||
parent->map_.erase(it);
|
||||
}
|
||||
};
|
||||
|
||||
std::function<std::unique_ptr<V>(const K&)> value_factory_;
|
||||
absl::Mutex mu_;
|
||||
absl::node_hash_map<K, std::weak_ptr<V>> map_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_
|
101
tensorflow/compiler/xla/refcounting_hash_map_test.cc
Normal file
101
tensorflow/compiler/xla/refcounting_hash_map_test.cc
Normal file
@ -0,0 +1,101 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
struct DeleteNotifier {
|
||||
DeleteNotifier() = default;
|
||||
DeleteNotifier(const DeleteNotifier&) = delete;
|
||||
DeleteNotifier& operator=(const DeleteNotifier&) = delete;
|
||||
DeleteNotifier(DeleteNotifier&& o) noexcept : fn(std::move(o.fn)) {
|
||||
o.fn = nullptr;
|
||||
}
|
||||
DeleteNotifier& operator=(DeleteNotifier&& o) noexcept {
|
||||
fn = o.fn;
|
||||
o.fn = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
~DeleteNotifier() {
|
||||
if (fn) {
|
||||
fn();
|
||||
}
|
||||
}
|
||||
|
||||
std::function<void()> fn;
|
||||
};
|
||||
|
||||
TEST(RefcountingHashMapTest, PointerIdentity) {
|
||||
RefcountingHashMap<int, int> m;
|
||||
std::shared_ptr<int> a = m[0];
|
||||
std::shared_ptr<int> b = m[0];
|
||||
std::shared_ptr<int> c = m[1];
|
||||
EXPECT_EQ(a.get(), b.get());
|
||||
EXPECT_NE(a.get(), c.get());
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, DefaultInitialized) {
|
||||
RefcountingHashMap<int, int> m;
|
||||
EXPECT_EQ(*m[42], 0);
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, DeletesEagerly) {
|
||||
RefcountingHashMap<int, DeleteNotifier> m;
|
||||
bool deleted = false;
|
||||
auto handle = m[0];
|
||||
handle->fn = [&] { deleted = true; };
|
||||
EXPECT_FALSE(deleted);
|
||||
handle = nullptr;
|
||||
EXPECT_TRUE(deleted);
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, CustomFactory) {
|
||||
RefcountingHashMap<int, int> m(
|
||||
[](const int& x) { return absl::make_unique<int>(x + 1); });
|
||||
EXPECT_EQ(*m[0], 1);
|
||||
EXPECT_EQ(*m[100], 101);
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, ForEachEmpty) {
|
||||
RefcountingHashMap<int, int> m;
|
||||
int64 count = 0;
|
||||
m.ForEach([&](const int&, std::shared_ptr<int>) { ++count; });
|
||||
EXPECT_EQ(count, 0);
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, ForEachNonempty) {
|
||||
RefcountingHashMap<int, int> m;
|
||||
auto a = m[0];
|
||||
auto b = m[1];
|
||||
|
||||
std::vector<int> seen_keys;
|
||||
std::vector<int*> seen_values;
|
||||
m.ForEach([&](const int& k, std::shared_ptr<int> v) {
|
||||
seen_keys.push_back(k);
|
||||
seen_values.push_back(v.get());
|
||||
});
|
||||
EXPECT_THAT(seen_keys, testing::UnorderedElementsAre(0, 1));
|
||||
EXPECT_THAT(seen_values, testing::UnorderedElementsAre(a.get(), b.get()));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace xla
|
Loading…
x
Reference in New Issue
Block a user