Support old TOCO generated reshape operator cases where the dimension of the
given shape input is non 1-D In such cases, the TOCO converted models always have the correct value at the new_shape attribute so we need to copy the new_shape attribute in order to correctly rewrite. This change closes #46423 PiperOrigin-RevId: 351920426 Change-Id: I4f7bea1ad6f59485d4011cf2a52c7bce989c52d7
This commit is contained in:
parent
5b1e9b0e28
commit
f12082d7af
@ -274,14 +274,17 @@ void GenerateImportForResizeBilinearOp(FILE* fp) {
|
|||||||
// Reshape Op infers output shape either from Parameter or from shape tensor
|
// Reshape Op infers output shape either from Parameter or from shape tensor
|
||||||
// that's is an additional input. When we have this additional shape tensor as
|
// that's is an additional input. When we have this additional shape tensor as
|
||||||
// input we don't have the parameter present in this layer. In case of more than
|
// input we don't have the parameter present in this layer. In case of more than
|
||||||
// one input we import an empty vector for the parameters.
|
// one input and the shape parameter does not have a valid value, we import an
|
||||||
|
// empty vector for the parameters.
|
||||||
void GenerateImportForReshapeOp(FILE* fp) {
|
void GenerateImportForReshapeOp(FILE* fp) {
|
||||||
fprintf(fp,
|
fprintf(fp,
|
||||||
" case BuiltinOperator_RESHAPE: {\n"
|
" case BuiltinOperator_RESHAPE: {\n"
|
||||||
" const auto* params = reinterpret_cast<const "
|
" const auto* params = reinterpret_cast<const "
|
||||||
"TfLiteReshapeParams*>(builtin_op_data);\n"
|
"TfLiteReshapeParams*>(builtin_op_data);\n"
|
||||||
" flatbuffers::Offset<void> union_type;\n"
|
" flatbuffers::Offset<void> union_type;\n"
|
||||||
" if (node.inputs->size > 1) {\n"
|
" if (node.inputs->size > 1 && (params->num_dimensions <= 0 || "
|
||||||
|
"params->num_dimensions > TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT))"
|
||||||
|
" {\n"
|
||||||
" union_type = CreateReshapeOptions(*fbb).Union();\n"
|
" union_type = CreateReshapeOptions(*fbb).Union();\n"
|
||||||
" } else {\n"
|
" } else {\n"
|
||||||
" auto val0 = fbb->CreateVector(std::vector<int>(params->shape, "
|
" auto val0 = fbb->CreateVector(std::vector<int>(params->shape, "
|
||||||
|
@ -311,6 +311,7 @@ INSTANTIATE_TEST_SUITE_P(Writer, SingleSubgraphTest, ::testing::Bool());
|
|||||||
struct ReshapeTestPattern {
|
struct ReshapeTestPattern {
|
||||||
int num_inputs;
|
int num_inputs;
|
||||||
bool is_param_valid;
|
bool is_param_valid;
|
||||||
|
bool has_buggy_non_flatten_shape;
|
||||||
};
|
};
|
||||||
|
|
||||||
class ReshapeLayerTest : public ::testing::TestWithParam<ReshapeTestPattern> {};
|
class ReshapeLayerTest : public ::testing::TestWithParam<ReshapeTestPattern> {};
|
||||||
@ -326,10 +327,19 @@ TEST_P(ReshapeLayerTest, ReshapeLayerTest) {
|
|||||||
TfLiteQuantization());
|
TfLiteQuantization());
|
||||||
ASSERT_LE(param.num_inputs, 2);
|
ASSERT_LE(param.num_inputs, 2);
|
||||||
if (param.num_inputs == 2) {
|
if (param.num_inputs == 2) {
|
||||||
interpreter.SetTensorParametersReadOnly(
|
// Some TOCO generated models have buggy shape arguments, which are required
|
||||||
/*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3},
|
// to be flatten, for example, dims={3, 1} instead of dims={3}.
|
||||||
TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
|
if (param.has_buggy_non_flatten_shape) {
|
||||||
sizeof(output_shape));
|
interpreter.SetTensorParametersReadOnly(
|
||||||
|
/*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3, 1},
|
||||||
|
TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
|
||||||
|
sizeof(output_shape));
|
||||||
|
} else {
|
||||||
|
interpreter.SetTensorParametersReadOnly(
|
||||||
|
/*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3},
|
||||||
|
TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
|
||||||
|
sizeof(output_shape));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
interpreter.SetTensorParametersReadWrite(/*tensor_index=*/total_tensors - 1,
|
interpreter.SetTensorParametersReadWrite(/*tensor_index=*/total_tensors - 1,
|
||||||
kTfLiteFloat32, /*name=*/"c",
|
kTfLiteFloat32, /*name=*/"c",
|
||||||
@ -373,15 +383,22 @@ TEST_P(ReshapeLayerTest, ReshapeLayerTest) {
|
|||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
Writer, ReshapeLayerTest,
|
Writer, ReshapeLayerTest,
|
||||||
::testing::Values(ReshapeTestPattern{/*num_inputs=*/2,
|
::testing::Values(ReshapeTestPattern{/*num_inputs=*/2,
|
||||||
/*is_param_valid=*/true},
|
/*is_param_valid=*/true,
|
||||||
|
/*has_buggy_non_flatten_shape=*/false},
|
||||||
ReshapeTestPattern{/*num_inputs=*/2,
|
ReshapeTestPattern{/*num_inputs=*/2,
|
||||||
/*is_param_valid=*/false},
|
/*is_param_valid=*/false,
|
||||||
|
/*has_buggy_non_flatten_shape=*/false},
|
||||||
ReshapeTestPattern{/*num_inputs=*/1,
|
ReshapeTestPattern{/*num_inputs=*/1,
|
||||||
/*is_param_valid=*/true}),
|
/*is_param_valid=*/true,
|
||||||
|
/*has_buggy_non_flatten_shape=*/false},
|
||||||
|
ReshapeTestPattern{/*num_inputs=*/2,
|
||||||
|
/*is_param_valid=*/true,
|
||||||
|
/*has_buggy_non_flatten_shape=*/true}),
|
||||||
[](const ::testing::TestParamInfo<ReshapeLayerTest::ParamType>& info) {
|
[](const ::testing::TestParamInfo<ReshapeLayerTest::ParamType>& info) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "num_inputs_" << info.param.num_inputs << "_valid_param_"
|
ss << "num_inputs_" << info.param.num_inputs << "_valid_param_"
|
||||||
<< info.param.is_param_valid;
|
<< info.param.is_param_valid << "_buggy_shape_"
|
||||||
|
<< info.param.has_buggy_non_flatten_shape;
|
||||||
std::string name = ss.str();
|
std::string name = ss.str();
|
||||||
return name;
|
return name;
|
||||||
});
|
});
|
||||||
|
Loading…
Reference in New Issue
Block a user