Expose the fully_quantize flag for the new mlir quantizer
PiperOrigin-RevId: 314547190 Change-Id: I14dfb095eefb5f9a565f726eb1ea760a8d0129b7
This commit is contained in:
parent
b75ea8de71
commit
2fc2651747
|
@ -108,15 +108,19 @@ class ConverterError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def mlir_quantize(input_data_str, disable_per_channel=False,
|
||||
def mlir_quantize(input_data_str,
|
||||
disable_per_channel=False,
|
||||
fully_quantize=False,
|
||||
inference_type=_types_pb2.INT8):
|
||||
"""Quantize `input_data_str` with calibration results.
|
||||
|
||||
Args:
|
||||
input_data_str: Input data in serialized form (e.g. a TFLITE model with
|
||||
calibration results).
|
||||
disable_per_channel: Bool indicating whether to do per-channel or
|
||||
per-tensor quantization
|
||||
disable_per_channel: Bool indicating whether to do per-channel or per-tensor
|
||||
quantization
|
||||
fully_quantize: Bool indicating whether to fully quantize the model. Besides
|
||||
model body, the input/output will be quantized as well.
|
||||
inference_type: Data type for the activations. The default value is int8.
|
||||
|
||||
Returns:
|
||||
|
@ -125,6 +129,7 @@ def mlir_quantize(input_data_str, disable_per_channel=False,
|
|||
"""
|
||||
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str,
|
||||
disable_per_channel,
|
||||
fully_quantize,
|
||||
inference_type)
|
||||
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ from six.moves import range
|
|||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python import lite_constants
|
||||
from tensorflow.lite.python.convert import ConverterError
|
||||
from tensorflow.lite.python.convert import mlir_quantize
|
||||
from tensorflow.lite.python.interpreter import Interpreter
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.client import session
|
||||
|
@ -1175,9 +1176,21 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||
converter._experimental_new_quantizer = True
|
||||
quantized_tflite = converter.convert()
|
||||
self.assertTrue(quantized_tflite)
|
||||
|
||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||
|
||||
# calibration only api
|
||||
converter._experimental_calibrate_only = True
|
||||
calibrated_tflite = converter.convert()
|
||||
quantized_tflite = mlir_quantize(calibrated_tflite, fully_quantize=True)
|
||||
interpreter = Interpreter(model_content=quantized_tflite)
|
||||
interpreter.allocate_tensors()
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertEqual(np.int8, input_details[0]['dtype'])
|
||||
self.assertEqual((1., 0.), input_details[0]['quantization'])
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(np.int8, output_details[0]['dtype'])
|
||||
|
||||
def testFloatTocoConverter(self):
|
||||
"""Tests deprecated test TocoConverter."""
|
||||
with ops.Graph().as_default():
|
||||
|
|
|
@ -44,10 +44,11 @@ def wrapped_get_potentially_supported_ops():
|
|||
|
||||
|
||||
def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel,
|
||||
inference_type):
|
||||
fully_quantize, inference_type):
|
||||
"""Wraps experimental mlir quantize model."""
|
||||
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str,
|
||||
disable_per_channel,
|
||||
fully_quantize,
|
||||
inference_type)
|
||||
|
||||
|
||||
|
|
|
@ -264,11 +264,13 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
|||
default:
|
||||
return nullptr;
|
||||
}
|
||||
tflite::TensorType inference_io_type =
|
||||
fully_quantize ? inference_tensor_type : tflite::TensorType_FLOAT32;
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto status = mlir::lite::QuantizeModel(
|
||||
*tflite_model, tflite::TensorType::TensorType_FLOAT32,
|
||||
tflite::TensorType::TensorType_FLOAT32, inference_tensor_type, {},
|
||||
disable_per_channel, fully_quantize, &builder, error_reporter.get());
|
||||
*tflite_model, inference_io_type, inference_io_type,
|
||||
inference_tensor_type, {}, disable_per_channel, fully_quantize, &builder,
|
||||
error_reporter.get());
|
||||
|
||||
if (status != kTfLiteOk) {
|
||||
error_reporter->exception();
|
||||
|
|
Loading…
Reference in New Issue