From d5b2156e4e5367cc336d26fbbac3a6e76781de05 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Thu, 26 Nov 2020 06:01:48 -0800 Subject: [PATCH] Added `RefCountingHashMap::GetOrTryCreateIfAbsent` to allow for factory methods that may fail. PiperOrigin-RevId: 344412270 Change-Id: I7b384c571a04e00405c17650e45b417aeb337fb0 --- tensorflow/compiler/xla/BUILD | 2 ++ tensorflow/compiler/xla/refcounting_hash_map.h | 17 +++++++++++++++-- .../compiler/xla/refcounting_hash_map_test.cc | 17 +++++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 2db1caa1e05..6de4ae517f3 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -940,6 +940,7 @@ cc_library( name = "refcounting_hash_map", hdrs = ["refcounting_hash_map.h"], deps = [ + ":statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", @@ -955,6 +956,7 @@ tf_cc_test( ":test", ":types", "//tensorflow/core:test_main", + "//tensorflow/core/platform:errors", ], ) diff --git a/tensorflow/compiler/xla/refcounting_hash_map.h b/tensorflow/compiler/xla/refcounting_hash_map.h index efa1b9e3a50..a9d07a741c4 100644 --- a/tensorflow/compiler/xla/refcounting_hash_map.h +++ b/tensorflow/compiler/xla/refcounting_hash_map.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/container/node_hash_map.h" #include "absl/memory/memory.h" #include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/xla/statusor.h" namespace xla { @@ -58,6 +59,18 @@ class RefcountingHashMap { std::shared_ptr GetOrCreateIfAbsent( const K& key, const std::function(const K&)>& value_factory) { + return *GetOrTryCreateIfAbsent(key, [&](const K& key) { + return StatusOr>(value_factory(key)); + }); + } + + // Gets the value for the given key. + // + // If the map doesn't contain a live value for the key, constructs one + // using `value_factory`, or returns the status from `value_factory`. + StatusOr> GetOrTryCreateIfAbsent( + const K& key, const std::function>(const K&)>& + value_factory) { absl::MutexLock lock(&mu_); auto it = map_.find(key); // We ensure that the entry has not expired in case deleter was running when @@ -71,9 +84,9 @@ class RefcountingHashMap { // Create entry in the map and then set its value, so the value can // contain a pointer back into the map. + TF_ASSIGN_OR_RETURN(std::unique_ptr value_unique, value_factory(key)); it = map_.emplace(key, std::weak_ptr()).first; - std::shared_ptr value(value_factory(key).release(), - Deleter{&it->first, this}); + std::shared_ptr value(value_unique.release(), Deleter{&it->first, this}); it->second = value; // Set the weak ptr to the shared ptr. return value; } diff --git a/tensorflow/compiler/xla/refcounting_hash_map_test.cc b/tensorflow/compiler/xla/refcounting_hash_map_test.cc index acb7d7afb46..8ead034d1bc 100644 --- a/tensorflow/compiler/xla/refcounting_hash_map_test.cc +++ b/tensorflow/compiler/xla/refcounting_hash_map_test.cc @@ -16,9 +16,11 @@ limitations under the License. #include "tensorflow/compiler/xla/refcounting_hash_map.h" #include +#include #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/errors.h" namespace xla { namespace { @@ -79,6 +81,21 @@ TEST(RefcountingHashMapTest, CustomFactory) { EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101); } +TEST(RefcountingHashMapTest, TrySuccessful) { + RefcountingHashMap m; + auto factory = [](const int&) { return absl::make_unique(7); }; + StatusOr> result = m.GetOrTryCreateIfAbsent(42, factory); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(**result, 7); +} + +TEST(RefcountingHashMapTest, TryFailure) { + RefcountingHashMap m; + Status status = tensorflow::errors::Internal("Arrggg!"); + auto factory = [&](const int&) { return status; }; + EXPECT_EQ(m.GetOrTryCreateIfAbsent(42, factory).status(), status); +} + TEST(RefcountingHashMapTest, ForEachEmpty) { RefcountingHashMap m; int64 count = 0;