diff --git a/tensorflow/lite/arena_planner.cc b/tensorflow/lite/arena_planner.cc index b134a5de044..8b8913d7322 100644 --- a/tensorflow/lite/arena_planner.cc +++ b/tensorflow/lite/arena_planner.cc @@ -323,7 +323,9 @@ TfLiteStatus ArenaPlanner::CalculateAllocations(int first_node, int last_node) { tensor_index, alloc_node_[tensor_index], dealloc_node_[tensor_index], &allocs_[tensor_index])); } - if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + // Check allocs_[].size to prevent from reallocation of persistent tensors. + if (tensor.allocation_type == kTfLiteArenaRwPersistent && + allocs_[tensor_index].size == 0) { TF_LITE_ENSURE_STATUS(persistent_arena_.Allocate( context_, tensor_alignment_, tensor.bytes, tensor_index, /*first_node=*/alloc_node_[tensor_index], diff --git a/tensorflow/lite/arena_planner_test.cc b/tensorflow/lite/arena_planner_test.cc index 47ecc68cf40..ca2127d8333 100644 --- a/tensorflow/lite/arena_planner_test.cc +++ b/tensorflow/lite/arena_planner_test.cc @@ -356,6 +356,40 @@ TEST_F(ArenaPlannerTest, SimpleGraphWithResetAllocationsAfter) { EXPECT_TRUE(IsUnallocated(5)); } +TEST_F(ArenaPlannerTest, SimpleGraphWithPersistentResetAllocationsAfter) { + TestGraph graph({0, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4}, {5}}, // Second op, with temporary + {{4}, {3}, {}} // Third op + }, + {3}); + // Make the tensor #5 persistent. + (*graph.tensors())[5].allocation_type = kTfLiteArenaRwPersistent; + SetGraph(&graph); + Execute(0, 10); + + // Save the pointer of the persistent temporary tensor #5. + void* tensor5_ptr = (*graph.tensors())[5].data.raw; + + // Reset allocations after the first node + ResetAllocationsAfter(0); + + EXPECT_FALSE(IsUnallocated(0)); + EXPECT_FALSE(IsUnallocated(1)); + EXPECT_FALSE(IsUnallocated(2)); + EXPECT_TRUE(IsUnallocated(3)); + EXPECT_TRUE(IsUnallocated(4)); + EXPECT_FALSE(IsUnallocated(5)); + + // Second run + Execute(0, 10); + + // Check if the persistent pointer isn't changed. + EXPECT_TRUE(tensor5_ptr == (*graph.tensors())[5].data.raw); +} + TEST_F(ArenaPlannerTest, SimpleGraphWithOptionals) { TestGraph graph({0, -1, 1}, {