Add TensorBuffer::GetAllocatedBytes()
method.
We now call `Tensor::AllocatedBytes()` for every tensor fed into and fetched from a session. The current implementation incurs protobuf overhead to create the requisite AllocationDescription, then typically falls back to `Tensor::TotalBytes()` anyway. PiperOrigin-RevId: 266933172
This commit is contained in:
parent
89ce4c791c
commit
4fc96cee6e
@ -66,6 +66,17 @@ namespace tensorflow {
|
|||||||
// code).
|
// code).
|
||||||
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
|
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
|
||||||
|
|
||||||
|
bool TensorBuffer::GetAllocatedBytes(size_t* out_bytes) const {
|
||||||
|
AllocationDescription allocation_description;
|
||||||
|
FillAllocationDescription(&allocation_description);
|
||||||
|
if (allocation_description.allocated_bytes() > 0) {
|
||||||
|
*out_bytes = allocation_description.allocated_bytes();
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// An un-templated base class for Buffer.
|
// An un-templated base class for Buffer.
|
||||||
@ -75,6 +86,16 @@ class BufferBase : public TensorBuffer {
|
|||||||
: TensorBuffer(data_ptr), alloc_(alloc) {}
|
: TensorBuffer(data_ptr), alloc_(alloc) {}
|
||||||
|
|
||||||
TensorBuffer* root_buffer() override { return this; }
|
TensorBuffer* root_buffer() override { return this; }
|
||||||
|
|
||||||
|
bool GetAllocatedBytes(size_t* out_bytes) const override {
|
||||||
|
if (alloc_->TracksAllocationSizes()) {
|
||||||
|
*out_bytes = alloc_->AllocatedSize(data());
|
||||||
|
return *out_bytes > 0;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void FillAllocationDescription(AllocationDescription* proto) const override {
|
void FillAllocationDescription(AllocationDescription* proto) const override {
|
||||||
void* data_ptr = data();
|
void* data_ptr = data();
|
||||||
int64 rb = size();
|
int64 rb = size();
|
||||||
@ -784,6 +805,13 @@ static Allocator* get_default_cpu_allocator() {
|
|||||||
Tensor::Tensor(DataType type, const TensorShape& shape)
|
Tensor::Tensor(DataType type, const TensorShape& shape)
|
||||||
: Tensor(get_default_cpu_allocator(), type, shape) {}
|
: Tensor(get_default_cpu_allocator(), type, shape) {}
|
||||||
|
|
||||||
|
bool Tensor::HostScalarTensorBufferBase::GetAllocatedBytes(
|
||||||
|
size_t* out_bytes) const {
|
||||||
|
// `this->FillAllocationDescription()` never sets allocated bytes information,
|
||||||
|
// so we can short-circuit the construction of an `AllocationDescription`.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
void Tensor::HostScalarTensorBufferBase::FillAllocationDescription(
|
void Tensor::HostScalarTensorBufferBase::FillAllocationDescription(
|
||||||
AllocationDescription* proto) const {
|
AllocationDescription* proto) const {
|
||||||
proto->set_requested_bytes(size());
|
proto->set_requested_bytes(size());
|
||||||
@ -811,6 +839,9 @@ class SubBuffer : public TensorBuffer {
|
|||||||
|
|
||||||
size_t size() const override { return sizeof(T) * elem_; }
|
size_t size() const override { return sizeof(T) * elem_; }
|
||||||
TensorBuffer* root_buffer() override { return root_; }
|
TensorBuffer* root_buffer() override { return root_; }
|
||||||
|
bool GetAllocatedBytes(size_t* out_bytes) const override {
|
||||||
|
return root_->GetAllocatedBytes(out_bytes);
|
||||||
|
}
|
||||||
void FillAllocationDescription(AllocationDescription* proto) const override {
|
void FillAllocationDescription(AllocationDescription* proto) const override {
|
||||||
root_->FillAllocationDescription(proto);
|
root_->FillAllocationDescription(proto);
|
||||||
}
|
}
|
||||||
@ -937,15 +968,13 @@ size_t Tensor::TotalBytes() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t Tensor::AllocatedBytes() const {
|
size_t Tensor::AllocatedBytes() const {
|
||||||
TensorDescription tensor_description;
|
if (buf_) {
|
||||||
FillDescription(&tensor_description);
|
size_t ret;
|
||||||
if (tensor_description.has_allocation_description() &&
|
if (buf_->GetAllocatedBytes(&ret)) {
|
||||||
tensor_description.allocation_description().allocated_bytes() > 0) {
|
return ret;
|
||||||
return tensor_description.allocation_description().allocated_bytes();
|
}
|
||||||
} else {
|
|
||||||
// Fall back to TotalBytes() if the allocator doesn't have its size.
|
|
||||||
return TotalBytes();
|
|
||||||
}
|
}
|
||||||
|
return TotalBytes();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Tensor::CanUseDMA() const {
|
bool Tensor::CanUseDMA() const {
|
||||||
|
@ -79,6 +79,8 @@ class TensorBuffer : public core::RefCounted {
|
|||||||
virtual void FillAllocationDescription(
|
virtual void FillAllocationDescription(
|
||||||
AllocationDescription* proto) const = 0;
|
AllocationDescription* proto) const = 0;
|
||||||
|
|
||||||
|
virtual bool GetAllocatedBytes(size_t* out_bytes) const;
|
||||||
|
|
||||||
/// \brief Helper method to reinterpret the buffer as an array of `T`.
|
/// \brief Helper method to reinterpret the buffer as an array of `T`.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* base() const {
|
T* base() const {
|
||||||
@ -940,6 +942,7 @@ inline Tensor::Tensor(Tensor&& other)
|
|||||||
class Tensor::HostScalarTensorBufferBase : public TensorBuffer {
|
class Tensor::HostScalarTensorBufferBase : public TensorBuffer {
|
||||||
public:
|
public:
|
||||||
using TensorBuffer::TensorBuffer;
|
using TensorBuffer::TensorBuffer;
|
||||||
|
bool GetAllocatedBytes(size_t* out_bytes) const final;
|
||||||
void FillAllocationDescription(AllocationDescription* proto) const final;
|
void FillAllocationDescription(AllocationDescription* proto) const final;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -950,7 +953,8 @@ template <typename T>
|
|||||||
struct Tensor::ValueAndTensorBuffer {
|
struct Tensor::ValueAndTensorBuffer {
|
||||||
class HostScalarTensorBuffer : public Tensor::HostScalarTensorBufferBase {
|
class HostScalarTensorBuffer : public Tensor::HostScalarTensorBufferBase {
|
||||||
public:
|
public:
|
||||||
HostScalarTensorBuffer(void* data) : HostScalarTensorBufferBase(data) {}
|
explicit HostScalarTensorBuffer(void* data)
|
||||||
|
: HostScalarTensorBufferBase(data) {}
|
||||||
size_t size() const final { return sizeof(T); }
|
size_t size() const final { return sizeof(T); }
|
||||||
TensorBuffer* root_buffer() final { return this; }
|
TensorBuffer* root_buffer() final { return this; }
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user