diff --git a/tensorflow/lite/experimental/writer/option_writer_generator.cc b/tensorflow/lite/experimental/writer/option_writer_generator.cc index a565422457c..898f4a95ef6 100644 --- a/tensorflow/lite/experimental/writer/option_writer_generator.cc +++ b/tensorflow/lite/experimental/writer/option_writer_generator.cc @@ -265,6 +265,29 @@ void GenerateImportForResizeBilinearOp(FILE* fp) { " }\n break;\n"); } +// Reshape Op infers output shape either from Parameter or from shape tensor +// that's is an additional input. When we have this additional shape tensor as +// input we don't have the parameter present in this layer. In case of more than +// one input we import an empty vector for the parameters. +void GenerateImportForReshapeOp(FILE* fp) { + fprintf(fp, + " case BuiltinOperator_RESHAPE: {\n" + " const auto* params = reinterpret_cast(builtin_op_data);\n" + " flatbuffers::Offset union_type;\n" + " if (node.inputs->size > 1) {\n" + " union_type = CreateReshapeOptions(*fbb).Union();\n" + " } else {\n" + " auto val0 = fbb->CreateVector(std::vector(params->shape, " + "params->shape + params->num_dimensions));\n" + " union_type = CreateReshapeOptions(*fbb, " + "val0).Union();\n" + " }\n" + " return std::make_pair(BuiltinOptions_ReshapeOptions, " + "union_type);\n" + " }\n break;\n"); +} + void GenerateImportForOp(FILE* fp, const std::string& op_name, const std::string& option_name, const std::string& option_type, @@ -276,6 +299,13 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name, return; } + // Special case Reshape that may have 'new_shape' field missing from the + // parameters. + if (struct_name == "TfLiteReshapeParams") { + GenerateImportForReshapeOp(fp); + return; + } + fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str()); if (options->num_elems != 0) { fprintf(fp, diff --git a/tensorflow/lite/experimental/writer/writer_lib.cc b/tensorflow/lite/experimental/writer/writer_lib.cc index ed26c7f9038..2f509daa9cb 100644 --- a/tensorflow/lite/experimental/writer/writer_lib.cc +++ b/tensorflow/lite/experimental/writer/writer_lib.cc @@ -31,7 +31,7 @@ namespace tflite { std::pair> CreateBuiltinUnion( flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op, - void* builtin_op_data) { + void* builtin_op_data, const TfLiteNode& node) { switch (op) { #include "tensorflow/lite/experimental/writer/option_writer_generated.h" } @@ -82,7 +82,7 @@ SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) { // builtin auto builtin_options_and_type = CreateBuiltinUnion( fbb, static_cast(registration.builtin_code), - node.builtin_data); + node.builtin_data, node); builtin_options = builtin_options_and_type.second; builtin_options_type = builtin_options_and_type.first; } else { diff --git a/tensorflow/lite/experimental/writer/writer_lib_test.cc b/tensorflow/lite/experimental/writer/writer_lib_test.cc index 41cca88ead7..fb59482f705 100644 --- a/tensorflow/lite/experimental/writer/writer_lib_test.cc +++ b/tensorflow/lite/experimental/writer/writer_lib_test.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/lite/experimental/writer/writer_lib.h" +#include +#include + #include #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/interpreter.h" @@ -184,6 +187,83 @@ TEST(Writer, PerTensorQuantizedModelTest) { CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk); } +struct ReshapeTestPattern { + int num_inputs; + bool is_param_valid; +}; + +class ReshapeLayerTest : public ::testing::TestWithParam {}; + +TEST_P(ReshapeLayerTest, ReshapeLayerTest) { + const auto param = GetParam(); + Interpreter interpreter; + const int total_tensors = param.num_inputs + 1; + interpreter.AddTensors(total_tensors); + int output_shape[] = {1, 2, 3}; + interpreter.SetTensorParametersReadWrite(/*tensor_index=*/0, kTfLiteFloat32, + /*name=*/"a", /*dims=*/{6}, + TfLiteQuantization()); + ASSERT_LE(param.num_inputs, 2); + if (param.num_inputs == 2) { + interpreter.SetTensorParametersReadOnly( + /*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3}, + TfLiteQuantization(), reinterpret_cast(output_shape), + sizeof(output_shape)); + } + interpreter.SetTensorParametersReadWrite(/*tensor_index=*/total_tensors - 1, + kTfLiteFloat32, /*name=*/"c", + /*dims=*/{3}, TfLiteQuantization()); + + std::vector input_tensors(param.num_inputs); + std::iota(input_tensors.begin(), input_tensors.end(), 0); + + interpreter.SetInputs(input_tensors); + interpreter.SetOutputs({total_tensors - 1}); + const char* initial_data = ""; + tflite::ops::builtin::BuiltinOpResolver resolver; + TfLiteReshapeParams* builtin_data = reinterpret_cast( + malloc(sizeof(TfLiteReshapeParams))); + if (param.is_param_valid) { + builtin_data->num_dimensions = 3; + for (int dim = 0; dim < builtin_data->num_dimensions; ++dim) { + builtin_data->shape[dim] = output_shape[dim]; + } + } + const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_RESHAPE, 1); + interpreter.AddNodeWithParameters(input_tensors, + /*outputs=*/{total_tensors - 1}, + initial_data, /*init_data_size=*/0, + reinterpret_cast(builtin_data), reg); + + SubgraphWriter writer(&interpreter.primary_subgraph()); + std::stringstream ss; + ss << "/tmp/test_reshape_" << param.num_inputs << param.is_param_valid + << ".tflite"; + std::string filename = ss.str(); + writer.Write(filename); + std::unique_ptr model = + FlatBufferModel::BuildFromFile(filename.c_str()); + InterpreterBuilder builder(*model, resolver); + std::unique_ptr new_interpreter; + builder(&new_interpreter); + ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk); +} + +INSTANTIATE_TEST_SUITE_P( + Writer, ReshapeLayerTest, + ::testing::Values(ReshapeTestPattern{/*num_inputs=*/2, + /*is_param_valid=*/true}, + ReshapeTestPattern{/*num_inputs=*/2, + /*is_param_valid=*/false}, + ReshapeTestPattern{/*num_inputs=*/1, + /*is_param_valid=*/true}), + [](const ::testing::TestParamInfo& info) { + std::stringstream ss; + ss << "num_inputs_" << info.param.num_inputs << "_valid_param_" + << info.param.is_param_valid; + std::string name = ss.str(); + return name; + }); } // namespace tflite int main(int argc, char** argv) {