Added RefCountingHashMap::GetOrTryCreateIfAbsent
to allow for factory methods that may fail.
PiperOrigin-RevId: 344412270 Change-Id: I7b384c571a04e00405c17650e45b417aeb337fb0
This commit is contained in:
parent
31f0b21597
commit
d5b2156e4e
tensorflow/compiler/xla
@ -940,6 +940,7 @@ cc_library(
|
||||
name = "refcounting_hash_map",
|
||||
hdrs = ["refcounting_hash_map.h"],
|
||||
deps = [
|
||||
":statusor",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -955,6 +956,7 @@ tf_cc_test(
|
||||
":test",
|
||||
":types",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "absl/container/node_hash_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -58,6 +59,18 @@ class RefcountingHashMap {
|
||||
std::shared_ptr<V> GetOrCreateIfAbsent(
|
||||
const K& key,
|
||||
const std::function<std::unique_ptr<V>(const K&)>& value_factory) {
|
||||
return *GetOrTryCreateIfAbsent(key, [&](const K& key) {
|
||||
return StatusOr<std::unique_ptr<V>>(value_factory(key));
|
||||
});
|
||||
}
|
||||
|
||||
// Gets the value for the given key.
|
||||
//
|
||||
// If the map doesn't contain a live value for the key, constructs one
|
||||
// using `value_factory`, or returns the status from `value_factory`.
|
||||
StatusOr<std::shared_ptr<V>> GetOrTryCreateIfAbsent(
|
||||
const K& key, const std::function<StatusOr<std::unique_ptr<V>>(const K&)>&
|
||||
value_factory) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = map_.find(key);
|
||||
// We ensure that the entry has not expired in case deleter was running when
|
||||
@ -71,9 +84,9 @@ class RefcountingHashMap {
|
||||
|
||||
// Create entry in the map and then set its value, so the value can
|
||||
// contain a pointer back into the map.
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<V> value_unique, value_factory(key));
|
||||
it = map_.emplace(key, std::weak_ptr<V>()).first;
|
||||
std::shared_ptr<V> value(value_factory(key).release(),
|
||||
Deleter{&it->first, this});
|
||||
std::shared_ptr<V> value(value_unique.release(), Deleter{&it->first, this});
|
||||
it->second = value; // Set the weak ptr to the shared ptr.
|
||||
return value;
|
||||
}
|
||||
|
@ -16,9 +16,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
@ -79,6 +81,21 @@ TEST(RefcountingHashMapTest, CustomFactory) {
|
||||
EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101);
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, TrySuccessful) {
|
||||
RefcountingHashMap<int, int> m;
|
||||
auto factory = [](const int&) { return absl::make_unique<int>(7); };
|
||||
StatusOr<std::shared_ptr<int>> result = m.GetOrTryCreateIfAbsent(42, factory);
|
||||
ASSERT_TRUE(result.ok());
|
||||
EXPECT_EQ(**result, 7);
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, TryFailure) {
|
||||
RefcountingHashMap<int, int> m;
|
||||
Status status = tensorflow::errors::Internal("Arrggg!");
|
||||
auto factory = [&](const int&) { return status; };
|
||||
EXPECT_EQ(m.GetOrTryCreateIfAbsent(42, factory).status(), status);
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, ForEachEmpty) {
|
||||
RefcountingHashMap<int, int> m;
|
||||
int64 count = 0;
|
||||
|
Loading…
Reference in New Issue
Block a user