Support old TOCO generated reshape operator cases where the dimension of the

given shape input is non 1-D

In such cases, the TOCO converted models always have the correct value at the
new_shape attribute so we need to copy the new_shape attribute in order to
correctly rewrite.

This change closes #46423

PiperOrigin-RevId: 351920426
Change-Id: I4f7bea1ad6f59485d4011cf2a52c7bce989c52d7
This commit is contained in:
Jaesung Chung 2021-01-14 18:43:36 -08:00 committed by TensorFlower Gardener
parent 5b1e9b0e28
commit f12082d7af
2 changed files with 30 additions and 10 deletions

View File

@ -274,14 +274,17 @@ void GenerateImportForResizeBilinearOp(FILE* fp) {
// Reshape Op infers output shape either from Parameter or from shape tensor // Reshape Op infers output shape either from Parameter or from shape tensor
// that's is an additional input. When we have this additional shape tensor as // that's is an additional input. When we have this additional shape tensor as
// input we don't have the parameter present in this layer. In case of more than // input we don't have the parameter present in this layer. In case of more than
// one input we import an empty vector for the parameters. // one input and the shape parameter does not have a valid value, we import an
// empty vector for the parameters.
void GenerateImportForReshapeOp(FILE* fp) { void GenerateImportForReshapeOp(FILE* fp) {
fprintf(fp, fprintf(fp,
" case BuiltinOperator_RESHAPE: {\n" " case BuiltinOperator_RESHAPE: {\n"
" const auto* params = reinterpret_cast<const " " const auto* params = reinterpret_cast<const "
"TfLiteReshapeParams*>(builtin_op_data);\n" "TfLiteReshapeParams*>(builtin_op_data);\n"
" flatbuffers::Offset<void> union_type;\n" " flatbuffers::Offset<void> union_type;\n"
" if (node.inputs->size > 1) {\n" " if (node.inputs->size > 1 && (params->num_dimensions <= 0 || "
"params->num_dimensions > TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT))"
" {\n"
" union_type = CreateReshapeOptions(*fbb).Union();\n" " union_type = CreateReshapeOptions(*fbb).Union();\n"
" } else {\n" " } else {\n"
" auto val0 = fbb->CreateVector(std::vector<int>(params->shape, " " auto val0 = fbb->CreateVector(std::vector<int>(params->shape, "

View File

@ -311,6 +311,7 @@ INSTANTIATE_TEST_SUITE_P(Writer, SingleSubgraphTest, ::testing::Bool());
struct ReshapeTestPattern { struct ReshapeTestPattern {
int num_inputs; int num_inputs;
bool is_param_valid; bool is_param_valid;
bool has_buggy_non_flatten_shape;
}; };
class ReshapeLayerTest : public ::testing::TestWithParam<ReshapeTestPattern> {}; class ReshapeLayerTest : public ::testing::TestWithParam<ReshapeTestPattern> {};
@ -326,11 +327,20 @@ TEST_P(ReshapeLayerTest, ReshapeLayerTest) {
TfLiteQuantization()); TfLiteQuantization());
ASSERT_LE(param.num_inputs, 2); ASSERT_LE(param.num_inputs, 2);
if (param.num_inputs == 2) { if (param.num_inputs == 2) {
// Some TOCO generated models have buggy shape arguments, which are required
// to be flatten, for example, dims={3, 1} instead of dims={3}.
if (param.has_buggy_non_flatten_shape) {
interpreter.SetTensorParametersReadOnly(
/*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3, 1},
TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
sizeof(output_shape));
} else {
interpreter.SetTensorParametersReadOnly( interpreter.SetTensorParametersReadOnly(
/*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3}, /*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3},
TfLiteQuantization(), reinterpret_cast<char*>(output_shape), TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
sizeof(output_shape)); sizeof(output_shape));
} }
}
interpreter.SetTensorParametersReadWrite(/*tensor_index=*/total_tensors - 1, interpreter.SetTensorParametersReadWrite(/*tensor_index=*/total_tensors - 1,
kTfLiteFloat32, /*name=*/"c", kTfLiteFloat32, /*name=*/"c",
/*dims=*/{3}, TfLiteQuantization()); /*dims=*/{3}, TfLiteQuantization());
@ -373,15 +383,22 @@ TEST_P(ReshapeLayerTest, ReshapeLayerTest) {
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
Writer, ReshapeLayerTest, Writer, ReshapeLayerTest,
::testing::Values(ReshapeTestPattern{/*num_inputs=*/2, ::testing::Values(ReshapeTestPattern{/*num_inputs=*/2,
/*is_param_valid=*/true}, /*is_param_valid=*/true,
/*has_buggy_non_flatten_shape=*/false},
ReshapeTestPattern{/*num_inputs=*/2, ReshapeTestPattern{/*num_inputs=*/2,
/*is_param_valid=*/false}, /*is_param_valid=*/false,
/*has_buggy_non_flatten_shape=*/false},
ReshapeTestPattern{/*num_inputs=*/1, ReshapeTestPattern{/*num_inputs=*/1,
/*is_param_valid=*/true}), /*is_param_valid=*/true,
/*has_buggy_non_flatten_shape=*/false},
ReshapeTestPattern{/*num_inputs=*/2,
/*is_param_valid=*/true,
/*has_buggy_non_flatten_shape=*/true}),
[](const ::testing::TestParamInfo<ReshapeLayerTest::ParamType>& info) { [](const ::testing::TestParamInfo<ReshapeLayerTest::ParamType>& info) {
std::stringstream ss; std::stringstream ss;
ss << "num_inputs_" << info.param.num_inputs << "_valid_param_" ss << "num_inputs_" << info.param.num_inputs << "_valid_param_"
<< info.param.is_param_valid; << info.param.is_param_valid << "_buggy_shape_"
<< info.param.has_buggy_non_flatten_shape;
std::string name = ss.str(); std::string name = ss.str();
return name; return name;
}); });