Expose the fully_quantize flag for the new mlir quantizer

PiperOrigin-RevId: 314547190
Change-Id: I14dfb095eefb5f9a565f726eb1ea760a8d0129b7
This commit is contained in:
Feng Liu 2020-06-03 09:39:12 -07:00 committed by TensorFlower Gardener
parent b75ea8de71
commit 2fc2651747
4 changed files with 30 additions and 9 deletions

View File

@ -108,15 +108,19 @@ class ConverterError(Exception):
pass
def mlir_quantize(input_data_str, disable_per_channel=False,
def mlir_quantize(input_data_str,
disable_per_channel=False,
fully_quantize=False,
inference_type=_types_pb2.INT8):
"""Quantize `input_data_str` with calibration results.
Args:
input_data_str: Input data in serialized form (e.g. a TFLITE model with
calibration results).
disable_per_channel: Bool indicating whether to do per-channel or
per-tensor quantization
disable_per_channel: Bool indicating whether to do per-channel or per-tensor
quantization
fully_quantize: Bool indicating whether to fully quantize the model. Besides
model body, the input/output will be quantized as well.
inference_type: Data type for the activations. The default value is int8.
Returns:
@ -125,6 +129,7 @@ def mlir_quantize(input_data_str, disable_per_channel=False,
"""
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str,
disable_per_channel,
fully_quantize,
inference_type)

View File

@ -32,6 +32,7 @@ from six.moves import range
from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_constants
from tensorflow.lite.python.convert import ConverterError
from tensorflow.lite.python.convert import mlir_quantize
from tensorflow.lite.python.interpreter import Interpreter
from tensorflow.python import keras
from tensorflow.python.client import session
@ -1175,9 +1176,21 @@ class FromSessionTest(TestModels, parameterized.TestCase):
converter._experimental_new_quantizer = True
quantized_tflite = converter.convert()
self.assertTrue(quantized_tflite)
self.assertLess(len(quantized_tflite), len(float_tflite))
# calibration only api
converter._experimental_calibrate_only = True
calibrated_tflite = converter.convert()
quantized_tflite = mlir_quantize(calibrated_tflite, fully_quantize=True)
interpreter = Interpreter(model_content=quantized_tflite)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(np.int8, input_details[0]['dtype'])
self.assertEqual((1., 0.), input_details[0]['quantization'])
output_details = interpreter.get_output_details()
self.assertEqual(np.int8, output_details[0]['dtype'])
def testFloatTocoConverter(self):
"""Tests deprecated test TocoConverter."""
with ops.Graph().as_default():

View File

@ -44,10 +44,11 @@ def wrapped_get_potentially_supported_ops():
def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel,
inference_type):
fully_quantize, inference_type):
"""Wraps experimental mlir quantize model."""
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str,
disable_per_channel,
fully_quantize,
inference_type)

View File

@ -264,11 +264,13 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
default:
return nullptr;
}
tflite::TensorType inference_io_type =
fully_quantize ? inference_tensor_type : tflite::TensorType_FLOAT32;
flatbuffers::FlatBufferBuilder builder;
auto status = mlir::lite::QuantizeModel(
*tflite_model, tflite::TensorType::TensorType_FLOAT32,
tflite::TensorType::TensorType_FLOAT32, inference_tensor_type, {},
disable_per_channel, fully_quantize, &builder, error_reporter.get());
*tflite_model, inference_io_type, inference_io_type,
inference_tensor_type, {}, disable_per_channel, fully_quantize, &builder,
error_reporter.get());
if (status != kTfLiteOk) {
error_reporter->exception();