Report allocation failures in SimpleMemoryAllocator.

PiperOrigin-RevId: 306899890
Change-Id: I3fcd27eb68034587d9a5fd3eb28158f64e77ade9
This commit is contained in:
Robert David 2020-04-16 12:31:29 -07:00 committed by TensorFlower Gardener
parent faa58d6e94
commit e4eb45ee8a
5 changed files with 44 additions and 23 deletions

View File

@ -444,9 +444,8 @@ MicroAllocator::MicroAllocator(TfLiteContext* context, const Model* model,
// Creates a root memory allocator managing the arena. The allocator itself // Creates a root memory allocator managing the arena. The allocator itself
// also locates in the arena buffer. This allocator doesn't need to be // also locates in the arena buffer. This allocator doesn't need to be
// destructed as it's the root allocator. // destructed as it's the root allocator.
SimpleMemoryAllocator* aligned_allocator = memory_allocator_ = CreateInPlaceSimpleMemoryAllocator(
CreateInPlaceSimpleMemoryAllocator(aligned_arena, aligned_arena_size); error_reporter, aligned_arena, aligned_arena_size);
memory_allocator_ = aligned_allocator;
TfLiteStatus status = Init(); TfLiteStatus status = Init();
// TODO(b/147871299): Consider improving this code. A better way of handling // TODO(b/147871299): Consider improving this code. A better way of handling
// failures in the constructor is to have a static function that returns a // failures in the constructor is to have a static function that returns a
@ -558,7 +557,8 @@ TfLiteStatus MicroAllocator::FinishTensorAllocation() {
// Note that AllocationInfo is only needed for creating the plan. It will be // Note that AllocationInfo is only needed for creating the plan. It will be
// thrown away when the child allocator (tmp_allocator) goes out of scope. // thrown away when the child allocator (tmp_allocator) goes out of scope.
{ {
SimpleMemoryAllocator tmp_allocator(memory_allocator_->GetHead(), SimpleMemoryAllocator tmp_allocator(error_reporter_,
memory_allocator_->GetHead(),
memory_allocator_->GetTail()); memory_allocator_->GetTail());
AllocationInfoBuilder builder(error_reporter_, &tmp_allocator); AllocationInfoBuilder builder(error_reporter_, &tmp_allocator);

View File

@ -68,7 +68,8 @@ TF_LITE_MICRO_TEST(TestInitializeRuntimeTensor) {
TfLiteContext context; TfLiteContext context;
constexpr size_t arena_size = 1024; constexpr size_t arena_size = 1024;
uint8_t arena[arena_size]; uint8_t arena[arena_size];
tflite::SimpleMemoryAllocator simple_allocator(arena, arena_size); tflite::SimpleMemoryAllocator simple_allocator(micro_test::reporter, arena,
arena_size);
const tflite::Tensor* tensor = tflite::testing::Create1dFlatbufferTensor(100); const tflite::Tensor* tensor = tflite::testing::Create1dFlatbufferTensor(100);
const flatbuffers::Vector<flatbuffers::Offset<tflite::Buffer>>* buffers = const flatbuffers::Vector<flatbuffers::Offset<tflite::Buffer>>* buffers =
@ -92,7 +93,8 @@ TF_LITE_MICRO_TEST(TestInitializeQuantizedTensor) {
TfLiteContext context; TfLiteContext context;
constexpr size_t arena_size = 1024; constexpr size_t arena_size = 1024;
uint8_t arena[arena_size]; uint8_t arena[arena_size];
tflite::SimpleMemoryAllocator simple_allocator(arena, arena_size); tflite::SimpleMemoryAllocator simple_allocator(micro_test::reporter, arena,
arena_size);
const tflite::Tensor* tensor = const tflite::Tensor* tensor =
tflite::testing::CreateQuantizedFlatbufferTensor(100); tflite::testing::CreateQuantizedFlatbufferTensor(100);
@ -117,7 +119,8 @@ TF_LITE_MICRO_TEST(TestMissingQuantization) {
TfLiteContext context; TfLiteContext context;
constexpr size_t arena_size = 1024; constexpr size_t arena_size = 1024;
uint8_t arena[arena_size]; uint8_t arena[arena_size];
tflite::SimpleMemoryAllocator simple_allocator(arena, arena_size); tflite::SimpleMemoryAllocator simple_allocator(micro_test::reporter, arena,
arena_size);
const tflite::Tensor* tensor = const tflite::Tensor* tensor =
tflite::testing::CreateMissingQuantizationFlatbufferTensor(100); tflite::testing::CreateMissingQuantizationFlatbufferTensor(100);

View File

@ -22,9 +22,10 @@ limitations under the License.
namespace tflite { namespace tflite {
SimpleMemoryAllocator* CreateInPlaceSimpleMemoryAllocator(uint8_t* buffer, SimpleMemoryAllocator* CreateInPlaceSimpleMemoryAllocator(
size_t buffer_size) { ErrorReporter* error_reporter, uint8_t* buffer, size_t buffer_size) {
SimpleMemoryAllocator tmp = SimpleMemoryAllocator(buffer, buffer_size); SimpleMemoryAllocator tmp =
SimpleMemoryAllocator(error_reporter, buffer, buffer_size);
SimpleMemoryAllocator* in_place_allocator = SimpleMemoryAllocator* in_place_allocator =
reinterpret_cast<SimpleMemoryAllocator*>(tmp.AllocateFromTail( reinterpret_cast<SimpleMemoryAllocator*>(tmp.AllocateFromTail(
sizeof(SimpleMemoryAllocator), alignof(SimpleMemoryAllocator))); sizeof(SimpleMemoryAllocator), alignof(SimpleMemoryAllocator)));
@ -34,10 +35,13 @@ SimpleMemoryAllocator* CreateInPlaceSimpleMemoryAllocator(uint8_t* buffer,
uint8_t* SimpleMemoryAllocator::AllocateFromHead(size_t size, uint8_t* SimpleMemoryAllocator::AllocateFromHead(size_t size,
size_t alignment) { size_t alignment) {
uint8_t* aligned_result = AlignPointerUp(head_, alignment); uint8_t* const aligned_result = AlignPointerUp(head_, alignment);
size_t available_memory = tail_ - aligned_result; const size_t available_memory = tail_ - aligned_result;
if (available_memory < size) { if (available_memory < size) {
// TODO(petewarden): Add error reporting beyond returning null! TF_LITE_REPORT_ERROR(
error_reporter_,
"Failed to allocate memory. Requested: %u, available %u, missing: %u",
size, available_memory, size - available_memory);
return nullptr; return nullptr;
} }
head_ = aligned_result + size; head_ = aligned_result + size;
@ -46,8 +50,13 @@ uint8_t* SimpleMemoryAllocator::AllocateFromHead(size_t size,
uint8_t* SimpleMemoryAllocator::AllocateFromTail(size_t size, uint8_t* SimpleMemoryAllocator::AllocateFromTail(size_t size,
size_t alignment) { size_t alignment) {
uint8_t* aligned_result = AlignPointerDown(tail_ - size, alignment); uint8_t* const aligned_result = AlignPointerDown(tail_ - size, alignment);
if (aligned_result < head_) { if (aligned_result < head_) {
const size_t missing_memory = head_ - aligned_result;
TF_LITE_REPORT_ERROR(
error_reporter_,
"Failed to allocate memory. Requested: %u, available %u, missing: %u",
size, size - missing_memory, missing_memory);
return nullptr; return nullptr;
} }
tail_ = aligned_result; tail_ = aligned_result;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <cstdint> #include <cstdint>
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
namespace tflite { namespace tflite {
@ -27,10 +28,14 @@ namespace tflite {
// This makes it pretty wasteful, so we should use a more intelligent method. // This makes it pretty wasteful, so we should use a more intelligent method.
class SimpleMemoryAllocator { class SimpleMemoryAllocator {
public: public:
SimpleMemoryAllocator(uint8_t* buffer_head, uint8_t* buffer_tail) SimpleMemoryAllocator(ErrorReporter* error_reporter, uint8_t* buffer_head,
: head_(buffer_head), tail_(buffer_tail) {} uint8_t* buffer_tail)
SimpleMemoryAllocator(uint8_t* buffer, size_t buffer_size) : error_reporter_(error_reporter),
: SimpleMemoryAllocator(buffer, buffer + buffer_size) {} head_(buffer_head),
tail_(buffer_tail) {}
SimpleMemoryAllocator(ErrorReporter* error_reporter, uint8_t* buffer,
size_t buffer_size)
: SimpleMemoryAllocator(error_reporter, buffer, buffer + buffer_size) {}
// Allocates memory starting at the head of the arena (lowest address and // Allocates memory starting at the head of the arena (lowest address and
// moving upwards). // moving upwards).
@ -44,14 +49,15 @@ class SimpleMemoryAllocator {
size_t GetAvailableMemory() const { return tail_ - head_; } size_t GetAvailableMemory() const { return tail_ - head_; }
private: private:
ErrorReporter* error_reporter_;
uint8_t* head_; uint8_t* head_;
uint8_t* tail_; uint8_t* tail_;
}; };
// Allocate a SimpleMemoryAllocator from the buffer and then return the pointer // Allocate a SimpleMemoryAllocator from the buffer and then return the pointer
// to this allocator. // to this allocator.
SimpleMemoryAllocator* CreateInPlaceSimpleMemoryAllocator(uint8_t* buffer, SimpleMemoryAllocator* CreateInPlaceSimpleMemoryAllocator(
size_t buffer_size); ErrorReporter* error_reporter, uint8_t* buffer, size_t buffer_size);
} // namespace tflite } // namespace tflite

View File

@ -25,7 +25,8 @@ TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(TestJustFits) { TF_LITE_MICRO_TEST(TestJustFits) {
constexpr size_t arena_size = 1024; constexpr size_t arena_size = 1024;
uint8_t arena[arena_size]; uint8_t arena[arena_size];
tflite::SimpleMemoryAllocator allocator(arena, arena_size); tflite::SimpleMemoryAllocator allocator(micro_test::reporter, arena,
arena_size);
uint8_t* result = allocator.AllocateFromTail(arena_size, 1); uint8_t* result = allocator.AllocateFromTail(arena_size, 1);
TF_LITE_MICRO_EXPECT_NE(nullptr, result); TF_LITE_MICRO_EXPECT_NE(nullptr, result);
@ -34,7 +35,8 @@ TF_LITE_MICRO_TEST(TestJustFits) {
TF_LITE_MICRO_TEST(TestAligned) { TF_LITE_MICRO_TEST(TestAligned) {
constexpr size_t arena_size = 1024; constexpr size_t arena_size = 1024;
uint8_t arena[arena_size]; uint8_t arena[arena_size];
tflite::SimpleMemoryAllocator allocator(arena, arena_size); tflite::SimpleMemoryAllocator allocator(micro_test::reporter, arena,
arena_size);
uint8_t* result = allocator.AllocateFromTail(1, 1); uint8_t* result = allocator.AllocateFromTail(1, 1);
TF_LITE_MICRO_EXPECT_NE(nullptr, result); TF_LITE_MICRO_EXPECT_NE(nullptr, result);
@ -47,7 +49,8 @@ TF_LITE_MICRO_TEST(TestAligned) {
TF_LITE_MICRO_TEST(TestMultipleTooLarge) { TF_LITE_MICRO_TEST(TestMultipleTooLarge) {
constexpr size_t arena_size = 1024; constexpr size_t arena_size = 1024;
uint8_t arena[arena_size]; uint8_t arena[arena_size];
tflite::SimpleMemoryAllocator allocator(arena, arena_size); tflite::SimpleMemoryAllocator allocator(micro_test::reporter, arena,
arena_size);
uint8_t* result = allocator.AllocateFromTail(768, 1); uint8_t* result = allocator.AllocateFromTail(768, 1);
TF_LITE_MICRO_EXPECT_NE(nullptr, result); TF_LITE_MICRO_EXPECT_NE(nullptr, result);