Check the dynamic tensor existence of the subgraph outputs at the beginning

Forwarding inputs without modification won't be not evaluated in the
subgraph's operators. So, it needs to look up the subgraph's output tensors at
the beginning.

PiperOrigin-RevId: 352703717
Change-Id: If299e706cf5c9d4c722e0f85751e7d085a0cd4c3
This commit is contained in:
Jaesung Chung 2021-01-19 19:26:28 -08:00 committed by TensorFlower Gardener
parent 87933916c3
commit 2d40649f53
4 changed files with 37 additions and 1 deletions

View File

@ -639,6 +639,7 @@ cc_test(
"testdata/test_min_runtime.bin",
"testdata/test_model.bin",
"testdata/test_model_broken.bin",
"testdata/while_op_with_forwarding_input.bin",
],
tags = [
"tflite_not_portable",
@ -647,6 +648,7 @@ cc_test(
deps = [
":framework",
":interpreter_test_util",
"//tensorflow/lite:string_util",
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/testing:util",

View File

@ -886,7 +886,10 @@ TfLiteStatus Subgraph::PrepareOpsStartingAt(
int first_execution_plan_index, const std::vector<int>& execution_plan,
int* last_execution_plan_index_prepared) {
if (first_execution_plan_index == 0) {
has_dynamic_tensors_ = false;
// Forwarding inputs without modification won't be not evaluated in the
// operators. So, it needs to look up the subgraph's output tensors at the
// beginning.
has_dynamic_tensors_ = HasDynamicTensorImpl(context_, outputs());
}
for (int execution_plan_index = first_execution_plan_index;
execution_plan_index < execution_plan.size(); execution_plan_index++) {

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/interpreter_test_util.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/testing/util.h"
// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
@ -550,6 +551,36 @@ TEST(TestAddDelegateOwnership, AddDelegateDoesNotTakeOwnership) {
EXPECT_TRUE(destroyed);
}
// The model contains a while loop with a forwarding string input. This test
// makes sure that the dynamic tensor existence in the while subgraph's outputs
// is detected. If not, the while loop will be failed at handling the dynamic
// tensor handling as a static tensor.
TEST(BasicFlatBufferModel, TestHandleModelWithWhileOpContainsForwardingInput) {
const auto model_path =
"tensorflow/lite/testdata/while_op_with_forwarding_input.bin";
std::unique_ptr<tflite::FlatBufferModel> model =
FlatBufferModel::BuildFromFile(model_path);
ASSERT_NE(model, nullptr);
tflite::ops::builtin::BuiltinOpResolver resolver;
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> interpreter;
ASSERT_EQ(builder(&interpreter), kTfLiteOk);
ASSERT_NE(interpreter, nullptr);
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
int32_t* tensor_data = interpreter->typed_tensor<int32_t>(0);
tensor_data[0] = 20;
auto tensor = interpreter->tensor(1);
DynamicBuffer buf;
buf.AddString("a", 1);
buf.WriteToTensor(tensor, /*new_shape=*/nullptr);
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
}
// TODO(aselle): Add tests for serialization of builtin op data types.
// These tests will occur with the evaluation tests of individual operators,
// not here.

Binary file not shown.