Add inference_input_type and inference_output_type flags in TF 2.x TFLiteConverter (backward compatible with TF 1.x) to support integer (tf.int8, tf.uint8) input and output types in post training full integer quantized models.

PiperOrigin-RevId: 313668965
Change-Id: Iea684507f58651b34dada0285b00a82e80066aab
This commit is contained in:
A. Unique TensorFlower 2020-05-28 15:15:50 -07:00 committed by TensorFlower Gardener
parent 60c828a70e
commit 7d605fb0e2
2 changed files with 32 additions and 217 deletions

View File

@ -201,11 +201,6 @@ class QuantizationMode(object):
self._representative_dataset is not None and
self._smallest_supported_type() == constants.INT8)
def is_post_training_integer_quantize(self):
"""Post training integer quantization."""
return (self.post_training_int8_no_float() or
self.post_training_int8_allow_float())
def training_time_int8_allow_float(self):
"""Training-time int8 quantize, allow float fallback."""
return (self._any_optimization_enabled() and
@ -418,56 +413,7 @@ class TFLiteConverterBase(object):
class TFLiteConverterBaseV2(TFLiteConverterBase):
"""Converter subclass to share functionality between V2 converters.
Attributes:
allow_custom_ops: Boolean indicating whether to allow custom operations.
When False, any unknown operation is an error. When True, custom ops are
created for any op that is unknown. The developer needs to provide these
to the TensorFlow Lite runtime with a custom resolver. (default False)
optimizations: Experimental flag, subject to change. A list of optimizations
to apply when converting the model. E.g. `[Optimize.DEFAULT]`
representative_dataset: A representative dataset that can be used to
generate input and output samples for the model. The converter can use the
dataset to evaluate different optimizations. Note that this is an optional
attribute but it is necessary if INT8 is the only support builtin ops in
target ops.
target_spec: Experimental flag, subject to change. Specification of target
device.
inference_input_type: Data type of the input layer. Note that integer types
(tf.int8 and tf.uint8) are currently only supported for post training
integer quantization. (default tf.float32, must be in {tf.float32,
tf.int8, tf.uint8})
inference_output_type: Data type of the output layer. Note that integer
types (tf.int8 and tf.uint8) are currently only supported for post
training integer quantization. (default tf.float32, must be in
{tf.float32, tf.int8, tf.uint8})
experimental_new_converter: Experimental flag, subject to change. Enables
MLIR-based conversion instead of TOCO conversion.
"""
def __init__(self):
"""Constructor for TFLiteConverter."""
super(TFLiteConverterBaseV2, self).__init__()
self.inference_input_type = constants.FLOAT
self.inference_output_type = constants.FLOAT
def _validate_inference_input_output_types(self, quant_mode):
"""Validate inference_input_type and inference_output_type flags."""
default_types = [constants.FLOAT, None]
# We only support integer types for post training integer quantization
# as we have statistical information to quantize the input and output.
if quant_mode.is_post_training_integer_quantize():
all_types = default_types + [constants.INT8, constants.QUANTIZED_UINT8]
if self.inference_input_type not in all_types or \
self.inference_output_type not in all_types:
all_types_names = ["tf." + t.name for t in all_types]
raise ValueError("The inference_input_type and inference_output_type "
"must be in {}.".format(all_types_names))
elif self.inference_input_type not in default_types or \
self.inference_output_type not in default_types:
raise ValueError("The inference_input_type and inference_output_type "
"must be tf.float32.")
"""Converter subclass to share functionality between V2 converters."""
def convert(self, graph_def, input_tensors, output_tensors):
"""Converts a TensorFlow GraphDef based on instance variables.
@ -491,8 +437,6 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
self.representative_dataset, graph_def)
self._validate_inference_input_output_types(quant_mode)
if not self._is_unknown_shapes_allowed():
# Checks dimensions in input tensor.
for tensor in input_tensors:
@ -535,9 +479,6 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
"quantize_to_float16": True,
})
# Converter requires that the inference_input_type flag is set to FLOAT
converter_kwargs.update({"inference_input_type": constants.FLOAT})
if not self.experimental_new_converter:
logging.warning(
"Please consider switching to use new converter by setting "
@ -557,11 +498,11 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
**converter_kwargs)
if quant_mode.post_training_int8_no_float():
result = self._calibrate_quantize_model(result, self.inference_input_type,
self.inference_output_type, False)
result = self._calibrate_quantize_model(result, constants.FLOAT,
constants.FLOAT, False)
elif quant_mode.post_training_int8_allow_float():
result = self._calibrate_quantize_model(result, self.inference_input_type,
self.inference_output_type, True)
result = self._calibrate_quantize_model(result, constants.FLOAT,
constants.FLOAT, True)
if self._experimental_sparsify_model:
result = _mlir_sparsify(result)
@ -817,9 +758,12 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
Attributes:
allow_custom_ops: Boolean indicating whether to allow custom operations.
When False, any unknown operation is an error. When True, custom ops are
created for any op that is unknown. The developer needs to provide these
to the TensorFlow Lite runtime with a custom resolver. (default False)
When false any unknown operation is an error. When true, custom ops are
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver.
(default False)
target_spec: Experimental flag, subject to change. Specification of target
device.
optimizations: Experimental flag, subject to change. A list of optimizations
to apply when converting the model. E.g. `[Optimize.DEFAULT]`
representative_dataset: A representative dataset that can be used to
@ -827,19 +771,8 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
dataset to evaluate different optimizations. Note that this is an optional
attribute but it is necessary if INT8 is the only support builtin ops in
target ops.
target_spec: Experimental flag, subject to change. Specification of target
device.
inference_input_type: Data type of the input layer. Note that integer types
(tf.int8 and tf.uint8) are currently only supported for post training
integer quantization. (default tf.float32, must be in {tf.float32,
tf.int8, tf.uint8})
inference_output_type: Data type of the output layer. Note that integer
types (tf.int8 and tf.uint8) are currently only supported for post
training integer quantization. (default tf.float32, must be in
{tf.float32, tf.int8, tf.uint8})
experimental_new_converter: Experimental flag, subject to change. Enables
MLIR-based conversion instead of TOCO conversion.
experimental_new_converter: Experimental flag, subject to change.
Enables MLIR-based conversion instead of TOCO conversion.
Example usage:
```python

View File

@ -71,27 +71,6 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
self.assertEqual(expected_value.numpy(), actual_value)
@parameterized.named_parameters(
('_INT8InputOutput', lite.constants.INT8),
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
@test_util.run_v2_only
def testInvalidFloat(self, inference_input_output_type):
root = self._getSimpleVariableModel()
input_data = tf.constant(1., shape=[1])
concrete_func = root.f.get_concrete_function(input_data)
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
# We don't support integer types as we don't have statistical information
# to quantize (only supported for post training integer quantization).
with self.assertRaises(ValueError) as error:
converter.inference_input_type = inference_input_output_type
converter.inference_output_type = inference_input_output_type
converter.convert()
self.assertEqual(
'The inference_input_type and inference_output_type '
'must be tf.float32.', str(error.exception))
@test_util.run_v2_only
def testScalarInput(self):
root = self._getSimpleVariableModel()
@ -193,113 +172,39 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
self.assertLess(len(quantized_tflite), len(float_tflite))
@parameterized.named_parameters(
('_INT8InputOutput', lite.constants.INT8),
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
@test_util.run_v2_only
def testInvalidPostTrainingDynamicRangeQuantization(
self, inference_input_output_type):
func, _ = self._getCalibrationQuantizeModel()
# Convert float model.
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Convert quantized model.
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
# We don't support integer types as we don't have statistical information
# to quantize (only supported for post training integer quantization).
with self.assertRaises(ValueError) as error:
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_converter.convert()
self.assertEqual(
'The inference_input_type and inference_output_type '
'must be tf.float32.', str(error.exception))
@parameterized.named_parameters(
('_DefaultFLOAT32InputOutput', lite.constants.FLOAT),
('_INT8InputOutput', lite.constants.INT8),
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
def testPostTrainingIntegerAllowFloatQuantization(
self, inference_input_output_type):
('EnableMlirQuantizer', True), # enable mlir quantizer
('DisableMlirQuantizer', False)) # disable mlir quantizer
def testCalibrateAndQuantizeBuiltinInt8(self, mlir_quantizer):
func, calibration_gen = self._getCalibrationQuantizeModel()
# Convert float model.
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Convert quantized model.
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.representative_dataset = calibration_gen
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_tflite_model = quantized_converter.convert()
self.assertTrue(quantized_tflite_model)
interpreter = Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(inference_input_output_type.as_numpy_dtype,
input_details[0]['dtype'])
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertEqual(inference_input_output_type.as_numpy_dtype,
output_details[0]['dtype'])
# Ensure that the quantized tflite model is smaller.
self.assertLess(len(quantized_tflite_model), len(tflite_model))
@parameterized.named_parameters(
('_DefaultFLOAT32InputOutput_UseTargetTypesFlag',
lite.constants.FLOAT, False),
('_DefaultFLOAT32InputOutput', lite.constants.FLOAT, True),
('_INT8InputOutput', lite.constants.INT8, True),
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8, True))
@test_util.run_v2_only
def testPostTrainingIntegerNoFloatQuantization(self,
inference_input_output_type,
use_target_ops_flag):
func, calibration_gen = self._getCalibrationQuantizeModel()
# Convert float model.
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
float_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
float_tflite = float_converter.convert()
self.assertTrue(float_tflite)
# Convert model by specifying target spec (instead of optimizations), since
# when targeting an integer only backend, quantization is mandatory.
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.TFLITE_BUILTINS_INT8
]
quantized_converter.representative_dataset = calibration_gen
if use_target_ops_flag:
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.TFLITE_BUILTINS_INT8
]
else:
quantized_converter.target_spec.supported_types = [lite.constants.INT8]
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_tflite_model = quantized_converter.convert()
self.assertTrue(quantized_tflite_model)
quantized_converter._experimental_new_quantizer = mlir_quantizer
quantized_tflite = quantized_converter.convert()
self.assertTrue(quantized_tflite)
interpreter = Interpreter(model_content=quantized_tflite_model)
# The default input and output types should be float.
interpreter = Interpreter(model_content=quantized_tflite)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(inference_input_output_type.as_numpy_dtype,
input_details[0]['dtype'])
self.assertEqual(np.float32, input_details[0]['dtype'])
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertEqual(inference_input_output_type.as_numpy_dtype,
output_details[0]['dtype'])
self.assertEqual(np.float32, output_details[0]['dtype'])
# Ensure that the quantized tflite model is smaller.
self.assertLess(len(quantized_tflite_model), len(tflite_model))
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
def testCalibrateAndQuantizeBuiltinInt16(self):
func, calibration_gen = self._getCalibrationQuantizeModel()
@ -374,7 +279,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
return tf.keras.Sequential(QLinear(3, input_shape=(2,)))
@test_util.run_v2_only
def testTrainingTimeQuantization(self):
def testTrainingTimeQuantizeConversion(self):
model = self._getTrainingTimeQuantizedModel()
float_converter = lite.TFLiteConverterV2.from_keras_model(model)
@ -392,29 +297,6 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
interpreter = Interpreter(model_content=quantized_tflite)
self.assertEqual(np.float32, interpreter.get_input_details()[0]['dtype'])
@parameterized.named_parameters(
('_INT8InputOutput', lite.constants.INT8),
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
def testInvalidTrainingTimeQuantization(self, inference_input_output_type):
# We currently don't support integer inference_input_type and
# inference_output_type flags for training time quantization.
model = self._getTrainingTimeQuantizedModel()
converter = lite.TFLiteConverterV2.from_keras_model(model)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
with self.assertRaises(ValueError) as error:
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_converter.convert()
self.assertEqual(
'The inference_input_type and inference_output_type '
'must be tf.float32.', str(error.exception))
@test_util.run_v2_only
def testNewQuantizer(self):
"""Test the model quantized by the new converter."""