Fix to handle Reshape Layer in experimental TFLite writer library.
Changes: 1. Updated handling of ReshapeParams. 2. Added write_lib tests to check different scenarios. PiperOrigin-RevId: 321508374 Change-Id: I6e22be4d5fcfd6b771e0e5f1d28e9459deb49af7
This commit is contained in:
parent
9c20fbff6c
commit
d2d6c3f07a
@ -265,6 +265,32 @@ void GenerateImportForResizeBilinearOp(FILE* fp) {
|
|||||||
" }\n break;\n");
|
" }\n break;\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reshape Op infers output shape either from Parameter or from shape tensor
|
||||||
|
// that's is an additional input. When we have this additional shape tensor as
|
||||||
|
// input we don't have the parameter present in this layer. In case of more than
|
||||||
|
// one input we import an empty vector for the parameters.
|
||||||
|
void GenerateImportForReshapeOp(FILE* fp) {
|
||||||
|
fprintf(fp,
|
||||||
|
" case BuiltinOperator_RESHAPE: {\n"
|
||||||
|
" const auto* params = reinterpret_cast<const "
|
||||||
|
"TfLiteReshapeParams*>(builtin_op_data);\n"
|
||||||
|
" flatbuffers::Offset<void> union_type;\n"
|
||||||
|
" if ((node.inputs->size > 1) &&\n"
|
||||||
|
" (params->num_dimensions < 0 ||\n"
|
||||||
|
" params->num_dimensions >= "
|
||||||
|
"TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT)) {\n"
|
||||||
|
" union_type = CreateReshapeOptions(*fbb).Union();\n"
|
||||||
|
" } else {\n"
|
||||||
|
" auto val0 = fbb->CreateVector(std::vector<int>(params->shape, "
|
||||||
|
"params->shape + params->num_dimensions));\n"
|
||||||
|
" union_type = CreateReshapeOptions(*fbb, "
|
||||||
|
"val0).Union();\n"
|
||||||
|
" }\n"
|
||||||
|
" return std::make_pair(BuiltinOptions_ReshapeOptions, "
|
||||||
|
"union_type);\n"
|
||||||
|
" }\n break;\n");
|
||||||
|
}
|
||||||
|
|
||||||
void GenerateImportForOp(FILE* fp, const std::string& op_name,
|
void GenerateImportForOp(FILE* fp, const std::string& op_name,
|
||||||
const std::string& option_name,
|
const std::string& option_name,
|
||||||
const std::string& option_type,
|
const std::string& option_type,
|
||||||
@ -276,6 +302,13 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Special case Reshape that may have 'new_shape' field missing from the
|
||||||
|
// parameters.
|
||||||
|
if (struct_name == "TfLiteReshapeParams") {
|
||||||
|
GenerateImportForReshapeOp(fp);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
|
fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
|
||||||
if (options->num_elems != 0) {
|
if (options->num_elems != 0) {
|
||||||
fprintf(fp,
|
fprintf(fp,
|
||||||
|
@ -31,7 +31,7 @@ namespace tflite {
|
|||||||
|
|
||||||
std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
|
std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
|
||||||
flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op,
|
flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op,
|
||||||
void* builtin_op_data) {
|
void* builtin_op_data, const TfLiteNode& node) {
|
||||||
switch (op) {
|
switch (op) {
|
||||||
#include "tensorflow/lite/experimental/writer/option_writer_generated.h"
|
#include "tensorflow/lite/experimental/writer/option_writer_generated.h"
|
||||||
}
|
}
|
||||||
@ -82,7 +82,7 @@ SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) {
|
|||||||
// builtin
|
// builtin
|
||||||
auto builtin_options_and_type = CreateBuiltinUnion(
|
auto builtin_options_and_type = CreateBuiltinUnion(
|
||||||
fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
|
fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
|
||||||
node.builtin_data);
|
node.builtin_data, node);
|
||||||
builtin_options = builtin_options_and_type.second;
|
builtin_options = builtin_options_and_type.second;
|
||||||
builtin_options_type = builtin_options_and_type.first;
|
builtin_options_type = builtin_options_and_type.first;
|
||||||
} else {
|
} else {
|
||||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/experimental/writer/writer_lib.h"
|
#include "tensorflow/lite/experimental/writer/writer_lib.h"
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
@ -184,6 +186,79 @@ TEST(Writer, PerTensorQuantizedModelTest) {
|
|||||||
CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
|
CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ReshapeTestPattern {
|
||||||
|
int num_inputs;
|
||||||
|
bool is_param_valid;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ReshapeLayerTest : public ::testing::TestWithParam<ReshapeTestPattern> {};
|
||||||
|
|
||||||
|
TEST_P(ReshapeLayerTest, ReshapeLayerTest) {
|
||||||
|
const auto param = GetParam();
|
||||||
|
Interpreter interpreter;
|
||||||
|
const int total_tensors = param.num_inputs + 1;
|
||||||
|
interpreter.AddTensors(total_tensors);
|
||||||
|
int output_shape[] = {1, 2, 3};
|
||||||
|
interpreter.SetTensorParametersReadWrite(/*tensor_index=*/0, kTfLiteFloat32,
|
||||||
|
/*name=*/"a", /*dims=*/{6},
|
||||||
|
TfLiteQuantization());
|
||||||
|
ASSERT_LE(param.num_inputs, 2);
|
||||||
|
if (param.num_inputs == 2) {
|
||||||
|
interpreter.SetTensorParametersReadOnly(
|
||||||
|
/*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3},
|
||||||
|
TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
|
||||||
|
sizeof(output_shape));
|
||||||
|
}
|
||||||
|
interpreter.SetTensorParametersReadWrite(/*tensor_index=*/total_tensors - 1,
|
||||||
|
kTfLiteFloat32, /*name=*/"c",
|
||||||
|
/*dims=*/{3}, TfLiteQuantization());
|
||||||
|
|
||||||
|
std::vector<int> input_tensors(param.num_inputs);
|
||||||
|
std::iota(input_tensors.begin(), input_tensors.end(), 0);
|
||||||
|
|
||||||
|
interpreter.SetInputs(input_tensors);
|
||||||
|
interpreter.SetOutputs({total_tensors - 1});
|
||||||
|
const char* initial_data = "";
|
||||||
|
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||||
|
TfLiteReshapeParams* builtin_data = reinterpret_cast<TfLiteReshapeParams*>(
|
||||||
|
malloc(sizeof(TfLiteReshapeParams)));
|
||||||
|
if (param.is_param_valid) {
|
||||||
|
builtin_data->num_dimensions = 3;
|
||||||
|
for (int dim = 0; dim < builtin_data->num_dimensions; ++dim) {
|
||||||
|
builtin_data->shape[dim] = output_shape[dim];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_RESHAPE, 1);
|
||||||
|
interpreter.AddNodeWithParameters(input_tensors,
|
||||||
|
/*outputs=*/{total_tensors - 1},
|
||||||
|
initial_data, /*init_data_size=*/0,
|
||||||
|
reinterpret_cast<void*>(builtin_data), reg);
|
||||||
|
|
||||||
|
SubgraphWriter writer(&interpreter.primary_subgraph());
|
||||||
|
std::string filename = absl::StrCat("/tmp/test_reshape_", param.num_inputs,
|
||||||
|
"_", param.is_param_valid, ".tflite");
|
||||||
|
writer.Write(filename);
|
||||||
|
std::unique_ptr<FlatBufferModel> model =
|
||||||
|
FlatBufferModel::BuildFromFile(filename.c_str());
|
||||||
|
InterpreterBuilder builder(*model, resolver);
|
||||||
|
std::unique_ptr<Interpreter> new_interpreter;
|
||||||
|
builder(&new_interpreter);
|
||||||
|
ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
Writer, ReshapeLayerTest,
|
||||||
|
::testing::Values(ReshapeTestPattern{/*num_inputs=*/2,
|
||||||
|
/*is_param_valid=*/true},
|
||||||
|
ReshapeTestPattern{/*num_inputs=*/2,
|
||||||
|
/*is_param_valid=*/false},
|
||||||
|
ReshapeTestPattern{/*num_inputs=*/1,
|
||||||
|
/*is_param_valid=*/true}),
|
||||||
|
[](const ::testing::TestParamInfo<ReshapeLayerTest::ParamType>& info) {
|
||||||
|
std::string name = absl::StrCat("num_inputs_", info.param.num_inputs,
|
||||||
|
"_isvalid_", info.param.is_param_valid);
|
||||||
|
return name;
|
||||||
|
});
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
|
Loading…
Reference in New Issue
Block a user