Added RefCountingHashMap::GetOrTryCreateIfAbsent to allow for factory methods that may fail.
				
					
				
			PiperOrigin-RevId: 344412270 Change-Id: I7b384c571a04e00405c17650e45b417aeb337fb0
This commit is contained in:
		
							parent
							
								
									31f0b21597
								
							
						
					
					
						commit
						d5b2156e4e
					
				@ -940,6 +940,7 @@ cc_library(
 | 
				
			|||||||
    name = "refcounting_hash_map",
 | 
					    name = "refcounting_hash_map",
 | 
				
			||||||
    hdrs = ["refcounting_hash_map.h"],
 | 
					    hdrs = ["refcounting_hash_map.h"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":statusor",
 | 
				
			||||||
        "@com_google_absl//absl/base:core_headers",
 | 
					        "@com_google_absl//absl/base:core_headers",
 | 
				
			||||||
        "@com_google_absl//absl/container:node_hash_map",
 | 
					        "@com_google_absl//absl/container:node_hash_map",
 | 
				
			||||||
        "@com_google_absl//absl/memory",
 | 
					        "@com_google_absl//absl/memory",
 | 
				
			||||||
@ -955,6 +956,7 @@ tf_cc_test(
 | 
				
			|||||||
        ":test",
 | 
					        ":test",
 | 
				
			||||||
        ":types",
 | 
					        ":types",
 | 
				
			||||||
        "//tensorflow/core:test_main",
 | 
					        "//tensorflow/core:test_main",
 | 
				
			||||||
 | 
					        "//tensorflow/core/platform:errors",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -23,6 +23,7 @@ limitations under the License.
 | 
				
			|||||||
#include "absl/container/node_hash_map.h"
 | 
					#include "absl/container/node_hash_map.h"
 | 
				
			||||||
#include "absl/memory/memory.h"
 | 
					#include "absl/memory/memory.h"
 | 
				
			||||||
#include "absl/synchronization/mutex.h"
 | 
					#include "absl/synchronization/mutex.h"
 | 
				
			||||||
 | 
					#include "tensorflow/compiler/xla/statusor.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace xla {
 | 
					namespace xla {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -58,6 +59,18 @@ class RefcountingHashMap {
 | 
				
			|||||||
  std::shared_ptr<V> GetOrCreateIfAbsent(
 | 
					  std::shared_ptr<V> GetOrCreateIfAbsent(
 | 
				
			||||||
      const K& key,
 | 
					      const K& key,
 | 
				
			||||||
      const std::function<std::unique_ptr<V>(const K&)>& value_factory) {
 | 
					      const std::function<std::unique_ptr<V>(const K&)>& value_factory) {
 | 
				
			||||||
 | 
					    return *GetOrTryCreateIfAbsent(key, [&](const K& key) {
 | 
				
			||||||
 | 
					      return StatusOr<std::unique_ptr<V>>(value_factory(key));
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Gets the value for the given key.
 | 
				
			||||||
 | 
					  //
 | 
				
			||||||
 | 
					  // If the map doesn't contain a live value for the key, constructs one
 | 
				
			||||||
 | 
					  // using `value_factory`, or returns the status from `value_factory`.
 | 
				
			||||||
 | 
					  StatusOr<std::shared_ptr<V>> GetOrTryCreateIfAbsent(
 | 
				
			||||||
 | 
					      const K& key, const std::function<StatusOr<std::unique_ptr<V>>(const K&)>&
 | 
				
			||||||
 | 
					                        value_factory) {
 | 
				
			||||||
    absl::MutexLock lock(&mu_);
 | 
					    absl::MutexLock lock(&mu_);
 | 
				
			||||||
    auto it = map_.find(key);
 | 
					    auto it = map_.find(key);
 | 
				
			||||||
    // We ensure that the entry has not expired in case deleter was running when
 | 
					    // We ensure that the entry has not expired in case deleter was running when
 | 
				
			||||||
@ -71,9 +84,9 @@ class RefcountingHashMap {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    // Create entry in the map and then set its value, so the value can
 | 
					    // Create entry in the map and then set its value, so the value can
 | 
				
			||||||
    // contain a pointer back into the map.
 | 
					    // contain a pointer back into the map.
 | 
				
			||||||
 | 
					    TF_ASSIGN_OR_RETURN(std::unique_ptr<V> value_unique, value_factory(key));
 | 
				
			||||||
    it = map_.emplace(key, std::weak_ptr<V>()).first;
 | 
					    it = map_.emplace(key, std::weak_ptr<V>()).first;
 | 
				
			||||||
    std::shared_ptr<V> value(value_factory(key).release(),
 | 
					    std::shared_ptr<V> value(value_unique.release(), Deleter{&it->first, this});
 | 
				
			||||||
                             Deleter{&it->first, this});
 | 
					 | 
				
			||||||
    it->second = value;  // Set the weak ptr to the shared ptr.
 | 
					    it->second = value;  // Set the weak ptr to the shared ptr.
 | 
				
			||||||
    return value;
 | 
					    return value;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
				
			|||||||
@ -16,9 +16,11 @@ limitations under the License.
 | 
				
			|||||||
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
 | 
					#include "tensorflow/compiler/xla/refcounting_hash_map.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <functional>
 | 
					#include <functional>
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "tensorflow/compiler/xla/test.h"
 | 
					#include "tensorflow/compiler/xla/test.h"
 | 
				
			||||||
#include "tensorflow/compiler/xla/types.h"
 | 
					#include "tensorflow/compiler/xla/types.h"
 | 
				
			||||||
 | 
					#include "tensorflow/core/platform/errors.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace xla {
 | 
					namespace xla {
 | 
				
			||||||
namespace {
 | 
					namespace {
 | 
				
			||||||
@ -79,6 +81,21 @@ TEST(RefcountingHashMapTest, CustomFactory) {
 | 
				
			|||||||
  EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101);
 | 
					  EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(RefcountingHashMapTest, TrySuccessful) {
 | 
				
			||||||
 | 
					  RefcountingHashMap<int, int> m;
 | 
				
			||||||
 | 
					  auto factory = [](const int&) { return absl::make_unique<int>(7); };
 | 
				
			||||||
 | 
					  StatusOr<std::shared_ptr<int>> result = m.GetOrTryCreateIfAbsent(42, factory);
 | 
				
			||||||
 | 
					  ASSERT_TRUE(result.ok());
 | 
				
			||||||
 | 
					  EXPECT_EQ(**result, 7);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(RefcountingHashMapTest, TryFailure) {
 | 
				
			||||||
 | 
					  RefcountingHashMap<int, int> m;
 | 
				
			||||||
 | 
					  Status status = tensorflow::errors::Internal("Arrggg!");
 | 
				
			||||||
 | 
					  auto factory = [&](const int&) { return status; };
 | 
				
			||||||
 | 
					  EXPECT_EQ(m.GetOrTryCreateIfAbsent(42, factory).status(), status);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST(RefcountingHashMapTest, ForEachEmpty) {
 | 
					TEST(RefcountingHashMapTest, ForEachEmpty) {
 | 
				
			||||||
  RefcountingHashMap<int, int> m;
 | 
					  RefcountingHashMap<int, int> m;
 | 
				
			||||||
  int64 count = 0;
 | 
					  int64 count = 0;
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user