Corrected after merge with master.
Tested: strict mode and non-strict mode. Change-Id: I7e03d08133f39cc65a18875e65ce5cdddaf2d6a4
This commit is contained in:
parent
5b9a467e3d
commit
db0d468121
@ -180,11 +180,13 @@ class QuantizationMode(object):
|
|||||||
def post_training_int8_no_float(self):
|
def post_training_int8_no_float(self):
|
||||||
"""Post training int8 quantize, disallow float fallback."""
|
"""Post training int8 quantize, disallow float fallback."""
|
||||||
return (self._is_int8_target_required() and
|
return (self._is_int8_target_required() and
|
||||||
|
not self._is_int16x8_target_required() and
|
||||||
self._representative_dataset is not None)
|
self._representative_dataset is not None)
|
||||||
|
|
||||||
def post_training_int8_allow_float(self):
|
def post_training_int8_allow_float(self):
|
||||||
"""Post training int8 quantize, allow float fallback."""
|
"""Post training int8 quantize, allow float fallback."""
|
||||||
return (self._any_optimization_enabled() and
|
return (self._any_optimization_enabled() and
|
||||||
|
not self._is_int16x8_target_required() and
|
||||||
self._representative_dataset is not None and
|
self._representative_dataset is not None and
|
||||||
self._smallest_supported_type() == constants.INT8)
|
self._smallest_supported_type() == constants.INT8)
|
||||||
|
|
||||||
@ -193,6 +195,18 @@ class QuantizationMode(object):
|
|||||||
return (self._any_optimization_enabled() and
|
return (self._any_optimization_enabled() and
|
||||||
self._contains_training_quant_op())
|
self._contains_training_quant_op())
|
||||||
|
|
||||||
|
def post_training_int16x8_no_float(self):
|
||||||
|
"""Post training int16x8 quantize, disallow float fallback."""
|
||||||
|
return (not self._is_int8_target_required() and
|
||||||
|
self._is_int16x8_target_required() and
|
||||||
|
not self._is_allow_float() and
|
||||||
|
self._representative_dataset is not None)
|
||||||
|
|
||||||
|
def post_training_int16x8_allow_float(self):
|
||||||
|
"""Post training int16x8 quantize, allow float fallback."""
|
||||||
|
return (self._is_int16x8_target_required() and
|
||||||
|
self._is_allow_float())
|
||||||
|
|
||||||
def post_training_dynamic_range_int8(self):
|
def post_training_dynamic_range_int8(self):
|
||||||
"""Post training int8 const, on-the-fly int8 quantize of dynamic tensors."""
|
"""Post training int8 const, on-the-fly int8 quantize of dynamic tensors."""
|
||||||
# Post-training dynamic range quantization is only enabled if post-training
|
# Post-training dynamic range quantization is only enabled if post-training
|
||||||
@ -212,9 +226,14 @@ class QuantizationMode(object):
|
|||||||
return not (self.post_training_int8_no_float() or
|
return not (self.post_training_int8_no_float() or
|
||||||
self.post_training_int8_allow_float() or
|
self.post_training_int8_allow_float() or
|
||||||
self.training_time_int8_allow_float() or
|
self.training_time_int8_allow_float() or
|
||||||
|
self.post_training_int16x8_no_float() or
|
||||||
|
self.post_training_int16x8_allow_float() or
|
||||||
self.post_training_dynamic_range_int8() or
|
self.post_training_dynamic_range_int8() or
|
||||||
self.post_training_fp16())
|
self.post_training_fp16())
|
||||||
|
|
||||||
|
def activations_type(self):
|
||||||
|
return constants.INT16 if self._is_int16x8_target_required() else constants.INT8
|
||||||
|
|
||||||
# Below are helpers for the above functions.
|
# Below are helpers for the above functions.
|
||||||
|
|
||||||
def _validate_int8_required(self):
|
def _validate_int8_required(self):
|
||||||
@ -244,6 +263,18 @@ class QuantizationMode(object):
|
|||||||
self._target_spec.supported_ops) or
|
self._target_spec.supported_ops) or
|
||||||
set(self._target_spec.supported_types) == set([constants.INT8]))
|
set(self._target_spec.supported_types) == set([constants.INT8]))
|
||||||
|
|
||||||
|
def _is_int16x8_target_required(self):
|
||||||
|
return bool(
|
||||||
|
set(self._target_spec.supported_ops).intersection([
|
||||||
|
OpsSet.TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
|
||||||
|
]))
|
||||||
|
|
||||||
|
def _is_allow_float(self):
|
||||||
|
return bool(
|
||||||
|
set(self._target_spec.supported_ops).intersection([
|
||||||
|
OpsSet.TFLITE_BUILTINS
|
||||||
|
]))
|
||||||
|
|
||||||
def _any_optimization_enabled(self):
|
def _any_optimization_enabled(self):
|
||||||
return bool(
|
return bool(
|
||||||
set(self._optimizations).intersection([
|
set(self._optimizations).intersection([
|
||||||
@ -309,13 +340,13 @@ class TFLiteConverterBase(object):
|
|||||||
return _get_grappler_config(optimizers)
|
return _get_grappler_config(optimizers)
|
||||||
|
|
||||||
def _calibrate_quantize_model(self, result, inference_input_type,
|
def _calibrate_quantize_model(self, result, inference_input_type,
|
||||||
inference_output_type, allow_float):
|
inference_output_type, activations_type, allow_float):
|
||||||
if not isinstance(self.representative_dataset, RepresentativeDataset):
|
if not isinstance(self.representative_dataset, RepresentativeDataset):
|
||||||
self.representative_dataset = RepresentativeDataset(
|
self.representative_dataset = RepresentativeDataset(
|
||||||
self.representative_dataset)
|
self.representative_dataset)
|
||||||
|
|
||||||
calibrate_quantize = _calibrator.Calibrator(result)
|
calibrate_quantize = _calibrator.Calibrator(result)
|
||||||
activations_type = constants.INT16 if self._is_int16x8_target_required() else constants.INT8
|
|
||||||
if (self.experimental_calibrate_only:
|
if (self.experimental_calibrate_only:
|
||||||
return calibrate_quantize.calibrate(self.representative_dataset.input_gen)
|
return calibrate_quantize.calibrate(self.representative_dataset.input_gen)
|
||||||
else:
|
else:
|
||||||
@ -608,12 +639,20 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
|||||||
output_tensors=output_tensors,
|
output_tensors=output_tensors,
|
||||||
**converter_kwargs)
|
**converter_kwargs)
|
||||||
|
|
||||||
|
activations_type = quant_mode.activations_type()
|
||||||
|
|
||||||
if quant_mode.post_training_int8_no_float():
|
if quant_mode.post_training_int8_no_float():
|
||||||
result = self._calibrate_quantize_model(result, constants.FLOAT,
|
result = self._calibrate_quantize_model(result, constants.FLOAT,
|
||||||
constants.FLOAT, False)
|
constants.FLOAT, activations_type, False)
|
||||||
elif quant_mode.post_training_int8_allow_float():
|
elif quant_mode.post_training_int8_allow_float():
|
||||||
result = self._calibrate_quantize_model(result, constants.FLOAT,
|
result = self._calibrate_quantize_model(result, constants.FLOAT,
|
||||||
constants.FLOAT, True)
|
constants.FLOAT, activations_type, True)
|
||||||
|
elif quant_mode.post_training_int16x8_no_float():
|
||||||
|
result = self._calibrate_quantize_model(result, constants.FLOAT,
|
||||||
|
constants.FLOAT, activations_type, False)
|
||||||
|
elif quant_mode.post_training_int16x8_allow_float():
|
||||||
|
result = self._calibrate_quantize_model(result, constants.FLOAT,
|
||||||
|
constants.FLOAT, activations_type, True)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -1114,6 +1153,8 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
quant_mode.post_training_int8_no_float() or
|
quant_mode.post_training_int8_no_float() or
|
||||||
quant_mode.post_training_int8_allow_float() or
|
quant_mode.post_training_int8_allow_float() or
|
||||||
quant_mode.post_training_dynamic_range_int8() or
|
quant_mode.post_training_dynamic_range_int8() or
|
||||||
|
quant_mode.post_training_int16x8_no_float() or
|
||||||
|
quant_mode.post_training_int16x8_allow_float() or
|
||||||
quant_mode.post_training_fp16())
|
quant_mode.post_training_fp16())
|
||||||
if post_training_optimize:
|
if post_training_optimize:
|
||||||
# Post training optimizations require that TOCO outputs a float model.
|
# Post training optimizations require that TOCO outputs a float model.
|
||||||
@ -1223,12 +1264,20 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
output_arrays=self._output_arrays,
|
output_arrays=self._output_arrays,
|
||||||
**converter_kwargs)
|
**converter_kwargs)
|
||||||
|
|
||||||
|
activations_type = quant_mode.activations_type()
|
||||||
|
|
||||||
if quant_mode.post_training_int8_no_float():
|
if quant_mode.post_training_int8_no_float():
|
||||||
result = self._calibrate_quantize_model(result, inference_input_type,
|
result = self._calibrate_quantize_model(result, inference_input_type,
|
||||||
inference_output_type, False)
|
inference_output_type, activations_type, False)
|
||||||
elif quant_mode.post_training_int8_allow_float():
|
elif quant_mode.post_training_int8_allow_float():
|
||||||
result = self._calibrate_quantize_model(result, inference_input_type,
|
result = self._calibrate_quantize_model(result, inference_input_type,
|
||||||
inference_output_type, True)
|
inference_output_type, activations_type, True)
|
||||||
|
elif quant_mode.post_training_int16x8_no_float():
|
||||||
|
result = self._calibrate_quantize_model(result, inference_input_type,
|
||||||
|
inference_output_type, activations_type, False)
|
||||||
|
elif quant_mode.post_training_int16x8_allow_float():
|
||||||
|
result = self._calibrate_quantize_model(result, inference_input_type,
|
||||||
|
inference_output_type, activations_type, True)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -1334,7 +1383,6 @@ class TocoConverter(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@_deprecation.deprecated(
|
@_deprecation.deprecated(
|
||||||
None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.")
|
|
||||||
def from_keras_model_file(cls,
|
def from_keras_model_file(cls,
|
||||||
model_file,
|
model_file,
|
||||||
input_arrays=None,
|
input_arrays=None,
|
||||||
|
@ -40,17 +40,17 @@ PYBIND11_MODULE(_pywrap_tensorflow_lite_calibration_wrapper, m) {
|
|||||||
})
|
})
|
||||||
.def("QuantizeModel",
|
.def("QuantizeModel",
|
||||||
[](CalibrationWrapper& self, int input_py_type, int output_py_type,
|
[](CalibrationWrapper& self, int input_py_type, int output_py_type,
|
||||||
bool allow_float, bool enable_mlir_quantizer) {
|
bool allow_float, int activations_py_type, bool enable_mlir_quantizer) {
|
||||||
return tensorflow::pyo_or_throw(
|
return tensorflow::pyo_or_throw(
|
||||||
self.QuantizeModel(input_py_type, output_py_type, allow_float,
|
self.QuantizeModel(input_py_type, output_py_type, allow_float,
|
||||||
enable_mlir_quantizer));
|
activations_py_type, enable_mlir_quantizer));
|
||||||
})
|
})
|
||||||
.def("QuantizeModel",
|
.def("QuantizeModel",
|
||||||
[](CalibrationWrapper& self, int input_py_type, int output_py_type,
|
[](CalibrationWrapper& self, int input_py_type, int output_py_type,
|
||||||
bool allow_float) {
|
bool allow_float, int activations_py_type) {
|
||||||
return tensorflow::pyo_or_throw(
|
return tensorflow::pyo_or_throw(
|
||||||
self.QuantizeModel(input_py_type, output_py_type, allow_float,
|
self.QuantizeModel(input_py_type, output_py_type, allow_float,
|
||||||
/*enable_mlir_quantizer=*/false));
|
activations_py_type, /*enable_mlir_quantizer=*/false));
|
||||||
})
|
})
|
||||||
.def("QuantizeModel", [](CalibrationWrapper& self, int input_py_type,
|
.def("QuantizeModel", [](CalibrationWrapper& self, int input_py_type,
|
||||||
int output_py_type, bool allow_float,
|
int output_py_type, bool allow_float,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user