Enable Float16 conversion of model constants through Python API

PiperOrigin-RevId: 256460833
This commit is contained in:
A. Unique TensorFlower 2019-07-03 16:46:17 -07:00 committed by TensorFlower Gardener
parent e43de87265
commit cd7f680dcd
8 changed files with 251 additions and 23 deletions

View File

@ -194,6 +194,7 @@ def build_toco_convert_protos(input_tensors,
allow_custom_ops=False,
change_concat_input_ranges=False,
post_training_quantize=False,
quantize_to_float16=False,
dump_graphviz_dir=None,
dump_graphviz_video=False,
target_ops=None,
@ -247,6 +248,8 @@ def build_toco_convert_protos(input_tensors,
of the converted float model. Model size will be reduced and there will be
latency improvements (at the cost of accuracy).
(default False)
quantize_to_float16: Boolean indicating whether to convert float buffers
to float16. (default False)
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
stages of processing GraphViz .dot files. Preferred over
--output_format=GRAPHVIZ_DOT in order to keep the requirements of the
@ -285,6 +288,7 @@ def build_toco_convert_protos(input_tensors,
toco.reorder_across_fake_quant = reorder_across_fake_quant
toco.allow_custom_ops = allow_custom_ops
toco.post_training_quantize = post_training_quantize
toco.quantize_to_float16 = quantize_to_float16
if default_ranges_stats:
toco.default_ranges_min = default_ranges_stats[0]
toco.default_ranges_max = default_ranges_stats[1]

View File

@ -140,12 +140,19 @@ class TargetSpec(object):
Attributes:
supported_ops: Experimental flag, subject to change. Set of OpsSet options
supported by the device. (default set([OpsSet.TFLITE_BUILTINS]))
supported_types: List of types for constant values on the target device.
Supported values are types exported by lite.constants. Frequently, an
optimization choice is driven by the most compact (i.e. smallest)
type in this list (default [constants.FLOAT])
"""
def __init__(self, supported_ops=None):
def __init__(self, supported_ops=None, supported_types=None):
if supported_ops is None:
supported_ops = set([OpsSet.TFLITE_BUILTINS])
self.supported_ops = supported_ops
if supported_types is None:
supported_types = []
self.supported_types = supported_types
class TFLiteConverterBase(object):
@ -174,30 +181,53 @@ class TFLiteConverterBase(object):
if self.representative_dataset.input_gen is None:
raise ValueError(
"Provide an input generator for representative_dataset")
elif self._int8_target_required():
elif self._is_int8_target_required():
raise ValueError("representative_dataset is required when specifying "
"TFLITE_BUILTINS_INT8 target.")
"TFLITE_BUILTINS_INT8 or INT8 supported types.")
def _int8_target_required(self):
return set([OpsSet.TFLITE_BUILTINS_INT8]) == set(self._target_ops)
def _validate_quantization(self):
if self._is_int8_target_required():
if self.target_spec.supported_types and (self._smallest_supported_type()
!= constants.INT8):
raise ValueError("TFLITE_BUILTINS_INT8 requires smallest supported "
"type to be INT8.")
def _is_post_training_optimize(self):
return (self._int8_target_required() or bool(
def _is_int8_target_required(self):
return (set([OpsSet.TFLITE_BUILTINS_INT8]) == set(self._target_ops) or
self._smallest_supported_type() == constants.INT8)
def _smallest_supported_type(self):
if self.target_spec.supported_types:
return min(self.target_spec.supported_types, key=lambda x: x.size)
else:
return None
def _any_optimization_enabled(self):
return bool(
set(self.optimizations).intersection([
Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE,
Optimize.DEFAULT
])))
]))
def _is_weight_only_quantize(self):
def _is_post_training_optimize(self):
return self._is_int8_target_required() or self._any_optimization_enabled()
def _is_int8_weight_only_quantize(self):
return (self._is_post_training_optimize() and
(self.representative_dataset is None))
def _is_float16_quantize(self):
return self._any_optimization_enabled() and (
self._smallest_supported_type() == constants.FLOAT16)
def _is_calibration_quantize(self):
return self._is_post_training_optimize() and self.representative_dataset
return (self._is_post_training_optimize() and
self.representative_dataset and
self._smallest_supported_type() != constants.FLOAT16)
def _calibrate_quantize_model(self, result, inference_input_type,
inference_output_type):
allow_float = not self._int8_target_required()
allow_float = not self._is_int8_target_required()
calibrate_quantize = _calibrator.Calibrator(result)
return calibrate_quantize.calibrate_and_quantize(
self.representative_dataset.input_gen, inference_input_type,
@ -380,16 +410,26 @@ class TFLiteConverterV2(TFLiteConverterBase):
shape[0] = 1
tensor.set_shape(shape)
self._validate_quantization()
self._validate_representative_dataset()
self._debug_info = _get_debug_info(
_build_debug_info_func(self._funcs[0].graph), graph_def)
float16_quantize = self._is_float16_quantize()
converter_kwargs = {
"input_format": constants.TENSORFLOW_GRAPHDEF,
"allow_custom_ops": self.allow_custom_ops,
"post_training_quantize": self._is_weight_only_quantize(),
"target_ops": self.target_spec.supported_ops,
"debug_info": self._debug_info
"input_format":
constants.TENSORFLOW_GRAPHDEF,
"allow_custom_ops":
self.allow_custom_ops,
"post_training_quantize":
self._is_int8_weight_only_quantize() or float16_quantize,
"quantize_to_float16":
float16_quantize,
"target_ops":
self.target_spec.supported_ops,
"debug_info":
self._debug_info
}
# Converts model.
@ -871,6 +911,7 @@ class TFLiteConverter(TFLiteConverterBase):
else:
quantized_stats = None
self._validate_quantization()
self._validate_representative_dataset()
toco_inference_input_type = self.inference_input_type
@ -889,7 +930,7 @@ class TFLiteConverter(TFLiteConverterBase):
if inference_output_type is None:
inference_output_type = constants.FLOAT
weight_only_quantize = self._is_weight_only_quantize()
weight_only_quantize = self._is_int8_weight_only_quantize()
if weight_only_quantize:
# Currently, weight only quantization requires float inputs and outputs.
if (inference_input_type != constants.FLOAT or
@ -898,6 +939,8 @@ class TFLiteConverter(TFLiteConverterBase):
"Provide an inference_input_type and inference_output_type of type "
"tf.float32.")
float16_quantize = self._is_float16_quantize()
if not post_training_optimize and self.inference_output_type is not None:
raise ValueError(
"inference_output_type is currently not supported if optimizations "
@ -914,7 +957,8 @@ class TFLiteConverter(TFLiteConverterBase):
"reorder_across_fake_quant": self.reorder_across_fake_quant,
"change_concat_input_ranges": self.change_concat_input_ranges,
"allow_custom_ops": self.allow_custom_ops,
"post_training_quantize": weight_only_quantize,
"post_training_quantize": weight_only_quantize or float16_quantize,
"quantize_to_float16": float16_quantize,
"target_ops": self._target_ops,
"dump_graphviz_dir": self.dump_graphviz_dir,
"dump_graphviz_video": self.dump_graphviz_video

View File

@ -24,6 +24,7 @@ from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export as _tf_export
FLOAT = dtypes.float32
FLOAT16 = dtypes.float16
INT32 = dtypes.int32
INT64 = dtypes.int64
STRING = dtypes.string
@ -35,6 +36,7 @@ TFLITE = _toco_flags_pb2.TFLITE
GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT
_tf_export(v1=["lite.constants.FLOAT"]).export_constant(__name__, "FLOAT")
_tf_export(v1=["lite.constants.FLOAT16"]).export_constant(__name__, "FLOAT16")
_tf_export(v1=["lite.constants.INT32"]).export_constant(__name__, "INT32")
_tf_export(v1=["lite.constants.INT64"]).export_constant(__name__, "INT64")
_tf_export(v1=["lite.constants.STRING"]).export_constant(__name__, "STRING")
@ -54,6 +56,7 @@ EXPERIMENTAL_USE_TOCO_API_DIRECTLY = False
_allowed_symbols = [
"FLOAT",
"FLOAT16",
"INT32",
"INT64",
"STRING",

View File

@ -102,7 +102,7 @@ class FromConstructor(TestModels):
@test_util.run_v1_only('Incompatible with 2.0.')
class FromSessionTest(TestModels):
class FromSessionTest(TestModels, parameterized.TestCase):
def testFloat(self):
in_tensor = array_ops.placeholder(
@ -636,6 +636,137 @@ class FromSessionTest(TestModels):
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
@parameterized.named_parameters(
# Quantize to Float16 even if rep data provided.
('UseRepresentativeData', True, False, True, False, False),
# Quantize to Float16 if no rep data provided.
('NoRepresentativeData', False, False, True, False, False),
# Post training quantization if both rep data and int8 included.
('UseRepresentativeDataIncludeInt8', True, True, False, False, True),
# Error if no rep data and int8 included.
('NoRepresentativeDataIncludeInt8', False, True, False, True, False))
def testQuantizeFloat16(self, use_rep_data, include_int8,
is_float16_quantized, is_error,
is_post_training_quantized):
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
sess = session.Session()
# Convert float model.
float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
float_tflite = float_converter.convert()
self.assertTrue(float_tflite)
interpreter = Interpreter(model_content=float_tflite)
interpreter.allocate_tensors()
self.assertEqual(interpreter.get_tensor_details()[0]['name'], 'Conv2D_bias')
self.assertEqual(interpreter.get_tensor_details()[0]['dtype'],
lite.constants.FLOAT)
# Convert model to quantized version
quantized_converter = lite.TFLiteConverter.from_session(
sess, [inp], [output])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.target_spec.supported_types = [lite.constants.FLOAT16]
if include_int8:
quantized_converter.target_spec.supported_types.append(
lite.constants.INT8)
if use_rep_data:
quantized_converter.representative_dataset = calibration_gen
if is_error:
with self.assertRaises(ValueError) as error:
quantized_converter.convert()
self.assertEqual(
'representative_dataset is required when specifying '
'TFLITE_BUILTINS_INT8 or INT8 supported types.', str(error.exception))
else:
quantized_tflite = quantized_converter.convert()
self.assertTrue(quantized_tflite)
interpreter = Interpreter(model_content=quantized_tflite)
interpreter.allocate_tensors()
self.assertEqual(interpreter.get_tensor_details()[0]['name'],
'Conv2D_bias')
if is_float16_quantized:
# Verify that bias constant is float16 type.
self.assertEqual(interpreter.get_tensor_details()[0]['dtype'],
lite.constants.FLOAT16)
elif is_post_training_quantized:
# Verify that bias constants is int32 type.
self.assertEqual(interpreter.get_tensor_details()[0]['dtype'],
lite.constants.INT32)
else:
raise ValueError('Invalid test options.')
def testInvalidQuantizeFloat16(self):
inp, output, _ = self._getCalibrationQuantizeModel()
sess = session.Session()
# Specify float16 quantization
quantized_converter = lite.TFLiteConverter.from_session(
sess, [inp], [output])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.target_spec.supported_types = [lite.constants.FLOAT16]
# Specifiy only int8 builtin ops
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.TFLITE_BUILTINS_INT8
]
with self.assertRaises(ValueError) as error:
quantized_converter.convert()
self.assertEqual(
'TFLITE_BUILTINS_INT8 requires smallest supported type to be INT8.',
str(error.exception))
def testInvalidPostTrainingQuantize(self):
np.random.seed(0)
# We need the tensor to have more than 1024 elements for quantize_weights
# to kick in. Thus, the [33, 33] shape.
in_tensor_1 = array_ops.placeholder(
shape=[33, 33], dtype=dtypes.float32, name='inputA')
in_tensor_2 = constant_op.constant(
np.random.uniform(low=-10., high=10., size=(33, 33)),
shape=[33, 33],
dtype=dtypes.float32,
name='inputB')
out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
sess = session.Session()
# Attempt to convert to quantized weights model.
quantized_converter = lite.TFLiteConverter.from_session(
sess, [in_tensor_1], [out_tensor])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
# Restricting to int8 type only
quantized_converter.target_spec.supported_types = [lite.constants.INT8]
# A representative dataset is required for full fixed point quantization.
with self.assertRaises(ValueError) as error:
quantized_converter.convert()
self.assertEqual(
'representative_dataset is required when specifying '
'TFLITE_BUILTINS_INT8 or INT8 supported types.', str(error.exception))
def testPostTrainingCalibrateAndQuantizeFloatNotAllowed(self):
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
sess = session.Session()
# Convert float model.
float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
float_tflite = float_converter.convert()
self.assertTrue(float_tflite)
# Convert quantized model.
quantized_converter = lite.TFLiteConverter.from_session(
sess, [inp], [output])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.representative_dataset = calibration_gen
quantized_converter.target_spec.supported_types = [lite.constants.INT8]
quantized_tflite = quantized_converter.convert()
self.assertTrue(quantized_tflite)
# Ensure that restricting supported types to int8 forces
# all fixed point ops/tensors in converter.
self.assertTrue(quantized_converter._is_int8_target_required())
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
def testPostTrainingCalibrateAndQuantizeInt8Inputs(self):
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
sess = session.Session()

View File

@ -27,6 +27,7 @@ from google.protobuf.message import DecodeError
from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.lite.python import convert_saved_model as _convert_saved_model
from tensorflow.lite.python import lite as _lite
from tensorflow.lite.python import lite_constants as constants
from tensorflow.lite.python import util as _util
from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
@ -81,7 +82,7 @@ def _convert(converter, **kwargs):
Args:
converter: TFLiteConverter object.
**kwargs: Additional arguments to be passed into the converter. Supported
flags are {"target_ops", "post_training_quantize"}.
flags are {"target_ops", "post_training_quantize", "quantize_to_float16"}.
Returns:
The converted TFLite model in serialized format.
@ -92,7 +93,9 @@ def _convert(converter, **kwargs):
if "target_ops" in kwargs:
converter.target_spec.supported_ops = kwargs["target_ops"]
if "post_training_quantize" in kwargs:
converter.post_training_quantize = kwargs["post_training_quantize"]
converter.optimizations = [_lite.Optimize.DEFAULT]
if kwargs.get("quantize_to_float16", False):
converter.target_spec.supported_types = [constants.FLOAT16]
return converter.convert()
@ -362,7 +365,10 @@ def test_frozen_graph_quant(filename,
for float_tensor in float_tensors)
has_quant_tensor = num_tensors_float != num_tensors_same_dtypes
# For the "flex" case, post_training_quantize should not alter the graph,
# unless we are quantizing to float16.
if ("target_ops" in kwargs and
not kwargs.get("quantize_to_float16", False) and
set(kwargs["target_ops"]) == set([_lite.OpsSet.SELECT_TF_OPS])):
if has_quant_tensor:
raise ValueError("--post_training_quantize flag unexpectedly altered the "
@ -465,6 +471,42 @@ def test_saved_model_v2(directory,
compare_models_v2(tflite_model, concrete_func, input_data=input_data)
def test_saved_model_v2_quant_float16(directory, **kwargs):
"""Validates the TensorFlow SavedModel converts to a TFLite model."""
converter = _lite.TFLiteConverterV2.from_saved_model(directory)
tflite_model_float = _convert(converter, version=2, **kwargs)
interpreter_float = _lite.Interpreter(model_content=tflite_model_float)
interpreter_float.allocate_tensors()
float_tensors = interpreter_float.get_tensor_details()
tflite_model_quant = _convert(
converter,
version=2,
post_training_quantize=True,
quantize_to_float16=True,
**kwargs)
interpreter_quant = _lite.Interpreter(model_content=tflite_model_quant)
interpreter_quant.allocate_tensors()
quant_tensors = interpreter_quant.get_tensor_details()
quant_tensors_map = {
tensor_detail["name"]: tensor_detail for tensor_detail in quant_tensors
}
# Check if weights are of different types in the float and quantized models.
num_tensors_float = len(float_tensors)
num_tensors_same_dtypes = sum(
float_tensor["dtype"] == quant_tensors_map[float_tensor["name"]]["dtype"]
for float_tensor in float_tensors)
has_quant_tensor = num_tensors_float != num_tensors_same_dtypes
if not has_quant_tensor:
raise ValueError("--post_training_quantize flag was unable to quantize the "
"graph as expected.")
def test_keras_model(filename,
input_arrays=None,
input_shapes=None,

View File

@ -4,6 +4,6 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'supported_ops\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'supported_ops\', \'supported_types\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
}

View File

@ -4,6 +4,10 @@ tf_module {
name: "FLOAT"
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
member {
name: "FLOAT16"
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
member {
name: "GRAPHVIZ_DOT"
mtype: "<type \'int\'>"

View File

@ -4,6 +4,6 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'supported_ops\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'supported_ops\', \'supported_types\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
}