diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index bc37c645fb7..e9ca062dc3c 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -639,6 +639,7 @@ cc_test( "testdata/test_min_runtime.bin", "testdata/test_model.bin", "testdata/test_model_broken.bin", + "testdata/while_op_with_forwarding_input.bin", ], tags = [ "tflite_not_portable", @@ -647,6 +648,7 @@ cc_test( deps = [ ":framework", ":interpreter_test_util", + "//tensorflow/lite:string_util", "//tensorflow/lite/core/api", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/testing:util", diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 1219e9709da..7e031472277 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -886,7 +886,10 @@ TfLiteStatus Subgraph::PrepareOpsStartingAt( int first_execution_plan_index, const std::vector<int>& execution_plan, int* last_execution_plan_index_prepared) { if (first_execution_plan_index == 0) { - has_dynamic_tensors_ = false; + // Forwarding inputs without modification won't be not evaluated in the + // operators. So, it needs to look up the subgraph's output tensors at the + // beginning. + has_dynamic_tensors_ = HasDynamicTensorImpl(context_, outputs()); } for (int execution_plan_index = first_execution_plan_index; execution_plan_index < execution_plan.size(); execution_plan_index++) { diff --git a/tensorflow/lite/model_test.cc b/tensorflow/lite/model_test.cc index a0115d0d630..c419ad218ff 100644 --- a/tensorflow/lite/model_test.cc +++ b/tensorflow/lite/model_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/interpreter_test_util.h" #include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/string_util.h" #include "tensorflow/lite/testing/util.h" // Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object, @@ -550,6 +551,36 @@ TEST(TestAddDelegateOwnership, AddDelegateDoesNotTakeOwnership) { EXPECT_TRUE(destroyed); } +// The model contains a while loop with a forwarding string input. This test +// makes sure that the dynamic tensor existence in the while subgraph's outputs +// is detected. If not, the while loop will be failed at handling the dynamic +// tensor handling as a static tensor. +TEST(BasicFlatBufferModel, TestHandleModelWithWhileOpContainsForwardingInput) { + const auto model_path = + "tensorflow/lite/testdata/while_op_with_forwarding_input.bin"; + + std::unique_ptr<tflite::FlatBufferModel> model = + FlatBufferModel::BuildFromFile(model_path); + ASSERT_NE(model, nullptr); + + tflite::ops::builtin::BuiltinOpResolver resolver; + InterpreterBuilder builder(*model, resolver); + std::unique_ptr<Interpreter> interpreter; + ASSERT_EQ(builder(&interpreter), kTfLiteOk); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk); + + int32_t* tensor_data = interpreter->typed_tensor<int32_t>(0); + tensor_data[0] = 20; + + auto tensor = interpreter->tensor(1); + DynamicBuffer buf; + buf.AddString("a", 1); + buf.WriteToTensor(tensor, /*new_shape=*/nullptr); + + ASSERT_EQ(interpreter->Invoke(), kTfLiteOk); +} + // TODO(aselle): Add tests for serialization of builtin op data types. // These tests will occur with the evaluation tests of individual operators, // not here. diff --git a/tensorflow/lite/testdata/while_op_with_forwarding_input.bin b/tensorflow/lite/testdata/while_op_with_forwarding_input.bin new file mode 100644 index 00000000000..e0eb9914001 Binary files /dev/null and b/tensorflow/lite/testdata/while_op_with_forwarding_input.bin differ