Add TensorBuffer::GetAllocatedBytes() method.

We now call `Tensor::AllocatedBytes()` for every tensor fed into and fetched from a session. The current implementation incurs protobuf overhead to create the requisite AllocationDescription, then typically falls back to `Tensor::TotalBytes()` anyway.

PiperOrigin-RevId: 266933172
This commit is contained in:
Derek Murray 2019-09-03 08:24:01 -07:00 committed by TensorFlower Gardener
parent 89ce4c791c
commit 4fc96cee6e
2 changed files with 42 additions and 9 deletions

View File

@ -66,6 +66,17 @@ namespace tensorflow {
// code).
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
bool TensorBuffer::GetAllocatedBytes(size_t* out_bytes) const {
AllocationDescription allocation_description;
FillAllocationDescription(&allocation_description);
if (allocation_description.allocated_bytes() > 0) {
*out_bytes = allocation_description.allocated_bytes();
return true;
} else {
return false;
}
}
namespace {
// An un-templated base class for Buffer.
@ -75,6 +86,16 @@ class BufferBase : public TensorBuffer {
: TensorBuffer(data_ptr), alloc_(alloc) {}
TensorBuffer* root_buffer() override { return this; }
bool GetAllocatedBytes(size_t* out_bytes) const override {
if (alloc_->TracksAllocationSizes()) {
*out_bytes = alloc_->AllocatedSize(data());
return *out_bytes > 0;
} else {
return false;
}
}
void FillAllocationDescription(AllocationDescription* proto) const override {
void* data_ptr = data();
int64 rb = size();
@ -784,6 +805,13 @@ static Allocator* get_default_cpu_allocator() {
Tensor::Tensor(DataType type, const TensorShape& shape)
: Tensor(get_default_cpu_allocator(), type, shape) {}
bool Tensor::HostScalarTensorBufferBase::GetAllocatedBytes(
size_t* out_bytes) const {
// `this->FillAllocationDescription()` never sets allocated bytes information,
// so we can short-circuit the construction of an `AllocationDescription`.
return false;
}
void Tensor::HostScalarTensorBufferBase::FillAllocationDescription(
AllocationDescription* proto) const {
proto->set_requested_bytes(size());
@ -811,6 +839,9 @@ class SubBuffer : public TensorBuffer {
size_t size() const override { return sizeof(T) * elem_; }
TensorBuffer* root_buffer() override { return root_; }
bool GetAllocatedBytes(size_t* out_bytes) const override {
return root_->GetAllocatedBytes(out_bytes);
}
void FillAllocationDescription(AllocationDescription* proto) const override {
root_->FillAllocationDescription(proto);
}
@ -937,15 +968,13 @@ size_t Tensor::TotalBytes() const {
}
size_t Tensor::AllocatedBytes() const {
TensorDescription tensor_description;
FillDescription(&tensor_description);
if (tensor_description.has_allocation_description() &&
tensor_description.allocation_description().allocated_bytes() > 0) {
return tensor_description.allocation_description().allocated_bytes();
} else {
// Fall back to TotalBytes() if the allocator doesn't have its size.
return TotalBytes();
if (buf_) {
size_t ret;
if (buf_->GetAllocatedBytes(&ret)) {
return ret;
}
}
return TotalBytes();
}
bool Tensor::CanUseDMA() const {

View File

@ -79,6 +79,8 @@ class TensorBuffer : public core::RefCounted {
virtual void FillAllocationDescription(
AllocationDescription* proto) const = 0;
virtual bool GetAllocatedBytes(size_t* out_bytes) const;
/// \brief Helper method to reinterpret the buffer as an array of `T`.
template <typename T>
T* base() const {
@ -940,6 +942,7 @@ inline Tensor::Tensor(Tensor&& other)
class Tensor::HostScalarTensorBufferBase : public TensorBuffer {
public:
using TensorBuffer::TensorBuffer;
bool GetAllocatedBytes(size_t* out_bytes) const final;
void FillAllocationDescription(AllocationDescription* proto) const final;
};
@ -950,7 +953,8 @@ template <typename T>
struct Tensor::ValueAndTensorBuffer {
class HostScalarTensorBuffer : public Tensor::HostScalarTensorBufferBase {
public:
HostScalarTensorBuffer(void* data) : HostScalarTensorBufferBase(data) {}
explicit HostScalarTensorBuffer(void* data)
: HostScalarTensorBufferBase(data) {}
size_t size() const final { return sizeof(T); }
TensorBuffer* root_buffer() final { return this; }