Reduce redundancy by replacing TFLite data types with TF data types
PiperOrigin-RevId: 337004382 Change-Id: I48b5734781527138705bdda5e6ea80120e3a33a1
This commit is contained in:
parent
e485bb64c5
commit
2396803aa1
@ -104,8 +104,8 @@ tflite_convert \
|
||||
--std_dev_values=127.7
|
||||
```
|
||||
|
||||
*If you're setting `--inference_type=QUANTIZED_UINT8` then update
|
||||
`--mean_values=128` and `--std_dev_values=127`*
|
||||
*If you're setting `--inference_type=UINT8` then update `--mean_values=128` and
|
||||
`--std_dev_values=127`*
|
||||
|
||||
#### Convert a model with \"dummy-quantization\" into a quantized TensorFlow Lite model
|
||||
|
||||
@ -134,8 +134,8 @@ tflite_convert \
|
||||
--default_ranges_max=6
|
||||
```
|
||||
|
||||
*If you're setting `--inference_type=QUANTIZED_UINT8` then update
|
||||
`--mean_values=128` and `--std_dev_values=127`*
|
||||
*If you're setting `--inference_type=UINT8` then update `--mean_values=128` and
|
||||
`--std_dev_values=127`*
|
||||
|
||||
#### Convert a model with select TensorFlow operators.
|
||||
|
||||
|
@ -63,8 +63,7 @@ based on index.
|
||||
has a shape of [2, 3] and "bar" has a shape of [4, 5, 6].
|
||||
* `--std_dev_values`, `--mean_values`. Type: comma-separated list of floats.
|
||||
These specify the (de-)quantization parameters of the input array, when it
|
||||
is quantized. This is only needed if `inference_input_type` is `INT8` or
|
||||
`QUANTIZED_UINT8`.
|
||||
is quantized. Only needed if `inference_input_type` is `INT8` or `UINT8`.
|
||||
* The meaning of `mean_values` and `std_dev_values` is as follows: each
|
||||
quantized value in the quantized input array will be interpreted as a
|
||||
mathematical real number (i.e. as an input activation value) according
|
||||
@ -75,12 +74,12 @@ based on index.
|
||||
the inference code according to the above formula, before proceeding
|
||||
with float inference.
|
||||
* When performing quantized inference (`inference_type`
|
||||
is`INT8`or`QUANTIZED_UINT8`), no dequantization is performed by the
|
||||
inference code. However, the quantization parameters of all arrays,
|
||||
including those of the input arrays as specified
|
||||
by`mean_value`and`std_dev_value`, determine the fixed-point multipliers
|
||||
used in the quantized inference code.`mean_value` must be an integer
|
||||
when performing quantized inference.
|
||||
is`INT8`or`UINT8`), no dequantization is performed by the inference
|
||||
code. However, the quantization parameters of all arrays, including
|
||||
those of the input arrays as specified by`mean_value`and`std_dev_value`,
|
||||
determine the fixed-point multipliers used in the quantized inference
|
||||
code.`mean_value` must be an integer when performing quantized
|
||||
inference.
|
||||
|
||||
## Transformation flags
|
||||
|
||||
@ -90,7 +89,7 @@ have.
|
||||
|
||||
* `--inference_type`. Type: string. Default: `FLOAT`. Data type of all
|
||||
real-number arrays in the output file except for input arrays (defined by
|
||||
`--inference_input_type`). Must be `{FLOAT, INT8, QUANTIZED_UINT8}`.
|
||||
`--inference_input_type`). Must be `{FLOAT, INT8, UINT8}`.
|
||||
|
||||
This flag only impacts real-number arrays including float and quantized
|
||||
arrays. This excludes all other data types including plain integer arrays
|
||||
@ -102,16 +101,15 @@ have.
|
||||
* If `INT8`, then real-numbers arrays will be quantized as int8 in the
|
||||
output file. If they were float in the input file, then they get
|
||||
quantized.
|
||||
* If `QUANTIZED_UINT8`, then real-numbers arrays will be quantized as
|
||||
uint8 in the output file. If they were float in the input file, then
|
||||
they get quantized.
|
||||
* If `UINT8`, then real-numbers arrays will be quantized as uint8 in the
|
||||
output file. If they were float in the input file, then they get
|
||||
quantized.
|
||||
|
||||
* `--inference_input_type`. Type: string. Data type of a real-number input
|
||||
array in the output file. By default the `--inference_type` is used as type
|
||||
of all of the input arrays. Flag is primarily intended for generating a
|
||||
float-point graph with a quantized input array. A Dequantized operator is
|
||||
added immediately after the input array. Must be `{FLOAT, INT8,
|
||||
QUANTIZED_UINT8}`.
|
||||
added immediately after the input array. Must be `{FLOAT, INT8, UINT8}`.
|
||||
|
||||
The flag is typically used for vision models taking a bitmap as input but
|
||||
requiring floating-point inference. For such image models, the uint8 input
|
||||
|
@ -99,7 +99,7 @@ py_test(
|
||||
],
|
||||
python_version = "PY3",
|
||||
# Increased thread count for reducing timeout failures.
|
||||
shard_count = 4,
|
||||
shard_count = 10,
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
@ -109,9 +109,26 @@ py_test(
|
||||
"notsan", # b/160824139
|
||||
],
|
||||
deps = [
|
||||
":convert",
|
||||
":tflite_convert",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/saved_model",
|
||||
"//tensorflow/python/saved_model:save",
|
||||
"//tensorflow/python/training:training_util",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.lite.python import wrap_toco
|
||||
from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
|
||||
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.platform import resource_loader as _resource_loader
|
||||
from tensorflow.python.util import deprecation
|
||||
@ -301,7 +302,7 @@ Alternative, use virtualenv.""")
|
||||
pass
|
||||
|
||||
|
||||
def build_toco_flags(inference_type=lite_constants.FLOAT,
|
||||
def build_toco_flags(inference_type=dtypes.float32,
|
||||
inference_input_type=None,
|
||||
input_format=lite_constants.TENSORFLOW_GRAPHDEF,
|
||||
output_format=lite_constants.TFLITE,
|
||||
@ -352,7 +353,7 @@ def build_toco_flags(inference_type=lite_constants.FLOAT,
|
||||
|
||||
def build_toco_convert_protos(input_tensors,
|
||||
output_tensors,
|
||||
inference_type=lite_constants.FLOAT,
|
||||
inference_type=dtypes.float32,
|
||||
inference_input_type=None,
|
||||
input_format=lite_constants.TENSORFLOW_GRAPHDEF,
|
||||
input_shapes=None,
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.lite.python import convert
|
||||
from tensorflow.lite.python import lite_constants
|
||||
from tensorflow.lite.python import op_hint
|
||||
from tensorflow.lite.python.interpreter import Interpreter
|
||||
from tensorflow.python.client import session
|
||||
@ -59,7 +58,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
|
||||
tflite_model = convert.toco_convert(
|
||||
sess.graph_def, [in_tensor], [out_tensor],
|
||||
inference_type=lite_constants.QUANTIZED_UINT8,
|
||||
inference_type=dtypes.uint8,
|
||||
quantized_input_stats=[(0., 1.)])
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
@ -73,7 +72,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
tflite_model = convert.toco_convert_graph_def(
|
||||
sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
|
||||
enable_mlir_converter=False,
|
||||
inference_type=lite_constants.FLOAT)
|
||||
inference_type=dtypes.float32)
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
# Check values from converted model.
|
||||
@ -111,7 +110,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
input_arrays_map,
|
||||
output_arrays,
|
||||
enable_mlir_converter=False,
|
||||
inference_type=lite_constants.QUANTIZED_UINT8,
|
||||
inference_type=dtypes.uint8,
|
||||
quantized_input_stats=[(0., 1.), (0., 1.)])
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
@ -158,7 +157,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
input_arrays_map,
|
||||
output_arrays,
|
||||
enable_mlir_converter=False,
|
||||
inference_type=lite_constants.QUANTIZED_UINT8)
|
||||
inference_type=dtypes.uint8)
|
||||
self.assertEqual(
|
||||
"std_dev and mean must be defined when inference_type or "
|
||||
"inference_input_type is QUANTIZED_UINT8 or INT8.",
|
||||
|
@ -163,9 +163,8 @@ class TargetSpec(object):
|
||||
supported_ops: Experimental flag, subject to change. Set of OpsSet options
|
||||
supported by the device. (default set([OpsSet.TFLITE_BUILTINS]))
|
||||
supported_types: List of types for constant values on the target device.
|
||||
Supported values are types exported by lite.constants. Frequently, an
|
||||
optimization choice is driven by the most compact (i.e. smallest) type in
|
||||
this list (default [constants.FLOAT])
|
||||
Frequently, an optimization choice is driven by the most compact
|
||||
(i.e. smallest) type in this list (default [tf.float32])
|
||||
"""
|
||||
|
||||
def __init__(self, supported_ops=None, supported_types=None):
|
||||
@ -200,7 +199,7 @@ class QuantizationMode(object):
|
||||
return (self._any_optimization_enabled() and
|
||||
not self._is_int16x8_target_required() and
|
||||
self._representative_dataset is not None and
|
||||
self._smallest_supported_type() == constants.INT8)
|
||||
self._smallest_supported_type() == _dtypes.int8)
|
||||
|
||||
def is_post_training_integer_quantize_8(self):
|
||||
"""Post training integer 8 quantization."""
|
||||
@ -241,12 +240,12 @@ class QuantizationMode(object):
|
||||
return (self._any_optimization_enabled() and
|
||||
self._representative_dataset is None and
|
||||
not self.contains_training_quant_op() and
|
||||
self._smallest_supported_type() == constants.INT8)
|
||||
self._smallest_supported_type() == _dtypes.int8)
|
||||
|
||||
def post_training_fp16(self):
|
||||
"""Post training fp16 quantize."""
|
||||
return (self._any_optimization_enabled() and
|
||||
self._smallest_supported_type() == constants.FLOAT16)
|
||||
self._smallest_supported_type() == _dtypes.float16)
|
||||
|
||||
def fp32_execution(self):
|
||||
"""If none of the above are true."""
|
||||
@ -259,36 +258,36 @@ class QuantizationMode(object):
|
||||
self.post_training_fp16())
|
||||
|
||||
def activations_type(self):
|
||||
return constants.INT16 if self._is_int16x8_target_required() \
|
||||
else constants.INT8
|
||||
return _dtypes.int16 if self._is_int16x8_target_required() \
|
||||
else _dtypes.int8
|
||||
|
||||
def converter_flags(self, inference_ty=None, inference_input_ty=None):
|
||||
"""Flags to the converter."""
|
||||
if self.is_post_training_integer_quantize():
|
||||
# The inference_input_type is for the quantizer, then we need to keep the
|
||||
# converter inference_input_type to float.
|
||||
inference_input_ty = constants.FLOAT
|
||||
inference_input_ty = _dtypes.float32
|
||||
|
||||
if self.training_time_int8_allow_float():
|
||||
return {
|
||||
"inference_type": inference_ty if inference_ty else \
|
||||
self.activations_type(),
|
||||
"inference_input_type":
|
||||
inference_input_ty if inference_input_ty else constants.FLOAT,
|
||||
inference_input_ty if inference_input_ty else _dtypes.float32,
|
||||
"post_training_quantize": False, # disable dynamic range quantization
|
||||
"quantize_to_float16": False # disable float16 quantization
|
||||
}
|
||||
elif self.post_training_dynamic_range_int8():
|
||||
return {
|
||||
"inference_type": constants.FLOAT,
|
||||
"inference_input_type": constants.FLOAT,
|
||||
"inference_type": _dtypes.float32,
|
||||
"inference_input_type": _dtypes.float32,
|
||||
"post_training_quantize": True, # enable dynamic range quantization
|
||||
"quantize_to_float16": False # disable float16 quantization
|
||||
}
|
||||
elif self.post_training_fp16():
|
||||
return {
|
||||
"inference_type": constants.FLOAT,
|
||||
"inference_input_type": constants.FLOAT,
|
||||
"inference_type": _dtypes.float32,
|
||||
"inference_input_type": _dtypes.float32,
|
||||
"post_training_quantize": True,
|
||||
"quantize_to_float16": True # enable float16 quantization
|
||||
}
|
||||
@ -296,7 +295,7 @@ class QuantizationMode(object):
|
||||
# Note this might still trigger (uint8) quantization to be compatible with
|
||||
# TOCO.
|
||||
return {
|
||||
"inference_type": inference_ty if inference_ty else constants.FLOAT,
|
||||
"inference_type": inference_ty if inference_ty else _dtypes.float32,
|
||||
"inference_input_type": inference_input_ty,
|
||||
"post_training_quantize": False, # enable dynamic range quantization
|
||||
"quantize_to_float16": False # disable float16 quantization
|
||||
@ -305,8 +304,8 @@ class QuantizationMode(object):
|
||||
def quantizer_flags(self, input_ty=None, output_ty=None):
|
||||
"""Default flags to the TFMOT quantizer."""
|
||||
|
||||
inference_input_type = input_ty if input_ty else constants.FLOAT
|
||||
inference_output_type = output_ty if output_ty else constants.FLOAT
|
||||
inference_input_type = input_ty if input_ty else _dtypes.float32
|
||||
inference_output_type = output_ty if output_ty else _dtypes.float32
|
||||
|
||||
if self.post_training_int8_no_float() \
|
||||
or self.post_training_int16x8_no_float():
|
||||
@ -327,9 +326,8 @@ class QuantizationMode(object):
|
||||
else:
|
||||
return False, None
|
||||
|
||||
def flags_modify_model_io_type(self,
|
||||
input_type=constants.FLOAT,
|
||||
output_type=constants.FLOAT):
|
||||
def flags_modify_model_io_type(
|
||||
self, input_type=_dtypes.float32, output_type=_dtypes.float32):
|
||||
"""Flags for modifying the input and output type of a tflite model."""
|
||||
is_post_training_quantize = self.quantizer_flags(input_type, output_type)[0]
|
||||
is_training_time_only_quantize = self.training_time_int8_allow_float() and \
|
||||
@ -353,7 +351,7 @@ class QuantizationMode(object):
|
||||
return
|
||||
|
||||
if self._target_spec.supported_types and (self._smallest_supported_type() !=
|
||||
constants.INT8):
|
||||
_dtypes.int8):
|
||||
raise ValueError("TFLITE_BUILTINS_INT8 requires smallest supported "
|
||||
"type to be INT8.")
|
||||
|
||||
@ -372,7 +370,7 @@ class QuantizationMode(object):
|
||||
def _is_int8_target_required(self):
|
||||
return (set([OpsSet.TFLITE_BUILTINS_INT8]) == set(
|
||||
self._target_spec.supported_ops) or
|
||||
set(self._target_spec.supported_types) == set([constants.INT8]))
|
||||
set(self._target_spec.supported_types) == set([_dtypes.int8]))
|
||||
|
||||
def _is_int16x8_target_required(self):
|
||||
return bool(
|
||||
@ -397,7 +395,7 @@ class QuantizationMode(object):
|
||||
return min(self._target_spec.supported_types, key=lambda x: x.size)
|
||||
else:
|
||||
# The default smallest supported type is INT8.
|
||||
return constants.INT8
|
||||
return _dtypes.int8
|
||||
|
||||
def contains_training_quant_op(self):
|
||||
"""Checks if the graph contains any training-time quantization ops."""
|
||||
@ -556,18 +554,18 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
|
||||
def __init__(self):
|
||||
"""Constructor for TFLiteConverter."""
|
||||
super(TFLiteConverterBaseV2, self).__init__()
|
||||
self.inference_input_type = constants.FLOAT
|
||||
self.inference_output_type = constants.FLOAT
|
||||
self.inference_input_type = _dtypes.float32
|
||||
self.inference_output_type = _dtypes.float32
|
||||
|
||||
def _validate_inference_input_output_types(self, quant_mode):
|
||||
"""Validate inference_input_type and inference_output_type flags."""
|
||||
default_types = [constants.FLOAT]
|
||||
default_types = [_dtypes.float32]
|
||||
# We support integer input/output for integer quantized models only.
|
||||
if quant_mode.training_time_int8_allow_float():
|
||||
if quant_mode.is_post_training_integer_quantize_16x8():
|
||||
all_types = default_types + [constants.INT16]
|
||||
all_types = default_types + [_dtypes.int16]
|
||||
else:
|
||||
all_types = default_types + [constants.INT8, constants.QUANTIZED_UINT8]
|
||||
all_types = default_types + [_dtypes.int8, _dtypes.uint8]
|
||||
if self.inference_input_type not in all_types or \
|
||||
self.inference_output_type not in all_types:
|
||||
all_types_names = ["tf." + t.name for t in all_types]
|
||||
@ -1148,7 +1146,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
|
||||
graph debug info for a set of nodes from the `graph_def`.
|
||||
"""
|
||||
super(TFLiteConverterBaseV1, self).__init__()
|
||||
self.inference_type = constants.FLOAT
|
||||
self.inference_type = _dtypes.float32
|
||||
self.inference_input_type = None
|
||||
self.inference_output_type = None
|
||||
self.output_format = constants.TFLITE
|
||||
@ -1195,7 +1193,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
|
||||
def _validate_quantized_input_stats(self, converter_kwargs, calibrate):
|
||||
"""Ensure the `quantized_input_stats` flag is provided if required."""
|
||||
|
||||
quantized_types = frozenset({constants.INT8, constants.QUANTIZED_UINT8})
|
||||
quantized_types = frozenset({_dtypes.int8, _dtypes.uint8})
|
||||
|
||||
requires_quantized_input_stats = (
|
||||
(converter_kwargs["inference_type"] in quantized_types or
|
||||
@ -1645,8 +1643,8 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
|
||||
quantized_input_stats: Dict of strings representing input tensor names
|
||||
mapped to tuple of floats representing the mean and standard deviation
|
||||
of the training data (e.g., {"foo" : (0., 1.)}). Only need if
|
||||
`inference_input_type` is `QUANTIZED_UINT8`. real_input_value =
|
||||
(quantized_input_value - mean_value) / std_dev_value. (default {})
|
||||
`inference_input_type` is `QUANTIZED_UINT8`. real_input_value =
|
||||
(quantized_input_value - mean_value) / std_dev_value. (default {})
|
||||
default_ranges_stats: Tuple of integers representing (min, max) range values
|
||||
for all arrays without a specified range. Intended for experimenting with
|
||||
quantization via "dummy quantization". (default None)
|
||||
|
@ -155,8 +155,8 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
# Convert model and ensure model is not None.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
converter.inference_input_type = lite_constants.QUANTIZED_UINT8
|
||||
converter.inference_type = lite_constants.FLOAT
|
||||
converter.inference_input_type = dtypes.uint8
|
||||
converter.inference_type = dtypes.float32
|
||||
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
|
||||
tflite_model = converter.convert()
|
||||
self.assertIsNotNone(tflite_model)
|
||||
@ -788,8 +788,8 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
quantized_converter = lite.TFLiteConverter.from_session(
|
||||
sess, [inp], [output])
|
||||
quantized_converter.experimental_new_converter = enable_mlir_converter
|
||||
quantized_converter.inference_input_type = lite_constants.INT8
|
||||
quantized_converter.inference_output_type = lite_constants.INT8
|
||||
quantized_converter.inference_input_type = dtypes.int8
|
||||
quantized_converter.inference_output_type = dtypes.int8
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
quantized_converter.representative_dataset = calibration_gen
|
||||
quantized_tflite_model = quantized_converter.convert()
|
||||
@ -832,7 +832,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
quantized_converter.experimental_new_converter = enable_mlir_converter
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
# Restricting to int8 type only
|
||||
quantized_converter.target_spec.supported_types = [lite.constants.INT8]
|
||||
quantized_converter.target_spec.supported_types = [dtypes.int8]
|
||||
# A representative dataset is required for full fixed point quantization.
|
||||
with self.assertRaises(ValueError) as error:
|
||||
quantized_converter.convert()
|
||||
@ -857,7 +857,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
converter = lite.TFLiteConverter.from_session(sess,
|
||||
[in_tensor_1, in_tensor_2],
|
||||
[out_tensor])
|
||||
converter.inference_type = lite_constants.QUANTIZED_UINT8
|
||||
converter.inference_type = dtypes.uint8
|
||||
converter.quantized_input_stats = {
|
||||
'inputA': (0., 1.),
|
||||
'inputB': (0., 1.)
|
||||
@ -898,7 +898,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
# Convert model and ensure model is not None.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
converter.inference_type = lite_constants.QUANTIZED_UINT8
|
||||
converter.inference_type = dtypes.uint8
|
||||
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
|
||||
converter.default_ranges_stats = (0, 6) # min, max
|
||||
tflite_model = converter.convert()
|
||||
@ -954,16 +954,15 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
interpreter.allocate_tensors()
|
||||
self.assertEqual(interpreter.get_tensor_details()[idx]['name'], node_name)
|
||||
self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
|
||||
lite.constants.FLOAT)
|
||||
dtypes.float32)
|
||||
# Convert model to quantized version
|
||||
quantized_converter = lite.TFLiteConverter.from_session(
|
||||
sess, [inp], [output])
|
||||
quantized_converter.experimental_new_converter = enable_mlir_converter
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
quantized_converter.target_spec.supported_types = [lite.constants.FLOAT16]
|
||||
quantized_converter.target_spec.supported_types = [dtypes.float16]
|
||||
if include_int8:
|
||||
quantized_converter.target_spec.supported_types.append(
|
||||
lite.constants.INT8)
|
||||
quantized_converter.target_spec.supported_types.append(dtypes.int8)
|
||||
if use_rep_data:
|
||||
quantized_converter.representative_dataset = calibration_gen
|
||||
|
||||
@ -984,11 +983,11 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
if is_float16_quantized:
|
||||
# Verify that bias constant is float16 type.
|
||||
self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
|
||||
lite.constants.FLOAT16)
|
||||
dtypes.float16)
|
||||
elif is_post_training_quantized:
|
||||
# Verify that bias constants is int32 type.
|
||||
self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
|
||||
lite.constants.INT32)
|
||||
dtypes.int32)
|
||||
else:
|
||||
raise ValueError('Invalid test options.')
|
||||
|
||||
@ -1005,7 +1004,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
sess, [inp], [output])
|
||||
quantized_converter.experimental_new_converter = enable_mlir_converter
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
quantized_converter.target_spec.supported_types = [lite.constants.FLOAT16]
|
||||
quantized_converter.target_spec.supported_types = [dtypes.float16]
|
||||
# Specify only int8 builtin ops
|
||||
quantized_converter.target_spec.supported_ops = [
|
||||
lite.OpsSet.TFLITE_BUILTINS_INT8
|
||||
@ -1017,8 +1016,8 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
str(error.exception))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('InferenceType_INT8', lite_constants.INT8),
|
||||
('InferenceType_UINT8', lite_constants.QUANTIZED_UINT8))
|
||||
('InferenceType_INT8', dtypes.int8),
|
||||
('InferenceType_UINT8', dtypes.uint8))
|
||||
def testInvalidQuantizeQATModelRequiresInputStats(self, quantized_type):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(
|
||||
@ -1039,7 +1038,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
'flag is set to tf.uint8 or tf.int8.', str(error.exception))
|
||||
|
||||
with self.assertRaises(ValueError) as error:
|
||||
quantized_converter.inference_type = lite_constants.FLOAT
|
||||
quantized_converter.inference_type = dtypes.float32
|
||||
quantized_converter.inference_input_type = quantized_type
|
||||
quantized_converter.convert()
|
||||
self.assertEqual(
|
||||
@ -1070,7 +1069,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
converter = lite.TFLiteConverter.from_session(sess,
|
||||
[in_tensor_1, in_tensor_2],
|
||||
[out_tensor])
|
||||
converter.inference_type = lite_constants.QUANTIZED_UINT8
|
||||
converter.inference_type = dtypes.uint8
|
||||
converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev
|
||||
with self.assertRaises(ValueError) as error:
|
||||
converter.convert()
|
||||
@ -1091,9 +1090,9 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
|
||||
|
||||
# extra flags to trigger training time quantization conversion
|
||||
converter.inference_type = lite_constants.INT8
|
||||
converter.inference_input_type = lite_constants.FLOAT
|
||||
converter.inference_output_type = lite_constants.FLOAT
|
||||
converter.inference_type = dtypes.int8
|
||||
converter.inference_input_type = dtypes.float32
|
||||
converter.inference_output_type = dtypes.float32
|
||||
input_arrays = converter.get_input_arrays()
|
||||
converter.quantized_input_stats = {
|
||||
input_arrays[0]: (0., 1.)
|
||||
@ -1255,7 +1254,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
# Convert model and ensure model is not None.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
converter.inference_type = lite_constants.QUANTIZED_UINT8
|
||||
converter.inference_type = dtypes.uint8
|
||||
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
|
||||
tflite_model = converter.convert()
|
||||
self.assertIsNotNone(tflite_model)
|
||||
@ -2334,7 +2333,7 @@ class DefaultConverterAttrsTest(LiteTest):
|
||||
self.assertEqual(converter.output_format, lite_constants.TFLITE)
|
||||
|
||||
# Assert the default inference type is float.
|
||||
self.assertEqual(converter.inference_type, lite_constants.FLOAT)
|
||||
self.assertEqual(converter.inference_type, dtypes.float32)
|
||||
|
||||
# Assert the default inference type overrides are None.
|
||||
self.assertIsNone(converter.inference_input_type)
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.lite.python import lite_v2_test_util
|
||||
from tensorflow.lite.python.convert import mlir_quantize
|
||||
from tensorflow.lite.python.interpreter import Interpreter
|
||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.layers import recurrent
|
||||
@ -74,9 +75,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
self.assertEqual(expected_value.numpy(), actual_value)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_INT8InputOutput', lite.constants.INT8),
|
||||
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8),
|
||||
('_INT16InputOutput', lite.constants.INT16))
|
||||
('_INT8InputOutput', dtypes.int8),
|
||||
('_UINT8InputOutput', dtypes.uint8),
|
||||
('_INT16InputOutput', dtypes.int16))
|
||||
@test_util.run_v2_only
|
||||
def testInvalidFloat(self, inference_input_output_type):
|
||||
root = self._getSimpleVariableModel()
|
||||
@ -194,9 +195,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_INT8InputOutput', lite.constants.INT8),
|
||||
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8),
|
||||
('_INT16InputOutput', lite.constants.INT16))
|
||||
('_INT8InputOutput', dtypes.int8),
|
||||
('_UINT8InputOutput', dtypes.uint8),
|
||||
('_INT16InputOutput', dtypes.int16))
|
||||
@test_util.run_v2_only
|
||||
def testInvalidPostTrainingDynamicRangeQuantization(
|
||||
self, inference_input_output_type):
|
||||
@ -219,18 +220,18 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
'must be tf.float32.', str(error.exception))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_Default', False, False, lite.constants.FLOAT),
|
||||
('_INT8InputOutput', False, False, lite.constants.INT8),
|
||||
('_UINT8InputOutput', False, False, lite.constants.QUANTIZED_UINT8),
|
||||
('_INT16Quantize', False, True, lite.constants.FLOAT),
|
||||
('_INT16Quantize_INT16InputOutput', False, True, lite.constants.INT16),
|
||||
('_IntOnly', True, False, lite.constants.FLOAT),
|
||||
('_IntOnly_INT8InputOutput', True, False, lite.constants.INT8),
|
||||
('_Default', False, False, dtypes.float32),
|
||||
('_INT8InputOutput', False, False, dtypes.int8),
|
||||
('_UINT8InputOutput', False, False, dtypes.uint8),
|
||||
('_INT16Quantize', False, True, dtypes.float32),
|
||||
('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
|
||||
('_IntOnly', True, False, dtypes.float32),
|
||||
('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
|
||||
('_IntOnly_UINT8InputOutput', True, False,
|
||||
lite.constants.QUANTIZED_UINT8),
|
||||
('_IntOnly_INT16Quantize', True, True, lite.constants.FLOAT),
|
||||
dtypes.uint8),
|
||||
('_IntOnly_INT16Quantize', True, True, dtypes.float32),
|
||||
('_IntOnly_INT16Quantize_INT16InputOutput', True, True,
|
||||
lite.constants.INT16))
|
||||
dtypes.int16))
|
||||
def testIntegerQuantization(self, is_int_only, is_int16_quantize,
|
||||
inference_input_output_type):
|
||||
func, calibration_gen = self._getIntegerQuantizeModel()
|
||||
@ -281,7 +282,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
self.assertLess(len(quantized_tflite_model), len(tflite_model))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_INT16Quantize_INT8InputOutput', True, lite.constants.INT8))
|
||||
('_INT16Quantize_INT8InputOutput', True, dtypes.int8))
|
||||
def testInvalidIntegerQuantization(self, is_int16_quantize,
|
||||
inference_input_output_type):
|
||||
func, calibration_gen = self._getIntegerQuantizeModel()
|
||||
@ -297,8 +298,8 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
lite.OpsSet.TFLITE_BUILTINS
|
||||
]
|
||||
with self.assertRaises(ValueError) as error:
|
||||
quantized_converter.inference_input_type = lite.constants.INT8
|
||||
quantized_converter.inference_output_type = lite.constants.INT8
|
||||
quantized_converter.inference_input_type = dtypes.int8
|
||||
quantized_converter.inference_output_type = dtypes.int8
|
||||
quantized_converter.convert()
|
||||
self.assertEqual(
|
||||
"The inference_input_type and inference_output_type "
|
||||
@ -377,9 +378,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
return tf.keras.Sequential(QLinear(3, input_shape=(2,)))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_DefaultFLOAT32InputOutput', lite.constants.FLOAT),
|
||||
('_INT8InputOutput', lite.constants.INT8),
|
||||
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
|
||||
('_DefaultFLOAT32InputOutput', dtypes.float32),
|
||||
('_INT8InputOutput', dtypes.int8),
|
||||
('_UINT8InputOutput', dtypes.uint8))
|
||||
@test_util.run_v2_only
|
||||
def testTrainingTimeQuantization(self, inference_input_output_type):
|
||||
model = self._getTrainingTimeQuantizedModel()
|
||||
|
@ -50,6 +50,7 @@ py_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":_pywrap_tensorflow_lite_calibration_wrapper", # buildcleaner: keep
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
@ -67,8 +68,8 @@ py_test(
|
||||
tags = ["no_oss"],
|
||||
deps = [
|
||||
":calibrator",
|
||||
"//tensorflow/lite/python:lite_constants",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//third_party/py/numpy",
|
||||
|
@ -18,8 +18,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
from tensorflow.lite.python import lite_constants
|
||||
|
||||
# Lazy load since some of the performance benchmark skylark rules
|
||||
# break dependencies. Must use double quotes to match code internal rewrite
|
||||
@ -60,7 +61,7 @@ class Calibrator(object):
|
||||
input_type,
|
||||
output_type,
|
||||
allow_float,
|
||||
activations_type=lite_constants.INT8,
|
||||
activations_type=dtypes.int8,
|
||||
resize_input=True):
|
||||
"""Calibrates the model with specified generator and then quantizes it.
|
||||
|
||||
|
@ -23,8 +23,8 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from six.moves import range
|
||||
|
||||
from tensorflow.lite.python import lite_constants as constants
|
||||
from tensorflow.lite.python.optimize import calibrator as _calibrator
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import test
|
||||
@ -34,9 +34,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
# Activation type Int8
|
||||
('UseActivationTypeInt8', constants.INT8),
|
||||
('UseActivationTypeInt8', dtypes.int8),
|
||||
# Activation type Int16
|
||||
('UseActivationTypeInt16', constants.INT16))
|
||||
('UseActivationTypeInt16', dtypes.int16))
|
||||
def test_calibration_with_quantization(self, activations_type):
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
'test_data/mobilenet_like_model.bin')
|
||||
@ -49,16 +49,17 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
|
||||
|
||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||
constants.FLOAT,
|
||||
constants.FLOAT, False,
|
||||
dtypes.float32,
|
||||
dtypes.float32,
|
||||
False,
|
||||
activations_type)
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
# Activation type Int8
|
||||
('UseActivationTypeInt8', constants.INT8),
|
||||
('UseActivationTypeInt8', dtypes.int8),
|
||||
# Activation type Int16
|
||||
('UseActivationTypeInt16', constants.INT16))
|
||||
('UseActivationTypeInt16', dtypes.int16))
|
||||
def test_calibration_with_quantization_allow_float(self, activations_type):
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
'test_data/mobilenet_like_model.bin')
|
||||
@ -71,8 +72,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
|
||||
|
||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||
constants.FLOAT,
|
||||
constants.FLOAT, True,
|
||||
dtypes.float32,
|
||||
dtypes.float32,
|
||||
True,
|
||||
activations_type)
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
@ -88,7 +90,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
|
||||
|
||||
quantized_model = quantizer.calibrate_and_quantize_single(
|
||||
input_gen, constants.FLOAT, constants.FLOAT, True, 'conv2d_8/BiasAdd')
|
||||
input_gen, dtypes.float32, dtypes.float32, True, 'conv2d_8/BiasAdd')
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
def test_calibration_with_string_input(self):
|
||||
@ -103,14 +105,14 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
yield [np.array(u'Test' + str(i))]
|
||||
|
||||
quantized_model = quantizer.calibrate_and_quantize_single(
|
||||
input_gen, constants.FLOAT, constants.FLOAT, True, 'Identity')
|
||||
input_gen, dtypes.float32, dtypes.float32, True, 'Identity')
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
# Activation type Int8
|
||||
('UseActivationTypeInt8 - EnableMlirQuantizer', constants.INT8),
|
||||
('UseActivationTypeInt8 - EnableMlirQuantizer', dtypes.int8),
|
||||
# Activation type Int16
|
||||
('UseActivationTypeInt16 - DisableEnableMlirQuantizer', constants.INT16))
|
||||
('UseActivationTypeInt16 - DisableEnableMlirQuantizer', dtypes.int16))
|
||||
def test_calibration_with_quantization_multiple_inputs(
|
||||
self, activations_type):
|
||||
# Load multi add model from test data.
|
||||
@ -126,8 +128,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
yield [np.ones(shape=(1, 8, 8, 3), dtype=np.float32) for _ in range(4)]
|
||||
|
||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||
constants.FLOAT,
|
||||
constants.FLOAT, False,
|
||||
dtypes.float32,
|
||||
dtypes.float32,
|
||||
False,
|
||||
activations_type)
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
@ -148,8 +151,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
yield i
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
quantizer.calibrate_and_quantize(empty_input_gen, constants.FLOAT,
|
||||
constants.FLOAT, False)
|
||||
quantizer.calibrate_and_quantize(empty_input_gen, dtypes.float32,
|
||||
dtypes.float32, False)
|
||||
|
||||
def test_invalid_shape_calibrator_gen(self):
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
@ -163,8 +166,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
yield [np.ones(shape=(1, 2, 2, 3), dtype=np.float32)]
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'Size mismatch'):
|
||||
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
||||
constants.FLOAT, False, constants.INT8,
|
||||
quantizer.calibrate_and_quantize(input_gen, dtypes.float32,
|
||||
dtypes.float32, False, dtypes.int8,
|
||||
False)
|
||||
|
||||
def test_invalid_type_calibrator_gen(self):
|
||||
@ -179,8 +182,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.int32)]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
||||
constants.FLOAT, False, constants.INT8)
|
||||
quantizer.calibrate_and_quantize(input_gen, dtypes.float32,
|
||||
dtypes.float32, False, dtypes.int8)
|
||||
|
||||
def test_calibration(self):
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
|
@ -28,12 +28,12 @@ import six
|
||||
from six.moves import zip
|
||||
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python import lite_constants
|
||||
from tensorflow.lite.python.convert import register_custom_opdefs
|
||||
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
||||
from tensorflow.lite.toco.logging import gen_html
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.platform import app
|
||||
|
||||
|
||||
@ -63,13 +63,14 @@ def _parse_inference_type(value, flag):
|
||||
ValueError: Unsupported value.
|
||||
"""
|
||||
if value == "FLOAT":
|
||||
return lite_constants.FLOAT
|
||||
if value == "QUANTIZED_UINT8":
|
||||
return lite_constants.QUANTIZED_UINT8
|
||||
return dtypes.float32
|
||||
if value == "INT8":
|
||||
return lite_constants.INT8
|
||||
raise ValueError("Unsupported value for --{0}. Only FLOAT and "
|
||||
"QUANTIZED_UINT8 are supported.".format(flag))
|
||||
return dtypes.int8
|
||||
if value == "UINT8" or value == "QUANTIZED_UINT8":
|
||||
return dtypes.uint8
|
||||
raise ValueError(
|
||||
"Unsupported value for `{}` flag. Expected FLOAT, INT8 or UINT8, instead "
|
||||
"got {}.".format(flag, value))
|
||||
|
||||
|
||||
def _get_tflite_converter(flags):
|
||||
@ -151,10 +152,10 @@ def _convert_tf1_model(flags):
|
||||
|
||||
# In quantized inference, mean_value has to be integer so that the real
|
||||
# value 0.0 is exactly representable.
|
||||
if converter.inference_type == lite_constants.QUANTIZED_UINT8:
|
||||
mean_values = _parse_array(flags.mean_values, type_fn=int)
|
||||
else:
|
||||
if converter.inference_type == dtypes.float32:
|
||||
mean_values = _parse_array(flags.mean_values, type_fn=float)
|
||||
else:
|
||||
mean_values = _parse_array(flags.mean_values, type_fn=int)
|
||||
quant_stats = list(zip(mean_values, std_dev_values))
|
||||
if ((not flags.input_arrays and len(input_arrays) > 1) or
|
||||
(len(input_arrays) != len(quant_stats))):
|
||||
@ -193,13 +194,13 @@ def _convert_tf1_model(flags):
|
||||
|
||||
if flags.post_training_quantize:
|
||||
converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
if converter.inference_type == lite_constants.QUANTIZED_UINT8:
|
||||
if converter.inference_type != dtypes.float32:
|
||||
print("--post_training_quantize quantizes a graph of inference_type "
|
||||
"FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.")
|
||||
converter.inference_type = lite_constants.FLOAT
|
||||
"FLOAT. Overriding inference_type to FLOAT.")
|
||||
converter.inference_type = dtypes.float32
|
||||
|
||||
if flags.quantize_to_float16:
|
||||
converter.target_spec.supported_types = [lite.constants.FLOAT16]
|
||||
converter.target_spec.supported_types = [dtypes.float16]
|
||||
if not flags.post_training_quantize:
|
||||
print("--quantize_to_float16 will only take effect with the "
|
||||
"--post_training_quantize flag enabled.")
|
||||
@ -358,14 +359,15 @@ def _get_tf1_flags(parser):
|
||||
parser.add_argument(
|
||||
"--inference_type",
|
||||
type=str.upper,
|
||||
choices=["FLOAT", "QUANTIZED_UINT8", "INT8"],
|
||||
help="Target data type of real-number arrays in the output file.")
|
||||
default="FLOAT",
|
||||
help=("Target data type of real-number arrays in the output file. "
|
||||
"Must be either FLOAT, INT8 or UINT8."))
|
||||
parser.add_argument(
|
||||
"--inference_input_type",
|
||||
type=str.upper,
|
||||
choices=["FLOAT", "QUANTIZED_UINT8", "INT8"],
|
||||
help=("Target data type of real-number input arrays. Allows for a "
|
||||
"different type for input arrays in the case of quantization."))
|
||||
"different type for input arrays in the case of quantization. "
|
||||
"Must be either FLOAT, INT8 or UINT8."))
|
||||
|
||||
# Input and output arrays flags.
|
||||
parser.add_argument(
|
||||
|
@ -168,6 +168,37 @@ class TfLiteConvertV1Test(TestModels):
|
||||
self._run(flags_str, should_succeed=True)
|
||||
os.remove(graph_def_file)
|
||||
|
||||
def testQATFrozenGraphDefUInt8(self):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor_1 = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
|
||||
in_tensor_2 = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
|
||||
_ = array_ops.fake_quant_with_min_max_args(
|
||||
in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
|
||||
sess = session.Session()
|
||||
|
||||
# Write graph to file.
|
||||
graph_def_file = self._getFilepath('model.pb')
|
||||
write_graph(sess.graph_def, '', graph_def_file, False)
|
||||
sess.close()
|
||||
|
||||
# Define converter flags
|
||||
flags_str = ('--std_dev_values=128,128 --mean_values=128,128 '
|
||||
'--graph_def_file={0} --input_arrays={1} '
|
||||
'--output_arrays={2}'.format(
|
||||
graph_def_file, 'inputA,inputB', 'output'))
|
||||
|
||||
# Set inference_type UINT8 and (default) inference_input_type UINT8
|
||||
flags_str_1 = flags_str + ' --inference_type=UINT8'
|
||||
self._run(flags_str_1, should_succeed=True)
|
||||
|
||||
# Set inference_type UINT8 and inference_input_type FLOAT
|
||||
flags_str_2 = flags_str_1 + ' --inference_input_type=FLOAT'
|
||||
self._run(flags_str_2, should_succeed=True)
|
||||
|
||||
os.remove(graph_def_file)
|
||||
|
||||
def testSavedModel(self):
|
||||
saved_model_dir = self._getFilepath('model')
|
||||
with ops.Graph().as_default():
|
||||
|
@ -24,7 +24,6 @@ import numpy as np
|
||||
from six.moves import range
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.lite.python import lite_constants
|
||||
from tensorflow.lite.python import util
|
||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||
from tensorflow.python.client import session
|
||||
@ -292,9 +291,9 @@ def _test_param_modify_integer_model_io_type():
|
||||
# "DuringTraining": False,
|
||||
}
|
||||
map_types = {
|
||||
"": lite_constants.FLOAT,
|
||||
"INT8": lite_constants.INT8,
|
||||
"UINT8": lite_constants.QUANTIZED_UINT8
|
||||
"": dtypes.float32,
|
||||
"INT8": dtypes.int8,
|
||||
"UINT8": dtypes.uint8,
|
||||
}
|
||||
for k1, v1 in map_model_type.items():
|
||||
for k2, v2 in map_types.items():
|
||||
|
@ -28,11 +28,11 @@ from google.protobuf.message import DecodeError
|
||||
from tensorflow.core.framework import graph_pb2 as _graph_pb2
|
||||
from tensorflow.lite.python import convert_saved_model as _convert_saved_model
|
||||
from tensorflow.lite.python import lite as _lite
|
||||
from tensorflow.lite.python import lite_constants as constants
|
||||
from tensorflow.lite.python import util as _util
|
||||
from tensorflow.python import keras as _keras
|
||||
from tensorflow.python.client import session as _session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
|
||||
from tensorflow.python.keras.preprocessing import image
|
||||
@ -97,7 +97,7 @@ def _convert(converter, **kwargs):
|
||||
if "post_training_quantize" in kwargs:
|
||||
converter.optimizations = [_lite.Optimize.DEFAULT]
|
||||
if kwargs.get("quantize_to_float16", False):
|
||||
converter.target_spec.supported_types = [constants.FLOAT16]
|
||||
converter.target_spec.supported_types = [dtypes.float16]
|
||||
if kwargs.get("post_training_quantize_16x8", False):
|
||||
input_size = kwargs.get("model_input_size")
|
||||
|
||||
|
@ -52,7 +52,7 @@ py_library(
|
||||
name = "modify_model_interface_constants",
|
||||
srcs = ["modify_model_interface_constants.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = ["//tensorflow/lite/python:lite_constants"],
|
||||
deps = ["//tensorflow/python:dtypes"],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
|
@ -19,12 +19,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.lite.python import lite_constants
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
||||
STR_TO_TFLITE_TYPES = {
|
||||
'INT8': lite_constants.INT8,
|
||||
'INT16': lite_constants.INT16,
|
||||
'UINT8': lite_constants.QUANTIZED_UINT8
|
||||
'INT8': dtypes.int8,
|
||||
'UINT8': dtypes.uint8,
|
||||
'INT16': dtypes.int16,
|
||||
}
|
||||
TFLITE_TO_STR_TYPES = {v: k for k, v in STR_TO_TFLITE_TYPES.items()}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user