Support QAT conversion using TFLiteConverterV2

V1 converter requires inference_type, inference_input_type
and quantized_input_stats for conversion, whereas V2
converter utilizes FQ ops inside the graph for conversion
and input information.

This CL does the following
  1. Move the input_stats check from the convert code into
     the V1 converter since that is now specific to it.
  2. Improve the condition checking for post training calibrate
     vs weight only quantize vs training time quantize
  3. Actually handle training-time quantize by passing the
     necessary flags to TOCO.

Important to note, this appraoch leaves the option for both
QAT and post-training calibrate quantize to be applied together
in the same conversion.

PiperOrigin-RevId: 298533518
Change-Id: I48ec5b8db8f20242522ca7af70dcbe339b79aa2f
This commit is contained in:
Pulkit Bhuwalka 2020-03-02 23:06:44 -08:00 committed by TensorFlower Gardener
parent 48393637f8
commit 867c320558
5 changed files with 150 additions and 35 deletions

View File

@ -374,13 +374,9 @@ def build_toco_convert_protos(input_tensors,
input_array.data_type = util.convert_dtype_to_tflite_type(
input_tensor.dtype)
if _requires_input_stats(toco):
if quantized_input_stats:
input_array.mean_value, input_array.std_value = quantized_input_stats[
idx]
else:
raise ValueError("std_dev and mean must be defined when inference_type "
"or inference_input_type is QUANTIZED_UINT8 or INT8.")
if _requires_input_stats(toco) and quantized_input_stats:
input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
if input_shapes is None:
shape = input_tensor.shape
else:

View File

@ -63,33 +63,6 @@ class ConvertTest(test_util.TensorFlowTestCase):
quantized_input_stats=[(0., 1.)])
self.assertTrue(tflite_model)
def testQuantizationInvalid(self):
with ops.Graph().as_default():
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = array_ops.fake_quant_with_min_max_args(
in_tensor + in_tensor, min=0., max=1.)
sess = session.Session()
with self.assertRaises(ValueError) as error:
convert.toco_convert(
sess.graph_def, [in_tensor], [out_tensor],
inference_type=lite_constants.QUANTIZED_UINT8)
self.assertEqual(
"std_dev and mean must be defined when inference_type or "
"inference_input_type is QUANTIZED_UINT8 or INT8.",
str(error.exception))
with self.assertRaises(ValueError) as error:
convert.toco_convert(
sess.graph_def, [in_tensor], [out_tensor],
inference_type=lite_constants.QUANTIZED_UINT8,
inference_input_type=lite_constants.FLOAT)
self.assertEqual(
"std_dev and mean must be defined when inference_type or "
"inference_input_type is QUANTIZED_UINT8 or INT8.",
str(error.exception))
def testGraphDefBasic(self):
with ops.Graph().as_default():
in_tensor = array_ops.placeholder(

View File

@ -211,6 +211,7 @@ class TFLiteConverterBase(object):
raise ValueError(
"Provide an input generator for representative_dataset")
elif self._is_int8_target_required():
# TODO(b/150661651): Relax this check for QAT
raise ValueError("representative_dataset is required when specifying "
"TFLITE_BUILTINS_INT8 or INT8 supported types.")
@ -239,12 +240,35 @@ class TFLiteConverterBase(object):
Optimize.DEFAULT
]))
def _contains_training_quant_op(self, graph_def):
"""Checks if the graph contains any training-time quantization ops.
This is one of the simplest ways to detect whether the model is
training-time quantized, since FakeQuant ops are added only during
quantization aware training.
Args:
graph_def: GraphDef representing the TF graph.
Returns:
True/False
"""
training_quant_ops = frozenset({
"FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxVarsPerChannel",
"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3"})
for node_def in graph_def.node:
if any([op in node_def.name for op in training_quant_ops]):
return True
return False
def _is_post_training_optimize(self):
return self._is_int8_target_required() or self._any_optimization_enabled()
def _is_int8_weight_only_quantize(self):
return (self._is_post_training_optimize() and
(self.representative_dataset is None))
(self.representative_dataset is None) and
not self._contains_training_quant_op(self._graph_def))
def _is_float16_quantize(self):
return self._any_optimization_enabled() and (
@ -255,6 +279,10 @@ class TFLiteConverterBase(object):
self.representative_dataset and
self._smallest_supported_type() != constants.FLOAT16)
def _is_training_time_quantize(self):
return (self._contains_training_quant_op(self._graph_def) and
self._any_optimization_enabled())
def _calibrate_quantize_model(self, result, inference_input_type,
inference_output_type):
allow_float = not self._is_int8_target_required()
@ -462,6 +490,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
frozen_func, graph_def = (
_convert_to_constants.convert_variables_to_constants_v2_as_graph(
self._funcs[0], lower_control_flow=False))
self._graph_def = graph_def
input_tensors = [
tensor for tensor in frozen_func.inputs
if tensor.dtype != _dtypes.resource
@ -504,6 +533,12 @@ class TFLiteConverterV2(TFLiteConverterBase):
converter_kwargs = self._get_base_converter_args()
if self._is_training_time_quantize():
converter_kwargs.update({
"inference_type": constants.INT8,
"inference_input_type": constants.FLOAT,
})
if not self.experimental_new_converter:
logging.warning(
"Please consider switching to use new converter by setting "
@ -953,6 +988,21 @@ class TFLiteConverter(TFLiteConverterBase):
return self.target_spec.supported_ops
return object.__getattribute__(self, name)
def _validate_quantized_input_stats(self, converter_kwargs):
"""Ensure quantized_input_stats provided if required."""
quantized_types = frozenset({constants.INT8, constants.QUANTIZED_UINT8})
requires_quantized_input_stats = (
(converter_kwargs["inference_type"] in quantized_types or
converter_kwargs["inference_input_type"] in quantized_types) and
not converter_kwargs["post_training_quantize"])
if (requires_quantized_input_stats and
not converter_kwargs["quantized_input_stats"]):
raise ValueError("std_dev and mean must be defined when inference_type "
"or inference_input_type is QUANTIZED_UINT8 or INT8.")
def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables.
@ -1085,6 +1135,8 @@ class TFLiteConverter(TFLiteConverterBase):
"please file a bug. You can opt-out "
"by setting experimental_new_converter=False")
self._validate_quantized_input_stats(converter_kwargs)
# Converts model.
if self._has_valid_tensors():
result = _toco_convert_impl(

View File

@ -1082,6 +1082,46 @@ class FromSessionTest(TestModels, parameterized.TestCase):
# Ensure that the quantized weights tflite model is smaller.
self.assertTrue(len(quantized_tflite) < len(float_tflite))
@parameterized.named_parameters(
('InferenceType_INT8', lite_constants.INT8),
('InferenceType_QUANTIZED_INT8', lite_constants.QUANTIZED_UINT8))
def testRequiresInputStatsForTrainingTimeQuantization(self, quantized_type):
with ops.Graph().as_default():
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = array_ops.fake_quant_with_min_max_args(
in_tensor + in_tensor, min=0., max=1.)
sess = session.Session()
quantized_converter = lite.TFLiteConverter.from_session(
sess, [in_tensor], [out_tensor])
with self.assertRaises(ValueError) as error:
quantized_converter.inference_type = quantized_type
quantized_converter.convert()
self.assertEqual(
'std_dev and mean must be defined when inference_type or '
'inference_input_type is QUANTIZED_UINT8 or INT8.',
str(error.exception))
with self.assertRaises(ValueError) as error:
quantized_converter.inference_type = lite_constants.FLOAT
quantized_converter.inference_input_type = quantized_type
quantized_converter.convert()
self.assertEqual(
'std_dev and mean must be defined when inference_type or '
'inference_input_type is QUANTIZED_UINT8 or INT8.',
str(error.exception))
quantized_converter.inference_type = quantized_type
quantized_converter.inference_input_type = quantized_type
input_arrays = quantized_converter.get_input_arrays()
quantized_converter.quantized_input_stats = {
input_arrays[0]: (0., 1.)
}
quantized_converter.convert()
def testFloatTocoConverter(self):
"""Tests deprecated test TocoConverter."""
with ops.Graph().as_default():

View File

@ -284,6 +284,60 @@ class FromConcreteFunctionTest(TestModels):
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
def _getTrainingTimeQuantizedModel(self):
class QLinear(keras.layers.Layer):
def __init__(self, units=3, **kwargs):
super(QLinear, self).__init__(**kwargs)
self.units = units
def build(self, input_shape):
self.w = self.add_weight(shape=(input_shape[-1], self.units),
initializer='random_normal',
trainable=True)
self.min_var = self.add_weight(
'min',
initializer=keras.initializers.Constant(-6.0),
trainable=False)
self.max_var = self.add_weight(
'max',
initializer=keras.initializers.Constant(6.0),
trainable=False)
def call(self, inputs):
x = array_ops.fake_quant_with_min_max_vars(
inputs, self.min_var, self.max_var)
w_fq = array_ops.fake_quant_with_min_max_vars(
self.w, self.min_var, self.max_var)
x = math_ops.matmul(x, w_fq)
x = array_ops.fake_quant_with_min_max_vars(
x, self.min_var, self.max_var)
return x
return keras.Sequential(QLinear(3, input_shape=(2,)))
@test_util.run_v2_only
def testTrainingTimeQuantizeConversion(self):
model = self._getTrainingTimeQuantizedModel()
float_converter = lite.TFLiteConverterV2.from_keras_model(model)
float_tflite = float_converter.convert()
self.assertTrue(float_tflite)
quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_tflite = quantized_converter.convert()
self.assertTrue(quantized_tflite)
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
interpreter = Interpreter(model_content=quantized_tflite)
self.assertEqual(np.float32, interpreter.get_input_details()[0]['dtype'])
@test_util.run_v2_only
def testNewQuantizer(self):
"""Test the model quantized by the new converter."""