Fix to handle Reshape Layer in experimental TFLite writer library.
Changes: 1. Updated handling of ReshapeParams. 2. Added write_lib tests to check different scenarios. PiperOrigin-RevId: 321508374 Change-Id: I6e22be4d5fcfd6b771e0e5f1d28e9459deb49af7
This commit is contained in:
parent
9c20fbff6c
commit
d2d6c3f07a
@ -265,6 +265,32 @@ void GenerateImportForResizeBilinearOp(FILE* fp) {
|
||||
" }\n break;\n");
|
||||
}
|
||||
|
||||
// Reshape Op infers output shape either from Parameter or from shape tensor
|
||||
// that's is an additional input. When we have this additional shape tensor as
|
||||
// input we don't have the parameter present in this layer. In case of more than
|
||||
// one input we import an empty vector for the parameters.
|
||||
void GenerateImportForReshapeOp(FILE* fp) {
|
||||
fprintf(fp,
|
||||
" case BuiltinOperator_RESHAPE: {\n"
|
||||
" const auto* params = reinterpret_cast<const "
|
||||
"TfLiteReshapeParams*>(builtin_op_data);\n"
|
||||
" flatbuffers::Offset<void> union_type;\n"
|
||||
" if ((node.inputs->size > 1) &&\n"
|
||||
" (params->num_dimensions < 0 ||\n"
|
||||
" params->num_dimensions >= "
|
||||
"TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT)) {\n"
|
||||
" union_type = CreateReshapeOptions(*fbb).Union();\n"
|
||||
" } else {\n"
|
||||
" auto val0 = fbb->CreateVector(std::vector<int>(params->shape, "
|
||||
"params->shape + params->num_dimensions));\n"
|
||||
" union_type = CreateReshapeOptions(*fbb, "
|
||||
"val0).Union();\n"
|
||||
" }\n"
|
||||
" return std::make_pair(BuiltinOptions_ReshapeOptions, "
|
||||
"union_type);\n"
|
||||
" }\n break;\n");
|
||||
}
|
||||
|
||||
void GenerateImportForOp(FILE* fp, const std::string& op_name,
|
||||
const std::string& option_name,
|
||||
const std::string& option_type,
|
||||
@ -276,6 +302,13 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name,
|
||||
return;
|
||||
}
|
||||
|
||||
// Special case Reshape that may have 'new_shape' field missing from the
|
||||
// parameters.
|
||||
if (struct_name == "TfLiteReshapeParams") {
|
||||
GenerateImportForReshapeOp(fp);
|
||||
return;
|
||||
}
|
||||
|
||||
fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
|
||||
if (options->num_elems != 0) {
|
||||
fprintf(fp,
|
||||
|
@ -31,7 +31,7 @@ namespace tflite {
|
||||
|
||||
std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
|
||||
flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op,
|
||||
void* builtin_op_data) {
|
||||
void* builtin_op_data, const TfLiteNode& node) {
|
||||
switch (op) {
|
||||
#include "tensorflow/lite/experimental/writer/option_writer_generated.h"
|
||||
}
|
||||
@ -82,7 +82,7 @@ SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) {
|
||||
// builtin
|
||||
auto builtin_options_and_type = CreateBuiltinUnion(
|
||||
fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
|
||||
node.builtin_data);
|
||||
node.builtin_data, node);
|
||||
builtin_options = builtin_options_and_type.second;
|
||||
builtin_options_type = builtin_options_and_type.first;
|
||||
} else {
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/experimental/writer/writer_lib.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
@ -184,6 +186,79 @@ TEST(Writer, PerTensorQuantizedModelTest) {
|
||||
CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
|
||||
}
|
||||
|
||||
struct ReshapeTestPattern {
|
||||
int num_inputs;
|
||||
bool is_param_valid;
|
||||
};
|
||||
|
||||
class ReshapeLayerTest : public ::testing::TestWithParam<ReshapeTestPattern> {};
|
||||
|
||||
TEST_P(ReshapeLayerTest, ReshapeLayerTest) {
|
||||
const auto param = GetParam();
|
||||
Interpreter interpreter;
|
||||
const int total_tensors = param.num_inputs + 1;
|
||||
interpreter.AddTensors(total_tensors);
|
||||
int output_shape[] = {1, 2, 3};
|
||||
interpreter.SetTensorParametersReadWrite(/*tensor_index=*/0, kTfLiteFloat32,
|
||||
/*name=*/"a", /*dims=*/{6},
|
||||
TfLiteQuantization());
|
||||
ASSERT_LE(param.num_inputs, 2);
|
||||
if (param.num_inputs == 2) {
|
||||
interpreter.SetTensorParametersReadOnly(
|
||||
/*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3},
|
||||
TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
|
||||
sizeof(output_shape));
|
||||
}
|
||||
interpreter.SetTensorParametersReadWrite(/*tensor_index=*/total_tensors - 1,
|
||||
kTfLiteFloat32, /*name=*/"c",
|
||||
/*dims=*/{3}, TfLiteQuantization());
|
||||
|
||||
std::vector<int> input_tensors(param.num_inputs);
|
||||
std::iota(input_tensors.begin(), input_tensors.end(), 0);
|
||||
|
||||
interpreter.SetInputs(input_tensors);
|
||||
interpreter.SetOutputs({total_tensors - 1});
|
||||
const char* initial_data = "";
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
TfLiteReshapeParams* builtin_data = reinterpret_cast<TfLiteReshapeParams*>(
|
||||
malloc(sizeof(TfLiteReshapeParams)));
|
||||
if (param.is_param_valid) {
|
||||
builtin_data->num_dimensions = 3;
|
||||
for (int dim = 0; dim < builtin_data->num_dimensions; ++dim) {
|
||||
builtin_data->shape[dim] = output_shape[dim];
|
||||
}
|
||||
}
|
||||
const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_RESHAPE, 1);
|
||||
interpreter.AddNodeWithParameters(input_tensors,
|
||||
/*outputs=*/{total_tensors - 1},
|
||||
initial_data, /*init_data_size=*/0,
|
||||
reinterpret_cast<void*>(builtin_data), reg);
|
||||
|
||||
SubgraphWriter writer(&interpreter.primary_subgraph());
|
||||
std::string filename = absl::StrCat("/tmp/test_reshape_", param.num_inputs,
|
||||
"_", param.is_param_valid, ".tflite");
|
||||
writer.Write(filename);
|
||||
std::unique_ptr<FlatBufferModel> model =
|
||||
FlatBufferModel::BuildFromFile(filename.c_str());
|
||||
InterpreterBuilder builder(*model, resolver);
|
||||
std::unique_ptr<Interpreter> new_interpreter;
|
||||
builder(&new_interpreter);
|
||||
ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
Writer, ReshapeLayerTest,
|
||||
::testing::Values(ReshapeTestPattern{/*num_inputs=*/2,
|
||||
/*is_param_valid=*/true},
|
||||
ReshapeTestPattern{/*num_inputs=*/2,
|
||||
/*is_param_valid=*/false},
|
||||
ReshapeTestPattern{/*num_inputs=*/1,
|
||||
/*is_param_valid=*/true}),
|
||||
[](const ::testing::TestParamInfo<ReshapeLayerTest::ParamType>& info) {
|
||||
std::string name = absl::StrCat("num_inputs_", info.param.num_inputs,
|
||||
"_isvalid_", info.param.is_param_valid);
|
||||
return name;
|
||||
});
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
Loading…
Reference in New Issue
Block a user