STT-tensorflow/tensorflow/compiler/xrt/xrt_memory_manager.cc
Davide Libenzi 90e6bdca1f Instrument XRT with metrics and add op to fetch them from client side.
PiperOrigin-RevId: 291547054
Change-Id: Ia44b4d724805912961cf4f1fae165df9bad0c3b2
2020-01-25 12:57:21 -08:00

362 lines
13 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xrt/xrt_memory_manager.h"
#include <algorithm>
#include <list>
#include <unordered_map>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xrt/xrt_metrics.h"
#include "tensorflow/core/lib/monitoring/timed.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow {
namespace {
// We use kDeviceBits to store the device ordinal in the handle. We store the
// device in the upper part of the int64 handle to make sure the random bits are
// in the lower part which is better when storing the handle as a key for
// unordered maps.
const int kDeviceBits = 12;
int64 MakeDeviceHandle(int64 device_ordinal, int64 rnd_value) {
const int64 kUidMask = (static_cast<int64>(1) << (64 - kDeviceBits)) - 1;
return (device_ordinal << (64 - kDeviceBits)) | (rnd_value & kUidMask);
}
int GetDeviceFromHandle(int64 handle) {
return (handle >> (64 - kDeviceBits)) & ((1 << kDeviceBits) - 1);
}
} // namespace
class XRTMemoryManager::DeviceContext {
struct Alloc {
explicit Alloc(RefPtr<XRTTupleAllocation> tuple)
: tuple(std::move(tuple)) {}
RefPtr<XRTTupleAllocation> tuple;
};
using AllocList = std::list<Alloc>;
public:
int64 Register(RefPtr<XRTTupleAllocation> tuple) {
while (true) {
int64 handle = MakeDeviceHandle(tuple->device_ordinal(), CreateUid());
mutex_lock lock(lock_);
allocs_.emplace_front(tuple);
if (alloc_map_.emplace(handle, allocs_.begin()).second) {
return handle;
}
// The chances of hitting an existing handle are so remote, it is much
// more convenient to add to the list before, and eventually removing.
allocs_.erase(allocs_.begin());
}
}
bool Release(int64 handle) {
mutex_lock lock(lock_);
auto it = alloc_map_.find(handle);
if (it == alloc_map_.end()) {
return false;
}
allocs_.erase(it->second);
alloc_map_.erase(it);
return true;
}
RefPtr<XRTTupleAllocation> Lookup(int64 handle) {
mutex_lock lock(lock_);
auto it = alloc_map_.find(handle);
if (it == alloc_map_.end()) {
return nullptr;
}
// LRU
allocs_.splice(allocs_.begin(), allocs_, it->second);
return it->second->tuple;
}
void Clear() {
mutex_lock lock(lock_);
alloc_map_.clear();
allocs_.clear();
}
Status CompactAllocations(XRTMemoryManager* memory_manager,
xla::Backend* backend) {
profiler::TraceMe trace_me("XRTMemoryManager::CompactAllocations",
/*level=*/2);
auto timed = monitoring::MakeTimed(xrt_metrics::GetMemoryCompactCell());
VLOG(4) << "CompactAllocations started";
mutex_lock lock(lock_);
Status status;
std::vector<AllocList::iterator> swapped;
// We are swapping out from the most recently used allocations. This is
// desirable since the most recently used will be finding themselves at the
// bottom of the allocation space. Since these are more likely to be pinned
// allocations, a further trim done by following TryFreeMemory() call will
// eventually drop the higher located allocations, with better chance of
// reducing fragmentation.
// Also, by swapping out the pinned allocations first, those will also be
// the first to be restored, and hence if we will ever find OOM on the way
// out, we would more likely be swapping in not pinned ones.
for (auto it = allocs_.begin(); it != allocs_.end(); ++it) {
// We are compacting all the allocations, so we will temporarily swap out
// even pinned allocations.
auto swap_result_or = it->tuple->SwapOut(backend, /*swap_pinned=*/true);
if (!swap_result_or.ok()) {
status = swap_result_or.status();
break;
}
if (swap_result_or.ValueOrDie()) {
swapped.push_back(it);
}
}
// At this point we have released all the device memory we could release.
// Load back the tuple allocations we have swapped out above.
for (auto& it : swapped) {
auto swap_result_or = it->tuple->SwapIn(memory_manager, backend);
if (!swap_result_or.ok()) {
// If we failed to restored a pinned allocation, better to CHECK here
// than wondering why XRTTupleAllocation calls fail with errors about
// missing buffers.
CHECK(!it->tuple->IsPinned()); // Crash OK
if (status.ok()) {
status = swap_result_or.status();
}
}
}
VLOG(4) << "CompactAllocations finished: " << status;
return status;
}
// Tries to free size bytes by freeing some unpinned device memory. Returns
// the amount of memory which was able to free.
xla::StatusOr<size_t> TryFreeMemory(xla::Backend* backend, size_t size) {
profiler::TraceMe trace_me("XRTMemoryManager::TryFreeMemory", /*level=*/2);
auto timed = monitoring::MakeTimed(xrt_metrics::GetTryFreeMemoryCell());
mutex_lock lock(lock_);
size_t swapped_size = 0;
for (auto it = allocs_.rbegin(); it != allocs_.rend(); ++it) {
TF_ASSIGN_OR_RETURN(bool swap_result,
it->tuple->SwapOut(backend, /*swap_pinned=*/false));
if (swap_result) {
swapped_size += it->tuple->GetDeviceMemorySize();
if (swapped_size >= size) {
break;
}
}
}
VLOG(3) << "Swapped out " << swapped_size << " bytes";
return swapped_size;
}
private:
static int64 CreateUid() {
int64 uid;
do {
uid = random::New64() & INT64_MAX;
} while (uid == InvalidKey());
return uid;
}
// We store Alloc records inside an std::list<Alloc> so we can LRU it, and
// store the list iterators within the handle map, as list iterators don't get
// invalidated by (other elements) removals or position swaps.
mutex lock_;
AllocList allocs_;
std::unordered_map<int64, AllocList::iterator> alloc_map_;
};
XRTMemoryManager::WorkingSet::WorkingSet(
RefPtr<XRTMemoryManager> memory_manager)
: memory_manager_(std::move(memory_manager)) {}
XRTMemoryManager::WorkingSet::~WorkingSet() {
for (auto& tuple : pinned_tuples_) {
tuple->Unpin();
}
}
Status XRTMemoryManager::WorkingSet::LookupAndPin(xla::Backend* backend,
int64 handle) {
TF_ASSIGN_OR_RETURN(auto tuple, memory_manager_->Lookup(handle));
TF_RETURN_IF_ERROR(
tuple->PinAndSwapIn(memory_manager_.get(), backend).status());
pinned_tuples_.push_back(std::move(tuple));
return Status::OK();
}
/* static */ RefPtr<XRTMemoryManager> XRTMemoryManager::Get(ResourceMgr* rm) {
static string* container = new string("XrtState");
static string* name = new string("MemoryManager");
XRTMemoryManager* memory_manager = nullptr;
TF_CHECK_OK(rm->LookupOrCreate<XRTMemoryManager>(
*container, *name, &memory_manager, [](XRTMemoryManager** ret) {
*ret = new XRTMemoryManager();
return Status::OK();
}));
return memory_manager;
}
int64 XRTMemoryManager::Register(RefPtr<XRTTupleAllocation> tuple) {
DeviceContext* device_context = GetDeviceContext(tuple->device_ordinal(),
/*create_if_missing=*/true);
return device_context->Register(std::move(tuple));
}
xla::StatusOr<RefPtr<XRTTupleAllocation>> XRTMemoryManager::Lookup(
int64 handle) {
int device_ordinal = GetDeviceFromHandle(handle);
DeviceContext* device_context = GetDeviceContext(device_ordinal,
/*create_if_missing=*/false);
if (device_context == nullptr) {
return errors::NotFound("XRT memory handle not found: ", handle);
}
RefPtr<XRTTupleAllocation> tuple = device_context->Lookup(handle);
if (tuple == nullptr) {
return errors::NotFound("XRT memory handle not found: ", handle);
}
return std::move(tuple);
}
Status XRTMemoryManager::Release(int64 handle) {
int device_ordinal = GetDeviceFromHandle(handle);
DeviceContext* device_context = GetDeviceContext(device_ordinal,
/*create_if_missing=*/false);
if (device_context == nullptr || !device_context->Release(handle)) {
return errors::NotFound("XRT memory handle not found: ", handle);
}
return Status::OK();
}
Status XRTMemoryManager::CompactAllocations(xla::Backend* backend,
int device_ordinal) {
DeviceContext* device_context = GetDeviceContext(device_ordinal,
/*create_if_missing=*/false);
return device_context != nullptr
? device_context->CompactAllocations(this, backend)
: Status::OK();
}
void XRTMemoryManager::ReleaseAllAllocations() {
mutex_lock lock(lock_);
for (auto& device_context : device_contexts_) {
if (device_context != nullptr) {
device_context->Clear();
}
}
}
xla::StatusOr<se::OwningDeviceMemory> XRTMemoryManager::Allocate(
xla::Backend* backend, int device_ordinal, size_t size) {
se::DeviceMemoryAllocator* allocator = backend->memory_allocator();
auto memory_or =
allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false);
if (memory_or.status().code() == error::RESOURCE_EXHAUSTED) {
VLOG(4) << "Allocate of " << size << " bytes failed on device "
<< device_ordinal;
DeviceContext* device_context =
GetDeviceContext(device_ordinal,
/*create_if_missing=*/false);
if (device_context != nullptr) {
Status status = device_context->TryFreeMemory(backend, size).status();
if (status.ok()) {
// As long as there is no error, we still try again the allocation, even
// if the TryFreeMemory() call ended up freeing less memory than the
// required size. Fragmentation could make the memory allocation succeed
// even if the freed memory is indeed lower.
memory_or = allocator->Allocate(device_ordinal, size,
/*retry_on_failure=*/false);
} else if (status.code() != error::RESOURCE_EXHAUSTED) {
VLOG(4) << "Allocate of " << size << " bytes on device "
<< device_ordinal << ": " << status;
return status;
}
}
}
return memory_or;
}
string XRTMemoryManager::DebugString() const {
// We might want to emit more detailed information here, like per device
// memory allocations.
return "XRTMemoryManager";
}
XRTMemoryManager::DeviceContext* XRTMemoryManager::GetDeviceContext(
int device_ordinal, bool create_if_missing) {
mutex_lock lock(lock_);
if (device_ordinal >= device_contexts_.size()) {
if (!create_if_missing) {
return nullptr;
}
device_contexts_.resize(device_ordinal + 1);
}
DeviceContext* device_context = device_contexts_[device_ordinal].get();
if (device_context == nullptr && create_if_missing) {
device_contexts_[device_ordinal] = absl::make_unique<DeviceContext>();
device_context = device_contexts_[device_ordinal].get();
}
return device_context;
}
Status XRTMemoryManager::TryFreeMemoryStep(MemoryReclaimContext* mrctx,
const Status& status) {
DeviceContext* device_context = GetDeviceContext(mrctx->device_ordinal,
/*create_if_missing=*/false);
if (device_context == nullptr) {
return status;
}
if (!mrctx->done_freeing) {
// If the caller passed us a zero requested_free_size, we try to free chunks
// of kMaxFreeSize memory, until either the run function succeeds, or we run
// out of freeable memory.
const size_t kMaxFreeSize = 1000000000;
size_t free_size =
(mrctx->requested_free_size > 0)
? std::min<size_t>(mrctx->requested_free_size - mrctx->free_size,
kMaxFreeSize)
: kMaxFreeSize;
if (free_size > 0) {
auto free_size_or =
device_context->TryFreeMemory(mrctx->backend, free_size);
if (!free_size_or.ok()) {
return status;
}
size_t size = free_size_or.ValueOrDie();
mrctx->free_size += size;
if (size > 0) {
return Status::OK();
}
}
mrctx->done_freeing = true;
}
if (!mrctx->done_compacting) {
mrctx->done_compacting = true;
if (device_context->CompactAllocations(this, mrctx->backend).ok()) {
return Status::OK();
}
}
return status;
}
} // namespace tensorflow