Fix the output type in modify_model_inference.

PiperOrigin-RevId: 307155237
Change-Id: I5020db2ee599aea9afcae35114e2a9a6fafffe75
This commit is contained in:
Jian Li 2020-04-17 19:30:10 -07:00 committed by TensorFlower Gardener
parent 5fcc70306d
commit b8651b5dd2
2 changed files with 23 additions and 1 deletions

View File

@ -295,7 +295,7 @@ TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder,
GetOutputTensors(model, &error_reporter);
if (output_type == TensorType_UINT8) {
SetOutputTypeToUINT8(model, outputs);
} else if (input_type == TensorType_INT8) {
} else if (output_type == TensorType_INT8) {
RemoveOutputTensor(model, outputs);
} else {
return kTfLiteError;

View File

@ -351,6 +351,28 @@ TEST(ModelInterface, Int8SingleInputOutput) {
EXPECT_EQ(model->subgraphs[0]->outputs[0], 1);
}
TEST(ModelInterface, MixedTypeSingleInputOutput) {
auto model = CreateModelSingleInputOutput();
// Change model type.
flatbuffers::FlatBufferBuilder builder;
EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_UINT8,
TensorType_INT8),
kTfLiteOk);
// Verify results.
EXPECT_EQ(model->operator_codes.size(), 3);
EXPECT_EQ(model->subgraphs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->operators.size(), 2);
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3);
EXPECT_EQ(model->buffers.size(), 1);
EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->inputs[0], 2);
EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->outputs[0], 1);
}
TEST(ModelInterface, Uint8MutipleInputOutput) {
auto model = CreateModelMultipleInputOutput();