From 17521bcffd230abe7f7f5ceec6279b9a5504529a Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Wed, 10 Jul 2019 12:37:54 -0700 Subject: [PATCH] Fix invoking while multiple times when body is dynamic. PiperOrigin-RevId: 257459721 --- tensorflow/lite/kernels/while.cc | 55 +++++++++++++++++++-------- tensorflow/lite/kernels/while_test.cc | 9 +++-- 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/tensorflow/lite/kernels/while.cc b/tensorflow/lite/kernels/while.cc index b3f00d3fe13..a6438558458 100644 --- a/tensorflow/lite/kernels/while.cc +++ b/tensorflow/lite/kernels/while.cc @@ -28,21 +28,36 @@ namespace { // Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph` // to `dst_tensor_indices` in `dst_subgraph`. +// +// When `resize_subgraph_inputs` is true, the function calls subgraphs's +// `ResizeInputTensor` function, and it may trigger the memory planner to +// reallocate memory. +// When `resize_subgraph_inputs` is false, it implies `context` belongs to +// `dst_subgraph`. The function calls `context->ResizeTensor`. This happens +// when resizing `While` op's outputs. template TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context, Subgraph* src_subgraph, const SrcVector& src_tensor_indices, Subgraph* dst_subgraph, - const DstVector& dst_tensor_indices) { + const DstVector& dst_tensor_indices, + bool resize_subgraph_inputs) { TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(), dst_tensor_indices.size()); for (int i = 0; i < src_tensor_indices.size(); ++i) { const TfLiteTensor* src_tensor = src_subgraph->tensor(src_tensor_indices[i]); - std::vector dims(src_tensor->dims->data, - src_tensor->dims->data + src_tensor->dims->size); - dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims); + TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); + if (resize_subgraph_inputs) { + std::vector dims(src_tensor->dims->data, + src_tensor->dims->data + src_tensor->dims->size); + dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims); + } else { + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, dst_tensor, + TfLiteIntArrayCopy(src_tensor->dims))); + } dst_tensor->type = src_tensor->type; } return kTfLiteOk; @@ -130,9 +145,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Prepare and check the condition subgraph. TF_LITE_ENSURE_OK( - context, CopyTensorsShapeAndType(context, this_subgraph, - TfLiteIntArrayView(node->inputs), - cond_subgraph, cond_subgraph->inputs())); + context, CopyTensorsShapeAndType( + context, this_subgraph, TfLiteIntArrayView(node->inputs), + cond_subgraph, cond_subgraph->inputs(), true)); TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors()); TfLiteTensor* cond_output = cond_subgraph->tensor(cond_subgraph->outputs()[0]); @@ -148,9 +163,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Prepare and check the body subgraph. TF_LITE_ENSURE_OK( - context, CopyTensorsShapeAndType(context, this_subgraph, - TfLiteIntArrayView(node->inputs), - body_subgraph, body_subgraph->inputs())); + context, CopyTensorsShapeAndType( + context, this_subgraph, TfLiteIntArrayView(node->inputs), + body_subgraph, body_subgraph->inputs(), true)); TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors()); if (body_subgraph->HasDynamicTensors()) { op_data->body_has_dynamic_output_tensors = true; @@ -232,6 +247,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // boundary. Currently we copy the input / output between the subgraphs. This // isn't optimized yet and a lot of redundant copies are made. // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs. + + if (op_data->body_has_dynamic_output_tensors) { + // If body subgraph has dynamic outputs, the input of condition subgraph may + // be changed in the last invocation and may need resizing. + TF_LITE_ENSURE_OK( + context, CopyTensorsShapeAndType( + context, this_subgraph, TfLiteIntArrayView(node->inputs), + cond_subgraph, cond_subgraph->inputs(), true)); + TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors()); + } TF_LITE_ENSURE_OK( context, CopyTensorsData(context, this_subgraph, TfLiteIntArrayView(node->inputs), @@ -254,7 +279,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, CopyTensorsShapeAndType( context, cond_subgraph, cond_subgraph->inputs(), - body_subgraph, body_subgraph->inputs())); + body_subgraph, body_subgraph->inputs(), true)); TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors()); } @@ -273,7 +298,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, CopyTensorsShapeAndType( context, body_subgraph, body_subgraph->outputs(), - cond_subgraph, cond_subgraph->inputs())); + cond_subgraph, cond_subgraph->inputs(), true)); TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors()); } @@ -287,9 +312,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs. if (op_data->body_has_dynamic_output_tensors) { TF_LITE_ENSURE_OK( - context, CopyTensorsShapeAndType(context, cond_subgraph, - cond_subgraph->inputs(), this_subgraph, - TfLiteIntArrayView(node->outputs))); + context, CopyTensorsShapeAndType( + context, cond_subgraph, cond_subgraph->inputs(), + this_subgraph, TfLiteIntArrayView(node->outputs), false)); } TF_LITE_ENSURE_OK( diff --git a/tensorflow/lite/kernels/while_test.cc b/tensorflow/lite/kernels/while_test.cc index a3a80ea6f50..1745f585ed0 100644 --- a/tensorflow/lite/kernels/while_test.cc +++ b/tensorflow/lite/kernels/while_test.cc @@ -59,8 +59,6 @@ TEST_F(WhileTest, TestTriangularNumberSequence) { } } -// This requires dynamic sized subgraphs and it's not supported right now. -// TODO(ycling): Support dynamic sized subgraphs. TEST_F(WhileTest, TestPadLoop) { interpreter_.reset(new Interpreter); interpreter_->AddSubgraphs(2); @@ -70,8 +68,6 @@ TEST_F(WhileTest, TestPadLoop) { interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2}); - // This is not supported yet. The test ensures thatit doesn't crash and raises - // an error properly. ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1}); @@ -82,6 +78,11 @@ TEST_F(WhileTest, TestPadLoop) { CheckIntTensor(output1, {1}, {4}); TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]); CheckIntTensor(output2, {11}, {0, 0, 0, 5, 7, 0, 0, 0, 0, 0, 0}); + + // The extra invocation serves as a regiression test: There was a bug that + // invoking a while loop with dynamic shaped body makes the interpreter + // state uninvokable. + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); } } // namespace