Corrected after merge with master.
Tested: strict mode and non-strict mode. Change-Id: I7e03d08133f39cc65a18875e65ce5cdddaf2d6a4
This commit is contained in:
parent
5b9a467e3d
commit
db0d468121
@ -180,11 +180,13 @@ class QuantizationMode(object):
|
||||
def post_training_int8_no_float(self):
|
||||
"""Post training int8 quantize, disallow float fallback."""
|
||||
return (self._is_int8_target_required() and
|
||||
not self._is_int16x8_target_required() and
|
||||
self._representative_dataset is not None)
|
||||
|
||||
def post_training_int8_allow_float(self):
|
||||
"""Post training int8 quantize, allow float fallback."""
|
||||
return (self._any_optimization_enabled() and
|
||||
not self._is_int16x8_target_required() and
|
||||
self._representative_dataset is not None and
|
||||
self._smallest_supported_type() == constants.INT8)
|
||||
|
||||
@ -193,6 +195,18 @@ class QuantizationMode(object):
|
||||
return (self._any_optimization_enabled() and
|
||||
self._contains_training_quant_op())
|
||||
|
||||
def post_training_int16x8_no_float(self):
|
||||
"""Post training int16x8 quantize, disallow float fallback."""
|
||||
return (not self._is_int8_target_required() and
|
||||
self._is_int16x8_target_required() and
|
||||
not self._is_allow_float() and
|
||||
self._representative_dataset is not None)
|
||||
|
||||
def post_training_int16x8_allow_float(self):
|
||||
"""Post training int16x8 quantize, allow float fallback."""
|
||||
return (self._is_int16x8_target_required() and
|
||||
self._is_allow_float())
|
||||
|
||||
def post_training_dynamic_range_int8(self):
|
||||
"""Post training int8 const, on-the-fly int8 quantize of dynamic tensors."""
|
||||
# Post-training dynamic range quantization is only enabled if post-training
|
||||
@ -212,9 +226,14 @@ class QuantizationMode(object):
|
||||
return not (self.post_training_int8_no_float() or
|
||||
self.post_training_int8_allow_float() or
|
||||
self.training_time_int8_allow_float() or
|
||||
self.post_training_int16x8_no_float() or
|
||||
self.post_training_int16x8_allow_float() or
|
||||
self.post_training_dynamic_range_int8() or
|
||||
self.post_training_fp16())
|
||||
|
||||
def activations_type(self):
|
||||
return constants.INT16 if self._is_int16x8_target_required() else constants.INT8
|
||||
|
||||
# Below are helpers for the above functions.
|
||||
|
||||
def _validate_int8_required(self):
|
||||
@ -244,6 +263,18 @@ class QuantizationMode(object):
|
||||
self._target_spec.supported_ops) or
|
||||
set(self._target_spec.supported_types) == set([constants.INT8]))
|
||||
|
||||
def _is_int16x8_target_required(self):
|
||||
return bool(
|
||||
set(self._target_spec.supported_ops).intersection([
|
||||
OpsSet.TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
|
||||
]))
|
||||
|
||||
def _is_allow_float(self):
|
||||
return bool(
|
||||
set(self._target_spec.supported_ops).intersection([
|
||||
OpsSet.TFLITE_BUILTINS
|
||||
]))
|
||||
|
||||
def _any_optimization_enabled(self):
|
||||
return bool(
|
||||
set(self._optimizations).intersection([
|
||||
@ -309,13 +340,13 @@ class TFLiteConverterBase(object):
|
||||
return _get_grappler_config(optimizers)
|
||||
|
||||
def _calibrate_quantize_model(self, result, inference_input_type,
|
||||
inference_output_type, allow_float):
|
||||
inference_output_type, activations_type, allow_float):
|
||||
if not isinstance(self.representative_dataset, RepresentativeDataset):
|
||||
self.representative_dataset = RepresentativeDataset(
|
||||
self.representative_dataset)
|
||||
|
||||
calibrate_quantize = _calibrator.Calibrator(result)
|
||||
activations_type = constants.INT16 if self._is_int16x8_target_required() else constants.INT8
|
||||
|
||||
if (self.experimental_calibrate_only:
|
||||
return calibrate_quantize.calibrate(self.representative_dataset.input_gen)
|
||||
else:
|
||||
@ -608,12 +639,20 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
output_tensors=output_tensors,
|
||||
**converter_kwargs)
|
||||
|
||||
activations_type = quant_mode.activations_type()
|
||||
|
||||
if quant_mode.post_training_int8_no_float():
|
||||
result = self._calibrate_quantize_model(result, constants.FLOAT,
|
||||
constants.FLOAT, False)
|
||||
constants.FLOAT, activations_type, False)
|
||||
elif quant_mode.post_training_int8_allow_float():
|
||||
result = self._calibrate_quantize_model(result, constants.FLOAT,
|
||||
constants.FLOAT, True)
|
||||
constants.FLOAT, activations_type, True)
|
||||
elif quant_mode.post_training_int16x8_no_float():
|
||||
result = self._calibrate_quantize_model(result, constants.FLOAT,
|
||||
constants.FLOAT, activations_type, False)
|
||||
elif quant_mode.post_training_int16x8_allow_float():
|
||||
result = self._calibrate_quantize_model(result, constants.FLOAT,
|
||||
constants.FLOAT, activations_type, True)
|
||||
|
||||
return result
|
||||
|
||||
@ -1114,6 +1153,8 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
quant_mode.post_training_int8_no_float() or
|
||||
quant_mode.post_training_int8_allow_float() or
|
||||
quant_mode.post_training_dynamic_range_int8() or
|
||||
quant_mode.post_training_int16x8_no_float() or
|
||||
quant_mode.post_training_int16x8_allow_float() or
|
||||
quant_mode.post_training_fp16())
|
||||
if post_training_optimize:
|
||||
# Post training optimizations require that TOCO outputs a float model.
|
||||
@ -1223,12 +1264,20 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
output_arrays=self._output_arrays,
|
||||
**converter_kwargs)
|
||||
|
||||
activations_type = quant_mode.activations_type()
|
||||
|
||||
if quant_mode.post_training_int8_no_float():
|
||||
result = self._calibrate_quantize_model(result, inference_input_type,
|
||||
inference_output_type, False)
|
||||
inference_output_type, activations_type, False)
|
||||
elif quant_mode.post_training_int8_allow_float():
|
||||
result = self._calibrate_quantize_model(result, inference_input_type,
|
||||
inference_output_type, True)
|
||||
inference_output_type, activations_type, True)
|
||||
elif quant_mode.post_training_int16x8_no_float():
|
||||
result = self._calibrate_quantize_model(result, inference_input_type,
|
||||
inference_output_type, activations_type, False)
|
||||
elif quant_mode.post_training_int16x8_allow_float():
|
||||
result = self._calibrate_quantize_model(result, inference_input_type,
|
||||
inference_output_type, activations_type, True)
|
||||
|
||||
return result
|
||||
|
||||
@ -1334,7 +1383,6 @@ class TocoConverter(object):
|
||||
|
||||
@classmethod
|
||||
@_deprecation.deprecated(
|
||||
None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.")
|
||||
def from_keras_model_file(cls,
|
||||
model_file,
|
||||
input_arrays=None,
|
||||
|
@ -40,17 +40,17 @@ PYBIND11_MODULE(_pywrap_tensorflow_lite_calibration_wrapper, m) {
|
||||
})
|
||||
.def("QuantizeModel",
|
||||
[](CalibrationWrapper& self, int input_py_type, int output_py_type,
|
||||
bool allow_float, bool enable_mlir_quantizer) {
|
||||
bool allow_float, int activations_py_type, bool enable_mlir_quantizer) {
|
||||
return tensorflow::pyo_or_throw(
|
||||
self.QuantizeModel(input_py_type, output_py_type, allow_float,
|
||||
enable_mlir_quantizer));
|
||||
activations_py_type, enable_mlir_quantizer));
|
||||
})
|
||||
.def("QuantizeModel",
|
||||
[](CalibrationWrapper& self, int input_py_type, int output_py_type,
|
||||
bool allow_float) {
|
||||
bool allow_float, int activations_py_type) {
|
||||
return tensorflow::pyo_or_throw(
|
||||
self.QuantizeModel(input_py_type, output_py_type, allow_float,
|
||||
/*enable_mlir_quantizer=*/false));
|
||||
activations_py_type, /*enable_mlir_quantizer=*/false));
|
||||
})
|
||||
.def("QuantizeModel", [](CalibrationWrapper& self, int input_py_type,
|
||||
int output_py_type, bool allow_float,
|
||||
|
Loading…
Reference in New Issue
Block a user