Provide a way to keep temporary tensors

Keep temporary tensors when 'preserve_intermediates_' is set.

PiperOrigin-RevId: 346682289
Change-Id: I75dd0583466ec834a3a1eecaee38cb7ee668e28f
This commit is contained in:
Terry Heo 2020-12-09 18:23:00 -08:00 committed by TensorFlower Gardener
parent 3ce8f8f39c
commit a217fa5273
2 changed files with 35 additions and 3 deletions

View File

@ -205,7 +205,9 @@ TfLiteStatus ArenaPlanner::ExecuteAllocations(int first_node, int last_node) {
for (int j = 0; j < node_temporaries->size; ++j) {
int tensor_index = node_temporaries->data[j];
alloc_node_[tensor_index] = i;
dealloc_node_[tensor_index] = i;
if (!preserve_intermediates_) {
dealloc_node_[tensor_index] = i;
}
}
}

View File

@ -164,12 +164,13 @@ void ReportError(TfLiteContext* context, const char* format, ...) {
class ArenaPlannerTest : public ::testing::Test {
protected:
void SetGraph(TestGraph* graph, bool preserve_inputs = false) {
void SetGraph(TestGraph* graph, bool preserve_inputs = false,
bool preserve_intermediates = false) {
graph_ = graph;
context_.ReportError = ReportError;
planner_.reset(new ArenaPlanner(
&context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph)),
preserve_inputs, /*preserve intermediates*/ false, kTensorAlignment));
preserve_inputs, preserve_intermediates, kTensorAlignment));
CHECK(planner_->ResetAllocations() == kTfLiteOk);
CHECK(planner_->PlanAllocations() == kTfLiteOk);
}
@ -745,6 +746,35 @@ TEST_F(ArenaPlannerTest, GraphWithIntermediates) {
EXPECT_EQ(GetOffset(2), GetOffsetAfter(5));
}
TEST_F(ArenaPlannerTest, DebugTensors) {
TestGraph graph({0, 1},
{
/* in, out, tmp */
{{0, 1}, {2}, {5}}, // First op, with temporary
{{2, 0}, {4}, {6}}, // Second op, with temporary
{{4}, {3}, {7}} // Third op, with temporary
},
{3});
SetGraph(&graph, false, /*preserve_intermediates=*/false);
Execute(0, 10);
// Memory of temporary tensors are shared by default.
EXPECT_EQ(GetOffset(5), 0);
EXPECT_EQ(GetOffset(6), 0);
EXPECT_EQ(GetOffset(7), 0);
SetGraph(&graph, false, /*preserve_intermediates=*/true);
Execute(0, 10);
std::set<std::ptrdiff_t> tensorOffsets;
for (int i = 0; i < 8; i++) {
tensorOffsets.insert(GetOffset(i));
}
// Every tensor should have unique memory allocation with
// preserve_intermediates.
EXPECT_EQ(tensorOffsets.size(), 8);
}
} // namespace
} // namespace tflite