From d33cb73389c4198c01d8dac55cbbd6620abe7d4b Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Wed, 13 May 2020 23:48:03 -0700 Subject: [PATCH] Expose inference type in the mlir quantizer This is to prepare the 16 bits activation quantization release. The data type specified by this flag is only applied on the activations. PiperOrigin-RevId: 311478782 Change-Id: I5f63f0508011cc0b1b47a0debb35c17d3284eae9 --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 6 ++-- .../lite/quantization/lite/quantize_model.cc | 10 ++++-- .../lite/quantization/lite/quantize_model.h | 4 ++- .../lite/quantization/lite/tfl_quantizer.cc | 3 +- tensorflow/lite/python/convert.py | 7 ++-- tensorflow/lite/python/lite_v2_test.py | 36 +++++++++++++++++++ tensorflow/lite/python/wrap_toco.py | 6 ++-- tensorflow/lite/toco/python/BUILD | 1 + .../lite/toco/python/toco_python_api.cc | 21 +++++++++-- tensorflow/lite/toco/python/toco_python_api.h | 2 +- .../python/lite/toco_python_api_wrapper.cc | 7 ++-- 11 files changed, 84 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 8a949a45e2d..a585b8e1520 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -414,9 +414,9 @@ class TFL_ConvOp : }]; let arguments = ( - ins TFL_TensorOf<[F32, QI8, QUI8]>:$input, + ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input, TFL_TensorOf<[F32, QI8, QUI8]>:$filter, - TFL_TensorOfOrNone<[F32, I32]>:$bias, + TFL_TensorOfOrNone<[F32, I32, I64]>:$bias, I32Attr:$dilation_h_factor, I32Attr:$dilation_w_factor, TFL_AFAttr:$fused_activation_function, @@ -425,7 +425,7 @@ class TFL_ConvOp : I32Attr:$stride_w ); - let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output); let hasOptions = 0b1; } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 0ac3fa419bc..a2e3c065113 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace mlir { namespace lite { @@ -38,6 +39,7 @@ namespace lite { TfLiteStatus QuantizeModel( const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::TensorType& output_type, + const tflite::TensorType& inference_type, const std::unordered_set& operator_names, bool disable_per_channel, bool fully_quantize, flatbuffers::FlatBufferBuilder* builder, @@ -73,7 +75,7 @@ TfLiteStatus QuantizeModel( // Apply quantization passes PassManager pm(module->getContext()); TFL::QuantizationSpecs quant_specs; - quant_specs.inference_type = tensorflow::DT_QINT8; + quant_specs.inference_type = tflite::TflTypeToTfType(inference_type); quant_specs.post_training_quantization = true; quant_specs.disable_per_channel = disable_per_channel; @@ -81,8 +83,10 @@ TfLiteStatus QuantizeModel( auto input_tf_type = tflite::TflTypeToTfType(input_type); if (input_tf_type == tensorflow::DT_FLOAT) { emit_adaptor = true; - } else if (input_tf_type == tensorflow::DT_UINT8) { - quant_specs.inference_type = tensorflow::DT_QUINT8; + } else if (input_tf_type == tensorflow::DT_UINT8 || + input_tf_type == tensorflow::DT_INT8 || + input_tf_type == tensorflow::DT_INT16) { + quant_specs.inference_type = input_tf_type; } pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs)); diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index 578aa6438de..d60df56b473 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -26,11 +26,13 @@ namespace mlir { namespace lite { // Quantize the `input_model` and write the result to a flatbuffer `builder`. -// The `input_type` and `output_type` can be float32/qint8/int8. +// The `input_type`, `output_type` and `inference_type` can be +// float32/qint8/int8/int16. // Return partially quantized model if `fully_quantize` is false. TfLiteStatus QuantizeModel( const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::TensorType& output_type, + const tflite::TensorType& inference_type, const std::unordered_set& operator_names, bool disable_per_channel, bool fully_quantize, flatbuffers::FlatBufferBuilder* builder, diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc index 77bd87a3c03..5bd1b71e631 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc @@ -46,7 +46,8 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer, tflite::StderrReporter error_reporter; return mlir::lite::QuantizeModel( - *model, tflite::TensorType_INT8, tflite::TensorType_INT8, {}, + *model, tflite::TensorType_INT8, tflite::TensorType_INT8, + tflite::TensorType_INT8, {}, /*disable_per_channel=*/false, /*fully_quantize=*/true, builder, &error_reporter); } diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index ae70afd6962..6b7a32f1bcc 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -108,7 +108,8 @@ class ConverterError(Exception): pass -def mlir_quantize(input_data_str, disable_per_channel=False): +def mlir_quantize(input_data_str, disable_per_channel=False, + inference_type=_types_pb2.INT8): """Quantize `input_data_str` with calibration results. Args: @@ -116,13 +117,15 @@ def mlir_quantize(input_data_str, disable_per_channel=False): calibration results). disable_per_channel: Bool indicating whether to do per-channel or per-tensor quantization + inference_type: Data type for the activations. The default value is int8. Returns: Quantized model in serialized form (e.g. a TFLITE model) with floating-point inputs and outputs. """ return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str, - disable_per_channel) + disable_per_channel, + inference_type) def mlir_sparsify(input_data_str): diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 4768892f359..9af37df2975 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -29,7 +29,9 @@ import tensorflow as tf from tensorflow.lite.python import lite from tensorflow.lite.python import lite_v2_test_util +from tensorflow.lite.python.convert import mlir_quantize from tensorflow.lite.python.interpreter import Interpreter +from tensorflow.lite.toco import types_pb2 as _types_pb2 from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras.layers import recurrent @@ -204,6 +206,40 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): # Ensure that the quantized weights tflite model is smaller. self.assertLess(len(quantized_tflite), len(float_tflite)) + def testCalibrateAndQuantizeBuiltinInt16(self): + func, calibration_gen = self._getCalibrationQuantizeModel() + + # Convert float model. + float_converter = lite.TFLiteConverterV2.from_concrete_functions([func]) + float_tflite = float_converter.convert() + self.assertTrue(float_tflite) + + converter = lite.TFLiteConverterV2.from_concrete_functions([func]) + # TODO(b/156309549): We should add INT16 to the builtin types. + converter.target_spec.supported_ops = [ + lite.OpsSet.TFLITE_BUILTINS_INT8 + ] + converter.representative_dataset = calibration_gen + converter._experimental_calibrate_only = True + calibrated_tflite = converter.convert() + quantized_tflite = mlir_quantize(calibrated_tflite, + inference_type=_types_pb2.QUANTIZED_INT16) + + self.assertTrue(quantized_tflite) + + # The default input and output types should be float. + interpreter = Interpreter(model_content=quantized_tflite) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + self.assertLen(input_details, 1) + self.assertEqual(np.float32, input_details[0]['dtype']) + output_details = interpreter.get_output_details() + self.assertLen(output_details, 1) + self.assertEqual(np.float32, output_details[0]['dtype']) + + # Ensure that the quantized weights tflite model is smaller. + self.assertLess(len(quantized_tflite), len(float_tflite)) + def _getTrainingTimeQuantizedModel(self): class QLinear(tf.keras.layers.Layer): diff --git a/tensorflow/lite/python/wrap_toco.py b/tensorflow/lite/python/wrap_toco.py index 3c1f98ff42d..8f72cc8cbbd 100644 --- a/tensorflow/lite/python/wrap_toco.py +++ b/tensorflow/lite/python/wrap_toco.py @@ -43,10 +43,12 @@ def wrapped_get_potentially_supported_ops(): return _pywrap_toco_api.TocoGetPotentiallySupportedOps() -def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel): +def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel, + inference_type): """Wraps experimental mlir quantize model.""" return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str, - disable_per_channel) + disable_per_channel, + inference_type) def wrapped_experimental_mlir_sparsify(input_data_str): diff --git a/tensorflow/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD index bea582d83a5..7dfa714d1d6 100644 --- a/tensorflow/lite/toco/python/BUILD +++ b/tensorflow/lite/toco/python/BUILD @@ -54,6 +54,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/python:saved_model_to_tfl_flatbuffer", "//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model", "//tensorflow/compiler/mlir/lite/sparsity:sparsify_model", + "//tensorflow/lite/toco:types_proto_cc", ] + select({ # This is required when running `tflite_convert` from `bazel`. # It requires to link with TensorFlow Ops to get the op definitions. diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index aafd14f9da8..441aabf0ffe 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/lite/toco/toco_tooling.h" #include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/types.pb.h" namespace toco { @@ -229,7 +230,7 @@ PyObject* TocoGetPotentiallySupportedOps() { } PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, - bool fully_quantize) { + bool fully_quantize, int inference_type) { using tflite::interpreter_wrapper::PythonErrorReporter; char* buf = nullptr; Py_ssize_t length; @@ -249,11 +250,25 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, auto tflite_model = absl::make_unique(); model->GetModel()->UnPackTo(tflite_model.get(), nullptr); + tflite::TensorType inference_tensor_type; + switch (inference_type) { + case toco::IODataType::QUANTIZED_INT16: + inference_tensor_type = tflite::TensorType_INT16; + break; + case toco::IODataType::QUANTIZED_UINT8: + inference_tensor_type = tflite::TensorType_UINT8; + break; + case toco::IODataType::INT8: + inference_tensor_type = tflite::TensorType_INT8; + break; + default: + return nullptr; + } flatbuffers::FlatBufferBuilder builder; auto status = mlir::lite::QuantizeModel( *tflite_model, tflite::TensorType::TensorType_FLOAT32, - tflite::TensorType::TensorType_FLOAT32, {}, disable_per_channel, - fully_quantize, &builder, error_reporter.get()); + tflite::TensorType::TensorType_FLOAT32, inference_tensor_type, {}, + disable_per_channel, fully_quantize, &builder, error_reporter.get()); if (status != kTfLiteOk) { error_reporter->exception(); diff --git a/tensorflow/lite/toco/python/toco_python_api.h b/tensorflow/lite/toco/python/toco_python_api.h index 7afb097fd4a..058ae9fb942 100644 --- a/tensorflow/lite/toco/python/toco_python_api.h +++ b/tensorflow/lite/toco/python/toco_python_api.h @@ -44,7 +44,7 @@ PyObject* TocoGetPotentiallySupportedOps(); // is specified by the calibration data are not sufficient to quantize the // model. PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, - bool fully_quantize); + bool fully_quantize, int inference_type); // Sparsifies model to encode sparse tensors with proper format. Throws error if // sparsification fails. diff --git a/tensorflow/python/lite/toco_python_api_wrapper.cc b/tensorflow/python/lite/toco_python_api_wrapper.cc index e6e0e111ec4..b77200a3bee 100644 --- a/tensorflow/python/lite/toco_python_api_wrapper.cc +++ b/tensorflow/python/lite/toco_python_api_wrapper.cc @@ -57,12 +57,13 @@ PYBIND11_MODULE(_pywrap_toco_api, m) { m.def( "ExperimentalMlirQuantizeModel", [](py::object input_contents_txt_raw, bool disable_per_channel, - bool fully_quantize) { + bool fully_quantize, int inference_type) { return tensorflow::PyoOrThrow(toco::MlirQuantizeModel( - input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize)); + input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize, + inference_type)); }, py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false, - py::arg("fully_quantize") = true, + py::arg("fully_quantize") = true, py::arg("inference_type") = 9, R"pbdoc( Returns a quantized model. )pbdoc");