[XLA] Add RefcountingHashMap.
RefcountingHashMap is an eager, threadsafe, refcounted cache. When you look up an element, you get a handle to a new or existing element. When all handles to an existing element are destroyed, the element is freed and removed from the cache. Used in a future patch. PiperOrigin-RevId: 249797851
This commit is contained in:
parent
27ab98d8d5
commit
353e06c3a3
@ -881,6 +881,26 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "refcounting_hash_map",
|
||||||
|
hdrs = ["refcounting_hash_map.h"],
|
||||||
|
deps = [
|
||||||
|
"@com_google_absl//absl/container:node_hash_map",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "refcounting_hash_map_test",
|
||||||
|
srcs = ["refcounting_hash_map_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":refcounting_hash_map",
|
||||||
|
":test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.
|
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.
|
||||||
|
115
tensorflow/compiler/xla/refcounting_hash_map.h
Normal file
115
tensorflow/compiler/xla/refcounting_hash_map.h
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/container/node_hash_map.h"
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// RefcountingHashMap is an "eager, thread-safe cache".
|
||||||
|
//
|
||||||
|
// Given a key k you can retrieve a shared_ptr to a value v. If k is not
|
||||||
|
// already in the map, we construct a new V; if it is already in the map, we'll
|
||||||
|
// return the existing v. Once all shared_ptrs are destroyed, the entry is
|
||||||
|
// removed from the map.
|
||||||
|
//
|
||||||
|
// This class is thread-safe.
|
||||||
|
//
|
||||||
|
// Word to the wise: You might want an erase() function here that removes a
|
||||||
|
// value from the map but leaves existing shared_ptrs intact. My experience is,
|
||||||
|
// this is extremely complicated to implement correctly.
|
||||||
|
template <typename K, typename V>
|
||||||
|
class RefcountingHashMap {
|
||||||
|
public:
|
||||||
|
// Default-constructs new values.
|
||||||
|
RefcountingHashMap()
|
||||||
|
: value_factory_([](const K&) { return absl::make_unique<V>(); }) {}
|
||||||
|
|
||||||
|
// Constructs new values according to the given factory function.
|
||||||
|
explicit RefcountingHashMap(
|
||||||
|
std::function<std::unique_ptr<V>(const K&)> value_factory)
|
||||||
|
: value_factory_(std::move(value_factory)) {}
|
||||||
|
|
||||||
|
// Not copyable or movable because this contains internal pointers (namely,
|
||||||
|
// instances of Deleter contain pointers to `this` and into `map_`).
|
||||||
|
RefcountingHashMap(const RefcountingHashMap&) = delete;
|
||||||
|
RefcountingHashMap(RefcountingHashMap&&) = delete;
|
||||||
|
RefcountingHashMap& operator=(const RefcountingHashMap&) = delete;
|
||||||
|
RefcountingHashMap& operator=(RefcountingHashMap&&) = delete;
|
||||||
|
|
||||||
|
// Gets the value for the given key.
|
||||||
|
//
|
||||||
|
// If the map doesn't contain a live value for the key, constructs one
|
||||||
|
// according to the factory passed to the map's constructor.
|
||||||
|
std::shared_ptr<V> operator[](const K& key) {
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
auto it = map_.find(key);
|
||||||
|
if (it == map_.end()) {
|
||||||
|
// Create entry in the map and then set its value, so the value can
|
||||||
|
// contain a pointer back into the map.
|
||||||
|
it = map_.emplace(key, std::weak_ptr<V>()).first;
|
||||||
|
std::shared_ptr<V> value(value_factory_(key).release(),
|
||||||
|
Deleter{&it->first, this});
|
||||||
|
it->second = value; // Set the weak ptr to the shared ptr.
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
return it->second.lock();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Runs a function over every key/value in the map.
|
||||||
|
//
|
||||||
|
// Touching the map from within this function may deadlock; don't do it.
|
||||||
|
//
|
||||||
|
// Function signature must be compatible with
|
||||||
|
// void fn(const K&, std::shared_ptr<V>)
|
||||||
|
//
|
||||||
|
template <typename Fn>
|
||||||
|
void ForEach(Fn&& fn) {
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
for (const auto& kv : map_) {
|
||||||
|
fn(kv.first, kv.second.lock());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct Deleter {
|
||||||
|
const K* key; // Points into parent->map_.
|
||||||
|
RefcountingHashMap* parent;
|
||||||
|
|
||||||
|
void operator()(V* v) {
|
||||||
|
delete v;
|
||||||
|
absl::MutexLock lock(&parent->mu_);
|
||||||
|
auto it = parent->map_.find(*key);
|
||||||
|
CHECK(it != parent->map_.end());
|
||||||
|
CHECK(it->second.expired());
|
||||||
|
parent->map_.erase(it);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::function<std::unique_ptr<V>(const K&)> value_factory_;
|
||||||
|
absl::Mutex mu_;
|
||||||
|
absl::node_hash_map<K, std::weak_ptr<V>> map_ GUARDED_BY(mu_);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_
|
101
tensorflow/compiler/xla/refcounting_hash_map_test.cc
Normal file
101
tensorflow/compiler/xla/refcounting_hash_map_test.cc
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct DeleteNotifier {
|
||||||
|
DeleteNotifier() = default;
|
||||||
|
DeleteNotifier(const DeleteNotifier&) = delete;
|
||||||
|
DeleteNotifier& operator=(const DeleteNotifier&) = delete;
|
||||||
|
DeleteNotifier(DeleteNotifier&& o) noexcept : fn(std::move(o.fn)) {
|
||||||
|
o.fn = nullptr;
|
||||||
|
}
|
||||||
|
DeleteNotifier& operator=(DeleteNotifier&& o) noexcept {
|
||||||
|
fn = o.fn;
|
||||||
|
o.fn = nullptr;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
~DeleteNotifier() {
|
||||||
|
if (fn) {
|
||||||
|
fn();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<void()> fn;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(RefcountingHashMapTest, PointerIdentity) {
|
||||||
|
RefcountingHashMap<int, int> m;
|
||||||
|
std::shared_ptr<int> a = m[0];
|
||||||
|
std::shared_ptr<int> b = m[0];
|
||||||
|
std::shared_ptr<int> c = m[1];
|
||||||
|
EXPECT_EQ(a.get(), b.get());
|
||||||
|
EXPECT_NE(a.get(), c.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RefcountingHashMapTest, DefaultInitialized) {
|
||||||
|
RefcountingHashMap<int, int> m;
|
||||||
|
EXPECT_EQ(*m[42], 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RefcountingHashMapTest, DeletesEagerly) {
|
||||||
|
RefcountingHashMap<int, DeleteNotifier> m;
|
||||||
|
bool deleted = false;
|
||||||
|
auto handle = m[0];
|
||||||
|
handle->fn = [&] { deleted = true; };
|
||||||
|
EXPECT_FALSE(deleted);
|
||||||
|
handle = nullptr;
|
||||||
|
EXPECT_TRUE(deleted);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RefcountingHashMapTest, CustomFactory) {
|
||||||
|
RefcountingHashMap<int, int> m(
|
||||||
|
[](const int& x) { return absl::make_unique<int>(x + 1); });
|
||||||
|
EXPECT_EQ(*m[0], 1);
|
||||||
|
EXPECT_EQ(*m[100], 101);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RefcountingHashMapTest, ForEachEmpty) {
|
||||||
|
RefcountingHashMap<int, int> m;
|
||||||
|
int64 count = 0;
|
||||||
|
m.ForEach([&](const int&, std::shared_ptr<int>) { ++count; });
|
||||||
|
EXPECT_EQ(count, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RefcountingHashMapTest, ForEachNonempty) {
|
||||||
|
RefcountingHashMap<int, int> m;
|
||||||
|
auto a = m[0];
|
||||||
|
auto b = m[1];
|
||||||
|
|
||||||
|
std::vector<int> seen_keys;
|
||||||
|
std::vector<int*> seen_values;
|
||||||
|
m.ForEach([&](const int& k, std::shared_ptr<int> v) {
|
||||||
|
seen_keys.push_back(k);
|
||||||
|
seen_values.push_back(v.get());
|
||||||
|
});
|
||||||
|
EXPECT_THAT(seen_keys, testing::UnorderedElementsAre(0, 1));
|
||||||
|
EXPECT_THAT(seen_values, testing::UnorderedElementsAre(a.get(), b.get()));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
} // namespace xla
|
Loading…
x
Reference in New Issue
Block a user