Check the dynamic tensor existence of the subgraph outputs at the beginning
Forwarding inputs without modification won't be not evaluated in the subgraph's operators. So, it needs to look up the subgraph's output tensors at the beginning. PiperOrigin-RevId: 352703717 Change-Id: If299e706cf5c9d4c722e0f85751e7d085a0cd4c3
This commit is contained in:
parent
87933916c3
commit
2d40649f53
@ -639,6 +639,7 @@ cc_test(
|
||||
"testdata/test_min_runtime.bin",
|
||||
"testdata/test_model.bin",
|
||||
"testdata/test_model_broken.bin",
|
||||
"testdata/while_op_with_forwarding_input.bin",
|
||||
],
|
||||
tags = [
|
||||
"tflite_not_portable",
|
||||
@ -647,6 +648,7 @@ cc_test(
|
||||
deps = [
|
||||
":framework",
|
||||
":interpreter_test_util",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/testing:util",
|
||||
|
@ -886,7 +886,10 @@ TfLiteStatus Subgraph::PrepareOpsStartingAt(
|
||||
int first_execution_plan_index, const std::vector<int>& execution_plan,
|
||||
int* last_execution_plan_index_prepared) {
|
||||
if (first_execution_plan_index == 0) {
|
||||
has_dynamic_tensors_ = false;
|
||||
// Forwarding inputs without modification won't be not evaluated in the
|
||||
// operators. So, it needs to look up the subgraph's output tensors at the
|
||||
// beginning.
|
||||
has_dynamic_tensors_ = HasDynamicTensorImpl(context_, outputs());
|
||||
}
|
||||
for (int execution_plan_index = first_execution_plan_index;
|
||||
execution_plan_index < execution_plan.size(); execution_plan_index++) {
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/interpreter_test_util.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
|
||||
// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
|
||||
@ -550,6 +551,36 @@ TEST(TestAddDelegateOwnership, AddDelegateDoesNotTakeOwnership) {
|
||||
EXPECT_TRUE(destroyed);
|
||||
}
|
||||
|
||||
// The model contains a while loop with a forwarding string input. This test
|
||||
// makes sure that the dynamic tensor existence in the while subgraph's outputs
|
||||
// is detected. If not, the while loop will be failed at handling the dynamic
|
||||
// tensor handling as a static tensor.
|
||||
TEST(BasicFlatBufferModel, TestHandleModelWithWhileOpContainsForwardingInput) {
|
||||
const auto model_path =
|
||||
"tensorflow/lite/testdata/while_op_with_forwarding_input.bin";
|
||||
|
||||
std::unique_ptr<tflite::FlatBufferModel> model =
|
||||
FlatBufferModel::BuildFromFile(model_path);
|
||||
ASSERT_NE(model, nullptr);
|
||||
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
InterpreterBuilder builder(*model, resolver);
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
ASSERT_EQ(builder(&interpreter), kTfLiteOk);
|
||||
ASSERT_NE(interpreter, nullptr);
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
int32_t* tensor_data = interpreter->typed_tensor<int32_t>(0);
|
||||
tensor_data[0] = 20;
|
||||
|
||||
auto tensor = interpreter->tensor(1);
|
||||
DynamicBuffer buf;
|
||||
buf.AddString("a", 1);
|
||||
buf.WriteToTensor(tensor, /*new_shape=*/nullptr);
|
||||
|
||||
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||
}
|
||||
|
||||
// TODO(aselle): Add tests for serialization of builtin op data types.
|
||||
// These tests will occur with the evaluation tests of individual operators,
|
||||
// not here.
|
||||
|
BIN
tensorflow/lite/testdata/while_op_with_forwarding_input.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/while_op_with_forwarding_input.bin
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user