diff --git a/tensorflow/lite/tools/optimize/modify_model_interface.cc b/tensorflow/lite/tools/optimize/modify_model_interface.cc index 91c9b7e8b74..d40d4455e24 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface.cc +++ b/tensorflow/lite/tools/optimize/modify_model_interface.cc @@ -45,7 +45,8 @@ struct TensorOpTensor { // Finds float tensors that are model inputs and is consumed by a quantize Op. // The returned TensorOpTensor should have reverse order. -std::vector GetInputTensors(ModelT* model, +std::vector GetInputTensors(const TensorType& input_type, + ModelT* model, ErrorReporter* error_reporter) { std::vector result; // Get all input tensors. @@ -71,7 +72,7 @@ std::vector GetInputTensors(ModelT* model, continue; } if (op_code != BuiltinOperator_QUANTIZE) { - // Current only support INT8 quantized models. + // Currently only supports int8 and int16 quantized models. TF_LITE_REPORT_ERROR( error_reporter, "modify_model_interface called on a model without quant/dequant."); @@ -85,10 +86,27 @@ std::vector GetInputTensors(ModelT* model, } const int model_input_index = input_tensors[input_tensor]; TensorT* quant_output = subgraph->tensors[op->outputs[0]].get(); - if (quant_output->type != TensorType_INT8) { + if (quant_output->type != TensorType_INT8 && + quant_output->type != TensorType_INT16) { TF_LITE_REPORT_ERROR(error_reporter, - "modify_model_interface currently only support " - "int8 quantized models."); + "modify_model_interface currently only supports " + "int8 and int16 quantized models."); + } + + // The input type must be the same as the model quantization type + if (input_type != quant_output->type) { + // An exception, allow for UINT8 input type for INT8 quantized model. + if (!(input_type == TensorType_UINT8 && + quant_output->type == TensorType_INT8)) { + TF_LITE_REPORT_ERROR( + error_reporter, + "The %s input type is incompatible with %s quantized models. " + "To resolve this error, change the input_type to a compatible " + "one. " + "See: modify_model_interface.cc", + EnumNameTensorType(input_type), + EnumNameTensorType(quant_output->type)); + } } if (quant_output->quantization == nullptr) { continue; @@ -102,7 +120,8 @@ std::vector GetInputTensors(ModelT* model, // Finds float tensors that are model output and is consumed by a dequantize Op. // The returned TensorOpTensor should have reverse order. -std::vector GetOutputTensors(ModelT* model, +std::vector GetOutputTensors(const TensorType& output_type, + ModelT* model, ErrorReporter* error_reporter) { std::vector result; // Get all output tensors. @@ -128,7 +147,7 @@ std::vector GetOutputTensors(ModelT* model, continue; } if (op_code != BuiltinOperator_DEQUANTIZE) { - // Current only support INT8 quantized models. + // Currently only supports int8 and int16 quantized models. TF_LITE_REPORT_ERROR( error_reporter, "modify_model_interface called on a model without quant/dequant."); @@ -142,13 +161,28 @@ std::vector GetOutputTensors(ModelT* model, } const int model_output_index = output_tensors[output_tensor]; TensorT* dequant_input = subgraph->tensors[op->inputs[0]].get(); - if (dequant_input->type != TensorType_INT8) { - // Current only support INT8 quantized models. + if (dequant_input->type != TensorType_INT8 && + dequant_input->type != TensorType_INT16) { + // Currently only supports int8 and int16 quantized models. TF_LITE_REPORT_ERROR(error_reporter, - "modify_model_interface currently only support " - "int8 quantized models."); + "modify_model_interface currently only supports " + "int8 and int16 quantized models."); return {}; } + if (output_type != dequant_input->type) { + // An exception, allow for UINT8 input type for INT8 quantized model. + if (!(output_type == TensorType_UINT8 && + dequant_input->type == TensorType_INT8)) { + TF_LITE_REPORT_ERROR( + error_reporter, + "The %s output type is incompatible with %s quantized models. " + "To resolve this error, change the output_type to a compatible " + "one. " + "See: modify_model_interface.cc", + EnumNameTensorType(output_type), + EnumNameTensorType(dequant_input->type)); + } + } if (dequant_input->quantization == nullptr) { continue; } @@ -288,9 +322,13 @@ std::unique_ptr CreateMutableModelFromFile( return copied_model; } -int GetOriginalNumberOfTensors(ModelT* model, ErrorReporter* error_reporter) { - std::vector outputs = GetOutputTensors(model, error_reporter); - std::vector inputs = GetInputTensors(model, error_reporter); +int GetOriginalNumberOfTensors(const TensorType& input_type, + const TensorType& output_type, ModelT* model, + ErrorReporter* error_reporter) { + std::vector outputs = + GetOutputTensors(output_type, model, error_reporter); + std::vector inputs = + GetInputTensors(input_type, model, error_reporter); return model->subgraphs[0]->tensors.size() - outputs.size() - inputs.size(); } @@ -300,30 +338,39 @@ TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder, ModelT* model, const TensorType& input_type, const TensorType& output_type) { tflite::StderrReporter error_reporter; - const int original_number_tensors = - GetOriginalNumberOfTensors(model, &error_reporter); - // Find float tensors that are model output and is consumed by a float to int8 - // quantize Op. - // Do output first since the tensors are added into input first., + const int original_number_tensors = GetOriginalNumberOfTensors( + input_type, output_type, model, &error_reporter); + // Finds float tensors that are model output and are consumed by a float to + // int8/int16 quantize Op. Do output first since the tensors are added into + // input first., std::vector outputs = - GetOutputTensors(model, &error_reporter); - if (output_type == TensorType_UINT8) { - SetOutputTypeToUINT8(model, outputs); - } else if (output_type == TensorType_INT8) { - RemoveOutputTensor(model, outputs, original_number_tensors); - } else { - return kTfLiteError; + GetOutputTensors(output_type, model, &error_reporter); + switch (output_type) { + case TensorType_UINT8: + SetOutputTypeToUINT8(model, outputs); + break; + case TensorType_INT8: + case TensorType_INT16: + RemoveOutputTensor(model, outputs, original_number_tensors); + break; + default: + return kTfLiteError; } - // Find float tensors that are model input and is consumed by a float to int8 - // quantize Op. - std::vector inputs = GetInputTensors(model, &error_reporter); - if (input_type == TensorType_UINT8) { - SetInputTypeToUINT8(model, inputs); - } else if (input_type == TensorType_INT8) { - RemoveInputTensor(model, inputs, original_number_tensors); - } else { - return kTfLiteError; + // Find float tensors that are model input and is consumed by a float to + // int8/int16 quantize Op. + std::vector inputs = + GetInputTensors(input_type, model, &error_reporter); + switch (input_type) { + case TensorType_UINT8: + SetInputTypeToUINT8(model, inputs); + break; + case TensorType_INT8: + case TensorType_INT16: + RemoveInputTensor(model, inputs, original_number_tensors); + break; + default: + return kTfLiteError; } // Write to builder. @@ -340,11 +387,13 @@ TfLiteStatus ModifyModelInterface(const string& input_file, const TensorType& output_type) { // Consistency Check if (input_type != tflite::TensorType_INT8 && - input_type != tflite::TensorType_UINT8) { + input_type != tflite::TensorType_UINT8 && + input_type != tflite::TensorType_INT16) { return kTfLiteError; } if (output_type != tflite::TensorType_INT8 && - output_type != tflite::TensorType_UINT8) { + output_type != tflite::TensorType_UINT8 && + output_type != tflite::TensorType_INT16) { return kTfLiteError; } @@ -357,17 +406,8 @@ TfLiteStatus ModifyModelInterface(const string& input_file, absl::make_unique(); flatbuffers::FlatBufferBuilder builder; - tflite::TensorType input_override_type = tflite::TensorType_INT8; - if (input_type == tflite::TensorType_UINT8) { - input_override_type = tflite::TensorType_UINT8; - } - tflite::TensorType output_override_type = tflite::TensorType_INT8; - if (output_type == tflite::TensorType_UINT8) { - output_override_type = tflite::TensorType_UINT8; - } - - auto status = ModifyModelInterface(&builder, tflite_model.get(), - input_override_type, output_override_type); + auto status = ModifyModelInterface(&builder, tflite_model.get(), input_type, + output_type); TFLITE_DCHECK_EQ(status, kTfLiteOk); WriteFile(output_file, builder.GetBufferPointer(), builder.GetSize()); diff --git a/tensorflow/lite/tools/optimize/modify_model_interface.h b/tensorflow/lite/tools/optimize/modify_model_interface.h index 170e0e73a67..5711a615812 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface.h +++ b/tensorflow/lite/tools/optimize/modify_model_interface.h @@ -24,7 +24,7 @@ namespace optimize { // Changes the interface of a quantized model. This method allows the users to // replace float interface with other types. // This populates the builder with the new model. -// Currently only int8 and unit8 are supported. +// Currently only int8, int16 and uint8 are supported. // // Note: This is a private API, subject to change. TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder, diff --git a/tensorflow/lite/tools/optimize/modify_model_interface_main.cc b/tensorflow/lite/tools/optimize/modify_model_interface_main.cc index 24674a1b341..940c0d98b82 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface_main.cc +++ b/tensorflow/lite/tools/optimize/modify_model_interface_main.cc @@ -25,24 +25,20 @@ int main(int argc, char** argv) { return 1; } - if (strcmp(argv[3], "uint8") && strcmp(argv[3], "int8")) { - printf("Only support uint8 and int8 for input interface"); - return 1; - } - - if (strcmp(argv[4], "uint8") && strcmp(argv[4], "int8")) { - printf("Only support uint8 and int8 for output interface"); - return 1; - } + const std::unordered_map supported_types{ + {"uint8", tflite::TensorType_UINT8}, + {"int8", tflite::TensorType_INT8}, + {"int16", tflite::TensorType_INT16}}; tflite::TensorType input = tflite::TensorType_INT8; tflite::TensorType output = tflite::TensorType_INT8; - if (!strcmp(argv[3], "uint8")) { - input = tflite::TensorType_UINT8; - } - if (!strcmp(argv[4], "uint8")) { - output = tflite::TensorType_UINT8; + try { + input = supported_types.at(argv[3]); + output = supported_types.at(argv[4]); + } catch (const std::out_of_range&) { + printf("Only supports uint8, int8 and int16 for input and output types"); + return 1; } tflite::optimize::ModifyModelInterface(argv[1], argv[2], input, output); diff --git a/tensorflow/lite/tools/optimize/modify_model_interface_test.cc b/tensorflow/lite/tools/optimize/modify_model_interface_test.cc index 55147cec1ec..99e0ad35b2d 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface_test.cc +++ b/tensorflow/lite/tools/optimize/modify_model_interface_test.cc @@ -26,8 +26,11 @@ namespace tflite { namespace optimize { namespace { -// Create a quantized model with 1 quant, 1 FC, 1 dequant -std::unique_ptr CreateQuantizedModelSingleInputOutput() { +using ::testing::ElementsAreArray; + +// Create a model with 1 quant, 1 FC, 1 dequant +std::unique_ptr CreateQuantizedModelSingleInputOutput( + const TensorType& quantization_type) { auto model = absl::make_unique(); auto subgraph = absl::make_unique(); auto buffer = absl::make_unique(); @@ -87,7 +90,7 @@ std::unique_ptr CreateQuantizedModelSingleInputOutput() { tensor_1->quantization->zero_point.push_back(28); tensor_1->name = "tensor_1"; tensor_1->shape = {}; - tensor_1->type = TensorType_INT8; + tensor_1->type = quantization_type; auto tensor_2 = absl::make_unique(); tensor_2->quantization = absl::make_unique(); @@ -95,7 +98,7 @@ std::unique_ptr CreateQuantizedModelSingleInputOutput() { tensor_2->quantization->zero_point.push_back(50); tensor_2->name = "tensor_2"; tensor_2->shape = {}; - tensor_2->type = TensorType_INT8; + tensor_2->type = quantization_type; auto tensor_3 = absl::make_unique(); tensor_3->name = "tensor_3"; @@ -113,8 +116,10 @@ std::unique_ptr CreateQuantizedModelSingleInputOutput() { return model; } -// Create a quantized model with 2 quant, 1 FC, 2 dequant -std::unique_ptr CreateQuantizedModelMultipleInputOutput() { +// Create a model with 2 quant, 1 FC, 2 dequant +// The model mimics the behavior of the quantize_model.cc. +std::unique_ptr CreateQuantizedModelMultipleInputOutput( + const TensorType& quantization_type) { auto model = absl::make_unique(); auto subgraph = absl::make_unique(); auto buffer = absl::make_unique(); @@ -189,7 +194,7 @@ std::unique_ptr CreateQuantizedModelMultipleInputOutput() { tensor_2->quantization->zero_point.push_back(28); tensor_2->name = "tensor_2"; tensor_2->shape = {}; - tensor_2->type = TensorType_INT8; + tensor_2->type = quantization_type; auto tensor_3 = absl::make_unique(); tensor_3->quantization = absl::make_unique(); @@ -197,7 +202,7 @@ std::unique_ptr CreateQuantizedModelMultipleInputOutput() { tensor_3->quantization->zero_point.push_back(50); tensor_3->name = "tensor_3"; tensor_3->shape = {}; - tensor_3->type = TensorType_INT8; + tensor_3->type = quantization_type; auto tensor_4 = absl::make_unique(); tensor_4->quantization = absl::make_unique(); @@ -205,7 +210,7 @@ std::unique_ptr CreateQuantizedModelMultipleInputOutput() { tensor_4->quantization->zero_point.push_back(28); tensor_4->name = "tensor_4"; tensor_4->shape = {}; - tensor_4->type = TensorType_INT8; + tensor_4->type = quantization_type; auto tensor_5 = absl::make_unique(); tensor_5->quantization = absl::make_unique(); @@ -213,7 +218,7 @@ std::unique_ptr CreateQuantizedModelMultipleInputOutput() { tensor_5->quantization->zero_point.push_back(50); tensor_5->name = "tensor_5"; tensor_5->shape = {}; - tensor_5->type = TensorType_INT8; + tensor_5->type = quantization_type; auto tensor_6 = absl::make_unique(); tensor_6->name = "tensor_6"; @@ -286,8 +291,141 @@ std::unique_ptr CreateFloatModel() { return model; } +struct ModelInterface : ::testing::TestWithParam {}; + +TEST_P(ModelInterface, SingleInputOutput) { + TensorType quantization_type = GetParam(); + + auto model = CreateQuantizedModelSingleInputOutput(quantization_type); + + // Change model type. + flatbuffers::FlatBufferBuilder builder; + EXPECT_EQ(ModifyModelInterface(&builder, model.get(), quantization_type, + quantization_type), + kTfLiteOk); + + // Verify results. + EXPECT_EQ(model->subgraphs.size(), 1); + // TODO(mnatraj): The float input tensor has not been removed. + // EXPECT_EQ(model->subgraphs[0]->tensors.size(), 2); + EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3); + EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1); + EXPECT_EQ(model->subgraphs[0]->inputs[0], 1); + EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1); + EXPECT_EQ(model->subgraphs[0]->outputs[0], 2); + EXPECT_EQ(model->operator_codes.size(), 3); + EXPECT_EQ(model->subgraphs[0]->operators.size(), 1); + EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1); + + auto fc_op = model->subgraphs[0]->operators[0].get(); + + auto input = model->subgraphs[0]->tensors[fc_op->inputs[0]].get(); + EXPECT_EQ(input->name, "tensor_1"); + EXPECT_EQ(input->type, quantization_type); + EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35); + EXPECT_EQ(input->quantization->zero_point[0], 28); + + auto output = model->subgraphs[0]->tensors[fc_op->outputs[0]].get(); + EXPECT_EQ(output->name, "tensor_2"); + EXPECT_EQ(output->type, quantization_type); + EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12); + EXPECT_EQ(output->quantization->zero_point[0], 50); +} + +TEST_P(ModelInterface, MutipleInputOutput) { + TensorType quantization_type = GetParam(); + + auto model = CreateQuantizedModelMultipleInputOutput(quantization_type); + + // Change model type. + flatbuffers::FlatBufferBuilder builder; + EXPECT_EQ(ModifyModelInterface(&builder, model.get(), quantization_type, + quantization_type), + kTfLiteOk); + + // Verify results. + EXPECT_EQ(model->subgraphs.size(), 1); + // TODO (b/158254056): Remove unused inputs and outputs from tensor list + // EXPECT_EQ(model->subgraphs[0]->tensors.size(), 4); + EXPECT_EQ(model->subgraphs[0]->tensors.size(), 6); + EXPECT_EQ(model->subgraphs[0]->inputs.size(), 2); + EXPECT_EQ(model->subgraphs[0]->inputs[0], 2); + EXPECT_EQ(model->subgraphs[0]->inputs[1], 3); + EXPECT_EQ(model->subgraphs[0]->outputs.size(), 2); + EXPECT_EQ(model->subgraphs[0]->outputs[0], 4); + EXPECT_EQ(model->subgraphs[0]->outputs[1], 5); + EXPECT_EQ(model->operator_codes.size(), 3); + EXPECT_EQ(model->subgraphs[0]->operators.size(), 1); + EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1); + + auto fc_op = model->subgraphs[0]->operators[0].get(); + + auto input_1 = model->subgraphs[0]->tensors[fc_op->inputs[0]].get(); + EXPECT_EQ(input_1->name, "tensor_2"); + EXPECT_EQ(input_1->type, quantization_type); + EXPECT_FLOAT_EQ(input_1->quantization->scale[0], 0.35); + EXPECT_EQ(input_1->quantization->zero_point[0], 28); + + auto input_2 = model->subgraphs[0]->tensors[fc_op->inputs[1]].get(); + EXPECT_EQ(input_2->name, "tensor_3"); + EXPECT_EQ(input_2->type, quantization_type); + EXPECT_FLOAT_EQ(input_2->quantization->scale[0], 0.12); + EXPECT_EQ(input_2->quantization->zero_point[0], 50); + + auto output_1 = model->subgraphs[0]->tensors[fc_op->outputs[0]].get(); + EXPECT_EQ(output_1->name, "tensor_4"); + EXPECT_EQ(output_1->type, quantization_type); + EXPECT_FLOAT_EQ(output_1->quantization->scale[0], 0.45); + EXPECT_EQ(output_1->quantization->zero_point[0], 28); + + auto output_2 = model->subgraphs[0]->tensors[fc_op->outputs[1]].get(); + EXPECT_EQ(output_2->name, "tensor_5"); + EXPECT_EQ(output_2->type, quantization_type); + EXPECT_FLOAT_EQ(output_2->quantization->scale[0], 0.22); + EXPECT_EQ(output_2->quantization->zero_point[0], 50); +} + +INSTANTIATE_TEST_SUITE_P(MultipleInputOutputTests, ModelInterface, + ::testing::Values(TensorType_INT8, TensorType_INT16)); + +TEST(ModelInterface, MixedTypeSingleInputOutput) { + auto model = CreateQuantizedModelSingleInputOutput(TensorType_INT8); + + // Change model type. + flatbuffers::FlatBufferBuilder builder; + EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_UINT8, + TensorType_INT8), + kTfLiteOk); + + // Verify results. + EXPECT_EQ(model->subgraphs.size(), 1); + EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3); + EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1); + EXPECT_EQ(model->subgraphs[0]->inputs[0], 0); + EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1); + EXPECT_EQ(model->subgraphs[0]->outputs[0], 2); + EXPECT_EQ(model->operator_codes.size(), 3); + EXPECT_EQ(model->subgraphs[0]->operators.size(), 2); + EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 0); + EXPECT_EQ(model->subgraphs[0]->operators[1]->opcode_index, 1); + + auto quant_op = model->subgraphs[0]->operators[0].get(); + auto input = model->subgraphs[0]->tensors[quant_op->inputs[0]].get(); + EXPECT_EQ(input->name, "tensor_0"); + EXPECT_EQ(input->type, TensorType_UINT8); + EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35); + EXPECT_EQ(input->quantization->zero_point[0], 156); + + auto fc_op = model->subgraphs[0]->operators[1].get(); + auto output = model->subgraphs[0]->tensors[fc_op->outputs[0]].get(); + EXPECT_EQ(output->name, "tensor_2"); + EXPECT_EQ(output->type, TensorType_INT8); + EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12); + EXPECT_EQ(output->quantization->zero_point[0], 50); +} + TEST(ModelInterface, Uint8SingleInputOutput) { - auto model = CreateQuantizedModelSingleInputOutput(); + auto model = CreateQuantizedModelSingleInputOutput(TensorType_INT8); // Change model type. flatbuffers::FlatBufferBuilder builder; @@ -323,81 +461,8 @@ TEST(ModelInterface, Uint8SingleInputOutput) { EXPECT_EQ(output->quantization->zero_point[0], 178); } -TEST(ModelInterface, Int8SingleInputOutput) { - auto model = CreateQuantizedModelSingleInputOutput(); - - // Change model type. - flatbuffers::FlatBufferBuilder builder; - EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_INT8, - TensorType_INT8), - kTfLiteOk); - - // Verify results. - EXPECT_EQ(model->subgraphs.size(), 1); - // TODO(mnatraj): The float input tensor has not been removed. - // EXPECT_EQ(model->subgraphs[0]->tensors.size(), 2); - EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3); - EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1); - EXPECT_EQ(model->subgraphs[0]->inputs[0], 1); - EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1); - EXPECT_EQ(model->subgraphs[0]->outputs[0], 2); - EXPECT_EQ(model->operator_codes.size(), 3); - EXPECT_EQ(model->subgraphs[0]->operators.size(), 1); - EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1); - - auto fc_op = model->subgraphs[0]->operators[0].get(); - - auto input = model->subgraphs[0]->tensors[fc_op->inputs[0]].get(); - EXPECT_EQ(input->name, "tensor_1"); - EXPECT_EQ(input->type, TensorType_INT8); - EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35); - EXPECT_EQ(input->quantization->zero_point[0], 28); - - auto output = model->subgraphs[0]->tensors[fc_op->outputs[0]].get(); - EXPECT_EQ(output->name, "tensor_2"); - EXPECT_EQ(output->type, TensorType_INT8); - EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12); - EXPECT_EQ(output->quantization->zero_point[0], 50); -} - -TEST(ModelInterface, MixedTypeSingleInputOutput) { - auto model = CreateQuantizedModelSingleInputOutput(); - - // Change model type. - flatbuffers::FlatBufferBuilder builder; - EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_UINT8, - TensorType_INT8), - kTfLiteOk); - - // Verify results. - EXPECT_EQ(model->subgraphs.size(), 1); - EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3); - EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1); - EXPECT_EQ(model->subgraphs[0]->inputs[0], 0); - EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1); - EXPECT_EQ(model->subgraphs[0]->outputs[0], 2); - EXPECT_EQ(model->operator_codes.size(), 3); - EXPECT_EQ(model->subgraphs[0]->operators.size(), 2); - EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 0); - EXPECT_EQ(model->subgraphs[0]->operators[1]->opcode_index, 1); - - auto quant_op = model->subgraphs[0]->operators[0].get(); - auto input = model->subgraphs[0]->tensors[quant_op->inputs[0]].get(); - EXPECT_EQ(input->name, "tensor_0"); - EXPECT_EQ(input->type, TensorType_UINT8); - EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35); - EXPECT_EQ(input->quantization->zero_point[0], 156); - - auto fc_op = model->subgraphs[0]->operators[1].get(); - auto output = model->subgraphs[0]->tensors[fc_op->outputs[0]].get(); - EXPECT_EQ(output->name, "tensor_2"); - EXPECT_EQ(output->type, TensorType_INT8); - EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12); - EXPECT_EQ(output->quantization->zero_point[0], 50); -} - TEST(ModelInterface, Uint8MutipleInputOutput) { - auto model = CreateQuantizedModelMultipleInputOutput(); + auto model = CreateQuantizedModelMultipleInputOutput(TensorType_INT8); // Change model type. flatbuffers::FlatBufferBuilder builder; @@ -454,7 +519,7 @@ TEST(ModelInterface, Uint8MutipleInputOutput) { } TEST(ModelInterface, Int8MutipleInputOutput) { - auto model = CreateQuantizedModelMultipleInputOutput(); + auto model = CreateQuantizedModelMultipleInputOutput(TensorType_INT8); // Change model type. flatbuffers::FlatBufferBuilder builder;