diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index a5fbb88132e..9ce88ec6c96 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -108,15 +108,19 @@ class ConverterError(Exception): pass -def mlir_quantize(input_data_str, disable_per_channel=False, +def mlir_quantize(input_data_str, + disable_per_channel=False, + fully_quantize=False, inference_type=_types_pb2.INT8): """Quantize `input_data_str` with calibration results. Args: input_data_str: Input data in serialized form (e.g. a TFLITE model with - calibration results). - disable_per_channel: Bool indicating whether to do per-channel or - per-tensor quantization + calibration results). + disable_per_channel: Bool indicating whether to do per-channel or per-tensor + quantization + fully_quantize: Bool indicating whether to fully quantize the model. Besides + model body, the input/output will be quantized as well. inference_type: Data type for the activations. The default value is int8. Returns: @@ -125,6 +129,7 @@ def mlir_quantize(input_data_str, disable_per_channel=False, """ return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel, + fully_quantize, inference_type) diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index 9611bda2594..f7f1e3cfbee 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -32,6 +32,7 @@ from six.moves import range from tensorflow.lite.python import lite from tensorflow.lite.python import lite_constants from tensorflow.lite.python.convert import ConverterError +from tensorflow.lite.python.convert import mlir_quantize from tensorflow.lite.python.interpreter import Interpreter from tensorflow.python import keras from tensorflow.python.client import session @@ -1175,9 +1176,21 @@ class FromSessionTest(TestModels, parameterized.TestCase): converter._experimental_new_quantizer = True quantized_tflite = converter.convert() self.assertTrue(quantized_tflite) - self.assertLess(len(quantized_tflite), len(float_tflite)) + # calibration only api + converter._experimental_calibrate_only = True + calibrated_tflite = converter.convert() + quantized_tflite = mlir_quantize(calibrated_tflite, fully_quantize=True) + interpreter = Interpreter(model_content=quantized_tflite) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + self.assertEqual(np.int8, input_details[0]['dtype']) + self.assertEqual((1., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(np.int8, output_details[0]['dtype']) + def testFloatTocoConverter(self): """Tests deprecated test TocoConverter.""" with ops.Graph().as_default(): diff --git a/tensorflow/lite/python/wrap_toco.py b/tensorflow/lite/python/wrap_toco.py index 8f72cc8cbbd..c6176275d81 100644 --- a/tensorflow/lite/python/wrap_toco.py +++ b/tensorflow/lite/python/wrap_toco.py @@ -44,10 +44,11 @@ def wrapped_get_potentially_supported_ops(): def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel, - inference_type): + fully_quantize, inference_type): """Wraps experimental mlir quantize model.""" return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str, disable_per_channel, + fully_quantize, inference_type) diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index 441aabf0ffe..3f3d301a40d 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -264,11 +264,13 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, default: return nullptr; } + tflite::TensorType inference_io_type = + fully_quantize ? inference_tensor_type : tflite::TensorType_FLOAT32; flatbuffers::FlatBufferBuilder builder; auto status = mlir::lite::QuantizeModel( - *tflite_model, tflite::TensorType::TensorType_FLOAT32, - tflite::TensorType::TensorType_FLOAT32, inference_tensor_type, {}, - disable_per_channel, fully_quantize, &builder, error_reporter.get()); + *tflite_model, inference_io_type, inference_io_type, + inference_tensor_type, {}, disable_per_channel, fully_quantize, &builder, + error_reporter.get()); if (status != kTfLiteOk) { error_reporter->exception();