diff --git a/tensorflow/lite/micro/simple_memory_allocator.cc b/tensorflow/lite/micro/simple_memory_allocator.cc index 18e9c5d4711..40593734044 100644 --- a/tensorflow/lite/micro/simple_memory_allocator.cc +++ b/tensorflow/lite/micro/simple_memory_allocator.cc @@ -131,13 +131,13 @@ size_t SimpleMemoryAllocator::GetTailUsedBytes() const { } size_t SimpleMemoryAllocator::GetAvailableMemory(size_t alignment) const { - uint8_t* const aligned_head = AlignPointerUp(head_, alignment); + uint8_t* const aligned_temp = AlignPointerUp(temp_, alignment); uint8_t* const aligned_tail = AlignPointerDown(tail_, alignment); - return aligned_tail - aligned_head; + return aligned_tail - aligned_temp; } size_t SimpleMemoryAllocator::GetUsedBytes() const { - return GetBufferSize() - (tail_ - head_); + return GetBufferSize() - (tail_ - temp_); } size_t SimpleMemoryAllocator::GetBufferSize() const { diff --git a/tensorflow/lite/micro/simple_memory_allocator.h b/tensorflow/lite/micro/simple_memory_allocator.h index d6ef4180847..6c353c84d8a 100644 --- a/tensorflow/lite/micro/simple_memory_allocator.h +++ b/tensorflow/lite/micro/simple_memory_allocator.h @@ -81,9 +81,12 @@ class SimpleMemoryAllocator { size_t GetHeadUsedBytes() const; size_t GetTailUsedBytes() const; - // Returns the number of bytes available with a given alignment. + // Returns the number of bytes available with a given alignment. This number + // takes in account any temporary allocations. size_t GetAvailableMemory(size_t alignment) const; + // Returns the number of used bytes in the allocator. This number takes in + // account any temporary allocations. size_t GetUsedBytes() const; private: diff --git a/tensorflow/lite/micro/simple_memory_allocator_test.cc b/tensorflow/lite/micro/simple_memory_allocator_test.cc index cafe2bf3b47..e2a9f972fff 100644 --- a/tensorflow/lite/micro/simple_memory_allocator_test.cc +++ b/tensorflow/lite/micro/simple_memory_allocator_test.cc @@ -98,6 +98,78 @@ TF_LITE_MICRO_TEST(TestAdjustHeadSizeMisalignedHandlesCorrectBytesAvailable) { TF_LITE_MICRO_EXPECT_GE(aligned_available_bytes, arena_size - 1000 - 24); } +TF_LITE_MICRO_TEST(TestGetAvailableMemory) { + constexpr size_t arena_size = 1024; + uint8_t arena[arena_size]; + tflite::SimpleMemoryAllocator allocator(micro_test::reporter, arena, + arena_size); + + constexpr size_t allocation_size = 100; + allocator.SetHeadSize(/*size=*/allocation_size, + /*alignment=*/1); + allocator.AllocateFromTail(/*size=*/allocation_size, + /*alignment=*/1); + + TF_LITE_MICRO_EXPECT_EQ(allocator.GetAvailableMemory(/*alignment=*/1), + arena_size - allocation_size * 2); +} + +TF_LITE_MICRO_TEST(TestGetAvailableMemoryWithTempAllocations) { + constexpr size_t arena_size = 1024; + uint8_t arena[arena_size]; + tflite::SimpleMemoryAllocator allocator(micro_test::reporter, arena, + arena_size); + + constexpr size_t allocation_size = 100; + allocator.AllocateTemp(/*size=*/allocation_size, + /*alignment=*/1); + + TF_LITE_MICRO_EXPECT_EQ(allocator.GetAvailableMemory(/*alignment=*/1), + arena_size - allocation_size); + + // Reset temp allocations and ensure GetAvailableMemory() is back to the + // starting size: + allocator.ResetTempAllocations(); + + TF_LITE_MICRO_EXPECT_EQ(allocator.GetAvailableMemory(/*alignment=*/1), + arena_size); +} + +TF_LITE_MICRO_TEST(TestGetUsedBytes) { + constexpr size_t arena_size = 1024; + uint8_t arena[arena_size]; + tflite::SimpleMemoryAllocator allocator(micro_test::reporter, arena, + arena_size); + TF_LITE_MICRO_EXPECT_EQ(allocator.GetUsedBytes(), static_cast(0)); + + constexpr size_t allocation_size = 100; + allocator.SetHeadSize(/*size=*/allocation_size, + /*alignment=*/1); + allocator.AllocateFromTail(/*size=*/allocation_size, + /*alignment=*/1); + + TF_LITE_MICRO_EXPECT_EQ(allocator.GetUsedBytes(), allocation_size * 2); +} + +TF_LITE_MICRO_TEST(TestGetUsedBytesTempAllocations) { + constexpr size_t arena_size = 1024; + uint8_t arena[arena_size]; + tflite::SimpleMemoryAllocator allocator(micro_test::reporter, arena, + arena_size); + + constexpr size_t allocation_size = 100; + allocator.AllocateTemp(/*size=*/allocation_size, + /*alignment=*/1); + + TF_LITE_MICRO_EXPECT_EQ(allocator.GetUsedBytes(), allocation_size); + + // Reset temp allocations and ensure GetUsedBytes() is back to the starting + // size: + allocator.ResetTempAllocations(); + + TF_LITE_MICRO_EXPECT_EQ(allocator.GetUsedBytes(), static_cast(0)); +} + TF_LITE_MICRO_TEST(TestJustFits) { constexpr size_t arena_size = 1024; uint8_t arena[arena_size];