diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc index 7cd40e54435..16b3b986a52 100644 --- a/tensorflow/lite/micro/micro_allocator.cc +++ b/tensorflow/lite/micro/micro_allocator.cc @@ -375,10 +375,9 @@ TfLiteStatus CreatePlan(ErrorReporter* error_reporter, planner->AddBuffer(error_reporter, aligned_bytes_required, current->first_created, current->last_used)); } else { - TF_LITE_ENSURE_STATUS( - planner->AddBuffer(error_reporter, aligned_bytes_required, - current->first_created, current->last_used, - current->offline_offset)); + TF_LITE_ENSURE_STATUS(planner->AddBuffer( + error_reporter, aligned_bytes_required, current->first_created, + current->last_used, current->offline_offset)); } } } @@ -647,7 +646,7 @@ TfLiteStatus MicroAllocator::FinishModelAllocation(const Model* model, const SubGraph* subgraph = GetSubGraphFromModel(model); TFLITE_DCHECK(subgraph != nullptr); - TF_LITE_ENSURE_STATUS(CommitStaticMemoryPlan(subgraph, context)); + TF_LITE_ENSURE_STATUS(CommitStaticMemoryPlan(model, subgraph, context)); TF_LITE_ENSURE_STATUS(AllocateVariables(subgraph->tensors(), context->tensors, memory_allocator_)); @@ -874,7 +873,8 @@ const SubGraph* MicroAllocator::GetSubGraphFromModel(const Model* model) { return (*subgraphs)[0]; } -TfLiteStatus MicroAllocator::CommitStaticMemoryPlan(const SubGraph* subgraph, +TfLiteStatus MicroAllocator::CommitStaticMemoryPlan(const Model* model, + const SubGraph* subgraph, TfLiteContext* context) { // Create static memory plan // 1. Calculate AllocationInfo to know the lifetime of each tensor/buffer. @@ -891,7 +891,13 @@ TfLiteStatus MicroAllocator::CommitStaticMemoryPlan(const SubGraph* subgraph, AllocationInfoBuilder builder(error_reporter_, &tmp_allocator); TF_LITE_ENSURE_STATUS( builder.Init(subgraph->tensors()->size(), scratch_buffer_count_)); - TF_LITE_ENSURE_STATUS(builder.AddTensors(subgraph, context->tensors)); + + int32_t* offline_planner_offsets = nullptr; + TF_LITE_ENSURE_STATUS( + builder.GetOfflinePlannedOffsets(model, &offline_planner_offsets)); + TF_LITE_ENSURE_STATUS(builder.AddTensors(subgraph, offline_planner_offsets, + context->tensors)); + TF_LITE_ENSURE_STATUS(builder.AddScratchBuffers(scratch_buffer_handles_)); const AllocationInfo* allocation_info = builder.Finish(); diff --git a/tensorflow/lite/micro/micro_allocator.h b/tensorflow/lite/micro/micro_allocator.h index 7fc091196a5..9cfc1793fc7 100644 --- a/tensorflow/lite/micro/micro_allocator.h +++ b/tensorflow/lite/micro/micro_allocator.h @@ -185,7 +185,8 @@ class MicroAllocator { // Commits a memory plan for all non-persistent buffer allocations in the // 'head' section of the memory arena. - virtual TfLiteStatus CommitStaticMemoryPlan(const SubGraph* subgraph, + virtual TfLiteStatus CommitStaticMemoryPlan(const Model* model, + const SubGraph* subgraph, TfLiteContext* context); // A simple memory allocator that always allocate from the arena tail or head. diff --git a/tensorflow/lite/micro/micro_allocator_test.cc b/tensorflow/lite/micro/micro_allocator_test.cc index 04f4732b9d3..f3f3f32611e 100644 --- a/tensorflow/lite/micro/micro_allocator_test.cc +++ b/tensorflow/lite/micro/micro_allocator_test.cc @@ -312,6 +312,8 @@ TF_LITE_MICRO_TEST(OfflinePlannerBranchesAllOnline) { int version = 1; int subgraph = 0; constexpr int nbr_tensors = 4; + tflite::testing::MockOpResolver mock_resolver; + tflite::NodeAndRegistration* node_and_registration; const int32_t metadata_buffer[tflite::testing::kOfflinePlannerHeaderSize + nbr_tensors] = {version, subgraph, nbr_tensors, // header @@ -340,9 +342,14 @@ TF_LITE_MICRO_TEST(OfflinePlannerBranchesAllOnline) { TfLiteContext context; constexpr size_t arena_size = 4096; uint8_t arena[arena_size]; - tflite::MicroAllocator* allocator = tflite::MicroAllocator::Create( - &context, model, arena, arena_size, micro_test::reporter); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, allocator->FinishTensorAllocation()); + tflite::MicroAllocator* allocator = + tflite::MicroAllocator::Create(arena, arena_size, micro_test::reporter); + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, allocator->StartModelAllocation(model, &context, mock_resolver, + &node_and_registration)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, + allocator->FinishModelAllocation(model, &context)); // Since all of the tensors are online planned and the model structure is // identical to that in TestAllocationForModelsWithBranches, @@ -357,6 +364,8 @@ TF_LITE_MICRO_TEST(OfflinePlannerBranchesAllOnline) { TF_LITE_MICRO_TEST(OfflinePlannerBasic) { constexpr int nbr_tensors = 4; + tflite::testing::MockOpResolver mock_resolver; + tflite::NodeAndRegistration* node_and_registration; const int32_t metadata_buffer[tflite::testing::kOfflinePlannerHeaderSize + nbr_tensors] = {1, 0, nbr_tensors, 0, // t0 @@ -389,9 +398,14 @@ TF_LITE_MICRO_TEST(OfflinePlannerBasic) { TfLiteContext context; constexpr size_t arena_size = 4096; uint8_t arena[arena_size]; - tflite::MicroAllocator* allocator = tflite::MicroAllocator::Create( - &context, model, arena, arena_size, micro_test::reporter); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, allocator->FinishTensorAllocation()); + tflite::MicroAllocator* allocator = + tflite::MicroAllocator::Create(arena, arena_size, micro_test::reporter); + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, allocator->StartModelAllocation(model, &context, mock_resolver, + &node_and_registration)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, + allocator->FinishModelAllocation(model, &context)); uint8_t* start = context.tensors[0].data.uint8; TF_LITE_MICRO_EXPECT_EQ(0, context.tensors[0].data.uint8 - start); @@ -402,6 +416,8 @@ TF_LITE_MICRO_TEST(OfflinePlannerBasic) { TF_LITE_MICRO_TEST(OfflinePlannerOverlappingAllocation) { constexpr int nbr_tensors = 4; + tflite::testing::MockOpResolver mock_resolver; + tflite::NodeAndRegistration* node_and_registration; const int32_t metadata_buffer[tflite::testing::kOfflinePlannerHeaderSize + nbr_tensors] = { 1, 0, nbr_tensors, // header: version, subgraph, nbr tensors @@ -434,9 +450,14 @@ TF_LITE_MICRO_TEST(OfflinePlannerOverlappingAllocation) { TfLiteContext context; constexpr size_t arena_size = 4096; uint8_t arena[arena_size]; - tflite::MicroAllocator* allocator = tflite::MicroAllocator::Create( - &context, model, arena, arena_size, micro_test::reporter); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, allocator->FinishTensorAllocation()); + tflite::MicroAllocator* allocator = + tflite::MicroAllocator::Create(arena, arena_size, micro_test::reporter); + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, allocator->StartModelAllocation(model, &context, mock_resolver, + &node_and_registration)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, + allocator->FinishModelAllocation(model, &context)); uint8_t* start = context.tensors[0].data.uint8; TF_LITE_MICRO_EXPECT_EQ(0, context.tensors[0].data.uint8 - start); @@ -448,6 +469,8 @@ TF_LITE_MICRO_TEST(OfflinePlannerOverlappingAllocation) { TF_LITE_MICRO_TEST(OfflinePlannerOfflineOnline) { constexpr int nbr_tensors = 5; + tflite::testing::MockOpResolver mock_resolver; + tflite::NodeAndRegistration* node_and_registration; const int32_t metadata_buffer[tflite::testing::kOfflinePlannerHeaderSize + nbr_tensors] = { 1, 0, nbr_tensors, // header: version, subgraph, nbr tensors @@ -482,9 +505,14 @@ TF_LITE_MICRO_TEST(OfflinePlannerOfflineOnline) { TfLiteContext context; constexpr size_t arena_size = 4096; uint8_t arena[arena_size]; - tflite::MicroAllocator* allocator = tflite::MicroAllocator::Create( - &context, model, arena, arena_size, micro_test::reporter); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, allocator->FinishTensorAllocation()); + tflite::MicroAllocator* allocator = + tflite::MicroAllocator::Create(arena, arena_size, micro_test::reporter); + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, allocator->StartModelAllocation(model, &context, mock_resolver, + &node_and_registration)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, + allocator->FinishModelAllocation(model, &context)); uint8_t* start = context.tensors[0].data.uint8; TF_LITE_MICRO_EXPECT_EQ(0, context.tensors[0].data.uint8 - start);