tflite: Prevent from reallocation of persistent tensors

After ResetAllocationsAfter() is called, CalculateAllocations() could be called
again for nodes which have persistent temporary tensors. The logic should
prevent from reallocation of these tensors since they're not going to be
initialized again.

This issue could be reproduced easily with hybrid quantized models since some
hybrid kernels are using persistent temporary tensors.

This PR resolves GitHub issue #44520.

PiperOrigin-RevId: 345569003
Change-Id: I1b9777b33a664ebd0f09df8d3236c7ece0118b1a
This commit is contained in:
Terry Heo 2020-12-03 17:19:02 -08:00
parent ab0402bd18
commit 9b87db80e0
2 changed files with 37 additions and 1 deletions

View File

@ -323,7 +323,9 @@ TfLiteStatus ArenaPlanner::CalculateAllocations(int first_node, int last_node) {
tensor_index, alloc_node_[tensor_index],
dealloc_node_[tensor_index], &allocs_[tensor_index]));
}
if (tensor.allocation_type == kTfLiteArenaRwPersistent) {
// Check allocs_[].size to prevent from reallocation of persistent tensors.
if (tensor.allocation_type == kTfLiteArenaRwPersistent &&
allocs_[tensor_index].size == 0) {
TF_LITE_ENSURE_STATUS(persistent_arena_.Allocate(
context_, tensor_alignment_, tensor.bytes, tensor_index,
/*first_node=*/alloc_node_[tensor_index],

View File

@ -356,6 +356,40 @@ TEST_F(ArenaPlannerTest, SimpleGraphWithResetAllocationsAfter) {
EXPECT_TRUE(IsUnallocated(5));
}
TEST_F(ArenaPlannerTest, SimpleGraphWithPersistentResetAllocationsAfter) {
TestGraph graph({0, 1},
{
/* in, out, tmp */
{{0, 1}, {2}, {}}, // First op
{{2, 0}, {4}, {5}}, // Second op, with temporary
{{4}, {3}, {}} // Third op
},
{3});
// Make the tensor #5 persistent.
(*graph.tensors())[5].allocation_type = kTfLiteArenaRwPersistent;
SetGraph(&graph);
Execute(0, 10);
// Save the pointer of the persistent temporary tensor #5.
void* tensor5_ptr = (*graph.tensors())[5].data.raw;
// Reset allocations after the first node
ResetAllocationsAfter(0);
EXPECT_FALSE(IsUnallocated(0));
EXPECT_FALSE(IsUnallocated(1));
EXPECT_FALSE(IsUnallocated(2));
EXPECT_TRUE(IsUnallocated(3));
EXPECT_TRUE(IsUnallocated(4));
EXPECT_FALSE(IsUnallocated(5));
// Second run
Execute(0, 10);
// Check if the persistent pointer isn't changed.
EXPECT_TRUE(tensor5_ptr == (*graph.tensors())[5].data.raw);
}
TEST_F(ArenaPlannerTest, SimpleGraphWithOptionals) {
TestGraph graph({0, -1, 1},
{