Merge pull request #45391 from terryheo/r2.4-fix-tflite-memoryplanner
tflite: Prevent from reallocation of persistent tensors
This commit is contained in:
commit
dcc57e7bf7
@ -323,7 +323,9 @@ TfLiteStatus ArenaPlanner::CalculateAllocations(int first_node, int last_node) {
|
||||
tensor_index, alloc_node_[tensor_index],
|
||||
dealloc_node_[tensor_index], &allocs_[tensor_index]));
|
||||
}
|
||||
if (tensor.allocation_type == kTfLiteArenaRwPersistent) {
|
||||
// Check allocs_[].size to prevent from reallocation of persistent tensors.
|
||||
if (tensor.allocation_type == kTfLiteArenaRwPersistent &&
|
||||
allocs_[tensor_index].size == 0) {
|
||||
TF_LITE_ENSURE_STATUS(persistent_arena_.Allocate(
|
||||
context_, tensor_alignment_, tensor.bytes, tensor_index,
|
||||
/*first_node=*/alloc_node_[tensor_index],
|
||||
|
@ -356,6 +356,40 @@ TEST_F(ArenaPlannerTest, SimpleGraphWithResetAllocationsAfter) {
|
||||
EXPECT_TRUE(IsUnallocated(5));
|
||||
}
|
||||
|
||||
TEST_F(ArenaPlannerTest, SimpleGraphWithPersistentResetAllocationsAfter) {
|
||||
TestGraph graph({0, 1},
|
||||
{
|
||||
/* in, out, tmp */
|
||||
{{0, 1}, {2}, {}}, // First op
|
||||
{{2, 0}, {4}, {5}}, // Second op, with temporary
|
||||
{{4}, {3}, {}} // Third op
|
||||
},
|
||||
{3});
|
||||
// Make the tensor #5 persistent.
|
||||
(*graph.tensors())[5].allocation_type = kTfLiteArenaRwPersistent;
|
||||
SetGraph(&graph);
|
||||
Execute(0, 10);
|
||||
|
||||
// Save the pointer of the persistent temporary tensor #5.
|
||||
void* tensor5_ptr = (*graph.tensors())[5].data.raw;
|
||||
|
||||
// Reset allocations after the first node
|
||||
ResetAllocationsAfter(0);
|
||||
|
||||
EXPECT_FALSE(IsUnallocated(0));
|
||||
EXPECT_FALSE(IsUnallocated(1));
|
||||
EXPECT_FALSE(IsUnallocated(2));
|
||||
EXPECT_TRUE(IsUnallocated(3));
|
||||
EXPECT_TRUE(IsUnallocated(4));
|
||||
EXPECT_FALSE(IsUnallocated(5));
|
||||
|
||||
// Second run
|
||||
Execute(0, 10);
|
||||
|
||||
// Check if the persistent pointer isn't changed.
|
||||
EXPECT_TRUE(tensor5_ptr == (*graph.tensors())[5].data.raw);
|
||||
}
|
||||
|
||||
TEST_F(ArenaPlannerTest, SimpleGraphWithOptionals) {
|
||||
TestGraph graph({0, -1, 1},
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user