Enable Float16 conversion of model constants through Python API
PiperOrigin-RevId: 256460833
This commit is contained in:
parent
e43de87265
commit
cd7f680dcd
@ -194,6 +194,7 @@ def build_toco_convert_protos(input_tensors,
|
||||
allow_custom_ops=False,
|
||||
change_concat_input_ranges=False,
|
||||
post_training_quantize=False,
|
||||
quantize_to_float16=False,
|
||||
dump_graphviz_dir=None,
|
||||
dump_graphviz_video=False,
|
||||
target_ops=None,
|
||||
@ -247,6 +248,8 @@ def build_toco_convert_protos(input_tensors,
|
||||
of the converted float model. Model size will be reduced and there will be
|
||||
latency improvements (at the cost of accuracy).
|
||||
(default False)
|
||||
quantize_to_float16: Boolean indicating whether to convert float buffers
|
||||
to float16. (default False)
|
||||
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
|
||||
stages of processing GraphViz .dot files. Preferred over
|
||||
--output_format=GRAPHVIZ_DOT in order to keep the requirements of the
|
||||
@ -285,6 +288,7 @@ def build_toco_convert_protos(input_tensors,
|
||||
toco.reorder_across_fake_quant = reorder_across_fake_quant
|
||||
toco.allow_custom_ops = allow_custom_ops
|
||||
toco.post_training_quantize = post_training_quantize
|
||||
toco.quantize_to_float16 = quantize_to_float16
|
||||
if default_ranges_stats:
|
||||
toco.default_ranges_min = default_ranges_stats[0]
|
||||
toco.default_ranges_max = default_ranges_stats[1]
|
||||
|
@ -140,12 +140,19 @@ class TargetSpec(object):
|
||||
Attributes:
|
||||
supported_ops: Experimental flag, subject to change. Set of OpsSet options
|
||||
supported by the device. (default set([OpsSet.TFLITE_BUILTINS]))
|
||||
supported_types: List of types for constant values on the target device.
|
||||
Supported values are types exported by lite.constants. Frequently, an
|
||||
optimization choice is driven by the most compact (i.e. smallest)
|
||||
type in this list (default [constants.FLOAT])
|
||||
"""
|
||||
|
||||
def __init__(self, supported_ops=None):
|
||||
def __init__(self, supported_ops=None, supported_types=None):
|
||||
if supported_ops is None:
|
||||
supported_ops = set([OpsSet.TFLITE_BUILTINS])
|
||||
self.supported_ops = supported_ops
|
||||
if supported_types is None:
|
||||
supported_types = []
|
||||
self.supported_types = supported_types
|
||||
|
||||
|
||||
class TFLiteConverterBase(object):
|
||||
@ -174,30 +181,53 @@ class TFLiteConverterBase(object):
|
||||
if self.representative_dataset.input_gen is None:
|
||||
raise ValueError(
|
||||
"Provide an input generator for representative_dataset")
|
||||
elif self._int8_target_required():
|
||||
elif self._is_int8_target_required():
|
||||
raise ValueError("representative_dataset is required when specifying "
|
||||
"TFLITE_BUILTINS_INT8 target.")
|
||||
"TFLITE_BUILTINS_INT8 or INT8 supported types.")
|
||||
|
||||
def _int8_target_required(self):
|
||||
return set([OpsSet.TFLITE_BUILTINS_INT8]) == set(self._target_ops)
|
||||
def _validate_quantization(self):
|
||||
if self._is_int8_target_required():
|
||||
if self.target_spec.supported_types and (self._smallest_supported_type()
|
||||
!= constants.INT8):
|
||||
raise ValueError("TFLITE_BUILTINS_INT8 requires smallest supported "
|
||||
"type to be INT8.")
|
||||
|
||||
def _is_post_training_optimize(self):
|
||||
return (self._int8_target_required() or bool(
|
||||
def _is_int8_target_required(self):
|
||||
return (set([OpsSet.TFLITE_BUILTINS_INT8]) == set(self._target_ops) or
|
||||
self._smallest_supported_type() == constants.INT8)
|
||||
|
||||
def _smallest_supported_type(self):
|
||||
if self.target_spec.supported_types:
|
||||
return min(self.target_spec.supported_types, key=lambda x: x.size)
|
||||
else:
|
||||
return None
|
||||
|
||||
def _any_optimization_enabled(self):
|
||||
return bool(
|
||||
set(self.optimizations).intersection([
|
||||
Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE,
|
||||
Optimize.DEFAULT
|
||||
])))
|
||||
]))
|
||||
|
||||
def _is_weight_only_quantize(self):
|
||||
def _is_post_training_optimize(self):
|
||||
return self._is_int8_target_required() or self._any_optimization_enabled()
|
||||
|
||||
def _is_int8_weight_only_quantize(self):
|
||||
return (self._is_post_training_optimize() and
|
||||
(self.representative_dataset is None))
|
||||
|
||||
def _is_float16_quantize(self):
|
||||
return self._any_optimization_enabled() and (
|
||||
self._smallest_supported_type() == constants.FLOAT16)
|
||||
|
||||
def _is_calibration_quantize(self):
|
||||
return self._is_post_training_optimize() and self.representative_dataset
|
||||
return (self._is_post_training_optimize() and
|
||||
self.representative_dataset and
|
||||
self._smallest_supported_type() != constants.FLOAT16)
|
||||
|
||||
def _calibrate_quantize_model(self, result, inference_input_type,
|
||||
inference_output_type):
|
||||
allow_float = not self._int8_target_required()
|
||||
allow_float = not self._is_int8_target_required()
|
||||
calibrate_quantize = _calibrator.Calibrator(result)
|
||||
return calibrate_quantize.calibrate_and_quantize(
|
||||
self.representative_dataset.input_gen, inference_input_type,
|
||||
@ -380,16 +410,26 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
shape[0] = 1
|
||||
tensor.set_shape(shape)
|
||||
|
||||
self._validate_quantization()
|
||||
self._validate_representative_dataset()
|
||||
self._debug_info = _get_debug_info(
|
||||
_build_debug_info_func(self._funcs[0].graph), graph_def)
|
||||
|
||||
float16_quantize = self._is_float16_quantize()
|
||||
|
||||
converter_kwargs = {
|
||||
"input_format": constants.TENSORFLOW_GRAPHDEF,
|
||||
"allow_custom_ops": self.allow_custom_ops,
|
||||
"post_training_quantize": self._is_weight_only_quantize(),
|
||||
"target_ops": self.target_spec.supported_ops,
|
||||
"debug_info": self._debug_info
|
||||
"input_format":
|
||||
constants.TENSORFLOW_GRAPHDEF,
|
||||
"allow_custom_ops":
|
||||
self.allow_custom_ops,
|
||||
"post_training_quantize":
|
||||
self._is_int8_weight_only_quantize() or float16_quantize,
|
||||
"quantize_to_float16":
|
||||
float16_quantize,
|
||||
"target_ops":
|
||||
self.target_spec.supported_ops,
|
||||
"debug_info":
|
||||
self._debug_info
|
||||
}
|
||||
|
||||
# Converts model.
|
||||
@ -871,6 +911,7 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
else:
|
||||
quantized_stats = None
|
||||
|
||||
self._validate_quantization()
|
||||
self._validate_representative_dataset()
|
||||
|
||||
toco_inference_input_type = self.inference_input_type
|
||||
@ -889,7 +930,7 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
if inference_output_type is None:
|
||||
inference_output_type = constants.FLOAT
|
||||
|
||||
weight_only_quantize = self._is_weight_only_quantize()
|
||||
weight_only_quantize = self._is_int8_weight_only_quantize()
|
||||
if weight_only_quantize:
|
||||
# Currently, weight only quantization requires float inputs and outputs.
|
||||
if (inference_input_type != constants.FLOAT or
|
||||
@ -898,6 +939,8 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
"Provide an inference_input_type and inference_output_type of type "
|
||||
"tf.float32.")
|
||||
|
||||
float16_quantize = self._is_float16_quantize()
|
||||
|
||||
if not post_training_optimize and self.inference_output_type is not None:
|
||||
raise ValueError(
|
||||
"inference_output_type is currently not supported if optimizations "
|
||||
@ -914,7 +957,8 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
"reorder_across_fake_quant": self.reorder_across_fake_quant,
|
||||
"change_concat_input_ranges": self.change_concat_input_ranges,
|
||||
"allow_custom_ops": self.allow_custom_ops,
|
||||
"post_training_quantize": weight_only_quantize,
|
||||
"post_training_quantize": weight_only_quantize or float16_quantize,
|
||||
"quantize_to_float16": float16_quantize,
|
||||
"target_ops": self._target_ops,
|
||||
"dump_graphviz_dir": self.dump_graphviz_dir,
|
||||
"dump_graphviz_video": self.dump_graphviz_video
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.python.util.all_util import remove_undocumented
|
||||
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
||||
|
||||
FLOAT = dtypes.float32
|
||||
FLOAT16 = dtypes.float16
|
||||
INT32 = dtypes.int32
|
||||
INT64 = dtypes.int64
|
||||
STRING = dtypes.string
|
||||
@ -35,6 +36,7 @@ TFLITE = _toco_flags_pb2.TFLITE
|
||||
GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT
|
||||
|
||||
_tf_export(v1=["lite.constants.FLOAT"]).export_constant(__name__, "FLOAT")
|
||||
_tf_export(v1=["lite.constants.FLOAT16"]).export_constant(__name__, "FLOAT16")
|
||||
_tf_export(v1=["lite.constants.INT32"]).export_constant(__name__, "INT32")
|
||||
_tf_export(v1=["lite.constants.INT64"]).export_constant(__name__, "INT64")
|
||||
_tf_export(v1=["lite.constants.STRING"]).export_constant(__name__, "STRING")
|
||||
@ -54,6 +56,7 @@ EXPERIMENTAL_USE_TOCO_API_DIRECTLY = False
|
||||
|
||||
_allowed_symbols = [
|
||||
"FLOAT",
|
||||
"FLOAT16",
|
||||
"INT32",
|
||||
"INT64",
|
||||
"STRING",
|
||||
|
@ -102,7 +102,7 @@ class FromConstructor(TestModels):
|
||||
|
||||
|
||||
@test_util.run_v1_only('Incompatible with 2.0.')
|
||||
class FromSessionTest(TestModels):
|
||||
class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
|
||||
def testFloat(self):
|
||||
in_tensor = array_ops.placeholder(
|
||||
@ -636,6 +636,137 @@ class FromSessionTest(TestModels):
|
||||
# Ensure that the quantized weights tflite model is smaller.
|
||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
# Quantize to Float16 even if rep data provided.
|
||||
('UseRepresentativeData', True, False, True, False, False),
|
||||
# Quantize to Float16 if no rep data provided.
|
||||
('NoRepresentativeData', False, False, True, False, False),
|
||||
# Post training quantization if both rep data and int8 included.
|
||||
('UseRepresentativeDataIncludeInt8', True, True, False, False, True),
|
||||
# Error if no rep data and int8 included.
|
||||
('NoRepresentativeDataIncludeInt8', False, True, False, True, False))
|
||||
def testQuantizeFloat16(self, use_rep_data, include_int8,
|
||||
is_float16_quantized, is_error,
|
||||
is_post_training_quantized):
|
||||
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
|
||||
sess = session.Session()
|
||||
|
||||
# Convert float model.
|
||||
float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
|
||||
float_tflite = float_converter.convert()
|
||||
self.assertTrue(float_tflite)
|
||||
interpreter = Interpreter(model_content=float_tflite)
|
||||
interpreter.allocate_tensors()
|
||||
self.assertEqual(interpreter.get_tensor_details()[0]['name'], 'Conv2D_bias')
|
||||
self.assertEqual(interpreter.get_tensor_details()[0]['dtype'],
|
||||
lite.constants.FLOAT)
|
||||
# Convert model to quantized version
|
||||
quantized_converter = lite.TFLiteConverter.from_session(
|
||||
sess, [inp], [output])
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
quantized_converter.target_spec.supported_types = [lite.constants.FLOAT16]
|
||||
if include_int8:
|
||||
quantized_converter.target_spec.supported_types.append(
|
||||
lite.constants.INT8)
|
||||
if use_rep_data:
|
||||
quantized_converter.representative_dataset = calibration_gen
|
||||
|
||||
if is_error:
|
||||
with self.assertRaises(ValueError) as error:
|
||||
quantized_converter.convert()
|
||||
self.assertEqual(
|
||||
'representative_dataset is required when specifying '
|
||||
'TFLITE_BUILTINS_INT8 or INT8 supported types.', str(error.exception))
|
||||
|
||||
else:
|
||||
quantized_tflite = quantized_converter.convert()
|
||||
self.assertTrue(quantized_tflite)
|
||||
interpreter = Interpreter(model_content=quantized_tflite)
|
||||
interpreter.allocate_tensors()
|
||||
self.assertEqual(interpreter.get_tensor_details()[0]['name'],
|
||||
'Conv2D_bias')
|
||||
|
||||
if is_float16_quantized:
|
||||
# Verify that bias constant is float16 type.
|
||||
self.assertEqual(interpreter.get_tensor_details()[0]['dtype'],
|
||||
lite.constants.FLOAT16)
|
||||
elif is_post_training_quantized:
|
||||
# Verify that bias constants is int32 type.
|
||||
self.assertEqual(interpreter.get_tensor_details()[0]['dtype'],
|
||||
lite.constants.INT32)
|
||||
else:
|
||||
raise ValueError('Invalid test options.')
|
||||
|
||||
def testInvalidQuantizeFloat16(self):
|
||||
inp, output, _ = self._getCalibrationQuantizeModel()
|
||||
sess = session.Session()
|
||||
|
||||
# Specify float16 quantization
|
||||
quantized_converter = lite.TFLiteConverter.from_session(
|
||||
sess, [inp], [output])
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
quantized_converter.target_spec.supported_types = [lite.constants.FLOAT16]
|
||||
# Specifiy only int8 builtin ops
|
||||
quantized_converter.target_spec.supported_ops = [
|
||||
lite.OpsSet.TFLITE_BUILTINS_INT8
|
||||
]
|
||||
with self.assertRaises(ValueError) as error:
|
||||
quantized_converter.convert()
|
||||
self.assertEqual(
|
||||
'TFLITE_BUILTINS_INT8 requires smallest supported type to be INT8.',
|
||||
str(error.exception))
|
||||
|
||||
def testInvalidPostTrainingQuantize(self):
|
||||
np.random.seed(0)
|
||||
# We need the tensor to have more than 1024 elements for quantize_weights
|
||||
# to kick in. Thus, the [33, 33] shape.
|
||||
in_tensor_1 = array_ops.placeholder(
|
||||
shape=[33, 33], dtype=dtypes.float32, name='inputA')
|
||||
in_tensor_2 = constant_op.constant(
|
||||
np.random.uniform(low=-10., high=10., size=(33, 33)),
|
||||
shape=[33, 33],
|
||||
dtype=dtypes.float32,
|
||||
name='inputB')
|
||||
out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
|
||||
sess = session.Session()
|
||||
|
||||
# Attempt to convert to quantized weights model.
|
||||
quantized_converter = lite.TFLiteConverter.from_session(
|
||||
sess, [in_tensor_1], [out_tensor])
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
# Restricting to int8 type only
|
||||
quantized_converter.target_spec.supported_types = [lite.constants.INT8]
|
||||
# A representative dataset is required for full fixed point quantization.
|
||||
with self.assertRaises(ValueError) as error:
|
||||
quantized_converter.convert()
|
||||
self.assertEqual(
|
||||
'representative_dataset is required when specifying '
|
||||
'TFLITE_BUILTINS_INT8 or INT8 supported types.', str(error.exception))
|
||||
|
||||
def testPostTrainingCalibrateAndQuantizeFloatNotAllowed(self):
|
||||
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
|
||||
sess = session.Session()
|
||||
|
||||
# Convert float model.
|
||||
float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
|
||||
float_tflite = float_converter.convert()
|
||||
self.assertTrue(float_tflite)
|
||||
|
||||
# Convert quantized model.
|
||||
quantized_converter = lite.TFLiteConverter.from_session(
|
||||
sess, [inp], [output])
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
quantized_converter.representative_dataset = calibration_gen
|
||||
quantized_converter.target_spec.supported_types = [lite.constants.INT8]
|
||||
quantized_tflite = quantized_converter.convert()
|
||||
self.assertTrue(quantized_tflite)
|
||||
# Ensure that restricting supported types to int8 forces
|
||||
# all fixed point ops/tensors in converter.
|
||||
self.assertTrue(quantized_converter._is_int8_target_required())
|
||||
|
||||
# Ensure that the quantized weights tflite model is smaller.
|
||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||
|
||||
def testPostTrainingCalibrateAndQuantizeInt8Inputs(self):
|
||||
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
|
||||
sess = session.Session()
|
||||
|
@ -27,6 +27,7 @@ from google.protobuf.message import DecodeError
|
||||
from tensorflow.core.framework import graph_pb2 as _graph_pb2
|
||||
from tensorflow.lite.python import convert_saved_model as _convert_saved_model
|
||||
from tensorflow.lite.python import lite as _lite
|
||||
from tensorflow.lite.python import lite_constants as constants
|
||||
from tensorflow.lite.python import util as _util
|
||||
from tensorflow.python import keras as _keras
|
||||
from tensorflow.python.client import session as _session
|
||||
@ -81,7 +82,7 @@ def _convert(converter, **kwargs):
|
||||
Args:
|
||||
converter: TFLiteConverter object.
|
||||
**kwargs: Additional arguments to be passed into the converter. Supported
|
||||
flags are {"target_ops", "post_training_quantize"}.
|
||||
flags are {"target_ops", "post_training_quantize", "quantize_to_float16"}.
|
||||
|
||||
Returns:
|
||||
The converted TFLite model in serialized format.
|
||||
@ -92,7 +93,9 @@ def _convert(converter, **kwargs):
|
||||
if "target_ops" in kwargs:
|
||||
converter.target_spec.supported_ops = kwargs["target_ops"]
|
||||
if "post_training_quantize" in kwargs:
|
||||
converter.post_training_quantize = kwargs["post_training_quantize"]
|
||||
converter.optimizations = [_lite.Optimize.DEFAULT]
|
||||
if kwargs.get("quantize_to_float16", False):
|
||||
converter.target_spec.supported_types = [constants.FLOAT16]
|
||||
return converter.convert()
|
||||
|
||||
|
||||
@ -362,7 +365,10 @@ def test_frozen_graph_quant(filename,
|
||||
for float_tensor in float_tensors)
|
||||
has_quant_tensor = num_tensors_float != num_tensors_same_dtypes
|
||||
|
||||
# For the "flex" case, post_training_quantize should not alter the graph,
|
||||
# unless we are quantizing to float16.
|
||||
if ("target_ops" in kwargs and
|
||||
not kwargs.get("quantize_to_float16", False) and
|
||||
set(kwargs["target_ops"]) == set([_lite.OpsSet.SELECT_TF_OPS])):
|
||||
if has_quant_tensor:
|
||||
raise ValueError("--post_training_quantize flag unexpectedly altered the "
|
||||
@ -465,6 +471,42 @@ def test_saved_model_v2(directory,
|
||||
compare_models_v2(tflite_model, concrete_func, input_data=input_data)
|
||||
|
||||
|
||||
def test_saved_model_v2_quant_float16(directory, **kwargs):
|
||||
"""Validates the TensorFlow SavedModel converts to a TFLite model."""
|
||||
|
||||
converter = _lite.TFLiteConverterV2.from_saved_model(directory)
|
||||
tflite_model_float = _convert(converter, version=2, **kwargs)
|
||||
|
||||
interpreter_float = _lite.Interpreter(model_content=tflite_model_float)
|
||||
interpreter_float.allocate_tensors()
|
||||
float_tensors = interpreter_float.get_tensor_details()
|
||||
|
||||
tflite_model_quant = _convert(
|
||||
converter,
|
||||
version=2,
|
||||
post_training_quantize=True,
|
||||
quantize_to_float16=True,
|
||||
**kwargs)
|
||||
|
||||
interpreter_quant = _lite.Interpreter(model_content=tflite_model_quant)
|
||||
interpreter_quant.allocate_tensors()
|
||||
quant_tensors = interpreter_quant.get_tensor_details()
|
||||
quant_tensors_map = {
|
||||
tensor_detail["name"]: tensor_detail for tensor_detail in quant_tensors
|
||||
}
|
||||
|
||||
# Check if weights are of different types in the float and quantized models.
|
||||
num_tensors_float = len(float_tensors)
|
||||
num_tensors_same_dtypes = sum(
|
||||
float_tensor["dtype"] == quant_tensors_map[float_tensor["name"]]["dtype"]
|
||||
for float_tensor in float_tensors)
|
||||
has_quant_tensor = num_tensors_float != num_tensors_same_dtypes
|
||||
|
||||
if not has_quant_tensor:
|
||||
raise ValueError("--post_training_quantize flag was unable to quantize the "
|
||||
"graph as expected.")
|
||||
|
||||
|
||||
def test_keras_model(filename,
|
||||
input_arrays=None,
|
||||
input_shapes=None,
|
||||
|
@ -4,6 +4,6 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'supported_ops\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'supported_ops\', \'supported_types\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,10 @@ tf_module {
|
||||
name: "FLOAT"
|
||||
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
|
||||
}
|
||||
member {
|
||||
name: "FLOAT16"
|
||||
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
|
||||
}
|
||||
member {
|
||||
name: "GRAPHVIZ_DOT"
|
||||
mtype: "<type \'int\'>"
|
||||
|
@ -4,6 +4,6 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'supported_ops\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'supported_ops\', \'supported_types\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user