Merge pull request #40190 from TamasArm:16_bit_support_for_modify_model_interface

PiperOrigin-RevId: 320997124
Change-Id: I09242d15c14365d72d43dc1d8e911e96f12edf82
This commit is contained in:
TensorFlower Gardener 2020-07-13 11:37:07 -07:00
commit 3fc7725cc2
4 changed files with 250 additions and 149 deletions

View File

@ -45,7 +45,8 @@ struct TensorOpTensor {
// Finds float tensors that are model inputs and is consumed by a quantize Op.
// The returned TensorOpTensor should have reverse order.
std::vector<TensorOpTensor> GetInputTensors(ModelT* model,
std::vector<TensorOpTensor> GetInputTensors(const TensorType& input_type,
ModelT* model,
ErrorReporter* error_reporter) {
std::vector<TensorOpTensor> result;
// Get all input tensors.
@ -71,7 +72,7 @@ std::vector<TensorOpTensor> GetInputTensors(ModelT* model,
continue;
}
if (op_code != BuiltinOperator_QUANTIZE) {
// Current only support INT8 quantized models.
// Currently only supports int8 and int16 quantized models.
TF_LITE_REPORT_ERROR(
error_reporter,
"modify_model_interface called on a model without quant/dequant.");
@ -85,10 +86,27 @@ std::vector<TensorOpTensor> GetInputTensors(ModelT* model,
}
const int model_input_index = input_tensors[input_tensor];
TensorT* quant_output = subgraph->tensors[op->outputs[0]].get();
if (quant_output->type != TensorType_INT8) {
if (quant_output->type != TensorType_INT8 &&
quant_output->type != TensorType_INT16) {
TF_LITE_REPORT_ERROR(error_reporter,
"modify_model_interface currently only support "
"int8 quantized models.");
"modify_model_interface currently only supports "
"int8 and int16 quantized models.");
}
// The input type must be the same as the model quantization type
if (input_type != quant_output->type) {
// An exception, allow for UINT8 input type for INT8 quantized model.
if (!(input_type == TensorType_UINT8 &&
quant_output->type == TensorType_INT8)) {
TF_LITE_REPORT_ERROR(
error_reporter,
"The %s input type is incompatible with %s quantized models. "
"To resolve this error, change the input_type to a compatible "
"one. "
"See: modify_model_interface.cc",
EnumNameTensorType(input_type),
EnumNameTensorType(quant_output->type));
}
}
if (quant_output->quantization == nullptr) {
continue;
@ -102,7 +120,8 @@ std::vector<TensorOpTensor> GetInputTensors(ModelT* model,
// Finds float tensors that are model output and is consumed by a dequantize Op.
// The returned TensorOpTensor should have reverse order.
std::vector<TensorOpTensor> GetOutputTensors(ModelT* model,
std::vector<TensorOpTensor> GetOutputTensors(const TensorType& output_type,
ModelT* model,
ErrorReporter* error_reporter) {
std::vector<TensorOpTensor> result;
// Get all output tensors.
@ -128,7 +147,7 @@ std::vector<TensorOpTensor> GetOutputTensors(ModelT* model,
continue;
}
if (op_code != BuiltinOperator_DEQUANTIZE) {
// Current only support INT8 quantized models.
// Currently only supports int8 and int16 quantized models.
TF_LITE_REPORT_ERROR(
error_reporter,
"modify_model_interface called on a model without quant/dequant.");
@ -142,13 +161,28 @@ std::vector<TensorOpTensor> GetOutputTensors(ModelT* model,
}
const int model_output_index = output_tensors[output_tensor];
TensorT* dequant_input = subgraph->tensors[op->inputs[0]].get();
if (dequant_input->type != TensorType_INT8) {
// Current only support INT8 quantized models.
if (dequant_input->type != TensorType_INT8 &&
dequant_input->type != TensorType_INT16) {
// Currently only supports int8 and int16 quantized models.
TF_LITE_REPORT_ERROR(error_reporter,
"modify_model_interface currently only support "
"int8 quantized models.");
"modify_model_interface currently only supports "
"int8 and int16 quantized models.");
return {};
}
if (output_type != dequant_input->type) {
// An exception, allow for UINT8 input type for INT8 quantized model.
if (!(output_type == TensorType_UINT8 &&
dequant_input->type == TensorType_INT8)) {
TF_LITE_REPORT_ERROR(
error_reporter,
"The %s output type is incompatible with %s quantized models. "
"To resolve this error, change the output_type to a compatible "
"one. "
"See: modify_model_interface.cc",
EnumNameTensorType(output_type),
EnumNameTensorType(dequant_input->type));
}
}
if (dequant_input->quantization == nullptr) {
continue;
}
@ -288,9 +322,13 @@ std::unique_ptr<tflite::ModelT> CreateMutableModelFromFile(
return copied_model;
}
int GetOriginalNumberOfTensors(ModelT* model, ErrorReporter* error_reporter) {
std::vector<TensorOpTensor> outputs = GetOutputTensors(model, error_reporter);
std::vector<TensorOpTensor> inputs = GetInputTensors(model, error_reporter);
int GetOriginalNumberOfTensors(const TensorType& input_type,
const TensorType& output_type, ModelT* model,
ErrorReporter* error_reporter) {
std::vector<TensorOpTensor> outputs =
GetOutputTensors(output_type, model, error_reporter);
std::vector<TensorOpTensor> inputs =
GetInputTensors(input_type, model, error_reporter);
return model->subgraphs[0]->tensors.size() - outputs.size() - inputs.size();
}
@ -300,30 +338,39 @@ TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder,
ModelT* model, const TensorType& input_type,
const TensorType& output_type) {
tflite::StderrReporter error_reporter;
const int original_number_tensors =
GetOriginalNumberOfTensors(model, &error_reporter);
// Find float tensors that are model output and is consumed by a float to int8
// quantize Op.
// Do output first since the tensors are added into input first.,
const int original_number_tensors = GetOriginalNumberOfTensors(
input_type, output_type, model, &error_reporter);
// Finds float tensors that are model output and are consumed by a float to
// int8/int16 quantize Op. Do output first since the tensors are added into
// input first.,
std::vector<TensorOpTensor> outputs =
GetOutputTensors(model, &error_reporter);
if (output_type == TensorType_UINT8) {
SetOutputTypeToUINT8(model, outputs);
} else if (output_type == TensorType_INT8) {
RemoveOutputTensor(model, outputs, original_number_tensors);
} else {
return kTfLiteError;
GetOutputTensors(output_type, model, &error_reporter);
switch (output_type) {
case TensorType_UINT8:
SetOutputTypeToUINT8(model, outputs);
break;
case TensorType_INT8:
case TensorType_INT16:
RemoveOutputTensor(model, outputs, original_number_tensors);
break;
default:
return kTfLiteError;
}
// Find float tensors that are model input and is consumed by a float to int8
// quantize Op.
std::vector<TensorOpTensor> inputs = GetInputTensors(model, &error_reporter);
if (input_type == TensorType_UINT8) {
SetInputTypeToUINT8(model, inputs);
} else if (input_type == TensorType_INT8) {
RemoveInputTensor(model, inputs, original_number_tensors);
} else {
return kTfLiteError;
// Find float tensors that are model input and is consumed by a float to
// int8/int16 quantize Op.
std::vector<TensorOpTensor> inputs =
GetInputTensors(input_type, model, &error_reporter);
switch (input_type) {
case TensorType_UINT8:
SetInputTypeToUINT8(model, inputs);
break;
case TensorType_INT8:
case TensorType_INT16:
RemoveInputTensor(model, inputs, original_number_tensors);
break;
default:
return kTfLiteError;
}
// Write to builder.
@ -340,11 +387,13 @@ TfLiteStatus ModifyModelInterface(const string& input_file,
const TensorType& output_type) {
// Consistency Check
if (input_type != tflite::TensorType_INT8 &&
input_type != tflite::TensorType_UINT8) {
input_type != tflite::TensorType_UINT8 &&
input_type != tflite::TensorType_INT16) {
return kTfLiteError;
}
if (output_type != tflite::TensorType_INT8 &&
output_type != tflite::TensorType_UINT8) {
output_type != tflite::TensorType_UINT8 &&
output_type != tflite::TensorType_INT16) {
return kTfLiteError;
}
@ -357,17 +406,8 @@ TfLiteStatus ModifyModelInterface(const string& input_file,
absl::make_unique<flatbuffers::FlatBufferBuilder>();
flatbuffers::FlatBufferBuilder builder;
tflite::TensorType input_override_type = tflite::TensorType_INT8;
if (input_type == tflite::TensorType_UINT8) {
input_override_type = tflite::TensorType_UINT8;
}
tflite::TensorType output_override_type = tflite::TensorType_INT8;
if (output_type == tflite::TensorType_UINT8) {
output_override_type = tflite::TensorType_UINT8;
}
auto status = ModifyModelInterface(&builder, tflite_model.get(),
input_override_type, output_override_type);
auto status = ModifyModelInterface(&builder, tflite_model.get(), input_type,
output_type);
TFLITE_DCHECK_EQ(status, kTfLiteOk);
WriteFile(output_file, builder.GetBufferPointer(), builder.GetSize());

View File

@ -24,7 +24,7 @@ namespace optimize {
// Changes the interface of a quantized model. This method allows the users to
// replace float interface with other types.
// This populates the builder with the new model.
// Currently only int8 and unit8 are supported.
// Currently only int8, int16 and uint8 are supported.
//
// Note: This is a private API, subject to change.
TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder,

View File

@ -25,24 +25,20 @@ int main(int argc, char** argv) {
return 1;
}
if (strcmp(argv[3], "uint8") && strcmp(argv[3], "int8")) {
printf("Only support uint8 and int8 for input interface");
return 1;
}
if (strcmp(argv[4], "uint8") && strcmp(argv[4], "int8")) {
printf("Only support uint8 and int8 for output interface");
return 1;
}
const std::unordered_map<std::string, tflite::TensorType> supported_types{
{"uint8", tflite::TensorType_UINT8},
{"int8", tflite::TensorType_INT8},
{"int16", tflite::TensorType_INT16}};
tflite::TensorType input = tflite::TensorType_INT8;
tflite::TensorType output = tflite::TensorType_INT8;
if (!strcmp(argv[3], "uint8")) {
input = tflite::TensorType_UINT8;
}
if (!strcmp(argv[4], "uint8")) {
output = tflite::TensorType_UINT8;
try {
input = supported_types.at(argv[3]);
output = supported_types.at(argv[4]);
} catch (const std::out_of_range&) {
printf("Only supports uint8, int8 and int16 for input and output types");
return 1;
}
tflite::optimize::ModifyModelInterface(argv[1], argv[2], input, output);

View File

@ -26,8 +26,11 @@ namespace tflite {
namespace optimize {
namespace {
// Create a quantized model with 1 quant, 1 FC, 1 dequant
std::unique_ptr<ModelT> CreateQuantizedModelSingleInputOutput() {
using ::testing::ElementsAreArray;
// Create a model with 1 quant, 1 FC, 1 dequant
std::unique_ptr<ModelT> CreateQuantizedModelSingleInputOutput(
const TensorType& quantization_type) {
auto model = absl::make_unique<ModelT>();
auto subgraph = absl::make_unique<tflite::SubGraphT>();
auto buffer = absl::make_unique<tflite::BufferT>();
@ -87,7 +90,7 @@ std::unique_ptr<ModelT> CreateQuantizedModelSingleInputOutput() {
tensor_1->quantization->zero_point.push_back(28);
tensor_1->name = "tensor_1";
tensor_1->shape = {};
tensor_1->type = TensorType_INT8;
tensor_1->type = quantization_type;
auto tensor_2 = absl::make_unique<TensorT>();
tensor_2->quantization = absl::make_unique<QuantizationParametersT>();
@ -95,7 +98,7 @@ std::unique_ptr<ModelT> CreateQuantizedModelSingleInputOutput() {
tensor_2->quantization->zero_point.push_back(50);
tensor_2->name = "tensor_2";
tensor_2->shape = {};
tensor_2->type = TensorType_INT8;
tensor_2->type = quantization_type;
auto tensor_3 = absl::make_unique<TensorT>();
tensor_3->name = "tensor_3";
@ -113,8 +116,10 @@ std::unique_ptr<ModelT> CreateQuantizedModelSingleInputOutput() {
return model;
}
// Create a quantized model with 2 quant, 1 FC, 2 dequant
std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput() {
// Create a model with 2 quant, 1 FC, 2 dequant
// The model mimics the behavior of the quantize_model.cc.
std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput(
const TensorType& quantization_type) {
auto model = absl::make_unique<ModelT>();
auto subgraph = absl::make_unique<tflite::SubGraphT>();
auto buffer = absl::make_unique<tflite::BufferT>();
@ -189,7 +194,7 @@ std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput() {
tensor_2->quantization->zero_point.push_back(28);
tensor_2->name = "tensor_2";
tensor_2->shape = {};
tensor_2->type = TensorType_INT8;
tensor_2->type = quantization_type;
auto tensor_3 = absl::make_unique<TensorT>();
tensor_3->quantization = absl::make_unique<QuantizationParametersT>();
@ -197,7 +202,7 @@ std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput() {
tensor_3->quantization->zero_point.push_back(50);
tensor_3->name = "tensor_3";
tensor_3->shape = {};
tensor_3->type = TensorType_INT8;
tensor_3->type = quantization_type;
auto tensor_4 = absl::make_unique<TensorT>();
tensor_4->quantization = absl::make_unique<QuantizationParametersT>();
@ -205,7 +210,7 @@ std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput() {
tensor_4->quantization->zero_point.push_back(28);
tensor_4->name = "tensor_4";
tensor_4->shape = {};
tensor_4->type = TensorType_INT8;
tensor_4->type = quantization_type;
auto tensor_5 = absl::make_unique<TensorT>();
tensor_5->quantization = absl::make_unique<QuantizationParametersT>();
@ -213,7 +218,7 @@ std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput() {
tensor_5->quantization->zero_point.push_back(50);
tensor_5->name = "tensor_5";
tensor_5->shape = {};
tensor_5->type = TensorType_INT8;
tensor_5->type = quantization_type;
auto tensor_6 = absl::make_unique<TensorT>();
tensor_6->name = "tensor_6";
@ -286,8 +291,141 @@ std::unique_ptr<ModelT> CreateFloatModel() {
return model;
}
struct ModelInterface : ::testing::TestWithParam<tflite::TensorType> {};
TEST_P(ModelInterface, SingleInputOutput) {
TensorType quantization_type = GetParam();
auto model = CreateQuantizedModelSingleInputOutput(quantization_type);
// Change model type.
flatbuffers::FlatBufferBuilder builder;
EXPECT_EQ(ModifyModelInterface(&builder, model.get(), quantization_type,
quantization_type),
kTfLiteOk);
// Verify results.
EXPECT_EQ(model->subgraphs.size(), 1);
// TODO(mnatraj): The float input tensor has not been removed.
// EXPECT_EQ(model->subgraphs[0]->tensors.size(), 2);
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3);
EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->inputs[0], 1);
EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->outputs[0], 2);
EXPECT_EQ(model->operator_codes.size(), 3);
EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1);
auto fc_op = model->subgraphs[0]->operators[0].get();
auto input = model->subgraphs[0]->tensors[fc_op->inputs[0]].get();
EXPECT_EQ(input->name, "tensor_1");
EXPECT_EQ(input->type, quantization_type);
EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35);
EXPECT_EQ(input->quantization->zero_point[0], 28);
auto output = model->subgraphs[0]->tensors[fc_op->outputs[0]].get();
EXPECT_EQ(output->name, "tensor_2");
EXPECT_EQ(output->type, quantization_type);
EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12);
EXPECT_EQ(output->quantization->zero_point[0], 50);
}
TEST_P(ModelInterface, MutipleInputOutput) {
TensorType quantization_type = GetParam();
auto model = CreateQuantizedModelMultipleInputOutput(quantization_type);
// Change model type.
flatbuffers::FlatBufferBuilder builder;
EXPECT_EQ(ModifyModelInterface(&builder, model.get(), quantization_type,
quantization_type),
kTfLiteOk);
// Verify results.
EXPECT_EQ(model->subgraphs.size(), 1);
// TODO (b/158254056): Remove unused inputs and outputs from tensor list
// EXPECT_EQ(model->subgraphs[0]->tensors.size(), 4);
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 6);
EXPECT_EQ(model->subgraphs[0]->inputs.size(), 2);
EXPECT_EQ(model->subgraphs[0]->inputs[0], 2);
EXPECT_EQ(model->subgraphs[0]->inputs[1], 3);
EXPECT_EQ(model->subgraphs[0]->outputs.size(), 2);
EXPECT_EQ(model->subgraphs[0]->outputs[0], 4);
EXPECT_EQ(model->subgraphs[0]->outputs[1], 5);
EXPECT_EQ(model->operator_codes.size(), 3);
EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1);
auto fc_op = model->subgraphs[0]->operators[0].get();
auto input_1 = model->subgraphs[0]->tensors[fc_op->inputs[0]].get();
EXPECT_EQ(input_1->name, "tensor_2");
EXPECT_EQ(input_1->type, quantization_type);
EXPECT_FLOAT_EQ(input_1->quantization->scale[0], 0.35);
EXPECT_EQ(input_1->quantization->zero_point[0], 28);
auto input_2 = model->subgraphs[0]->tensors[fc_op->inputs[1]].get();
EXPECT_EQ(input_2->name, "tensor_3");
EXPECT_EQ(input_2->type, quantization_type);
EXPECT_FLOAT_EQ(input_2->quantization->scale[0], 0.12);
EXPECT_EQ(input_2->quantization->zero_point[0], 50);
auto output_1 = model->subgraphs[0]->tensors[fc_op->outputs[0]].get();
EXPECT_EQ(output_1->name, "tensor_4");
EXPECT_EQ(output_1->type, quantization_type);
EXPECT_FLOAT_EQ(output_1->quantization->scale[0], 0.45);
EXPECT_EQ(output_1->quantization->zero_point[0], 28);
auto output_2 = model->subgraphs[0]->tensors[fc_op->outputs[1]].get();
EXPECT_EQ(output_2->name, "tensor_5");
EXPECT_EQ(output_2->type, quantization_type);
EXPECT_FLOAT_EQ(output_2->quantization->scale[0], 0.22);
EXPECT_EQ(output_2->quantization->zero_point[0], 50);
}
INSTANTIATE_TEST_SUITE_P(MultipleInputOutputTests, ModelInterface,
::testing::Values(TensorType_INT8, TensorType_INT16));
TEST(ModelInterface, MixedTypeSingleInputOutput) {
auto model = CreateQuantizedModelSingleInputOutput(TensorType_INT8);
// Change model type.
flatbuffers::FlatBufferBuilder builder;
EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_UINT8,
TensorType_INT8),
kTfLiteOk);
// Verify results.
EXPECT_EQ(model->subgraphs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3);
EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->inputs[0], 0);
EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->outputs[0], 2);
EXPECT_EQ(model->operator_codes.size(), 3);
EXPECT_EQ(model->subgraphs[0]->operators.size(), 2);
EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 0);
EXPECT_EQ(model->subgraphs[0]->operators[1]->opcode_index, 1);
auto quant_op = model->subgraphs[0]->operators[0].get();
auto input = model->subgraphs[0]->tensors[quant_op->inputs[0]].get();
EXPECT_EQ(input->name, "tensor_0");
EXPECT_EQ(input->type, TensorType_UINT8);
EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35);
EXPECT_EQ(input->quantization->zero_point[0], 156);
auto fc_op = model->subgraphs[0]->operators[1].get();
auto output = model->subgraphs[0]->tensors[fc_op->outputs[0]].get();
EXPECT_EQ(output->name, "tensor_2");
EXPECT_EQ(output->type, TensorType_INT8);
EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12);
EXPECT_EQ(output->quantization->zero_point[0], 50);
}
TEST(ModelInterface, Uint8SingleInputOutput) {
auto model = CreateQuantizedModelSingleInputOutput();
auto model = CreateQuantizedModelSingleInputOutput(TensorType_INT8);
// Change model type.
flatbuffers::FlatBufferBuilder builder;
@ -323,81 +461,8 @@ TEST(ModelInterface, Uint8SingleInputOutput) {
EXPECT_EQ(output->quantization->zero_point[0], 178);
}
TEST(ModelInterface, Int8SingleInputOutput) {
auto model = CreateQuantizedModelSingleInputOutput();
// Change model type.
flatbuffers::FlatBufferBuilder builder;
EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_INT8,
TensorType_INT8),
kTfLiteOk);
// Verify results.
EXPECT_EQ(model->subgraphs.size(), 1);
// TODO(mnatraj): The float input tensor has not been removed.
// EXPECT_EQ(model->subgraphs[0]->tensors.size(), 2);
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3);
EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->inputs[0], 1);
EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->outputs[0], 2);
EXPECT_EQ(model->operator_codes.size(), 3);
EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1);
auto fc_op = model->subgraphs[0]->operators[0].get();
auto input = model->subgraphs[0]->tensors[fc_op->inputs[0]].get();
EXPECT_EQ(input->name, "tensor_1");
EXPECT_EQ(input->type, TensorType_INT8);
EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35);
EXPECT_EQ(input->quantization->zero_point[0], 28);
auto output = model->subgraphs[0]->tensors[fc_op->outputs[0]].get();
EXPECT_EQ(output->name, "tensor_2");
EXPECT_EQ(output->type, TensorType_INT8);
EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12);
EXPECT_EQ(output->quantization->zero_point[0], 50);
}
TEST(ModelInterface, MixedTypeSingleInputOutput) {
auto model = CreateQuantizedModelSingleInputOutput();
// Change model type.
flatbuffers::FlatBufferBuilder builder;
EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_UINT8,
TensorType_INT8),
kTfLiteOk);
// Verify results.
EXPECT_EQ(model->subgraphs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3);
EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->inputs[0], 0);
EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->outputs[0], 2);
EXPECT_EQ(model->operator_codes.size(), 3);
EXPECT_EQ(model->subgraphs[0]->operators.size(), 2);
EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 0);
EXPECT_EQ(model->subgraphs[0]->operators[1]->opcode_index, 1);
auto quant_op = model->subgraphs[0]->operators[0].get();
auto input = model->subgraphs[0]->tensors[quant_op->inputs[0]].get();
EXPECT_EQ(input->name, "tensor_0");
EXPECT_EQ(input->type, TensorType_UINT8);
EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35);
EXPECT_EQ(input->quantization->zero_point[0], 156);
auto fc_op = model->subgraphs[0]->operators[1].get();
auto output = model->subgraphs[0]->tensors[fc_op->outputs[0]].get();
EXPECT_EQ(output->name, "tensor_2");
EXPECT_EQ(output->type, TensorType_INT8);
EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12);
EXPECT_EQ(output->quantization->zero_point[0], 50);
}
TEST(ModelInterface, Uint8MutipleInputOutput) {
auto model = CreateQuantizedModelMultipleInputOutput();
auto model = CreateQuantizedModelMultipleInputOutput(TensorType_INT8);
// Change model type.
flatbuffers::FlatBufferBuilder builder;
@ -454,7 +519,7 @@ TEST(ModelInterface, Uint8MutipleInputOutput) {
}
TEST(ModelInterface, Int8MutipleInputOutput) {
auto model = CreateQuantizedModelMultipleInputOutput();
auto model = CreateQuantizedModelMultipleInputOutput(TensorType_INT8);
// Change model type.
flatbuffers::FlatBufferBuilder builder;