Modify model input/output in a separate utility for all quantized TF1/TF2 models.
PiperOrigin-RevId: 338133056 Change-Id: I27f2232d0f39f51565434695ad5a9e075770267f
This commit is contained in:
parent
d5ab30ca14
commit
5832fa274f
tensorflow/lite
@ -211,16 +211,14 @@ class QuantizationMode(object):
|
||||
return (self.post_training_int16x8_no_float() or
|
||||
self.post_training_int16x8_allow_float())
|
||||
|
||||
def is_post_training_integer_quantize(self):
|
||||
"""Post training integer quantization."""
|
||||
def is_integer_quantize(self):
|
||||
return (self.is_post_training_integer_quantize_8() or
|
||||
self.is_post_training_integer_quantize_16x8())
|
||||
self.is_post_training_integer_quantize_16x8() or
|
||||
self.is_training_time_int8_allow_float())
|
||||
|
||||
def training_time_int8_allow_float(self):
|
||||
"""Training-time int8 quantize, allow float fallback."""
|
||||
def is_training_time_int8_allow_float(self):
|
||||
return (self._any_optimization_enabled() and
|
||||
not self.post_training_dynamic_range_int8() and
|
||||
not self.post_training_fp16())
|
||||
self.contains_training_quant_op())
|
||||
|
||||
def post_training_int16x8_no_float(self):
|
||||
"""Post training int16x8 quantize, disallow float fallback."""
|
||||
@ -249,11 +247,7 @@ class QuantizationMode(object):
|
||||
|
||||
def fp32_execution(self):
|
||||
"""If none of the above are true."""
|
||||
return not (self.post_training_int8_no_float() or
|
||||
self.post_training_int8_allow_float() or
|
||||
self.training_time_int8_allow_float() or
|
||||
self.post_training_int16x8_no_float() or
|
||||
self.post_training_int16x8_allow_float() or
|
||||
return not (self.is_integer_quantize() or
|
||||
self.post_training_dynamic_range_int8() or
|
||||
self.post_training_fp16())
|
||||
|
||||
@ -263,17 +257,12 @@ class QuantizationMode(object):
|
||||
|
||||
def converter_flags(self, inference_ty=None, inference_input_ty=None):
|
||||
"""Flags to the converter."""
|
||||
if self.is_post_training_integer_quantize():
|
||||
# The inference_input_type is for the quantizer, then we need to keep the
|
||||
# converter inference_input_type to float.
|
||||
inference_input_ty = _dtypes.float32
|
||||
|
||||
if self.training_time_int8_allow_float():
|
||||
if self.is_integer_quantize():
|
||||
return {
|
||||
"inference_type": inference_ty if inference_ty else \
|
||||
self.activations_type(),
|
||||
"inference_input_type":
|
||||
inference_input_ty if inference_input_ty else _dtypes.float32,
|
||||
"inference_input_type": _dtypes.float32,
|
||||
"post_training_quantize": False, # disable dynamic range quantization
|
||||
"quantize_to_float16": False # disable float16 quantization
|
||||
}
|
||||
@ -326,19 +315,13 @@ class QuantizationMode(object):
|
||||
else:
|
||||
return False, None
|
||||
|
||||
def flags_modify_model_io_type(
|
||||
self, input_type=_dtypes.float32, output_type=_dtypes.float32):
|
||||
def flags_modify_model_io_type(self, input_ty=None, output_ty=None):
|
||||
"""Flags for modifying the input and output type of a tflite model."""
|
||||
is_post_training_quantize = self.quantizer_flags(input_type, output_type)[0]
|
||||
is_training_time_only_quantize = self.training_time_int8_allow_float() and \
|
||||
not is_post_training_quantize
|
||||
|
||||
# TODO(b/153576658): Consolidate post/during training quantization workflows
|
||||
# to modify model input/output type after MLIR conversion.
|
||||
if is_training_time_only_quantize:
|
||||
if self.is_integer_quantize():
|
||||
return {
|
||||
"inference_input_type": input_type,
|
||||
"inference_output_type": output_type,
|
||||
"inference_input_type": input_ty if input_ty else _dtypes.float32,
|
||||
"inference_output_type": output_ty if output_ty else _dtypes.float32,
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@ -563,7 +546,7 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
|
||||
"""Validate inference_input_type and inference_output_type flags."""
|
||||
default_types = [_dtypes.float32]
|
||||
# We support integer input/output for integer quantized models only.
|
||||
if quant_mode.training_time_int8_allow_float():
|
||||
if quant_mode.is_integer_quantize():
|
||||
if quant_mode.is_post_training_integer_quantize_16x8():
|
||||
all_types = default_types + [_dtypes.int16]
|
||||
else:
|
||||
@ -645,8 +628,7 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
|
||||
output_tensors=output_tensors,
|
||||
**converter_kwargs)
|
||||
|
||||
calibrate_and_quantize, flags = quant_mode.quantizer_flags(
|
||||
self.inference_input_type, self.inference_output_type)
|
||||
calibrate_and_quantize, flags = quant_mode.quantizer_flags()
|
||||
if calibrate_and_quantize:
|
||||
result = self._calibrate_quantize_model(result, **flags)
|
||||
|
||||
@ -756,8 +738,7 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
|
||||
converter_kwargs.update(quant_mode.converter_flags())
|
||||
|
||||
result = _convert_saved_model(**converter_kwargs)
|
||||
calibrate_and_quantize, flags = quant_mode.quantizer_flags(
|
||||
self.inference_input_type, self.inference_output_type)
|
||||
calibrate_and_quantize, flags = quant_mode.quantizer_flags()
|
||||
if calibrate_and_quantize:
|
||||
result = self._calibrate_quantize_model(result, **flags)
|
||||
|
||||
@ -1307,8 +1288,11 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
|
||||
"please file a bug. You can opt-out "
|
||||
"by setting experimental_new_converter=False")
|
||||
|
||||
calibrate_quantize, flags = quant_mode.quantizer_flags(
|
||||
self.inference_input_type, self.inference_output_type)
|
||||
if not self.experimental_new_converter:
|
||||
calibrate_quantize, flags = quant_mode.quantizer_flags(
|
||||
self.inference_input_type, self.inference_output_type)
|
||||
else:
|
||||
calibrate_quantize, flags = quant_mode.quantizer_flags()
|
||||
|
||||
self._validate_quantized_input_stats(converter_kwargs, calibrate_quantize)
|
||||
|
||||
@ -1329,6 +1313,12 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
|
||||
if calibrate_quantize:
|
||||
result = self._calibrate_quantize_model(result, **flags)
|
||||
|
||||
if self.experimental_new_converter:
|
||||
flags_modify_model_io_type = quant_mode.flags_modify_model_io_type(
|
||||
self.inference_input_type, self.inference_output_type)
|
||||
if flags_modify_model_io_type:
|
||||
result = _modify_model_io_type(result, **flags_modify_model_io_type)
|
||||
|
||||
if self._experimental_sparsify_model:
|
||||
result = _mlir_sparsify(result)
|
||||
|
||||
|
@ -114,6 +114,7 @@ def toco_convert(options, graph_def, input_tensors, output_tensors, **kwargs):
|
||||
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
|
||||
graphdef_file.name, input_arrays, output_tensors, input_shapes)
|
||||
|
||||
converter.experimental_new_converter = options.use_experimental_converter
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
|
||||
if fully_quantize:
|
||||
|
Loading…
Reference in New Issue
Block a user