Improve GPU Timestamped Allocator and KernelTracker.
These are experimental features, off by default. Outside of XLA compiled units TensorFlow does dynamic allocation of memory for GPU kernels and for the destinations of copies from CPU RAM or RDMA from NICs. As an optimization GPU memory is usually allocated "Just Until Queued", i.e. we consider that a GPU kernel has completed as soon as it is queued on the single compute stream of the device. This is safe with respect to use by other kernels on that stream but not with respect to asynchronous i/o. For the latter case we have a convention of forcing the i/o to wait on completion of all pending compute kernels before begining. This is wasteful if the queue is deep and free memory could have been found that is not in potential use by anything in the compute queue. This change upgrades the GPU KernelTracker to allow more flexible tracking strategies. Previously, if turned on it would insert a tracking event after every single op and allow only limiting the number of outstanding ops. This has a significant negative performance impact where many short-duration kernels are queued in series. Now it is possible to insert events at variable strides depending on memory allocated. The BFCAllocator is also upgraded to pursue a more sophisticated approach to handling timestamped free chunks. When timestamping is not done, freed chunks are merged immediately with any free neighbor before returning to the free bins. When it is done they are reinserted in the appropriate free bin without merging, and kept on a list for later merging when their timestamp passes the safe frontier. The GPUKernelTracker informs the BFCAllocator every time the safe frontier advances. The AllocatorAttributes struct has an optional freed_by_func which is not defined for allocations that will be used by GPU kernels but it is defined for some i/o uses. When the incoming allocation request lacks this function and cannot be satified, instead of returning nullptr which could trigger an OOM error, we go ahead and try to merge in any outstanding chunks with unsafe timestamps. Their use is actually safe for kernels and the only downside to doing so is that it may cause a later i/o allocation to stall. PiperOrigin-RevId: 245446915
This commit is contained in:
parent
9c6cc1a077
commit
4c3a1fbaa8
@ -462,11 +462,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
|||||||
Allocator* allocator = GetAllocatorLocked(alloc_attrs);
|
Allocator* allocator = GetAllocatorLocked(alloc_attrs);
|
||||||
Tensor copy(allocator, parsed.dtype(), parsed.shape());
|
Tensor copy(allocator, parsed.dtype(), parsed.shape());
|
||||||
Notification n;
|
Notification n;
|
||||||
device_context->CopyCPUTensorToDevice(&parsed, this, ©,
|
device_context->CopyCPUTensorToDevice(
|
||||||
|
&parsed, this, ©,
|
||||||
[&n, &status](const Status& s) {
|
[&n, &status](const Status& s) {
|
||||||
status = s;
|
status = s;
|
||||||
n.Notify();
|
n.Notify();
|
||||||
});
|
},
|
||||||
|
true /*sync_dst_compute*/);
|
||||||
n.WaitForNotification();
|
n.WaitForNotification();
|
||||||
*tensor = copy;
|
*tensor = copy;
|
||||||
}
|
}
|
||||||
|
@ -106,7 +106,8 @@ void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
|
|||||||
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||||
Device* device,
|
Device* device,
|
||||||
Tensor* device_tensor,
|
Tensor* device_tensor,
|
||||||
StatusCallback done) const {
|
StatusCallback done,
|
||||||
|
bool sync_dst_compute) const {
|
||||||
if (cpu_tensor->NumElements() == 0) {
|
if (cpu_tensor->NumElements() == 0) {
|
||||||
VLOG(2) << "CopyCPUTensorToDevice empty tensor";
|
VLOG(2) << "CopyCPUTensorToDevice empty tensor";
|
||||||
done(Status::OK());
|
done(Status::OK());
|
||||||
|
@ -61,8 +61,8 @@ class XlaDeviceContext : public DeviceContext {
|
|||||||
thread::ThreadPool* thread_pool);
|
thread::ThreadPool* thread_pool);
|
||||||
|
|
||||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||||
Tensor* device_tensor,
|
Tensor* device_tensor, StatusCallback done,
|
||||||
StatusCallback done) const override;
|
bool sync_dst_compute) const override;
|
||||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||||
absl::string_view tensor_name, Device* device,
|
absl::string_view tensor_name, Device* device,
|
||||||
Tensor* cpu_tensor, StatusCallback done) override;
|
Tensor* cpu_tensor, StatusCallback done) override;
|
||||||
|
@ -165,8 +165,8 @@ class CollectiveAdapterImpl : public CollectiveAdapter {
|
|||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor Scalar(Allocator* a) const override {
|
Tensor Scalar(Allocator* a, const AllocationAttributes& attr) const override {
|
||||||
Tensor t(a, dt_, TensorShape({}));
|
Tensor t(a, dt_, TensorShape({}), attr);
|
||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,7 +58,8 @@ class CollectiveAdapter {
|
|||||||
|
|
||||||
// Generate a scalar tensor of same DataType and on the same device
|
// Generate a scalar tensor of same DataType and on the same device
|
||||||
// as the backing tensor.
|
// as the backing tensor.
|
||||||
virtual Tensor Scalar(Allocator* a) const = 0;
|
virtual Tensor Scalar(Allocator* a,
|
||||||
|
const AllocationAttributes& attr) const = 0;
|
||||||
|
|
||||||
// Debugging string describing buffer location
|
// Debugging string describing buffer location
|
||||||
virtual string TBounds(const Tensor& t) const = 0;
|
virtual string TBounds(const Tensor& t) const = 0;
|
||||||
|
@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include <atomic>
|
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/bfc_allocator.h"
|
#include "tensorflow/core/common_runtime/bfc_allocator.h"
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/allocator_retry.h"
|
#include "tensorflow/core/common_runtime/allocator_retry.h"
|
||||||
#include "tensorflow/core/framework/device_base.h"
|
#include "tensorflow/core/framework/device_base.h"
|
||||||
#include "tensorflow/core/lib/core/bits.h"
|
#include "tensorflow/core/lib/core/bits.h"
|
||||||
@ -159,7 +159,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
|
|||||||
c->allocation_id = -1;
|
c->allocation_id = -1;
|
||||||
c->prev = kInvalidChunkHandle;
|
c->prev = kInvalidChunkHandle;
|
||||||
c->next = kInvalidChunkHandle;
|
c->next = kInvalidChunkHandle;
|
||||||
c->freed_count = 0;
|
c->freed_at_count = 0;
|
||||||
|
|
||||||
region_manager_.set_handle(c->ptr, h);
|
region_manager_.set_handle(c->ptr, h);
|
||||||
|
|
||||||
@ -184,6 +184,8 @@ BFCAllocator::ChunkHandle BFCAllocator::AllocateChunk() {
|
|||||||
|
|
||||||
void BFCAllocator::DeallocateChunk(ChunkHandle h) {
|
void BFCAllocator::DeallocateChunk(ChunkHandle h) {
|
||||||
Chunk* c = ChunkFromHandle(h);
|
Chunk* c = ChunkFromHandle(h);
|
||||||
|
c->allocation_id = -1;
|
||||||
|
c->bin_num = kInvalidBinNum;
|
||||||
c->next = free_chunks_list_;
|
c->next = free_chunks_list_;
|
||||||
free_chunks_list_ = h;
|
free_chunks_list_ = h;
|
||||||
}
|
}
|
||||||
@ -194,7 +196,7 @@ void* BFCAllocator::AllocateRawInternalWithRetry(
|
|||||||
// Fast path: Try once to allocate without getting the retry_helper_ involved
|
// Fast path: Try once to allocate without getting the retry_helper_ involved
|
||||||
uint64 freed_by_count = 0;
|
uint64 freed_by_count = 0;
|
||||||
if (allocation_attr.freed_by_func != nullptr) {
|
if (allocation_attr.freed_by_func != nullptr) {
|
||||||
freed_by_count = allocation_attr.freed_by_func();
|
freed_by_count = (*allocation_attr.freed_by_func)();
|
||||||
}
|
}
|
||||||
void* r =
|
void* r =
|
||||||
AllocateRawInternal(unused_alignment, num_bytes, false, freed_by_count);
|
AllocateRawInternal(unused_alignment, num_bytes, false, freed_by_count);
|
||||||
@ -206,7 +208,7 @@ void* BFCAllocator::AllocateRawInternalWithRetry(
|
|||||||
[this, &allocation_attr](size_t a, size_t nb, bool v) {
|
[this, &allocation_attr](size_t a, size_t nb, bool v) {
|
||||||
uint64 freed_by_count = 0;
|
uint64 freed_by_count = 0;
|
||||||
if (allocation_attr.freed_by_func != nullptr) {
|
if (allocation_attr.freed_by_func != nullptr) {
|
||||||
freed_by_count = allocation_attr.freed_by_func();
|
freed_by_count = (*allocation_attr.freed_by_func)();
|
||||||
}
|
}
|
||||||
return AllocateRawInternal(a, nb, v, freed_by_count);
|
return AllocateRawInternal(a, nb, v, freed_by_count);
|
||||||
},
|
},
|
||||||
@ -224,7 +226,7 @@ void* BFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes,
|
|||||||
bool dump_log_on_failure = VLOG_IS_ON(2);
|
bool dump_log_on_failure = VLOG_IS_ON(2);
|
||||||
uint64 freed_by_count = 0;
|
uint64 freed_by_count = 0;
|
||||||
if (allocation_attr.freed_by_func != nullptr) {
|
if (allocation_attr.freed_by_func != nullptr) {
|
||||||
freed_by_count = allocation_attr.freed_by_func();
|
freed_by_count = (*allocation_attr.freed_by_func)();
|
||||||
}
|
}
|
||||||
void* result = AllocateRawInternal(unused_alignment, num_bytes,
|
void* result = AllocateRawInternal(unused_alignment, num_bytes,
|
||||||
dump_log_on_failure, freed_by_count);
|
dump_log_on_failure, freed_by_count);
|
||||||
@ -236,6 +238,8 @@ void* BFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes,
|
|||||||
LOG(WARNING)
|
LOG(WARNING)
|
||||||
<< "Allocator (" << Name() << ") ran out of memory trying "
|
<< "Allocator (" << Name() << ") ran out of memory trying "
|
||||||
<< "to allocate " << strings::HumanReadableNumBytes(num_bytes)
|
<< "to allocate " << strings::HumanReadableNumBytes(num_bytes)
|
||||||
|
<< " with freed_by_count=" << freed_by_count
|
||||||
|
|
||||||
<< ". The caller indicates that this is not a failure, but"
|
<< ". The caller indicates that this is not a failure, but"
|
||||||
<< " may mean that there could be performance gains if more"
|
<< " may mean that there could be performance gains if more"
|
||||||
<< " memory were available.";
|
<< " memory were available.";
|
||||||
@ -274,6 +278,10 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
|
|||||||
BinNum bin_num = BinNumForSize(rounded_bytes);
|
BinNum bin_num = BinNumForSize(rounded_bytes);
|
||||||
|
|
||||||
mutex_lock l(lock_);
|
mutex_lock l(lock_);
|
||||||
|
if (!timestamped_chunks_.empty()) {
|
||||||
|
// Merge timestamped chunks whose counts have become safe for general use.
|
||||||
|
MergeTimestampedChunks(0);
|
||||||
|
}
|
||||||
void* ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
|
void* ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
|
||||||
if (ptr != nullptr) {
|
if (ptr != nullptr) {
|
||||||
return ptr;
|
return ptr;
|
||||||
@ -287,13 +295,27 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ((freed_before == 0) && (!timestamped_chunks_.empty())) {
|
||||||
|
// We're unable to satisfy an allocation request without a specific
|
||||||
|
// timestamp requirement. Rather than fail, try merging any held-out
|
||||||
|
// timestamped chunks more aggressively until a free chunk of the necessary
|
||||||
|
// size is formed.
|
||||||
|
if (MergeTimestampedChunks(rounded_bytes)) {
|
||||||
|
ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
|
||||||
|
if (ptr != nullptr) {
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// We searched all bins for an existing free chunk to use and
|
// We searched all bins for an existing free chunk to use and
|
||||||
// couldn't find one. This means we must have run out of memory,
|
// couldn't find one. This means we must have run out of memory,
|
||||||
// Dump the memory log for analysis.
|
// Dump the memory log for analysis.
|
||||||
if (dump_log_on_failure) {
|
if (dump_log_on_failure) {
|
||||||
LOG(WARNING) << "Allocator (" << Name() << ") ran out of memory trying "
|
LOG(WARNING) << "Allocator (" << Name() << ") ran out of memory trying "
|
||||||
<< "to allocate " << strings::HumanReadableNumBytes(num_bytes)
|
<< "to allocate " << strings::HumanReadableNumBytes(num_bytes)
|
||||||
<< ". Current allocation summary follows.";
|
<< " (rounded to " << rounded_bytes
|
||||||
|
<< "). Current allocation summary follows.";
|
||||||
DumpMemoryLog(rounded_bytes);
|
DumpMemoryLog(rounded_bytes);
|
||||||
LOG(WARNING) << RenderOccupancy();
|
LOG(WARNING) << RenderOccupancy();
|
||||||
}
|
}
|
||||||
@ -312,7 +334,7 @@ void* BFCAllocator::FindChunkPtr(BinNum bin_num, size_t rounded_bytes,
|
|||||||
const BFCAllocator::ChunkHandle h = (*citer);
|
const BFCAllocator::ChunkHandle h = (*citer);
|
||||||
BFCAllocator::Chunk* chunk = ChunkFromHandle(h);
|
BFCAllocator::Chunk* chunk = ChunkFromHandle(h);
|
||||||
DCHECK(!chunk->in_use());
|
DCHECK(!chunk->in_use());
|
||||||
if (freed_before > 0 && freed_before < chunk->freed_count) {
|
if (freed_before > 0 && freed_before < chunk->freed_at_count) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (chunk->size >= rounded_bytes) {
|
if (chunk->size >= rounded_bytes) {
|
||||||
@ -378,7 +400,7 @@ void BFCAllocator::SplitChunk(BFCAllocator::ChunkHandle h, size_t num_bytes) {
|
|||||||
new_chunk->allocation_id = -1;
|
new_chunk->allocation_id = -1;
|
||||||
|
|
||||||
// It inherits the freed time.
|
// It inherits the freed time.
|
||||||
new_chunk->freed_count = c->freed_count;
|
new_chunk->freed_at_count = c->freed_at_count;
|
||||||
|
|
||||||
// Maintain the pointers.
|
// Maintain the pointers.
|
||||||
// c <-> c_neighbor becomes
|
// c <-> c_neighbor becomes
|
||||||
@ -414,8 +436,15 @@ void BFCAllocator::DeallocateRawInternal(void* ptr) {
|
|||||||
BFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr);
|
BFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr);
|
||||||
CHECK(h != kInvalidChunkHandle);
|
CHECK(h != kInvalidChunkHandle);
|
||||||
|
|
||||||
|
MarkFree(h);
|
||||||
|
|
||||||
// Consider coalescing it.
|
// Consider coalescing it.
|
||||||
FreeAndMaybeCoalesce(h);
|
if (timing_counter_) {
|
||||||
|
InsertFreeChunkIntoBin(h);
|
||||||
|
timestamped_chunks_.push_back(h);
|
||||||
|
} else {
|
||||||
|
InsertFreeChunkIntoBin(TryToCoalesce(h, false));
|
||||||
|
}
|
||||||
|
|
||||||
if (VLOG_IS_ON(4)) {
|
if (VLOG_IS_ON(4)) {
|
||||||
LOG(INFO) << "F: " << RenderOccupancy();
|
LOG(INFO) << "F: " << RenderOccupancy();
|
||||||
@ -451,7 +480,7 @@ void BFCAllocator::Merge(BFCAllocator::ChunkHandle h1,
|
|||||||
c1->size += c2->size;
|
c1->size += c2->size;
|
||||||
|
|
||||||
// Pick latest free time.
|
// Pick latest free time.
|
||||||
c1->freed_count = std::max(c1->freed_count, c2->freed_count);
|
c1->freed_at_count = std::max(c1->freed_at_count, c2->freed_at_count);
|
||||||
|
|
||||||
DeleteChunk(h2);
|
DeleteChunk(h2);
|
||||||
}
|
}
|
||||||
@ -491,7 +520,7 @@ void BFCAllocator::RemoveFreeChunkFromBin(BFCAllocator::ChunkHandle h) {
|
|||||||
c->bin_num = kInvalidBinNum;
|
c->bin_num = kInvalidBinNum;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BFCAllocator::FreeAndMaybeCoalesce(BFCAllocator::ChunkHandle h) {
|
void BFCAllocator::MarkFree(BFCAllocator::ChunkHandle h) {
|
||||||
Chunk* c = ChunkFromHandle(h);
|
Chunk* c = ChunkFromHandle(h);
|
||||||
CHECK(c->in_use() && (c->bin_num == kInvalidBinNum));
|
CHECK(c->in_use() && (c->bin_num == kInvalidBinNum));
|
||||||
|
|
||||||
@ -500,33 +529,128 @@ void BFCAllocator::FreeAndMaybeCoalesce(BFCAllocator::ChunkHandle h) {
|
|||||||
|
|
||||||
// Optionally record the free time.
|
// Optionally record the free time.
|
||||||
if (timing_counter_) {
|
if (timing_counter_) {
|
||||||
c->freed_count = timing_counter_->next();
|
c->freed_at_count = timing_counter_->next();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Updates the stats.
|
// Updates the stats.
|
||||||
stats_.bytes_in_use -= c->size;
|
stats_.bytes_in_use -= c->size;
|
||||||
|
}
|
||||||
|
|
||||||
|
BFCAllocator::ChunkHandle BFCAllocator::TryToCoalesce(ChunkHandle h,
|
||||||
|
bool ignore_freed_at) {
|
||||||
|
Chunk* c = ChunkFromHandle(h);
|
||||||
|
if ((!ignore_freed_at) && c->freed_at_count > 0) return h;
|
||||||
ChunkHandle coalesced_chunk = h;
|
ChunkHandle coalesced_chunk = h;
|
||||||
|
|
||||||
// If the next chunk is free, merge it into c and delete it.
|
// If the next chunk is free, merge it into c and delete it.
|
||||||
if (c->next != kInvalidChunkHandle && !ChunkFromHandle(c->next)->in_use()) {
|
if (c->next != kInvalidChunkHandle && !ChunkFromHandle(c->next)->in_use()) {
|
||||||
// VLOG(8) << "Merging c->next " << ChunkFromHandle(c->next)->ptr
|
Chunk* n = ChunkFromHandle(c->next);
|
||||||
// << " with c " << c->ptr;
|
if ((n->freed_at_count == 0) || ignore_freed_at) {
|
||||||
|
VLOG(4) << "Merging c->next " << n->ptr << " with c " << c->ptr;
|
||||||
RemoveFreeChunkFromBin(c->next);
|
RemoveFreeChunkFromBin(c->next);
|
||||||
Merge(h, c->next);
|
Merge(h, c->next);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If the previous chunk is free, merge c into it and delete c.
|
// If the previous chunk is free, merge c into it and delete c.
|
||||||
if (c->prev != kInvalidChunkHandle && !ChunkFromHandle(c->prev)->in_use()) {
|
if (c->prev != kInvalidChunkHandle && !ChunkFromHandle(c->prev)->in_use()) {
|
||||||
// VLOG(8) << "Merging c " << c->ptr << " into c->prev "
|
Chunk* n = ChunkFromHandle(c->prev);
|
||||||
// << ChunkFromHandle(c->prev)->ptr;
|
if ((n->freed_at_count == 0) || ignore_freed_at) {
|
||||||
|
VLOG(4) << "Merging c " << c->ptr << " into c->prev " << n->ptr;
|
||||||
coalesced_chunk = c->prev;
|
coalesced_chunk = c->prev;
|
||||||
RemoveFreeChunkFromBin(c->prev);
|
RemoveFreeChunkFromBin(c->prev);
|
||||||
Merge(c->prev, h);
|
Merge(c->prev, h);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
InsertFreeChunkIntoBin(coalesced_chunk);
|
return coalesced_chunk;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BFCAllocator::SetSafeFrontier(uint64 count) {
|
||||||
|
uint64 current = safe_frontier_.load(std::memory_order_relaxed);
|
||||||
|
while (count > current) {
|
||||||
|
if (safe_frontier_.compare_exchange_strong(current, count)) {
|
||||||
|
retry_helper_.NotifyDealloc();
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
current = safe_frontier_.load(std::memory_order_relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BFCAllocator::MergeTimestampedChunks(size_t required_bytes) {
|
||||||
|
VLOG(1) << "MergeTimestampedChunks queue_len=" << timestamped_chunks_.size()
|
||||||
|
<< " required_bytes=" << required_bytes;
|
||||||
|
bool satisfied = (required_bytes == 0);
|
||||||
|
std::vector<void*> to_merge;
|
||||||
|
std::deque<ChunkHandle> new_ts_queue;
|
||||||
|
while (!timestamped_chunks_.empty()) {
|
||||||
|
ChunkHandle h = timestamped_chunks_.front();
|
||||||
|
timestamped_chunks_.pop_front();
|
||||||
|
DCHECK_NE(h, kInvalidChunkHandle);
|
||||||
|
Chunk* c = ChunkFromHandle(h);
|
||||||
|
// It's possible this chunk has already been merged so refetch and retest
|
||||||
|
// the handle.
|
||||||
|
h = region_manager_.get_handle(c->ptr);
|
||||||
|
if (h == kInvalidChunkHandle) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (c->in_use() || (c->bin_num == kInvalidBinNum)) {
|
||||||
|
// This chunk has already been reallocated.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (c->freed_at_count == 0) {
|
||||||
|
to_merge.push_back(c->ptr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Chunk should be free and assigned to a bin.
|
||||||
|
DCHECK_NE(c->bin_num, kInvalidBinNum);
|
||||||
|
if (c->freed_at_count < safe_frontier_) {
|
||||||
|
c->freed_at_count = 0;
|
||||||
|
to_merge.push_back(c->ptr);
|
||||||
|
} else if (required_bytes > 0) {
|
||||||
|
to_merge.push_back(c->ptr);
|
||||||
|
} else {
|
||||||
|
new_ts_queue.push_back(h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DCHECK(timestamped_chunks_.empty());
|
||||||
|
std::swap(timestamped_chunks_, new_ts_queue);
|
||||||
|
|
||||||
|
// At this point all candidate chunks have been moved from timestamped_chunks_
|
||||||
|
// to to_merge. If this is a standard merge (required_bytes == 0) then
|
||||||
|
// merge them all, otherwise merge just until a Chunk of the required size
|
||||||
|
// is produced.
|
||||||
|
for (int ci = 0; ci < to_merge.size(); ++ci) {
|
||||||
|
void* ptr = to_merge[ci];
|
||||||
|
// It's possible that the Chunk associated with this memory location got
|
||||||
|
// merged and deallocated in a prior iteration so refetch the handle and
|
||||||
|
// retest.
|
||||||
|
ChunkHandle h = region_manager_.get_handle(ptr);
|
||||||
|
if (h == kInvalidChunkHandle) continue;
|
||||||
|
if (required_bytes == 0 || !satisfied) {
|
||||||
|
Chunk* c = ChunkFromHandle(h);
|
||||||
|
DCHECK_NE(c->bin_num, kInvalidBinNum);
|
||||||
|
DCHECK(!c->in_use());
|
||||||
|
RemoveFreeChunkFromBin(h);
|
||||||
|
ChunkHandle new_h = TryToCoalesce(h, (required_bytes > 0));
|
||||||
|
InsertFreeChunkIntoBin(new_h);
|
||||||
|
if (required_bytes > 0) {
|
||||||
|
c = ChunkFromHandle(new_h);
|
||||||
|
if (new_h != h && c->freed_at_count > 0) {
|
||||||
|
timestamped_chunks_.push_back(new_h);
|
||||||
|
}
|
||||||
|
if (c->size >= required_bytes) {
|
||||||
|
satisfied = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// We were force merging Chunks with unsafe timestamps, but managed
|
||||||
|
// to create a satisfying Chunk so just requeue the rest.
|
||||||
|
timestamped_chunks_.push_back(h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return satisfied;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BFCAllocator::TracksAllocationSizes() const { return true; }
|
bool BFCAllocator::TracksAllocationSizes() const { return true; }
|
||||||
@ -667,16 +791,17 @@ void BFCAllocator::DumpMemoryLog(size_t num_bytes) {
|
|||||||
// number by size.
|
// number by size.
|
||||||
std::map<size_t, int> in_use_by_size;
|
std::map<size_t, int> in_use_by_size;
|
||||||
for (const auto& region : region_manager_.regions()) {
|
for (const auto& region : region_manager_.regions()) {
|
||||||
|
LOG(INFO) << "Next region of size " << region.memory_size();
|
||||||
ChunkHandle h = region_manager_.get_handle(region.ptr());
|
ChunkHandle h = region_manager_.get_handle(region.ptr());
|
||||||
while (h != kInvalidChunkHandle) {
|
while (h != kInvalidChunkHandle) {
|
||||||
const Chunk* c = ChunkFromHandle(h);
|
const Chunk* c = ChunkFromHandle(h);
|
||||||
if (c->in_use()) {
|
if (c->in_use()) {
|
||||||
in_use_by_size[c->size]++;
|
in_use_by_size[c->size]++;
|
||||||
}
|
}
|
||||||
LOG(INFO) << (c->in_use() ? "Chunk" : "Free ") << " at " << c->ptr
|
LOG(INFO) << (c->in_use() ? "InUse" : "Free ") << " at " << c->ptr
|
||||||
<< " of size " << c->size
|
<< " next " << c->next << " of size " << c->size
|
||||||
<< (timing_counter_
|
<< (timing_counter_
|
||||||
? strings::StrCat(" freed_count ", c->freed_count)
|
? strings::StrCat(" freed_at_count ", c->freed_at_count)
|
||||||
: "");
|
: "");
|
||||||
h = c->next;
|
h = c->next;
|
||||||
}
|
}
|
||||||
@ -691,6 +816,12 @@ void BFCAllocator::DumpMemoryLog(size_t num_bytes) {
|
|||||||
}
|
}
|
||||||
LOG(INFO) << "Sum Total of in-use chunks: "
|
LOG(INFO) << "Sum Total of in-use chunks: "
|
||||||
<< strings::HumanReadableNumBytes(total_bytes);
|
<< strings::HumanReadableNumBytes(total_bytes);
|
||||||
|
LOG(INFO) << "total_region_allocated_bytes_: "
|
||||||
|
<< total_region_allocated_bytes_
|
||||||
|
<< " memory_limit_: " << memory_limit_ << " available bytes: "
|
||||||
|
<< (memory_limit_ - total_region_allocated_bytes_)
|
||||||
|
<< " curr_region_allocation_bytes_: "
|
||||||
|
<< curr_region_allocation_bytes_;
|
||||||
LOG(INFO) << "Stats: \n" << stats_.DebugString();
|
LOG(INFO) << "Stats: \n" << stats_.DebugString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_BFC_ALLOCATOR_H_
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_BFC_ALLOCATOR_H_
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
#include <deque>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@ -75,6 +76,8 @@ class BFCAllocator : public Allocator {
|
|||||||
|
|
||||||
void SetTimingCounter(SharedCounter* sc) { timing_counter_ = sc; }
|
void SetTimingCounter(SharedCounter* sc) { timing_counter_ = sc; }
|
||||||
|
|
||||||
|
void SetSafeFrontier(uint64 count) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct Bin;
|
struct Bin;
|
||||||
|
|
||||||
@ -88,6 +91,23 @@ class BFCAllocator : public Allocator {
|
|||||||
|
|
||||||
void DeallocateRawInternal(void* ptr);
|
void DeallocateRawInternal(void* ptr);
|
||||||
|
|
||||||
|
// Chunks whose freed_at_count is later than the safe frontier value are kept
|
||||||
|
// on a special list and not subject to merging immediately upon being freed.
|
||||||
|
//
|
||||||
|
// This function sweeps that list looking for Chunks whose timestamp is now
|
||||||
|
// safe. When found their freed_at_count is set to 0 and we attempt to merge
|
||||||
|
// them with their neighbors.
|
||||||
|
//
|
||||||
|
// If required_bytes > 0 then this function is being called in the context of
|
||||||
|
// a need for this many bytes that could not be satisfied without merging
|
||||||
|
// unsafe chunks, so we go ahead and merge the unsafe chunks too, just up to
|
||||||
|
// the point that a free chunk of required_bytes is produced. Note that
|
||||||
|
// unsafe merged chunks adopt the most conservative timestamp from their
|
||||||
|
// constituents so they're only useful for allocations not requiring a
|
||||||
|
// particular timestamp.
|
||||||
|
bool MergeTimestampedChunks(size_t required_bytes)
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(lock_);
|
||||||
|
|
||||||
// A ChunkHandle is an index into the chunks_ vector in BFCAllocator
|
// A ChunkHandle is an index into the chunks_ vector in BFCAllocator
|
||||||
// kInvalidChunkHandle means an invalid chunk
|
// kInvalidChunkHandle means an invalid chunk
|
||||||
typedef size_t ChunkHandle;
|
typedef size_t ChunkHandle;
|
||||||
@ -95,6 +115,7 @@ class BFCAllocator : public Allocator {
|
|||||||
|
|
||||||
typedef int BinNum;
|
typedef int BinNum;
|
||||||
static const int kInvalidBinNum = -1;
|
static const int kInvalidBinNum = -1;
|
||||||
|
// The following means that the largest bin'd chunk size is 256 << 21 = 512MB.
|
||||||
static const int kNumBins = 21;
|
static const int kNumBins = 21;
|
||||||
|
|
||||||
// A Chunk points to a piece of memory that's either entirely free or entirely
|
// A Chunk points to a piece of memory that's either entirely free or entirely
|
||||||
@ -141,7 +162,7 @@ class BFCAllocator : public Allocator {
|
|||||||
BinNum bin_num = kInvalidBinNum;
|
BinNum bin_num = kInvalidBinNum;
|
||||||
|
|
||||||
// Optional count when this chunk was most recently made free.
|
// Optional count when this chunk was most recently made free.
|
||||||
uint64 freed_count = 0;
|
uint64 freed_at_count = 0;
|
||||||
|
|
||||||
bool in_use() const { return allocation_id != -1; }
|
bool in_use() const { return allocation_id != -1; }
|
||||||
|
|
||||||
@ -151,7 +172,7 @@ class BFCAllocator : public Allocator {
|
|||||||
strings::StrAppend(
|
strings::StrAppend(
|
||||||
&dbg, " Size: ", strings::HumanReadableNumBytes(size),
|
&dbg, " Size: ", strings::HumanReadableNumBytes(size),
|
||||||
" | Requested Size: ", strings::HumanReadableNumBytes(requested_size),
|
" | Requested Size: ", strings::HumanReadableNumBytes(requested_size),
|
||||||
" | in_use: ", in_use());
|
" | in_use: ", in_use(), " | bin_num: ", bin_num);
|
||||||
if (recurse && prev != BFCAllocator::kInvalidChunkHandle) {
|
if (recurse && prev != BFCAllocator::kInvalidChunkHandle) {
|
||||||
Chunk* p = a->ChunkFromHandle(prev);
|
Chunk* p = a->ChunkFromHandle(prev);
|
||||||
strings::StrAppend(&dbg, ", prev: ", p->DebugString(a, false));
|
strings::StrAppend(&dbg, ", prev: ", p->DebugString(a, false));
|
||||||
@ -165,6 +186,7 @@ class BFCAllocator : public Allocator {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// A Bin is a collection of similar-sized free chunks.
|
// A Bin is a collection of similar-sized free chunks.
|
||||||
|
// Allocated chunks are never in a Bin.
|
||||||
struct Bin {
|
struct Bin {
|
||||||
// All chunks in this bin have >= bin_size memory.
|
// All chunks in this bin have >= bin_size memory.
|
||||||
size_t bin_size = 0;
|
size_t bin_size = 0;
|
||||||
@ -201,10 +223,13 @@ class BFCAllocator : public Allocator {
|
|||||||
|
|
||||||
// BFCAllocator allocates memory into a collection of disjoint
|
// BFCAllocator allocates memory into a collection of disjoint
|
||||||
// AllocationRegions. Each AllocationRegion corresponds to one call to
|
// AllocationRegions. Each AllocationRegion corresponds to one call to
|
||||||
// SubAllocator::Alloc().
|
// SubAllocator::Alloc(). (Actually, if a subsequent call to
|
||||||
|
// SubAllocator::Alloc() returns another region immediately adjacent to the
|
||||||
|
// last, it will be used to extend the first AllocationRegion, not create a
|
||||||
|
// separate one.)
|
||||||
//
|
//
|
||||||
// An AllocationRegion contains one or more Chunks, covering all of its
|
// An AllocationRegion contains one or more Chunks, covering all of its
|
||||||
// memory. Its primary job is to map a pointers to ChunkHandles.
|
// memory. Its primary job is to map pointers to ChunkHandles.
|
||||||
//
|
//
|
||||||
// This class is thread-compatible.
|
// This class is thread-compatible.
|
||||||
class AllocationRegion {
|
class AllocationRegion {
|
||||||
@ -358,6 +383,8 @@ class BFCAllocator : public Allocator {
|
|||||||
|
|
||||||
// Removes a free chunk from the bin.
|
// Removes a free chunk from the bin.
|
||||||
void RemoveFreeChunkFromBin(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
|
void RemoveFreeChunkFromBin(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
|
||||||
|
void MaybeRemoveFreeChunkFromBin(ChunkHandle h)
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(lock_);
|
||||||
|
|
||||||
// Removes the chunk metadata represented by 'h'.
|
// Removes the chunk metadata represented by 'h'.
|
||||||
void DeleteChunk(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
|
void DeleteChunk(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
|
||||||
@ -372,6 +399,11 @@ class BFCAllocator : public Allocator {
|
|||||||
const Chunk* ChunkFromHandle(ChunkHandle h) const
|
const Chunk* ChunkFromHandle(ChunkHandle h) const
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(lock_);
|
EXCLUSIVE_LOCKS_REQUIRED(lock_);
|
||||||
|
|
||||||
|
void MarkFree(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
|
||||||
|
|
||||||
|
ChunkHandle TryToCoalesce(ChunkHandle h, bool ignore_freed_at)
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(lock_);
|
||||||
|
|
||||||
// Information about a Bin that is useful for debugging.
|
// Information about a Bin that is useful for debugging.
|
||||||
struct BinDebugInfo {
|
struct BinDebugInfo {
|
||||||
size_t total_bytes_in_use = 0;
|
size_t total_bytes_in_use = 0;
|
||||||
@ -441,6 +473,9 @@ class BFCAllocator : public Allocator {
|
|||||||
std::unique_ptr<SubAllocator> sub_allocator_;
|
std::unique_ptr<SubAllocator> sub_allocator_;
|
||||||
string name_;
|
string name_;
|
||||||
SharedCounter* timing_counter_ = nullptr;
|
SharedCounter* timing_counter_ = nullptr;
|
||||||
|
std::deque<ChunkHandle> timestamped_chunks_;
|
||||||
|
|
||||||
|
std::atomic<uint64> safe_frontier_ = {0};
|
||||||
|
|
||||||
// Structures mutable after construction
|
// Structures mutable after construction
|
||||||
mutable mutex lock_;
|
mutable mutex lock_;
|
||||||
|
@ -50,7 +50,8 @@ std::vector<RegistrationInfo>* MutableRegistry() {
|
|||||||
void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
|
void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
|
||||||
Allocator* out_allocator, StringPiece edge_name,
|
Allocator* out_allocator, StringPiece edge_name,
|
||||||
Device* dst, Tensor* output,
|
Device* dst, Tensor* output,
|
||||||
DeviceContext* recv_dev_context, StatusCallback done) {
|
DeviceContext* recv_dev_context, StatusCallback done,
|
||||||
|
bool sync_dst_compute) {
|
||||||
if (input->dtype() == DT_VARIANT) {
|
if (input->dtype() == DT_VARIANT) {
|
||||||
Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
|
Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
|
||||||
auto* status_cb = new ReffedStatusCallback(std::move(done));
|
auto* status_cb = new ReffedStatusCallback(std::move(done));
|
||||||
@ -62,13 +63,14 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
|
|||||||
};
|
};
|
||||||
auto copier = std::bind(
|
auto copier = std::bind(
|
||||||
[dst, recv_dev_context, out_allocator, status_cb, cpu_allocator,
|
[dst, recv_dev_context, out_allocator, status_cb, cpu_allocator,
|
||||||
edge_name](StatusCallback wrapped_done_,
|
edge_name, sync_dst_compute](StatusCallback wrapped_done_,
|
||||||
// Begin unbound arguments
|
// Begin unbound arguments
|
||||||
const Tensor& from, Tensor* to) {
|
const Tensor& from, Tensor* to) {
|
||||||
if (from.dtype() == DT_VARIANT) {
|
if (from.dtype() == DT_VARIANT) {
|
||||||
status_cb->Ref();
|
status_cb->Ref();
|
||||||
CopyHostToDevice(&from, cpu_allocator, out_allocator, edge_name,
|
CopyHostToDevice(&from, cpu_allocator, out_allocator, edge_name,
|
||||||
dst, to, recv_dev_context, wrapped_done_);
|
dst, to, recv_dev_context, wrapped_done_,
|
||||||
|
sync_dst_compute);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
if (!DMAHelper::CanUseDMA(&from)) {
|
if (!DMAHelper::CanUseDMA(&from)) {
|
||||||
@ -82,8 +84,8 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
|
|||||||
if (status_cb->ok()) {
|
if (status_cb->ok()) {
|
||||||
status_cb->Ref();
|
status_cb->Ref();
|
||||||
*to = Tensor(out_allocator, from.dtype(), from.shape());
|
*to = Tensor(out_allocator, from.dtype(), from.shape());
|
||||||
recv_dev_context->CopyCPUTensorToDevice(&from, dst, to,
|
recv_dev_context->CopyCPUTensorToDevice(
|
||||||
wrapped_done_);
|
&from, dst, to, wrapped_done_, sync_dst_compute);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
return status_cb->status();
|
return status_cb->status();
|
||||||
@ -107,8 +109,8 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
|
|||||||
*output = std::move(copy);
|
*output = std::move(copy);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
recv_dev_context->CopyCPUTensorToDevice(input, dst, output,
|
recv_dev_context->CopyCPUTensorToDevice(input, dst, output, std::move(done),
|
||||||
std::move(done));
|
sync_dst_compute);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,7 +253,8 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
|
|||||||
Device* dst, const AllocatorAttributes src_alloc_attr,
|
Device* dst, const AllocatorAttributes src_alloc_attr,
|
||||||
const AllocatorAttributes dst_alloc_attr,
|
const AllocatorAttributes dst_alloc_attr,
|
||||||
const Tensor* input, Tensor* output,
|
const Tensor* input, Tensor* output,
|
||||||
int dev_to_dev_stream_index, StatusCallback done) {
|
int dev_to_dev_stream_index, StatusCallback done,
|
||||||
|
bool sync_dst_compute) {
|
||||||
tracing::ScopedAnnotation annotation(edge_name);
|
tracing::ScopedAnnotation annotation(edge_name);
|
||||||
VLOG(1) << "Copy " << edge_name;
|
VLOG(1) << "Copy " << edge_name;
|
||||||
|
|
||||||
@ -304,7 +307,8 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
|
|||||||
std::move(done), std::placeholders::_1);
|
std::move(done), std::placeholders::_1);
|
||||||
std::function<void(const Status&)> then_copy_to_other_device = std::bind(
|
std::function<void(const Status&)> then_copy_to_other_device = std::bind(
|
||||||
[delete_and_done, recv_dev_context, cpu_tensor, cpu_allocator,
|
[delete_and_done, recv_dev_context, cpu_tensor, cpu_allocator,
|
||||||
out_allocator, edge_name, dst, output](StatusCallback delete_and_done_,
|
out_allocator, edge_name, dst, output,
|
||||||
|
sync_dst_compute](StatusCallback delete_and_done_,
|
||||||
// Begin unbound arguments.
|
// Begin unbound arguments.
|
||||||
Status status) {
|
Status status) {
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
@ -313,7 +317,7 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
|
|||||||
}
|
}
|
||||||
CopyHostToDevice(cpu_tensor, cpu_allocator, out_allocator, edge_name,
|
CopyHostToDevice(cpu_tensor, cpu_allocator, out_allocator, edge_name,
|
||||||
dst, output, recv_dev_context,
|
dst, output, recv_dev_context,
|
||||||
std::move(delete_and_done_));
|
std::move(delete_and_done_), sync_dst_compute);
|
||||||
},
|
},
|
||||||
std::move(delete_and_done), std::placeholders::_1);
|
std::move(delete_and_done), std::placeholders::_1);
|
||||||
CopyDeviceToHost(input, cpu_allocator, out_allocator, edge_name, src,
|
CopyDeviceToHost(input, cpu_allocator, out_allocator, edge_name, src,
|
||||||
@ -334,7 +338,8 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
|
|||||||
if (!non_cpu_src && non_cpu_dst) {
|
if (!non_cpu_src && non_cpu_dst) {
|
||||||
// Host to Device copy.
|
// Host to Device copy.
|
||||||
CopyHostToDevice(input, cpu_allocator, out_allocator, edge_name, dst,
|
CopyHostToDevice(input, cpu_allocator, out_allocator, edge_name, dst,
|
||||||
output, recv_dev_context, std::move(done));
|
output, recv_dev_context, std::move(done),
|
||||||
|
sync_dst_compute);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,7 +45,8 @@ class CopyTensor {
|
|||||||
const AllocatorAttributes src_alloc_attr,
|
const AllocatorAttributes src_alloc_attr,
|
||||||
const AllocatorAttributes dst_alloc_attr,
|
const AllocatorAttributes dst_alloc_attr,
|
||||||
const Tensor* input, Tensor* output,
|
const Tensor* input, Tensor* output,
|
||||||
int dev_to_dev_stream_index, StatusCallback done);
|
int dev_to_dev_stream_index, StatusCallback done,
|
||||||
|
bool sync_dst_compute = true);
|
||||||
|
|
||||||
// Object used to call Register() at static-initialization time.
|
// Object used to call Register() at static-initialization time.
|
||||||
// Note: This should only ever be used as a global-static object; no stack
|
// Note: This should only ever be used as a global-static object; no stack
|
||||||
|
@ -24,10 +24,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
|
|
||||||
|
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <map>
|
#include <map>
|
||||||
@ -36,6 +35,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
|
#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
|
||||||
@ -313,24 +313,6 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
|
|||||||
sync_every_op_(sync_every_op),
|
sync_every_op_(sync_every_op),
|
||||||
max_streams_(max_streams) {
|
max_streams_(max_streams) {
|
||||||
GPUProcessState::singleton()->EnableGPUDevice();
|
GPUProcessState::singleton()->EnableGPUDevice();
|
||||||
pending_cap_ = options.config.gpu_options().experimental().pending_cap();
|
|
||||||
timestamped_allocator_ =
|
|
||||||
options.config.gpu_options().experimental().timestamped_allocator();
|
|
||||||
if (timestamped_allocator_ || pending_cap_ > 0) {
|
|
||||||
SharedCounter* timing_counter = nullptr;
|
|
||||||
if (timestamped_allocator_) {
|
|
||||||
// In this case the SharedCounter was already created and set in the
|
|
||||||
// associated Allocator, with ownership by GPUProcessState.
|
|
||||||
// The GPUKernelTracker will use this SharedCounter, instead of
|
|
||||||
// owning its own.
|
|
||||||
timing_counter =
|
|
||||||
GPUProcessState::singleton()->GPUAllocatorCounter(tf_gpu_id);
|
|
||||||
DCHECK(timing_counter);
|
|
||||||
} else {
|
|
||||||
DCHECK_GT(pending_cap_, 0);
|
|
||||||
}
|
|
||||||
kernel_tracker_.reset(new GPUKernelTracker(Env::Default(), timing_counter));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseGPUDevice::~BaseGPUDevice() {
|
BaseGPUDevice::~BaseGPUDevice() {
|
||||||
@ -379,7 +361,6 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
executor_ = executor_status.ValueOrDie();
|
executor_ = executor_status.ValueOrDie();
|
||||||
em_.reset(new EventMgr(executor_, options.config.gpu_options()));
|
|
||||||
|
|
||||||
if (max_streams_ < 1) {
|
if (max_streams_ < 1) {
|
||||||
return errors::InvalidArgument("Invalid value for max_streams.");
|
return errors::InvalidArgument("Invalid value for max_streams.");
|
||||||
@ -393,6 +374,39 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
|
|||||||
i, streams_.back()->compute, streams_.back()->host_to_device,
|
i, streams_.back()->compute, streams_.back()->host_to_device,
|
||||||
streams_.back()->device_to_host, streams_.back()->device_to_device));
|
streams_.back()->device_to_host, streams_.back()->device_to_device));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
em_.reset(new EventMgr(executor_, options.config.gpu_options()));
|
||||||
|
|
||||||
|
GPUKernelTracker::Params tracker_params(
|
||||||
|
options.config.gpu_options().experimental().kernel_tracker_max_interval(),
|
||||||
|
options.config.gpu_options().experimental().kernel_tracker_max_bytes(),
|
||||||
|
options.config.gpu_options().experimental().kernel_tracker_max_pending());
|
||||||
|
timestamped_allocator_ =
|
||||||
|
options.config.gpu_options().experimental().timestamped_allocator();
|
||||||
|
pending_cap_ = tracker_params.max_pending;
|
||||||
|
if (timestamped_allocator_ ||
|
||||||
|
(tracker_params.max_interval > 0 || tracker_params.max_bytes > 0 ||
|
||||||
|
tracker_params.max_pending > 0)) {
|
||||||
|
if (max_streams_ > 1) {
|
||||||
|
LOG(FATAL) << "max_streams > 1 was specified together with "
|
||||||
|
"timestamped_allocator and/or kernel tracking. This is an "
|
||||||
|
"unsupported combination.";
|
||||||
|
}
|
||||||
|
SharedCounter* timing_counter = nullptr;
|
||||||
|
if (timestamped_allocator_) {
|
||||||
|
// In this case the SharedCounter was already created and set in the
|
||||||
|
// associated Allocator, with ownership by GPUProcessState.
|
||||||
|
// The GPUKernelTracker will use this SharedCounter, instead of
|
||||||
|
// owning its own.
|
||||||
|
timing_counter =
|
||||||
|
GPUProcessState::singleton()->GPUAllocatorCounter(tf_gpu_id_);
|
||||||
|
DCHECK(timing_counter);
|
||||||
|
}
|
||||||
|
kernel_tracker_.reset(new GPUKernelTracker(
|
||||||
|
tracker_params, Env::Default(), streams_[0]->compute, timing_counter,
|
||||||
|
timestamped_allocator_ ? gpu_allocator_ : nullptr, em_.get()));
|
||||||
|
}
|
||||||
|
|
||||||
gpu_device_info_ = new GpuDeviceInfo;
|
gpu_device_info_ = new GpuDeviceInfo;
|
||||||
gpu_device_info_->stream = streams_[0]->compute;
|
gpu_device_info_->stream = streams_[0]->compute;
|
||||||
gpu_device_info_->default_context = device_contexts_[0];
|
gpu_device_info_->default_context = device_contexts_[0];
|
||||||
@ -569,10 +583,12 @@ void BaseGPUDevice::ComputeHelper(OpKernel* op_kernel,
|
|||||||
if (idc->stream() != stream) stream->ThenWaitFor(idc->stream());
|
if (idc->stream() != stream) stream->ThenWaitFor(idc->stream());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (kernel_tracker_.get()) {
|
||||||
|
context->set_record_memory_consumption(true);
|
||||||
if (pending_cap_ > 0) {
|
if (pending_cap_ > 0) {
|
||||||
DCHECK(kernel_tracker_);
|
|
||||||
kernel_tracker_->PauseWhilePendingExceeds(pending_cap_);
|
kernel_tracker_->PauseWhilePendingExceeds(pending_cap_);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
ScopedActivateExecutorContext scoped_activation{stream->parent()};
|
ScopedActivateExecutorContext scoped_activation{stream->parent()};
|
||||||
op_kernel->Compute(context);
|
op_kernel->Compute(context);
|
||||||
if (context->status().ok()) {
|
if (context->status().ok()) {
|
||||||
@ -593,11 +609,13 @@ void BaseGPUDevice::ComputeHelper(OpKernel* op_kernel,
|
|||||||
if (kernel_tracker_) {
|
if (kernel_tracker_) {
|
||||||
GPUKernelTracker* tracker = kernel_tracker_.get();
|
GPUKernelTracker* tracker = kernel_tracker_.get();
|
||||||
DCHECK(tracker);
|
DCHECK(tracker);
|
||||||
uint64 queued_count = tracker->RecordQueued();
|
uint64 queued_count = tracker->MaybeQueue(context);
|
||||||
em_->ThenExecute(stream, [op_kernel, tracker, queued_count]() {
|
if (queued_count > 0) {
|
||||||
|
em_->ThenExecute(stream, [tracker, queued_count]() {
|
||||||
tracker->RecordTerminated(queued_count);
|
tracker->RecordTerminated(queued_count);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (vlog_1) {
|
if (vlog_1) {
|
||||||
VLOG(1) << "GpuDevice::ComputeHelper failed to schedule "
|
VLOG(1) << "GpuDevice::ComputeHelper failed to schedule "
|
||||||
@ -661,8 +679,17 @@ Status BaseGPUDevice::MaybeCopyTensorToGPU(
|
|||||||
done(err);
|
done(err);
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
auto* copy =
|
AllocationAttributes allocation_attr;
|
||||||
new Tensor(GetAllocator(alloc_attrs), from.dtype(), from.shape());
|
uint64 safe_alloc_frontier = 0;
|
||||||
|
std::function<uint64()> freed_by_func = [this, &safe_alloc_frontier]() {
|
||||||
|
safe_alloc_frontier = SafeAllocFrontier(safe_alloc_frontier);
|
||||||
|
return safe_alloc_frontier;
|
||||||
|
};
|
||||||
|
if (timestamped_allocator_) {
|
||||||
|
allocation_attr.freed_by_func = &freed_by_func;
|
||||||
|
}
|
||||||
|
auto* copy = new Tensor(GetAllocator(alloc_attrs), from.dtype(),
|
||||||
|
from.shape(), allocation_attr);
|
||||||
|
|
||||||
// If the tensor is not initialized, we likely ran out of memory.
|
// If the tensor is not initialized, we likely ran out of memory.
|
||||||
if (!copy->IsInitialized()) {
|
if (!copy->IsInitialized()) {
|
||||||
@ -687,8 +714,9 @@ Status BaseGPUDevice::MaybeCopyTensorToGPU(
|
|||||||
std::move(done), std::placeholders::_1);
|
std::move(done), std::placeholders::_1);
|
||||||
|
|
||||||
tracing::ScopedAnnotation annotation("MakeTensorFromProto");
|
tracing::ScopedAnnotation annotation("MakeTensorFromProto");
|
||||||
device_contexts_[0]->CopyCPUTensorToDevice(&from, this, copy,
|
device_contexts_[0]->CopyCPUTensorToDevice(
|
||||||
std::move(wrapped_done));
|
&from, this, copy, std::move(wrapped_done),
|
||||||
|
!timestamped_allocator_ /*sync_dst_compute*/);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -894,8 +922,8 @@ int64 MinSystemMemory(int64 available_memory) {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(ANDROID_TEGRA)
|
#if defined(ANDROID_TEGRA)
|
||||||
// 1GB system mem for NVIDIA Tegra devices since they use the same mem for RAM
|
// 1GB system mem for NVIDIA Tegra devices since they use the same mem for
|
||||||
// and Video RAM
|
// RAM and Video RAM
|
||||||
min_system_memory = 1 << 30;
|
min_system_memory = 1 << 30;
|
||||||
#endif
|
#endif
|
||||||
return min_system_memory;
|
return min_system_memory;
|
||||||
@ -1048,8 +1076,8 @@ Status BaseGPUDeviceFactory::CreateDevices(
|
|||||||
std::vector<PlatformGpuId> valid_platform_gpu_ids;
|
std::vector<PlatformGpuId> valid_platform_gpu_ids;
|
||||||
// If we aren't going to use any GPUs, don't initialize them.
|
// If we aren't going to use any GPUs, don't initialize them.
|
||||||
// We don't want to call ParseVisibleDeviceList if num_gpus_to_use is 0,
|
// We don't want to call ParseVisibleDeviceList if num_gpus_to_use is 0,
|
||||||
// because it treats an empty gpu_options.visible_device_list as 'all GPUs are
|
// because it treats an empty gpu_options.visible_device_list as 'all GPUs
|
||||||
// visible'.
|
// are visible'.
|
||||||
if (num_gpus_to_use > 0) {
|
if (num_gpus_to_use > 0) {
|
||||||
TF_RETURN_IF_ERROR(ParseVisibleDeviceList(gpu_options.visible_device_list(),
|
TF_RETURN_IF_ERROR(ParseVisibleDeviceList(gpu_options.visible_device_list(),
|
||||||
&visible_gpu_order));
|
&visible_gpu_order));
|
||||||
@ -1237,8 +1265,9 @@ static string GetShortDeviceDescription(PlatformGpuId platform_gpu_id,
|
|||||||
cc_minor = 0;
|
cc_minor = 0;
|
||||||
}
|
}
|
||||||
// LINT.IfChange
|
// LINT.IfChange
|
||||||
return strings::StrCat("device: ", platform_gpu_id.value(), ", name: ",
|
return strings::StrCat("device: ", platform_gpu_id.value(),
|
||||||
desc.name(), ", pci bus id: ", desc.pci_bus_id(),
|
", name: ", desc.name(),
|
||||||
|
", pci bus id: ", desc.pci_bus_id(),
|
||||||
", compute capability: ", cc_major, ".", cc_minor);
|
", compute capability: ", cc_major, ".", cc_minor);
|
||||||
// LINT.ThenChange(//tensorflow/python/platform/test.py)
|
// LINT.ThenChange(//tensorflow/python/platform/test.py)
|
||||||
#elif TENSORFLOW_USE_ROCM
|
#elif TENSORFLOW_USE_ROCM
|
||||||
@ -1279,13 +1308,14 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(
|
|||||||
if (!stats) {
|
if (!stats) {
|
||||||
return errors::Internal("No allocator statistics");
|
return errors::Internal("No allocator statistics");
|
||||||
}
|
}
|
||||||
// 'memory_limit' is the required memory size, but if the allocator with given
|
// 'memory_limit' is the required memory size, but if the allocator with
|
||||||
// tf_gpu_id was created before, we'll use it instead of creating a new one
|
// given tf_gpu_id was created before, we'll use it instead of creating a
|
||||||
// (as TF gpu device is a shared resource), in which case the actual memory
|
// new one (as TF gpu device is a shared resource), in which case the actual
|
||||||
// limit represented by 'stats.bytes_limit' used by that allocator may be
|
// memory limit represented by 'stats.bytes_limit' used by that allocator
|
||||||
// different (which should be an error).
|
// may be different (which should be an error).
|
||||||
//
|
//
|
||||||
// TODO(laigd): report error if memory_limit doesn't match stats->bytes_limit.
|
// TODO(laigd): report error if memory_limit doesn't match
|
||||||
|
// stats->bytes_limit.
|
||||||
int64 bytes_limit = stats->bytes_limit ? *stats->bytes_limit : 0;
|
int64 bytes_limit = stats->bytes_limit ? *stats->bytes_limit : 0;
|
||||||
std::unique_ptr<BaseGPUDevice> gpu_device = CreateGPUDevice(
|
std::unique_ptr<BaseGPUDevice> gpu_device = CreateGPUDevice(
|
||||||
options, device_name, static_cast<Bytes>(bytes_limit), dev_locality,
|
options, device_name, static_cast<Bytes>(bytes_limit), dev_locality,
|
||||||
@ -1725,9 +1755,9 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64 BaseGPUDevice::SafeAllocFrontier() {
|
uint64 BaseGPUDevice::SafeAllocFrontier(uint64 old_value) {
|
||||||
if (timestamped_allocator_) {
|
if (timestamped_allocator_) {
|
||||||
return kernel_tracker_->LastTerminatedCount();
|
return kernel_tracker_->LastTerminatedCount(old_value);
|
||||||
} else {
|
} else {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@ -1740,19 +1770,50 @@ int BaseGPUDevice::PendingKernels() {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64 GPUKernelTracker::RecordQueued() {
|
uint64 GPUKernelTracker::MaybeQueue(OpKernelContext* ctx) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
++ops_since_last_;
|
||||||
|
int64 mem_used =
|
||||||
|
ctx->persistent_memory_allocated() + ctx->temp_memory_allocated();
|
||||||
|
VLOG(2) << "kernel: " << ctx->op_kernel().name() << " mem_used: " << mem_used;
|
||||||
|
mem_since_last_ += mem_used;
|
||||||
|
int weight = 1;
|
||||||
|
// Note that if all {max_bytes, max_interval, max_pending} are zero then
|
||||||
|
// we we track every single kernel with no pending cap. This can happen
|
||||||
|
// if timestamped_allocator alone was specified.
|
||||||
|
if ((mem_since_last_ < params_.max_bytes) &&
|
||||||
|
(ops_since_last_ < params_.max_interval)) {
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
weight = std::min(
|
||||||
|
params_.max_pending,
|
||||||
|
std::max(1, mem_since_last_ / std::max(16386, params_.max_bytes)));
|
||||||
|
mem_since_last_ = 0;
|
||||||
|
ops_since_last_ = 0;
|
||||||
|
}
|
||||||
uint64 queued_count = timing_counter_->next();
|
uint64 queued_count = timing_counter_->next();
|
||||||
|
RecordQueued(queued_count, weight);
|
||||||
|
return queued_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GPUKernelTracker::RecordQueued(uint64 queued_count, int weight) {
|
||||||
VLOG(2) << "RecordQueued queued_count=" << queued_count
|
VLOG(2) << "RecordQueued queued_count=" << queued_count
|
||||||
<< " first_available_=" << first_available_
|
<< " first_available_=" << first_available_
|
||||||
<< " last_completed_=" << last_completed_
|
<< " last_completed_=" << last_completed_
|
||||||
<< " num_pending_=" << num_pending_;
|
<< " num_pending_=" << num_pending_;
|
||||||
pending_kernels_[first_available_].queued_count = queued_count;
|
pending_kernels_[first_available_].queued_count = queued_count;
|
||||||
|
pending_kernels_[first_available_].weight = weight;
|
||||||
pending_kernels_[first_available_].terminated = false;
|
pending_kernels_[first_available_].terminated = false;
|
||||||
++first_available_;
|
++first_available_;
|
||||||
++num_pending_;
|
num_pending_ += weight;
|
||||||
if (first_available_ >= pending_kernels_.size()) {
|
if (first_available_ >= pending_kernels_.size()) {
|
||||||
|
if (last_completed_ >= 0) {
|
||||||
|
// wrap
|
||||||
first_available_ = 0;
|
first_available_ = 0;
|
||||||
|
} else {
|
||||||
|
// enlarge the ring buffer
|
||||||
|
pending_kernels_.resize(2 * pending_kernels_.size());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (first_available_ == last_completed_) {
|
if (first_available_ == last_completed_) {
|
||||||
// Ring buffer is full: double it. All of the same valid PendingKernel
|
// Ring buffer is full: double it. All of the same valid PendingKernel
|
||||||
@ -1771,12 +1832,30 @@ uint64 GPUKernelTracker::RecordQueued() {
|
|||||||
<< " num_pending_=" << num_pending_;
|
<< " num_pending_=" << num_pending_;
|
||||||
}
|
}
|
||||||
DCHECK_NE(first_available_, last_completed_) << "exhausted pending_kernels";
|
DCHECK_NE(first_available_, last_completed_) << "exhausted pending_kernels";
|
||||||
return queued_count;
|
}
|
||||||
|
|
||||||
|
// Called by LastTerminatedCount() when new_value is equal to old_value. This
|
||||||
|
// case can occur where an allocation failed and waited for memory to be freed,
|
||||||
|
// then when it retried the safe allocation frontier had not advanced because no
|
||||||
|
// tracking event had matured. Maybe GPU progress has stalled waiting on an i/o
|
||||||
|
// event, or maybe we're tracking at too infrequent an interval. In any case if
|
||||||
|
// the GPU compute queue is actually empty it's safe to advance the safe
|
||||||
|
// frontier so that this request can allocate from unrestricted (and better
|
||||||
|
// compacted) memory. So queue an event on the compute stream to ensure the
|
||||||
|
// frontier does advance.
|
||||||
|
void GPUKernelTracker::MaybeQueueProgressEvent() {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
if (num_pending_ == 0) {
|
||||||
|
uint64 new_count = timing_counter_->next();
|
||||||
|
RecordQueued(new_count, 1);
|
||||||
|
em_->ThenExecute(stream_,
|
||||||
|
[this, new_count]() { RecordTerminated(new_count); });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUKernelTracker::RecordTerminated(uint64 queued_count) {
|
void GPUKernelTracker::RecordTerminated(uint64 queued_count) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
VLOG(2) << "RecordTerminated queued_count=" << queued_count
|
VLOG(2) << this << " RecordTerminated queued_count=" << queued_count
|
||||||
<< " first_available_=" << first_available_
|
<< " first_available_=" << first_available_
|
||||||
<< " last_completed_=" << last_completed_
|
<< " last_completed_=" << last_completed_
|
||||||
<< " num_pending_=" << num_pending_ << " LC="
|
<< " num_pending_=" << num_pending_ << " LC="
|
||||||
@ -1788,26 +1867,31 @@ void GPUKernelTracker::RecordTerminated(uint64 queued_count) {
|
|||||||
// Starting just past the last completed entry, find the entry with
|
// Starting just past the last completed entry, find the entry with
|
||||||
// this queued_count and mark it done.
|
// this queued_count and mark it done.
|
||||||
int index = (last_completed_ + 1) % pending_kernels_.size();
|
int index = (last_completed_ + 1) % pending_kernels_.size();
|
||||||
|
int weight = 1;
|
||||||
while (true) {
|
while (true) {
|
||||||
if (index == first_available_) {
|
if (index == first_available_) {
|
||||||
// This should never happen.
|
// This should never happen.
|
||||||
LOG(FATAL) << "Failed to find " << queued_count // Crash OK
|
LOG(FATAL) << "Failed to find " << queued_count // Crash OK
|
||||||
<< " in queue";
|
<< " in queue, last_completed_=" << last_completed_
|
||||||
|
<< " index=" << index
|
||||||
|
<< " first_available_=" << first_available_
|
||||||
|
<< " pending_kernels_.size()=" << pending_kernels_.size();
|
||||||
}
|
}
|
||||||
if (pending_kernels_[index].queued_count == queued_count) {
|
if (pending_kernels_[index].queued_count == queued_count) {
|
||||||
pending_kernels_[index].terminated = true;
|
pending_kernels_[index].terminated = true;
|
||||||
|
weight = pending_kernels_[index].weight;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
index = (index + 1) % pending_kernels_.size();
|
index = (index + 1) % pending_kernels_.size();
|
||||||
}
|
}
|
||||||
// Next move last_completed_ forward past all completed kernels. In theory
|
// Next move last_completed_ forward past all completed kernels. In theory
|
||||||
// kernels should always complete in queued order so we should be able to
|
// kernels should always complete in queued order so we should be able to
|
||||||
// advance the completed frontier to the last queued PendingKernel. In
|
// advance the completed frontier to the just-completed PendingKernel. In
|
||||||
// practice we occassionally see the termination callbacks arrive out of order
|
// practice we occasionally see the termination callbacks arrive out of
|
||||||
// probably because of thread scheduling. Eventually we may support out-of-
|
// order probably because of thread scheduling. Eventually we may support
|
||||||
// order completion involving multple compute streams so here we follow a
|
// out-of- order completion involving multple compute streams so here we
|
||||||
// conservative approach and wait for every single callback to arrive before
|
// follow a conservative approach and wait for every single callback to
|
||||||
// advancing the frontier.
|
// arrive before advancing the frontier.
|
||||||
while (true) {
|
while (true) {
|
||||||
int next_index = (last_completed_ + 1) % pending_kernels_.size();
|
int next_index = (last_completed_ + 1) % pending_kernels_.size();
|
||||||
if (next_index == first_available_) break;
|
if (next_index == first_available_) break;
|
||||||
@ -1817,21 +1901,16 @@ void GPUKernelTracker::RecordTerminated(uint64 queued_count) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (last_completed_ >= 0) {
|
||||||
|
int64 v = pending_kernels_[last_completed_].queued_count;
|
||||||
|
last_terminated_count_ = v;
|
||||||
|
if (allocator_) {
|
||||||
|
allocator_->SetSafeFrontier(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
// Last decrease num_pending before maybe waking a waiter.
|
// Last decrease num_pending before maybe waking a waiter.
|
||||||
--num_pending_;
|
num_pending_ -= weight;
|
||||||
pending_decreased_.notify_one();
|
pending_decreased_.notify_all();
|
||||||
}
|
|
||||||
|
|
||||||
uint64 GPUKernelTracker::LastTerminatedCount() {
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
if (last_completed_ < 0) {
|
|
||||||
// This is an edge case that can be encountered only at the beginning of
|
|
||||||
// execution. There's not yet a safe threshold count. We don't want to
|
|
||||||
// return 0 since that bypasses the count mechanism in BFCAllocator, so
|
|
||||||
// return the least non-zero value.
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
return pending_kernels_[last_completed_].queued_count;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -125,7 +125,7 @@ class BaseGPUDevice : public LocalDevice {
|
|||||||
|
|
||||||
// If returned value is > 0 then GPU Memory chunks freed before this count
|
// If returned value is > 0 then GPU Memory chunks freed before this count
|
||||||
// are guaranteed not to be in use by any kernel pending on this device.
|
// are guaranteed not to be in use by any kernel pending on this device.
|
||||||
uint64 SafeAllocFrontier() override;
|
uint64 SafeAllocFrontier(uint64 old_value) override;
|
||||||
|
|
||||||
// Returns the number of kernels that have been queued for execution on
|
// Returns the number of kernels that have been queued for execution on
|
||||||
// the compute stream and are not yet known to have completed.
|
// the compute stream and are not yet known to have completed.
|
||||||
@ -160,7 +160,7 @@ class BaseGPUDevice : public LocalDevice {
|
|||||||
std::unique_ptr<EventMgr> em_;
|
std::unique_ptr<EventMgr> em_;
|
||||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||||
std::unique_ptr<GPUKernelTracker> kernel_tracker_;
|
std::unique_ptr<GPUKernelTracker> kernel_tracker_;
|
||||||
int pending_cap_ = 0;
|
int32 pending_cap_ = 0;
|
||||||
bool timestamped_allocator_ = false;
|
bool timestamped_allocator_ = false;
|
||||||
|
|
||||||
// Initialize scractch buffers used by Eigen.
|
// Initialize scractch buffers used by Eigen.
|
||||||
@ -185,15 +185,43 @@ class BaseGPUDevice : public LocalDevice {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// A per-compute-stream utility that keeps track of kernels that have been
|
// A per-compute-stream utility that keeps track of kernels that have been
|
||||||
// queued for execution but may not yet have terminated, and also the queued
|
// queued for execution but may not yet have terminated and also the queued
|
||||||
// time of the most recently terminated kernel.
|
// time of the most recently terminated kernel.
|
||||||
class GPUKernelTracker {
|
class GPUKernelTracker {
|
||||||
public:
|
public:
|
||||||
|
// Controls the strategy for inserting tracking events after GPU kernels.
|
||||||
|
// If max_interval >= 0, then insert an event after this many kernels
|
||||||
|
// if an event has not been inserted for another reason.
|
||||||
|
// If max_bytes > 0, then insert an event after kernels allocating this
|
||||||
|
// many bytes have been queued since the last event.
|
||||||
|
// If max_pending > 0, then track up to this many events at once. If
|
||||||
|
// this limit is reached the GPU::Compute() method will delay starting
|
||||||
|
// additional ops until some event completes. If 0 and one of the other
|
||||||
|
// fields is non-zero, then a reasonable default will be selected.
|
||||||
|
struct Params {
|
||||||
|
int max_interval = 0;
|
||||||
|
int max_bytes = 0;
|
||||||
|
int max_pending = 0;
|
||||||
|
Params(int mi, int mb, int mp)
|
||||||
|
: max_interval(mi), max_bytes(mb), max_pending(mp) {}
|
||||||
|
};
|
||||||
|
|
||||||
// If we're going to share a SharedCounter with an allocator, it's owned
|
// If we're going to share a SharedCounter with an allocator, it's owned
|
||||||
// by the allocator because allocators are initialized once per process.
|
// by the allocator because allocators are initialized once per process.
|
||||||
// Devices are per-session.
|
// Devices are per-session.
|
||||||
explicit GPUKernelTracker(Env* env, SharedCounter* timing_counter)
|
explicit GPUKernelTracker(const Params& params, Env* env,
|
||||||
: env_(env), timing_counter_(timing_counter), pending_kernels_(64) {
|
se::Stream* compute_stream,
|
||||||
|
SharedCounter* timing_counter, Allocator* allocator,
|
||||||
|
EventMgr* event_manager)
|
||||||
|
: params_(params),
|
||||||
|
env_(env),
|
||||||
|
stream_(compute_stream),
|
||||||
|
timing_counter_(timing_counter),
|
||||||
|
allocator_(allocator),
|
||||||
|
em_(event_manager),
|
||||||
|
pending_kernels_(
|
||||||
|
params.max_pending > 0 ? std::max(8, 2 * params.max_pending) : 64) {
|
||||||
|
mem_since_last_ = 0;
|
||||||
if (!timing_counter_) {
|
if (!timing_counter_) {
|
||||||
// There's not a preexisting counter owned by GPUProcessState, i.e.
|
// There's not a preexisting counter owned by GPUProcessState, i.e.
|
||||||
// pending_cap > 0 but timestamped_allocator == false.
|
// pending_cap > 0 but timestamped_allocator == false.
|
||||||
@ -202,19 +230,33 @@ class GPUKernelTracker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine whether a GPU kernel should have a recording event queued
|
||||||
|
// immediately afterwards. If so, advance the counter and return the new
|
||||||
|
// counter value after enqueuing.
|
||||||
|
uint64 MaybeQueue(OpKernelContext* ctx);
|
||||||
|
|
||||||
// Record that a GPU kernel has just been enqueued on the compute stream.
|
// Record that a GPU kernel has just been enqueued on the compute stream.
|
||||||
// Inserts a new timing counter value in a new PendingKernel record appended
|
// Inserts the supplied counter value in a new PendingKernel record appended
|
||||||
// to the end of the ring buffer then returns that same count.
|
// to the end of the ring buffer then returns that same count.
|
||||||
uint64 RecordQueued();
|
// Caller is responsible for ensuring that RecordTerminate() is eventually
|
||||||
|
// called with the same counter value.
|
||||||
|
void RecordQueued(uint64 queued_count, int weight)
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
// Takes a count value returned by RecordQueued and finds the corresponding
|
// Takes a count value returned by RecordQueued and finds the corresponding
|
||||||
// PendingKernel record in the ring buffer. Marks the kernel as completed and
|
// PendingKernel record in the ring buffer. Marks the kernel as completed and
|
||||||
// advances the completion frontier accordingly.
|
// advances the completion frontier accordingly.
|
||||||
void RecordTerminated(uint64 at_count);
|
void RecordTerminated(uint64 queued_count);
|
||||||
|
|
||||||
// Returns the largest timing count such that all kernels queued no
|
// Returns the largest timing count such that all kernels queued no
|
||||||
// later than that count are known to have terminated.
|
// later than that count are known to have terminated.
|
||||||
uint64 LastTerminatedCount();
|
inline uint64 LastTerminatedCount(uint64 old_value) {
|
||||||
|
uint64 new_value = last_terminated_count_.load(std::memory_order_relaxed);
|
||||||
|
if (new_value == old_value) {
|
||||||
|
MaybeQueueProgressEvent();
|
||||||
|
}
|
||||||
|
return new_value;
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the number of kernels enqueued that are not yet known to
|
// Returns the number of kernels enqueued that are not yet known to
|
||||||
// have terminated.
|
// have terminated.
|
||||||
@ -225,28 +267,42 @@ class GPUKernelTracker {
|
|||||||
|
|
||||||
// Yield current thread until number of pending kernels no longer
|
// Yield current thread until number of pending kernels no longer
|
||||||
// exceeds the cap.
|
// exceeds the cap.
|
||||||
void PauseWhilePendingExceeds(int cap) {
|
void PauseWhilePendingExceeds(int cap) LOCKS_EXCLUDED(mu_) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
while (num_pending_ > cap) {
|
while (num_pending_ > cap) {
|
||||||
|
VLOG(1) << "num_pending_=" << num_pending_ << " cap=" << cap;
|
||||||
pending_decreased_.wait(l);
|
pending_decreased_.wait(l);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class GPUKernelTrackerTest;
|
||||||
|
Params params_;
|
||||||
Env* env_;
|
Env* env_;
|
||||||
|
se::Stream* stream_;
|
||||||
SharedCounter* timing_counter_;
|
SharedCounter* timing_counter_;
|
||||||
std::unique_ptr<SharedCounter> owned_counter_;
|
std::unique_ptr<SharedCounter> owned_counter_;
|
||||||
|
Allocator* allocator_ = nullptr;
|
||||||
|
EventMgr* em_ = nullptr;
|
||||||
|
std::atomic<uint64> last_terminated_count_ = {1};
|
||||||
|
|
||||||
|
void MaybeQueueProgressEvent();
|
||||||
|
|
||||||
// Records when a kernel was queued for execution. Kernel launches are
|
// Records when a kernel was queued for execution. Kernel launches are
|
||||||
// identified by a unique count value from a per-GPU device timing counter.
|
// identified by a unique count value from a per-GPU device timing counter.
|
||||||
struct PendingKernel {
|
struct PendingKernel {
|
||||||
uint64 queued_count;
|
uint64 queued_count;
|
||||||
|
int weight;
|
||||||
bool terminated;
|
bool terminated;
|
||||||
PendingKernel(const PendingKernel& pk)
|
PendingKernel(const PendingKernel& pk)
|
||||||
: queued_count(pk.queued_count), terminated(pk.terminated) {}
|
: queued_count(pk.queued_count),
|
||||||
PendingKernel() : queued_count(0), terminated(false) {}
|
weight(pk.weight),
|
||||||
|
terminated(pk.terminated) {}
|
||||||
|
PendingKernel() : queued_count(0), weight(0), terminated(false) {}
|
||||||
};
|
};
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
|
int32 mem_since_last_ GUARDED_BY(mu_);
|
||||||
|
int32 ops_since_last_ GUARDED_BY(mu_);
|
||||||
// Ring buffer of PendingKernel records.
|
// Ring buffer of PendingKernel records.
|
||||||
std::vector<PendingKernel> pending_kernels_ GUARDED_BY(mu_);
|
std::vector<PendingKernel> pending_kernels_ GUARDED_BY(mu_);
|
||||||
// Next unused slot in pending_kernels_.
|
// Next unused slot in pending_kernels_.
|
||||||
@ -254,9 +310,9 @@ class GPUKernelTracker {
|
|||||||
// Last completed PendingKernel such that all prior PendingKernels are
|
// Last completed PendingKernel such that all prior PendingKernels are
|
||||||
// also completed. With out-of-order completion there may be a mixture
|
// also completed. With out-of-order completion there may be a mixture
|
||||||
// of completed and uncompleted entries between last_completed_ and
|
// of completed and uncompleted entries between last_completed_ and
|
||||||
// first_available_, hence num_pending_ is not guaranteed equal to
|
// first_available_.
|
||||||
// their differerence.
|
|
||||||
int last_completed_ GUARDED_BY(mu_) = -1;
|
int last_completed_ GUARDED_BY(mu_) = -1;
|
||||||
|
// Sum of weights of the outstanding events marking tracked kernels.
|
||||||
int num_pending_ GUARDED_BY(mu_) = 0;
|
int num_pending_ GUARDED_BY(mu_) = 0;
|
||||||
condition_variable pending_decreased_ GUARDED_BY(mu_);
|
condition_variable pending_decreased_ GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
@ -348,27 +348,36 @@ TEST_F(GPUDeviceTest, CopyTensorInSameDevice) {
|
|||||||
|
|
||||||
class GPUKernelTrackerTest : public ::testing::Test {
|
class GPUKernelTrackerTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void SetUp() {
|
void Init(const GPUKernelTracker::Params& params) {
|
||||||
timing_counter_.reset(new SharedCounter);
|
timing_counter_.reset(new SharedCounter);
|
||||||
kernel_tracker_.reset(
|
kernel_tracker_.reset(new GPUKernelTracker(params, Env::Default(), nullptr,
|
||||||
new GPUKernelTracker(Env::Default(), timing_counter_.get()));
|
timing_counter_.get(), nullptr,
|
||||||
|
nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecordQueued(uint64 v) {
|
||||||
|
mutex_lock l(kernel_tracker_->mu_);
|
||||||
|
kernel_tracker_->RecordQueued(v, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<GPUKernelTracker> kernel_tracker_;
|
std::unique_ptr<GPUKernelTracker> kernel_tracker_;
|
||||||
std::unique_ptr<SharedCounter> timing_counter_;
|
std::unique_ptr<SharedCounter> timing_counter_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(GPUKernelTrackerTest, basic) {
|
TEST_F(GPUKernelTrackerTest, CappingOnly) {
|
||||||
|
Init({0 /*max_interval*/, 0 /*max_bytes*/, 32 /*max_pending*/});
|
||||||
EXPECT_EQ(0, kernel_tracker_->NumPending());
|
EXPECT_EQ(0, kernel_tracker_->NumPending());
|
||||||
// 1 is the expected value when no kernels have yet terminated.
|
// 1 is the expected value when no kernels have yet terminated.
|
||||||
EXPECT_EQ(1, kernel_tracker_->LastTerminatedCount());
|
EXPECT_EQ(1, kernel_tracker_->LastTerminatedCount(0));
|
||||||
|
|
||||||
std::deque<int64> queued_counts;
|
std::deque<int64> queued_counts;
|
||||||
for (int i = 0; i < 32; ++i) {
|
for (int i = 0; i < 32; ++i) {
|
||||||
queued_counts.push_back(kernel_tracker_->RecordQueued());
|
uint64 queued_count = timing_counter_->next();
|
||||||
|
queued_counts.push_back(queued_count);
|
||||||
|
RecordQueued(queued_count);
|
||||||
}
|
}
|
||||||
EXPECT_EQ(32, kernel_tracker_->NumPending());
|
EXPECT_EQ(32, kernel_tracker_->NumPending());
|
||||||
EXPECT_EQ(1, kernel_tracker_->LastTerminatedCount());
|
EXPECT_EQ(1, kernel_tracker_->LastTerminatedCount(0));
|
||||||
|
|
||||||
// Mature the kernels in order until empty.
|
// Mature the kernels in order until empty.
|
||||||
while (!queued_counts.empty()) {
|
while (!queued_counts.empty()) {
|
||||||
@ -376,23 +385,25 @@ TEST_F(GPUKernelTrackerTest, basic) {
|
|||||||
queued_counts.pop_front();
|
queued_counts.pop_front();
|
||||||
kernel_tracker_->RecordTerminated(x);
|
kernel_tracker_->RecordTerminated(x);
|
||||||
EXPECT_EQ(queued_counts.size(), kernel_tracker_->NumPending());
|
EXPECT_EQ(queued_counts.size(), kernel_tracker_->NumPending());
|
||||||
EXPECT_EQ(x, kernel_tracker_->LastTerminatedCount());
|
EXPECT_EQ(x, kernel_tracker_->LastTerminatedCount(0));
|
||||||
}
|
}
|
||||||
EXPECT_EQ(timing_counter_->get(), kernel_tracker_->LastTerminatedCount());
|
EXPECT_EQ(timing_counter_->get(), kernel_tracker_->LastTerminatedCount(0));
|
||||||
|
|
||||||
// Next inject so many kernel events that the ring buffer needs
|
// Next inject so many kernel events that the ring buffer needs
|
||||||
// to grow a couple of times, while maturing a few in random order
|
// to grow a couple of times, while maturing a few in random order
|
||||||
// to introduce gaps between last_completed_ and first_available_.
|
// to introduce gaps between last_completed_ and first_available_.
|
||||||
int64 lower_bound = timing_counter_->get();
|
int64 lower_bound = timing_counter_->get();
|
||||||
for (int i = 0; i < 1111; ++i) {
|
for (int i = 0; i < 1111; ++i) {
|
||||||
queued_counts.push_back(kernel_tracker_->RecordQueued());
|
uint64 queued_count = timing_counter_->next();
|
||||||
|
queued_counts.push_back(queued_count);
|
||||||
|
RecordQueued(queued_count);
|
||||||
int64 upper_bound = timing_counter_->get();
|
int64 upper_bound = timing_counter_->get();
|
||||||
if (0 == (i % 16)) {
|
if (0 == (i % 16)) {
|
||||||
size_t index = (random::New64() % queued_counts.size());
|
size_t index = (random::New64() % queued_counts.size());
|
||||||
kernel_tracker_->RecordTerminated(queued_counts[index]);
|
kernel_tracker_->RecordTerminated(queued_counts[index]);
|
||||||
queued_counts.erase(queued_counts.begin() + index);
|
queued_counts.erase(queued_counts.begin() + index);
|
||||||
EXPECT_LE(lower_bound, kernel_tracker_->LastTerminatedCount());
|
EXPECT_LE(lower_bound, kernel_tracker_->LastTerminatedCount(0));
|
||||||
EXPECT_GE(upper_bound, kernel_tracker_->LastTerminatedCount());
|
EXPECT_GE(upper_bound, kernel_tracker_->LastTerminatedCount(0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -405,9 +416,9 @@ TEST_F(GPUKernelTrackerTest, basic) {
|
|||||||
// There may be a gap here where we find a kernel that got terminated
|
// There may be a gap here where we find a kernel that got terminated
|
||||||
// out of order, earlier, so the LastTerminatedCount can actually
|
// out of order, earlier, so the LastTerminatedCount can actually
|
||||||
// jump past x.
|
// jump past x.
|
||||||
EXPECT_LE(x, kernel_tracker_->LastTerminatedCount());
|
EXPECT_LE(x, kernel_tracker_->LastTerminatedCount(0));
|
||||||
}
|
}
|
||||||
EXPECT_EQ(timing_counter_->get(), kernel_tracker_->LastTerminatedCount());
|
EXPECT_EQ(timing_counter_->get(), kernel_tracker_->LastTerminatedCount(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -62,7 +62,7 @@ class TEST_EventMgrHelper {
|
|||||||
em_->QueueTensors(stream, tensors);
|
em_->QueueTensors(stream, tensors);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PollEvents(bool is_dedicated_poller) {
|
void PollEvents() {
|
||||||
while (queue_size() > 0) {
|
while (queue_size() > 0) {
|
||||||
// For ordinary tensor frees, this function
|
// For ordinary tensor frees, this function
|
||||||
// should synchronously harvest all complete
|
// should synchronously harvest all complete
|
||||||
@ -70,15 +70,15 @@ class TEST_EventMgrHelper {
|
|||||||
EventMgr::ToFreeVector to_free;
|
EventMgr::ToFreeVector to_free;
|
||||||
{
|
{
|
||||||
mutex_lock l(em_->mu_);
|
mutex_lock l(em_->mu_);
|
||||||
em_->PollEvents(is_dedicated_poller, &to_free);
|
em_->PollEvents(true, &to_free);
|
||||||
}
|
}
|
||||||
em_->FreeMemory(to_free);
|
em_->FreeMemory(to_free);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void StopPollingLoop() { em_->StopPollingLoop(); }
|
void StopPollingLoop() { return em_->StopPollingLoop(); }
|
||||||
|
|
||||||
void StartPollingLoop() { em_->StartPollingLoop(); }
|
void StartPollingLoop() { return em_->StartPollingLoop(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
EventMgr* em_;
|
EventMgr* em_;
|
||||||
@ -140,7 +140,7 @@ TEST(EventMgr, DelayedPolling) {
|
|||||||
EXPECT_EQ(i + 1, th.queue_size());
|
EXPECT_EQ(i + 1, th.queue_size());
|
||||||
EXPECT_EQ(0, th.free_size());
|
EXPECT_EQ(0, th.free_size());
|
||||||
}
|
}
|
||||||
th.PollEvents(false);
|
th.PollEvents();
|
||||||
EXPECT_EQ(0, th.queue_size());
|
EXPECT_EQ(0, th.queue_size());
|
||||||
EXPECT_EQ(5, th.free_size());
|
EXPECT_EQ(5, th.free_size());
|
||||||
for (int j = 0; j < 2; ++j) {
|
for (int j = 0; j < 2; ++j) {
|
||||||
@ -151,7 +151,7 @@ TEST(EventMgr, DelayedPolling) {
|
|||||||
EXPECT_EQ(i + 1, th.queue_size());
|
EXPECT_EQ(i + 1, th.queue_size());
|
||||||
EXPECT_EQ(4 - i, th.free_size());
|
EXPECT_EQ(4 - i, th.free_size());
|
||||||
}
|
}
|
||||||
th.PollEvents(false);
|
th.PollEvents();
|
||||||
EXPECT_EQ(0, th.queue_size());
|
EXPECT_EQ(0, th.queue_size());
|
||||||
EXPECT_EQ(5, th.free_size());
|
EXPECT_EQ(5, th.free_size());
|
||||||
}
|
}
|
||||||
@ -169,7 +169,7 @@ TEST(EventMgr, FlushLargeTensorImmediately) {
|
|||||||
TensorReferenceVector v;
|
TensorReferenceVector v;
|
||||||
AddTensorReference(&v, 100 * 1048576);
|
AddTensorReference(&v, 100 * 1048576);
|
||||||
em.ThenDeleteTensors(stream.get(), v);
|
em.ThenDeleteTensors(stream.get(), v);
|
||||||
th.PollEvents(false); // Ensure things get registered to be freed by Poll
|
th.PollEvents(); // Ensure things get registered to be freed by Poll
|
||||||
EXPECT_EQ(0, live_tensor_bytes);
|
EXPECT_EQ(0, live_tensor_bytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -188,7 +188,7 @@ TEST(EventMgr, ManySmallTensorsFlushedImmediately) {
|
|||||||
AddTensorReference(&v, 100 * 1024);
|
AddTensorReference(&v, 100 * 1024);
|
||||||
}
|
}
|
||||||
em.ThenDeleteTensors(stream.get(), v);
|
em.ThenDeleteTensors(stream.get(), v);
|
||||||
th.PollEvents(false); // Harvest the tensors ready to be freed.
|
th.PollEvents(); // Harvest the tensors ready to be freed.
|
||||||
EXPECT_EQ(0, live_tensor_bytes);
|
EXPECT_EQ(0, live_tensor_bytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -210,7 +210,7 @@ TEST(EventMgr, StreamSwitchingFlushesImmediately) {
|
|||||||
AddTensorReference(&v2, 1024);
|
AddTensorReference(&v2, 1024);
|
||||||
int64 initial_live_bytes = live_tensor_bytes;
|
int64 initial_live_bytes = live_tensor_bytes;
|
||||||
em.ThenDeleteTensors(stream2.get(), v2);
|
em.ThenDeleteTensors(stream2.get(), v2);
|
||||||
th.PollEvents(false); // Ensure things get registered to be freed by Poll
|
th.PollEvents(); // Ensure things get registered to be freed by Poll
|
||||||
// Different stream should cause first tensor to get deleted
|
// Different stream should cause first tensor to get deleted
|
||||||
EXPECT_GT(initial_live_bytes, live_tensor_bytes);
|
EXPECT_GT(initial_live_bytes, live_tensor_bytes);
|
||||||
}
|
}
|
||||||
@ -229,7 +229,7 @@ TEST(EventMgr, ManySmallTensorsSeparateCallsFlushed) {
|
|||||||
AddTensorReference(&v, 100 * 1024);
|
AddTensorReference(&v, 100 * 1024);
|
||||||
em.ThenDeleteTensors(stream.get(), v);
|
em.ThenDeleteTensors(stream.get(), v);
|
||||||
}
|
}
|
||||||
th.PollEvents(false); // Ensure things get registered to be freed by Poll
|
th.PollEvents(); // Ensure things get registered to be freed by Poll
|
||||||
// Some of the tensors at least should be flushed
|
// Some of the tensors at least should be flushed
|
||||||
EXPECT_GT(1000 * 100 * 1024, live_tensor_bytes);
|
EXPECT_GT(1000 * 100 * 1024, live_tensor_bytes);
|
||||||
}
|
}
|
||||||
@ -264,6 +264,7 @@ TEST(EventMgr, WarnIfInCallback) {
|
|||||||
CHECK(stream);
|
CHECK(stream);
|
||||||
stream->Init();
|
stream->Init();
|
||||||
bool hit = false;
|
bool hit = false;
|
||||||
|
th.StartPollingLoop();
|
||||||
gpu_event_mgr::WarnIfInCallback([&hit] { hit = true; });
|
gpu_event_mgr::WarnIfInCallback([&hit] { hit = true; });
|
||||||
EXPECT_FALSE(hit);
|
EXPECT_FALSE(hit);
|
||||||
Notification note;
|
Notification note;
|
||||||
@ -281,7 +282,7 @@ TEST(EventMgr, WarnIfInCallback) {
|
|||||||
// Provides access to private resources of BaseGPUDevice.
|
// Provides access to private resources of BaseGPUDevice.
|
||||||
class GPUDeviceTestHelper {
|
class GPUDeviceTestHelper {
|
||||||
public:
|
public:
|
||||||
GPUDeviceTestHelper(size_t memory_limit) {
|
GPUDeviceTestHelper(size_t memory_limit, int pending_cap) {
|
||||||
SessionOptions sops;
|
SessionOptions sops;
|
||||||
device_ =
|
device_ =
|
||||||
DeviceFactory::NewDevice(DEVICE_GPU, sops, "/job:a/replica:0/task:0");
|
DeviceFactory::NewDevice(DEVICE_GPU, sops, "/job:a/replica:0/task:0");
|
||||||
@ -299,6 +300,7 @@ class GPUDeviceTestHelper {
|
|||||||
se::Stream* d2h_stream() { return gpu_->streams_[0]->device_to_host; }
|
se::Stream* d2h_stream() { return gpu_->streams_[0]->device_to_host; }
|
||||||
se::Stream* d2d_stream() { return gpu_->streams_[0]->device_to_device[0]; }
|
se::Stream* d2d_stream() { return gpu_->streams_[0]->device_to_device[0]; }
|
||||||
EventMgr* event_mgr() { return gpu_->em_.get(); }
|
EventMgr* event_mgr() { return gpu_->em_.get(); }
|
||||||
|
int pending_cap() { return gpu_->pending_cap_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<Device> device_;
|
std::unique_ptr<Device> device_;
|
||||||
@ -340,23 +342,23 @@ class EMBenchmarkHelper {
|
|||||||
|
|
||||||
EMBenchmarkHelper(GPUDeviceTestHelper* h) : gpu_helper_(h) {}
|
EMBenchmarkHelper(GPUDeviceTestHelper* h) : gpu_helper_(h) {}
|
||||||
|
|
||||||
void ReInit(int num_ops) {
|
void ReInit(int num_ops, int tensor_size) {
|
||||||
gpu_inputs_.clear();
|
gpu_inputs_.clear();
|
||||||
while (gpu_inputs_.size() < 2) {
|
while (gpu_inputs_.size() < 2) {
|
||||||
gpu_inputs_.push_back(Tensor(gpu_helper_->gpu_allocator(), DT_FLOAT,
|
gpu_inputs_.push_back(Tensor(gpu_helper_->gpu_allocator(), DT_FLOAT,
|
||||||
{kTDim}, AllocationAttributes()));
|
{tensor_size}, AllocationAttributes()));
|
||||||
}
|
}
|
||||||
gpu_outputs_.clear();
|
gpu_outputs_.clear();
|
||||||
while (gpu_outputs_.size() < 1) {
|
while (gpu_outputs_.size() < 1) {
|
||||||
gpu_outputs_.push_back(Tensor(gpu_helper_->gpu_allocator(), DT_FLOAT,
|
gpu_outputs_.push_back(Tensor(gpu_helper_->gpu_allocator(), DT_FLOAT,
|
||||||
{kTDim}, AllocationAttributes()));
|
{tensor_size}, AllocationAttributes()));
|
||||||
}
|
}
|
||||||
host_inputs_.clear();
|
host_inputs_.clear();
|
||||||
while (host_inputs_.size() < 2) {
|
while (host_inputs_.size() < 2) {
|
||||||
int instance_index = host_inputs_.size();
|
int instance_index = host_inputs_.size();
|
||||||
host_inputs_.push_back(Tensor(gpu_helper_->host_allocator(), DT_FLOAT,
|
host_inputs_.push_back(Tensor(gpu_helper_->host_allocator(), DT_FLOAT,
|
||||||
{kTDim}, AllocationAttributes()));
|
{tensor_size}, AllocationAttributes()));
|
||||||
for (int i = 0; i < kTDim; ++i) {
|
for (int i = 0; i < tensor_size; ++i) {
|
||||||
host_inputs_.back().flat<float>()(i) =
|
host_inputs_.back().flat<float>()(i) =
|
||||||
i * (1.0 + (0.5 * instance_index));
|
i * (1.0 + (0.5 * instance_index));
|
||||||
}
|
}
|
||||||
@ -364,8 +366,8 @@ class EMBenchmarkHelper {
|
|||||||
host_outputs_.clear();
|
host_outputs_.clear();
|
||||||
while (host_outputs_.size() < 1) {
|
while (host_outputs_.size() < 1) {
|
||||||
host_outputs_.push_back(Tensor(gpu_helper_->host_allocator(), DT_FLOAT,
|
host_outputs_.push_back(Tensor(gpu_helper_->host_allocator(), DT_FLOAT,
|
||||||
{kTDim}, AllocationAttributes()));
|
{tensor_size}, AllocationAttributes()));
|
||||||
for (int i = 0; i < kTDim; ++i) {
|
for (int i = 0; i < tensor_size; ++i) {
|
||||||
host_outputs_.back().flat<float>()(i) = -1;
|
host_outputs_.back().flat<float>()(i) = -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -583,7 +585,7 @@ static void BM_no_ops(int iters, int threads) {
|
|||||||
std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
|
std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
|
||||||
CHECK(stream);
|
CHECK(stream);
|
||||||
stream->Init();
|
stream->Init();
|
||||||
EventMgr em(stream_exec, GPUOptions()); //, stream.get());
|
EventMgr em(stream_exec, GPUOptions());
|
||||||
testing::StartTiming();
|
testing::StartTiming();
|
||||||
std::atomic<int> counter;
|
std::atomic<int> counter;
|
||||||
counter.store(0, std::memory_order_seq_cst);
|
counter.store(0, std::memory_order_seq_cst);
|
||||||
@ -615,10 +617,11 @@ EMBenchmarkHelper* bm_helper = nullptr;
|
|||||||
mutex helper_mu;
|
mutex helper_mu;
|
||||||
|
|
||||||
#ifdef PLATFORM_GOOGLE
|
#ifdef PLATFORM_GOOGLE
|
||||||
static void BM_chain_ops(int iters, int adds_per_round, bool event_after_add) {
|
static void BM_chain_ops(int iters, int tensor_size, int adds_per_round,
|
||||||
|
bool event_after_add, int pending_cap) {
|
||||||
#else
|
#else
|
||||||
static void BM_chain_ops(int iters, int adds_per_round, bool event_after_add,
|
static void BM_chain_ops(int iters, int tensor_size, int adds_per_round,
|
||||||
int threads) {
|
bool event_after_add, int pending_cap, int threads) {
|
||||||
#endif
|
#endif
|
||||||
testing::StopTiming();
|
testing::StopTiming();
|
||||||
#ifdef PLATFORM_GOOGLE
|
#ifdef PLATFORM_GOOGLE
|
||||||
@ -628,12 +631,19 @@ static void BM_chain_ops(int iters, int adds_per_round, bool event_after_add,
|
|||||||
#endif // PLATFORM_GOOGLE
|
#endif // PLATFORM_GOOGLE
|
||||||
{
|
{
|
||||||
mutex_lock l(helper_mu);
|
mutex_lock l(helper_mu);
|
||||||
|
if (gpu_helper && gpu_helper->pending_cap() != pending_cap) {
|
||||||
|
delete bm_helper;
|
||||||
|
bm_helper = nullptr;
|
||||||
|
delete gpu_helper;
|
||||||
|
gpu_helper = nullptr;
|
||||||
|
}
|
||||||
if (!gpu_helper) {
|
if (!gpu_helper) {
|
||||||
gpu_helper = new GPUDeviceTestHelper(1 << 20);
|
gpu_helper = new GPUDeviceTestHelper(1 << 24, pending_cap);
|
||||||
bm_helper = new EMBenchmarkHelper(gpu_helper);
|
bm_helper = new EMBenchmarkHelper(gpu_helper);
|
||||||
}
|
}
|
||||||
if (bm_helper->num_ops() != adds_per_round) {
|
if (bm_helper->num_ops() != adds_per_round ||
|
||||||
bm_helper->ReInit(adds_per_round);
|
bm_helper->tensor_size() != tensor_size) {
|
||||||
|
bm_helper->ReInit(adds_per_round, tensor_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::vector<EMBenchmarkHelper::TimeSet> times;
|
std::vector<EMBenchmarkHelper::TimeSet> times;
|
||||||
@ -648,7 +658,7 @@ static void BM_chain_ops(int iters, int adds_per_round, bool event_after_add,
|
|||||||
// First iter is always slow, so do one prior to the timed loop.
|
// First iter is always slow, so do one prior to the timed loop.
|
||||||
int expected = 1 + (event_after_add ? adds_per_round : 0);
|
int expected = 1 + (event_after_add ? adds_per_round : 0);
|
||||||
bm_helper->DoAddChain(adds_per_round, 1, event_after_add, callback, nullptr);
|
bm_helper->DoAddChain(adds_per_round, 1, event_after_add, callback, nullptr);
|
||||||
while (counter < 1) {
|
while (counter < expected) {
|
||||||
Env::Default()->SleepForMicroseconds(1);
|
Env::Default()->SleepForMicroseconds(1);
|
||||||
}
|
}
|
||||||
counter = 0;
|
counter = 0;
|
||||||
@ -677,71 +687,169 @@ static void BM_chain_ops(int iters, int adds_per_round, bool event_after_add,
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifdef PLATFORM_GOOGLE
|
#ifdef PLATFORM_GOOGLE
|
||||||
static void BM_chain_1_false(int iters) { BM_chain_ops(iters, 1, false); }
|
static void BM_chain_1024_1_false(int iters) {
|
||||||
|
BM_chain_ops(iters, 1024, 1, false, 0);
|
||||||
|
}
|
||||||
|
|
||||||
static void BM_chain_1_true(int iters) { BM_chain_ops(iters, 1, true); }
|
static void BM_chain_1024_1_true(int iters) {
|
||||||
|
BM_chain_ops(iters, 1024, 1, true, 0);
|
||||||
|
}
|
||||||
|
|
||||||
static void BM_chain_10_false(int iters) { BM_chain_ops(iters, 10, false); }
|
static void BM_chain_1024_10_false(int iters) {
|
||||||
|
BM_chain_ops(iters, 1024, 10, false, 0);
|
||||||
|
}
|
||||||
|
|
||||||
static void BM_chain_10_true(int iters) { BM_chain_ops(iters, 10, true); }
|
static void BM_chain_1024_10_true(int iters) {
|
||||||
|
BM_chain_ops(iters, 1024, 10, true, 0);
|
||||||
|
}
|
||||||
|
|
||||||
static void BM_chain_100_false(int iters) { BM_chain_ops(iters, 100, false); }
|
static void BM_chain_1024_100_false(int iters) {
|
||||||
|
BM_chain_ops(iters, 1024, 100, false, 0);
|
||||||
|
}
|
||||||
|
|
||||||
static void BM_chain_100_true(int iters) { BM_chain_ops(iters, 100, true); }
|
static void BM_chain_1024_100_true(int iters) {
|
||||||
|
BM_chain_ops(iters, 1024, 100, true, 0);
|
||||||
|
}
|
||||||
|
|
||||||
BENCHMARK(BM_chain_1_false)->Threads(1);
|
static void BM_chain_1M_1_false(int iters) {
|
||||||
BENCHMARK(BM_chain_1_true)->Threads(1);
|
BM_chain_ops(iters, 1 << 20, 1, false, 0);
|
||||||
BENCHMARK(BM_chain_1_false)->Threads(2);
|
}
|
||||||
BENCHMARK(BM_chain_1_true)->Threads(2);
|
|
||||||
BENCHMARK(BM_chain_1_false)->Threads(8);
|
static void BM_chain_1M_1_true(int iters) {
|
||||||
BENCHMARK(BM_chain_1_true)->Threads(8);
|
BM_chain_ops(iters, 1 << 20, 1, true, 0);
|
||||||
BENCHMARK(BM_chain_10_false)->Threads(1);
|
}
|
||||||
BENCHMARK(BM_chain_10_true)->Threads(1);
|
|
||||||
BENCHMARK(BM_chain_10_false)->Threads(8);
|
static void BM_chain_1M_10_false(int iters) {
|
||||||
BENCHMARK(BM_chain_10_true)->Threads(8);
|
BM_chain_ops(iters, 1 << 20, 10, false, 0);
|
||||||
BENCHMARK(BM_chain_100_false)->Threads(1);
|
}
|
||||||
BENCHMARK(BM_chain_100_true)->Threads(1);
|
|
||||||
BENCHMARK(BM_chain_100_false)->Threads(8);
|
static void BM_chain_1M_10_true(int iters) {
|
||||||
BENCHMARK(BM_chain_100_true)->Threads(8);
|
BM_chain_ops(iters, 1 << 20, 10, true, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_chain_1M_100_false(int iters) {
|
||||||
|
BM_chain_ops(iters, 1 << 20, 100, false, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_chain_1M_100_true(int iters) {
|
||||||
|
BM_chain_ops(iters, 1 << 20, 100, true, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
BENCHMARK(BM_chain_1024_1_false)->Threads(1);
|
||||||
|
BENCHMARK(BM_chain_1024_1_true)->Threads(1);
|
||||||
|
BENCHMARK(BM_chain_1024_1_false)->Threads(2);
|
||||||
|
BENCHMARK(BM_chain_1024_1_true)->Threads(2);
|
||||||
|
BENCHMARK(BM_chain_1024_1_false)->Threads(8);
|
||||||
|
BENCHMARK(BM_chain_1024_1_true)->Threads(8);
|
||||||
|
BENCHMARK(BM_chain_1024_10_false)->Threads(1);
|
||||||
|
BENCHMARK(BM_chain_1024_10_true)->Threads(1);
|
||||||
|
BENCHMARK(BM_chain_1024_10_false)->Threads(8);
|
||||||
|
BENCHMARK(BM_chain_1024_10_true)->Threads(8);
|
||||||
|
BENCHMARK(BM_chain_1024_100_false)->Threads(1);
|
||||||
|
BENCHMARK(BM_chain_1024_100_true)->Threads(1);
|
||||||
|
BENCHMARK(BM_chain_1024_100_false)->Threads(2);
|
||||||
|
BENCHMARK(BM_chain_1024_100_true)->Threads(2);
|
||||||
|
BENCHMARK(BM_chain_1024_100_false)->Threads(8);
|
||||||
|
BENCHMARK(BM_chain_1024_100_true)->Threads(8);
|
||||||
|
|
||||||
|
BENCHMARK(BM_chain_1M_1_false)->Threads(1);
|
||||||
|
BENCHMARK(BM_chain_1M_1_true)->Threads(1);
|
||||||
|
BENCHMARK(BM_chain_1M_1_false)->Threads(2);
|
||||||
|
BENCHMARK(BM_chain_1M_1_true)->Threads(2);
|
||||||
|
BENCHMARK(BM_chain_1M_1_false)->Threads(8);
|
||||||
|
BENCHMARK(BM_chain_1M_1_true)->Threads(8);
|
||||||
|
BENCHMARK(BM_chain_1M_10_false)->Threads(1);
|
||||||
|
BENCHMARK(BM_chain_1M_10_true)->Threads(1);
|
||||||
|
BENCHMARK(BM_chain_1M_10_false)->Threads(8);
|
||||||
|
BENCHMARK(BM_chain_1M_10_true)->Threads(8);
|
||||||
|
BENCHMARK(BM_chain_1M_100_false)->Threads(1);
|
||||||
|
BENCHMARK(BM_chain_1M_100_true)->Threads(1);
|
||||||
|
BENCHMARK(BM_chain_1M_100_false)->Threads(2);
|
||||||
|
BENCHMARK(BM_chain_1M_100_true)->Threads(2);
|
||||||
|
BENCHMARK(BM_chain_1M_100_false)->Threads(8);
|
||||||
|
BENCHMARK(BM_chain_1M_100_true)->Threads(8);
|
||||||
#else
|
#else
|
||||||
static void BM_chain_1_false(int iters, int threads) {
|
static void BM_chain_1024_1_false(int iters, int threads) {
|
||||||
BM_chain_ops(iters, 1, false, threads);
|
BM_chain_ops(iters, 1024, 1, false, 0, threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_chain_1_true(int iters, int threads) {
|
static void BM_chain_1024_1_true(int iters, int threads) {
|
||||||
BM_chain_ops(iters, 1, true, threads);
|
BM_chain_ops(iters, 1024, 1, true, 0, threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_chain_10_false(int iters, int threads) {
|
static void BM_chain_1024_10_false(int iters, int threads) {
|
||||||
BM_chain_ops(iters, 10, false, threads);
|
BM_chain_ops(iters, 1024, 10, false, 0, threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_chain_10_true(int iters, int threads) {
|
static void BM_chain_1024_10_true(int iters, int threads) {
|
||||||
BM_chain_ops(iters, 10, true, threads);
|
BM_chain_ops(iters, 1024, 10, true, 0, threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_chain_100_false(int iters, int threads) {
|
static void BM_chain_1024_100_false(int iters, int threads) {
|
||||||
BM_chain_ops(iters, 100, false, threads);
|
BM_chain_ops(iters, 1024, 100, false, 0, threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_chain_100_true(int iters, int threads) {
|
static void BM_chain_1024_100_true(int iters, int threads) {
|
||||||
BM_chain_ops(iters, 100, true, threads);
|
BM_chain_ops(iters, 1024, 100, true, 0, threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
BENCHMARK(BM_chain_1_false)->Arg(1);
|
static void BM_chain_1M_1_false(int iters, int threads) {
|
||||||
BENCHMARK(BM_chain_1_true)->Arg(1);
|
BM_chain_ops(iters, 1 << 20, 1, false, 0, threads);
|
||||||
BENCHMARK(BM_chain_1_false)->Arg(2);
|
}
|
||||||
BENCHMARK(BM_chain_1_true)->Arg(2);
|
|
||||||
BENCHMARK(BM_chain_1_false)->Arg(8);
|
static void BM_chain_1M_1_true(int iters, int threads) {
|
||||||
BENCHMARK(BM_chain_1_true)->Arg(8);
|
BM_chain_ops(iters, 1 << 20, 1, true, 0, threads);
|
||||||
BENCHMARK(BM_chain_10_false)->Arg(1);
|
}
|
||||||
BENCHMARK(BM_chain_10_true)->Arg(1);
|
|
||||||
BENCHMARK(BM_chain_10_false)->Arg(8);
|
static void BM_chain_1M_10_false(int iters, int threads) {
|
||||||
BENCHMARK(BM_chain_10_true)->Arg(8);
|
BM_chain_ops(iters, 1 << 20, 10, false, 0, threads);
|
||||||
BENCHMARK(BM_chain_100_false)->Arg(1);
|
}
|
||||||
BENCHMARK(BM_chain_100_true)->Arg(1);
|
|
||||||
BENCHMARK(BM_chain_100_false)->Arg(8);
|
static void BM_chain_1M_10_true(int iters, int threads) {
|
||||||
BENCHMARK(BM_chain_100_true)->Arg(8);
|
BM_chain_ops(iters, 1 << 20, 10, true, 0, threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_chain_1M_100_false(int iters, int threads) {
|
||||||
|
BM_chain_ops(iters, 1 << 20, 100, false, 0, threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_chain_1M_100_true(int iters, int threads) {
|
||||||
|
BM_chain_ops(iters, 1 << 20, 100, true, 0, threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
BENCHMARK(BM_chain_1024_1_false)->Arg(1);
|
||||||
|
BENCHMARK(BM_chain_1024_1_true)->Arg(1);
|
||||||
|
BENCHMARK(BM_chain_1024_1_false)->Arg(2);
|
||||||
|
BENCHMARK(BM_chain_1024_1_true)->Arg(2);
|
||||||
|
BENCHMARK(BM_chain_1024_1_false)->Arg(8);
|
||||||
|
BENCHMARK(BM_chain_1024_1_true)->Arg(8);
|
||||||
|
BENCHMARK(BM_chain_1024_10_false)->Arg(1);
|
||||||
|
BENCHMARK(BM_chain_1024_10_true)->Arg(1);
|
||||||
|
BENCHMARK(BM_chain_1024_10_false)->Arg(8);
|
||||||
|
BENCHMARK(BM_chain_1024_10_true)->Arg(8);
|
||||||
|
BENCHMARK(BM_chain_1024_100_false)->Arg(1);
|
||||||
|
BENCHMARK(BM_chain_1024_100_true)->Arg(1);
|
||||||
|
BENCHMARK(BM_chain_1024_100_false)->Arg(2);
|
||||||
|
BENCHMARK(BM_chain_1024_100_true)->Arg(2);
|
||||||
|
BENCHMARK(BM_chain_1024_100_false)->Arg(8);
|
||||||
|
BENCHMARK(BM_chain_1024_100_true)->Arg(8);
|
||||||
|
|
||||||
|
BENCHMARK(BM_chain_1M_1_false)->Arg(1);
|
||||||
|
BENCHMARK(BM_chain_1M_1_true)->Arg(1);
|
||||||
|
BENCHMARK(BM_chain_1M_1_false)->Arg(2);
|
||||||
|
BENCHMARK(BM_chain_1M_1_true)->Arg(2);
|
||||||
|
BENCHMARK(BM_chain_1M_1_false)->Arg(8);
|
||||||
|
BENCHMARK(BM_chain_1M_1_true)->Arg(8);
|
||||||
|
BENCHMARK(BM_chain_1M_10_false)->Arg(1);
|
||||||
|
BENCHMARK(BM_chain_1M_10_true)->Arg(1);
|
||||||
|
BENCHMARK(BM_chain_1M_10_false)->Arg(8);
|
||||||
|
BENCHMARK(BM_chain_1M_10_true)->Arg(8);
|
||||||
|
BENCHMARK(BM_chain_1M_100_false)->Arg(1);
|
||||||
|
BENCHMARK(BM_chain_1M_100_true)->Arg(1);
|
||||||
|
BENCHMARK(BM_chain_1M_100_false)->Arg(2);
|
||||||
|
BENCHMARK(BM_chain_1M_100_true)->Arg(2);
|
||||||
|
BENCHMARK(BM_chain_1M_100_false)->Arg(8);
|
||||||
|
BENCHMARK(BM_chain_1M_100_true)->Arg(8);
|
||||||
#endif
|
#endif
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -147,7 +147,7 @@ Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
|
|||||||
}
|
}
|
||||||
allocator_parts = {std::unique_ptr<Allocator>(gpu_allocator),
|
allocator_parts = {std::unique_ptr<Allocator>(gpu_allocator),
|
||||||
std::unique_ptr<SharedCounter>(timing_counter),
|
std::unique_ptr<SharedCounter>(timing_counter),
|
||||||
sub_allocator,
|
gpu_bfc_allocator, sub_allocator,
|
||||||
std::unique_ptr<Allocator>(recording_allocator)};
|
std::unique_ptr<Allocator>(recording_allocator)};
|
||||||
}
|
}
|
||||||
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
|
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
|
||||||
@ -169,10 +169,17 @@ SharedCounter* GPUProcessState::GPUAllocatorCounter(TfGpuId tf_gpu_id) {
|
|||||||
GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
|
GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
|
if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
|
||||||
|
LOG(ERROR) << "Asked for counter for GPU allocator " << tf_gpu_id.value()
|
||||||
|
<< " but only have " << gpu_allocators_.size();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
|
AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
|
||||||
|
if (allocator_parts.counter.get() == nullptr) {
|
||||||
|
SharedCounter* timing_counter = new SharedCounter;
|
||||||
|
allocator_parts.bfc_allocator->SetTimingCounter(timing_counter);
|
||||||
|
allocator_parts.counter.reset(timing_counter);
|
||||||
|
}
|
||||||
return allocator_parts.counter.get();
|
return allocator_parts.counter.get();
|
||||||
#else
|
#else
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -242,6 +249,7 @@ Allocator* GPUProcessState::GetGpuHostAllocator(int numa_node) {
|
|||||||
LOG(ERROR) << "GetGpuHostAllocator: " << status.error_message();
|
LOG(ERROR) << "GetGpuHostAllocator: " << status.error_message();
|
||||||
}
|
}
|
||||||
int64 gpu_host_mem_limit = gpu_host_mem_limit_in_mb * (1LL << 20);
|
int64 gpu_host_mem_limit = gpu_host_mem_limit_in_mb * (1LL << 20);
|
||||||
|
|
||||||
Allocator* allocator =
|
Allocator* allocator =
|
||||||
new BFCAllocator(sub_allocator, gpu_host_mem_limit,
|
new BFCAllocator(sub_allocator, gpu_host_mem_limit,
|
||||||
true /*allow_growth*/, "gpu_host_bfc" /*name*/);
|
true /*allow_growth*/, "gpu_host_bfc" /*name*/);
|
||||||
@ -253,7 +261,7 @@ Allocator* GPUProcessState::GetGpuHostAllocator(int numa_node) {
|
|||||||
}
|
}
|
||||||
gpu_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator),
|
gpu_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator),
|
||||||
std::unique_ptr<SharedCounter>(nullptr),
|
std::unique_ptr<SharedCounter>(nullptr),
|
||||||
sub_allocator,
|
nullptr, sub_allocator,
|
||||||
std::unique_ptr<Allocator>(nullptr)});
|
std::unique_ptr<Allocator>(nullptr)});
|
||||||
AllocatorParts& allocator_parts = gpu_host_allocators_.back();
|
AllocatorParts& allocator_parts = gpu_host_allocators_.back();
|
||||||
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
|
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
|
||||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
class Allocator;
|
class Allocator;
|
||||||
|
class GPUBFCAllocator;
|
||||||
class PoolAllocator;
|
class PoolAllocator;
|
||||||
class SharedCounter;
|
class SharedCounter;
|
||||||
|
|
||||||
@ -137,6 +138,7 @@ class GPUProcessState {
|
|||||||
struct AllocatorParts {
|
struct AllocatorParts {
|
||||||
std::unique_ptr<Allocator> allocator;
|
std::unique_ptr<Allocator> allocator;
|
||||||
std::unique_ptr<SharedCounter> counter;
|
std::unique_ptr<SharedCounter> counter;
|
||||||
|
GPUBFCAllocator* bfc_allocator;
|
||||||
SubAllocator* sub_allocator; // owned by allocator
|
SubAllocator* sub_allocator; // owned by allocator
|
||||||
std::unique_ptr<Allocator> recording_allocator;
|
std::unique_ptr<Allocator> recording_allocator;
|
||||||
};
|
};
|
||||||
|
@ -300,7 +300,7 @@ void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device,
|
|||||||
void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor,
|
void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor,
|
||||||
const DeviceContext* device_context,
|
const DeviceContext* device_context,
|
||||||
Device* gpu_device, Tensor* gpu_tensor,
|
Device* gpu_device, Tensor* gpu_tensor,
|
||||||
StatusCallback done) {
|
StatusCallback done, bool sync_dst_compute) {
|
||||||
VLOG(1) << "CopyCPUTensorToGPU";
|
VLOG(1) << "CopyCPUTensorToGPU";
|
||||||
const DeviceBase::GpuDeviceInfo* dev_info = nullptr;
|
const DeviceBase::GpuDeviceInfo* dev_info = nullptr;
|
||||||
se::Stream* recv_stream = nullptr;
|
se::Stream* recv_stream = nullptr;
|
||||||
@ -319,7 +319,9 @@ void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Wait for the recv-stream to make sure the buffer is truly available.
|
// Wait for the recv-stream to make sure the buffer is truly available.
|
||||||
|
if (sync_dst_compute) {
|
||||||
recv_host_to_device_stream->ThenWaitFor(recv_stream);
|
recv_host_to_device_stream->ThenWaitFor(recv_stream);
|
||||||
|
}
|
||||||
|
|
||||||
const int64 total_bytes = cpu_tensor->TotalBytes();
|
const int64 total_bytes = cpu_tensor->TotalBytes();
|
||||||
// Note that 0-size tensors have no backing buffer.
|
// Note that 0-size tensors have no backing buffer.
|
||||||
|
@ -88,7 +88,7 @@ class GPUUtil {
|
|||||||
static void CopyCPUTensorToGPU(const Tensor* cpu_tensor,
|
static void CopyCPUTensorToGPU(const Tensor* cpu_tensor,
|
||||||
const DeviceContext* device_context,
|
const DeviceContext* device_context,
|
||||||
Device* gpu_device, Tensor* gpu_tensor,
|
Device* gpu_device, Tensor* gpu_tensor,
|
||||||
StatusCallback done);
|
StatusCallback done, bool sync_dst_compute);
|
||||||
|
|
||||||
static void DeviceToDeviceCopy(
|
static void DeviceToDeviceCopy(
|
||||||
DeviceContext* send_dev_context, DeviceContext* recv_dev_context,
|
DeviceContext* send_dev_context, DeviceContext* recv_dev_context,
|
||||||
|
@ -26,8 +26,10 @@ namespace tensorflow {
|
|||||||
void GPUDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
void GPUDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||||
Device* device,
|
Device* device,
|
||||||
Tensor* device_tensor,
|
Tensor* device_tensor,
|
||||||
StatusCallback done) const {
|
StatusCallback done,
|
||||||
GPUUtil::CopyCPUTensorToGPU(cpu_tensor, this, device, device_tensor, done);
|
bool sync_dst_compute) const {
|
||||||
|
GPUUtil::CopyCPUTensorToGPU(cpu_tensor, this, device, device_tensor, done,
|
||||||
|
sync_dst_compute);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||||
|
@ -50,8 +50,8 @@ class GPUDeviceContext : public DeviceContext {
|
|||||||
int stream_id() const { return stream_id_; }
|
int stream_id() const { return stream_id_; }
|
||||||
|
|
||||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||||
Tensor* device_tensor,
|
Tensor* device_tensor, StatusCallback done,
|
||||||
StatusCallback done) const override;
|
bool sync_dst_compute) const override;
|
||||||
|
|
||||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece edge_name,
|
void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece edge_name,
|
||||||
Device* device, Tensor* cpu_tensor,
|
Device* device, Tensor* cpu_tensor,
|
||||||
|
@ -101,16 +101,30 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
|
|||||||
attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
|
attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
|
||||||
recv_args.alloc_attrs.gpu_compatible());
|
recv_args.alloc_attrs.gpu_compatible());
|
||||||
Allocator* out_allocator = dst_device->GetAllocator(attr);
|
Allocator* out_allocator = dst_device->GetAllocator(attr);
|
||||||
|
bool sync_dst_compute = true;
|
||||||
if (in.dtype() != DT_VARIANT) {
|
if (in.dtype() != DT_VARIANT) {
|
||||||
// Variants are handled by CopyTensor::ViaDMA.
|
// Variants are handled by CopyTensor::ViaDMA.
|
||||||
Tensor copy(out_allocator, in.dtype(), in.shape());
|
AllocationAttributes aa;
|
||||||
|
uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
|
||||||
|
std::function<uint64()> freed_by_func = [dst_device,
|
||||||
|
&safe_alloc_frontier]() {
|
||||||
|
safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
|
||||||
|
return safe_alloc_frontier;
|
||||||
|
};
|
||||||
|
if (parsed.dst.type == "GPU" && safe_alloc_frontier > 0) {
|
||||||
|
// There's a timestamped allocator at work, so use it instead
|
||||||
|
// of sync_dst_compute.
|
||||||
|
aa.freed_by_func = &freed_by_func;
|
||||||
|
sync_dst_compute = false;
|
||||||
|
}
|
||||||
|
Tensor copy(out_allocator, in.dtype(), in.shape(), aa);
|
||||||
*out = copy;
|
*out = copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
|
CopyTensor::ViaDMA(
|
||||||
recv_args.device_context, src_device, dst_device,
|
parsed.edge_name, send_args.device_context, recv_args.device_context,
|
||||||
send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
|
src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
|
||||||
0 /*dev_to_dev_stream_index*/, std::move(done));
|
out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
|
||||||
}
|
}
|
||||||
|
|
||||||
void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
|
void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
|
||||||
|
@ -124,17 +124,29 @@ void RingReducer::ContinueAfterInputCopy() {
|
|||||||
// can be provided to the kernel in host memory?
|
// can be provided to the kernel in host memory?
|
||||||
Tensor group_size_val = ca_->Scalar(group_size_);
|
Tensor group_size_val = ca_->Scalar(group_size_);
|
||||||
if (col_params_->group.device_type != "CPU") {
|
if (col_params_->group.device_type != "CPU") {
|
||||||
group_size_tensor_ = ca_->Scalar(col_ctx_->device->GetAllocator(
|
uint64 safe_alloc_frontier = col_ctx_->device->SafeAllocFrontier(0);
|
||||||
col_ctx_->op_ctx->input_alloc_attr(0)));
|
AllocationAttributes aa;
|
||||||
|
std::function<uint64()> freed_by_func = [this, &safe_alloc_frontier]() {
|
||||||
|
safe_alloc_frontier =
|
||||||
|
col_ctx_->device->SafeAllocFrontier(safe_alloc_frontier);
|
||||||
|
return safe_alloc_frontier;
|
||||||
|
};
|
||||||
|
if (safe_alloc_frontier > 0) {
|
||||||
|
aa.freed_by_func = &freed_by_func;
|
||||||
|
}
|
||||||
|
group_size_tensor_ = ca_->Scalar(
|
||||||
|
col_ctx_->device->GetAllocator(col_ctx_->op_ctx->input_alloc_attr(0)),
|
||||||
|
aa);
|
||||||
DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
|
DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
|
||||||
op_dev_ctx->CopyCPUTensorToDevice(&group_size_val, col_ctx_->device,
|
op_dev_ctx->CopyCPUTensorToDevice(
|
||||||
&group_size_tensor_,
|
&group_size_val, col_ctx_->device, &group_size_tensor_,
|
||||||
[this](const Status& s) {
|
[this](const Status& s) {
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
StartAbort(s);
|
StartAbort(s);
|
||||||
}
|
}
|
||||||
group_size_tensor_ready_.Notify();
|
group_size_tensor_ready_.Notify();
|
||||||
});
|
},
|
||||||
|
(safe_alloc_frontier == 0));
|
||||||
} else {
|
} else {
|
||||||
group_size_tensor_ = group_size_val;
|
group_size_tensor_ = group_size_val;
|
||||||
group_size_tensor_ready_.Notify();
|
group_size_tensor_ready_.Notify();
|
||||||
|
@ -269,19 +269,28 @@ void BaseRemoteRendezvous::SameWorkerRecvDone(
|
|||||||
attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
|
attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
|
||||||
recv_args.alloc_attrs.gpu_compatible());
|
recv_args.alloc_attrs.gpu_compatible());
|
||||||
Allocator* out_allocator = dst_device->GetAllocator(attr);
|
Allocator* out_allocator = dst_device->GetAllocator(attr);
|
||||||
|
AllocationAttributes allocation_attr;
|
||||||
|
uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
|
||||||
|
bool sync_dst_compute = (safe_alloc_frontier == 0);
|
||||||
|
std::function<uint64()> freed_by_func = [dst_device, &safe_alloc_frontier]() {
|
||||||
|
safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
|
||||||
|
return safe_alloc_frontier;
|
||||||
|
};
|
||||||
|
if (!sync_dst_compute) {
|
||||||
|
allocation_attr.freed_by_func = &freed_by_func;
|
||||||
|
}
|
||||||
if (in.dtype() != DT_VARIANT) {
|
if (in.dtype() != DT_VARIANT) {
|
||||||
// Variants are handled by CopyTensor::ViaDMA.
|
// Variants are handled by CopyTensor::ViaDMA.
|
||||||
Tensor copy(out_allocator, in.dtype(), in.shape());
|
Tensor copy(out_allocator, in.dtype(), in.shape(), allocation_attr);
|
||||||
*out = copy;
|
*out = copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
|
// The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
|
||||||
// etc.
|
// etc.
|
||||||
CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
|
CopyTensor::ViaDMA(
|
||||||
recv_args.device_context, src_device, dst_device,
|
parsed.edge_name, send_args.device_context, recv_args.device_context,
|
||||||
send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
|
src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
|
||||||
0 /*dev_to_dev_stream_index*/, std::move(done));
|
out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
|
bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
|
||||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
|
||||||
#include "tensorflow/core/framework/numeric_types.h"
|
#include "tensorflow/core/framework/numeric_types.h"
|
||||||
#include "tensorflow/core/framework/resource_handle.h"
|
#include "tensorflow/core/framework/resource_handle.h"
|
||||||
#include "tensorflow/core/framework/type_traits.h"
|
#include "tensorflow/core/framework/type_traits.h"
|
||||||
@ -42,10 +41,10 @@ struct AllocationAttributes {
|
|||||||
AllocationAttributes() = default;
|
AllocationAttributes() = default;
|
||||||
|
|
||||||
AllocationAttributes(bool no_retry_on_failure, bool allocation_will_be_logged,
|
AllocationAttributes(bool no_retry_on_failure, bool allocation_will_be_logged,
|
||||||
std::function<uint64()> freed_by_func)
|
std::function<uint64()>* freed_by_func)
|
||||||
: no_retry_on_failure(no_retry_on_failure),
|
: no_retry_on_failure(no_retry_on_failure),
|
||||||
allocation_will_be_logged(allocation_will_be_logged),
|
allocation_will_be_logged(allocation_will_be_logged),
|
||||||
freed_by_func(std::move(freed_by_func)) {}
|
freed_by_func(freed_by_func) {}
|
||||||
|
|
||||||
// If the first attempt to allocate the memory fails, the allocation
|
// If the first attempt to allocate the memory fails, the allocation
|
||||||
// should return immediately without retrying.
|
// should return immediately without retrying.
|
||||||
@ -59,9 +58,9 @@ struct AllocationAttributes {
|
|||||||
// true.
|
// true.
|
||||||
bool allocation_will_be_logged = false;
|
bool allocation_will_be_logged = false;
|
||||||
// EXPERIMENTAL: If provided, then evaluates to a timing count such that only
|
// EXPERIMENTAL: If provided, then evaluates to a timing count such that only
|
||||||
// a memory chunk whose last-freed count is at this value or earlier may be
|
// a memory chunk whose freed_at_count is at this value or earlier may be
|
||||||
// returned.
|
// returned.
|
||||||
std::function<uint64()> freed_by_func = nullptr;
|
std::function<uint64()>* freed_by_func = nullptr; // Not owned.
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(AllocationAttributes);
|
TF_DISALLOW_COPY_AND_ASSIGN(AllocationAttributes);
|
||||||
};
|
};
|
||||||
@ -223,6 +222,8 @@ class Allocator {
|
|||||||
// Clears the internal stats except for the `in_use` field.
|
// Clears the internal stats except for the `in_use` field.
|
||||||
virtual void ClearStats() {}
|
virtual void ClearStats() {}
|
||||||
|
|
||||||
|
virtual void SetSafeFrontier(uint64 count) {}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// No constructors or destructors are run for simple types
|
// No constructors or destructors are run for simple types
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -74,11 +74,11 @@ class DeviceContext : public core::RefCounted {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// "cpu_tensor" is a tensor on a CPU. Copies "cpu_tensor" into
|
// "cpu_tensor" is a tensor on a CPU. Copies "cpu_tensor" into
|
||||||
// "device_tensor" which is on a GPU device "device". "device_tensor"
|
// "device_tensor" which is on a non-CPU device "device". "device_tensor"
|
||||||
// must be allocated to be of the same size as "cpu_tensor".
|
// must be allocated to be of the same size as "cpu_tensor".
|
||||||
virtual void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
virtual void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||||
Tensor* device_tensor,
|
Tensor* device_tensor, StatusCallback done,
|
||||||
StatusCallback done) const {
|
bool sync_dst_compute = true) const {
|
||||||
done(errors::Internal("Unrecognized device type in CPU-to-device Copy"));
|
done(errors::Internal("Unrecognized device type in CPU-to-device Copy"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -253,7 +253,7 @@ class DeviceBase {
|
|||||||
// device memory tagged with an earlier freed-at count is really unencumbered
|
// device memory tagged with an earlier freed-at count is really unencumbered
|
||||||
// by pending uses. For this to be useful the device memory allocator must
|
// by pending uses. For this to be useful the device memory allocator must
|
||||||
// be tagging deallocated memory chunks using the same counter.
|
// be tagging deallocated memory chunks using the same counter.
|
||||||
virtual uint64 SafeAllocFrontier() { return 0; }
|
virtual uint64 SafeAllocFrontier(uint64 old_value) { return 0; }
|
||||||
|
|
||||||
// Copies `input_tensor` to `output_tensor`, where both tensors are on this
|
// Copies `input_tensor` to `output_tensor`, where both tensors are on this
|
||||||
// device. This function assumes that `output_tensor` has already been
|
// device. This function assumes that `output_tensor` has already been
|
||||||
|
@ -751,6 +751,9 @@ Status OpKernelContext::allocate_temp(
|
|||||||
int64 alloc_size = a->AllocatedSize(out_temp->tensor_data().data());
|
int64 alloc_size = a->AllocatedSize(out_temp->tensor_data().data());
|
||||||
record_temp_memory_allocation(alloc_size, *out_temp);
|
record_temp_memory_allocation(alloc_size, *out_temp);
|
||||||
}
|
}
|
||||||
|
} else if (record_memory_consumption_) {
|
||||||
|
mutex_lock l(stats_mu_);
|
||||||
|
temp_memory_allocated_ += out_temp->TotalBytes();
|
||||||
}
|
}
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
@ -775,6 +778,10 @@ Status OpKernelContext::allocate_persistent(DataType type,
|
|||||||
int64 alloc_id = a->AllocationId(t->tensor_data().data());
|
int64 alloc_id = a->AllocationId(t->tensor_data().data());
|
||||||
record_persistent_memory_allocation(alloc_size, alloc_id);
|
record_persistent_memory_allocation(alloc_size, alloc_id);
|
||||||
}
|
}
|
||||||
|
} else if (record_memory_consumption_) {
|
||||||
|
mutex_lock l(stats_mu_);
|
||||||
|
persistent_memory_allocated_ +=
|
||||||
|
out_persistent->AccessTensor(this)->TotalBytes();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return s;
|
return s;
|
||||||
|
@ -1225,6 +1225,8 @@ class OpKernelContext {
|
|||||||
|
|
||||||
bool input_is_ref(int index) const;
|
bool input_is_ref(int index) const;
|
||||||
|
|
||||||
|
void set_record_memory_consumption(bool v) { record_memory_consumption_ = v; }
|
||||||
|
|
||||||
// Used by OpKernel implementations to track actively running deferred ops.
|
// Used by OpKernel implementations to track actively running deferred ops.
|
||||||
//
|
//
|
||||||
// A deferred op is one whose Compute method returns (or whose ComputeAsync
|
// A deferred op is one whose Compute method returns (or whose ComputeAsync
|
||||||
@ -1245,6 +1247,7 @@ class OpKernelContext {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
Allocator* get_allocator(AllocatorAttributes attr);
|
Allocator* get_allocator(AllocatorAttributes attr);
|
||||||
|
bool record_memory_consumption_ = false;
|
||||||
|
|
||||||
// Internal method to add a tensor's buffer to the list of buffers
|
// Internal method to add a tensor's buffer to the list of buffers
|
||||||
// referenced during the execution of the Op, so that GPUs may
|
// referenced during the execution of the Op, so that GPUs may
|
||||||
|
@ -165,9 +165,26 @@ message GPUOptions {
|
|||||||
// is really not subject to pending use.
|
// is really not subject to pending use.
|
||||||
bool timestamped_allocator = 5;
|
bool timestamped_allocator = 5;
|
||||||
|
|
||||||
// If > 0 limit the number of pending kernels on any compute
|
// reserved id: 6
|
||||||
// stream to this number.
|
|
||||||
int32 pending_cap = 6;
|
// Parameters for GPUKernelTracker. By default no kernel tracking is done.
|
||||||
|
// Note that timestamped_allocator is only effective if some tracking is
|
||||||
|
// specified.
|
||||||
|
//
|
||||||
|
// If kernel_tracker_max_interval = n > 0, then a tracking event
|
||||||
|
// is inserted after every n kernels without an event.
|
||||||
|
int32 kernel_tracker_max_interval = 7;
|
||||||
|
// If kernel_tracker_max_bytes = n > 0, then a tracking event is
|
||||||
|
// inserted after every series of kernels allocating a sum of
|
||||||
|
// memory >= n. If one kernel allocates b * n bytes, then one
|
||||||
|
// event will be inserted after it, but it will count as b against
|
||||||
|
// the pending limit.
|
||||||
|
int32 kernel_tracker_max_bytes = 8;
|
||||||
|
// If kernel_tracker_max_pending > 0 then no more than this many
|
||||||
|
// tracking events can be outstanding at a time. An attempt to
|
||||||
|
// launch an additional kernel will stall until an event
|
||||||
|
// completes.
|
||||||
|
int32 kernel_tracker_max_pending = 9;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Everything inside experimental is subject to change and is not subject
|
// Everything inside experimental is subject to change and is not subject
|
||||||
|
@ -91,8 +91,20 @@ tf_proto {
|
|||||||
type: TYPE_BOOL
|
type: TYPE_BOOL
|
||||||
}
|
}
|
||||||
field {
|
field {
|
||||||
name: "pending_cap"
|
name: "kernel_tracker_max_interval"
|
||||||
number: 6
|
number: 7
|
||||||
|
label: LABEL_OPTIONAL
|
||||||
|
type: TYPE_INT32
|
||||||
|
}
|
||||||
|
field {
|
||||||
|
name: "kernel_tracker_max_bytes"
|
||||||
|
number: 8
|
||||||
|
label: LABEL_OPTIONAL
|
||||||
|
type: TYPE_INT32
|
||||||
|
}
|
||||||
|
field {
|
||||||
|
name: "kernel_tracker_max_pending"
|
||||||
|
number: 9
|
||||||
label: LABEL_OPTIONAL
|
label: LABEL_OPTIONAL
|
||||||
type: TYPE_INT32
|
type: TYPE_INT32
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user