Requires quantized_input_stats when it is not post training quantization
If either inference_type or inference_input_type is set to int8/uint8 and it is not post-training quantization, the quantized_input_stats is required. PiperOrigin-RevId: 291441023 Change-Id: Iaee998f10dc90c66ddafc392de250d0f9234388c
This commit is contained in:
parent
1554ffdc6a
commit
d653517480
@ -39,6 +39,16 @@ from tensorflow.python.platform import resource_loader as _resource_loader
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
||||
|
||||
_quantized_inference_types = [_types_pb2.QUANTIZED_UINT8, _types_pb2.INT8]
|
||||
|
||||
|
||||
# If the `inference_type` or the `inference_input_type` is the quantized type
|
||||
# and it is not post training quantization, the input quantization stats is
|
||||
# required.
|
||||
def _requires_input_stats(toco_flags):
|
||||
return ((toco_flags.inference_type in _quantized_inference_types or
|
||||
toco_flags.inference_input_type in _quantized_inference_types) and
|
||||
not toco_flags.post_training_quantize)
|
||||
|
||||
# Find the toco_from_protos binary using the resource loader if using from
|
||||
# bazel, otherwise we are in a pip where console_scripts already has
|
||||
@ -117,6 +127,7 @@ def toco_convert_protos(model_flags_str,
|
||||
information. (default None)
|
||||
enable_mlir_converter: Enables MLIR-based conversion instead of the default
|
||||
TOCO conversion. (default False)
|
||||
|
||||
Returns:
|
||||
Converted model in serialized form (e.g. a TFLITE model is common).
|
||||
Raises:
|
||||
@ -151,8 +162,8 @@ Alternative, use virtualenv.""")
|
||||
# Windows and TemporaryFile are not that useful together,
|
||||
# since you cannot have two readers/writers. So we have to
|
||||
# make the temporaries and close and delete them explicitly.
|
||||
toco_filename, model_filename, input_filename, output_filename = (
|
||||
None, None, None, None)
|
||||
toco_filename, model_filename, input_filename, output_filename = (None, None,
|
||||
None, None)
|
||||
try:
|
||||
# Build all input files
|
||||
with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \
|
||||
@ -216,7 +227,8 @@ Alternative, use virtualenv.""")
|
||||
finally:
|
||||
# Must manually cleanup files.
|
||||
for filename in [
|
||||
toco_filename, input_filename, model_filename, output_filename]:
|
||||
toco_filename, input_filename, model_filename, output_filename
|
||||
]:
|
||||
try:
|
||||
_os.unlink(filename)
|
||||
except (OSError, TypeError):
|
||||
@ -257,12 +269,12 @@ def build_toco_convert_protos(input_tensors,
|
||||
inference_type: Target data type of real-number arrays in the output file.
|
||||
Must be `{tf.float32, tf.uint8, tf.int8}`. (default tf.float32)
|
||||
inference_input_type: Target data type of real-number input arrays. Allows
|
||||
for a different type for input arrays in the case of quantization.
|
||||
Must be `{tf.float32, tf.uint8, tf.int8}`. (default `inference_type`)
|
||||
for a different type for input arrays in the case of quantization. Must be
|
||||
`{tf.float32, tf.uint8, tf.int8}`. (default `inference_type`)
|
||||
input_format: Type of data to read Currently must be
|
||||
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
|
||||
input_shapes: Input array shape. It needs to be a list of the same length
|
||||
as `input_tensors`, or None. (default None)
|
||||
input_shapes: Input array shape. It needs to be a list of the same length as
|
||||
`input_tensors`, or None. (default None)
|
||||
output_format: Output file format. Currently must be `{TFLITE,
|
||||
GRAPHVIZ_DOT}`. (default TFLITE)
|
||||
quantized_input_stats: List of tuples of floats representing the mean and
|
||||
@ -284,8 +296,8 @@ def build_toco_convert_protos(input_tensors,
|
||||
allow_custom_ops: Boolean indicating whether to allow custom operations.
|
||||
When false any unknown operation is an error. When true, custom ops are
|
||||
created for any op that is unknown. The developer will need to provide
|
||||
these to the TensorFlow Lite runtime with a custom resolver.
|
||||
(default False)
|
||||
these to the TensorFlow Lite runtime with a custom resolver. (default
|
||||
False)
|
||||
custom_opdefs: List of strings representing custom ops OpDefs that are
|
||||
included in the GraphDef. Required when using custom operations with the
|
||||
MLIR-based converter. (default None)
|
||||
@ -294,21 +306,19 @@ def build_toco_convert_protos(input_tensors,
|
||||
the ranges of concat operator overlap when true. (default False)
|
||||
post_training_quantize: Boolean indicating whether to quantize the weights
|
||||
of the converted float model. Model size will be reduced and there will be
|
||||
latency improvements (at the cost of accuracy).
|
||||
(default False)
|
||||
quantize_to_float16: Boolean indicating whether to convert float buffers
|
||||
to float16. (default False)
|
||||
latency improvements (at the cost of accuracy). (default False)
|
||||
quantize_to_float16: Boolean indicating whether to convert float buffers to
|
||||
float16. (default False)
|
||||
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
|
||||
stages of processing GraphViz .dot files. Preferred over
|
||||
--output_format=GRAPHVIZ_DOT in order to keep the requirements of the
|
||||
output file. (default None)
|
||||
dump_graphviz_video: Boolean indicating whether to dump the graph after
|
||||
every graph transformation. (default False)
|
||||
target_ops: Experimental flag, subject to change. Set of OpsSet
|
||||
options indicating which converter to use.
|
||||
(default set([OpsSet.TFLITE_BUILTINS]))
|
||||
allow_nonexistent_arrays: Allow specifying array names that don't exist
|
||||
or are unused in the final graph. (default False)
|
||||
target_ops: Experimental flag, subject to change. Set of OpsSet options
|
||||
indicating which converter to use. (default set([OpsSet.TFLITE_BUILTINS]))
|
||||
allow_nonexistent_arrays: Allow specifying array names that don't exist or
|
||||
are unused in the final graph. (default False)
|
||||
debug_info: `GraphDebugInfo` proto containing the stack traces for the
|
||||
original nodes referred by the converted graph.
|
||||
conversion_summary_dir: A string, the path to the generated conversion logs.
|
||||
@ -363,11 +373,13 @@ def build_toco_convert_protos(input_tensors,
|
||||
input_array.data_type = util.convert_dtype_to_tflite_type(
|
||||
input_tensor.dtype)
|
||||
|
||||
if toco.inference_type in [_types_pb2.QUANTIZED_UINT8, _types_pb2.INT8]:
|
||||
if not quantized_input_stats and not post_training_quantize:
|
||||
if _requires_input_stats(toco):
|
||||
if quantized_input_stats:
|
||||
input_array.mean_value, input_array.std_value = quantized_input_stats[
|
||||
idx]
|
||||
else:
|
||||
raise ValueError("std_dev and mean must be defined when inference_type "
|
||||
"is QUANTIZED_UINT8 or INT8.")
|
||||
input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
|
||||
"or inference_input_type is QUANTIZED_UINT8 or INT8.")
|
||||
if input_shapes is None:
|
||||
shape = input_tensor.shape
|
||||
else:
|
||||
@ -396,7 +408,7 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
||||
input_arrays_with_shape: Tuple of strings representing input tensor names
|
||||
and list of integers representing input shapes
|
||||
(e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
|
||||
into TensorFlow and when `input_tensors` is None. (default None)
|
||||
into TensorFlow and when `input_tensors` is None. (default None)
|
||||
output_arrays: List of output tensors to freeze graph with. Use only when
|
||||
graph cannot be loaded into TensorFlow and when `output_tensors` is None.
|
||||
(default None)
|
||||
@ -417,13 +429,11 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
||||
|
||||
for idx, (name, shape) in enumerate(input_arrays_with_shape):
|
||||
input_array = model_flags.input_arrays.add()
|
||||
if toco_flags.inference_type in (
|
||||
[_types_pb2.QUANTIZED_UINT8, _types_pb2.INT8]):
|
||||
if ((("quantized_input_stats" not in kwargs) or
|
||||
(not kwargs["quantized_input_stats"])) and
|
||||
not toco_flags.post_training_quantize):
|
||||
raise ValueError("std_dev and mean must be defined when "
|
||||
"inference_type is QUANTIZED_UINT8 or INT8.")
|
||||
if _requires_input_stats(toco_flags):
|
||||
if (("quantized_input_stats" not in kwargs) or
|
||||
(not kwargs["quantized_input_stats"])):
|
||||
raise ValueError("std_dev and mean must be defined when inference_type "
|
||||
"or inference_input_type is QUANTIZED_UINT8 or INT8.")
|
||||
input_array.mean_value, input_array.std_value = kwargs[
|
||||
"quantized_input_stats"][idx]
|
||||
input_array.name = name
|
||||
|
@ -76,8 +76,9 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
sess.graph_def, [in_tensor], [out_tensor],
|
||||
inference_type=lite_constants.QUANTIZED_UINT8)
|
||||
self.assertEqual(
|
||||
"std_dev and mean must be defined when inference_type is "
|
||||
"QUANTIZED_UINT8 or INT8.", str(error.exception))
|
||||
"std_dev and mean must be defined when inference_type or "
|
||||
"inference_input_type is QUANTIZED_UINT8 or INT8.",
|
||||
str(error.exception))
|
||||
|
||||
with self.assertRaises(ValueError) as error:
|
||||
convert.toco_convert(
|
||||
@ -85,8 +86,9 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
inference_type=lite_constants.QUANTIZED_UINT8,
|
||||
inference_input_type=lite_constants.FLOAT)
|
||||
self.assertEqual(
|
||||
"std_dev and mean must be defined when inference_type is "
|
||||
"QUANTIZED_UINT8 or INT8.", str(error.exception))
|
||||
"std_dev and mean must be defined when inference_type or "
|
||||
"inference_input_type is QUANTIZED_UINT8 or INT8.",
|
||||
str(error.exception))
|
||||
|
||||
def testGraphDefBasic(self):
|
||||
with ops.Graph().as_default():
|
||||
@ -185,8 +187,9 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
enable_mlir_converter=False,
|
||||
inference_type=lite_constants.QUANTIZED_UINT8)
|
||||
self.assertEqual(
|
||||
"std_dev and mean must be defined when inference_type is "
|
||||
"QUANTIZED_UINT8 or INT8.", str(error.exception))
|
||||
"std_dev and mean must be defined when inference_type or "
|
||||
"inference_input_type is QUANTIZED_UINT8 or INT8.",
|
||||
str(error.exception))
|
||||
|
||||
|
||||
class ConvertTestOpHint(test_util.TensorFlowTestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user