Reduce redundancy by replacing TFLite data types with TF data types
PiperOrigin-RevId: 337004382 Change-Id: I48b5734781527138705bdda5e6ea80120e3a33a1
This commit is contained in:
parent
e485bb64c5
commit
2396803aa1
@ -104,8 +104,8 @@ tflite_convert \
|
|||||||
--std_dev_values=127.7
|
--std_dev_values=127.7
|
||||||
```
|
```
|
||||||
|
|
||||||
*If you're setting `--inference_type=QUANTIZED_UINT8` then update
|
*If you're setting `--inference_type=UINT8` then update `--mean_values=128` and
|
||||||
`--mean_values=128` and `--std_dev_values=127`*
|
`--std_dev_values=127`*
|
||||||
|
|
||||||
#### Convert a model with \"dummy-quantization\" into a quantized TensorFlow Lite model
|
#### Convert a model with \"dummy-quantization\" into a quantized TensorFlow Lite model
|
||||||
|
|
||||||
@ -134,8 +134,8 @@ tflite_convert \
|
|||||||
--default_ranges_max=6
|
--default_ranges_max=6
|
||||||
```
|
```
|
||||||
|
|
||||||
*If you're setting `--inference_type=QUANTIZED_UINT8` then update
|
*If you're setting `--inference_type=UINT8` then update `--mean_values=128` and
|
||||||
`--mean_values=128` and `--std_dev_values=127`*
|
`--std_dev_values=127`*
|
||||||
|
|
||||||
#### Convert a model with select TensorFlow operators.
|
#### Convert a model with select TensorFlow operators.
|
||||||
|
|
||||||
|
@ -63,8 +63,7 @@ based on index.
|
|||||||
has a shape of [2, 3] and "bar" has a shape of [4, 5, 6].
|
has a shape of [2, 3] and "bar" has a shape of [4, 5, 6].
|
||||||
* `--std_dev_values`, `--mean_values`. Type: comma-separated list of floats.
|
* `--std_dev_values`, `--mean_values`. Type: comma-separated list of floats.
|
||||||
These specify the (de-)quantization parameters of the input array, when it
|
These specify the (de-)quantization parameters of the input array, when it
|
||||||
is quantized. This is only needed if `inference_input_type` is `INT8` or
|
is quantized. Only needed if `inference_input_type` is `INT8` or `UINT8`.
|
||||||
`QUANTIZED_UINT8`.
|
|
||||||
* The meaning of `mean_values` and `std_dev_values` is as follows: each
|
* The meaning of `mean_values` and `std_dev_values` is as follows: each
|
||||||
quantized value in the quantized input array will be interpreted as a
|
quantized value in the quantized input array will be interpreted as a
|
||||||
mathematical real number (i.e. as an input activation value) according
|
mathematical real number (i.e. as an input activation value) according
|
||||||
@ -75,12 +74,12 @@ based on index.
|
|||||||
the inference code according to the above formula, before proceeding
|
the inference code according to the above formula, before proceeding
|
||||||
with float inference.
|
with float inference.
|
||||||
* When performing quantized inference (`inference_type`
|
* When performing quantized inference (`inference_type`
|
||||||
is`INT8`or`QUANTIZED_UINT8`), no dequantization is performed by the
|
is`INT8`or`UINT8`), no dequantization is performed by the inference
|
||||||
inference code. However, the quantization parameters of all arrays,
|
code. However, the quantization parameters of all arrays, including
|
||||||
including those of the input arrays as specified
|
those of the input arrays as specified by`mean_value`and`std_dev_value`,
|
||||||
by`mean_value`and`std_dev_value`, determine the fixed-point multipliers
|
determine the fixed-point multipliers used in the quantized inference
|
||||||
used in the quantized inference code.`mean_value` must be an integer
|
code.`mean_value` must be an integer when performing quantized
|
||||||
when performing quantized inference.
|
inference.
|
||||||
|
|
||||||
## Transformation flags
|
## Transformation flags
|
||||||
|
|
||||||
@ -90,7 +89,7 @@ have.
|
|||||||
|
|
||||||
* `--inference_type`. Type: string. Default: `FLOAT`. Data type of all
|
* `--inference_type`. Type: string. Default: `FLOAT`. Data type of all
|
||||||
real-number arrays in the output file except for input arrays (defined by
|
real-number arrays in the output file except for input arrays (defined by
|
||||||
`--inference_input_type`). Must be `{FLOAT, INT8, QUANTIZED_UINT8}`.
|
`--inference_input_type`). Must be `{FLOAT, INT8, UINT8}`.
|
||||||
|
|
||||||
This flag only impacts real-number arrays including float and quantized
|
This flag only impacts real-number arrays including float and quantized
|
||||||
arrays. This excludes all other data types including plain integer arrays
|
arrays. This excludes all other data types including plain integer arrays
|
||||||
@ -102,16 +101,15 @@ have.
|
|||||||
* If `INT8`, then real-numbers arrays will be quantized as int8 in the
|
* If `INT8`, then real-numbers arrays will be quantized as int8 in the
|
||||||
output file. If they were float in the input file, then they get
|
output file. If they were float in the input file, then they get
|
||||||
quantized.
|
quantized.
|
||||||
* If `QUANTIZED_UINT8`, then real-numbers arrays will be quantized as
|
* If `UINT8`, then real-numbers arrays will be quantized as uint8 in the
|
||||||
uint8 in the output file. If they were float in the input file, then
|
output file. If they were float in the input file, then they get
|
||||||
they get quantized.
|
quantized.
|
||||||
|
|
||||||
* `--inference_input_type`. Type: string. Data type of a real-number input
|
* `--inference_input_type`. Type: string. Data type of a real-number input
|
||||||
array in the output file. By default the `--inference_type` is used as type
|
array in the output file. By default the `--inference_type` is used as type
|
||||||
of all of the input arrays. Flag is primarily intended for generating a
|
of all of the input arrays. Flag is primarily intended for generating a
|
||||||
float-point graph with a quantized input array. A Dequantized operator is
|
float-point graph with a quantized input array. A Dequantized operator is
|
||||||
added immediately after the input array. Must be `{FLOAT, INT8,
|
added immediately after the input array. Must be `{FLOAT, INT8, UINT8}`.
|
||||||
QUANTIZED_UINT8}`.
|
|
||||||
|
|
||||||
The flag is typically used for vision models taking a bitmap as input but
|
The flag is typically used for vision models taking a bitmap as input but
|
||||||
requiring floating-point inference. For such image models, the uint8 input
|
requiring floating-point inference. For such image models, the uint8 input
|
||||||
|
@ -99,7 +99,7 @@ py_test(
|
|||||||
],
|
],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
# Increased thread count for reducing timeout failures.
|
# Increased thread count for reducing timeout failures.
|
||||||
shard_count = 4,
|
shard_count = 10,
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_oss",
|
"no_oss",
|
||||||
@ -109,9 +109,26 @@ py_test(
|
|||||||
"notsan", # b/160824139
|
"notsan", # b/160824139
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":convert",
|
||||||
":tflite_convert",
|
":tflite_convert",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:framework",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//tensorflow/python:random_ops",
|
||||||
|
"//tensorflow/python:session",
|
||||||
|
"//tensorflow/python:tf2",
|
||||||
|
"//tensorflow/python/eager:def_function",
|
||||||
|
"//tensorflow/python/keras",
|
||||||
|
"//tensorflow/python/saved_model",
|
||||||
|
"//tensorflow/python/saved_model:save",
|
||||||
|
"//tensorflow/python/training:training_util",
|
||||||
|
"//tensorflow/python/training/tracking",
|
||||||
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.lite.python import wrap_toco
|
|||||||
from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
|
from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
|
||||||
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
||||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.platform import resource_loader as _resource_loader
|
from tensorflow.python.platform import resource_loader as _resource_loader
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
@ -301,7 +302,7 @@ Alternative, use virtualenv.""")
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def build_toco_flags(inference_type=lite_constants.FLOAT,
|
def build_toco_flags(inference_type=dtypes.float32,
|
||||||
inference_input_type=None,
|
inference_input_type=None,
|
||||||
input_format=lite_constants.TENSORFLOW_GRAPHDEF,
|
input_format=lite_constants.TENSORFLOW_GRAPHDEF,
|
||||||
output_format=lite_constants.TFLITE,
|
output_format=lite_constants.TFLITE,
|
||||||
@ -352,7 +353,7 @@ def build_toco_flags(inference_type=lite_constants.FLOAT,
|
|||||||
|
|
||||||
def build_toco_convert_protos(input_tensors,
|
def build_toco_convert_protos(input_tensors,
|
||||||
output_tensors,
|
output_tensors,
|
||||||
inference_type=lite_constants.FLOAT,
|
inference_type=dtypes.float32,
|
||||||
inference_input_type=None,
|
inference_input_type=None,
|
||||||
input_format=lite_constants.TENSORFLOW_GRAPHDEF,
|
input_format=lite_constants.TENSORFLOW_GRAPHDEF,
|
||||||
input_shapes=None,
|
input_shapes=None,
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.lite.python import convert
|
from tensorflow.lite.python import convert
|
||||||
from tensorflow.lite.python import lite_constants
|
|
||||||
from tensorflow.lite.python import op_hint
|
from tensorflow.lite.python import op_hint
|
||||||
from tensorflow.lite.python.interpreter import Interpreter
|
from tensorflow.lite.python.interpreter import Interpreter
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
@ -59,7 +58,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
tflite_model = convert.toco_convert(
|
tflite_model = convert.toco_convert(
|
||||||
sess.graph_def, [in_tensor], [out_tensor],
|
sess.graph_def, [in_tensor], [out_tensor],
|
||||||
inference_type=lite_constants.QUANTIZED_UINT8,
|
inference_type=dtypes.uint8,
|
||||||
quantized_input_stats=[(0., 1.)])
|
quantized_input_stats=[(0., 1.)])
|
||||||
self.assertTrue(tflite_model)
|
self.assertTrue(tflite_model)
|
||||||
|
|
||||||
@ -73,7 +72,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
|||||||
tflite_model = convert.toco_convert_graph_def(
|
tflite_model = convert.toco_convert_graph_def(
|
||||||
sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
|
sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
|
||||||
enable_mlir_converter=False,
|
enable_mlir_converter=False,
|
||||||
inference_type=lite_constants.FLOAT)
|
inference_type=dtypes.float32)
|
||||||
self.assertTrue(tflite_model)
|
self.assertTrue(tflite_model)
|
||||||
|
|
||||||
# Check values from converted model.
|
# Check values from converted model.
|
||||||
@ -111,7 +110,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
|||||||
input_arrays_map,
|
input_arrays_map,
|
||||||
output_arrays,
|
output_arrays,
|
||||||
enable_mlir_converter=False,
|
enable_mlir_converter=False,
|
||||||
inference_type=lite_constants.QUANTIZED_UINT8,
|
inference_type=dtypes.uint8,
|
||||||
quantized_input_stats=[(0., 1.), (0., 1.)])
|
quantized_input_stats=[(0., 1.), (0., 1.)])
|
||||||
self.assertTrue(tflite_model)
|
self.assertTrue(tflite_model)
|
||||||
|
|
||||||
@ -158,7 +157,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
|||||||
input_arrays_map,
|
input_arrays_map,
|
||||||
output_arrays,
|
output_arrays,
|
||||||
enable_mlir_converter=False,
|
enable_mlir_converter=False,
|
||||||
inference_type=lite_constants.QUANTIZED_UINT8)
|
inference_type=dtypes.uint8)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"std_dev and mean must be defined when inference_type or "
|
"std_dev and mean must be defined when inference_type or "
|
||||||
"inference_input_type is QUANTIZED_UINT8 or INT8.",
|
"inference_input_type is QUANTIZED_UINT8 or INT8.",
|
||||||
|
@ -163,9 +163,8 @@ class TargetSpec(object):
|
|||||||
supported_ops: Experimental flag, subject to change. Set of OpsSet options
|
supported_ops: Experimental flag, subject to change. Set of OpsSet options
|
||||||
supported by the device. (default set([OpsSet.TFLITE_BUILTINS]))
|
supported by the device. (default set([OpsSet.TFLITE_BUILTINS]))
|
||||||
supported_types: List of types for constant values on the target device.
|
supported_types: List of types for constant values on the target device.
|
||||||
Supported values are types exported by lite.constants. Frequently, an
|
Frequently, an optimization choice is driven by the most compact
|
||||||
optimization choice is driven by the most compact (i.e. smallest) type in
|
(i.e. smallest) type in this list (default [tf.float32])
|
||||||
this list (default [constants.FLOAT])
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, supported_ops=None, supported_types=None):
|
def __init__(self, supported_ops=None, supported_types=None):
|
||||||
@ -200,7 +199,7 @@ class QuantizationMode(object):
|
|||||||
return (self._any_optimization_enabled() and
|
return (self._any_optimization_enabled() and
|
||||||
not self._is_int16x8_target_required() and
|
not self._is_int16x8_target_required() and
|
||||||
self._representative_dataset is not None and
|
self._representative_dataset is not None and
|
||||||
self._smallest_supported_type() == constants.INT8)
|
self._smallest_supported_type() == _dtypes.int8)
|
||||||
|
|
||||||
def is_post_training_integer_quantize_8(self):
|
def is_post_training_integer_quantize_8(self):
|
||||||
"""Post training integer 8 quantization."""
|
"""Post training integer 8 quantization."""
|
||||||
@ -241,12 +240,12 @@ class QuantizationMode(object):
|
|||||||
return (self._any_optimization_enabled() and
|
return (self._any_optimization_enabled() and
|
||||||
self._representative_dataset is None and
|
self._representative_dataset is None and
|
||||||
not self.contains_training_quant_op() and
|
not self.contains_training_quant_op() and
|
||||||
self._smallest_supported_type() == constants.INT8)
|
self._smallest_supported_type() == _dtypes.int8)
|
||||||
|
|
||||||
def post_training_fp16(self):
|
def post_training_fp16(self):
|
||||||
"""Post training fp16 quantize."""
|
"""Post training fp16 quantize."""
|
||||||
return (self._any_optimization_enabled() and
|
return (self._any_optimization_enabled() and
|
||||||
self._smallest_supported_type() == constants.FLOAT16)
|
self._smallest_supported_type() == _dtypes.float16)
|
||||||
|
|
||||||
def fp32_execution(self):
|
def fp32_execution(self):
|
||||||
"""If none of the above are true."""
|
"""If none of the above are true."""
|
||||||
@ -259,36 +258,36 @@ class QuantizationMode(object):
|
|||||||
self.post_training_fp16())
|
self.post_training_fp16())
|
||||||
|
|
||||||
def activations_type(self):
|
def activations_type(self):
|
||||||
return constants.INT16 if self._is_int16x8_target_required() \
|
return _dtypes.int16 if self._is_int16x8_target_required() \
|
||||||
else constants.INT8
|
else _dtypes.int8
|
||||||
|
|
||||||
def converter_flags(self, inference_ty=None, inference_input_ty=None):
|
def converter_flags(self, inference_ty=None, inference_input_ty=None):
|
||||||
"""Flags to the converter."""
|
"""Flags to the converter."""
|
||||||
if self.is_post_training_integer_quantize():
|
if self.is_post_training_integer_quantize():
|
||||||
# The inference_input_type is for the quantizer, then we need to keep the
|
# The inference_input_type is for the quantizer, then we need to keep the
|
||||||
# converter inference_input_type to float.
|
# converter inference_input_type to float.
|
||||||
inference_input_ty = constants.FLOAT
|
inference_input_ty = _dtypes.float32
|
||||||
|
|
||||||
if self.training_time_int8_allow_float():
|
if self.training_time_int8_allow_float():
|
||||||
return {
|
return {
|
||||||
"inference_type": inference_ty if inference_ty else \
|
"inference_type": inference_ty if inference_ty else \
|
||||||
self.activations_type(),
|
self.activations_type(),
|
||||||
"inference_input_type":
|
"inference_input_type":
|
||||||
inference_input_ty if inference_input_ty else constants.FLOAT,
|
inference_input_ty if inference_input_ty else _dtypes.float32,
|
||||||
"post_training_quantize": False, # disable dynamic range quantization
|
"post_training_quantize": False, # disable dynamic range quantization
|
||||||
"quantize_to_float16": False # disable float16 quantization
|
"quantize_to_float16": False # disable float16 quantization
|
||||||
}
|
}
|
||||||
elif self.post_training_dynamic_range_int8():
|
elif self.post_training_dynamic_range_int8():
|
||||||
return {
|
return {
|
||||||
"inference_type": constants.FLOAT,
|
"inference_type": _dtypes.float32,
|
||||||
"inference_input_type": constants.FLOAT,
|
"inference_input_type": _dtypes.float32,
|
||||||
"post_training_quantize": True, # enable dynamic range quantization
|
"post_training_quantize": True, # enable dynamic range quantization
|
||||||
"quantize_to_float16": False # disable float16 quantization
|
"quantize_to_float16": False # disable float16 quantization
|
||||||
}
|
}
|
||||||
elif self.post_training_fp16():
|
elif self.post_training_fp16():
|
||||||
return {
|
return {
|
||||||
"inference_type": constants.FLOAT,
|
"inference_type": _dtypes.float32,
|
||||||
"inference_input_type": constants.FLOAT,
|
"inference_input_type": _dtypes.float32,
|
||||||
"post_training_quantize": True,
|
"post_training_quantize": True,
|
||||||
"quantize_to_float16": True # enable float16 quantization
|
"quantize_to_float16": True # enable float16 quantization
|
||||||
}
|
}
|
||||||
@ -296,7 +295,7 @@ class QuantizationMode(object):
|
|||||||
# Note this might still trigger (uint8) quantization to be compatible with
|
# Note this might still trigger (uint8) quantization to be compatible with
|
||||||
# TOCO.
|
# TOCO.
|
||||||
return {
|
return {
|
||||||
"inference_type": inference_ty if inference_ty else constants.FLOAT,
|
"inference_type": inference_ty if inference_ty else _dtypes.float32,
|
||||||
"inference_input_type": inference_input_ty,
|
"inference_input_type": inference_input_ty,
|
||||||
"post_training_quantize": False, # enable dynamic range quantization
|
"post_training_quantize": False, # enable dynamic range quantization
|
||||||
"quantize_to_float16": False # disable float16 quantization
|
"quantize_to_float16": False # disable float16 quantization
|
||||||
@ -305,8 +304,8 @@ class QuantizationMode(object):
|
|||||||
def quantizer_flags(self, input_ty=None, output_ty=None):
|
def quantizer_flags(self, input_ty=None, output_ty=None):
|
||||||
"""Default flags to the TFMOT quantizer."""
|
"""Default flags to the TFMOT quantizer."""
|
||||||
|
|
||||||
inference_input_type = input_ty if input_ty else constants.FLOAT
|
inference_input_type = input_ty if input_ty else _dtypes.float32
|
||||||
inference_output_type = output_ty if output_ty else constants.FLOAT
|
inference_output_type = output_ty if output_ty else _dtypes.float32
|
||||||
|
|
||||||
if self.post_training_int8_no_float() \
|
if self.post_training_int8_no_float() \
|
||||||
or self.post_training_int16x8_no_float():
|
or self.post_training_int16x8_no_float():
|
||||||
@ -327,9 +326,8 @@ class QuantizationMode(object):
|
|||||||
else:
|
else:
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
def flags_modify_model_io_type(self,
|
def flags_modify_model_io_type(
|
||||||
input_type=constants.FLOAT,
|
self, input_type=_dtypes.float32, output_type=_dtypes.float32):
|
||||||
output_type=constants.FLOAT):
|
|
||||||
"""Flags for modifying the input and output type of a tflite model."""
|
"""Flags for modifying the input and output type of a tflite model."""
|
||||||
is_post_training_quantize = self.quantizer_flags(input_type, output_type)[0]
|
is_post_training_quantize = self.quantizer_flags(input_type, output_type)[0]
|
||||||
is_training_time_only_quantize = self.training_time_int8_allow_float() and \
|
is_training_time_only_quantize = self.training_time_int8_allow_float() and \
|
||||||
@ -353,7 +351,7 @@ class QuantizationMode(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self._target_spec.supported_types and (self._smallest_supported_type() !=
|
if self._target_spec.supported_types and (self._smallest_supported_type() !=
|
||||||
constants.INT8):
|
_dtypes.int8):
|
||||||
raise ValueError("TFLITE_BUILTINS_INT8 requires smallest supported "
|
raise ValueError("TFLITE_BUILTINS_INT8 requires smallest supported "
|
||||||
"type to be INT8.")
|
"type to be INT8.")
|
||||||
|
|
||||||
@ -372,7 +370,7 @@ class QuantizationMode(object):
|
|||||||
def _is_int8_target_required(self):
|
def _is_int8_target_required(self):
|
||||||
return (set([OpsSet.TFLITE_BUILTINS_INT8]) == set(
|
return (set([OpsSet.TFLITE_BUILTINS_INT8]) == set(
|
||||||
self._target_spec.supported_ops) or
|
self._target_spec.supported_ops) or
|
||||||
set(self._target_spec.supported_types) == set([constants.INT8]))
|
set(self._target_spec.supported_types) == set([_dtypes.int8]))
|
||||||
|
|
||||||
def _is_int16x8_target_required(self):
|
def _is_int16x8_target_required(self):
|
||||||
return bool(
|
return bool(
|
||||||
@ -397,7 +395,7 @@ class QuantizationMode(object):
|
|||||||
return min(self._target_spec.supported_types, key=lambda x: x.size)
|
return min(self._target_spec.supported_types, key=lambda x: x.size)
|
||||||
else:
|
else:
|
||||||
# The default smallest supported type is INT8.
|
# The default smallest supported type is INT8.
|
||||||
return constants.INT8
|
return _dtypes.int8
|
||||||
|
|
||||||
def contains_training_quant_op(self):
|
def contains_training_quant_op(self):
|
||||||
"""Checks if the graph contains any training-time quantization ops."""
|
"""Checks if the graph contains any training-time quantization ops."""
|
||||||
@ -556,18 +554,18 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Constructor for TFLiteConverter."""
|
"""Constructor for TFLiteConverter."""
|
||||||
super(TFLiteConverterBaseV2, self).__init__()
|
super(TFLiteConverterBaseV2, self).__init__()
|
||||||
self.inference_input_type = constants.FLOAT
|
self.inference_input_type = _dtypes.float32
|
||||||
self.inference_output_type = constants.FLOAT
|
self.inference_output_type = _dtypes.float32
|
||||||
|
|
||||||
def _validate_inference_input_output_types(self, quant_mode):
|
def _validate_inference_input_output_types(self, quant_mode):
|
||||||
"""Validate inference_input_type and inference_output_type flags."""
|
"""Validate inference_input_type and inference_output_type flags."""
|
||||||
default_types = [constants.FLOAT]
|
default_types = [_dtypes.float32]
|
||||||
# We support integer input/output for integer quantized models only.
|
# We support integer input/output for integer quantized models only.
|
||||||
if quant_mode.training_time_int8_allow_float():
|
if quant_mode.training_time_int8_allow_float():
|
||||||
if quant_mode.is_post_training_integer_quantize_16x8():
|
if quant_mode.is_post_training_integer_quantize_16x8():
|
||||||
all_types = default_types + [constants.INT16]
|
all_types = default_types + [_dtypes.int16]
|
||||||
else:
|
else:
|
||||||
all_types = default_types + [constants.INT8, constants.QUANTIZED_UINT8]
|
all_types = default_types + [_dtypes.int8, _dtypes.uint8]
|
||||||
if self.inference_input_type not in all_types or \
|
if self.inference_input_type not in all_types or \
|
||||||
self.inference_output_type not in all_types:
|
self.inference_output_type not in all_types:
|
||||||
all_types_names = ["tf." + t.name for t in all_types]
|
all_types_names = ["tf." + t.name for t in all_types]
|
||||||
@ -1148,7 +1146,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
|
|||||||
graph debug info for a set of nodes from the `graph_def`.
|
graph debug info for a set of nodes from the `graph_def`.
|
||||||
"""
|
"""
|
||||||
super(TFLiteConverterBaseV1, self).__init__()
|
super(TFLiteConverterBaseV1, self).__init__()
|
||||||
self.inference_type = constants.FLOAT
|
self.inference_type = _dtypes.float32
|
||||||
self.inference_input_type = None
|
self.inference_input_type = None
|
||||||
self.inference_output_type = None
|
self.inference_output_type = None
|
||||||
self.output_format = constants.TFLITE
|
self.output_format = constants.TFLITE
|
||||||
@ -1195,7 +1193,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
|
|||||||
def _validate_quantized_input_stats(self, converter_kwargs, calibrate):
|
def _validate_quantized_input_stats(self, converter_kwargs, calibrate):
|
||||||
"""Ensure the `quantized_input_stats` flag is provided if required."""
|
"""Ensure the `quantized_input_stats` flag is provided if required."""
|
||||||
|
|
||||||
quantized_types = frozenset({constants.INT8, constants.QUANTIZED_UINT8})
|
quantized_types = frozenset({_dtypes.int8, _dtypes.uint8})
|
||||||
|
|
||||||
requires_quantized_input_stats = (
|
requires_quantized_input_stats = (
|
||||||
(converter_kwargs["inference_type"] in quantized_types or
|
(converter_kwargs["inference_type"] in quantized_types or
|
||||||
@ -1645,8 +1643,8 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
|
|||||||
quantized_input_stats: Dict of strings representing input tensor names
|
quantized_input_stats: Dict of strings representing input tensor names
|
||||||
mapped to tuple of floats representing the mean and standard deviation
|
mapped to tuple of floats representing the mean and standard deviation
|
||||||
of the training data (e.g., {"foo" : (0., 1.)}). Only need if
|
of the training data (e.g., {"foo" : (0., 1.)}). Only need if
|
||||||
`inference_input_type` is `QUANTIZED_UINT8`. real_input_value =
|
`inference_input_type` is `QUANTIZED_UINT8`. real_input_value =
|
||||||
(quantized_input_value - mean_value) / std_dev_value. (default {})
|
(quantized_input_value - mean_value) / std_dev_value. (default {})
|
||||||
default_ranges_stats: Tuple of integers representing (min, max) range values
|
default_ranges_stats: Tuple of integers representing (min, max) range values
|
||||||
for all arrays without a specified range. Intended for experimenting with
|
for all arrays without a specified range. Intended for experimenting with
|
||||||
quantization via "dummy quantization". (default None)
|
quantization via "dummy quantization". (default None)
|
||||||
|
@ -155,8 +155,8 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
# Convert model and ensure model is not None.
|
# Convert model and ensure model is not None.
|
||||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||||
[out_tensor])
|
[out_tensor])
|
||||||
converter.inference_input_type = lite_constants.QUANTIZED_UINT8
|
converter.inference_input_type = dtypes.uint8
|
||||||
converter.inference_type = lite_constants.FLOAT
|
converter.inference_type = dtypes.float32
|
||||||
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
|
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
self.assertIsNotNone(tflite_model)
|
self.assertIsNotNone(tflite_model)
|
||||||
@ -788,8 +788,8 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
quantized_converter = lite.TFLiteConverter.from_session(
|
quantized_converter = lite.TFLiteConverter.from_session(
|
||||||
sess, [inp], [output])
|
sess, [inp], [output])
|
||||||
quantized_converter.experimental_new_converter = enable_mlir_converter
|
quantized_converter.experimental_new_converter = enable_mlir_converter
|
||||||
quantized_converter.inference_input_type = lite_constants.INT8
|
quantized_converter.inference_input_type = dtypes.int8
|
||||||
quantized_converter.inference_output_type = lite_constants.INT8
|
quantized_converter.inference_output_type = dtypes.int8
|
||||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||||
quantized_converter.representative_dataset = calibration_gen
|
quantized_converter.representative_dataset = calibration_gen
|
||||||
quantized_tflite_model = quantized_converter.convert()
|
quantized_tflite_model = quantized_converter.convert()
|
||||||
@ -832,7 +832,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
quantized_converter.experimental_new_converter = enable_mlir_converter
|
quantized_converter.experimental_new_converter = enable_mlir_converter
|
||||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||||
# Restricting to int8 type only
|
# Restricting to int8 type only
|
||||||
quantized_converter.target_spec.supported_types = [lite.constants.INT8]
|
quantized_converter.target_spec.supported_types = [dtypes.int8]
|
||||||
# A representative dataset is required for full fixed point quantization.
|
# A representative dataset is required for full fixed point quantization.
|
||||||
with self.assertRaises(ValueError) as error:
|
with self.assertRaises(ValueError) as error:
|
||||||
quantized_converter.convert()
|
quantized_converter.convert()
|
||||||
@ -857,7 +857,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
converter = lite.TFLiteConverter.from_session(sess,
|
converter = lite.TFLiteConverter.from_session(sess,
|
||||||
[in_tensor_1, in_tensor_2],
|
[in_tensor_1, in_tensor_2],
|
||||||
[out_tensor])
|
[out_tensor])
|
||||||
converter.inference_type = lite_constants.QUANTIZED_UINT8
|
converter.inference_type = dtypes.uint8
|
||||||
converter.quantized_input_stats = {
|
converter.quantized_input_stats = {
|
||||||
'inputA': (0., 1.),
|
'inputA': (0., 1.),
|
||||||
'inputB': (0., 1.)
|
'inputB': (0., 1.)
|
||||||
@ -898,7 +898,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
# Convert model and ensure model is not None.
|
# Convert model and ensure model is not None.
|
||||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||||
[out_tensor])
|
[out_tensor])
|
||||||
converter.inference_type = lite_constants.QUANTIZED_UINT8
|
converter.inference_type = dtypes.uint8
|
||||||
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
|
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
|
||||||
converter.default_ranges_stats = (0, 6) # min, max
|
converter.default_ranges_stats = (0, 6) # min, max
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
@ -954,16 +954,15 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
interpreter.allocate_tensors()
|
interpreter.allocate_tensors()
|
||||||
self.assertEqual(interpreter.get_tensor_details()[idx]['name'], node_name)
|
self.assertEqual(interpreter.get_tensor_details()[idx]['name'], node_name)
|
||||||
self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
|
self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
|
||||||
lite.constants.FLOAT)
|
dtypes.float32)
|
||||||
# Convert model to quantized version
|
# Convert model to quantized version
|
||||||
quantized_converter = lite.TFLiteConverter.from_session(
|
quantized_converter = lite.TFLiteConverter.from_session(
|
||||||
sess, [inp], [output])
|
sess, [inp], [output])
|
||||||
quantized_converter.experimental_new_converter = enable_mlir_converter
|
quantized_converter.experimental_new_converter = enable_mlir_converter
|
||||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||||
quantized_converter.target_spec.supported_types = [lite.constants.FLOAT16]
|
quantized_converter.target_spec.supported_types = [dtypes.float16]
|
||||||
if include_int8:
|
if include_int8:
|
||||||
quantized_converter.target_spec.supported_types.append(
|
quantized_converter.target_spec.supported_types.append(dtypes.int8)
|
||||||
lite.constants.INT8)
|
|
||||||
if use_rep_data:
|
if use_rep_data:
|
||||||
quantized_converter.representative_dataset = calibration_gen
|
quantized_converter.representative_dataset = calibration_gen
|
||||||
|
|
||||||
@ -984,11 +983,11 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
if is_float16_quantized:
|
if is_float16_quantized:
|
||||||
# Verify that bias constant is float16 type.
|
# Verify that bias constant is float16 type.
|
||||||
self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
|
self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
|
||||||
lite.constants.FLOAT16)
|
dtypes.float16)
|
||||||
elif is_post_training_quantized:
|
elif is_post_training_quantized:
|
||||||
# Verify that bias constants is int32 type.
|
# Verify that bias constants is int32 type.
|
||||||
self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
|
self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
|
||||||
lite.constants.INT32)
|
dtypes.int32)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid test options.')
|
raise ValueError('Invalid test options.')
|
||||||
|
|
||||||
@ -1005,7 +1004,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
sess, [inp], [output])
|
sess, [inp], [output])
|
||||||
quantized_converter.experimental_new_converter = enable_mlir_converter
|
quantized_converter.experimental_new_converter = enable_mlir_converter
|
||||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||||
quantized_converter.target_spec.supported_types = [lite.constants.FLOAT16]
|
quantized_converter.target_spec.supported_types = [dtypes.float16]
|
||||||
# Specify only int8 builtin ops
|
# Specify only int8 builtin ops
|
||||||
quantized_converter.target_spec.supported_ops = [
|
quantized_converter.target_spec.supported_ops = [
|
||||||
lite.OpsSet.TFLITE_BUILTINS_INT8
|
lite.OpsSet.TFLITE_BUILTINS_INT8
|
||||||
@ -1017,8 +1016,8 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
str(error.exception))
|
str(error.exception))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('InferenceType_INT8', lite_constants.INT8),
|
('InferenceType_INT8', dtypes.int8),
|
||||||
('InferenceType_UINT8', lite_constants.QUANTIZED_UINT8))
|
('InferenceType_UINT8', dtypes.uint8))
|
||||||
def testInvalidQuantizeQATModelRequiresInputStats(self, quantized_type):
|
def testInvalidQuantizeQATModelRequiresInputStats(self, quantized_type):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
in_tensor = array_ops.placeholder(
|
in_tensor = array_ops.placeholder(
|
||||||
@ -1039,7 +1038,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
'flag is set to tf.uint8 or tf.int8.', str(error.exception))
|
'flag is set to tf.uint8 or tf.int8.', str(error.exception))
|
||||||
|
|
||||||
with self.assertRaises(ValueError) as error:
|
with self.assertRaises(ValueError) as error:
|
||||||
quantized_converter.inference_type = lite_constants.FLOAT
|
quantized_converter.inference_type = dtypes.float32
|
||||||
quantized_converter.inference_input_type = quantized_type
|
quantized_converter.inference_input_type = quantized_type
|
||||||
quantized_converter.convert()
|
quantized_converter.convert()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -1070,7 +1069,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
converter = lite.TFLiteConverter.from_session(sess,
|
converter = lite.TFLiteConverter.from_session(sess,
|
||||||
[in_tensor_1, in_tensor_2],
|
[in_tensor_1, in_tensor_2],
|
||||||
[out_tensor])
|
[out_tensor])
|
||||||
converter.inference_type = lite_constants.QUANTIZED_UINT8
|
converter.inference_type = dtypes.uint8
|
||||||
converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev
|
converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev
|
||||||
with self.assertRaises(ValueError) as error:
|
with self.assertRaises(ValueError) as error:
|
||||||
converter.convert()
|
converter.convert()
|
||||||
@ -1091,9 +1090,9 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
|
converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
|
||||||
|
|
||||||
# extra flags to trigger training time quantization conversion
|
# extra flags to trigger training time quantization conversion
|
||||||
converter.inference_type = lite_constants.INT8
|
converter.inference_type = dtypes.int8
|
||||||
converter.inference_input_type = lite_constants.FLOAT
|
converter.inference_input_type = dtypes.float32
|
||||||
converter.inference_output_type = lite_constants.FLOAT
|
converter.inference_output_type = dtypes.float32
|
||||||
input_arrays = converter.get_input_arrays()
|
input_arrays = converter.get_input_arrays()
|
||||||
converter.quantized_input_stats = {
|
converter.quantized_input_stats = {
|
||||||
input_arrays[0]: (0., 1.)
|
input_arrays[0]: (0., 1.)
|
||||||
@ -1255,7 +1254,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
# Convert model and ensure model is not None.
|
# Convert model and ensure model is not None.
|
||||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||||
[out_tensor])
|
[out_tensor])
|
||||||
converter.inference_type = lite_constants.QUANTIZED_UINT8
|
converter.inference_type = dtypes.uint8
|
||||||
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
|
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
self.assertIsNotNone(tflite_model)
|
self.assertIsNotNone(tflite_model)
|
||||||
@ -2334,7 +2333,7 @@ class DefaultConverterAttrsTest(LiteTest):
|
|||||||
self.assertEqual(converter.output_format, lite_constants.TFLITE)
|
self.assertEqual(converter.output_format, lite_constants.TFLITE)
|
||||||
|
|
||||||
# Assert the default inference type is float.
|
# Assert the default inference type is float.
|
||||||
self.assertEqual(converter.inference_type, lite_constants.FLOAT)
|
self.assertEqual(converter.inference_type, dtypes.float32)
|
||||||
|
|
||||||
# Assert the default inference type overrides are None.
|
# Assert the default inference type overrides are None.
|
||||||
self.assertIsNone(converter.inference_input_type)
|
self.assertIsNone(converter.inference_input_type)
|
||||||
|
@ -32,6 +32,7 @@ from tensorflow.lite.python import lite_v2_test_util
|
|||||||
from tensorflow.lite.python.convert import mlir_quantize
|
from tensorflow.lite.python.convert import mlir_quantize
|
||||||
from tensorflow.lite.python.interpreter import Interpreter
|
from tensorflow.lite.python.interpreter import Interpreter
|
||||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras.layers import recurrent
|
from tensorflow.python.keras.layers import recurrent
|
||||||
@ -74,9 +75,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
|||||||
self.assertEqual(expected_value.numpy(), actual_value)
|
self.assertEqual(expected_value.numpy(), actual_value)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_INT8InputOutput', lite.constants.INT8),
|
('_INT8InputOutput', dtypes.int8),
|
||||||
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8),
|
('_UINT8InputOutput', dtypes.uint8),
|
||||||
('_INT16InputOutput', lite.constants.INT16))
|
('_INT16InputOutput', dtypes.int16))
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testInvalidFloat(self, inference_input_output_type):
|
def testInvalidFloat(self, inference_input_output_type):
|
||||||
root = self._getSimpleVariableModel()
|
root = self._getSimpleVariableModel()
|
||||||
@ -194,9 +195,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
|||||||
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
|
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_INT8InputOutput', lite.constants.INT8),
|
('_INT8InputOutput', dtypes.int8),
|
||||||
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8),
|
('_UINT8InputOutput', dtypes.uint8),
|
||||||
('_INT16InputOutput', lite.constants.INT16))
|
('_INT16InputOutput', dtypes.int16))
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testInvalidPostTrainingDynamicRangeQuantization(
|
def testInvalidPostTrainingDynamicRangeQuantization(
|
||||||
self, inference_input_output_type):
|
self, inference_input_output_type):
|
||||||
@ -219,18 +220,18 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
|||||||
'must be tf.float32.', str(error.exception))
|
'must be tf.float32.', str(error.exception))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_Default', False, False, lite.constants.FLOAT),
|
('_Default', False, False, dtypes.float32),
|
||||||
('_INT8InputOutput', False, False, lite.constants.INT8),
|
('_INT8InputOutput', False, False, dtypes.int8),
|
||||||
('_UINT8InputOutput', False, False, lite.constants.QUANTIZED_UINT8),
|
('_UINT8InputOutput', False, False, dtypes.uint8),
|
||||||
('_INT16Quantize', False, True, lite.constants.FLOAT),
|
('_INT16Quantize', False, True, dtypes.float32),
|
||||||
('_INT16Quantize_INT16InputOutput', False, True, lite.constants.INT16),
|
('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
|
||||||
('_IntOnly', True, False, lite.constants.FLOAT),
|
('_IntOnly', True, False, dtypes.float32),
|
||||||
('_IntOnly_INT8InputOutput', True, False, lite.constants.INT8),
|
('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
|
||||||
('_IntOnly_UINT8InputOutput', True, False,
|
('_IntOnly_UINT8InputOutput', True, False,
|
||||||
lite.constants.QUANTIZED_UINT8),
|
dtypes.uint8),
|
||||||
('_IntOnly_INT16Quantize', True, True, lite.constants.FLOAT),
|
('_IntOnly_INT16Quantize', True, True, dtypes.float32),
|
||||||
('_IntOnly_INT16Quantize_INT16InputOutput', True, True,
|
('_IntOnly_INT16Quantize_INT16InputOutput', True, True,
|
||||||
lite.constants.INT16))
|
dtypes.int16))
|
||||||
def testIntegerQuantization(self, is_int_only, is_int16_quantize,
|
def testIntegerQuantization(self, is_int_only, is_int16_quantize,
|
||||||
inference_input_output_type):
|
inference_input_output_type):
|
||||||
func, calibration_gen = self._getIntegerQuantizeModel()
|
func, calibration_gen = self._getIntegerQuantizeModel()
|
||||||
@ -281,7 +282,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
|||||||
self.assertLess(len(quantized_tflite_model), len(tflite_model))
|
self.assertLess(len(quantized_tflite_model), len(tflite_model))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_INT16Quantize_INT8InputOutput', True, lite.constants.INT8))
|
('_INT16Quantize_INT8InputOutput', True, dtypes.int8))
|
||||||
def testInvalidIntegerQuantization(self, is_int16_quantize,
|
def testInvalidIntegerQuantization(self, is_int16_quantize,
|
||||||
inference_input_output_type):
|
inference_input_output_type):
|
||||||
func, calibration_gen = self._getIntegerQuantizeModel()
|
func, calibration_gen = self._getIntegerQuantizeModel()
|
||||||
@ -297,8 +298,8 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
|||||||
lite.OpsSet.TFLITE_BUILTINS
|
lite.OpsSet.TFLITE_BUILTINS
|
||||||
]
|
]
|
||||||
with self.assertRaises(ValueError) as error:
|
with self.assertRaises(ValueError) as error:
|
||||||
quantized_converter.inference_input_type = lite.constants.INT8
|
quantized_converter.inference_input_type = dtypes.int8
|
||||||
quantized_converter.inference_output_type = lite.constants.INT8
|
quantized_converter.inference_output_type = dtypes.int8
|
||||||
quantized_converter.convert()
|
quantized_converter.convert()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"The inference_input_type and inference_output_type "
|
"The inference_input_type and inference_output_type "
|
||||||
@ -377,9 +378,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
|||||||
return tf.keras.Sequential(QLinear(3, input_shape=(2,)))
|
return tf.keras.Sequential(QLinear(3, input_shape=(2,)))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_DefaultFLOAT32InputOutput', lite.constants.FLOAT),
|
('_DefaultFLOAT32InputOutput', dtypes.float32),
|
||||||
('_INT8InputOutput', lite.constants.INT8),
|
('_INT8InputOutput', dtypes.int8),
|
||||||
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
|
('_UINT8InputOutput', dtypes.uint8))
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testTrainingTimeQuantization(self, inference_input_output_type):
|
def testTrainingTimeQuantization(self, inference_input_output_type):
|
||||||
model = self._getTrainingTimeQuantizedModel()
|
model = self._getTrainingTimeQuantizedModel()
|
||||||
|
@ -50,6 +50,7 @@ py_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":_pywrap_tensorflow_lite_calibration_wrapper", # buildcleaner: keep
|
":_pywrap_tensorflow_lite_calibration_wrapper", # buildcleaner: keep
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
@ -67,8 +68,8 @@ py_test(
|
|||||||
tags = ["no_oss"],
|
tags = ["no_oss"],
|
||||||
deps = [
|
deps = [
|
||||||
":calibrator",
|
":calibrator",
|
||||||
"//tensorflow/lite/python:lite_constants",
|
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
|
@ -18,8 +18,9 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||||
from tensorflow.lite.python import lite_constants
|
|
||||||
|
|
||||||
# Lazy load since some of the performance benchmark skylark rules
|
# Lazy load since some of the performance benchmark skylark rules
|
||||||
# break dependencies. Must use double quotes to match code internal rewrite
|
# break dependencies. Must use double quotes to match code internal rewrite
|
||||||
@ -60,7 +61,7 @@ class Calibrator(object):
|
|||||||
input_type,
|
input_type,
|
||||||
output_type,
|
output_type,
|
||||||
allow_float,
|
allow_float,
|
||||||
activations_type=lite_constants.INT8,
|
activations_type=dtypes.int8,
|
||||||
resize_input=True):
|
resize_input=True):
|
||||||
"""Calibrates the model with specified generator and then quantizes it.
|
"""Calibrates the model with specified generator and then quantizes it.
|
||||||
|
|
||||||
|
@ -23,8 +23,8 @@ from absl.testing import parameterized
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from six.moves import range
|
from six.moves import range
|
||||||
|
|
||||||
from tensorflow.lite.python import lite_constants as constants
|
|
||||||
from tensorflow.lite.python.optimize import calibrator as _calibrator
|
from tensorflow.lite.python.optimize import calibrator as _calibrator
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.python.platform import resource_loader
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -34,9 +34,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
# Activation type Int8
|
# Activation type Int8
|
||||||
('UseActivationTypeInt8', constants.INT8),
|
('UseActivationTypeInt8', dtypes.int8),
|
||||||
# Activation type Int16
|
# Activation type Int16
|
||||||
('UseActivationTypeInt16', constants.INT16))
|
('UseActivationTypeInt16', dtypes.int16))
|
||||||
def test_calibration_with_quantization(self, activations_type):
|
def test_calibration_with_quantization(self, activations_type):
|
||||||
model_path = resource_loader.get_path_to_datafile(
|
model_path = resource_loader.get_path_to_datafile(
|
||||||
'test_data/mobilenet_like_model.bin')
|
'test_data/mobilenet_like_model.bin')
|
||||||
@ -49,16 +49,17 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
|
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
|
||||||
|
|
||||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||||
constants.FLOAT,
|
dtypes.float32,
|
||||||
constants.FLOAT, False,
|
dtypes.float32,
|
||||||
|
False,
|
||||||
activations_type)
|
activations_type)
|
||||||
self.assertIsNotNone(quantized_model)
|
self.assertIsNotNone(quantized_model)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
# Activation type Int8
|
# Activation type Int8
|
||||||
('UseActivationTypeInt8', constants.INT8),
|
('UseActivationTypeInt8', dtypes.int8),
|
||||||
# Activation type Int16
|
# Activation type Int16
|
||||||
('UseActivationTypeInt16', constants.INT16))
|
('UseActivationTypeInt16', dtypes.int16))
|
||||||
def test_calibration_with_quantization_allow_float(self, activations_type):
|
def test_calibration_with_quantization_allow_float(self, activations_type):
|
||||||
model_path = resource_loader.get_path_to_datafile(
|
model_path = resource_loader.get_path_to_datafile(
|
||||||
'test_data/mobilenet_like_model.bin')
|
'test_data/mobilenet_like_model.bin')
|
||||||
@ -71,8 +72,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
|
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
|
||||||
|
|
||||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||||
constants.FLOAT,
|
dtypes.float32,
|
||||||
constants.FLOAT, True,
|
dtypes.float32,
|
||||||
|
True,
|
||||||
activations_type)
|
activations_type)
|
||||||
self.assertIsNotNone(quantized_model)
|
self.assertIsNotNone(quantized_model)
|
||||||
|
|
||||||
@ -88,7 +90,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
|
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
|
||||||
|
|
||||||
quantized_model = quantizer.calibrate_and_quantize_single(
|
quantized_model = quantizer.calibrate_and_quantize_single(
|
||||||
input_gen, constants.FLOAT, constants.FLOAT, True, 'conv2d_8/BiasAdd')
|
input_gen, dtypes.float32, dtypes.float32, True, 'conv2d_8/BiasAdd')
|
||||||
self.assertIsNotNone(quantized_model)
|
self.assertIsNotNone(quantized_model)
|
||||||
|
|
||||||
def test_calibration_with_string_input(self):
|
def test_calibration_with_string_input(self):
|
||||||
@ -103,14 +105,14 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
yield [np.array(u'Test' + str(i))]
|
yield [np.array(u'Test' + str(i))]
|
||||||
|
|
||||||
quantized_model = quantizer.calibrate_and_quantize_single(
|
quantized_model = quantizer.calibrate_and_quantize_single(
|
||||||
input_gen, constants.FLOAT, constants.FLOAT, True, 'Identity')
|
input_gen, dtypes.float32, dtypes.float32, True, 'Identity')
|
||||||
self.assertIsNotNone(quantized_model)
|
self.assertIsNotNone(quantized_model)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
# Activation type Int8
|
# Activation type Int8
|
||||||
('UseActivationTypeInt8 - EnableMlirQuantizer', constants.INT8),
|
('UseActivationTypeInt8 - EnableMlirQuantizer', dtypes.int8),
|
||||||
# Activation type Int16
|
# Activation type Int16
|
||||||
('UseActivationTypeInt16 - DisableEnableMlirQuantizer', constants.INT16))
|
('UseActivationTypeInt16 - DisableEnableMlirQuantizer', dtypes.int16))
|
||||||
def test_calibration_with_quantization_multiple_inputs(
|
def test_calibration_with_quantization_multiple_inputs(
|
||||||
self, activations_type):
|
self, activations_type):
|
||||||
# Load multi add model from test data.
|
# Load multi add model from test data.
|
||||||
@ -126,8 +128,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
yield [np.ones(shape=(1, 8, 8, 3), dtype=np.float32) for _ in range(4)]
|
yield [np.ones(shape=(1, 8, 8, 3), dtype=np.float32) for _ in range(4)]
|
||||||
|
|
||||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||||
constants.FLOAT,
|
dtypes.float32,
|
||||||
constants.FLOAT, False,
|
dtypes.float32,
|
||||||
|
False,
|
||||||
activations_type)
|
activations_type)
|
||||||
self.assertIsNotNone(quantized_model)
|
self.assertIsNotNone(quantized_model)
|
||||||
|
|
||||||
@ -148,8 +151,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
yield i
|
yield i
|
||||||
|
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
quantizer.calibrate_and_quantize(empty_input_gen, constants.FLOAT,
|
quantizer.calibrate_and_quantize(empty_input_gen, dtypes.float32,
|
||||||
constants.FLOAT, False)
|
dtypes.float32, False)
|
||||||
|
|
||||||
def test_invalid_shape_calibrator_gen(self):
|
def test_invalid_shape_calibrator_gen(self):
|
||||||
model_path = resource_loader.get_path_to_datafile(
|
model_path = resource_loader.get_path_to_datafile(
|
||||||
@ -163,8 +166,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
yield [np.ones(shape=(1, 2, 2, 3), dtype=np.float32)]
|
yield [np.ones(shape=(1, 2, 2, 3), dtype=np.float32)]
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, 'Size mismatch'):
|
with self.assertRaisesRegex(ValueError, 'Size mismatch'):
|
||||||
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
quantizer.calibrate_and_quantize(input_gen, dtypes.float32,
|
||||||
constants.FLOAT, False, constants.INT8,
|
dtypes.float32, False, dtypes.int8,
|
||||||
False)
|
False)
|
||||||
|
|
||||||
def test_invalid_type_calibrator_gen(self):
|
def test_invalid_type_calibrator_gen(self):
|
||||||
@ -179,8 +182,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.int32)]
|
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.int32)]
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
quantizer.calibrate_and_quantize(input_gen, dtypes.float32,
|
||||||
constants.FLOAT, False, constants.INT8)
|
dtypes.float32, False, dtypes.int8)
|
||||||
|
|
||||||
def test_calibration(self):
|
def test_calibration(self):
|
||||||
model_path = resource_loader.get_path_to_datafile(
|
model_path = resource_loader.get_path_to_datafile(
|
||||||
|
@ -28,12 +28,12 @@ import six
|
|||||||
from six.moves import zip
|
from six.moves import zip
|
||||||
|
|
||||||
from tensorflow.lite.python import lite
|
from tensorflow.lite.python import lite
|
||||||
from tensorflow.lite.python import lite_constants
|
|
||||||
from tensorflow.lite.python.convert import register_custom_opdefs
|
from tensorflow.lite.python.convert import register_custom_opdefs
|
||||||
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
||||||
from tensorflow.lite.toco.logging import gen_html
|
from tensorflow.lite.toco.logging import gen_html
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.platform import app
|
from tensorflow.python.platform import app
|
||||||
|
|
||||||
|
|
||||||
@ -63,13 +63,14 @@ def _parse_inference_type(value, flag):
|
|||||||
ValueError: Unsupported value.
|
ValueError: Unsupported value.
|
||||||
"""
|
"""
|
||||||
if value == "FLOAT":
|
if value == "FLOAT":
|
||||||
return lite_constants.FLOAT
|
return dtypes.float32
|
||||||
if value == "QUANTIZED_UINT8":
|
|
||||||
return lite_constants.QUANTIZED_UINT8
|
|
||||||
if value == "INT8":
|
if value == "INT8":
|
||||||
return lite_constants.INT8
|
return dtypes.int8
|
||||||
raise ValueError("Unsupported value for --{0}. Only FLOAT and "
|
if value == "UINT8" or value == "QUANTIZED_UINT8":
|
||||||
"QUANTIZED_UINT8 are supported.".format(flag))
|
return dtypes.uint8
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported value for `{}` flag. Expected FLOAT, INT8 or UINT8, instead "
|
||||||
|
"got {}.".format(flag, value))
|
||||||
|
|
||||||
|
|
||||||
def _get_tflite_converter(flags):
|
def _get_tflite_converter(flags):
|
||||||
@ -151,10 +152,10 @@ def _convert_tf1_model(flags):
|
|||||||
|
|
||||||
# In quantized inference, mean_value has to be integer so that the real
|
# In quantized inference, mean_value has to be integer so that the real
|
||||||
# value 0.0 is exactly representable.
|
# value 0.0 is exactly representable.
|
||||||
if converter.inference_type == lite_constants.QUANTIZED_UINT8:
|
if converter.inference_type == dtypes.float32:
|
||||||
mean_values = _parse_array(flags.mean_values, type_fn=int)
|
|
||||||
else:
|
|
||||||
mean_values = _parse_array(flags.mean_values, type_fn=float)
|
mean_values = _parse_array(flags.mean_values, type_fn=float)
|
||||||
|
else:
|
||||||
|
mean_values = _parse_array(flags.mean_values, type_fn=int)
|
||||||
quant_stats = list(zip(mean_values, std_dev_values))
|
quant_stats = list(zip(mean_values, std_dev_values))
|
||||||
if ((not flags.input_arrays and len(input_arrays) > 1) or
|
if ((not flags.input_arrays and len(input_arrays) > 1) or
|
||||||
(len(input_arrays) != len(quant_stats))):
|
(len(input_arrays) != len(quant_stats))):
|
||||||
@ -193,13 +194,13 @@ def _convert_tf1_model(flags):
|
|||||||
|
|
||||||
if flags.post_training_quantize:
|
if flags.post_training_quantize:
|
||||||
converter.optimizations = [lite.Optimize.DEFAULT]
|
converter.optimizations = [lite.Optimize.DEFAULT]
|
||||||
if converter.inference_type == lite_constants.QUANTIZED_UINT8:
|
if converter.inference_type != dtypes.float32:
|
||||||
print("--post_training_quantize quantizes a graph of inference_type "
|
print("--post_training_quantize quantizes a graph of inference_type "
|
||||||
"FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.")
|
"FLOAT. Overriding inference_type to FLOAT.")
|
||||||
converter.inference_type = lite_constants.FLOAT
|
converter.inference_type = dtypes.float32
|
||||||
|
|
||||||
if flags.quantize_to_float16:
|
if flags.quantize_to_float16:
|
||||||
converter.target_spec.supported_types = [lite.constants.FLOAT16]
|
converter.target_spec.supported_types = [dtypes.float16]
|
||||||
if not flags.post_training_quantize:
|
if not flags.post_training_quantize:
|
||||||
print("--quantize_to_float16 will only take effect with the "
|
print("--quantize_to_float16 will only take effect with the "
|
||||||
"--post_training_quantize flag enabled.")
|
"--post_training_quantize flag enabled.")
|
||||||
@ -358,14 +359,15 @@ def _get_tf1_flags(parser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--inference_type",
|
"--inference_type",
|
||||||
type=str.upper,
|
type=str.upper,
|
||||||
choices=["FLOAT", "QUANTIZED_UINT8", "INT8"],
|
default="FLOAT",
|
||||||
help="Target data type of real-number arrays in the output file.")
|
help=("Target data type of real-number arrays in the output file. "
|
||||||
|
"Must be either FLOAT, INT8 or UINT8."))
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--inference_input_type",
|
"--inference_input_type",
|
||||||
type=str.upper,
|
type=str.upper,
|
||||||
choices=["FLOAT", "QUANTIZED_UINT8", "INT8"],
|
|
||||||
help=("Target data type of real-number input arrays. Allows for a "
|
help=("Target data type of real-number input arrays. Allows for a "
|
||||||
"different type for input arrays in the case of quantization."))
|
"different type for input arrays in the case of quantization. "
|
||||||
|
"Must be either FLOAT, INT8 or UINT8."))
|
||||||
|
|
||||||
# Input and output arrays flags.
|
# Input and output arrays flags.
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -168,6 +168,37 @@ class TfLiteConvertV1Test(TestModels):
|
|||||||
self._run(flags_str, should_succeed=True)
|
self._run(flags_str, should_succeed=True)
|
||||||
os.remove(graph_def_file)
|
os.remove(graph_def_file)
|
||||||
|
|
||||||
|
def testQATFrozenGraphDefUInt8(self):
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
in_tensor_1 = array_ops.placeholder(
|
||||||
|
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
|
||||||
|
in_tensor_2 = array_ops.placeholder(
|
||||||
|
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
|
||||||
|
_ = array_ops.fake_quant_with_min_max_args(
|
||||||
|
in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
|
||||||
|
sess = session.Session()
|
||||||
|
|
||||||
|
# Write graph to file.
|
||||||
|
graph_def_file = self._getFilepath('model.pb')
|
||||||
|
write_graph(sess.graph_def, '', graph_def_file, False)
|
||||||
|
sess.close()
|
||||||
|
|
||||||
|
# Define converter flags
|
||||||
|
flags_str = ('--std_dev_values=128,128 --mean_values=128,128 '
|
||||||
|
'--graph_def_file={0} --input_arrays={1} '
|
||||||
|
'--output_arrays={2}'.format(
|
||||||
|
graph_def_file, 'inputA,inputB', 'output'))
|
||||||
|
|
||||||
|
# Set inference_type UINT8 and (default) inference_input_type UINT8
|
||||||
|
flags_str_1 = flags_str + ' --inference_type=UINT8'
|
||||||
|
self._run(flags_str_1, should_succeed=True)
|
||||||
|
|
||||||
|
# Set inference_type UINT8 and inference_input_type FLOAT
|
||||||
|
flags_str_2 = flags_str_1 + ' --inference_input_type=FLOAT'
|
||||||
|
self._run(flags_str_2, should_succeed=True)
|
||||||
|
|
||||||
|
os.remove(graph_def_file)
|
||||||
|
|
||||||
def testSavedModel(self):
|
def testSavedModel(self):
|
||||||
saved_model_dir = self._getFilepath('model')
|
saved_model_dir = self._getFilepath('model')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
|
@ -24,7 +24,6 @@ import numpy as np
|
|||||||
from six.moves import range
|
from six.moves import range
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.lite.python import lite_constants
|
|
||||||
from tensorflow.lite.python import util
|
from tensorflow.lite.python import util
|
||||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
@ -292,9 +291,9 @@ def _test_param_modify_integer_model_io_type():
|
|||||||
# "DuringTraining": False,
|
# "DuringTraining": False,
|
||||||
}
|
}
|
||||||
map_types = {
|
map_types = {
|
||||||
"": lite_constants.FLOAT,
|
"": dtypes.float32,
|
||||||
"INT8": lite_constants.INT8,
|
"INT8": dtypes.int8,
|
||||||
"UINT8": lite_constants.QUANTIZED_UINT8
|
"UINT8": dtypes.uint8,
|
||||||
}
|
}
|
||||||
for k1, v1 in map_model_type.items():
|
for k1, v1 in map_model_type.items():
|
||||||
for k2, v2 in map_types.items():
|
for k2, v2 in map_types.items():
|
||||||
|
@ -28,11 +28,11 @@ from google.protobuf.message import DecodeError
|
|||||||
from tensorflow.core.framework import graph_pb2 as _graph_pb2
|
from tensorflow.core.framework import graph_pb2 as _graph_pb2
|
||||||
from tensorflow.lite.python import convert_saved_model as _convert_saved_model
|
from tensorflow.lite.python import convert_saved_model as _convert_saved_model
|
||||||
from tensorflow.lite.python import lite as _lite
|
from tensorflow.lite.python import lite as _lite
|
||||||
from tensorflow.lite.python import lite_constants as constants
|
|
||||||
from tensorflow.lite.python import util as _util
|
from tensorflow.lite.python import util as _util
|
||||||
from tensorflow.python import keras as _keras
|
from tensorflow.python import keras as _keras
|
||||||
from tensorflow.python.client import session as _session
|
from tensorflow.python.client import session as _session
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
|
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
|
||||||
from tensorflow.python.keras.preprocessing import image
|
from tensorflow.python.keras.preprocessing import image
|
||||||
@ -97,7 +97,7 @@ def _convert(converter, **kwargs):
|
|||||||
if "post_training_quantize" in kwargs:
|
if "post_training_quantize" in kwargs:
|
||||||
converter.optimizations = [_lite.Optimize.DEFAULT]
|
converter.optimizations = [_lite.Optimize.DEFAULT]
|
||||||
if kwargs.get("quantize_to_float16", False):
|
if kwargs.get("quantize_to_float16", False):
|
||||||
converter.target_spec.supported_types = [constants.FLOAT16]
|
converter.target_spec.supported_types = [dtypes.float16]
|
||||||
if kwargs.get("post_training_quantize_16x8", False):
|
if kwargs.get("post_training_quantize_16x8", False):
|
||||||
input_size = kwargs.get("model_input_size")
|
input_size = kwargs.get("model_input_size")
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ py_library(
|
|||||||
name = "modify_model_interface_constants",
|
name = "modify_model_interface_constants",
|
||||||
srcs = ["modify_model_interface_constants.py"],
|
srcs = ["modify_model_interface_constants.py"],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = ["//tensorflow/lite/python:lite_constants"],
|
deps = ["//tensorflow/python:dtypes"],
|
||||||
)
|
)
|
||||||
|
|
||||||
pybind_extension(
|
pybind_extension(
|
||||||
|
@ -19,12 +19,12 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.lite.python import lite_constants
|
from tensorflow.python.framework import dtypes
|
||||||
|
|
||||||
STR_TO_TFLITE_TYPES = {
|
STR_TO_TFLITE_TYPES = {
|
||||||
'INT8': lite_constants.INT8,
|
'INT8': dtypes.int8,
|
||||||
'INT16': lite_constants.INT16,
|
'UINT8': dtypes.uint8,
|
||||||
'UINT8': lite_constants.QUANTIZED_UINT8
|
'INT16': dtypes.int16,
|
||||||
}
|
}
|
||||||
TFLITE_TO_STR_TYPES = {v: k for k, v in STR_TO_TFLITE_TYPES.items()}
|
TFLITE_TO_STR_TYPES = {v: k for k, v in STR_TO_TFLITE_TYPES.items()}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user