Merge pull request #45391 from terryheo/r2.4-fix-tflite-memoryplanner

tflite: Prevent from reallocation of persistent tensors
This commit is contained in:
Mihai Maruseac 2020-12-08 16:59:37 -08:00 committed by GitHub
commit dcc57e7bf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 1 deletions

View File

@ -323,7 +323,9 @@ TfLiteStatus ArenaPlanner::CalculateAllocations(int first_node, int last_node) {
tensor_index, alloc_node_[tensor_index],
dealloc_node_[tensor_index], &allocs_[tensor_index]));
}
if (tensor.allocation_type == kTfLiteArenaRwPersistent) {
// Check allocs_[].size to prevent from reallocation of persistent tensors.
if (tensor.allocation_type == kTfLiteArenaRwPersistent &&
allocs_[tensor_index].size == 0) {
TF_LITE_ENSURE_STATUS(persistent_arena_.Allocate(
context_, tensor_alignment_, tensor.bytes, tensor_index,
/*first_node=*/alloc_node_[tensor_index],

View File

@ -356,6 +356,40 @@ TEST_F(ArenaPlannerTest, SimpleGraphWithResetAllocationsAfter) {
EXPECT_TRUE(IsUnallocated(5));
}
TEST_F(ArenaPlannerTest, SimpleGraphWithPersistentResetAllocationsAfter) {
TestGraph graph({0, 1},
{
/* in, out, tmp */
{{0, 1}, {2}, {}}, // First op
{{2, 0}, {4}, {5}}, // Second op, with temporary
{{4}, {3}, {}} // Third op
},
{3});
// Make the tensor #5 persistent.
(*graph.tensors())[5].allocation_type = kTfLiteArenaRwPersistent;
SetGraph(&graph);
Execute(0, 10);
// Save the pointer of the persistent temporary tensor #5.
void* tensor5_ptr = (*graph.tensors())[5].data.raw;
// Reset allocations after the first node
ResetAllocationsAfter(0);
EXPECT_FALSE(IsUnallocated(0));
EXPECT_FALSE(IsUnallocated(1));
EXPECT_FALSE(IsUnallocated(2));
EXPECT_TRUE(IsUnallocated(3));
EXPECT_TRUE(IsUnallocated(4));
EXPECT_FALSE(IsUnallocated(5));
// Second run
Execute(0, 10);
// Check if the persistent pointer isn't changed.
EXPECT_TRUE(tensor5_ptr == (*graph.tensors())[5].data.raw);
}
TEST_F(ArenaPlannerTest, SimpleGraphWithOptionals) {
TestGraph graph({0, -1, 1},
{