From bb5d0144b9f2a54c4202e7583e85c154447d09dc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Feb 2020 08:48:40 -0800 Subject: [PATCH] PR #36468: Add block cache for low level table library Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/36468 This is part of a patch series aiming to improve the performance of on-disk dataset.cache() (CacheDatasetV2). Currently CacheDataset uses core/util/tensor_bundle to cache dataset elements on disks. It uses sorted string table (SST) to index dataset elements. Unlike checkpoints which do not have a great number of tensors, caching a large dataset may incur a greater number of tensors as well as index blocks. If the index block is present in an in-memory LRU block cache, fetching a dataset element only needs 1 round trip instead of 2. This is particularly useful when CacheDataset are read from remote file system at a higher latency such as HDFS and GCS. Almost all code are imported from the LevelDB project, in particular the hash function to shard LRU cache. Currently using Hash32 in core/lib/hash fails the EvictionPolicy test. I only make 2 modifications to the original cache: 1. Alias leveldb::Slice to tensorflow::StringPiece 2. Switch to tensorflow::mutex for all mutexes. Ping @jsimsa to review. Copybara import of the project: -- 4c28247f5f3f6fcd12e82757befd7d90bf413e2c by Bairen Yi : Add block cache for low level table library This is part of a patch series aiming to improve the performance of on-disk dataset.cache() (CacheDatasetV2). Currently CacheDataset uses core/util/tensor_bundle to cache dataset elements on disks. It uses sorted string table (SST) to index dataset elements. Unlike checkpoints which do not have a great number of tensors, caching a large dataset may incur a greater number of tensors as well as index blocks. If the index block is present in an in-memory LRU block cache, fetching a dataset element only needs 1 round trip instead of 2. This is particularly useful when CacheDataset are read from remote file system at a higher latency such as HDFS and GCS. Almost all code are imported from the LevelDB project, in particular the hash function to shard LRU cache. Currently using Hash32 in core/lib/hash fails the EvictionPolicy test. I only make 2 modifications to the original cache: 1. Alias leveldb::Slice to tensorflow::StringPiece, which transitively aliases 2. Switch to tensorflow::mutex for all mutexes. Signed-off-by: Bairen Yi -- b69b43382ea7692ffd60ad50b118ac0646ceecc8 by Bairen Yi : tensor_bundle: Enable cache for metadata table The index cache is by default disabled unless one set the TF_TABLE_INDEX_CACHE_SIZE_IN_MB environment variable. Signed-off-by: Bairen Yi PiperOrigin-RevId: 297125962 Change-Id: Ibfec97b19f337d40f5726f656ee9c6487ce552d0 --- tensorflow/core/lib/io/BUILD | 20 + tensorflow/core/lib/io/cache.cc | 450 ++++++++++++++++++ tensorflow/core/lib/io/cache.h | 125 +++++ tensorflow/core/lib/io/cache_test.cc | 238 +++++++++ tensorflow/core/lib/io/table.cc | 55 ++- tensorflow/core/lib/io/table_options.h | 8 + .../core/util/tensor_bundle/tensor_bundle.cc | 17 +- .../core/util/tensor_bundle/tensor_bundle.h | 5 +- 8 files changed, 903 insertions(+), 15 deletions(-) create mode 100644 tensorflow/core/lib/io/cache.cc create mode 100644 tensorflow/core/lib/io/cache.h create mode 100644 tensorflow/core/lib/io/cache_test.cc diff --git a/tensorflow/core/lib/io/BUILD b/tensorflow/core/lib/io/BUILD index 87b5090a59f..d03a895b429 100644 --- a/tensorflow/core/lib/io/BUILD +++ b/tensorflow/core/lib/io/BUILD @@ -208,6 +208,21 @@ cc_library( alwayslink = True, ) +cc_library( + name = "cache", + srcs = [ + "cache.cc", + ], + hdrs = [ + "cache.h", + ], + deps = [ + "//tensorflow/core/platform:coding", + "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:stringpiece", + ], +) + cc_library( name = "table", srcs = [ @@ -220,6 +235,7 @@ cc_library( ], deps = [ ":block", + ":cache", ":iterator", ":table_options", "//tensorflow/core/lib/core:coding", @@ -290,6 +306,8 @@ filegroup( "block_builder.h", "buffered_inputstream.cc", "buffered_inputstream.h", + "cache.cc", + "cache.h", "compression.cc", "compression.h", "format.cc", @@ -352,6 +370,7 @@ filegroup( name = "legacy_lib_io_all_tests", srcs = [ "buffered_inputstream_test.cc", + "cache_test.cc", "inputbuffer_test.cc", "inputstream_interface_test.cc", "path_test.cc", @@ -369,6 +388,7 @@ filegroup( name = "legacy_lib_io_headers", srcs = [ "buffered_inputstream.h", + "cache.h", "compression.h", "inputstream_interface.h", "path.h", diff --git a/tensorflow/core/lib/io/cache.cc b/tensorflow/core/lib/io/cache.cc new file mode 100644 index 00000000000..b5521b1752b --- /dev/null +++ b/tensorflow/core/lib/io/cache.cc @@ -0,0 +1,450 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/cache.h" + +#include +#include +#include +#include + +#include "tensorflow/core/platform/coding.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +namespace table { + +Cache::~Cache() {} + +namespace { + +// LRU cache implementation +// +// Cache entries have an "in_cache" boolean indicating whether the cache has a +// reference on the entry. The only ways that this can become false without the +// entry being passed to its "deleter" are via Erase(), via Insert() when +// an element with a duplicate key is inserted, or on destruction of the cache. +// +// The cache keeps two linked lists of items in the cache. All items in the +// cache are in one list or the other, and never both. Items still referenced +// by clients but erased from the cache are in neither list. The lists are: +// - in-use: contains the items currently referenced by clients, in no +// particular order. (This list is used for invariant checking. If we +// removed the check, elements that would otherwise be on this list could be +// left as disconnected singleton lists.) +// - LRU: contains the items not currently referenced by clients, in LRU order +// Elements are moved between these lists by the Ref() and Unref() methods, +// when they detect an element in the cache acquiring or losing its only +// external reference. + +// An entry is a variable length heap-allocated structure. Entries +// are kept in a circular doubly linked list ordered by access time. +struct LRUHandle { + void* value; + void (*deleter)(const Slice&, void* value); + LRUHandle* next_hash; + LRUHandle* next; + LRUHandle* prev; + size_t charge; // TODO(opt): Only allow uint32_t? + size_t key_length; + bool in_cache; // Whether entry is in the cache. + uint32_t refs; // References, including cache reference, if present. + uint32_t hash; // Hash of key(); used for fast sharding and comparisons + char key_data[1]; // Beginning of key + + Slice key() const { + // next_ is only equal to this if the LRU handle is the list head of an + // empty list. List heads never have meaningful keys. + assert(next != this); + + return Slice(key_data, key_length); + } +}; + +// We provide our own simple hash table since it removes a whole bunch +// of porting hacks and is also faster than some of the built-in hash +// table implementations in some of the compiler/runtime combinations +// we have tested. E.g., readrandom speeds up by ~5% over the g++ +// 4.4.3's builtin hashtable. +class HandleTable { + public: + HandleTable() : length_(0), elems_(0), list_(nullptr) { Resize(); } + ~HandleTable() { delete[] list_; } + + LRUHandle* Lookup(const Slice& key, uint32_t hash) { + return *FindPointer(key, hash); + } + + LRUHandle* Insert(LRUHandle* h) { + LRUHandle** ptr = FindPointer(h->key(), h->hash); + LRUHandle* old = *ptr; + h->next_hash = (old == nullptr ? nullptr : old->next_hash); + *ptr = h; + if (old == nullptr) { + ++elems_; + if (elems_ > length_) { + // Since each cache entry is fairly large, we aim for a small + // average linked list length (<= 1). + Resize(); + } + } + return old; + } + + LRUHandle* Remove(const Slice& key, uint32_t hash) { + LRUHandle** ptr = FindPointer(key, hash); + LRUHandle* result = *ptr; + if (result != nullptr) { + *ptr = result->next_hash; + --elems_; + } + return result; + } + + private: + // The table consists of an array of buckets where each bucket is + // a linked list of cache entries that hash into the bucket. + uint32_t length_; + uint32_t elems_; + LRUHandle** list_; + + // Return a pointer to slot that points to a cache entry that + // matches key/hash. If there is no such cache entry, return a + // pointer to the trailing slot in the corresponding linked list. + LRUHandle** FindPointer(const Slice& key, uint32_t hash) { + LRUHandle** ptr = &list_[hash & (length_ - 1)]; + while (*ptr != nullptr && ((*ptr)->hash != hash || key != (*ptr)->key())) { + ptr = &(*ptr)->next_hash; + } + return ptr; + } + + void Resize() { + uint32_t new_length = 4; + while (new_length < elems_) { + new_length *= 2; + } + LRUHandle** new_list = new LRUHandle*[new_length]; + memset(new_list, 0, sizeof(new_list[0]) * new_length); + uint32_t count = 0; + for (uint32_t i = 0; i < length_; i++) { + LRUHandle* h = list_[i]; + while (h != nullptr) { + LRUHandle* next = h->next_hash; + uint32_t hash = h->hash; + LRUHandle** ptr = &new_list[hash & (new_length - 1)]; + h->next_hash = *ptr; + *ptr = h; + h = next; + count++; + } + } + assert(elems_ == count); + delete[] list_; + list_ = new_list; + length_ = new_length; + } +}; + +// A single shard of sharded cache. +class LRUCache { + public: + LRUCache(); + ~LRUCache(); + + // Separate from constructor so caller can easily make an array of LRUCache + void SetCapacity(size_t capacity) { capacity_ = capacity; } + + // Like Cache methods, but with an extra "hash" parameter. + Cache::Handle* Insert(const Slice& key, uint32_t hash, void* value, + size_t charge, + void (*deleter)(const Slice& key, void* value)); + Cache::Handle* Lookup(const Slice& key, uint32_t hash); + void Release(Cache::Handle* handle); + void Erase(const Slice& key, uint32_t hash); + void Prune(); + size_t TotalCharge() const { + mutex_lock l(mutex_); + return usage_; + } + + private: + void LRU_Remove(LRUHandle* e); + void LRU_Append(LRUHandle* list, LRUHandle* e); + void Ref(LRUHandle* e); + void Unref(LRUHandle* e); + bool FinishErase(LRUHandle* e) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Initialized before use. + size_t capacity_; + + // mutex_ protects the following state. + mutable mutex mutex_; + size_t usage_ GUARDED_BY(mutex_); + + // Dummy head of LRU list. + // lru.prev is newest entry, lru.next is oldest entry. + // Entries have refs==1 and in_cache==true. + LRUHandle lru_ GUARDED_BY(mutex_); + + // Dummy head of in-use list. + // Entries are in use by clients, and have refs >= 2 and in_cache==true. + LRUHandle in_use_ GUARDED_BY(mutex_); + + HandleTable table_ GUARDED_BY(mutex_); +}; + +LRUCache::LRUCache() : capacity_(0), usage_(0) { + // Make empty circular linked lists. + lru_.next = &lru_; + lru_.prev = &lru_; + in_use_.next = &in_use_; + in_use_.prev = &in_use_; +} + +LRUCache::~LRUCache() { + assert(in_use_.next == &in_use_); // Error if caller has an unreleased handle + for (LRUHandle* e = lru_.next; e != &lru_;) { + LRUHandle* next = e->next; + assert(e->in_cache); + e->in_cache = false; + assert(e->refs == 1); // Invariant of lru_ list. + Unref(e); + e = next; + } +} + +void LRUCache::Ref(LRUHandle* e) { + if (e->refs == 1 && e->in_cache) { // If on lru_ list, move to in_use_ list. + LRU_Remove(e); + LRU_Append(&in_use_, e); + } + e->refs++; +} + +void LRUCache::Unref(LRUHandle* e) { + assert(e->refs > 0); + e->refs--; + if (e->refs == 0) { // Deallocate. + assert(!e->in_cache); + (*e->deleter)(e->key(), e->value); + free(e); + } else if (e->in_cache && e->refs == 1) { + // No longer in use; move to lru_ list. + LRU_Remove(e); + LRU_Append(&lru_, e); + } +} + +void LRUCache::LRU_Remove(LRUHandle* e) { + e->next->prev = e->prev; + e->prev->next = e->next; +} + +void LRUCache::LRU_Append(LRUHandle* list, LRUHandle* e) { + // Make "e" newest entry by inserting just before *list + e->next = list; + e->prev = list->prev; + e->prev->next = e; + e->next->prev = e; +} + +Cache::Handle* LRUCache::Lookup(const Slice& key, uint32_t hash) { + mutex_lock l(mutex_); + LRUHandle* e = table_.Lookup(key, hash); + if (e != nullptr) { + Ref(e); + } + return reinterpret_cast(e); +} + +void LRUCache::Release(Cache::Handle* handle) { + mutex_lock l(mutex_); + Unref(reinterpret_cast(handle)); +} + +Cache::Handle* LRUCache::Insert(const Slice& key, uint32_t hash, void* value, + size_t charge, + void (*deleter)(const Slice& key, + void* value)) { + mutex_lock l(mutex_); + + LRUHandle* e = + reinterpret_cast(malloc(sizeof(LRUHandle) - 1 + key.size())); + e->value = value; + e->deleter = deleter; + e->charge = charge; + e->key_length = key.size(); + e->hash = hash; + e->in_cache = false; + e->refs = 1; // for the returned handle. + memcpy(e->key_data, key.data(), key.size()); + + if (capacity_ > 0) { + e->refs++; // for the cache's reference. + e->in_cache = true; + LRU_Append(&in_use_, e); + usage_ += charge; + FinishErase(table_.Insert(e)); + } else { // don't cache. (capacity_==0 is supported and turns off caching.) + // next is read by key() in an assert, so it must be initialized + e->next = nullptr; + } + while (usage_ > capacity_ && lru_.next != &lru_) { + LRUHandle* old = lru_.next; + assert(old->refs == 1); + bool erased = FinishErase(table_.Remove(old->key(), old->hash)); + if (!erased) { // to avoid unused variable when compiled NDEBUG + assert(erased); + } + } + + return reinterpret_cast(e); +} + +// If e != nullptr, finish removing *e from the cache; it has already been +// removed from the hash table. Return whether e != nullptr. +bool LRUCache::FinishErase(LRUHandle* e) { + if (e != nullptr) { + assert(e->in_cache); + LRU_Remove(e); + e->in_cache = false; + usage_ -= e->charge; + Unref(e); + } + return e != nullptr; +} + +void LRUCache::Erase(const Slice& key, uint32_t hash) { + mutex_lock l(mutex_); + FinishErase(table_.Remove(key, hash)); +} + +void LRUCache::Prune() { + mutex_lock l(mutex_); + while (lru_.next != &lru_) { + LRUHandle* e = lru_.next; + assert(e->refs == 1); + bool erased = FinishErase(table_.Remove(e->key(), e->hash)); + if (!erased) { // to avoid unused variable when compiled NDEBUG + assert(erased); + } + } +} + +static const int kNumShardBits = 4; +static const int kNumShards = 1 << kNumShardBits; + +class ShardedLRUCache : public Cache { + private: + LRUCache shard_[kNumShards]; + mutex id_mutex_; + uint64_t last_id_; + + static inline uint32_t HashSlice(const Slice& s) { + return Hash(s.data(), s.size(), 0); + } + + static uint32_t Shard(uint32_t hash) { return hash >> (32 - kNumShardBits); } + + public: + explicit ShardedLRUCache(size_t capacity) : last_id_(0) { + const size_t per_shard = (capacity + (kNumShards - 1)) / kNumShards; + for (int s = 0; s < kNumShards; s++) { + shard_[s].SetCapacity(per_shard); + } + } + ~ShardedLRUCache() override {} + Handle* Insert(const Slice& key, void* value, size_t charge, + void (*deleter)(const Slice& key, void* value)) override { + const uint32_t hash = HashSlice(key); + return shard_[Shard(hash)].Insert(key, hash, value, charge, deleter); + } + Handle* Lookup(const Slice& key) override { + const uint32_t hash = HashSlice(key); + return shard_[Shard(hash)].Lookup(key, hash); + } + void Release(Handle* handle) override { + LRUHandle* h = reinterpret_cast(handle); + shard_[Shard(h->hash)].Release(handle); + } + void Erase(const Slice& key) override { + const uint32_t hash = HashSlice(key); + shard_[Shard(hash)].Erase(key, hash); + } + void* Value(Handle* handle) override { + return reinterpret_cast(handle)->value; + } + uint64_t NewId() override { + mutex_lock l(id_mutex_); + return ++(last_id_); + } + void Prune() override { + for (int s = 0; s < kNumShards; s++) { + shard_[s].Prune(); + } + } + size_t TotalCharge() const override { + size_t total = 0; + for (int s = 0; s < kNumShards; s++) { + total += shard_[s].TotalCharge(); + } + return total; + } + + private: + // TODO(byronyi): Figure out why Hash32 fails EvictionPolicy test. + static uint32_t Hash(const char* data, size_t n, uint32_t seed) { + // Similar to murmur hash + const uint32_t m = 0xc6a4a793; + const uint32_t r = 24; + const char* limit = data + n; + uint32_t h = seed ^ (n * m); + + // Pick up four bytes at a time + while (data + 4 <= limit) { + uint32_t w = core::DecodeFixed32(data); + data += 4; + h += w; + h *= m; + h ^= (h >> 16); + } + + // Pick up remaining bytes + switch (limit - data) { + case 3: + h += static_cast(data[2]) << 16; + ABSL_FALLTHROUGH_INTENDED; + case 2: + h += static_cast(data[1]) << 8; + ABSL_FALLTHROUGH_INTENDED; + case 1: + h += static_cast(data[0]); + h *= m; + h ^= (h >> r); + break; + } + return h; + } +}; + +} // end anonymous namespace + +Cache* NewLRUCache(size_t capacity) { return new ShardedLRUCache(capacity); } + +} // namespace table + +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/cache.h b/tensorflow/core/lib/io/cache.h new file mode 100644 index 00000000000..788a637077a --- /dev/null +++ b/tensorflow/core/lib/io/cache.h @@ -0,0 +1,125 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_CACHE_H_ +#define TENSORFLOW_CORE_LIB_IO_CACHE_H_ + +#include "tensorflow/core/platform/stringpiece.h" + +// A Cache is an interface that maps keys to values. It has internal +// synchronization and may be safely accessed concurrently from +// multiple threads. It may automatically evict entries to make room +// for new entries. Values have a specified charge against the cache +// capacity. For example, a cache where the values are variable +// length strings, may use the length of the string as the charge for +// the string. +// +// A builtin cache implementation with a least-recently-used eviction +// policy is provided. Clients may use their own implementations if +// they want something more sophisticated (like scan-resistance, a +// custom eviction policy, variable cache sizing, etc.) + +namespace tensorflow { + +using Slice = StringPiece; + +namespace table { + +class Cache; + +// Create a new cache with a fixed size capacity. This implementation +// of Cache uses a least-recently-used eviction policy. +Cache* NewLRUCache(size_t capacity); + +class Cache { + public: + Cache() = default; + + Cache(const Cache&) = delete; + Cache& operator=(const Cache&) = delete; + + // Destroys all existing entries by calling the "deleter" + // function that was passed to the constructor. + virtual ~Cache(); + + // Opaque handle to an entry stored in the cache. + struct Handle {}; + + // Insert a mapping from key->value into the cache and assign it + // the specified charge against the total cache capacity. + // + // Returns a handle that corresponds to the mapping. The caller + // must call this->Release(handle) when the returned mapping is no + // longer needed. + // + // When the inserted entry is no longer needed, the key and + // value will be passed to "deleter". + virtual Handle* Insert(const Slice& key, void* value, size_t charge, + void (*deleter)(const Slice& key, void* value)) = 0; + + // If the cache has no mapping for "key", returns nullptr. + // + // Else return a handle that corresponds to the mapping. The caller + // must call this->Release(handle) when the returned mapping is no + // longer needed. + virtual Handle* Lookup(const Slice& key) = 0; + + // Release a mapping returned by a previous Lookup(). + // REQUIRES: handle must not have been released yet. + // REQUIRES: handle must have been returned by a method on *this. + virtual void Release(Handle* handle) = 0; + + // Return the value encapsulated in a handle returned by a + // successful Lookup(). + // REQUIRES: handle must not have been released yet. + // REQUIRES: handle must have been returned by a method on *this. + virtual void* Value(Handle* handle) = 0; + + // If the cache contains entry for key, erase it. Note that the + // underlying entry will be kept around until all existing handles + // to it have been released. + virtual void Erase(const Slice& key) = 0; + + // Return a new numeric id. May be used by multiple clients who are + // sharing the same cache to partition the key space. Typically the + // client will allocate a new id at startup and prepend the id to + // its cache keys. + virtual uint64_t NewId() = 0; + + // Remove all cache entries that are not actively in use. Memory-constrained + // applications may wish to call this method to reduce memory usage. + // Default implementation of Prune() does nothing. Subclasses are strongly + // encouraged to override the default implementation. A future release of + // leveldb may change Prune() to a pure abstract method. + virtual void Prune() {} + + // Return an estimate of the combined charges of all elements stored in the + // cache. + virtual size_t TotalCharge() const = 0; + + private: + void LRU_Remove(Handle* e); + void LRU_Append(Handle* e); + void Unref(Handle* e); + + struct Rep; + Rep* rep_; +}; + +} // namespace table + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_CACHE_H_ diff --git a/tensorflow/core/lib/io/cache_test.cc b/tensorflow/core/lib/io/cache_test.cc new file mode 100644 index 00000000000..38552d43b34 --- /dev/null +++ b/tensorflow/core/lib/io/cache_test.cc @@ -0,0 +1,238 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/cache.h" + +#include +#include + +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace table { +// Conversions between numeric keys/values and the types expected by Cache. +static std::string EncodeKey(int k) { + std::string result; + core::PutFixed32(&result, k); + return result; +} +static int DecodeKey(const Slice& k) { + assert(k.size() == 4); + return core::DecodeFixed32(k.data()); +} +static void* EncodeValue(uintptr_t v) { return reinterpret_cast(v); } +static int DecodeValue(void* v) { return reinterpret_cast(v); } + +class CacheTest : public ::testing::Test { + public: + static void Deleter(const Slice& key, void* v) { + current_->deleted_keys_.push_back(DecodeKey(key)); + current_->deleted_values_.push_back(DecodeValue(v)); + } + + static const int kCacheSize = 1000; + std::vector deleted_keys_; + std::vector deleted_values_; + Cache* cache_; + + CacheTest() : cache_(NewLRUCache(kCacheSize)) { current_ = this; } + + ~CacheTest() { delete cache_; } + + int Lookup(int key) { + Cache::Handle* handle = cache_->Lookup(EncodeKey(key)); + const int r = (handle == nullptr) ? -1 : DecodeValue(cache_->Value(handle)); + if (handle != nullptr) { + cache_->Release(handle); + } + return r; + } + + void Insert(int key, int value, int charge = 1) { + cache_->Release(cache_->Insert(EncodeKey(key), EncodeValue(value), charge, + &CacheTest::Deleter)); + } + + Cache::Handle* InsertAndReturnHandle(int key, int value, int charge = 1) { + return cache_->Insert(EncodeKey(key), EncodeValue(value), charge, + &CacheTest::Deleter); + } + + void Erase(int key) { cache_->Erase(EncodeKey(key)); } + static CacheTest* current_; +}; +CacheTest* CacheTest::current_; + +TEST_F(CacheTest, HitAndMiss) { + ASSERT_EQ(-1, Lookup(100)); + + Insert(100, 101); + ASSERT_EQ(101, Lookup(100)); + ASSERT_EQ(-1, Lookup(200)); + ASSERT_EQ(-1, Lookup(300)); + + Insert(200, 201); + ASSERT_EQ(101, Lookup(100)); + ASSERT_EQ(201, Lookup(200)); + ASSERT_EQ(-1, Lookup(300)); + + Insert(100, 102); + ASSERT_EQ(102, Lookup(100)); + ASSERT_EQ(201, Lookup(200)); + ASSERT_EQ(-1, Lookup(300)); + + ASSERT_EQ(1, deleted_keys_.size()); + ASSERT_EQ(100, deleted_keys_[0]); + ASSERT_EQ(101, deleted_values_[0]); +} + +TEST_F(CacheTest, Erase) { + Erase(200); + ASSERT_EQ(0, deleted_keys_.size()); + + Insert(100, 101); + Insert(200, 201); + Erase(100); + ASSERT_EQ(-1, Lookup(100)); + ASSERT_EQ(201, Lookup(200)); + ASSERT_EQ(1, deleted_keys_.size()); + ASSERT_EQ(100, deleted_keys_[0]); + ASSERT_EQ(101, deleted_values_[0]); + + Erase(100); + ASSERT_EQ(-1, Lookup(100)); + ASSERT_EQ(201, Lookup(200)); + ASSERT_EQ(1, deleted_keys_.size()); +} + +TEST_F(CacheTest, EntriesArePinned) { + Insert(100, 101); + Cache::Handle* h1 = cache_->Lookup(EncodeKey(100)); + ASSERT_EQ(101, DecodeValue(cache_->Value(h1))); + + Insert(100, 102); + Cache::Handle* h2 = cache_->Lookup(EncodeKey(100)); + ASSERT_EQ(102, DecodeValue(cache_->Value(h2))); + ASSERT_EQ(0, deleted_keys_.size()); + + cache_->Release(h1); + ASSERT_EQ(1, deleted_keys_.size()); + ASSERT_EQ(100, deleted_keys_[0]); + ASSERT_EQ(101, deleted_values_[0]); + + Erase(100); + ASSERT_EQ(-1, Lookup(100)); + ASSERT_EQ(1, deleted_keys_.size()); + + cache_->Release(h2); + ASSERT_EQ(2, deleted_keys_.size()); + ASSERT_EQ(100, deleted_keys_[1]); + ASSERT_EQ(102, deleted_values_[1]); +} + +TEST_F(CacheTest, EvictionPolicy) { + Insert(100, 101); + Insert(200, 201); + Insert(300, 301); + Cache::Handle* h = cache_->Lookup(EncodeKey(300)); + + // Frequently used entry must be kept around, + // as must things that are still in use. + for (int i = 0; i < kCacheSize + 100; i++) { + Insert(1000 + i, 2000 + i); + ASSERT_EQ(2000 + i, Lookup(1000 + i)); + ASSERT_EQ(101, Lookup(100)); + } + ASSERT_EQ(101, Lookup(100)); + ASSERT_EQ(-1, Lookup(200)); + ASSERT_EQ(301, Lookup(300)); + cache_->Release(h); +} + +TEST_F(CacheTest, UseExceedsCacheSize) { + // Overfill the cache, keeping handles on all inserted entries. + std::vector h; + for (int i = 0; i < kCacheSize + 100; i++) { + h.push_back(InsertAndReturnHandle(1000 + i, 2000 + i)); + } + + // Check that all the entries can be found in the cache. + for (int i = 0; i < h.size(); i++) { + ASSERT_EQ(2000 + i, Lookup(1000 + i)); + } + + for (int i = 0; i < h.size(); i++) { + cache_->Release(h[i]); + } +} + +TEST_F(CacheTest, HeavyEntries) { + // Add a bunch of light and heavy entries and then count the combined + // size of items still in the cache, which must be approximately the + // same as the total capacity. + const int kLight = 1; + const int kHeavy = 10; + int added = 0; + int index = 0; + while (added < 2 * kCacheSize) { + const int weight = (index & 1) ? kLight : kHeavy; + Insert(index, 1000 + index, weight); + added += weight; + index++; + } + + int cached_weight = 0; + for (int i = 0; i < index; i++) { + const int weight = (i & 1 ? kLight : kHeavy); + int r = Lookup(i); + if (r >= 0) { + cached_weight += weight; + ASSERT_EQ(1000 + i, r); + } + } + ASSERT_LE(cached_weight, kCacheSize + kCacheSize / 10); +} + +TEST_F(CacheTest, NewId) { + uint64_t a = cache_->NewId(); + uint64_t b = cache_->NewId(); + ASSERT_NE(a, b); +} + +TEST_F(CacheTest, Prune) { + Insert(1, 100); + Insert(2, 200); + + Cache::Handle* handle = cache_->Lookup(EncodeKey(1)); + ASSERT_TRUE(handle); + cache_->Prune(); + cache_->Release(handle); + + ASSERT_EQ(100, Lookup(1)); + ASSERT_EQ(-1, Lookup(2)); +} + +TEST_F(CacheTest, ZeroSizeCache) { + delete cache_; + cache_ = NewLRUCache(0); + + Insert(1, 100); + ASSERT_EQ(-1, Lookup(1)); +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/table.cc b/tensorflow/core/lib/io/table.cc index 1e68493bfe9..6cd1b21c14d 100644 --- a/tensorflow/core/lib/io/table.cc +++ b/tensorflow/core/lib/io/table.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/block.h" +#include "tensorflow/core/lib/io/cache.h" #include "tensorflow/core/lib/io/format.h" #include "tensorflow/core/lib/io/table_options.h" #include "tensorflow/core/lib/io/two_level_iterator.h" @@ -32,7 +33,7 @@ struct Table::Rep { Options options; Status status; RandomAccessFile* file; - // XXX uint64 cache_id; + uint64 cache_id; BlockHandle metaindex_handle; // Handle to metaindex_block: saved from footer Block* index_block; @@ -60,21 +61,18 @@ Status Table::Open(const Options& options, RandomAccessFile* file, uint64 size, Block* index_block = nullptr; if (s.ok()) { s = ReadBlock(file, footer.index_handle(), &contents); - if (s.ok()) { - index_block = new Block(contents); - } } if (s.ok()) { // We've successfully read the footer and the index block: we're // ready to serve requests. + index_block = new Block(contents); Rep* rep = new Table::Rep; rep->options = options; rep->file = file; rep->metaindex_handle = footer.metaindex_handle(); rep->index_block = index_block; - // XXX rep->cache_id = (options.block_cache ? - // options.block_cache->NewId() : 0); + rep->cache_id = (options.block_cache ? options.block_cache->NewId() : 0); *table = new Table(rep); } else { if (index_block) delete index_block; @@ -89,13 +87,24 @@ static void DeleteBlock(void* arg, void* ignored) { delete reinterpret_cast(arg); } +static void DeleteCachedBlock(const absl::string_view&, void* value) { + Block* block = reinterpret_cast(value); + delete block; +} + +static void ReleaseBlock(void* arg, void* h) { + Cache* cache = reinterpret_cast(arg); + Cache::Handle* handle = reinterpret_cast(h); + cache->Release(handle); +} + // Convert an index iterator value (i.e., an encoded BlockHandle) // into an iterator over the contents of the corresponding block. Iterator* Table::BlockReader(void* arg, const StringPiece& index_value) { Table* table = reinterpret_cast(arg); - // Cache* block_cache = table->rep_->options.block_cache; + Cache* block_cache = table->rep_->options.block_cache; Block* block = nullptr; - // Cache::Handle* cache_handle = NULL; + Cache::Handle* cache_handle = NULL; BlockHandle handle; StringPiece input = index_value; @@ -105,16 +114,38 @@ Iterator* Table::BlockReader(void* arg, const StringPiece& index_value) { if (s.ok()) { BlockContents contents; - s = ReadBlock(table->rep_->file, handle, &contents); - if (s.ok()) { - block = new Block(contents); + if (block_cache != nullptr) { + char cache_key_buffer[16]; + core::EncodeFixed64(cache_key_buffer, table->rep_->cache_id); + core::EncodeFixed64(cache_key_buffer + 8, handle.offset()); + absl::string_view key(cache_key_buffer, sizeof(cache_key_buffer)); + cache_handle = block_cache->Lookup(key); + if (cache_handle != nullptr) { + block = reinterpret_cast(block_cache->Value(cache_handle)); + } else { + s = ReadBlock(table->rep_->file, handle, &contents); + if (s.ok()) { + block = new Block(contents); + cache_handle = block_cache->Insert(key, block, block->size(), + &DeleteCachedBlock); + } + } + } else { + s = ReadBlock(table->rep_->file, handle, &contents); + if (s.ok()) { + block = new Block(contents); + } } } Iterator* iter; if (block != nullptr) { iter = block->NewIterator(); - iter->RegisterCleanup(&DeleteBlock, block, nullptr); + if (cache_handle == nullptr) { + iter->RegisterCleanup(&DeleteBlock, block, nullptr); + } else { + iter->RegisterCleanup(&ReleaseBlock, block_cache, cache_handle); + } } else { iter = NewErrorIterator(s); } diff --git a/tensorflow/core/lib/io/table_options.h b/tensorflow/core/lib/io/table_options.h index 9a36bf16315..d1b43ae7d43 100644 --- a/tensorflow/core/lib/io/table_options.h +++ b/tensorflow/core/lib/io/table_options.h @@ -21,6 +21,8 @@ limitations under the License. namespace tensorflow { namespace table { +class Cache; + // DB contents are stored in a set of blocks, each of which holds a // sequence of key,value pairs. Each block may be compressed before // being stored in a file. The following enum describes which @@ -60,6 +62,12 @@ struct Options { // incompressible, the kSnappyCompression implementation will // efficiently detect that and will switch to uncompressed mode. CompressionType compression = kSnappyCompression; + + // Control over blocks (user data is stored in a set of blocks, and + // a block is the unit of reading from disk). + + // If non-null, use the specified cache for blocks. + Cache* block_cache = nullptr; }; } // namespace table diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index 3ebb2ea5dc8..339c28dfb66 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/lib/io/table_builder.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/saved_tensor_slice_util.h" #include "tensorflow/core/util/tensor_bundle/byte_swap.h" #include "tensorflow/core/util/tensor_slice_util.h" @@ -729,6 +730,7 @@ BundleReader::BundleReader(Env* env, StringPiece prefix) prefix_(prefix), metadata_(nullptr), table_(nullptr), + index_cache_(nullptr), iter_(nullptr), need_to_swap_bytes_(false) { const string filename = MetaFilename(prefix_); @@ -741,7 +743,17 @@ BundleReader::BundleReader(Env* env, StringPiece prefix) status_ = env_->NewRandomAccessFile(filename, &wrapper); if (!status_.ok()) return; metadata_ = wrapper.release(); - status_ = table::Table::Open(table::Options(), metadata_, file_size, &table_); + + table::Options o; + int64 cache_size; + Status s = + ReadInt64FromEnvVar("TF_TABLE_INDEX_CACHE_SIZE_IN_MB", 0, &cache_size); + if (s.ok() && cache_size > 0) { + index_cache_ = table::NewLRUCache(cache_size << 20); + o.block_cache = index_cache_; + } + + status_ = table::Table::Open(o, metadata_, file_size, &table_); if (!status_.ok()) return; iter_ = table_->NewIterator(); @@ -772,6 +784,9 @@ BundleReader::~BundleReader() { delete metadata_; delete iter_; delete table_; + if (index_cache_) { + delete index_cache_; + } // InputBuffer does not own the underlying RandomAccessFile. for (auto pair : data_) { if (pair.second != nullptr && pair.second->file() != nullptr) { diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h index e1f39eccd17..c362dd41151 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h @@ -61,8 +61,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_ #define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_ -#include "tensorflow/core/protobuf/tensor_bundle.pb.h" - #include #include #include @@ -72,12 +70,14 @@ limitations under the License. #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/io/cache.h" #include "tensorflow/core/lib/io/inputbuffer.h" #include "tensorflow/core/lib/io/table.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/tensor_bundle.pb.h" #include "tensorflow/core/util/tensor_bundle/naming.h" #include "tensorflow/core/util/tensor_slice_set.h" @@ -288,6 +288,7 @@ class BundleReader { Status status_; RandomAccessFile* metadata_; // Owned. table::Table* table_; + table::Cache* index_cache_; table::Iterator* iter_; // Owned the InputBuffer objects and their underlying RandomAccessFile's. std::unordered_map data_;