305 lines
11 KiB
C++
305 lines
11 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
|
|
|
|
#include <stdlib.h>
|
|
|
|
#include <string>
|
|
|
|
#include "absl/synchronization/mutex.h"
|
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
|
#include "tensorflow/core/lib/core/errors.h"
|
|
#include "tensorflow/core/lib/random/random.h"
|
|
|
|
namespace tensorflow {
|
|
|
|
namespace {
|
|
|
|
int64 get_uid() {
|
|
uint64 unsigned_rand = random::New64() & INT64_MAX;
|
|
return static_cast<int64>(unsigned_rand);
|
|
}
|
|
|
|
int64 GetCompilationCacheSizeFromEnv() {
|
|
const char* env = getenv("TF_XRT_COMPILATION_CACHE_SIZE");
|
|
return env == nullptr ? 1024 : std::stol(env);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
const char* kXRTCompilationCacheResourceName = "xrt_compilation_cache";
|
|
|
|
XRTCompilationCache::EntryRefImpl::EntryRefImpl(XRTCompilationCache* parent,
|
|
CompiledSubgraph* entry)
|
|
: parent_(parent), entry_(entry) {
|
|
entry_->Ref();
|
|
}
|
|
|
|
XRTCompilationCache::EntryRefImpl::~EntryRefImpl() {
|
|
parent_->DiscardEntryRef(entry_);
|
|
}
|
|
|
|
XRTCompilationCacheEntry XRTCompilationCache::EntryRefImpl::get() {
|
|
return XRTCompilationCacheEntry(entry_->program.get());
|
|
}
|
|
|
|
XRTCompilationCache::XRTCompilationCache(int max_number_of_entries)
|
|
: max_cache_entries_(max_number_of_entries) {
|
|
CHECK_GE(max_cache_entries_, 0);
|
|
VLOG(1) << "Created compilation cache max " << max_cache_entries_
|
|
<< " entries.";
|
|
}
|
|
|
|
XRTCompilationCache::~XRTCompilationCache() {
|
|
VLOG(1) << "XRTCompilationCache::~XRTCompilationCache()";
|
|
// A buggy client may be holding onto a reference, or a client might have
|
|
// crashed while holding onto a reference. In either case, discard all
|
|
// outstanding client references to avoid leaking storage.
|
|
for (const auto& entry : entries_by_uid_) {
|
|
while (!entry.second->RefCountIsOne()) {
|
|
entry.second->Unref();
|
|
}
|
|
}
|
|
while (!entries_by_last_use_.empty()) {
|
|
MarkOldestEntryForEviction();
|
|
}
|
|
CHECK_EQ(cache_.size(), 0);
|
|
CHECK_EQ(entries_by_uid_.size(), 0);
|
|
CHECK_EQ(cache_entries_, 0);
|
|
CHECK_EQ(marked_for_eviction_entries_, 0);
|
|
}
|
|
|
|
Status XRTCompilationCache::Release(int64 uid) {
|
|
absl::MutexLock lock(&mu_);
|
|
auto iter = entries_by_uid_.find(uid);
|
|
|
|
if (iter == entries_by_uid_.end()) {
|
|
return errors::NotFound("No cache entry found for uid ", uid);
|
|
}
|
|
|
|
DiscardEntryRefLocked(iter->second);
|
|
|
|
VLOG(1) << "After releasing entry " << uid << " refs cache is "
|
|
<< cache_.size() << " entries ("
|
|
<< cache_entries_ + marked_for_eviction_entries_
|
|
<< "), marked for eviction "
|
|
<< (cache_.size() - entries_by_last_use_.size()) << " entries ("
|
|
<< marked_for_eviction_entries_ << ").";
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
void XRTCompilationCache::DiscardEntryRef(CompiledSubgraph* entry) {
|
|
absl::MutexLock lock(&mu_);
|
|
DiscardEntryRefLocked(entry);
|
|
}
|
|
|
|
void XRTCompilationCache::DiscardEntryRefLocked(CompiledSubgraph* entry) {
|
|
if (entry->RefCountIsOne()) {
|
|
// The last reference to this entry is going away, so really delete it from
|
|
// the cache in such a way that it can't be restored by being looked up
|
|
// again.
|
|
|
|
// Sanity-check that it has been marked for eviction.
|
|
CHECK(entries_by_last_use_.find(entry->last_use) ==
|
|
entries_by_last_use_.end());
|
|
// Update the counter tracking how much space is taken up by entries that
|
|
// are marked for eviction.
|
|
--marked_for_eviction_entries_;
|
|
|
|
// Remove the entry from the cache.
|
|
auto erased = cache_.erase(entry->key);
|
|
if (erased == 0) {
|
|
LOG(FATAL) << "Tried to discard nonexistent cache entry";
|
|
}
|
|
erased = entries_by_uid_.erase(entry->uid);
|
|
CHECK_EQ(erased, 1);
|
|
}
|
|
entry->Unref();
|
|
}
|
|
|
|
void XRTCompilationCache::MarkOldestEntryForEviction() {
|
|
CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
|
|
VLOG(1) << "Marking " << entry_to_mark->key << " for eviction";
|
|
entries_by_last_use_.erase(entry_to_mark->last_use);
|
|
--cache_entries_;
|
|
++marked_for_eviction_entries_;
|
|
// Discard the cache's reference to entry. If steps are holding onto
|
|
// references to entry it won't be deleted until the last step holding it
|
|
// completes. It stays in the cache in the meantime and can be resurrected
|
|
// by a call to CompileIfKeyAbsent if that occurs before the last reference
|
|
// expires.
|
|
DiscardEntryRefLocked(entry_to_mark);
|
|
}
|
|
|
|
void XRTCompilationCache::LookupEntryMarkedForEviction(
|
|
CompiledSubgraph* entry) {
|
|
// The entry was previously marked for eviction (or is newly created) so
|
|
// unmark it. Add a reference (owned by the cache), update the cache size, and
|
|
// mark something old for eviction if necessary.
|
|
entry->Ref();
|
|
--marked_for_eviction_entries_;
|
|
++cache_entries_;
|
|
|
|
// Mark the least-recently-used non-marked entry for eviction. Never mark the
|
|
// most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
|
|
// which means there's only one entry not already marked for eviction), so
|
|
// that an entry persists in the cache even if it is larger than the allocated
|
|
// cache size.
|
|
while (entries_by_last_use_.size() > 1 &&
|
|
cache_entries_ > max_cache_entries_) {
|
|
MarkOldestEntryForEviction();
|
|
}
|
|
}
|
|
|
|
XRTCompilationCache::CompiledSubgraph* XRTCompilationCache::InitializeEntry(
|
|
const string& key,
|
|
const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
|
|
initialize_program) {
|
|
CompiledSubgraph* entry = new CompiledSubgraph();
|
|
entry->parent = this;
|
|
entry->key = key;
|
|
entry->uid = get_uid();
|
|
// Add the entry to the cache. Once the computation has been compiled,
|
|
// UpdateEntryAfterCompilation will be called to potentially mark old entries
|
|
// that don't fit any more for eviction.
|
|
//
|
|
// At this point there is one reference to entry, which is owned by the caller
|
|
// who created the entry. A second reference, owned by the cache, will be
|
|
// added below since we leave the entry in the 'marked for eviction' state
|
|
// here.
|
|
auto cache_inserted =
|
|
cache_.insert(std::pair<string, CompiledSubgraph*>(key, entry));
|
|
CHECK(cache_inserted.second);
|
|
|
|
// Initialize the program outside the lock so that other cache operations
|
|
// can proceed during the (potentially lengthy) initialization.
|
|
Status s;
|
|
std::unique_ptr<xla::LocalExecutable> program;
|
|
{
|
|
mu_.Unlock();
|
|
{ s = initialize_program(&program); }
|
|
mu_.Lock();
|
|
}
|
|
|
|
// Add the entry to the uid index.
|
|
auto uid_inserted = entries_by_uid_.insert(
|
|
std::pair<int64, CompiledSubgraph*>(entry->uid, entry));
|
|
CHECK(uid_inserted.second);
|
|
|
|
entry->initialized = true;
|
|
entry->initialization_status = s;
|
|
if (s.ok()) {
|
|
entry->program = std::move(program);
|
|
}
|
|
// Add the entry to marked_for_eviction_entries_ since it will be adjusted
|
|
// down again when the newly-created entry gets unmarked.
|
|
++marked_for_eviction_entries_;
|
|
return entry;
|
|
}
|
|
|
|
Status XRTCompilationCache::CompileIfKeyAbsent(
|
|
const string& key, int64* uid,
|
|
const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
|
|
compile_function) {
|
|
CompiledSubgraph* entry = nullptr;
|
|
|
|
absl::MutexLock lock(&mu_);
|
|
auto iter = cache_.find(key);
|
|
|
|
if (iter == cache_.end()) {
|
|
// The single ref on the newly-created entry is owned by the caller.
|
|
VLOG(1) << "Before adding new entry for key " << key << " cache is "
|
|
<< cache_.size() << " entries ("
|
|
<< cache_entries_ + marked_for_eviction_entries_ << "), "
|
|
<< " marked for eviction "
|
|
<< (cache_.size() - entries_by_last_use_.size()) << " entries ("
|
|
<< marked_for_eviction_entries_ << ").";
|
|
entry = InitializeEntry(key, compile_function);
|
|
} else {
|
|
VLOG(1) << "Before refreshing entry for key " << key << " cache is "
|
|
<< cache_.size() << " entries ("
|
|
<< cache_entries_ + marked_for_eviction_entries_ << "), "
|
|
<< " marked for eviction "
|
|
<< (cache_.size() - entries_by_last_use_.size()) << " entries ("
|
|
<< marked_for_eviction_entries_ << ").";
|
|
entry = iter->second;
|
|
// Make a new reference that is owned by the caller.
|
|
entry->Ref();
|
|
// Block if necessary until the subgraph has been initialized.
|
|
mu_.Await(absl::Condition(
|
|
+[](CompiledSubgraph* e) { return e->initialized; }, entry));
|
|
}
|
|
|
|
// Let the caller know the uid of the entry.
|
|
*uid = entry->uid;
|
|
|
|
// Remove the old LRU-table entry if it wasn't already marked for eviction.
|
|
auto erased = entries_by_last_use_.erase(entry->last_use);
|
|
// Update the LRU table indicating this entry is the most recently used.
|
|
entry->last_use = use_counter_++;
|
|
entries_by_last_use_[entry->last_use] = entry;
|
|
if (erased == 0) {
|
|
// The entry had been marked for eviction, or is newly created.
|
|
LookupEntryMarkedForEviction(entry);
|
|
}
|
|
|
|
VLOG(1) << "After refreshing entry for key " << key << " cache is "
|
|
<< cache_.size() << " entries ("
|
|
<< cache_entries_ + marked_for_eviction_entries_ << "), "
|
|
<< " marked for eviction "
|
|
<< (cache_.size() - entries_by_last_use_.size()) << " entries ("
|
|
<< marked_for_eviction_entries_ << ").";
|
|
|
|
return entry->initialization_status;
|
|
}
|
|
|
|
Status XRTCompilationCache::Lookup(
|
|
int64 uid, std::unique_ptr<XRTCompilationCacheEntryRef>* entry) {
|
|
entry->reset();
|
|
|
|
absl::MutexLock lock(&mu_);
|
|
const auto iter = entries_by_uid_.find(uid);
|
|
if (iter == entries_by_uid_.end()) {
|
|
return errors::NotFound("No executable found for uid ", uid);
|
|
}
|
|
CompiledSubgraph* cache_entry = iter->second;
|
|
*entry = std::unique_ptr<XRTCompilationCacheEntryRef>(
|
|
new EntryRefImpl(this, cache_entry));
|
|
return Status::OK();
|
|
}
|
|
|
|
string XRTCompilationCache::DebugString() const {
|
|
return "XRTCompilationCache";
|
|
}
|
|
|
|
xla::StatusOr<RefPtr<XRTCompilationCache>> GetOrCreateCompilationCache(
|
|
ResourceMgr* rm, int64 max_number_of_entries) {
|
|
if (max_number_of_entries == 0) {
|
|
max_number_of_entries = GetCompilationCacheSizeFromEnv();
|
|
}
|
|
XRTCompilationCache* cache;
|
|
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XRTCompilationCache>(
|
|
rm->default_container(), kXRTCompilationCacheResourceName, &cache,
|
|
[&](XRTCompilationCache** new_cache) {
|
|
*new_cache = new XRTCompilationCache(max_number_of_entries);
|
|
return Status::OK();
|
|
}));
|
|
return RefPtr<XRTCompilationCache>(cache);
|
|
}
|
|
|
|
} // namespace tensorflow
|