Requires quantized_input_stats when it is not post training quantization
If either inference_type or inference_input_type is set to int8/uint8 and it is not post-training quantization, the quantized_input_stats is required. PiperOrigin-RevId: 291441023 Change-Id: Iaee998f10dc90c66ddafc392de250d0f9234388c
This commit is contained in:
parent
1554ffdc6a
commit
d653517480
@ -39,6 +39,16 @@ from tensorflow.python.platform import resource_loader as _resource_loader
|
|||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
||||||
|
|
||||||
|
_quantized_inference_types = [_types_pb2.QUANTIZED_UINT8, _types_pb2.INT8]
|
||||||
|
|
||||||
|
|
||||||
|
# If the `inference_type` or the `inference_input_type` is the quantized type
|
||||||
|
# and it is not post training quantization, the input quantization stats is
|
||||||
|
# required.
|
||||||
|
def _requires_input_stats(toco_flags):
|
||||||
|
return ((toco_flags.inference_type in _quantized_inference_types or
|
||||||
|
toco_flags.inference_input_type in _quantized_inference_types) and
|
||||||
|
not toco_flags.post_training_quantize)
|
||||||
|
|
||||||
# Find the toco_from_protos binary using the resource loader if using from
|
# Find the toco_from_protos binary using the resource loader if using from
|
||||||
# bazel, otherwise we are in a pip where console_scripts already has
|
# bazel, otherwise we are in a pip where console_scripts already has
|
||||||
@ -117,6 +127,7 @@ def toco_convert_protos(model_flags_str,
|
|||||||
information. (default None)
|
information. (default None)
|
||||||
enable_mlir_converter: Enables MLIR-based conversion instead of the default
|
enable_mlir_converter: Enables MLIR-based conversion instead of the default
|
||||||
TOCO conversion. (default False)
|
TOCO conversion. (default False)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Converted model in serialized form (e.g. a TFLITE model is common).
|
Converted model in serialized form (e.g. a TFLITE model is common).
|
||||||
Raises:
|
Raises:
|
||||||
@ -151,8 +162,8 @@ Alternative, use virtualenv.""")
|
|||||||
# Windows and TemporaryFile are not that useful together,
|
# Windows and TemporaryFile are not that useful together,
|
||||||
# since you cannot have two readers/writers. So we have to
|
# since you cannot have two readers/writers. So we have to
|
||||||
# make the temporaries and close and delete them explicitly.
|
# make the temporaries and close and delete them explicitly.
|
||||||
toco_filename, model_filename, input_filename, output_filename = (
|
toco_filename, model_filename, input_filename, output_filename = (None, None,
|
||||||
None, None, None, None)
|
None, None)
|
||||||
try:
|
try:
|
||||||
# Build all input files
|
# Build all input files
|
||||||
with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \
|
with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \
|
||||||
@ -216,7 +227,8 @@ Alternative, use virtualenv.""")
|
|||||||
finally:
|
finally:
|
||||||
# Must manually cleanup files.
|
# Must manually cleanup files.
|
||||||
for filename in [
|
for filename in [
|
||||||
toco_filename, input_filename, model_filename, output_filename]:
|
toco_filename, input_filename, model_filename, output_filename
|
||||||
|
]:
|
||||||
try:
|
try:
|
||||||
_os.unlink(filename)
|
_os.unlink(filename)
|
||||||
except (OSError, TypeError):
|
except (OSError, TypeError):
|
||||||
@ -257,12 +269,12 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
inference_type: Target data type of real-number arrays in the output file.
|
inference_type: Target data type of real-number arrays in the output file.
|
||||||
Must be `{tf.float32, tf.uint8, tf.int8}`. (default tf.float32)
|
Must be `{tf.float32, tf.uint8, tf.int8}`. (default tf.float32)
|
||||||
inference_input_type: Target data type of real-number input arrays. Allows
|
inference_input_type: Target data type of real-number input arrays. Allows
|
||||||
for a different type for input arrays in the case of quantization.
|
for a different type for input arrays in the case of quantization. Must be
|
||||||
Must be `{tf.float32, tf.uint8, tf.int8}`. (default `inference_type`)
|
`{tf.float32, tf.uint8, tf.int8}`. (default `inference_type`)
|
||||||
input_format: Type of data to read Currently must be
|
input_format: Type of data to read Currently must be
|
||||||
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
|
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
|
||||||
input_shapes: Input array shape. It needs to be a list of the same length
|
input_shapes: Input array shape. It needs to be a list of the same length as
|
||||||
as `input_tensors`, or None. (default None)
|
`input_tensors`, or None. (default None)
|
||||||
output_format: Output file format. Currently must be `{TFLITE,
|
output_format: Output file format. Currently must be `{TFLITE,
|
||||||
GRAPHVIZ_DOT}`. (default TFLITE)
|
GRAPHVIZ_DOT}`. (default TFLITE)
|
||||||
quantized_input_stats: List of tuples of floats representing the mean and
|
quantized_input_stats: List of tuples of floats representing the mean and
|
||||||
@ -284,8 +296,8 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
allow_custom_ops: Boolean indicating whether to allow custom operations.
|
allow_custom_ops: Boolean indicating whether to allow custom operations.
|
||||||
When false any unknown operation is an error. When true, custom ops are
|
When false any unknown operation is an error. When true, custom ops are
|
||||||
created for any op that is unknown. The developer will need to provide
|
created for any op that is unknown. The developer will need to provide
|
||||||
these to the TensorFlow Lite runtime with a custom resolver.
|
these to the TensorFlow Lite runtime with a custom resolver. (default
|
||||||
(default False)
|
False)
|
||||||
custom_opdefs: List of strings representing custom ops OpDefs that are
|
custom_opdefs: List of strings representing custom ops OpDefs that are
|
||||||
included in the GraphDef. Required when using custom operations with the
|
included in the GraphDef. Required when using custom operations with the
|
||||||
MLIR-based converter. (default None)
|
MLIR-based converter. (default None)
|
||||||
@ -294,21 +306,19 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
the ranges of concat operator overlap when true. (default False)
|
the ranges of concat operator overlap when true. (default False)
|
||||||
post_training_quantize: Boolean indicating whether to quantize the weights
|
post_training_quantize: Boolean indicating whether to quantize the weights
|
||||||
of the converted float model. Model size will be reduced and there will be
|
of the converted float model. Model size will be reduced and there will be
|
||||||
latency improvements (at the cost of accuracy).
|
latency improvements (at the cost of accuracy). (default False)
|
||||||
(default False)
|
quantize_to_float16: Boolean indicating whether to convert float buffers to
|
||||||
quantize_to_float16: Boolean indicating whether to convert float buffers
|
float16. (default False)
|
||||||
to float16. (default False)
|
|
||||||
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
|
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
|
||||||
stages of processing GraphViz .dot files. Preferred over
|
stages of processing GraphViz .dot files. Preferred over
|
||||||
--output_format=GRAPHVIZ_DOT in order to keep the requirements of the
|
--output_format=GRAPHVIZ_DOT in order to keep the requirements of the
|
||||||
output file. (default None)
|
output file. (default None)
|
||||||
dump_graphviz_video: Boolean indicating whether to dump the graph after
|
dump_graphviz_video: Boolean indicating whether to dump the graph after
|
||||||
every graph transformation. (default False)
|
every graph transformation. (default False)
|
||||||
target_ops: Experimental flag, subject to change. Set of OpsSet
|
target_ops: Experimental flag, subject to change. Set of OpsSet options
|
||||||
options indicating which converter to use.
|
indicating which converter to use. (default set([OpsSet.TFLITE_BUILTINS]))
|
||||||
(default set([OpsSet.TFLITE_BUILTINS]))
|
allow_nonexistent_arrays: Allow specifying array names that don't exist or
|
||||||
allow_nonexistent_arrays: Allow specifying array names that don't exist
|
are unused in the final graph. (default False)
|
||||||
or are unused in the final graph. (default False)
|
|
||||||
debug_info: `GraphDebugInfo` proto containing the stack traces for the
|
debug_info: `GraphDebugInfo` proto containing the stack traces for the
|
||||||
original nodes referred by the converted graph.
|
original nodes referred by the converted graph.
|
||||||
conversion_summary_dir: A string, the path to the generated conversion logs.
|
conversion_summary_dir: A string, the path to the generated conversion logs.
|
||||||
@ -363,11 +373,13 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
input_array.data_type = util.convert_dtype_to_tflite_type(
|
input_array.data_type = util.convert_dtype_to_tflite_type(
|
||||||
input_tensor.dtype)
|
input_tensor.dtype)
|
||||||
|
|
||||||
if toco.inference_type in [_types_pb2.QUANTIZED_UINT8, _types_pb2.INT8]:
|
if _requires_input_stats(toco):
|
||||||
if not quantized_input_stats and not post_training_quantize:
|
if quantized_input_stats:
|
||||||
|
input_array.mean_value, input_array.std_value = quantized_input_stats[
|
||||||
|
idx]
|
||||||
|
else:
|
||||||
raise ValueError("std_dev and mean must be defined when inference_type "
|
raise ValueError("std_dev and mean must be defined when inference_type "
|
||||||
"is QUANTIZED_UINT8 or INT8.")
|
"or inference_input_type is QUANTIZED_UINT8 or INT8.")
|
||||||
input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
|
|
||||||
if input_shapes is None:
|
if input_shapes is None:
|
||||||
shape = input_tensor.shape
|
shape = input_tensor.shape
|
||||||
else:
|
else:
|
||||||
@ -396,7 +408,7 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
|||||||
input_arrays_with_shape: Tuple of strings representing input tensor names
|
input_arrays_with_shape: Tuple of strings representing input tensor names
|
||||||
and list of integers representing input shapes
|
and list of integers representing input shapes
|
||||||
(e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
|
(e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
|
||||||
into TensorFlow and when `input_tensors` is None. (default None)
|
into TensorFlow and when `input_tensors` is None. (default None)
|
||||||
output_arrays: List of output tensors to freeze graph with. Use only when
|
output_arrays: List of output tensors to freeze graph with. Use only when
|
||||||
graph cannot be loaded into TensorFlow and when `output_tensors` is None.
|
graph cannot be loaded into TensorFlow and when `output_tensors` is None.
|
||||||
(default None)
|
(default None)
|
||||||
@ -417,13 +429,11 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
|||||||
|
|
||||||
for idx, (name, shape) in enumerate(input_arrays_with_shape):
|
for idx, (name, shape) in enumerate(input_arrays_with_shape):
|
||||||
input_array = model_flags.input_arrays.add()
|
input_array = model_flags.input_arrays.add()
|
||||||
if toco_flags.inference_type in (
|
if _requires_input_stats(toco_flags):
|
||||||
[_types_pb2.QUANTIZED_UINT8, _types_pb2.INT8]):
|
if (("quantized_input_stats" not in kwargs) or
|
||||||
if ((("quantized_input_stats" not in kwargs) or
|
(not kwargs["quantized_input_stats"])):
|
||||||
(not kwargs["quantized_input_stats"])) and
|
raise ValueError("std_dev and mean must be defined when inference_type "
|
||||||
not toco_flags.post_training_quantize):
|
"or inference_input_type is QUANTIZED_UINT8 or INT8.")
|
||||||
raise ValueError("std_dev and mean must be defined when "
|
|
||||||
"inference_type is QUANTIZED_UINT8 or INT8.")
|
|
||||||
input_array.mean_value, input_array.std_value = kwargs[
|
input_array.mean_value, input_array.std_value = kwargs[
|
||||||
"quantized_input_stats"][idx]
|
"quantized_input_stats"][idx]
|
||||||
input_array.name = name
|
input_array.name = name
|
||||||
|
@ -76,8 +76,9 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
|||||||
sess.graph_def, [in_tensor], [out_tensor],
|
sess.graph_def, [in_tensor], [out_tensor],
|
||||||
inference_type=lite_constants.QUANTIZED_UINT8)
|
inference_type=lite_constants.QUANTIZED_UINT8)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"std_dev and mean must be defined when inference_type is "
|
"std_dev and mean must be defined when inference_type or "
|
||||||
"QUANTIZED_UINT8 or INT8.", str(error.exception))
|
"inference_input_type is QUANTIZED_UINT8 or INT8.",
|
||||||
|
str(error.exception))
|
||||||
|
|
||||||
with self.assertRaises(ValueError) as error:
|
with self.assertRaises(ValueError) as error:
|
||||||
convert.toco_convert(
|
convert.toco_convert(
|
||||||
@ -85,8 +86,9 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
|||||||
inference_type=lite_constants.QUANTIZED_UINT8,
|
inference_type=lite_constants.QUANTIZED_UINT8,
|
||||||
inference_input_type=lite_constants.FLOAT)
|
inference_input_type=lite_constants.FLOAT)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"std_dev and mean must be defined when inference_type is "
|
"std_dev and mean must be defined when inference_type or "
|
||||||
"QUANTIZED_UINT8 or INT8.", str(error.exception))
|
"inference_input_type is QUANTIZED_UINT8 or INT8.",
|
||||||
|
str(error.exception))
|
||||||
|
|
||||||
def testGraphDefBasic(self):
|
def testGraphDefBasic(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -185,8 +187,9 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
|||||||
enable_mlir_converter=False,
|
enable_mlir_converter=False,
|
||||||
inference_type=lite_constants.QUANTIZED_UINT8)
|
inference_type=lite_constants.QUANTIZED_UINT8)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"std_dev and mean must be defined when inference_type is "
|
"std_dev and mean must be defined when inference_type or "
|
||||||
"QUANTIZED_UINT8 or INT8.", str(error.exception))
|
"inference_input_type is QUANTIZED_UINT8 or INT8.",
|
||||||
|
str(error.exception))
|
||||||
|
|
||||||
|
|
||||||
class ConvertTestOpHint(test_util.TensorFlowTestCase):
|
class ConvertTestOpHint(test_util.TensorFlowTestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user