From f12082d7af509e9549d8e8fb2b514ccd0db0e84e Mon Sep 17 00:00:00 2001 From: Jaesung Chung Date: Thu, 14 Jan 2021 18:43:36 -0800 Subject: [PATCH] Support old TOCO generated reshape operator cases where the dimension of the given shape input is non 1-D In such cases, the TOCO converted models always have the correct value at the new_shape attribute so we need to copy the new_shape attribute in order to correctly rewrite. This change closes #46423 PiperOrigin-RevId: 351920426 Change-Id: I4f7bea1ad6f59485d4011cf2a52c7bce989c52d7 --- .../serialization/option_writer_generator.cc | 7 ++-- .../tools/serialization/writer_lib_test.cc | 33 ++++++++++++++----- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/tensorflow/lite/tools/serialization/option_writer_generator.cc b/tensorflow/lite/tools/serialization/option_writer_generator.cc index bd26ec6ed6e..ee10d69b4c8 100644 --- a/tensorflow/lite/tools/serialization/option_writer_generator.cc +++ b/tensorflow/lite/tools/serialization/option_writer_generator.cc @@ -274,14 +274,17 @@ void GenerateImportForResizeBilinearOp(FILE* fp) { // Reshape Op infers output shape either from Parameter or from shape tensor // that's is an additional input. When we have this additional shape tensor as // input we don't have the parameter present in this layer. In case of more than -// one input we import an empty vector for the parameters. +// one input and the shape parameter does not have a valid value, we import an +// empty vector for the parameters. void GenerateImportForReshapeOp(FILE* fp) { fprintf(fp, " case BuiltinOperator_RESHAPE: {\n" " const auto* params = reinterpret_cast(builtin_op_data);\n" " flatbuffers::Offset union_type;\n" - " if (node.inputs->size > 1) {\n" + " if (node.inputs->size > 1 && (params->num_dimensions <= 0 || " + "params->num_dimensions > TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT))" + " {\n" " union_type = CreateReshapeOptions(*fbb).Union();\n" " } else {\n" " auto val0 = fbb->CreateVector(std::vector(params->shape, " diff --git a/tensorflow/lite/tools/serialization/writer_lib_test.cc b/tensorflow/lite/tools/serialization/writer_lib_test.cc index 395c40c2e48..c15277534d7 100644 --- a/tensorflow/lite/tools/serialization/writer_lib_test.cc +++ b/tensorflow/lite/tools/serialization/writer_lib_test.cc @@ -311,6 +311,7 @@ INSTANTIATE_TEST_SUITE_P(Writer, SingleSubgraphTest, ::testing::Bool()); struct ReshapeTestPattern { int num_inputs; bool is_param_valid; + bool has_buggy_non_flatten_shape; }; class ReshapeLayerTest : public ::testing::TestWithParam {}; @@ -326,10 +327,19 @@ TEST_P(ReshapeLayerTest, ReshapeLayerTest) { TfLiteQuantization()); ASSERT_LE(param.num_inputs, 2); if (param.num_inputs == 2) { - interpreter.SetTensorParametersReadOnly( - /*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3}, - TfLiteQuantization(), reinterpret_cast(output_shape), - sizeof(output_shape)); + // Some TOCO generated models have buggy shape arguments, which are required + // to be flatten, for example, dims={3, 1} instead of dims={3}. + if (param.has_buggy_non_flatten_shape) { + interpreter.SetTensorParametersReadOnly( + /*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3, 1}, + TfLiteQuantization(), reinterpret_cast(output_shape), + sizeof(output_shape)); + } else { + interpreter.SetTensorParametersReadOnly( + /*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3}, + TfLiteQuantization(), reinterpret_cast(output_shape), + sizeof(output_shape)); + } } interpreter.SetTensorParametersReadWrite(/*tensor_index=*/total_tensors - 1, kTfLiteFloat32, /*name=*/"c", @@ -373,15 +383,22 @@ TEST_P(ReshapeLayerTest, ReshapeLayerTest) { INSTANTIATE_TEST_SUITE_P( Writer, ReshapeLayerTest, ::testing::Values(ReshapeTestPattern{/*num_inputs=*/2, - /*is_param_valid=*/true}, + /*is_param_valid=*/true, + /*has_buggy_non_flatten_shape=*/false}, ReshapeTestPattern{/*num_inputs=*/2, - /*is_param_valid=*/false}, + /*is_param_valid=*/false, + /*has_buggy_non_flatten_shape=*/false}, ReshapeTestPattern{/*num_inputs=*/1, - /*is_param_valid=*/true}), + /*is_param_valid=*/true, + /*has_buggy_non_flatten_shape=*/false}, + ReshapeTestPattern{/*num_inputs=*/2, + /*is_param_valid=*/true, + /*has_buggy_non_flatten_shape=*/true}), [](const ::testing::TestParamInfo& info) { std::stringstream ss; ss << "num_inputs_" << info.param.num_inputs << "_valid_param_" - << info.param.is_param_valid; + << info.param.is_param_valid << "_buggy_shape_" + << info.param.has_buggy_non_flatten_shape; std::string name = ss.str(); return name; });