Added RefCountingHashMap::GetOrTryCreateIfAbsent
to allow for factory methods that may fail.
PiperOrigin-RevId: 344412270 Change-Id: I7b384c571a04e00405c17650e45b417aeb337fb0
This commit is contained in:
parent
31f0b21597
commit
d5b2156e4e
@ -940,6 +940,7 @@ cc_library(
|
|||||||
name = "refcounting_hash_map",
|
name = "refcounting_hash_map",
|
||||||
hdrs = ["refcounting_hash_map.h"],
|
hdrs = ["refcounting_hash_map.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":statusor",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"@com_google_absl//absl/container:node_hash_map",
|
"@com_google_absl//absl/container:node_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
@ -955,6 +956,7 @@ tf_cc_test(
|
|||||||
":test",
|
":test",
|
||||||
":types",
|
":types",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/platform:errors",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "absl/container/node_hash_map.h"
|
#include "absl/container/node_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -58,6 +59,18 @@ class RefcountingHashMap {
|
|||||||
std::shared_ptr<V> GetOrCreateIfAbsent(
|
std::shared_ptr<V> GetOrCreateIfAbsent(
|
||||||
const K& key,
|
const K& key,
|
||||||
const std::function<std::unique_ptr<V>(const K&)>& value_factory) {
|
const std::function<std::unique_ptr<V>(const K&)>& value_factory) {
|
||||||
|
return *GetOrTryCreateIfAbsent(key, [&](const K& key) {
|
||||||
|
return StatusOr<std::unique_ptr<V>>(value_factory(key));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gets the value for the given key.
|
||||||
|
//
|
||||||
|
// If the map doesn't contain a live value for the key, constructs one
|
||||||
|
// using `value_factory`, or returns the status from `value_factory`.
|
||||||
|
StatusOr<std::shared_ptr<V>> GetOrTryCreateIfAbsent(
|
||||||
|
const K& key, const std::function<StatusOr<std::unique_ptr<V>>(const K&)>&
|
||||||
|
value_factory) {
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
auto it = map_.find(key);
|
auto it = map_.find(key);
|
||||||
// We ensure that the entry has not expired in case deleter was running when
|
// We ensure that the entry has not expired in case deleter was running when
|
||||||
@ -71,9 +84,9 @@ class RefcountingHashMap {
|
|||||||
|
|
||||||
// Create entry in the map and then set its value, so the value can
|
// Create entry in the map and then set its value, so the value can
|
||||||
// contain a pointer back into the map.
|
// contain a pointer back into the map.
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<V> value_unique, value_factory(key));
|
||||||
it = map_.emplace(key, std::weak_ptr<V>()).first;
|
it = map_.emplace(key, std::weak_ptr<V>()).first;
|
||||||
std::shared_ptr<V> value(value_factory(key).release(),
|
std::shared_ptr<V> value(value_unique.release(), Deleter{&it->first, this});
|
||||||
Deleter{&it->first, this});
|
|
||||||
it->second = value; // Set the weak ptr to the shared ptr.
|
it->second = value; // Set the weak ptr to the shared ptr.
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
@ -16,9 +16,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
|
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
@ -79,6 +81,21 @@ TEST(RefcountingHashMapTest, CustomFactory) {
|
|||||||
EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101);
|
EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(RefcountingHashMapTest, TrySuccessful) {
|
||||||
|
RefcountingHashMap<int, int> m;
|
||||||
|
auto factory = [](const int&) { return absl::make_unique<int>(7); };
|
||||||
|
StatusOr<std::shared_ptr<int>> result = m.GetOrTryCreateIfAbsent(42, factory);
|
||||||
|
ASSERT_TRUE(result.ok());
|
||||||
|
EXPECT_EQ(**result, 7);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RefcountingHashMapTest, TryFailure) {
|
||||||
|
RefcountingHashMap<int, int> m;
|
||||||
|
Status status = tensorflow::errors::Internal("Arrggg!");
|
||||||
|
auto factory = [&](const int&) { return status; };
|
||||||
|
EXPECT_EQ(m.GetOrTryCreateIfAbsent(42, factory).status(), status);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(RefcountingHashMapTest, ForEachEmpty) {
|
TEST(RefcountingHashMapTest, ForEachEmpty) {
|
||||||
RefcountingHashMap<int, int> m;
|
RefcountingHashMap<int, int> m;
|
||||||
int64 count = 0;
|
int64 count = 0;
|
||||||
|
Loading…
Reference in New Issue
Block a user