Support QAT conversion using TFLiteConverterV2
V1 converter requires inference_type, inference_input_type and quantized_input_stats for conversion, whereas V2 converter utilizes FQ ops inside the graph for conversion and input information. This CL does the following 1. Move the input_stats check from the convert code into the V1 converter since that is now specific to it. 2. Improve the condition checking for post training calibrate vs weight only quantize vs training time quantize 3. Actually handle training-time quantize by passing the necessary flags to TOCO. Important to note, this appraoch leaves the option for both QAT and post-training calibrate quantize to be applied together in the same conversion. PiperOrigin-RevId: 298533518 Change-Id: I48ec5b8db8f20242522ca7af70dcbe339b79aa2f
This commit is contained in:
parent
48393637f8
commit
867c320558
tensorflow/lite/python
@ -374,13 +374,9 @@ def build_toco_convert_protos(input_tensors,
|
||||
input_array.data_type = util.convert_dtype_to_tflite_type(
|
||||
input_tensor.dtype)
|
||||
|
||||
if _requires_input_stats(toco):
|
||||
if quantized_input_stats:
|
||||
input_array.mean_value, input_array.std_value = quantized_input_stats[
|
||||
idx]
|
||||
else:
|
||||
raise ValueError("std_dev and mean must be defined when inference_type "
|
||||
"or inference_input_type is QUANTIZED_UINT8 or INT8.")
|
||||
if _requires_input_stats(toco) and quantized_input_stats:
|
||||
input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
|
||||
|
||||
if input_shapes is None:
|
||||
shape = input_tensor.shape
|
||||
else:
|
||||
|
@ -63,33 +63,6 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
quantized_input_stats=[(0., 1.)])
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
def testQuantizationInvalid(self):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32)
|
||||
out_tensor = array_ops.fake_quant_with_min_max_args(
|
||||
in_tensor + in_tensor, min=0., max=1.)
|
||||
sess = session.Session()
|
||||
|
||||
with self.assertRaises(ValueError) as error:
|
||||
convert.toco_convert(
|
||||
sess.graph_def, [in_tensor], [out_tensor],
|
||||
inference_type=lite_constants.QUANTIZED_UINT8)
|
||||
self.assertEqual(
|
||||
"std_dev and mean must be defined when inference_type or "
|
||||
"inference_input_type is QUANTIZED_UINT8 or INT8.",
|
||||
str(error.exception))
|
||||
|
||||
with self.assertRaises(ValueError) as error:
|
||||
convert.toco_convert(
|
||||
sess.graph_def, [in_tensor], [out_tensor],
|
||||
inference_type=lite_constants.QUANTIZED_UINT8,
|
||||
inference_input_type=lite_constants.FLOAT)
|
||||
self.assertEqual(
|
||||
"std_dev and mean must be defined when inference_type or "
|
||||
"inference_input_type is QUANTIZED_UINT8 or INT8.",
|
||||
str(error.exception))
|
||||
|
||||
def testGraphDefBasic(self):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(
|
||||
|
@ -211,6 +211,7 @@ class TFLiteConverterBase(object):
|
||||
raise ValueError(
|
||||
"Provide an input generator for representative_dataset")
|
||||
elif self._is_int8_target_required():
|
||||
# TODO(b/150661651): Relax this check for QAT
|
||||
raise ValueError("representative_dataset is required when specifying "
|
||||
"TFLITE_BUILTINS_INT8 or INT8 supported types.")
|
||||
|
||||
@ -239,12 +240,35 @@ class TFLiteConverterBase(object):
|
||||
Optimize.DEFAULT
|
||||
]))
|
||||
|
||||
def _contains_training_quant_op(self, graph_def):
|
||||
"""Checks if the graph contains any training-time quantization ops.
|
||||
|
||||
This is one of the simplest ways to detect whether the model is
|
||||
training-time quantized, since FakeQuant ops are added only during
|
||||
quantization aware training.
|
||||
|
||||
Args:
|
||||
graph_def: GraphDef representing the TF graph.
|
||||
|
||||
Returns:
|
||||
True/False
|
||||
"""
|
||||
training_quant_ops = frozenset({
|
||||
"FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxVarsPerChannel",
|
||||
"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3"})
|
||||
|
||||
for node_def in graph_def.node:
|
||||
if any([op in node_def.name for op in training_quant_ops]):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_post_training_optimize(self):
|
||||
return self._is_int8_target_required() or self._any_optimization_enabled()
|
||||
|
||||
def _is_int8_weight_only_quantize(self):
|
||||
return (self._is_post_training_optimize() and
|
||||
(self.representative_dataset is None))
|
||||
(self.representative_dataset is None) and
|
||||
not self._contains_training_quant_op(self._graph_def))
|
||||
|
||||
def _is_float16_quantize(self):
|
||||
return self._any_optimization_enabled() and (
|
||||
@ -255,6 +279,10 @@ class TFLiteConverterBase(object):
|
||||
self.representative_dataset and
|
||||
self._smallest_supported_type() != constants.FLOAT16)
|
||||
|
||||
def _is_training_time_quantize(self):
|
||||
return (self._contains_training_quant_op(self._graph_def) and
|
||||
self._any_optimization_enabled())
|
||||
|
||||
def _calibrate_quantize_model(self, result, inference_input_type,
|
||||
inference_output_type):
|
||||
allow_float = not self._is_int8_target_required()
|
||||
@ -462,6 +490,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
frozen_func, graph_def = (
|
||||
_convert_to_constants.convert_variables_to_constants_v2_as_graph(
|
||||
self._funcs[0], lower_control_flow=False))
|
||||
self._graph_def = graph_def
|
||||
input_tensors = [
|
||||
tensor for tensor in frozen_func.inputs
|
||||
if tensor.dtype != _dtypes.resource
|
||||
@ -504,6 +533,12 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
|
||||
converter_kwargs = self._get_base_converter_args()
|
||||
|
||||
if self._is_training_time_quantize():
|
||||
converter_kwargs.update({
|
||||
"inference_type": constants.INT8,
|
||||
"inference_input_type": constants.FLOAT,
|
||||
})
|
||||
|
||||
if not self.experimental_new_converter:
|
||||
logging.warning(
|
||||
"Please consider switching to use new converter by setting "
|
||||
@ -953,6 +988,21 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
return self.target_spec.supported_ops
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
def _validate_quantized_input_stats(self, converter_kwargs):
|
||||
"""Ensure quantized_input_stats provided if required."""
|
||||
|
||||
quantized_types = frozenset({constants.INT8, constants.QUANTIZED_UINT8})
|
||||
|
||||
requires_quantized_input_stats = (
|
||||
(converter_kwargs["inference_type"] in quantized_types or
|
||||
converter_kwargs["inference_input_type"] in quantized_types) and
|
||||
not converter_kwargs["post_training_quantize"])
|
||||
|
||||
if (requires_quantized_input_stats and
|
||||
not converter_kwargs["quantized_input_stats"]):
|
||||
raise ValueError("std_dev and mean must be defined when inference_type "
|
||||
"or inference_input_type is QUANTIZED_UINT8 or INT8.")
|
||||
|
||||
def convert(self):
|
||||
"""Converts a TensorFlow GraphDef based on instance variables.
|
||||
|
||||
@ -1085,6 +1135,8 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
"please file a bug. You can opt-out "
|
||||
"by setting experimental_new_converter=False")
|
||||
|
||||
self._validate_quantized_input_stats(converter_kwargs)
|
||||
|
||||
# Converts model.
|
||||
if self._has_valid_tensors():
|
||||
result = _toco_convert_impl(
|
||||
|
@ -1082,6 +1082,46 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
# Ensure that the quantized weights tflite model is smaller.
|
||||
self.assertTrue(len(quantized_tflite) < len(float_tflite))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('InferenceType_INT8', lite_constants.INT8),
|
||||
('InferenceType_QUANTIZED_INT8', lite_constants.QUANTIZED_UINT8))
|
||||
def testRequiresInputStatsForTrainingTimeQuantization(self, quantized_type):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32)
|
||||
out_tensor = array_ops.fake_quant_with_min_max_args(
|
||||
in_tensor + in_tensor, min=0., max=1.)
|
||||
sess = session.Session()
|
||||
|
||||
quantized_converter = lite.TFLiteConverter.from_session(
|
||||
sess, [in_tensor], [out_tensor])
|
||||
|
||||
with self.assertRaises(ValueError) as error:
|
||||
quantized_converter.inference_type = quantized_type
|
||||
quantized_converter.convert()
|
||||
self.assertEqual(
|
||||
'std_dev and mean must be defined when inference_type or '
|
||||
'inference_input_type is QUANTIZED_UINT8 or INT8.',
|
||||
str(error.exception))
|
||||
|
||||
with self.assertRaises(ValueError) as error:
|
||||
quantized_converter.inference_type = lite_constants.FLOAT
|
||||
quantized_converter.inference_input_type = quantized_type
|
||||
quantized_converter.convert()
|
||||
self.assertEqual(
|
||||
'std_dev and mean must be defined when inference_type or '
|
||||
'inference_input_type is QUANTIZED_UINT8 or INT8.',
|
||||
str(error.exception))
|
||||
|
||||
quantized_converter.inference_type = quantized_type
|
||||
quantized_converter.inference_input_type = quantized_type
|
||||
|
||||
input_arrays = quantized_converter.get_input_arrays()
|
||||
quantized_converter.quantized_input_stats = {
|
||||
input_arrays[0]: (0., 1.)
|
||||
}
|
||||
quantized_converter.convert()
|
||||
|
||||
def testFloatTocoConverter(self):
|
||||
"""Tests deprecated test TocoConverter."""
|
||||
with ops.Graph().as_default():
|
||||
|
@ -284,6 +284,60 @@ class FromConcreteFunctionTest(TestModels):
|
||||
# Ensure that the quantized weights tflite model is smaller.
|
||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||
|
||||
def _getTrainingTimeQuantizedModel(self):
|
||||
class QLinear(keras.layers.Layer):
|
||||
|
||||
def __init__(self, units=3, **kwargs):
|
||||
super(QLinear, self).__init__(**kwargs)
|
||||
self.units = units
|
||||
|
||||
def build(self, input_shape):
|
||||
self.w = self.add_weight(shape=(input_shape[-1], self.units),
|
||||
initializer='random_normal',
|
||||
trainable=True)
|
||||
self.min_var = self.add_weight(
|
||||
'min',
|
||||
initializer=keras.initializers.Constant(-6.0),
|
||||
trainable=False)
|
||||
self.max_var = self.add_weight(
|
||||
'max',
|
||||
initializer=keras.initializers.Constant(6.0),
|
||||
trainable=False)
|
||||
|
||||
def call(self, inputs):
|
||||
x = array_ops.fake_quant_with_min_max_vars(
|
||||
inputs, self.min_var, self.max_var)
|
||||
|
||||
w_fq = array_ops.fake_quant_with_min_max_vars(
|
||||
self.w, self.min_var, self.max_var)
|
||||
x = math_ops.matmul(x, w_fq)
|
||||
|
||||
x = array_ops.fake_quant_with_min_max_vars(
|
||||
x, self.min_var, self.max_var)
|
||||
|
||||
return x
|
||||
|
||||
return keras.Sequential(QLinear(3, input_shape=(2,)))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testTrainingTimeQuantizeConversion(self):
|
||||
model = self._getTrainingTimeQuantizedModel()
|
||||
|
||||
float_converter = lite.TFLiteConverterV2.from_keras_model(model)
|
||||
float_tflite = float_converter.convert()
|
||||
self.assertTrue(float_tflite)
|
||||
|
||||
quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
quantized_tflite = quantized_converter.convert()
|
||||
self.assertTrue(quantized_tflite)
|
||||
|
||||
# Ensure that the quantized weights tflite model is smaller.
|
||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||
|
||||
interpreter = Interpreter(model_content=quantized_tflite)
|
||||
self.assertEqual(np.float32, interpreter.get_input_details()[0]['dtype'])
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testNewQuantizer(self):
|
||||
"""Test the model quantized by the new converter."""
|
||||
|
Loading…
Reference in New Issue
Block a user