Added RefCountingHashMap::GetOrTryCreateIfAbsent to allow for factory methods that may fail.

PiperOrigin-RevId: 344412270
Change-Id: I7b384c571a04e00405c17650e45b417aeb337fb0
This commit is contained in:
Chris Jones 2020-11-26 06:01:48 -08:00 committed by TensorFlower Gardener
parent 31f0b21597
commit d5b2156e4e
3 changed files with 34 additions and 2 deletions

View File

@ -940,6 +940,7 @@ cc_library(
name = "refcounting_hash_map",
hdrs = ["refcounting_hash_map.h"],
deps = [
":statusor",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/memory",
@ -955,6 +956,7 @@ tf_cc_test(
":test",
":types",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:errors",
],
)

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/container/node_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
@ -58,6 +59,18 @@ class RefcountingHashMap {
std::shared_ptr<V> GetOrCreateIfAbsent(
const K& key,
const std::function<std::unique_ptr<V>(const K&)>& value_factory) {
return *GetOrTryCreateIfAbsent(key, [&](const K& key) {
return StatusOr<std::unique_ptr<V>>(value_factory(key));
});
}
// Gets the value for the given key.
//
// If the map doesn't contain a live value for the key, constructs one
// using `value_factory`, or returns the status from `value_factory`.
StatusOr<std::shared_ptr<V>> GetOrTryCreateIfAbsent(
const K& key, const std::function<StatusOr<std::unique_ptr<V>>(const K&)>&
value_factory) {
absl::MutexLock lock(&mu_);
auto it = map_.find(key);
// We ensure that the entry has not expired in case deleter was running when
@ -71,9 +84,9 @@ class RefcountingHashMap {
// Create entry in the map and then set its value, so the value can
// contain a pointer back into the map.
TF_ASSIGN_OR_RETURN(std::unique_ptr<V> value_unique, value_factory(key));
it = map_.emplace(key, std::weak_ptr<V>()).first;
std::shared_ptr<V> value(value_factory(key).release(),
Deleter{&it->first, this});
std::shared_ptr<V> value(value_unique.release(), Deleter{&it->first, this});
it->second = value; // Set the weak ptr to the shared ptr.
return value;
}

View File

@ -16,9 +16,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
#include <functional>
#include <memory>
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/errors.h"
namespace xla {
namespace {
@ -79,6 +81,21 @@ TEST(RefcountingHashMapTest, CustomFactory) {
EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101);
}
TEST(RefcountingHashMapTest, TrySuccessful) {
RefcountingHashMap<int, int> m;
auto factory = [](const int&) { return absl::make_unique<int>(7); };
StatusOr<std::shared_ptr<int>> result = m.GetOrTryCreateIfAbsent(42, factory);
ASSERT_TRUE(result.ok());
EXPECT_EQ(**result, 7);
}
TEST(RefcountingHashMapTest, TryFailure) {
RefcountingHashMap<int, int> m;
Status status = tensorflow::errors::Internal("Arrggg!");
auto factory = [&](const int&) { return status; };
EXPECT_EQ(m.GetOrTryCreateIfAbsent(42, factory).status(), status);
}
TEST(RefcountingHashMapTest, ForEachEmpty) {
RefcountingHashMap<int, int> m;
int64 count = 0;