Requires quantized_input_stats when it is not post training quantization

If either inference_type or inference_input_type is set to int8/uint8 and it is
not post-training quantization, the quantized_input_stats is required.

PiperOrigin-RevId: 291441023
Change-Id: Iaee998f10dc90c66ddafc392de250d0f9234388c
This commit is contained in:
Feng Liu 2020-01-24 14:15:19 -08:00 committed by TensorFlower Gardener
parent 1554ffdc6a
commit d653517480
2 changed files with 49 additions and 36 deletions

View File

@ -39,6 +39,16 @@ from tensorflow.python.platform import resource_loader as _resource_loader
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export as _tf_export from tensorflow.python.util.tf_export import tf_export as _tf_export
_quantized_inference_types = [_types_pb2.QUANTIZED_UINT8, _types_pb2.INT8]
# If the `inference_type` or the `inference_input_type` is the quantized type
# and it is not post training quantization, the input quantization stats is
# required.
def _requires_input_stats(toco_flags):
return ((toco_flags.inference_type in _quantized_inference_types or
toco_flags.inference_input_type in _quantized_inference_types) and
not toco_flags.post_training_quantize)
# Find the toco_from_protos binary using the resource loader if using from # Find the toco_from_protos binary using the resource loader if using from
# bazel, otherwise we are in a pip where console_scripts already has # bazel, otherwise we are in a pip where console_scripts already has
@ -117,6 +127,7 @@ def toco_convert_protos(model_flags_str,
information. (default None) information. (default None)
enable_mlir_converter: Enables MLIR-based conversion instead of the default enable_mlir_converter: Enables MLIR-based conversion instead of the default
TOCO conversion. (default False) TOCO conversion. (default False)
Returns: Returns:
Converted model in serialized form (e.g. a TFLITE model is common). Converted model in serialized form (e.g. a TFLITE model is common).
Raises: Raises:
@ -151,8 +162,8 @@ Alternative, use virtualenv.""")
# Windows and TemporaryFile are not that useful together, # Windows and TemporaryFile are not that useful together,
# since you cannot have two readers/writers. So we have to # since you cannot have two readers/writers. So we have to
# make the temporaries and close and delete them explicitly. # make the temporaries and close and delete them explicitly.
toco_filename, model_filename, input_filename, output_filename = ( toco_filename, model_filename, input_filename, output_filename = (None, None,
None, None, None, None) None, None)
try: try:
# Build all input files # Build all input files
with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \ with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \
@ -216,7 +227,8 @@ Alternative, use virtualenv.""")
finally: finally:
# Must manually cleanup files. # Must manually cleanup files.
for filename in [ for filename in [
toco_filename, input_filename, model_filename, output_filename]: toco_filename, input_filename, model_filename, output_filename
]:
try: try:
_os.unlink(filename) _os.unlink(filename)
except (OSError, TypeError): except (OSError, TypeError):
@ -257,12 +269,12 @@ def build_toco_convert_protos(input_tensors,
inference_type: Target data type of real-number arrays in the output file. inference_type: Target data type of real-number arrays in the output file.
Must be `{tf.float32, tf.uint8, tf.int8}`. (default tf.float32) Must be `{tf.float32, tf.uint8, tf.int8}`. (default tf.float32)
inference_input_type: Target data type of real-number input arrays. Allows inference_input_type: Target data type of real-number input arrays. Allows
for a different type for input arrays in the case of quantization. for a different type for input arrays in the case of quantization. Must be
Must be `{tf.float32, tf.uint8, tf.int8}`. (default `inference_type`) `{tf.float32, tf.uint8, tf.int8}`. (default `inference_type`)
input_format: Type of data to read Currently must be input_format: Type of data to read Currently must be
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
input_shapes: Input array shape. It needs to be a list of the same length input_shapes: Input array shape. It needs to be a list of the same length as
as `input_tensors`, or None. (default None) `input_tensors`, or None. (default None)
output_format: Output file format. Currently must be `{TFLITE, output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE) GRAPHVIZ_DOT}`. (default TFLITE)
quantized_input_stats: List of tuples of floats representing the mean and quantized_input_stats: List of tuples of floats representing the mean and
@ -284,8 +296,8 @@ def build_toco_convert_protos(input_tensors,
allow_custom_ops: Boolean indicating whether to allow custom operations. allow_custom_ops: Boolean indicating whether to allow custom operations.
When false any unknown operation is an error. When true, custom ops are When false any unknown operation is an error. When true, custom ops are
created for any op that is unknown. The developer will need to provide created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver. these to the TensorFlow Lite runtime with a custom resolver. (default
(default False) False)
custom_opdefs: List of strings representing custom ops OpDefs that are custom_opdefs: List of strings representing custom ops OpDefs that are
included in the GraphDef. Required when using custom operations with the included in the GraphDef. Required when using custom operations with the
MLIR-based converter. (default None) MLIR-based converter. (default None)
@ -294,21 +306,19 @@ def build_toco_convert_protos(input_tensors,
the ranges of concat operator overlap when true. (default False) the ranges of concat operator overlap when true. (default False)
post_training_quantize: Boolean indicating whether to quantize the weights post_training_quantize: Boolean indicating whether to quantize the weights
of the converted float model. Model size will be reduced and there will be of the converted float model. Model size will be reduced and there will be
latency improvements (at the cost of accuracy). latency improvements (at the cost of accuracy). (default False)
(default False) quantize_to_float16: Boolean indicating whether to convert float buffers to
quantize_to_float16: Boolean indicating whether to convert float buffers float16. (default False)
to float16. (default False)
dump_graphviz_dir: Full filepath of folder to dump the graphs at various dump_graphviz_dir: Full filepath of folder to dump the graphs at various
stages of processing GraphViz .dot files. Preferred over stages of processing GraphViz .dot files. Preferred over
--output_format=GRAPHVIZ_DOT in order to keep the requirements of the --output_format=GRAPHVIZ_DOT in order to keep the requirements of the
output file. (default None) output file. (default None)
dump_graphviz_video: Boolean indicating whether to dump the graph after dump_graphviz_video: Boolean indicating whether to dump the graph after
every graph transformation. (default False) every graph transformation. (default False)
target_ops: Experimental flag, subject to change. Set of OpsSet target_ops: Experimental flag, subject to change. Set of OpsSet options
options indicating which converter to use. indicating which converter to use. (default set([OpsSet.TFLITE_BUILTINS]))
(default set([OpsSet.TFLITE_BUILTINS])) allow_nonexistent_arrays: Allow specifying array names that don't exist or
allow_nonexistent_arrays: Allow specifying array names that don't exist are unused in the final graph. (default False)
or are unused in the final graph. (default False)
debug_info: `GraphDebugInfo` proto containing the stack traces for the debug_info: `GraphDebugInfo` proto containing the stack traces for the
original nodes referred by the converted graph. original nodes referred by the converted graph.
conversion_summary_dir: A string, the path to the generated conversion logs. conversion_summary_dir: A string, the path to the generated conversion logs.
@ -363,11 +373,13 @@ def build_toco_convert_protos(input_tensors,
input_array.data_type = util.convert_dtype_to_tflite_type( input_array.data_type = util.convert_dtype_to_tflite_type(
input_tensor.dtype) input_tensor.dtype)
if toco.inference_type in [_types_pb2.QUANTIZED_UINT8, _types_pb2.INT8]: if _requires_input_stats(toco):
if not quantized_input_stats and not post_training_quantize: if quantized_input_stats:
input_array.mean_value, input_array.std_value = quantized_input_stats[
idx]
else:
raise ValueError("std_dev and mean must be defined when inference_type " raise ValueError("std_dev and mean must be defined when inference_type "
"is QUANTIZED_UINT8 or INT8.") "or inference_input_type is QUANTIZED_UINT8 or INT8.")
input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
if input_shapes is None: if input_shapes is None:
shape = input_tensor.shape shape = input_tensor.shape
else: else:
@ -396,7 +408,7 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
input_arrays_with_shape: Tuple of strings representing input tensor names input_arrays_with_shape: Tuple of strings representing input tensor names
and list of integers representing input shapes and list of integers representing input shapes
(e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
into TensorFlow and when `input_tensors` is None. (default None) into TensorFlow and when `input_tensors` is None. (default None)
output_arrays: List of output tensors to freeze graph with. Use only when output_arrays: List of output tensors to freeze graph with. Use only when
graph cannot be loaded into TensorFlow and when `output_tensors` is None. graph cannot be loaded into TensorFlow and when `output_tensors` is None.
(default None) (default None)
@ -417,13 +429,11 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
for idx, (name, shape) in enumerate(input_arrays_with_shape): for idx, (name, shape) in enumerate(input_arrays_with_shape):
input_array = model_flags.input_arrays.add() input_array = model_flags.input_arrays.add()
if toco_flags.inference_type in ( if _requires_input_stats(toco_flags):
[_types_pb2.QUANTIZED_UINT8, _types_pb2.INT8]): if (("quantized_input_stats" not in kwargs) or
if ((("quantized_input_stats" not in kwargs) or (not kwargs["quantized_input_stats"])):
(not kwargs["quantized_input_stats"])) and raise ValueError("std_dev and mean must be defined when inference_type "
not toco_flags.post_training_quantize): "or inference_input_type is QUANTIZED_UINT8 or INT8.")
raise ValueError("std_dev and mean must be defined when "
"inference_type is QUANTIZED_UINT8 or INT8.")
input_array.mean_value, input_array.std_value = kwargs[ input_array.mean_value, input_array.std_value = kwargs[
"quantized_input_stats"][idx] "quantized_input_stats"][idx]
input_array.name = name input_array.name = name

View File

@ -76,8 +76,9 @@ class ConvertTest(test_util.TensorFlowTestCase):
sess.graph_def, [in_tensor], [out_tensor], sess.graph_def, [in_tensor], [out_tensor],
inference_type=lite_constants.QUANTIZED_UINT8) inference_type=lite_constants.QUANTIZED_UINT8)
self.assertEqual( self.assertEqual(
"std_dev and mean must be defined when inference_type is " "std_dev and mean must be defined when inference_type or "
"QUANTIZED_UINT8 or INT8.", str(error.exception)) "inference_input_type is QUANTIZED_UINT8 or INT8.",
str(error.exception))
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
convert.toco_convert( convert.toco_convert(
@ -85,8 +86,9 @@ class ConvertTest(test_util.TensorFlowTestCase):
inference_type=lite_constants.QUANTIZED_UINT8, inference_type=lite_constants.QUANTIZED_UINT8,
inference_input_type=lite_constants.FLOAT) inference_input_type=lite_constants.FLOAT)
self.assertEqual( self.assertEqual(
"std_dev and mean must be defined when inference_type is " "std_dev and mean must be defined when inference_type or "
"QUANTIZED_UINT8 or INT8.", str(error.exception)) "inference_input_type is QUANTIZED_UINT8 or INT8.",
str(error.exception))
def testGraphDefBasic(self): def testGraphDefBasic(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
@ -185,8 +187,9 @@ class ConvertTest(test_util.TensorFlowTestCase):
enable_mlir_converter=False, enable_mlir_converter=False,
inference_type=lite_constants.QUANTIZED_UINT8) inference_type=lite_constants.QUANTIZED_UINT8)
self.assertEqual( self.assertEqual(
"std_dev and mean must be defined when inference_type is " "std_dev and mean must be defined when inference_type or "
"QUANTIZED_UINT8 or INT8.", str(error.exception)) "inference_input_type is QUANTIZED_UINT8 or INT8.",
str(error.exception))
class ConvertTestOpHint(test_util.TensorFlowTestCase): class ConvertTestOpHint(test_util.TensorFlowTestCase):