Modify model input/output in a separate utility for all quantized TF1/TF2 models.

PiperOrigin-RevId: 338133056
Change-Id: I27f2232d0f39f51565434695ad5a9e075770267f
This commit is contained in:
Meghna Natraj 2020-10-20 14:01:36 -07:00 committed by TensorFlower Gardener
parent d5ab30ca14
commit 5832fa274f
2 changed files with 27 additions and 36 deletions
tensorflow/lite

View File

@ -211,16 +211,14 @@ class QuantizationMode(object):
return (self.post_training_int16x8_no_float() or
self.post_training_int16x8_allow_float())
def is_post_training_integer_quantize(self):
"""Post training integer quantization."""
def is_integer_quantize(self):
return (self.is_post_training_integer_quantize_8() or
self.is_post_training_integer_quantize_16x8())
self.is_post_training_integer_quantize_16x8() or
self.is_training_time_int8_allow_float())
def training_time_int8_allow_float(self):
"""Training-time int8 quantize, allow float fallback."""
def is_training_time_int8_allow_float(self):
return (self._any_optimization_enabled() and
not self.post_training_dynamic_range_int8() and
not self.post_training_fp16())
self.contains_training_quant_op())
def post_training_int16x8_no_float(self):
"""Post training int16x8 quantize, disallow float fallback."""
@ -249,11 +247,7 @@ class QuantizationMode(object):
def fp32_execution(self):
"""If none of the above are true."""
return not (self.post_training_int8_no_float() or
self.post_training_int8_allow_float() or
self.training_time_int8_allow_float() or
self.post_training_int16x8_no_float() or
self.post_training_int16x8_allow_float() or
return not (self.is_integer_quantize() or
self.post_training_dynamic_range_int8() or
self.post_training_fp16())
@ -263,17 +257,12 @@ class QuantizationMode(object):
def converter_flags(self, inference_ty=None, inference_input_ty=None):
"""Flags to the converter."""
if self.is_post_training_integer_quantize():
# The inference_input_type is for the quantizer, then we need to keep the
# converter inference_input_type to float.
inference_input_ty = _dtypes.float32
if self.training_time_int8_allow_float():
if self.is_integer_quantize():
return {
"inference_type": inference_ty if inference_ty else \
self.activations_type(),
"inference_input_type":
inference_input_ty if inference_input_ty else _dtypes.float32,
"inference_input_type": _dtypes.float32,
"post_training_quantize": False, # disable dynamic range quantization
"quantize_to_float16": False # disable float16 quantization
}
@ -326,19 +315,13 @@ class QuantizationMode(object):
else:
return False, None
def flags_modify_model_io_type(
self, input_type=_dtypes.float32, output_type=_dtypes.float32):
def flags_modify_model_io_type(self, input_ty=None, output_ty=None):
"""Flags for modifying the input and output type of a tflite model."""
is_post_training_quantize = self.quantizer_flags(input_type, output_type)[0]
is_training_time_only_quantize = self.training_time_int8_allow_float() and \
not is_post_training_quantize
# TODO(b/153576658): Consolidate post/during training quantization workflows
# to modify model input/output type after MLIR conversion.
if is_training_time_only_quantize:
if self.is_integer_quantize():
return {
"inference_input_type": input_type,
"inference_output_type": output_type,
"inference_input_type": input_ty if input_ty else _dtypes.float32,
"inference_output_type": output_ty if output_ty else _dtypes.float32,
}
else:
return None
@ -563,7 +546,7 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
"""Validate inference_input_type and inference_output_type flags."""
default_types = [_dtypes.float32]
# We support integer input/output for integer quantized models only.
if quant_mode.training_time_int8_allow_float():
if quant_mode.is_integer_quantize():
if quant_mode.is_post_training_integer_quantize_16x8():
all_types = default_types + [_dtypes.int16]
else:
@ -645,8 +628,7 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
output_tensors=output_tensors,
**converter_kwargs)
calibrate_and_quantize, flags = quant_mode.quantizer_flags(
self.inference_input_type, self.inference_output_type)
calibrate_and_quantize, flags = quant_mode.quantizer_flags()
if calibrate_and_quantize:
result = self._calibrate_quantize_model(result, **flags)
@ -756,8 +738,7 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
converter_kwargs.update(quant_mode.converter_flags())
result = _convert_saved_model(**converter_kwargs)
calibrate_and_quantize, flags = quant_mode.quantizer_flags(
self.inference_input_type, self.inference_output_type)
calibrate_and_quantize, flags = quant_mode.quantizer_flags()
if calibrate_and_quantize:
result = self._calibrate_quantize_model(result, **flags)
@ -1307,8 +1288,11 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
"please file a bug. You can opt-out "
"by setting experimental_new_converter=False")
calibrate_quantize, flags = quant_mode.quantizer_flags(
self.inference_input_type, self.inference_output_type)
if not self.experimental_new_converter:
calibrate_quantize, flags = quant_mode.quantizer_flags(
self.inference_input_type, self.inference_output_type)
else:
calibrate_quantize, flags = quant_mode.quantizer_flags()
self._validate_quantized_input_stats(converter_kwargs, calibrate_quantize)
@ -1329,6 +1313,12 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
if calibrate_quantize:
result = self._calibrate_quantize_model(result, **flags)
if self.experimental_new_converter:
flags_modify_model_io_type = quant_mode.flags_modify_model_io_type(
self.inference_input_type, self.inference_output_type)
if flags_modify_model_io_type:
result = _modify_model_io_type(result, **flags_modify_model_io_type)
if self._experimental_sparsify_model:
result = _mlir_sparsify(result)

View File

@ -114,6 +114,7 @@ def toco_convert(options, graph_def, input_tensors, output_tensors, **kwargs):
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
graphdef_file.name, input_arrays, output_tensors, input_shapes)
converter.experimental_new_converter = options.use_experimental_converter
converter.optimizations = [tf.lite.Optimize.DEFAULT]
if fully_quantize: