Expose the fully_quantize flag for the new mlir quantizer
PiperOrigin-RevId: 314547190 Change-Id: I14dfb095eefb5f9a565f726eb1ea760a8d0129b7
This commit is contained in:
parent
b75ea8de71
commit
2fc2651747
|
@ -108,15 +108,19 @@ class ConverterError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def mlir_quantize(input_data_str, disable_per_channel=False,
|
def mlir_quantize(input_data_str,
|
||||||
|
disable_per_channel=False,
|
||||||
|
fully_quantize=False,
|
||||||
inference_type=_types_pb2.INT8):
|
inference_type=_types_pb2.INT8):
|
||||||
"""Quantize `input_data_str` with calibration results.
|
"""Quantize `input_data_str` with calibration results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_data_str: Input data in serialized form (e.g. a TFLITE model with
|
input_data_str: Input data in serialized form (e.g. a TFLITE model with
|
||||||
calibration results).
|
calibration results).
|
||||||
disable_per_channel: Bool indicating whether to do per-channel or
|
disable_per_channel: Bool indicating whether to do per-channel or per-tensor
|
||||||
per-tensor quantization
|
quantization
|
||||||
|
fully_quantize: Bool indicating whether to fully quantize the model. Besides
|
||||||
|
model body, the input/output will be quantized as well.
|
||||||
inference_type: Data type for the activations. The default value is int8.
|
inference_type: Data type for the activations. The default value is int8.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -125,6 +129,7 @@ def mlir_quantize(input_data_str, disable_per_channel=False,
|
||||||
"""
|
"""
|
||||||
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str,
|
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str,
|
||||||
disable_per_channel,
|
disable_per_channel,
|
||||||
|
fully_quantize,
|
||||||
inference_type)
|
inference_type)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ from six.moves import range
|
||||||
from tensorflow.lite.python import lite
|
from tensorflow.lite.python import lite
|
||||||
from tensorflow.lite.python import lite_constants
|
from tensorflow.lite.python import lite_constants
|
||||||
from tensorflow.lite.python.convert import ConverterError
|
from tensorflow.lite.python.convert import ConverterError
|
||||||
|
from tensorflow.lite.python.convert import mlir_quantize
|
||||||
from tensorflow.lite.python.interpreter import Interpreter
|
from tensorflow.lite.python.interpreter import Interpreter
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
|
@ -1175,9 +1176,21 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||||
converter._experimental_new_quantizer = True
|
converter._experimental_new_quantizer = True
|
||||||
quantized_tflite = converter.convert()
|
quantized_tflite = converter.convert()
|
||||||
self.assertTrue(quantized_tflite)
|
self.assertTrue(quantized_tflite)
|
||||||
|
|
||||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||||
|
|
||||||
|
# calibration only api
|
||||||
|
converter._experimental_calibrate_only = True
|
||||||
|
calibrated_tflite = converter.convert()
|
||||||
|
quantized_tflite = mlir_quantize(calibrated_tflite, fully_quantize=True)
|
||||||
|
interpreter = Interpreter(model_content=quantized_tflite)
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
input_details = interpreter.get_input_details()
|
||||||
|
self.assertEqual(np.int8, input_details[0]['dtype'])
|
||||||
|
self.assertEqual((1., 0.), input_details[0]['quantization'])
|
||||||
|
|
||||||
|
output_details = interpreter.get_output_details()
|
||||||
|
self.assertEqual(np.int8, output_details[0]['dtype'])
|
||||||
|
|
||||||
def testFloatTocoConverter(self):
|
def testFloatTocoConverter(self):
|
||||||
"""Tests deprecated test TocoConverter."""
|
"""Tests deprecated test TocoConverter."""
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
|
|
|
@ -44,10 +44,11 @@ def wrapped_get_potentially_supported_ops():
|
||||||
|
|
||||||
|
|
||||||
def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel,
|
def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel,
|
||||||
inference_type):
|
fully_quantize, inference_type):
|
||||||
"""Wraps experimental mlir quantize model."""
|
"""Wraps experimental mlir quantize model."""
|
||||||
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str,
|
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str,
|
||||||
disable_per_channel,
|
disable_per_channel,
|
||||||
|
fully_quantize,
|
||||||
inference_type)
|
inference_type)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -264,11 +264,13 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
tflite::TensorType inference_io_type =
|
||||||
|
fully_quantize ? inference_tensor_type : tflite::TensorType_FLOAT32;
|
||||||
flatbuffers::FlatBufferBuilder builder;
|
flatbuffers::FlatBufferBuilder builder;
|
||||||
auto status = mlir::lite::QuantizeModel(
|
auto status = mlir::lite::QuantizeModel(
|
||||||
*tflite_model, tflite::TensorType::TensorType_FLOAT32,
|
*tflite_model, inference_io_type, inference_io_type,
|
||||||
tflite::TensorType::TensorType_FLOAT32, inference_tensor_type, {},
|
inference_tensor_type, {}, disable_per_channel, fully_quantize, &builder,
|
||||||
disable_per_channel, fully_quantize, &builder, error_reporter.get());
|
error_reporter.get());
|
||||||
|
|
||||||
if (status != kTfLiteOk) {
|
if (status != kTfLiteOk) {
|
||||||
error_reporter->exception();
|
error_reporter->exception();
|
||||||
|
|
Loading…
Reference in New Issue