Corrected after merge with master.

Tested: strict mode and non-strict mode.

Change-Id: I7e03d08133f39cc65a18875e65ce5cdddaf2d6a4
This commit is contained in:
Elena Zhelezina 2020-03-17 16:30:32 +00:00
parent 5b9a467e3d
commit db0d468121
2 changed files with 59 additions and 11 deletions

View File

@ -180,11 +180,13 @@ class QuantizationMode(object):
def post_training_int8_no_float(self): def post_training_int8_no_float(self):
"""Post training int8 quantize, disallow float fallback.""" """Post training int8 quantize, disallow float fallback."""
return (self._is_int8_target_required() and return (self._is_int8_target_required() and
not self._is_int16x8_target_required() and
self._representative_dataset is not None) self._representative_dataset is not None)
def post_training_int8_allow_float(self): def post_training_int8_allow_float(self):
"""Post training int8 quantize, allow float fallback.""" """Post training int8 quantize, allow float fallback."""
return (self._any_optimization_enabled() and return (self._any_optimization_enabled() and
not self._is_int16x8_target_required() and
self._representative_dataset is not None and self._representative_dataset is not None and
self._smallest_supported_type() == constants.INT8) self._smallest_supported_type() == constants.INT8)
@ -193,6 +195,18 @@ class QuantizationMode(object):
return (self._any_optimization_enabled() and return (self._any_optimization_enabled() and
self._contains_training_quant_op()) self._contains_training_quant_op())
def post_training_int16x8_no_float(self):
"""Post training int16x8 quantize, disallow float fallback."""
return (not self._is_int8_target_required() and
self._is_int16x8_target_required() and
not self._is_allow_float() and
self._representative_dataset is not None)
def post_training_int16x8_allow_float(self):
"""Post training int16x8 quantize, allow float fallback."""
return (self._is_int16x8_target_required() and
self._is_allow_float())
def post_training_dynamic_range_int8(self): def post_training_dynamic_range_int8(self):
"""Post training int8 const, on-the-fly int8 quantize of dynamic tensors.""" """Post training int8 const, on-the-fly int8 quantize of dynamic tensors."""
# Post-training dynamic range quantization is only enabled if post-training # Post-training dynamic range quantization is only enabled if post-training
@ -212,9 +226,14 @@ class QuantizationMode(object):
return not (self.post_training_int8_no_float() or return not (self.post_training_int8_no_float() or
self.post_training_int8_allow_float() or self.post_training_int8_allow_float() or
self.training_time_int8_allow_float() or self.training_time_int8_allow_float() or
self.post_training_int16x8_no_float() or
self.post_training_int16x8_allow_float() or
self.post_training_dynamic_range_int8() or self.post_training_dynamic_range_int8() or
self.post_training_fp16()) self.post_training_fp16())
def activations_type(self):
return constants.INT16 if self._is_int16x8_target_required() else constants.INT8
# Below are helpers for the above functions. # Below are helpers for the above functions.
def _validate_int8_required(self): def _validate_int8_required(self):
@ -244,6 +263,18 @@ class QuantizationMode(object):
self._target_spec.supported_ops) or self._target_spec.supported_ops) or
set(self._target_spec.supported_types) == set([constants.INT8])) set(self._target_spec.supported_types) == set([constants.INT8]))
def _is_int16x8_target_required(self):
return bool(
set(self._target_spec.supported_ops).intersection([
OpsSet.TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
]))
def _is_allow_float(self):
return bool(
set(self._target_spec.supported_ops).intersection([
OpsSet.TFLITE_BUILTINS
]))
def _any_optimization_enabled(self): def _any_optimization_enabled(self):
return bool( return bool(
set(self._optimizations).intersection([ set(self._optimizations).intersection([
@ -309,13 +340,13 @@ class TFLiteConverterBase(object):
return _get_grappler_config(optimizers) return _get_grappler_config(optimizers)
def _calibrate_quantize_model(self, result, inference_input_type, def _calibrate_quantize_model(self, result, inference_input_type,
inference_output_type, allow_float): inference_output_type, activations_type, allow_float):
if not isinstance(self.representative_dataset, RepresentativeDataset): if not isinstance(self.representative_dataset, RepresentativeDataset):
self.representative_dataset = RepresentativeDataset( self.representative_dataset = RepresentativeDataset(
self.representative_dataset) self.representative_dataset)
calibrate_quantize = _calibrator.Calibrator(result) calibrate_quantize = _calibrator.Calibrator(result)
activations_type = constants.INT16 if self._is_int16x8_target_required() else constants.INT8
if (self.experimental_calibrate_only: if (self.experimental_calibrate_only:
return calibrate_quantize.calibrate(self.representative_dataset.input_gen) return calibrate_quantize.calibrate(self.representative_dataset.input_gen)
else: else:
@ -608,12 +639,20 @@ class TFLiteConverterV2(TFLiteConverterBase):
output_tensors=output_tensors, output_tensors=output_tensors,
**converter_kwargs) **converter_kwargs)
activations_type = quant_mode.activations_type()
if quant_mode.post_training_int8_no_float(): if quant_mode.post_training_int8_no_float():
result = self._calibrate_quantize_model(result, constants.FLOAT, result = self._calibrate_quantize_model(result, constants.FLOAT,
constants.FLOAT, False) constants.FLOAT, activations_type, False)
elif quant_mode.post_training_int8_allow_float(): elif quant_mode.post_training_int8_allow_float():
result = self._calibrate_quantize_model(result, constants.FLOAT, result = self._calibrate_quantize_model(result, constants.FLOAT,
constants.FLOAT, True) constants.FLOAT, activations_type, True)
elif quant_mode.post_training_int16x8_no_float():
result = self._calibrate_quantize_model(result, constants.FLOAT,
constants.FLOAT, activations_type, False)
elif quant_mode.post_training_int16x8_allow_float():
result = self._calibrate_quantize_model(result, constants.FLOAT,
constants.FLOAT, activations_type, True)
return result return result
@ -1114,6 +1153,8 @@ class TFLiteConverter(TFLiteConverterBase):
quant_mode.post_training_int8_no_float() or quant_mode.post_training_int8_no_float() or
quant_mode.post_training_int8_allow_float() or quant_mode.post_training_int8_allow_float() or
quant_mode.post_training_dynamic_range_int8() or quant_mode.post_training_dynamic_range_int8() or
quant_mode.post_training_int16x8_no_float() or
quant_mode.post_training_int16x8_allow_float() or
quant_mode.post_training_fp16()) quant_mode.post_training_fp16())
if post_training_optimize: if post_training_optimize:
# Post training optimizations require that TOCO outputs a float model. # Post training optimizations require that TOCO outputs a float model.
@ -1223,12 +1264,20 @@ class TFLiteConverter(TFLiteConverterBase):
output_arrays=self._output_arrays, output_arrays=self._output_arrays,
**converter_kwargs) **converter_kwargs)
activations_type = quant_mode.activations_type()
if quant_mode.post_training_int8_no_float(): if quant_mode.post_training_int8_no_float():
result = self._calibrate_quantize_model(result, inference_input_type, result = self._calibrate_quantize_model(result, inference_input_type,
inference_output_type, False) inference_output_type, activations_type, False)
elif quant_mode.post_training_int8_allow_float(): elif quant_mode.post_training_int8_allow_float():
result = self._calibrate_quantize_model(result, inference_input_type, result = self._calibrate_quantize_model(result, inference_input_type,
inference_output_type, True) inference_output_type, activations_type, True)
elif quant_mode.post_training_int16x8_no_float():
result = self._calibrate_quantize_model(result, inference_input_type,
inference_output_type, activations_type, False)
elif quant_mode.post_training_int16x8_allow_float():
result = self._calibrate_quantize_model(result, inference_input_type,
inference_output_type, activations_type, True)
return result return result
@ -1334,7 +1383,6 @@ class TocoConverter(object):
@classmethod @classmethod
@_deprecation.deprecated( @_deprecation.deprecated(
None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.")
def from_keras_model_file(cls, def from_keras_model_file(cls,
model_file, model_file,
input_arrays=None, input_arrays=None,

View File

@ -40,17 +40,17 @@ PYBIND11_MODULE(_pywrap_tensorflow_lite_calibration_wrapper, m) {
}) })
.def("QuantizeModel", .def("QuantizeModel",
[](CalibrationWrapper& self, int input_py_type, int output_py_type, [](CalibrationWrapper& self, int input_py_type, int output_py_type,
bool allow_float, bool enable_mlir_quantizer) { bool allow_float, int activations_py_type, bool enable_mlir_quantizer) {
return tensorflow::pyo_or_throw( return tensorflow::pyo_or_throw(
self.QuantizeModel(input_py_type, output_py_type, allow_float, self.QuantizeModel(input_py_type, output_py_type, allow_float,
enable_mlir_quantizer)); activations_py_type, enable_mlir_quantizer));
}) })
.def("QuantizeModel", .def("QuantizeModel",
[](CalibrationWrapper& self, int input_py_type, int output_py_type, [](CalibrationWrapper& self, int input_py_type, int output_py_type,
bool allow_float) { bool allow_float, int activations_py_type) {
return tensorflow::pyo_or_throw( return tensorflow::pyo_or_throw(
self.QuantizeModel(input_py_type, output_py_type, allow_float, self.QuantizeModel(input_py_type, output_py_type, allow_float,
/*enable_mlir_quantizer=*/false)); activations_py_type, /*enable_mlir_quantizer=*/false));
}) })
.def("QuantizeModel", [](CalibrationWrapper& self, int input_py_type, .def("QuantizeModel", [](CalibrationWrapper& self, int input_py_type,
int output_py_type, bool allow_float, int output_py_type, bool allow_float,