Provide a way to keep temporary tensors
Keep temporary tensors when 'preserve_intermediates_' is set. PiperOrigin-RevId: 346682289 Change-Id: I75dd0583466ec834a3a1eecaee38cb7ee668e28f
This commit is contained in:
parent
3ce8f8f39c
commit
a217fa5273
tensorflow/lite
@ -205,7 +205,9 @@ TfLiteStatus ArenaPlanner::ExecuteAllocations(int first_node, int last_node) {
|
||||
for (int j = 0; j < node_temporaries->size; ++j) {
|
||||
int tensor_index = node_temporaries->data[j];
|
||||
alloc_node_[tensor_index] = i;
|
||||
dealloc_node_[tensor_index] = i;
|
||||
if (!preserve_intermediates_) {
|
||||
dealloc_node_[tensor_index] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -164,12 +164,13 @@ void ReportError(TfLiteContext* context, const char* format, ...) {
|
||||
|
||||
class ArenaPlannerTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetGraph(TestGraph* graph, bool preserve_inputs = false) {
|
||||
void SetGraph(TestGraph* graph, bool preserve_inputs = false,
|
||||
bool preserve_intermediates = false) {
|
||||
graph_ = graph;
|
||||
context_.ReportError = ReportError;
|
||||
planner_.reset(new ArenaPlanner(
|
||||
&context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph)),
|
||||
preserve_inputs, /*preserve intermediates*/ false, kTensorAlignment));
|
||||
preserve_inputs, preserve_intermediates, kTensorAlignment));
|
||||
CHECK(planner_->ResetAllocations() == kTfLiteOk);
|
||||
CHECK(planner_->PlanAllocations() == kTfLiteOk);
|
||||
}
|
||||
@ -745,6 +746,35 @@ TEST_F(ArenaPlannerTest, GraphWithIntermediates) {
|
||||
EXPECT_EQ(GetOffset(2), GetOffsetAfter(5));
|
||||
}
|
||||
|
||||
TEST_F(ArenaPlannerTest, DebugTensors) {
|
||||
TestGraph graph({0, 1},
|
||||
{
|
||||
/* in, out, tmp */
|
||||
{{0, 1}, {2}, {5}}, // First op, with temporary
|
||||
{{2, 0}, {4}, {6}}, // Second op, with temporary
|
||||
{{4}, {3}, {7}} // Third op, with temporary
|
||||
},
|
||||
{3});
|
||||
SetGraph(&graph, false, /*preserve_intermediates=*/false);
|
||||
Execute(0, 10);
|
||||
|
||||
// Memory of temporary tensors are shared by default.
|
||||
EXPECT_EQ(GetOffset(5), 0);
|
||||
EXPECT_EQ(GetOffset(6), 0);
|
||||
EXPECT_EQ(GetOffset(7), 0);
|
||||
|
||||
SetGraph(&graph, false, /*preserve_intermediates=*/true);
|
||||
Execute(0, 10);
|
||||
|
||||
std::set<std::ptrdiff_t> tensorOffsets;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
tensorOffsets.insert(GetOffset(i));
|
||||
}
|
||||
// Every tensor should have unique memory allocation with
|
||||
// preserve_intermediates.
|
||||
EXPECT_EQ(tensorOffsets.size(), 8);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user