diff --git a/tensorflow/lite/tools/optimize/modify_model_interface.cc b/tensorflow/lite/tools/optimize/modify_model_interface.cc index bc1e9cbe5a3..d173bb608aa 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface.cc +++ b/tensorflow/lite/tools/optimize/modify_model_interface.cc @@ -295,7 +295,7 @@ TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder, GetOutputTensors(model, &error_reporter); if (output_type == TensorType_UINT8) { SetOutputTypeToUINT8(model, outputs); - } else if (input_type == TensorType_INT8) { + } else if (output_type == TensorType_INT8) { RemoveOutputTensor(model, outputs); } else { return kTfLiteError; diff --git a/tensorflow/lite/tools/optimize/modify_model_interface_test.cc b/tensorflow/lite/tools/optimize/modify_model_interface_test.cc index 26fe4fa6331..5a04f28f638 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface_test.cc +++ b/tensorflow/lite/tools/optimize/modify_model_interface_test.cc @@ -351,6 +351,28 @@ TEST(ModelInterface, Int8SingleInputOutput) { EXPECT_EQ(model->subgraphs[0]->outputs[0], 1); } +TEST(ModelInterface, MixedTypeSingleInputOutput) { + auto model = CreateModelSingleInputOutput(); + + // Change model type. + flatbuffers::FlatBufferBuilder builder; + EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_UINT8, + TensorType_INT8), + kTfLiteOk); + + // Verify results. + EXPECT_EQ(model->operator_codes.size(), 3); + EXPECT_EQ(model->subgraphs.size(), 1); + EXPECT_EQ(model->subgraphs[0]->operators.size(), 2); + EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3); + EXPECT_EQ(model->buffers.size(), 1); + + EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1); + EXPECT_EQ(model->subgraphs[0]->inputs[0], 2); + EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1); + EXPECT_EQ(model->subgraphs[0]->outputs[0], 1); +} + TEST(ModelInterface, Uint8MutipleInputOutput) { auto model = CreateModelMultipleInputOutput();