Fix to handle Reshape Layer in experimental TFLite writer library.

Changes:
 1. Updated handling of ReshapeParams.
 2. Added write_lib tests to check different scenarios.
PiperOrigin-RevId: 321508374
Change-Id: I6e22be4d5fcfd6b771e0e5f1d28e9459deb49af7
This commit is contained in:
A. Unique TensorFlower 2020-07-15 23:20:19 -07:00 committed by TensorFlower Gardener
parent 9c20fbff6c
commit d2d6c3f07a
3 changed files with 110 additions and 2 deletions

View File

@ -265,6 +265,32 @@ void GenerateImportForResizeBilinearOp(FILE* fp) {
" }\n break;\n");
}
// Reshape Op infers output shape either from Parameter or from shape tensor
// that's is an additional input. When we have this additional shape tensor as
// input we don't have the parameter present in this layer. In case of more than
// one input we import an empty vector for the parameters.
void GenerateImportForReshapeOp(FILE* fp) {
fprintf(fp,
" case BuiltinOperator_RESHAPE: {\n"
" const auto* params = reinterpret_cast<const "
"TfLiteReshapeParams*>(builtin_op_data);\n"
" flatbuffers::Offset<void> union_type;\n"
" if ((node.inputs->size > 1) &&\n"
" (params->num_dimensions < 0 ||\n"
" params->num_dimensions >= "
"TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT)) {\n"
" union_type = CreateReshapeOptions(*fbb).Union();\n"
" } else {\n"
" auto val0 = fbb->CreateVector(std::vector<int>(params->shape, "
"params->shape + params->num_dimensions));\n"
" union_type = CreateReshapeOptions(*fbb, "
"val0).Union();\n"
" }\n"
" return std::make_pair(BuiltinOptions_ReshapeOptions, "
"union_type);\n"
" }\n break;\n");
}
void GenerateImportForOp(FILE* fp, const std::string& op_name,
const std::string& option_name,
const std::string& option_type,
@ -276,6 +302,13 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name,
return;
}
// Special case Reshape that may have 'new_shape' field missing from the
// parameters.
if (struct_name == "TfLiteReshapeParams") {
GenerateImportForReshapeOp(fp);
return;
}
fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
if (options->num_elems != 0) {
fprintf(fp,

View File

@ -31,7 +31,7 @@ namespace tflite {
std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op,
void* builtin_op_data) {
void* builtin_op_data, const TfLiteNode& node) {
switch (op) {
#include "tensorflow/lite/experimental/writer/option_writer_generated.h"
}
@ -82,7 +82,7 @@ SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) {
// builtin
auto builtin_options_and_type = CreateBuiltinUnion(
fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
node.builtin_data);
node.builtin_data, node);
builtin_options = builtin_options_and_type.second;
builtin_options_type = builtin_options_and_type.first;
} else {

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/lite/experimental/writer/writer_lib.h"
#include <numeric>
#include <gtest/gtest.h>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"
@ -184,6 +186,79 @@ TEST(Writer, PerTensorQuantizedModelTest) {
CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
}
struct ReshapeTestPattern {
int num_inputs;
bool is_param_valid;
};
class ReshapeLayerTest : public ::testing::TestWithParam<ReshapeTestPattern> {};
TEST_P(ReshapeLayerTest, ReshapeLayerTest) {
const auto param = GetParam();
Interpreter interpreter;
const int total_tensors = param.num_inputs + 1;
interpreter.AddTensors(total_tensors);
int output_shape[] = {1, 2, 3};
interpreter.SetTensorParametersReadWrite(/*tensor_index=*/0, kTfLiteFloat32,
/*name=*/"a", /*dims=*/{6},
TfLiteQuantization());
ASSERT_LE(param.num_inputs, 2);
if (param.num_inputs == 2) {
interpreter.SetTensorParametersReadOnly(
/*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3},
TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
sizeof(output_shape));
}
interpreter.SetTensorParametersReadWrite(/*tensor_index=*/total_tensors - 1,
kTfLiteFloat32, /*name=*/"c",
/*dims=*/{3}, TfLiteQuantization());
std::vector<int> input_tensors(param.num_inputs);
std::iota(input_tensors.begin(), input_tensors.end(), 0);
interpreter.SetInputs(input_tensors);
interpreter.SetOutputs({total_tensors - 1});
const char* initial_data = "";
tflite::ops::builtin::BuiltinOpResolver resolver;
TfLiteReshapeParams* builtin_data = reinterpret_cast<TfLiteReshapeParams*>(
malloc(sizeof(TfLiteReshapeParams)));
if (param.is_param_valid) {
builtin_data->num_dimensions = 3;
for (int dim = 0; dim < builtin_data->num_dimensions; ++dim) {
builtin_data->shape[dim] = output_shape[dim];
}
}
const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_RESHAPE, 1);
interpreter.AddNodeWithParameters(input_tensors,
/*outputs=*/{total_tensors - 1},
initial_data, /*init_data_size=*/0,
reinterpret_cast<void*>(builtin_data), reg);
SubgraphWriter writer(&interpreter.primary_subgraph());
std::string filename = absl::StrCat("/tmp/test_reshape_", param.num_inputs,
"_", param.is_param_valid, ".tflite");
writer.Write(filename);
std::unique_ptr<FlatBufferModel> model =
FlatBufferModel::BuildFromFile(filename.c_str());
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> new_interpreter;
builder(&new_interpreter);
ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
}
INSTANTIATE_TEST_SUITE_P(
Writer, ReshapeLayerTest,
::testing::Values(ReshapeTestPattern{/*num_inputs=*/2,
/*is_param_valid=*/true},
ReshapeTestPattern{/*num_inputs=*/2,
/*is_param_valid=*/false},
ReshapeTestPattern{/*num_inputs=*/1,
/*is_param_valid=*/true}),
[](const ::testing::TestParamInfo<ReshapeLayerTest::ParamType>& info) {
std::string name = absl::StrCat("num_inputs_", info.param.num_inputs,
"_isvalid_", info.param.is_param_valid);
return name;
});
} // namespace tflite
int main(int argc, char** argv) {