Merge pull request #40190 from TamasArm:16_bit_support_for_modify_model_interface
PiperOrigin-RevId: 320997124 Change-Id: I09242d15c14365d72d43dc1d8e911e96f12edf82
This commit is contained in:
commit
3fc7725cc2
@ -45,7 +45,8 @@ struct TensorOpTensor {
|
||||
|
||||
// Finds float tensors that are model inputs and is consumed by a quantize Op.
|
||||
// The returned TensorOpTensor should have reverse order.
|
||||
std::vector<TensorOpTensor> GetInputTensors(ModelT* model,
|
||||
std::vector<TensorOpTensor> GetInputTensors(const TensorType& input_type,
|
||||
ModelT* model,
|
||||
ErrorReporter* error_reporter) {
|
||||
std::vector<TensorOpTensor> result;
|
||||
// Get all input tensors.
|
||||
@ -71,7 +72,7 @@ std::vector<TensorOpTensor> GetInputTensors(ModelT* model,
|
||||
continue;
|
||||
}
|
||||
if (op_code != BuiltinOperator_QUANTIZE) {
|
||||
// Current only support INT8 quantized models.
|
||||
// Currently only supports int8 and int16 quantized models.
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter,
|
||||
"modify_model_interface called on a model without quant/dequant.");
|
||||
@ -85,10 +86,27 @@ std::vector<TensorOpTensor> GetInputTensors(ModelT* model,
|
||||
}
|
||||
const int model_input_index = input_tensors[input_tensor];
|
||||
TensorT* quant_output = subgraph->tensors[op->outputs[0]].get();
|
||||
if (quant_output->type != TensorType_INT8) {
|
||||
if (quant_output->type != TensorType_INT8 &&
|
||||
quant_output->type != TensorType_INT16) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
"modify_model_interface currently only support "
|
||||
"int8 quantized models.");
|
||||
"modify_model_interface currently only supports "
|
||||
"int8 and int16 quantized models.");
|
||||
}
|
||||
|
||||
// The input type must be the same as the model quantization type
|
||||
if (input_type != quant_output->type) {
|
||||
// An exception, allow for UINT8 input type for INT8 quantized model.
|
||||
if (!(input_type == TensorType_UINT8 &&
|
||||
quant_output->type == TensorType_INT8)) {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter,
|
||||
"The %s input type is incompatible with %s quantized models. "
|
||||
"To resolve this error, change the input_type to a compatible "
|
||||
"one. "
|
||||
"See: modify_model_interface.cc",
|
||||
EnumNameTensorType(input_type),
|
||||
EnumNameTensorType(quant_output->type));
|
||||
}
|
||||
}
|
||||
if (quant_output->quantization == nullptr) {
|
||||
continue;
|
||||
@ -102,7 +120,8 @@ std::vector<TensorOpTensor> GetInputTensors(ModelT* model,
|
||||
|
||||
// Finds float tensors that are model output and is consumed by a dequantize Op.
|
||||
// The returned TensorOpTensor should have reverse order.
|
||||
std::vector<TensorOpTensor> GetOutputTensors(ModelT* model,
|
||||
std::vector<TensorOpTensor> GetOutputTensors(const TensorType& output_type,
|
||||
ModelT* model,
|
||||
ErrorReporter* error_reporter) {
|
||||
std::vector<TensorOpTensor> result;
|
||||
// Get all output tensors.
|
||||
@ -128,7 +147,7 @@ std::vector<TensorOpTensor> GetOutputTensors(ModelT* model,
|
||||
continue;
|
||||
}
|
||||
if (op_code != BuiltinOperator_DEQUANTIZE) {
|
||||
// Current only support INT8 quantized models.
|
||||
// Currently only supports int8 and int16 quantized models.
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter,
|
||||
"modify_model_interface called on a model without quant/dequant.");
|
||||
@ -142,13 +161,28 @@ std::vector<TensorOpTensor> GetOutputTensors(ModelT* model,
|
||||
}
|
||||
const int model_output_index = output_tensors[output_tensor];
|
||||
TensorT* dequant_input = subgraph->tensors[op->inputs[0]].get();
|
||||
if (dequant_input->type != TensorType_INT8) {
|
||||
// Current only support INT8 quantized models.
|
||||
if (dequant_input->type != TensorType_INT8 &&
|
||||
dequant_input->type != TensorType_INT16) {
|
||||
// Currently only supports int8 and int16 quantized models.
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
"modify_model_interface currently only support "
|
||||
"int8 quantized models.");
|
||||
"modify_model_interface currently only supports "
|
||||
"int8 and int16 quantized models.");
|
||||
return {};
|
||||
}
|
||||
if (output_type != dequant_input->type) {
|
||||
// An exception, allow for UINT8 input type for INT8 quantized model.
|
||||
if (!(output_type == TensorType_UINT8 &&
|
||||
dequant_input->type == TensorType_INT8)) {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter,
|
||||
"The %s output type is incompatible with %s quantized models. "
|
||||
"To resolve this error, change the output_type to a compatible "
|
||||
"one. "
|
||||
"See: modify_model_interface.cc",
|
||||
EnumNameTensorType(output_type),
|
||||
EnumNameTensorType(dequant_input->type));
|
||||
}
|
||||
}
|
||||
if (dequant_input->quantization == nullptr) {
|
||||
continue;
|
||||
}
|
||||
@ -288,9 +322,13 @@ std::unique_ptr<tflite::ModelT> CreateMutableModelFromFile(
|
||||
return copied_model;
|
||||
}
|
||||
|
||||
int GetOriginalNumberOfTensors(ModelT* model, ErrorReporter* error_reporter) {
|
||||
std::vector<TensorOpTensor> outputs = GetOutputTensors(model, error_reporter);
|
||||
std::vector<TensorOpTensor> inputs = GetInputTensors(model, error_reporter);
|
||||
int GetOriginalNumberOfTensors(const TensorType& input_type,
|
||||
const TensorType& output_type, ModelT* model,
|
||||
ErrorReporter* error_reporter) {
|
||||
std::vector<TensorOpTensor> outputs =
|
||||
GetOutputTensors(output_type, model, error_reporter);
|
||||
std::vector<TensorOpTensor> inputs =
|
||||
GetInputTensors(input_type, model, error_reporter);
|
||||
return model->subgraphs[0]->tensors.size() - outputs.size() - inputs.size();
|
||||
}
|
||||
|
||||
@ -300,30 +338,39 @@ TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type) {
|
||||
tflite::StderrReporter error_reporter;
|
||||
const int original_number_tensors =
|
||||
GetOriginalNumberOfTensors(model, &error_reporter);
|
||||
// Find float tensors that are model output and is consumed by a float to int8
|
||||
// quantize Op.
|
||||
// Do output first since the tensors are added into input first.,
|
||||
const int original_number_tensors = GetOriginalNumberOfTensors(
|
||||
input_type, output_type, model, &error_reporter);
|
||||
// Finds float tensors that are model output and are consumed by a float to
|
||||
// int8/int16 quantize Op. Do output first since the tensors are added into
|
||||
// input first.,
|
||||
std::vector<TensorOpTensor> outputs =
|
||||
GetOutputTensors(model, &error_reporter);
|
||||
if (output_type == TensorType_UINT8) {
|
||||
SetOutputTypeToUINT8(model, outputs);
|
||||
} else if (output_type == TensorType_INT8) {
|
||||
RemoveOutputTensor(model, outputs, original_number_tensors);
|
||||
} else {
|
||||
return kTfLiteError;
|
||||
GetOutputTensors(output_type, model, &error_reporter);
|
||||
switch (output_type) {
|
||||
case TensorType_UINT8:
|
||||
SetOutputTypeToUINT8(model, outputs);
|
||||
break;
|
||||
case TensorType_INT8:
|
||||
case TensorType_INT16:
|
||||
RemoveOutputTensor(model, outputs, original_number_tensors);
|
||||
break;
|
||||
default:
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Find float tensors that are model input and is consumed by a float to int8
|
||||
// quantize Op.
|
||||
std::vector<TensorOpTensor> inputs = GetInputTensors(model, &error_reporter);
|
||||
if (input_type == TensorType_UINT8) {
|
||||
SetInputTypeToUINT8(model, inputs);
|
||||
} else if (input_type == TensorType_INT8) {
|
||||
RemoveInputTensor(model, inputs, original_number_tensors);
|
||||
} else {
|
||||
return kTfLiteError;
|
||||
// Find float tensors that are model input and is consumed by a float to
|
||||
// int8/int16 quantize Op.
|
||||
std::vector<TensorOpTensor> inputs =
|
||||
GetInputTensors(input_type, model, &error_reporter);
|
||||
switch (input_type) {
|
||||
case TensorType_UINT8:
|
||||
SetInputTypeToUINT8(model, inputs);
|
||||
break;
|
||||
case TensorType_INT8:
|
||||
case TensorType_INT16:
|
||||
RemoveInputTensor(model, inputs, original_number_tensors);
|
||||
break;
|
||||
default:
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Write to builder.
|
||||
@ -340,11 +387,13 @@ TfLiteStatus ModifyModelInterface(const string& input_file,
|
||||
const TensorType& output_type) {
|
||||
// Consistency Check
|
||||
if (input_type != tflite::TensorType_INT8 &&
|
||||
input_type != tflite::TensorType_UINT8) {
|
||||
input_type != tflite::TensorType_UINT8 &&
|
||||
input_type != tflite::TensorType_INT16) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (output_type != tflite::TensorType_INT8 &&
|
||||
output_type != tflite::TensorType_UINT8) {
|
||||
output_type != tflite::TensorType_UINT8 &&
|
||||
output_type != tflite::TensorType_INT16) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@ -357,17 +406,8 @@ TfLiteStatus ModifyModelInterface(const string& input_file,
|
||||
absl::make_unique<flatbuffers::FlatBufferBuilder>();
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
|
||||
tflite::TensorType input_override_type = tflite::TensorType_INT8;
|
||||
if (input_type == tflite::TensorType_UINT8) {
|
||||
input_override_type = tflite::TensorType_UINT8;
|
||||
}
|
||||
tflite::TensorType output_override_type = tflite::TensorType_INT8;
|
||||
if (output_type == tflite::TensorType_UINT8) {
|
||||
output_override_type = tflite::TensorType_UINT8;
|
||||
}
|
||||
|
||||
auto status = ModifyModelInterface(&builder, tflite_model.get(),
|
||||
input_override_type, output_override_type);
|
||||
auto status = ModifyModelInterface(&builder, tflite_model.get(), input_type,
|
||||
output_type);
|
||||
TFLITE_DCHECK_EQ(status, kTfLiteOk);
|
||||
|
||||
WriteFile(output_file, builder.GetBufferPointer(), builder.GetSize());
|
||||
|
@ -24,7 +24,7 @@ namespace optimize {
|
||||
// Changes the interface of a quantized model. This method allows the users to
|
||||
// replace float interface with other types.
|
||||
// This populates the builder with the new model.
|
||||
// Currently only int8 and unit8 are supported.
|
||||
// Currently only int8, int16 and uint8 are supported.
|
||||
//
|
||||
// Note: This is a private API, subject to change.
|
||||
TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder,
|
||||
|
@ -25,24 +25,20 @@ int main(int argc, char** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (strcmp(argv[3], "uint8") && strcmp(argv[3], "int8")) {
|
||||
printf("Only support uint8 and int8 for input interface");
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (strcmp(argv[4], "uint8") && strcmp(argv[4], "int8")) {
|
||||
printf("Only support uint8 and int8 for output interface");
|
||||
return 1;
|
||||
}
|
||||
const std::unordered_map<std::string, tflite::TensorType> supported_types{
|
||||
{"uint8", tflite::TensorType_UINT8},
|
||||
{"int8", tflite::TensorType_INT8},
|
||||
{"int16", tflite::TensorType_INT16}};
|
||||
|
||||
tflite::TensorType input = tflite::TensorType_INT8;
|
||||
tflite::TensorType output = tflite::TensorType_INT8;
|
||||
|
||||
if (!strcmp(argv[3], "uint8")) {
|
||||
input = tflite::TensorType_UINT8;
|
||||
}
|
||||
if (!strcmp(argv[4], "uint8")) {
|
||||
output = tflite::TensorType_UINT8;
|
||||
try {
|
||||
input = supported_types.at(argv[3]);
|
||||
output = supported_types.at(argv[4]);
|
||||
} catch (const std::out_of_range&) {
|
||||
printf("Only supports uint8, int8 and int16 for input and output types");
|
||||
return 1;
|
||||
}
|
||||
|
||||
tflite::optimize::ModifyModelInterface(argv[1], argv[2], input, output);
|
||||
|
@ -26,8 +26,11 @@ namespace tflite {
|
||||
namespace optimize {
|
||||
namespace {
|
||||
|
||||
// Create a quantized model with 1 quant, 1 FC, 1 dequant
|
||||
std::unique_ptr<ModelT> CreateQuantizedModelSingleInputOutput() {
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
// Create a model with 1 quant, 1 FC, 1 dequant
|
||||
std::unique_ptr<ModelT> CreateQuantizedModelSingleInputOutput(
|
||||
const TensorType& quantization_type) {
|
||||
auto model = absl::make_unique<ModelT>();
|
||||
auto subgraph = absl::make_unique<tflite::SubGraphT>();
|
||||
auto buffer = absl::make_unique<tflite::BufferT>();
|
||||
@ -87,7 +90,7 @@ std::unique_ptr<ModelT> CreateQuantizedModelSingleInputOutput() {
|
||||
tensor_1->quantization->zero_point.push_back(28);
|
||||
tensor_1->name = "tensor_1";
|
||||
tensor_1->shape = {};
|
||||
tensor_1->type = TensorType_INT8;
|
||||
tensor_1->type = quantization_type;
|
||||
|
||||
auto tensor_2 = absl::make_unique<TensorT>();
|
||||
tensor_2->quantization = absl::make_unique<QuantizationParametersT>();
|
||||
@ -95,7 +98,7 @@ std::unique_ptr<ModelT> CreateQuantizedModelSingleInputOutput() {
|
||||
tensor_2->quantization->zero_point.push_back(50);
|
||||
tensor_2->name = "tensor_2";
|
||||
tensor_2->shape = {};
|
||||
tensor_2->type = TensorType_INT8;
|
||||
tensor_2->type = quantization_type;
|
||||
|
||||
auto tensor_3 = absl::make_unique<TensorT>();
|
||||
tensor_3->name = "tensor_3";
|
||||
@ -113,8 +116,10 @@ std::unique_ptr<ModelT> CreateQuantizedModelSingleInputOutput() {
|
||||
return model;
|
||||
}
|
||||
|
||||
// Create a quantized model with 2 quant, 1 FC, 2 dequant
|
||||
std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput() {
|
||||
// Create a model with 2 quant, 1 FC, 2 dequant
|
||||
// The model mimics the behavior of the quantize_model.cc.
|
||||
std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput(
|
||||
const TensorType& quantization_type) {
|
||||
auto model = absl::make_unique<ModelT>();
|
||||
auto subgraph = absl::make_unique<tflite::SubGraphT>();
|
||||
auto buffer = absl::make_unique<tflite::BufferT>();
|
||||
@ -189,7 +194,7 @@ std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput() {
|
||||
tensor_2->quantization->zero_point.push_back(28);
|
||||
tensor_2->name = "tensor_2";
|
||||
tensor_2->shape = {};
|
||||
tensor_2->type = TensorType_INT8;
|
||||
tensor_2->type = quantization_type;
|
||||
|
||||
auto tensor_3 = absl::make_unique<TensorT>();
|
||||
tensor_3->quantization = absl::make_unique<QuantizationParametersT>();
|
||||
@ -197,7 +202,7 @@ std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput() {
|
||||
tensor_3->quantization->zero_point.push_back(50);
|
||||
tensor_3->name = "tensor_3";
|
||||
tensor_3->shape = {};
|
||||
tensor_3->type = TensorType_INT8;
|
||||
tensor_3->type = quantization_type;
|
||||
|
||||
auto tensor_4 = absl::make_unique<TensorT>();
|
||||
tensor_4->quantization = absl::make_unique<QuantizationParametersT>();
|
||||
@ -205,7 +210,7 @@ std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput() {
|
||||
tensor_4->quantization->zero_point.push_back(28);
|
||||
tensor_4->name = "tensor_4";
|
||||
tensor_4->shape = {};
|
||||
tensor_4->type = TensorType_INT8;
|
||||
tensor_4->type = quantization_type;
|
||||
|
||||
auto tensor_5 = absl::make_unique<TensorT>();
|
||||
tensor_5->quantization = absl::make_unique<QuantizationParametersT>();
|
||||
@ -213,7 +218,7 @@ std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput() {
|
||||
tensor_5->quantization->zero_point.push_back(50);
|
||||
tensor_5->name = "tensor_5";
|
||||
tensor_5->shape = {};
|
||||
tensor_5->type = TensorType_INT8;
|
||||
tensor_5->type = quantization_type;
|
||||
|
||||
auto tensor_6 = absl::make_unique<TensorT>();
|
||||
tensor_6->name = "tensor_6";
|
||||
@ -286,8 +291,141 @@ std::unique_ptr<ModelT> CreateFloatModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
struct ModelInterface : ::testing::TestWithParam<tflite::TensorType> {};
|
||||
|
||||
TEST_P(ModelInterface, SingleInputOutput) {
|
||||
TensorType quantization_type = GetParam();
|
||||
|
||||
auto model = CreateQuantizedModelSingleInputOutput(quantization_type);
|
||||
|
||||
// Change model type.
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
EXPECT_EQ(ModifyModelInterface(&builder, model.get(), quantization_type,
|
||||
quantization_type),
|
||||
kTfLiteOk);
|
||||
|
||||
// Verify results.
|
||||
EXPECT_EQ(model->subgraphs.size(), 1);
|
||||
// TODO(mnatraj): The float input tensor has not been removed.
|
||||
// EXPECT_EQ(model->subgraphs[0]->tensors.size(), 2);
|
||||
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3);
|
||||
EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->inputs[0], 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->outputs[0], 2);
|
||||
EXPECT_EQ(model->operator_codes.size(), 3);
|
||||
EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1);
|
||||
|
||||
auto fc_op = model->subgraphs[0]->operators[0].get();
|
||||
|
||||
auto input = model->subgraphs[0]->tensors[fc_op->inputs[0]].get();
|
||||
EXPECT_EQ(input->name, "tensor_1");
|
||||
EXPECT_EQ(input->type, quantization_type);
|
||||
EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35);
|
||||
EXPECT_EQ(input->quantization->zero_point[0], 28);
|
||||
|
||||
auto output = model->subgraphs[0]->tensors[fc_op->outputs[0]].get();
|
||||
EXPECT_EQ(output->name, "tensor_2");
|
||||
EXPECT_EQ(output->type, quantization_type);
|
||||
EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12);
|
||||
EXPECT_EQ(output->quantization->zero_point[0], 50);
|
||||
}
|
||||
|
||||
TEST_P(ModelInterface, MutipleInputOutput) {
|
||||
TensorType quantization_type = GetParam();
|
||||
|
||||
auto model = CreateQuantizedModelMultipleInputOutput(quantization_type);
|
||||
|
||||
// Change model type.
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
EXPECT_EQ(ModifyModelInterface(&builder, model.get(), quantization_type,
|
||||
quantization_type),
|
||||
kTfLiteOk);
|
||||
|
||||
// Verify results.
|
||||
EXPECT_EQ(model->subgraphs.size(), 1);
|
||||
// TODO (b/158254056): Remove unused inputs and outputs from tensor list
|
||||
// EXPECT_EQ(model->subgraphs[0]->tensors.size(), 4);
|
||||
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 6);
|
||||
EXPECT_EQ(model->subgraphs[0]->inputs.size(), 2);
|
||||
EXPECT_EQ(model->subgraphs[0]->inputs[0], 2);
|
||||
EXPECT_EQ(model->subgraphs[0]->inputs[1], 3);
|
||||
EXPECT_EQ(model->subgraphs[0]->outputs.size(), 2);
|
||||
EXPECT_EQ(model->subgraphs[0]->outputs[0], 4);
|
||||
EXPECT_EQ(model->subgraphs[0]->outputs[1], 5);
|
||||
EXPECT_EQ(model->operator_codes.size(), 3);
|
||||
EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1);
|
||||
|
||||
auto fc_op = model->subgraphs[0]->operators[0].get();
|
||||
|
||||
auto input_1 = model->subgraphs[0]->tensors[fc_op->inputs[0]].get();
|
||||
EXPECT_EQ(input_1->name, "tensor_2");
|
||||
EXPECT_EQ(input_1->type, quantization_type);
|
||||
EXPECT_FLOAT_EQ(input_1->quantization->scale[0], 0.35);
|
||||
EXPECT_EQ(input_1->quantization->zero_point[0], 28);
|
||||
|
||||
auto input_2 = model->subgraphs[0]->tensors[fc_op->inputs[1]].get();
|
||||
EXPECT_EQ(input_2->name, "tensor_3");
|
||||
EXPECT_EQ(input_2->type, quantization_type);
|
||||
EXPECT_FLOAT_EQ(input_2->quantization->scale[0], 0.12);
|
||||
EXPECT_EQ(input_2->quantization->zero_point[0], 50);
|
||||
|
||||
auto output_1 = model->subgraphs[0]->tensors[fc_op->outputs[0]].get();
|
||||
EXPECT_EQ(output_1->name, "tensor_4");
|
||||
EXPECT_EQ(output_1->type, quantization_type);
|
||||
EXPECT_FLOAT_EQ(output_1->quantization->scale[0], 0.45);
|
||||
EXPECT_EQ(output_1->quantization->zero_point[0], 28);
|
||||
|
||||
auto output_2 = model->subgraphs[0]->tensors[fc_op->outputs[1]].get();
|
||||
EXPECT_EQ(output_2->name, "tensor_5");
|
||||
EXPECT_EQ(output_2->type, quantization_type);
|
||||
EXPECT_FLOAT_EQ(output_2->quantization->scale[0], 0.22);
|
||||
EXPECT_EQ(output_2->quantization->zero_point[0], 50);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(MultipleInputOutputTests, ModelInterface,
|
||||
::testing::Values(TensorType_INT8, TensorType_INT16));
|
||||
|
||||
TEST(ModelInterface, MixedTypeSingleInputOutput) {
|
||||
auto model = CreateQuantizedModelSingleInputOutput(TensorType_INT8);
|
||||
|
||||
// Change model type.
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_UINT8,
|
||||
TensorType_INT8),
|
||||
kTfLiteOk);
|
||||
|
||||
// Verify results.
|
||||
EXPECT_EQ(model->subgraphs.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3);
|
||||
EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->inputs[0], 0);
|
||||
EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->outputs[0], 2);
|
||||
EXPECT_EQ(model->operator_codes.size(), 3);
|
||||
EXPECT_EQ(model->subgraphs[0]->operators.size(), 2);
|
||||
EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 0);
|
||||
EXPECT_EQ(model->subgraphs[0]->operators[1]->opcode_index, 1);
|
||||
|
||||
auto quant_op = model->subgraphs[0]->operators[0].get();
|
||||
auto input = model->subgraphs[0]->tensors[quant_op->inputs[0]].get();
|
||||
EXPECT_EQ(input->name, "tensor_0");
|
||||
EXPECT_EQ(input->type, TensorType_UINT8);
|
||||
EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35);
|
||||
EXPECT_EQ(input->quantization->zero_point[0], 156);
|
||||
|
||||
auto fc_op = model->subgraphs[0]->operators[1].get();
|
||||
auto output = model->subgraphs[0]->tensors[fc_op->outputs[0]].get();
|
||||
EXPECT_EQ(output->name, "tensor_2");
|
||||
EXPECT_EQ(output->type, TensorType_INT8);
|
||||
EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12);
|
||||
EXPECT_EQ(output->quantization->zero_point[0], 50);
|
||||
}
|
||||
|
||||
TEST(ModelInterface, Uint8SingleInputOutput) {
|
||||
auto model = CreateQuantizedModelSingleInputOutput();
|
||||
auto model = CreateQuantizedModelSingleInputOutput(TensorType_INT8);
|
||||
|
||||
// Change model type.
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
@ -323,81 +461,8 @@ TEST(ModelInterface, Uint8SingleInputOutput) {
|
||||
EXPECT_EQ(output->quantization->zero_point[0], 178);
|
||||
}
|
||||
|
||||
TEST(ModelInterface, Int8SingleInputOutput) {
|
||||
auto model = CreateQuantizedModelSingleInputOutput();
|
||||
|
||||
// Change model type.
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_INT8,
|
||||
TensorType_INT8),
|
||||
kTfLiteOk);
|
||||
|
||||
// Verify results.
|
||||
EXPECT_EQ(model->subgraphs.size(), 1);
|
||||
// TODO(mnatraj): The float input tensor has not been removed.
|
||||
// EXPECT_EQ(model->subgraphs[0]->tensors.size(), 2);
|
||||
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3);
|
||||
EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->inputs[0], 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->outputs[0], 2);
|
||||
EXPECT_EQ(model->operator_codes.size(), 3);
|
||||
EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1);
|
||||
|
||||
auto fc_op = model->subgraphs[0]->operators[0].get();
|
||||
|
||||
auto input = model->subgraphs[0]->tensors[fc_op->inputs[0]].get();
|
||||
EXPECT_EQ(input->name, "tensor_1");
|
||||
EXPECT_EQ(input->type, TensorType_INT8);
|
||||
EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35);
|
||||
EXPECT_EQ(input->quantization->zero_point[0], 28);
|
||||
|
||||
auto output = model->subgraphs[0]->tensors[fc_op->outputs[0]].get();
|
||||
EXPECT_EQ(output->name, "tensor_2");
|
||||
EXPECT_EQ(output->type, TensorType_INT8);
|
||||
EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12);
|
||||
EXPECT_EQ(output->quantization->zero_point[0], 50);
|
||||
}
|
||||
|
||||
TEST(ModelInterface, MixedTypeSingleInputOutput) {
|
||||
auto model = CreateQuantizedModelSingleInputOutput();
|
||||
|
||||
// Change model type.
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_UINT8,
|
||||
TensorType_INT8),
|
||||
kTfLiteOk);
|
||||
|
||||
// Verify results.
|
||||
EXPECT_EQ(model->subgraphs.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3);
|
||||
EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->inputs[0], 0);
|
||||
EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->outputs[0], 2);
|
||||
EXPECT_EQ(model->operator_codes.size(), 3);
|
||||
EXPECT_EQ(model->subgraphs[0]->operators.size(), 2);
|
||||
EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 0);
|
||||
EXPECT_EQ(model->subgraphs[0]->operators[1]->opcode_index, 1);
|
||||
|
||||
auto quant_op = model->subgraphs[0]->operators[0].get();
|
||||
auto input = model->subgraphs[0]->tensors[quant_op->inputs[0]].get();
|
||||
EXPECT_EQ(input->name, "tensor_0");
|
||||
EXPECT_EQ(input->type, TensorType_UINT8);
|
||||
EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35);
|
||||
EXPECT_EQ(input->quantization->zero_point[0], 156);
|
||||
|
||||
auto fc_op = model->subgraphs[0]->operators[1].get();
|
||||
auto output = model->subgraphs[0]->tensors[fc_op->outputs[0]].get();
|
||||
EXPECT_EQ(output->name, "tensor_2");
|
||||
EXPECT_EQ(output->type, TensorType_INT8);
|
||||
EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12);
|
||||
EXPECT_EQ(output->quantization->zero_point[0], 50);
|
||||
}
|
||||
|
||||
TEST(ModelInterface, Uint8MutipleInputOutput) {
|
||||
auto model = CreateQuantizedModelMultipleInputOutput();
|
||||
auto model = CreateQuantizedModelMultipleInputOutput(TensorType_INT8);
|
||||
|
||||
// Change model type.
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
@ -454,7 +519,7 @@ TEST(ModelInterface, Uint8MutipleInputOutput) {
|
||||
}
|
||||
|
||||
TEST(ModelInterface, Int8MutipleInputOutput) {
|
||||
auto model = CreateQuantizedModelMultipleInputOutput();
|
||||
auto model = CreateQuantizedModelMultipleInputOutput(TensorType_INT8);
|
||||
|
||||
// Change model type.
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
|
Loading…
Reference in New Issue
Block a user