From d2d6c3f07a0b874e64a024c767deb7c9fb39b704 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Jul 2020 23:20:19 -0700 Subject: [PATCH] Fix to handle Reshape Layer in experimental TFLite writer library. Changes: 1. Updated handling of ReshapeParams. 2. Added write_lib tests to check different scenarios. PiperOrigin-RevId: 321508374 Change-Id: I6e22be4d5fcfd6b771e0e5f1d28e9459deb49af7 --- .../writer/option_writer_generator.cc | 33 ++++++++ .../lite/experimental/writer/writer_lib.cc | 4 +- .../experimental/writer/writer_lib_test.cc | 75 +++++++++++++++++++ 3 files changed, 110 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/experimental/writer/option_writer_generator.cc b/tensorflow/lite/experimental/writer/option_writer_generator.cc index a565422457c..e484c5ba2f4 100644 --- a/tensorflow/lite/experimental/writer/option_writer_generator.cc +++ b/tensorflow/lite/experimental/writer/option_writer_generator.cc @@ -265,6 +265,32 @@ void GenerateImportForResizeBilinearOp(FILE* fp) { " }\n break;\n"); } +// Reshape Op infers output shape either from Parameter or from shape tensor +// that's is an additional input. When we have this additional shape tensor as +// input we don't have the parameter present in this layer. In case of more than +// one input we import an empty vector for the parameters. +void GenerateImportForReshapeOp(FILE* fp) { + fprintf(fp, + " case BuiltinOperator_RESHAPE: {\n" + " const auto* params = reinterpret_cast(builtin_op_data);\n" + " flatbuffers::Offset union_type;\n" + " if ((node.inputs->size > 1) &&\n" + " (params->num_dimensions < 0 ||\n" + " params->num_dimensions >= " + "TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT)) {\n" + " union_type = CreateReshapeOptions(*fbb).Union();\n" + " } else {\n" + " auto val0 = fbb->CreateVector(std::vector(params->shape, " + "params->shape + params->num_dimensions));\n" + " union_type = CreateReshapeOptions(*fbb, " + "val0).Union();\n" + " }\n" + " return std::make_pair(BuiltinOptions_ReshapeOptions, " + "union_type);\n" + " }\n break;\n"); +} + void GenerateImportForOp(FILE* fp, const std::string& op_name, const std::string& option_name, const std::string& option_type, @@ -276,6 +302,13 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name, return; } + // Special case Reshape that may have 'new_shape' field missing from the + // parameters. + if (struct_name == "TfLiteReshapeParams") { + GenerateImportForReshapeOp(fp); + return; + } + fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str()); if (options->num_elems != 0) { fprintf(fp, diff --git a/tensorflow/lite/experimental/writer/writer_lib.cc b/tensorflow/lite/experimental/writer/writer_lib.cc index 85f57527c31..2c71919724c 100644 --- a/tensorflow/lite/experimental/writer/writer_lib.cc +++ b/tensorflow/lite/experimental/writer/writer_lib.cc @@ -31,7 +31,7 @@ namespace tflite { std::pair> CreateBuiltinUnion( flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op, - void* builtin_op_data) { + void* builtin_op_data, const TfLiteNode& node) { switch (op) { #include "tensorflow/lite/experimental/writer/option_writer_generated.h" } @@ -82,7 +82,7 @@ SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) { // builtin auto builtin_options_and_type = CreateBuiltinUnion( fbb, static_cast(registration.builtin_code), - node.builtin_data); + node.builtin_data, node); builtin_options = builtin_options_and_type.second; builtin_options_type = builtin_options_and_type.first; } else { diff --git a/tensorflow/lite/experimental/writer/writer_lib_test.cc b/tensorflow/lite/experimental/writer/writer_lib_test.cc index 41cca88ead7..4cab27ecb2d 100644 --- a/tensorflow/lite/experimental/writer/writer_lib_test.cc +++ b/tensorflow/lite/experimental/writer/writer_lib_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/experimental/writer/writer_lib.h" +#include + #include #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/interpreter.h" @@ -184,6 +186,79 @@ TEST(Writer, PerTensorQuantizedModelTest) { CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk); } +struct ReshapeTestPattern { + int num_inputs; + bool is_param_valid; +}; + +class ReshapeLayerTest : public ::testing::TestWithParam {}; + +TEST_P(ReshapeLayerTest, ReshapeLayerTest) { + const auto param = GetParam(); + Interpreter interpreter; + const int total_tensors = param.num_inputs + 1; + interpreter.AddTensors(total_tensors); + int output_shape[] = {1, 2, 3}; + interpreter.SetTensorParametersReadWrite(/*tensor_index=*/0, kTfLiteFloat32, + /*name=*/"a", /*dims=*/{6}, + TfLiteQuantization()); + ASSERT_LE(param.num_inputs, 2); + if (param.num_inputs == 2) { + interpreter.SetTensorParametersReadOnly( + /*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3}, + TfLiteQuantization(), reinterpret_cast(output_shape), + sizeof(output_shape)); + } + interpreter.SetTensorParametersReadWrite(/*tensor_index=*/total_tensors - 1, + kTfLiteFloat32, /*name=*/"c", + /*dims=*/{3}, TfLiteQuantization()); + + std::vector input_tensors(param.num_inputs); + std::iota(input_tensors.begin(), input_tensors.end(), 0); + + interpreter.SetInputs(input_tensors); + interpreter.SetOutputs({total_tensors - 1}); + const char* initial_data = ""; + tflite::ops::builtin::BuiltinOpResolver resolver; + TfLiteReshapeParams* builtin_data = reinterpret_cast( + malloc(sizeof(TfLiteReshapeParams))); + if (param.is_param_valid) { + builtin_data->num_dimensions = 3; + for (int dim = 0; dim < builtin_data->num_dimensions; ++dim) { + builtin_data->shape[dim] = output_shape[dim]; + } + } + const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_RESHAPE, 1); + interpreter.AddNodeWithParameters(input_tensors, + /*outputs=*/{total_tensors - 1}, + initial_data, /*init_data_size=*/0, + reinterpret_cast(builtin_data), reg); + + SubgraphWriter writer(&interpreter.primary_subgraph()); + std::string filename = absl::StrCat("/tmp/test_reshape_", param.num_inputs, + "_", param.is_param_valid, ".tflite"); + writer.Write(filename); + std::unique_ptr model = + FlatBufferModel::BuildFromFile(filename.c_str()); + InterpreterBuilder builder(*model, resolver); + std::unique_ptr new_interpreter; + builder(&new_interpreter); + ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk); +} + +INSTANTIATE_TEST_SUITE_P( + Writer, ReshapeLayerTest, + ::testing::Values(ReshapeTestPattern{/*num_inputs=*/2, + /*is_param_valid=*/true}, + ReshapeTestPattern{/*num_inputs=*/2, + /*is_param_valid=*/false}, + ReshapeTestPattern{/*num_inputs=*/1, + /*is_param_valid=*/true}), + [](const ::testing::TestParamInfo& info) { + std::string name = absl::StrCat("num_inputs_", info.param.num_inputs, + "_isvalid_", info.param.is_param_valid); + return name; + }); } // namespace tflite int main(int argc, char** argv) {