Support old TOCO generated reshape operator cases where the dimension of the
given shape input is non 1-D In such cases, the TOCO converted models always have the correct value at the new_shape attribute so we need to copy the new_shape attribute in order to correctly rewrite. This change closes #46423 PiperOrigin-RevId: 351920426 Change-Id: I4f7bea1ad6f59485d4011cf2a52c7bce989c52d7
This commit is contained in:
parent
5b1e9b0e28
commit
f12082d7af
tensorflow/lite/tools/serialization
@ -274,14 +274,17 @@ void GenerateImportForResizeBilinearOp(FILE* fp) {
|
||||
// Reshape Op infers output shape either from Parameter or from shape tensor
|
||||
// that's is an additional input. When we have this additional shape tensor as
|
||||
// input we don't have the parameter present in this layer. In case of more than
|
||||
// one input we import an empty vector for the parameters.
|
||||
// one input and the shape parameter does not have a valid value, we import an
|
||||
// empty vector for the parameters.
|
||||
void GenerateImportForReshapeOp(FILE* fp) {
|
||||
fprintf(fp,
|
||||
" case BuiltinOperator_RESHAPE: {\n"
|
||||
" const auto* params = reinterpret_cast<const "
|
||||
"TfLiteReshapeParams*>(builtin_op_data);\n"
|
||||
" flatbuffers::Offset<void> union_type;\n"
|
||||
" if (node.inputs->size > 1) {\n"
|
||||
" if (node.inputs->size > 1 && (params->num_dimensions <= 0 || "
|
||||
"params->num_dimensions > TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT))"
|
||||
" {\n"
|
||||
" union_type = CreateReshapeOptions(*fbb).Union();\n"
|
||||
" } else {\n"
|
||||
" auto val0 = fbb->CreateVector(std::vector<int>(params->shape, "
|
||||
|
@ -311,6 +311,7 @@ INSTANTIATE_TEST_SUITE_P(Writer, SingleSubgraphTest, ::testing::Bool());
|
||||
struct ReshapeTestPattern {
|
||||
int num_inputs;
|
||||
bool is_param_valid;
|
||||
bool has_buggy_non_flatten_shape;
|
||||
};
|
||||
|
||||
class ReshapeLayerTest : public ::testing::TestWithParam<ReshapeTestPattern> {};
|
||||
@ -326,10 +327,19 @@ TEST_P(ReshapeLayerTest, ReshapeLayerTest) {
|
||||
TfLiteQuantization());
|
||||
ASSERT_LE(param.num_inputs, 2);
|
||||
if (param.num_inputs == 2) {
|
||||
interpreter.SetTensorParametersReadOnly(
|
||||
/*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3},
|
||||
TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
|
||||
sizeof(output_shape));
|
||||
// Some TOCO generated models have buggy shape arguments, which are required
|
||||
// to be flatten, for example, dims={3, 1} instead of dims={3}.
|
||||
if (param.has_buggy_non_flatten_shape) {
|
||||
interpreter.SetTensorParametersReadOnly(
|
||||
/*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3, 1},
|
||||
TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
|
||||
sizeof(output_shape));
|
||||
} else {
|
||||
interpreter.SetTensorParametersReadOnly(
|
||||
/*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3},
|
||||
TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
|
||||
sizeof(output_shape));
|
||||
}
|
||||
}
|
||||
interpreter.SetTensorParametersReadWrite(/*tensor_index=*/total_tensors - 1,
|
||||
kTfLiteFloat32, /*name=*/"c",
|
||||
@ -373,15 +383,22 @@ TEST_P(ReshapeLayerTest, ReshapeLayerTest) {
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
Writer, ReshapeLayerTest,
|
||||
::testing::Values(ReshapeTestPattern{/*num_inputs=*/2,
|
||||
/*is_param_valid=*/true},
|
||||
/*is_param_valid=*/true,
|
||||
/*has_buggy_non_flatten_shape=*/false},
|
||||
ReshapeTestPattern{/*num_inputs=*/2,
|
||||
/*is_param_valid=*/false},
|
||||
/*is_param_valid=*/false,
|
||||
/*has_buggy_non_flatten_shape=*/false},
|
||||
ReshapeTestPattern{/*num_inputs=*/1,
|
||||
/*is_param_valid=*/true}),
|
||||
/*is_param_valid=*/true,
|
||||
/*has_buggy_non_flatten_shape=*/false},
|
||||
ReshapeTestPattern{/*num_inputs=*/2,
|
||||
/*is_param_valid=*/true,
|
||||
/*has_buggy_non_flatten_shape=*/true}),
|
||||
[](const ::testing::TestParamInfo<ReshapeLayerTest::ParamType>& info) {
|
||||
std::stringstream ss;
|
||||
ss << "num_inputs_" << info.param.num_inputs << "_valid_param_"
|
||||
<< info.param.is_param_valid;
|
||||
<< info.param.is_param_valid << "_buggy_shape_"
|
||||
<< info.param.has_buggy_non_flatten_shape;
|
||||
std::string name = ss.str();
|
||||
return name;
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user