Provide a way to keep temporary tensors
Keep temporary tensors when 'preserve_intermediates_' is set. PiperOrigin-RevId: 346682289 Change-Id: I75dd0583466ec834a3a1eecaee38cb7ee668e28f
This commit is contained in:
parent
3ce8f8f39c
commit
a217fa5273
@ -205,7 +205,9 @@ TfLiteStatus ArenaPlanner::ExecuteAllocations(int first_node, int last_node) {
|
|||||||
for (int j = 0; j < node_temporaries->size; ++j) {
|
for (int j = 0; j < node_temporaries->size; ++j) {
|
||||||
int tensor_index = node_temporaries->data[j];
|
int tensor_index = node_temporaries->data[j];
|
||||||
alloc_node_[tensor_index] = i;
|
alloc_node_[tensor_index] = i;
|
||||||
dealloc_node_[tensor_index] = i;
|
if (!preserve_intermediates_) {
|
||||||
|
dealloc_node_[tensor_index] = i;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,12 +164,13 @@ void ReportError(TfLiteContext* context, const char* format, ...) {
|
|||||||
|
|
||||||
class ArenaPlannerTest : public ::testing::Test {
|
class ArenaPlannerTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void SetGraph(TestGraph* graph, bool preserve_inputs = false) {
|
void SetGraph(TestGraph* graph, bool preserve_inputs = false,
|
||||||
|
bool preserve_intermediates = false) {
|
||||||
graph_ = graph;
|
graph_ = graph;
|
||||||
context_.ReportError = ReportError;
|
context_.ReportError = ReportError;
|
||||||
planner_.reset(new ArenaPlanner(
|
planner_.reset(new ArenaPlanner(
|
||||||
&context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph)),
|
&context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph)),
|
||||||
preserve_inputs, /*preserve intermediates*/ false, kTensorAlignment));
|
preserve_inputs, preserve_intermediates, kTensorAlignment));
|
||||||
CHECK(planner_->ResetAllocations() == kTfLiteOk);
|
CHECK(planner_->ResetAllocations() == kTfLiteOk);
|
||||||
CHECK(planner_->PlanAllocations() == kTfLiteOk);
|
CHECK(planner_->PlanAllocations() == kTfLiteOk);
|
||||||
}
|
}
|
||||||
@ -745,6 +746,35 @@ TEST_F(ArenaPlannerTest, GraphWithIntermediates) {
|
|||||||
EXPECT_EQ(GetOffset(2), GetOffsetAfter(5));
|
EXPECT_EQ(GetOffset(2), GetOffsetAfter(5));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ArenaPlannerTest, DebugTensors) {
|
||||||
|
TestGraph graph({0, 1},
|
||||||
|
{
|
||||||
|
/* in, out, tmp */
|
||||||
|
{{0, 1}, {2}, {5}}, // First op, with temporary
|
||||||
|
{{2, 0}, {4}, {6}}, // Second op, with temporary
|
||||||
|
{{4}, {3}, {7}} // Third op, with temporary
|
||||||
|
},
|
||||||
|
{3});
|
||||||
|
SetGraph(&graph, false, /*preserve_intermediates=*/false);
|
||||||
|
Execute(0, 10);
|
||||||
|
|
||||||
|
// Memory of temporary tensors are shared by default.
|
||||||
|
EXPECT_EQ(GetOffset(5), 0);
|
||||||
|
EXPECT_EQ(GetOffset(6), 0);
|
||||||
|
EXPECT_EQ(GetOffset(7), 0);
|
||||||
|
|
||||||
|
SetGraph(&graph, false, /*preserve_intermediates=*/true);
|
||||||
|
Execute(0, 10);
|
||||||
|
|
||||||
|
std::set<std::ptrdiff_t> tensorOffsets;
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
tensorOffsets.insert(GetOffset(i));
|
||||||
|
}
|
||||||
|
// Every tensor should have unique memory allocation with
|
||||||
|
// preserve_intermediates.
|
||||||
|
EXPECT_EQ(tensorOffsets.size(), 8);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user