diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index c74c3f495d3..c3318f1ab26 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -17,6 +17,7 @@ cc_library( srcs = ["modify_model_interface.cc"], hdrs = ["modify_model_interface.h"], deps = [ + ":model_utils", "//tensorflow/lite:framework", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels/internal:compatibility", diff --git a/tensorflow/lite/tools/optimize/modify_model_interface.cc b/tensorflow/lite/tools/optimize/modify_model_interface.cc index 7d51bc03434..bc1e9cbe5a3 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface.cc +++ b/tensorflow/lite/tools/optimize/modify_model_interface.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/tools/optimize/model_utils.h" namespace tflite { namespace optimize { @@ -360,5 +361,108 @@ TfLiteStatus ModifyModelInterface(const string& input_file, return kTfLiteOk; } +namespace { +void AddUint8Dequant( + const std::unordered_map>& quant_params, + ModelT* model) { + for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); + subgraph_idx++) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); + // Add dequant to input tensors. + for (size_t input_idx = 0; input_idx < subgraph->inputs.size(); + input_idx++) { + const int32_t tensor_idx = subgraph->inputs[input_idx]; + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + if (tensor->type != TensorType_FLOAT32) { + continue; + } + if (quant_params.find(tensor->name) != quant_params.end()) { + // Add uint8 tensor + const string added_tensor_name = tensor->name + "_uint8"; + std::unique_ptr leading_op_input; + const std::pair& provided_quant_params = + quant_params.at(string(tensor->name)); + utils::MakeTensorWithQuantParam( + added_tensor_name, tensor->shape, TensorType_UINT8, + provided_quant_params.first, provided_quant_params.second, + &leading_op_input); + const int32_t leading_op_input_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(leading_op_input)); + + // Create the leading op, which is deqantize Op. + std::unique_ptr leading_op; + utils::MakeDequantizeOperator(model, &leading_op, leading_op_input_idx, + tensor_idx); + + // Insert the new op at the start of the model. + subgraph->operators.insert(subgraph->operators.begin(), + std::move(leading_op)); + } + } + } +} + +void AddUint8Quant( + const std::unordered_map>& quant_params, + ModelT* model) { + for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); + subgraph_idx++) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); + // Add quant to output tensors. + for (size_t output_idx = 0; output_idx < subgraph->outputs.size(); + output_idx++) { + const int32_t tensor_idx = subgraph->outputs[output_idx]; + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + if (tensor->type != TensorType_FLOAT32) { + continue; + } + if (quant_params.find(tensor->name) != quant_params.end()) { + // Add uint8 tensor + const string added_tensor_name = tensor->name + "_uint8"; + std::unique_ptr tailing_op_output; + const std::pair& provided_quant_params = + quant_params.at(string(tensor->name)); + utils::MakeTensorWithQuantParam( + added_tensor_name, tensor->shape, TensorType_UINT8, + provided_quant_params.first, provided_quant_params.second, + &tailing_op_output); + const int32_t tailing_op_output_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(tailing_op_output)); + + // Create the tailing op, which is Qantize Op. + std::unique_ptr tailing_op; + utils::MakeQuantizeOperator(model, &tailing_op, tensor_idx, + tailing_op_output_idx); + + // Insert the new op at the end of the model. + subgraph->operators.push_back(std::move(tailing_op)); + } + } + } +} +} // namespace + +TfLiteStatus Uint8QuantizeModelInputsOutputs( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + const std::unordered_map>& + input_quant_params, + const std::unordered_map>& + output_quant_params) { + std::unique_ptr model; + model.reset(input_model->UnPack()); + // Add Dequant for inputs. + AddUint8Dequant(input_quant_params, model.get()); + + // Add Quant for outputs. + AddUint8Quant(output_quant_params, model.get()); + + // Output model. + flatbuffers::Offset output_model_location = + Model::Pack(*builder, model.get()); + FinishModelBuffer(*builder, output_model_location); + + return kTfLiteOk; +} + } // namespace optimize } // namespace tflite diff --git a/tensorflow/lite/tools/optimize/modify_model_interface.h b/tensorflow/lite/tools/optimize/modify_model_interface.h index cfe4f41ff90..170e0e73a67 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface.h +++ b/tensorflow/lite/tools/optimize/modify_model_interface.h @@ -39,6 +39,24 @@ TfLiteStatus ModifyModelInterface(const string& input_file, const TensorType& input_type, const TensorType& output_type); +// Adds uint8 quantize ops for specified inputs and uint8 dequantize ops for +// specified outputs for a float model. The scale and zero point of uint8 +// tensors are provided through quant_params. +// - input_quant_params has a map between tensor name and the +// pair for inputs. +// - output_quant_params has a map between tensor name and the +// pair for inputs. +// For the inputs/output tensors for the model, if its quantization parameters +// are not provided, that tensor is not affected. +// +// Note: This is a private API, subject to change. +TfLiteStatus Uint8QuantizeModelInputsOutputs( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + const std::unordered_map>& + input_quant_params, + const std::unordered_map>& + output_quant_params); + } // namespace optimize } // namespace tflite diff --git a/tensorflow/lite/tools/optimize/modify_model_interface_test.cc b/tensorflow/lite/tools/optimize/modify_model_interface_test.cc index 01d0775c953..7e2744fb1e1 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface_test.cc +++ b/tensorflow/lite/tools/optimize/modify_model_interface_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/tools/optimize/modify_model_interface.h" +#include + +#include #include #include "absl/memory/memory.h" #include "tensorflow/lite/model.h" @@ -23,6 +26,8 @@ namespace tflite { namespace optimize { namespace { +using ::testing::ElementsAreArray; + // Create a model with 1 quant, 1 FC, 1 dequant std::unique_ptr CreateModelSingleInputOutput() { auto model = absl::make_unique(); @@ -238,7 +243,53 @@ std::unique_ptr CreateModelMultipleInputOutput() { return model; } -TEST(ModelInference, Uint8SingleInputOutput) { +// Create a model with 1 FC. +std::unique_ptr CreateFloatModel() { + auto model = absl::make_unique(); + auto subgraph = absl::make_unique(); + auto buffer = absl::make_unique(); + auto fc_op_code = absl::make_unique(); + auto fc_op = absl::make_unique(); + + model->subgraphs.push_back(std::move(subgraph)); + + // Op code + fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED; + fc_op_code->version = 2; + + // Op. + fc_op->opcode_index = 0; + fc_op->inputs = {0}; + fc_op->outputs = {1}; + + model->subgraphs[0]->operators.push_back(std::move(fc_op)); + model->operator_codes.push_back(std::move(fc_op_code)); + + // Model input/otuput. + model->subgraphs[0]->inputs = {0}; + model->subgraphs[0]->outputs = {1}; + + // Tensors + auto tensor_0 = absl::make_unique(); + tensor_0->name = "tensor_0"; + tensor_0->shape = {}; + tensor_0->type = TensorType_FLOAT32; + + auto tensor_1 = absl::make_unique(); + tensor_1->name = "tensor_1"; + tensor_1->shape = {}; + tensor_1->type = TensorType_FLOAT32; + + model->subgraphs[0]->tensors.push_back(std::move(tensor_0)); + model->subgraphs[0]->tensors.push_back(std::move(tensor_1)); + + // Buffer + model->buffers.push_back(std::move(buffer)); + + return model; +} + +TEST(ModelInterface, Uint8SingleInputOutput) { auto model = CreateModelSingleInputOutput(); // Ops. @@ -277,7 +328,7 @@ TEST(ModelInference, Uint8SingleInputOutput) { EXPECT_EQ(model->subgraphs[0]->operators[2]->opcode_index, 0); } -TEST(ModelInference, Int8SingleInputOutput) { +TEST(ModelInterface, Int8SingleInputOutput) { auto model = CreateModelSingleInputOutput(); // Change model type. @@ -299,7 +350,7 @@ TEST(ModelInference, Int8SingleInputOutput) { EXPECT_EQ(model->subgraphs[0]->outputs[0], 2); } -TEST(ModelInference, Uint8MutipleInputOutput) { +TEST(ModelInterface, Uint8MutipleInputOutput) { auto model = CreateModelMultipleInputOutput(); // Ops. @@ -362,7 +413,7 @@ TEST(ModelInference, Uint8MutipleInputOutput) { EXPECT_EQ(model->subgraphs[0]->operators[4]->opcode_index, 0); } -TEST(ModelInference, Int8MutipleInputOutput) { +TEST(ModelInterface, Int8MutipleInputOutput) { auto model = CreateModelMultipleInputOutput(); // Change model type. @@ -413,6 +464,72 @@ TEST(ModelInference, Int8MutipleInputOutput) { EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1); } +TEST(ModelInterface, Float) { + // Create the model. + std::unique_ptr input_model_t = CreateFloatModel(); + flatbuffers::FlatBufferBuilder builder_temp; + flatbuffers::Offset output_model_location = + Model::Pack(builder_temp, input_model_t.get()); + FinishModelBuffer(builder_temp, output_model_location); + const uint8_t* buffer_temp = builder_temp.GetBufferPointer(); + const Model* input_model = GetModel(buffer_temp); + + // Change model type. + flatbuffers::FlatBufferBuilder builder; + EXPECT_EQ(Uint8QuantizeModelInputsOutputs(&builder, input_model, + {{"tensor_0", {0.4, 2}}}, + {{"tensor_1", {0.5, -5}}}), + kTfLiteOk); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + std::unique_ptr model; + model.reset(output_model->UnPack()); + + // Verify results. + EXPECT_EQ(model->operator_codes.size(), 3); + EXPECT_EQ(model->subgraphs.size(), 1); + EXPECT_EQ(model->subgraphs[0]->operators.size(), 3); + EXPECT_EQ(model->subgraphs[0]->tensors.size(), 4); + EXPECT_EQ(model->buffers.size(), 1); + + // Ops. + EXPECT_EQ(model->operator_codes[0]->builtin_code, + BuiltinOperator_FULLY_CONNECTED); + EXPECT_EQ(model->operator_codes[1]->builtin_code, BuiltinOperator_DEQUANTIZE); + EXPECT_EQ(model->operator_codes[2]->builtin_code, BuiltinOperator_QUANTIZE); + + EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1); + EXPECT_EQ(model->subgraphs[0]->operators[1]->opcode_index, 0); + EXPECT_EQ(model->subgraphs[0]->operators[2]->opcode_index, 2); + + EXPECT_THAT(model->subgraphs[0]->operators[0]->inputs, ElementsAreArray({2})); + EXPECT_THAT(model->subgraphs[0]->operators[0]->outputs, + ElementsAreArray({0})); + EXPECT_THAT(model->subgraphs[0]->operators[1]->inputs, ElementsAreArray({0})); + EXPECT_THAT(model->subgraphs[0]->operators[1]->outputs, + ElementsAreArray({1})); + EXPECT_THAT(model->subgraphs[0]->operators[2]->inputs, ElementsAreArray({1})); + EXPECT_THAT(model->subgraphs[0]->operators[2]->outputs, + ElementsAreArray({3})); + + // Tensors. + EXPECT_EQ(model->subgraphs[0]->tensors[0]->name, "tensor_0"); + EXPECT_EQ(model->subgraphs[0]->tensors[0]->type, TensorType_FLOAT32); + EXPECT_EQ(model->subgraphs[0]->tensors[1]->name, "tensor_1"); + EXPECT_EQ(model->subgraphs[0]->tensors[1]->type, TensorType_FLOAT32); + + EXPECT_EQ(model->subgraphs[0]->tensors[2]->name, "tensor_0_uint8"); + EXPECT_EQ(model->subgraphs[0]->tensors[2]->type, TensorType_UINT8); + EXPECT_FLOAT_EQ(model->subgraphs[0]->tensors[2]->quantization->scale[0], 0.4); + EXPECT_EQ(model->subgraphs[0]->tensors[2]->quantization->zero_point[0], 2); + + EXPECT_EQ(model->subgraphs[0]->tensors[3]->name, "tensor_1_uint8"); + EXPECT_EQ(model->subgraphs[0]->tensors[3]->type, TensorType_UINT8); + EXPECT_FLOAT_EQ(model->subgraphs[0]->tensors[3]->quantization->scale[0], 0.5); + EXPECT_EQ(model->subgraphs[0]->tensors[3]->quantization->zero_point[0], -5); +} + } // namespace } // namespace optimize } // namespace tflite