Merge pull request #36251 from wwwind:interface_16x8
PiperOrigin-RevId: 317232781
This commit is contained in:
commit
3d868aa1c6
|
@ -142,6 +142,50 @@ The disadvantages of float16 quantization are as follows:
|
|||
to float32 when run on the CPU. (Note that the GPU delegate will not perform
|
||||
this dequantization, since it can operate on float16 data.)
|
||||
|
||||
### Integer only: 16-bit activations with 8-bit weights (experimental)
|
||||
|
||||
This is an experimental quantization scheme. It is similar to the "integer only"
|
||||
scheme, but activations are quantized based on their range to 16-bits, weights
|
||||
are quantized in 8-bit integer and bias is quantized into 64-bit integer. This
|
||||
is referred to as 16x8 quantization further.
|
||||
|
||||
The main advantage of this quantization is that it can improve accuracy
|
||||
significantly, but only slightly increase model size.
|
||||
|
||||
<pre>
|
||||
import tensorflow as tf
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
||||
<b>converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
converter.target_spec.supported_types = [tf.lite.constants.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8]</b>
|
||||
tflite_quant_model = converter.convert()
|
||||
</pre>
|
||||
|
||||
If 16x8 quantization is not supported for some operators in the model,
|
||||
then the model still can be quantized, but unsupported operators kept in float.
|
||||
The following option should be added to the target_spec to allow this.
|
||||
<pre>
|
||||
import tensorflow as tf
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
converter.target_spec.supported_types = [tf.lite.constants.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
|
||||
<b>tf.lite.OpsSet.TFLITE_BUILTINS</b>]
|
||||
tflite_quant_model = converter.convert()
|
||||
</pre>
|
||||
|
||||
Examples of the use cases where accuracy improvements provided by this
|
||||
quantization scheme include: * super-resolution, * audio signal processing such
|
||||
as noise cancelling and beamforming, * image de-noising, * HDR reconstruction
|
||||
from a single image.
|
||||
|
||||
The disadvantage of this quantization is:
|
||||
|
||||
* Currently inference is noticeably slower than 8-bit full integer due to the
|
||||
lack of optimized kernel implementation.
|
||||
* Currently it is incompatible with the existing hardware accelerated TFLite
|
||||
delegates.
|
||||
|
||||
Note: This is an experimental feature.
|
||||
|
||||
### Model accuracy
|
||||
|
||||
Since weights are quantized post training, there could be an accuracy loss,
|
||||
|
|
|
@ -94,6 +94,20 @@ class OpsSet(enum.Enum):
|
|||
# quantized implementations.
|
||||
TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8"
|
||||
|
||||
# Convert model using only TensorFlow Lite operations with quantized int8
|
||||
# weights, int16 activations and int64 bias.
|
||||
# Specifying this will throw an error for operations that do not yet have
|
||||
# quantized implementations.
|
||||
# This quantization mode may be used in models for super-resolution,
|
||||
# audio signal processing or image de-noising. It improves accuracy
|
||||
# significantly, but only slightly increases the model size.
|
||||
# WARNING: These ops are currently experimental and have not yet been
|
||||
# finalized.
|
||||
# They are only compatible with CPU execution, and have not been optimized for
|
||||
# production.
|
||||
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 = \
|
||||
"EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
|
|
@ -193,11 +193,13 @@ class QuantizationMode(object):
|
|||
def post_training_int8_no_float(self):
|
||||
"""Post training int8 quantize, disallow float fallback."""
|
||||
return (self._is_int8_target_required() and
|
||||
not self._is_int16x8_target_required() and
|
||||
self._representative_dataset is not None)
|
||||
|
||||
def post_training_int8_allow_float(self):
|
||||
"""Post training int8 quantize, allow float fallback."""
|
||||
return (self._any_optimization_enabled() and
|
||||
not self._is_int16x8_target_required() and
|
||||
self._representative_dataset is not None and
|
||||
self._smallest_supported_type() == constants.INT8)
|
||||
|
||||
|
@ -212,6 +214,17 @@ class QuantizationMode(object):
|
|||
not self.post_training_dynamic_range_int8() and
|
||||
not self.post_training_fp16())
|
||||
|
||||
def post_training_int16x8_no_float(self):
|
||||
"""Post training int16x8 quantize, disallow float fallback."""
|
||||
return (not self._is_int8_target_required() and
|
||||
self._is_int16x8_target_required() and
|
||||
not self._is_allow_float() and
|
||||
self._representative_dataset is not None)
|
||||
|
||||
def post_training_int16x8_allow_float(self):
|
||||
"""Post training int16x8 quantize, allow float fallback."""
|
||||
return (self._is_int16x8_target_required() and self._is_allow_float())
|
||||
|
||||
def post_training_dynamic_range_int8(self):
|
||||
"""Post training int8 const, on-the-fly int8 quantize of dynamic tensors."""
|
||||
# Post-training dynamic range quantization is only enabled if post-training
|
||||
|
@ -231,9 +244,15 @@ class QuantizationMode(object):
|
|||
return not (self.post_training_int8_no_float() or
|
||||
self.post_training_int8_allow_float() or
|
||||
self.training_time_int8_allow_float() or
|
||||
self.post_training_int16x8_no_float() or
|
||||
self.post_training_int16x8_allow_float() or
|
||||
self.post_training_dynamic_range_int8() or
|
||||
self.post_training_fp16())
|
||||
|
||||
def activations_type(self):
|
||||
return constants.INT16 if self._is_int16x8_target_required() \
|
||||
else constants.INT8
|
||||
|
||||
def converter_flags(self, inference_ty=None, inference_input_ty=None):
|
||||
"""Flags to the converter."""
|
||||
if self.is_post_training_integer_quantize():
|
||||
|
@ -243,7 +262,8 @@ class QuantizationMode(object):
|
|||
|
||||
if self.training_time_int8_allow_float():
|
||||
return {
|
||||
"inference_type": inference_ty if inference_ty else constants.INT8,
|
||||
"inference_type": inference_ty if inference_ty else \
|
||||
self.activations_type(),
|
||||
"inference_input_type":
|
||||
inference_input_ty if inference_input_ty else constants.FLOAT,
|
||||
"post_training_quantize": False, # disable dynamic range quantization
|
||||
|
@ -278,16 +298,21 @@ class QuantizationMode(object):
|
|||
|
||||
inference_input_type = input_ty if input_ty else constants.FLOAT
|
||||
inference_output_type = output_ty if output_ty else constants.FLOAT
|
||||
if self.post_training_int8_no_float():
|
||||
|
||||
if self.post_training_int8_no_float() \
|
||||
or self.post_training_int16x8_no_float():
|
||||
return True, {
|
||||
"inference_input_type": inference_input_type,
|
||||
"inference_output_type": inference_output_type,
|
||||
"activations_type": self.activations_type(),
|
||||
"allow_float": False
|
||||
}
|
||||
elif self.post_training_int8_allow_float():
|
||||
elif self.post_training_int8_allow_float() \
|
||||
or self.post_training_int16x8_allow_float():
|
||||
return True, {
|
||||
"inference_input_type": inference_input_type,
|
||||
"inference_output_type": inference_output_type,
|
||||
"activations_type": self.activations_type(),
|
||||
"allow_float": True
|
||||
}
|
||||
else:
|
||||
|
@ -322,6 +347,17 @@ class QuantizationMode(object):
|
|||
self._target_spec.supported_ops) or
|
||||
set(self._target_spec.supported_types) == set([constants.INT8]))
|
||||
|
||||
def _is_int16x8_target_required(self):
|
||||
return bool(
|
||||
set(self._target_spec.supported_ops).intersection([
|
||||
OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
|
||||
]))
|
||||
|
||||
def _is_allow_float(self):
|
||||
return bool(
|
||||
set(self._target_spec.supported_ops).intersection(
|
||||
[OpsSet.TFLITE_BUILTINS]))
|
||||
|
||||
def _any_optimization_enabled(self):
|
||||
return bool(
|
||||
set(self._optimizations).intersection([
|
||||
|
@ -394,11 +430,13 @@ class TFLiteConverterBase(object):
|
|||
return _get_grappler_config(optimizers)
|
||||
|
||||
def _calibrate_quantize_model(self, result, inference_input_type,
|
||||
inference_output_type, allow_float):
|
||||
inference_output_type, activations_type,
|
||||
allow_float):
|
||||
"""Calibrate and quantize the model."""
|
||||
if not isinstance(self.representative_dataset, RepresentativeDataset):
|
||||
self.representative_dataset = RepresentativeDataset(
|
||||
self.representative_dataset)
|
||||
|
||||
calibrate_quantize = _calibrator.Calibrator(result)
|
||||
if self._experimental_calibrate_only or self._experimental_new_quantizer:
|
||||
calibrated = calibrate_quantize.calibrate(
|
||||
|
@ -411,7 +449,7 @@ class TFLiteConverterBase(object):
|
|||
else:
|
||||
return calibrate_quantize.calibrate_and_quantize(
|
||||
self.representative_dataset.input_gen, inference_input_type,
|
||||
inference_output_type, allow_float)
|
||||
inference_output_type, allow_float, activations_type)
|
||||
|
||||
def _is_unknown_shapes_allowed(self):
|
||||
# Unknown dimensions are only allowed with the new converter.
|
||||
|
@ -1931,7 +1969,6 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
|
|||
"""
|
||||
return super(TFLiteConverter, self).convert()
|
||||
|
||||
|
||||
@_tf_export(v1=["lite.TocoConverter"])
|
||||
class TocoConverter(object):
|
||||
"""Convert a TensorFlow model into `output_format` using TOCO.
|
||||
|
|
|
@ -30,6 +30,7 @@ INT64 = dtypes.int64
|
|||
STRING = dtypes.string
|
||||
QUANTIZED_UINT8 = dtypes.uint8
|
||||
INT8 = dtypes.int8
|
||||
INT16 = dtypes.int16
|
||||
COMPLEX64 = dtypes.complex64
|
||||
TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF
|
||||
TFLITE = _toco_flags_pb2.TFLITE
|
||||
|
@ -43,6 +44,7 @@ _tf_export(v1=["lite.constants.STRING"]).export_constant(__name__, "STRING")
|
|||
_tf_export(v1=["lite.constants.QUANTIZED_UINT8"]).export_constant(
|
||||
__name__, "QUANTIZED_UINT8")
|
||||
_tf_export(v1=["lite.constants.INT8"]).export_constant(__name__, "INT8")
|
||||
_tf_export(v1=["lite.constants.INT16"]).export_constant(__name__, "INT16")
|
||||
_tf_export(v1=["lite.constants.TFLITE"]).export_constant(__name__, "TFLITE")
|
||||
_tf_export(v1=["lite.constants.GRAPHVIZ_DOT"]).export_constant(
|
||||
__name__, "GRAPHVIZ_DOT")
|
||||
|
@ -62,6 +64,7 @@ _allowed_symbols = [
|
|||
"STRING",
|
||||
"QUANTIZED_UINT8",
|
||||
"INT8",
|
||||
"INT16",
|
||||
"COMPLEX64",
|
||||
"TENSORFLOW_GRAPHDEF",
|
||||
"TFLITE",
|
||||
|
|
|
@ -881,9 +881,22 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('EnableMlirConverter', True), # enable mlir
|
||||
('DisableMlirConverter', False)) # disable mlir
|
||||
def testCalibrateAndQuantizeBuiltinInt8(self, enable_mlir):
|
||||
# Quantize model to Int8: with enable mlir
|
||||
('UseTfliteBuiltinsIntEnableMLIR',
|
||||
[lite.OpsSet.TFLITE_BUILTINS_INT8], True),
|
||||
# Quantize model to Int8: with disable mlir
|
||||
('UseTfliteBuiltinsIntDisableMLIR',
|
||||
[lite.OpsSet.TFLITE_BUILTINS_INT8], False),
|
||||
# Quantize model to Int16: with disable mlir
|
||||
('UseTfliteBuiltinsInt16DisableMLIR',
|
||||
[lite.OpsSet.\
|
||||
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8],
|
||||
False),
|
||||
('UseTfliteBuiltinsInt16EnableMLIR',
|
||||
[lite.OpsSet.\
|
||||
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8],
|
||||
True))
|
||||
def testCalibrateAndQuantizeBuiltinInt(self, supported_ops, enable_mlir):
|
||||
with ops.Graph().as_default():
|
||||
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
|
||||
sess = session.Session()
|
||||
|
@ -899,9 +912,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||
quantized_converter = lite.TFLiteConverter.from_session(
|
||||
sess, [inp], [output])
|
||||
quantized_converter.experimental_new_converter = enable_mlir
|
||||
quantized_converter.target_spec.supported_ops = [
|
||||
lite.OpsSet.TFLITE_BUILTINS_INT8
|
||||
]
|
||||
quantized_converter.target_spec.supported_ops = supported_ops
|
||||
quantized_converter.representative_dataset = calibration_gen
|
||||
quantized_tflite = quantized_converter.convert()
|
||||
self.assertTrue(quantized_tflite)
|
||||
|
|
|
@ -269,7 +269,8 @@ PyObject* CalibrationWrapper::Calibrate() {
|
|||
|
||||
PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
||||
int output_py_type,
|
||||
bool allow_float) {
|
||||
bool allow_float,
|
||||
int activations_py_type) {
|
||||
if (NoOpModel(*model_)) {
|
||||
return python_utils::ConvertToPyString(model_str_->data(),
|
||||
model_str_->size());
|
||||
|
@ -277,6 +278,9 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
|||
|
||||
TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
|
||||
TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
|
||||
TfLiteType activations_type =
|
||||
python_utils::TfLiteTypeFromPyType(activations_py_type);
|
||||
|
||||
if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Input/output type cannot be kTfLiteNoType");
|
||||
|
@ -286,9 +290,11 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
|||
reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto status = kTfLiteOk;
|
||||
status = tflite::optimize::QuantizeModel(
|
||||
|
||||
status = tflite::optimize::QuantizeModelAllOperators(
|
||||
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
|
||||
TfLiteTypeToSchemaType(output_type), allow_float, error_reporter_.get());
|
||||
TfLiteTypeToSchemaType(output_type), allow_float,
|
||||
TfLiteTypeToSchemaType(activations_type), error_reporter_.get());
|
||||
|
||||
if (status != kTfLiteOk) {
|
||||
error_reporter_->exception();
|
||||
|
@ -319,7 +325,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
|||
auto status = tflite::optimize::QuantizeModel(
|
||||
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
|
||||
TfLiteTypeToSchemaType(output_type), allow_float, {op_name},
|
||||
error_reporter_.get());
|
||||
TensorType_INT8, error_reporter_.get());
|
||||
if (status != kTfLiteOk) {
|
||||
error_reporter_->exception();
|
||||
return nullptr;
|
||||
|
|
|
@ -62,7 +62,7 @@ class CalibrationWrapper {
|
|||
PyObject* FeedTensor(PyObject* input_value);
|
||||
|
||||
PyObject* QuantizeModel(int input_py_type, int output_py_type,
|
||||
bool allow_float);
|
||||
bool allow_float, int activations_py_type);
|
||||
|
||||
// Allows quantizing only the operator that produces the tensor with name
|
||||
// operator_output_name. (This can be used to help debug.).
|
||||
|
|
|
@ -43,15 +43,18 @@ PYBIND11_MODULE(_pywrap_tensorflow_lite_calibration_wrapper, m) {
|
|||
})
|
||||
.def("QuantizeModel",
|
||||
[](CalibrationWrapper& self, int input_py_type, int output_py_type,
|
||||
bool allow_float, bool enable_mlir_quantizer) {
|
||||
return tensorflow::PyoOrThrow(self.QuantizeModel(
|
||||
input_py_type, output_py_type, allow_float));
|
||||
bool allow_float, int activations_py_type,
|
||||
bool enable_mlir_quantizer) {
|
||||
return tensorflow::PyoOrThrow(
|
||||
self.QuantizeModel(input_py_type, output_py_type, allow_float,
|
||||
activations_py_type));
|
||||
})
|
||||
.def("QuantizeModel",
|
||||
[](CalibrationWrapper& self, int input_py_type, int output_py_type,
|
||||
bool allow_float) {
|
||||
return tensorflow::PyoOrThrow(self.QuantizeModel(
|
||||
input_py_type, output_py_type, allow_float));
|
||||
bool allow_float, int activations_py_type) {
|
||||
return tensorflow::PyoOrThrow(
|
||||
self.QuantizeModel(input_py_type, output_py_type, allow_float,
|
||||
activations_py_type));
|
||||
})
|
||||
.def("QuantizeModel",
|
||||
[](CalibrationWrapper& self, int input_py_type, int output_py_type,
|
||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
|||
|
||||
import numpy as np
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
from tensorflow.lite.python import lite_constants
|
||||
|
||||
# Lazy load since some of the performance benchmark skylark rules
|
||||
# break dependencies. Must use double quotes to match code internal rewrite
|
||||
|
@ -59,6 +60,7 @@ class Calibrator(object):
|
|||
input_type,
|
||||
output_type,
|
||||
allow_float,
|
||||
activations_type=lite_constants.INT8,
|
||||
resize_input=True):
|
||||
"""Calibrates the model with specified generator and then quantizes it.
|
||||
|
||||
|
@ -73,9 +75,11 @@ class Calibrator(object):
|
|||
input_type: A tf.dtype representing the desired real-value input type.
|
||||
output_type: A tf.dtype representing the desired real-value output type.
|
||||
allow_float: A boolean. False if the resulting model cannot perform float
|
||||
computation, useful when targeting an integer-only backend. If False, an
|
||||
error will be thrown if an operation cannot be quantized, otherwise the
|
||||
model will fallback to float ops.
|
||||
computation, useful when targeting an integer-only backend.
|
||||
If False, an error will be thrown if an operation cannot be
|
||||
quantized, otherwise the model will fallback to float ops.
|
||||
activations_type: A tf.dtype representing the desired type for
|
||||
activations.
|
||||
resize_input: A boolean. True if the shape of the sample data is different
|
||||
from the input.
|
||||
"""
|
||||
|
@ -90,7 +94,8 @@ class Calibrator(object):
|
|||
self._calibrator.FeedTensor(sample)
|
||||
return self._calibrator.QuantizeModel(
|
||||
np.dtype(input_type.as_numpy_dtype()).num,
|
||||
np.dtype(output_type.as_numpy_dtype()).num, allow_float)
|
||||
np.dtype(output_type.as_numpy_dtype()).num, allow_float,
|
||||
np.dtype(activations_type.as_numpy_dtype()).num)
|
||||
|
||||
def calibrate_and_quantize_single(self,
|
||||
dataset_gen,
|
||||
|
|
|
@ -32,7 +32,12 @@ from tensorflow.python.platform import test
|
|||
|
||||
class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def test_calibration_with_quantization(self):
|
||||
@parameterized.named_parameters(
|
||||
# Activation type Int8
|
||||
('UseActivationTypeInt8', constants.INT8),
|
||||
# Activation type Int16
|
||||
('UseActivationTypeInt16', constants.INT16))
|
||||
def test_calibration_with_quantization(self, activations_type):
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
'test_data/mobilenet_like_model.bin')
|
||||
float_model = open(model_path, 'rb').read()
|
||||
|
@ -45,10 +50,16 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
|
||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||
constants.FLOAT,
|
||||
constants.FLOAT, False)
|
||||
constants.FLOAT, False,
|
||||
activations_type)
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
def test_calibration_with_quantization_allow_float(self):
|
||||
@parameterized.named_parameters(
|
||||
# Activation type Int8
|
||||
('UseActivationTypeInt8', constants.INT8),
|
||||
# Activation type Int16
|
||||
('UseActivationTypeInt16', constants.INT16))
|
||||
def test_calibration_with_quantization_allow_float(self, activations_type):
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
'test_data/mobilenet_like_model.bin')
|
||||
float_model = open(model_path, 'rb').read()
|
||||
|
@ -61,7 +72,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
|
||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||
constants.FLOAT,
|
||||
constants.FLOAT, True)
|
||||
constants.FLOAT, True,
|
||||
activations_type)
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
def test_calibration_with_quantization_single_op(self):
|
||||
|
@ -79,7 +91,13 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
input_gen, constants.FLOAT, constants.FLOAT, True, 'conv2d_8/BiasAdd')
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
def test_calibration_with_quantization_multiple_inputs(self):
|
||||
@parameterized.named_parameters(
|
||||
# Activation type Int8
|
||||
('UseActivationTypeInt8 - EnableMlirQuantizer', constants.INT8),
|
||||
# Activation type Int16
|
||||
('UseActivationTypeInt16 - DisableEnableMlirQuantizer', constants.INT16))
|
||||
def test_calibration_with_quantization_multiple_inputs(
|
||||
self, activations_type):
|
||||
# Load multi add model from test data.
|
||||
# This model has 4 inputs of size (1, 8, 8, 3).
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
|
@ -94,7 +112,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
|
||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||
constants.FLOAT,
|
||||
constants.FLOAT, False)
|
||||
constants.FLOAT, False,
|
||||
activations_type)
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
def test_invalid_model_buffer(self):
|
||||
|
@ -130,7 +149,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
|
||||
with self.assertRaisesRegex(ValueError, 'Size mismatch'):
|
||||
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
||||
constants.FLOAT, False, False)
|
||||
constants.FLOAT, False, constants.INT8,
|
||||
False)
|
||||
|
||||
def test_invalid_type_calibrator_gen(self):
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
|
@ -145,7 +165,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
|
||||
with self.assertRaises(ValueError):
|
||||
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
||||
constants.FLOAT, False)
|
||||
constants.FLOAT, False, constants.INT8)
|
||||
|
||||
def test_calibration(self):
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
|
|
|
@ -49,6 +49,7 @@ _MAP_TF_TO_TFLITE_TYPES = {
|
|||
dtypes.string: _types_pb2.STRING,
|
||||
dtypes.uint8: _types_pb2.QUANTIZED_UINT8,
|
||||
dtypes.int8: _types_pb2.INT8,
|
||||
dtypes.int16: _types_pb2.QUANTIZED_INT16,
|
||||
dtypes.complex64: _types_pb2.COMPLEX64,
|
||||
dtypes.bool: _types_pb2.BOOL,
|
||||
}
|
||||
|
|
|
@ -286,6 +286,7 @@ tf_cc_test(
|
|||
"//tensorflow/lite/tools/optimize:testdata/maximum.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/minimum.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/mixed.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/mixed16x8.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/pack.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin",
|
||||
|
|
|
@ -74,11 +74,13 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
property.inputs = {{0, {}}, {1, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.version = 2;
|
||||
property.quantize_input_as_activations = true;
|
||||
break;
|
||||
case BuiltinOperator_ARG_MAX:
|
||||
property.inputs = {{0, {}}};
|
||||
// ArgMax has no quantizable output.
|
||||
property.version = 2;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
case BuiltinOperator_AVERAGE_POOL_2D:
|
||||
property.inputs = {{0, {}}};
|
||||
|
@ -94,6 +96,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.version = 2;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
case BuiltinOperator_SPLIT:
|
||||
// We skip input 0 since it is the split dim which is not real valued.
|
||||
|
@ -159,6 +162,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
property.inputs = {{0, {}}, {1, {}}};
|
||||
// Comparisons have no quantizable outputs.
|
||||
property.version = 2;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
case BuiltinOperator_EXPAND_DIMS:
|
||||
// We skip input 1 as it is not real valued (it's the index of axis) and
|
||||
|
@ -181,11 +185,13 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.version = 2;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
case BuiltinOperator_HARD_SWISH: {
|
||||
property.inputs = {{0, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.version = 1;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_LOG_SOFTMAX: {
|
||||
|
@ -193,9 +199,10 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
// LogSoftmax requires output with 16/256 as scale and 127 as zero point.
|
||||
TensorProperty tensor_property;
|
||||
tensor_property.restriction = true;
|
||||
tensor_property.restricted_value = {16.0f / 256.0f, 127};
|
||||
tensor_property.restricted_value_int8 = {16.0f / 256.0f, 127};
|
||||
property.outputs = {{0, tensor_property}};
|
||||
property.version = 2;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_LOGISTIC: {
|
||||
|
@ -203,7 +210,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
// Logistic requires output with 1/256 as scale and -128 as zero point.
|
||||
TensorProperty tensor_property;
|
||||
tensor_property.restriction = true;
|
||||
tensor_property.restricted_value = {1 / 256.0f, -128};
|
||||
tensor_property.restricted_value_int8 = {1 / 256.0f, -128};
|
||||
tensor_property.restricted_value_int16 = {1 / 32768.0f, 0};
|
||||
property.outputs = {{0, tensor_property}};
|
||||
property.version = 2;
|
||||
break;
|
||||
|
@ -757,6 +765,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
property.restrict_scale = {{18, 0}};
|
||||
property.version = 2;
|
||||
}
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_L2_NORMALIZATION: {
|
||||
|
@ -764,9 +773,10 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
// L2 Norm requires output with 1/128 as scale and 0 as zero point.
|
||||
TensorProperty tensor_property;
|
||||
tensor_property.restriction = true;
|
||||
tensor_property.restricted_value = {1 / 128.0f, 0};
|
||||
tensor_property.restricted_value_int8 = {1 / 128.0f, 0};
|
||||
property.outputs = {{0, tensor_property}};
|
||||
property.version = 2;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_MAX_POOL_2D:
|
||||
|
@ -779,28 +789,33 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
property.arbitrary_inputs = true;
|
||||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.quantize_input_as_activations = true;
|
||||
property.version = 2;
|
||||
break;
|
||||
case BuiltinOperator_MEAN:
|
||||
property.inputs = {{0, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.version = 2;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
case BuiltinOperator_MINIMUM:
|
||||
property.arbitrary_inputs = true;
|
||||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.quantize_input_as_activations = true;
|
||||
property.version = 2;
|
||||
break;
|
||||
case BuiltinOperator_MUL:
|
||||
property.inputs = {{0, {}}, {1, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.quantize_input_as_activations = true;
|
||||
property.version = 2;
|
||||
break;
|
||||
case BuiltinOperator_PACK:
|
||||
property.arbitrary_inputs = true;
|
||||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.version = 2;
|
||||
break;
|
||||
case BuiltinOperator_PAD:
|
||||
|
@ -809,6 +824,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.version = 2;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
case BuiltinOperator_QUANTIZE:
|
||||
property.inputs = {{0, {}}};
|
||||
|
@ -831,11 +847,13 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
property.inputs = {{0, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.version = 2;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
case BuiltinOperator_RELU_N1_TO_1:
|
||||
property.inputs = {{0, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.version = 1;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
case BuiltinOperator_RESHAPE:
|
||||
property.inputs = {{0, {}}};
|
||||
|
@ -849,6 +867,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.version = 2;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
case BuiltinOperator_SHAPE:
|
||||
property.inputs = {{0, {}}};
|
||||
|
@ -874,7 +893,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
// Softmax requires output with 1/256 as scale and -128 as zero point.
|
||||
TensorProperty tensor_property;
|
||||
tensor_property.restriction = true;
|
||||
tensor_property.restricted_value = {1 / 256.0f, -128};
|
||||
tensor_property.restricted_value_int8 = {1 / 256.0f, -128};
|
||||
tensor_property.restricted_value_int16 = {1 / 32768.0f, 0};
|
||||
property.outputs = {{0, tensor_property}};
|
||||
property.version = 2;
|
||||
break;
|
||||
|
@ -894,13 +914,15 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
property.inputs = {{0, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.version = 2;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
case BuiltinOperator_TANH: {
|
||||
property.inputs = {{0, {}}};
|
||||
// Tanh requires output with 1/128 as scale and 0 as zero point.
|
||||
TensorProperty tensor_property;
|
||||
tensor_property.restriction = true;
|
||||
tensor_property.restricted_value = {1 / 128.0f, 0};
|
||||
tensor_property.restricted_value_int8 = {1 / 128.0f, 0};
|
||||
tensor_property.restricted_value_int16 = {1 / 32768.0f, 0};
|
||||
property.outputs = {{0, tensor_property}};
|
||||
property.version = 2;
|
||||
break;
|
||||
|
@ -926,6 +948,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
{3, tensor_property_bias}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.version = 3;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_TRANSPOSE:
|
||||
|
@ -949,6 +972,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||
default:
|
||||
// No quantized implementation exists for this operation.
|
||||
property.quantizable = false;
|
||||
property.quantizable_int16 = false;
|
||||
}
|
||||
return property;
|
||||
}
|
||||
|
|
|
@ -43,7 +43,8 @@ struct TensorProperty {
|
|||
// Constraints.
|
||||
bool restriction = false;
|
||||
// scale/zero_point hardcoded.
|
||||
std::pair<float, int> restricted_value = {0.0f, 0};
|
||||
std::pair<float, int> restricted_value_int8 = {0.0f, 0};
|
||||
std::pair<float, int> restricted_value_int16 = {0.0f, 0};
|
||||
|
||||
// Use derived scale.
|
||||
bool use_derived_scale = false;
|
||||
|
@ -64,7 +65,8 @@ struct TensorProperty {
|
|||
struct OperatorProperty {
|
||||
// Is a quantized operations currently supported.
|
||||
bool quantizable = true;
|
||||
|
||||
// Is a quantized operations currently supported for 16x8
|
||||
bool quantizable_int16 = true;
|
||||
// Op has arbitrary number of inputs, such as concat.
|
||||
bool arbitrary_inputs = false;
|
||||
// Op has arbitrary number of outputs, such as slice.
|
||||
|
@ -93,6 +95,13 @@ struct OperatorProperty {
|
|||
|
||||
// Op version.
|
||||
int version = 1;
|
||||
|
||||
// When we quantize activations into 16 bit and weights into 8 bit,
|
||||
// we want to quantize all inputs, including constant tensors,
|
||||
// for the operators like Add, Mul into 16-bit as well. The constant
|
||||
// inputs are quantized as weights and this variable indicates
|
||||
// that we want to do quantizations of these tensors as activations.
|
||||
bool quantize_input_as_activations = false;
|
||||
};
|
||||
|
||||
OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||
|
|
|
@ -85,6 +85,42 @@ void GetAsymmetricQuantizationParams(
|
|||
quantization_params->zero_point = std::vector<int64_t>(1, zero_point);
|
||||
}
|
||||
|
||||
void GetSymmetricQuantizationParams(
|
||||
float min, float max, const int half_quant_range,
|
||||
QuantizationParametersT* quantization_params) {
|
||||
// Adjust the boundaries to guarantee 0 is included.
|
||||
min = std::min(min, 0.0f);
|
||||
max = std::max(max, 0.0f);
|
||||
const float scale = std::max(std::abs(max), std::abs(min)) / half_quant_range;
|
||||
quantization_params->min = std::vector<float>(1, min);
|
||||
quantization_params->max = std::vector<float>(1, max);
|
||||
quantization_params->scale = std::vector<float>(1, scale);
|
||||
quantization_params->zero_point = std::vector<int64_t>(1, 0);
|
||||
}
|
||||
|
||||
TfLiteStatus GetQuantizationParams(TensorT* tensor, TensorType activations_type,
|
||||
QuantizationParametersT* quantization_params,
|
||||
ErrorReporter* error_reporter) {
|
||||
if (activations_type == TensorType_INT8) {
|
||||
GetAsymmetricQuantizationParams(
|
||||
tensor->quantization->min[0], tensor->quantization->max[0],
|
||||
std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max(),
|
||||
quantization_params);
|
||||
} else if (activations_type == TensorType_INT16) {
|
||||
const float quantized_range = 32767.0;
|
||||
GetSymmetricQuantizationParams(tensor->quantization->min[0],
|
||||
tensor->quantization->max[0],
|
||||
quantized_range, quantization_params);
|
||||
} else {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter,
|
||||
"Unsupported activation type for quantize-activation: %d",
|
||||
activations_type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Set the max and min quantization parameter for a single tensor given its
|
||||
// values.
|
||||
void FillSingleMinMax(const float* const input, const uint64_t input_size,
|
||||
|
@ -548,6 +584,7 @@ TfLiteStatus SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor,
|
|||
model, tensor, error_reporter);
|
||||
}
|
||||
|
||||
template <class BiasType>
|
||||
TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
float scaling_factor,
|
||||
ErrorReporter* error_reporter) {
|
||||
|
@ -560,25 +597,38 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
|||
uint64_t num_elements;
|
||||
TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements));
|
||||
|
||||
std::vector<int32_t> final_buffer(num_elements);
|
||||
const int32_t kScale = std::numeric_limits<int32_t>::max();
|
||||
std::vector<BiasType> final_buffer(num_elements);
|
||||
const BiasType kScale = std::numeric_limits<BiasType>::max();
|
||||
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
const int32_t quantized_value = tflite::SafeCast<int32_t>(
|
||||
const BiasType quantized_value = tflite::SafeCast<BiasType>(
|
||||
TfLiteRound(float_data[i] * scaling_factor_inv));
|
||||
final_buffer[i] = std::min(kScale, std::max(-kScale, quantized_value));
|
||||
}
|
||||
|
||||
// Set the buffers and output type.
|
||||
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
|
||||
size_t buffer_size = num_elements * sizeof(int32_t);
|
||||
size_t buffer_size = num_elements * sizeof(BiasType);
|
||||
std::vector<float> scales(1, scaling_factor);
|
||||
std::vector<int64_t> zero_points(1, 0);
|
||||
|
||||
auto output_type = std::is_same<BiasType, std::int32_t>::value
|
||||
? TensorType_INT32
|
||||
: TensorType_INT64;
|
||||
return AddQuantizationParams(scales, zero_points, 0, uint8_buffer,
|
||||
buffer_size, TensorType_INT32, model, tensor,
|
||||
buffer_size, output_type, model, tensor,
|
||||
error_reporter);
|
||||
}
|
||||
|
||||
template TfLiteStatus SymmetricPerLayerBiasQuantize<std::int32_t>(
|
||||
ModelT* model, TensorT* tensor, float scaling_factor,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
template TfLiteStatus SymmetricPerLayerBiasQuantize<std::int64_t>(
|
||||
ModelT* model, TensorT* tensor, float scaling_factor,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
template <class BiasType>
|
||||
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
float input_scale,
|
||||
const float* weight_scales,
|
||||
|
@ -595,14 +645,14 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
|||
uint64_t num_elements;
|
||||
TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements));
|
||||
|
||||
std::vector<int32_t> final_buffer(num_elements);
|
||||
const int32_t kScale = std::numeric_limits<int32_t>::max();
|
||||
std::vector<BiasType> final_buffer(num_elements);
|
||||
const BiasType kScale = std::numeric_limits<BiasType>::max();
|
||||
|
||||
for (int32_t channel_idx = 0; channel_idx < number_of_dimension;
|
||||
channel_idx++) {
|
||||
float scaling_factor = scales[channel_idx];
|
||||
float scaling_factor_inv = (scaling_factor == 0) ? 0 : 1.0 / scaling_factor;
|
||||
const int32_t quantized_value = tflite::SafeCast<int32_t>(
|
||||
const BiasType quantized_value = tflite::SafeCast<BiasType>(
|
||||
TfLiteRound(float_data[channel_idx] * scaling_factor_inv));
|
||||
final_buffer[channel_idx] =
|
||||
std::min(kScale, std::max(-kScale, quantized_value));
|
||||
|
@ -610,12 +660,26 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
|||
|
||||
// Set the buffers and output type.
|
||||
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
|
||||
size_t buffer_size = num_elements * sizeof(int32_t);
|
||||
size_t buffer_size = num_elements * sizeof(BiasType);
|
||||
std::vector<int64_t> zero_point(scales.size(), 0);
|
||||
|
||||
auto output_type = std::is_same<BiasType, std::int32_t>::value
|
||||
? TensorType_INT32
|
||||
: TensorType_INT64;
|
||||
return AddQuantizationParams(scales, zero_point, 0, uint8_buffer, buffer_size,
|
||||
TensorType_INT32, model, tensor, error_reporter);
|
||||
output_type, model, tensor, error_reporter);
|
||||
}
|
||||
|
||||
template TfLiteStatus SymmetricPerChannelBiasQuantize<std::int64_t>(
|
||||
ModelT* model, TensorT* tensor, float input_scale,
|
||||
const float* weight_scales, int number_of_dimension,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
template TfLiteStatus SymmetricPerChannelBiasQuantize<std::int32_t>(
|
||||
ModelT* model, TensorT* tensor, float input_scale,
|
||||
const float* weight_scales, int number_of_dimension,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
|
||||
int per_axis_index, ErrorReporter* error_reporter) {
|
||||
// TODO(suharshs): Currently we conflate quantizing weights and constants. Its
|
||||
|
@ -657,12 +721,12 @@ float GetEffectiveScale(ModelT* model, SubGraphT* subgraph, int op_idx,
|
|||
return scale;
|
||||
}
|
||||
|
||||
void QuantizeActivation(TensorT* tensor) {
|
||||
GetAsymmetricQuantizationParams(
|
||||
tensor->quantization->min[0], tensor->quantization->max[0],
|
||||
std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max(),
|
||||
tensor->quantization.get());
|
||||
tensor->type = TensorType_INT8;
|
||||
TfLiteStatus QuantizeActivation(TensorT* tensor, TensorType activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
TF_LITE_ENSURE_STATUS(GetQuantizationParams(
|
||||
tensor, activations_type, tensor->quantization.get(), error_reporter));
|
||||
tensor->type = activations_type;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale) {
|
||||
|
|
|
@ -113,12 +113,14 @@ TfLiteStatus SymmetricQuantizeFloatsToInt16(ModelT* model, TensorT* tensor,
|
|||
ErrorReporter* error_reporter);
|
||||
|
||||
// Symmetrically quantized the bias for per-layer ops (i.e. FullyConnected).
|
||||
template <typename BiasType>
|
||||
TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
float scaling_factor,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
// Symmetrically quantizes the bias for ops like Conv and DepthwiseConv.
|
||||
// The scale of bias if weight_per_channel_scale[channel] * input_scale.
|
||||
template <typename BiasType>
|
||||
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
float input_scale,
|
||||
const float* weight_scales,
|
||||
|
@ -135,8 +137,14 @@ float GetEffectiveScale(ModelT* model, SubGraphT* subgraph, int op_idx,
|
|||
std::vector<int> intermediate_index,
|
||||
std::vector<float> factors);
|
||||
|
||||
// Return quantization parameters depending on activations type.
|
||||
TfLiteStatus GetQuantizationParams(TensorT* tensor, TensorType activations_type,
|
||||
QuantizationParametersT* quantization_params,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
// Quantize activation.
|
||||
void QuantizeActivation(TensorT* tensor);
|
||||
TfLiteStatus QuantizeActivation(TensorT* tensor, TensorType activations_type,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
// Quantize activation to 16bit.
|
||||
TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale);
|
||||
|
|
|
@ -701,7 +701,7 @@ TEST_F(QuantizationUtilsTest, SymmetricPerLayerBiasQuantize) {
|
|||
model->buffers.push_back(std::move(buffer));
|
||||
|
||||
// Call and verify.
|
||||
EXPECT_EQ(SymmetricPerLayerBiasQuantize(
|
||||
EXPECT_EQ(SymmetricPerLayerBiasQuantize<int32_t>(
|
||||
model.get(), model->subgraphs[0]->tensors[0].get(),
|
||||
input_scale * weight_scale, &error_reporter_),
|
||||
kTfLiteOk);
|
||||
|
@ -759,7 +759,7 @@ TEST_F(QuantizationUtilsTest, SymmetricPerChannelBiasQuantize) {
|
|||
model->buffers.push_back(std::move(buffer));
|
||||
|
||||
// Call and verify.
|
||||
EXPECT_EQ(SymmetricPerChannelBiasQuantize(
|
||||
EXPECT_EQ(SymmetricPerChannelBiasQuantize<int32_t>(
|
||||
model.get(), model->subgraphs[0]->tensors[0].get(), input_scale,
|
||||
weight_scales.data(), 2, &error_reporter_),
|
||||
kTfLiteOk);
|
||||
|
|
|
@ -42,7 +42,9 @@ bool CreateQuantizedModel(const std::string& path) {
|
|||
tflite::StderrReporter error_reporter;
|
||||
if (tflite::optimize::QuantizeModel(
|
||||
&builder, &model, tflite::TensorType_FLOAT32,
|
||||
tflite::TensorType_FLOAT32, &error_reporter) != kTfLiteOk) {
|
||||
tflite::TensorType_FLOAT32,
|
||||
// TODO(b/159351372): Pass required activation type if needed
|
||||
tflite::TensorType_INT8, &error_reporter) != kTfLiteOk) {
|
||||
return false;
|
||||
}
|
||||
return WriteFile(path, builder.GetBufferPointer(), builder.GetSize());
|
||||
|
|
|
@ -52,13 +52,17 @@ bool IsFloatTensor(const SubGraphT* subgraph, int32_t tensor_idx) {
|
|||
// operator_names.
|
||||
operator_property::OperatorProperty GetOperatorProperty(
|
||||
const std::unordered_set<string>& operator_names, const ModelT* model,
|
||||
int subgraph_index, int op_idx, const string& operator_name) {
|
||||
int subgraph_index, int op_idx, const string& operator_name,
|
||||
const TensorType& activations_type) {
|
||||
operator_property::OperatorProperty property =
|
||||
operator_property::GetOperatorProperty(model, subgraph_index, op_idx);
|
||||
const SubGraphT* subgraph = model->subgraphs[subgraph_index].get();
|
||||
const OperatorT* op = subgraph->operators[op_idx].get();
|
||||
const BuiltinOperator op_code =
|
||||
model->operator_codes[op->opcode_index]->builtin_code;
|
||||
if (activations_type == TensorType_INT16 && !property.quantizable_int16) {
|
||||
property.quantizable = false;
|
||||
}
|
||||
// The algorithm adds Dequantize and Quantize, so we don't require them to be
|
||||
// in the operator_names.
|
||||
if (op_code != BuiltinOperator_DEQUANTIZE &&
|
||||
|
@ -78,7 +82,8 @@ bool IsRealValueOp(const std::unordered_set<string>& real_value_op_set,
|
|||
// Creates a set that contains all quantizable ops that happen to take a
|
||||
// non-float type in the source graph.
|
||||
std::unordered_set<string> PopulateRealValueOpSet(
|
||||
ModelT* model, const std::unordered_set<string>& operator_names) {
|
||||
ModelT* model, const std::unordered_set<string>& operator_names,
|
||||
const TensorType& activations_type) {
|
||||
std::unordered_set<string> real_value_op_set;
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
|
@ -86,8 +91,9 @@ std::unordered_set<string> PopulateRealValueOpSet(
|
|||
for (size_t op_idx = 0; op_idx < subgraph->operators.size(); op_idx++) {
|
||||
OperatorT* op = subgraph->operators[op_idx].get();
|
||||
const string operator_name = subgraph->tensors[op->outputs[0]]->name;
|
||||
operator_property::OperatorProperty property = GetOperatorProperty(
|
||||
operator_names, model, subgraph_idx, op_idx, operator_name);
|
||||
operator_property::OperatorProperty property =
|
||||
GetOperatorProperty(operator_names, model, subgraph_idx, op_idx,
|
||||
operator_name, activations_type);
|
||||
|
||||
if (!property.quantizable) {
|
||||
real_value_op_set.insert(operator_name);
|
||||
|
@ -134,6 +140,7 @@ std::unordered_set<string> PopulateRealValueOpSet(
|
|||
TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
|
||||
const TensorT* weight_tensor, TensorT* bias_tensor,
|
||||
bool is_per_channel, int channel_dim_index,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
if (bias_tensor->shape.size() != 1) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter, "Expected bias tensor shape to be 1.");
|
||||
|
@ -165,9 +172,15 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
|
|||
weight_scales.size());
|
||||
return kTfLiteError;
|
||||
}
|
||||
return utils::SymmetricPerChannelBiasQuantize(
|
||||
model, bias_tensor, input_tensor->quantization->scale[0],
|
||||
weight_scales.data(), channel_dim_size, error_reporter);
|
||||
if (activations_type == tflite::TensorType_INT16) {
|
||||
return utils::SymmetricPerChannelBiasQuantize<std::int64_t>(
|
||||
model, bias_tensor, input_tensor->quantization->scale[0],
|
||||
weight_scales.data(), channel_dim_size, error_reporter);
|
||||
} else {
|
||||
return utils::SymmetricPerChannelBiasQuantize<std::int32_t>(
|
||||
model, bias_tensor, input_tensor->quantization->scale[0],
|
||||
weight_scales.data(), channel_dim_size, error_reporter);
|
||||
}
|
||||
} else {
|
||||
if (weight_scales.size() != 1) {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
|
@ -176,40 +189,54 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
|
|||
weight_scales.size());
|
||||
return kTfLiteError;
|
||||
}
|
||||
return utils::SymmetricPerLayerBiasQuantize(
|
||||
model, bias_tensor,
|
||||
input_tensor->quantization->scale[0] * weight_scales[0],
|
||||
error_reporter);
|
||||
if (activations_type == tflite::TensorType_INT16) {
|
||||
return utils::SymmetricPerLayerBiasQuantize<std::int64_t>(
|
||||
model, bias_tensor,
|
||||
input_tensor->quantization->scale[0] * weight_scales[0],
|
||||
error_reporter);
|
||||
} else {
|
||||
return utils::SymmetricPerLayerBiasQuantize<std::int32_t>(
|
||||
model, bias_tensor,
|
||||
input_tensor->quantization->scale[0] * weight_scales[0],
|
||||
error_reporter);
|
||||
}
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// True if the tensor type has to be modified.
|
||||
bool TensorTypeChangeRequired(const TensorT* tensor, const TensorType& type) {
|
||||
// The quantized model is type INT8, so if the user provided type is INT8, we
|
||||
// do not have to do any custom logic. Additionally, if the current tensor
|
||||
// isn't INT8 quantized, the custom type doesn't apply.
|
||||
return (type != TensorType_INT8 && tensor->type == TensorType_INT8 &&
|
||||
!tensor->quantization->scale.empty());
|
||||
// The quantized model is type INT8/INT16, so if the user provided type is
|
||||
// INT8/INT16, we do not have to do any custom logic. Additionally, if the
|
||||
// current tensor isn't INT8/INT16 quantized, the custom type doesn't apply.
|
||||
bool int8check = type != TensorType_INT8 && tensor->type == TensorType_INT8 &&
|
||||
!tensor->quantization->scale.empty();
|
||||
bool int16check = type != TensorType_INT16 &&
|
||||
tensor->type == TensorType_INT16 &&
|
||||
!tensor->quantization->scale.empty();
|
||||
return (int8check || int16check);
|
||||
}
|
||||
|
||||
// Sets the input type, adding a Leading Op node at the start of the model if
|
||||
// necessary.
|
||||
// Returns the new input tensor index.
|
||||
int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
|
||||
const int32_t tensor_idx, const TensorType& input_type) {
|
||||
const int32_t tensor_idx, const TensorType& input_type,
|
||||
const TensorType& activations_type) {
|
||||
TensorT* tensor = subgraph->tensors[tensor_idx].get();
|
||||
if (!TensorTypeChangeRequired(tensor, input_type)) {
|
||||
return -1;
|
||||
}
|
||||
if (input_type == TensorType_FLOAT32 || input_type == TensorType_UINT8) {
|
||||
std::string type_string =
|
||||
activations_type == TensorType_INT16 ? "int16" : "int8";
|
||||
// Create a new tensor to be the input of the leading Op.
|
||||
std::unique_ptr<TensorT> leading_op_input;
|
||||
if (input_type == TensorType_FLOAT32) {
|
||||
// Add tensor for quantize operator. Scales and zero points are not
|
||||
// needed.
|
||||
const string leading_op_name = tensor->name;
|
||||
const string new_name_original_input = tensor->name + "_int8";
|
||||
const string new_name_original_input = tensor->name + "_" + type_string;
|
||||
tensor->name = new_name_original_input;
|
||||
utils::MakeTensor(leading_op_name, tensor->shape, tensor->shape_signature,
|
||||
input_type, &leading_op_input);
|
||||
|
@ -224,7 +251,7 @@ int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
|
|||
TFLITE_DCHECK_GE(zero_point, -128);
|
||||
TFLITE_DCHECK_LE(zero_point, 127);
|
||||
const string leading_op_name = tensor->name;
|
||||
const string new_name_original_input = tensor->name + "_int8";
|
||||
const string new_name_original_input = tensor->name + "_" + type_string;
|
||||
tensor->name = new_name_original_input;
|
||||
utils::MakeTensorWithQuantParam(
|
||||
leading_op_name, tensor->shape, tensor->shape_signature, input_type,
|
||||
|
@ -251,17 +278,20 @@ int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
|
|||
// necessary.
|
||||
// Returns the new output tensor index.
|
||||
int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
|
||||
const int32_t tensor_idx, const TensorType& output_type) {
|
||||
const int32_t tensor_idx, const TensorType& output_type,
|
||||
const TensorType& activations_type) {
|
||||
TensorT* tensor = subgraph->tensors[tensor_idx].get();
|
||||
if (!TensorTypeChangeRequired(tensor, output_type)) {
|
||||
return -1;
|
||||
}
|
||||
if (output_type == TensorType_FLOAT32 || output_type == TensorType_UINT8) {
|
||||
std::string type_string =
|
||||
activations_type == TensorType_INT16 ? "int16" : "int8";
|
||||
// Create a new tensor to be the output of the tailing op.
|
||||
std::unique_ptr<TensorT> tailing_op_output;
|
||||
if (output_type == TensorType_FLOAT32) {
|
||||
const string tailing_op_name = tensor->name;
|
||||
const string new_name_original_output = tensor->name + "_int8";
|
||||
const string new_name_original_output = tensor->name + "_" + type_string;
|
||||
tensor->name = new_name_original_output;
|
||||
utils::MakeTensor(tailing_op_name, tensor->shape, tensor->shape_signature,
|
||||
output_type, &tailing_op_output);
|
||||
|
@ -276,7 +306,7 @@ int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
|
|||
TFLITE_DCHECK_GE(zero_point, -128);
|
||||
TFLITE_DCHECK_LE(zero_point, 127);
|
||||
const string tailing_op_name = tensor->name;
|
||||
const string new_name_original_output = tensor->name + "_int8";
|
||||
const string new_name_original_output = tensor->name + "_" + type_string;
|
||||
tensor->name = new_name_original_output;
|
||||
utils::MakeTensorWithQuantParam(
|
||||
tailing_op_name, tensor->shape, tensor->shape_signature, output_type,
|
||||
|
@ -312,6 +342,7 @@ int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
|
|||
// uint8, can be thought as "requant").
|
||||
TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
|
@ -328,8 +359,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
|||
EnumNameTensorType(tensor->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
const int32_t input_idx =
|
||||
SetInputType(model, subgraph, subgraph->inputs[i], input_type);
|
||||
const int32_t input_idx = SetInputType(
|
||||
model, subgraph, subgraph->inputs[i], input_type, activations_type);
|
||||
if (input_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
|
@ -346,8 +377,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
|||
EnumNameTensorType(tensor->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
const int32_t output_idx =
|
||||
SetOutputType(model, subgraph, subgraph->outputs[i], output_type);
|
||||
const int32_t output_idx = SetOutputType(
|
||||
model, subgraph, subgraph->outputs[i], output_type, activations_type);
|
||||
if (output_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
|
@ -364,7 +395,7 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
|||
TfLiteStatus ApplyConstraints(
|
||||
ModelT* model, const std::unordered_set<string>& operator_names,
|
||||
const std::unordered_set<string>& real_value_op_set,
|
||||
ErrorReporter* error_reporter) {
|
||||
TensorType activations_type, ErrorReporter* error_reporter) {
|
||||
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||
|
@ -372,8 +403,9 @@ TfLiteStatus ApplyConstraints(
|
|||
for (int op_idx = subgraph->operators.size() - 1; op_idx >= 0; op_idx--) {
|
||||
OperatorT* op = subgraph->operators[op_idx].get();
|
||||
const string operator_name = subgraph->tensors[op->outputs[0]]->name;
|
||||
operator_property::OperatorProperty property = GetOperatorProperty(
|
||||
operator_names, model, subgraph_idx, op_idx, operator_name);
|
||||
operator_property::OperatorProperty property =
|
||||
GetOperatorProperty(operator_names, model, subgraph_idx, op_idx,
|
||||
operator_name, activations_type);
|
||||
if (!property.quantizable ||
|
||||
!IsRealValueOp(real_value_op_set, operator_name)) {
|
||||
continue;
|
||||
|
@ -413,7 +445,7 @@ TfLiteStatus ApplyConstraints(
|
|||
const string requant_tensor_name = input_tensor->name + "_requantized";
|
||||
utils::MakeTensorWithQuantParam(
|
||||
requant_tensor_name, input_tensor->shape,
|
||||
input_tensor->shape_signature, TensorType_INT8, output_scale,
|
||||
input_tensor->shape_signature, activations_type, output_scale,
|
||||
output_zp, &additional_tensor);
|
||||
const int32_t additional_tensor_idx = subgraph->tensors.size();
|
||||
subgraph->tensors.push_back(std::move(additional_tensor));
|
||||
|
@ -463,7 +495,8 @@ std::vector<std::pair<int, operator_property::TensorProperty>> GetOutputs(
|
|||
|
||||
bool ShouldRestrictSameInputOutputScale(
|
||||
operator_property::OperatorProperty property) {
|
||||
// Ops with multiple inputs (i.e. concat) gets restricted in ApplyConstraints.
|
||||
// Ops with multiple inputs (i.e. concat, max and min) gets restricted in
|
||||
// ApplyConstraints.
|
||||
return (!property.arbitrary_inputs &&
|
||||
property.restrict_same_input_output_scale);
|
||||
}
|
||||
|
@ -482,7 +515,7 @@ TfLiteStatus QuantizeOpInput(
|
|||
ModelT* model, int32_t subgraph_idx, size_t* op_idx,
|
||||
operator_property::OperatorProperty property,
|
||||
const std::pair<int32_t, operator_property::TensorProperty>& input,
|
||||
ErrorReporter* error_reporter) {
|
||||
const TensorType& activations_type, ErrorReporter* error_reporter) {
|
||||
int32_t input_idx = input.first;
|
||||
operator_property::TensorProperty tensor_property = input.second;
|
||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||
|
@ -511,7 +544,9 @@ TfLiteStatus QuantizeOpInput(
|
|||
if (utils::HasBuffer(model, subgraph, tensor_idx)) {
|
||||
// TODO(suharshs): Look at consumers, throw error if one consumer is
|
||||
// per-channel and one per-layer.
|
||||
if (tensor_property.number_of_bits == 8) {
|
||||
bool quantize_const_input = property.quantize_input_as_activations &&
|
||||
activations_type == TensorType_INT16;
|
||||
if (tensor_property.number_of_bits == 8 && !quantize_const_input) {
|
||||
if (tensor_property.use_derived_scale) {
|
||||
// Currently 8bit tensors in input do not accept derived scale.
|
||||
return kTfLiteError;
|
||||
|
@ -527,7 +562,7 @@ TfLiteStatus QuantizeOpInput(
|
|||
*op_idx);
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else if (tensor_property.number_of_bits == 16) {
|
||||
} else if (tensor_property.number_of_bits == 16 || quantize_const_input) {
|
||||
if (tensor_property.use_derived_scale) {
|
||||
// Currently 16bit tensors in input do not accept derived scale.
|
||||
return kTfLiteError;
|
||||
|
@ -559,8 +594,8 @@ TfLiteStatus QuantizeOpInput(
|
|||
tensor_property.derived_scale.input_tensors,
|
||||
tensor_property.derived_scale.intermediate_tensors,
|
||||
tensor_property.derived_scale.factors);
|
||||
return utils::SymmetricPerLayerBiasQuantize(model, tensor, scale,
|
||||
error_reporter);
|
||||
return utils::SymmetricPerLayerBiasQuantize<std::int32_t>(
|
||||
model, tensor, scale, error_reporter);
|
||||
|
||||
} else if (tensor_property.number_of_bits == 10) {
|
||||
// When the number of bits is 10 (instead of 16), quantize the tensor to
|
||||
|
@ -598,7 +633,8 @@ TfLiteStatus QuantizeOpInput(
|
|||
// Currently 8bit tensors in input do not accept derived scale.
|
||||
return kTfLiteError;
|
||||
}
|
||||
utils::QuantizeActivation(tensor);
|
||||
TF_LITE_ENSURE_STATUS(utils::QuantizeActivation(
|
||||
tensor, activations_type, error_reporter));
|
||||
} else if (tensor_property.number_of_bits == 16) {
|
||||
TensorT* tensor = subgraph->tensors[tensor_idx].get();
|
||||
float quantized_range = 32767.0;
|
||||
|
@ -616,13 +652,17 @@ TfLiteStatus QuantizeOpInput(
|
|||
} else {
|
||||
// If the tensor is not a model input, we need to add a Quantize
|
||||
// operation since the preceding op may require a float output.
|
||||
std::string type_string =
|
||||
activations_type == TensorType_INT16 ? "int16" : "int8";
|
||||
std::unique_ptr<TensorT> op_output;
|
||||
utils::MakeTensor(tensor->name + "_int8", tensor->shape,
|
||||
tensor->shape_signature, TensorType_INT8, &op_output);
|
||||
utils::MakeTensor(tensor->name + "_" + type_string, tensor->shape,
|
||||
tensor->shape_signature, activations_type,
|
||||
&op_output);
|
||||
op_output->quantization = absl::make_unique<QuantizationParametersT>();
|
||||
op_output->quantization->min.push_back(tensor->quantization->min[0]);
|
||||
op_output->quantization->max.push_back(tensor->quantization->max[0]);
|
||||
utils::QuantizeActivation(op_output.get());
|
||||
TF_LITE_ENSURE_STATUS(utils::QuantizeActivation(
|
||||
op_output.get(), activations_type, error_reporter));
|
||||
const int32_t quant_op_output_idx = subgraph->tensors.size();
|
||||
subgraph->tensors.push_back(std::move(op_output));
|
||||
std::unique_ptr<OperatorT> quant_op;
|
||||
|
@ -665,7 +705,7 @@ TfLiteStatus QuantizeOpOutput(
|
|||
ModelT* model, int32_t subgraph_idx, int32_t op_idx,
|
||||
operator_property::OperatorProperty property,
|
||||
const std::pair<int32_t, operator_property::TensorProperty>& output,
|
||||
ErrorReporter* error_reporter) {
|
||||
TensorType activations_type, ErrorReporter* error_reporter) {
|
||||
int32_t output_idx = output.first;
|
||||
operator_property::TensorProperty tensor_property = output.second;
|
||||
// If the operator is not quantizable, we don't need to do anything for the
|
||||
|
@ -732,18 +772,22 @@ TfLiteStatus QuantizeOpOutput(
|
|||
const float max = input_tensor->quantization->max[0];
|
||||
output_tensor->quantization->max = {max};
|
||||
}
|
||||
output_tensor->type = TensorType_INT8;
|
||||
output_tensor->type = activations_type;
|
||||
} else if (tensor_property.restriction) {
|
||||
const auto scale_and_zp = tensor_property.restricted_value;
|
||||
const auto scale_and_zp = activations_type == TensorType_INT16
|
||||
? tensor_property.restricted_value_int16
|
||||
: tensor_property.restricted_value_int8;
|
||||
|
||||
// Apply to output.
|
||||
output_tensor->quantization = absl::make_unique<QuantizationParametersT>();
|
||||
output_tensor->quantization->scale.push_back(scale_and_zp.first);
|
||||
output_tensor->quantization->zero_point.push_back(scale_and_zp.second);
|
||||
output_tensor->type = TensorType_INT8;
|
||||
output_tensor->type = activations_type;
|
||||
} else {
|
||||
// Process regular output that doesn't have any restrictions.
|
||||
if (utils::HasMinMax(output_tensor)) {
|
||||
utils::QuantizeActivation(output_tensor);
|
||||
utils::QuantizeActivation(output_tensor, activations_type,
|
||||
error_reporter);
|
||||
} else {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter,
|
||||
|
@ -757,6 +801,7 @@ TfLiteStatus QuantizeOpOutput(
|
|||
}
|
||||
|
||||
TfLiteStatus QuantizeIntemediateTensors(ModelT* model,
|
||||
TensorType activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
|
@ -780,7 +825,8 @@ TfLiteStatus QuantizeIntemediateTensors(ModelT* model,
|
|||
input.second.symmetric == false) {
|
||||
TensorT* tensor = subgraph->tensors[index_global].get();
|
||||
if (utils::HasMinMax(tensor)) {
|
||||
utils::QuantizeActivation(tensor);
|
||||
utils::QuantizeActivation(tensor, activations_type,
|
||||
error_reporter);
|
||||
} else {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter,
|
||||
|
@ -884,7 +930,7 @@ TfLiteStatus QuantizeWeightsInputOutput(
|
|||
ModelT* model, bool allow_float,
|
||||
const std::unordered_set<string>& operator_names,
|
||||
const std::unordered_set<string>& real_value_op_set,
|
||||
ErrorReporter* error_reporter) {
|
||||
const TensorType& activations_type, ErrorReporter* error_reporter) {
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||
|
@ -893,15 +939,23 @@ TfLiteStatus QuantizeWeightsInputOutput(
|
|||
const BuiltinOperator op_code =
|
||||
model->operator_codes[op->opcode_index]->builtin_code;
|
||||
const string operator_name = subgraph->tensors[op->outputs[0]]->name;
|
||||
operator_property::OperatorProperty property = GetOperatorProperty(
|
||||
operator_names, model, subgraph_idx, op_idx, operator_name);
|
||||
operator_property::OperatorProperty property =
|
||||
GetOperatorProperty(operator_names, model, subgraph_idx, op_idx,
|
||||
operator_name, activations_type);
|
||||
if (!IsRealValueOp(real_value_op_set, operator_name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!property.quantizable && !allow_float) {
|
||||
if (activations_type == TensorType_INT16 && !property.quantizable &&
|
||||
!allow_float) {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter,
|
||||
"Quantization to 16x8-bit not yet supported for op: %",
|
||||
EnumNameBuiltinOperator(op_code));
|
||||
return kTfLiteError;
|
||||
} else if (!property.quantizable && !allow_float) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
"Quantization not yet supported for op: %s",
|
||||
"Quantization not yet supported for op: %",
|
||||
EnumNameBuiltinOperator(op_code));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
@ -910,14 +964,16 @@ TfLiteStatus QuantizeWeightsInputOutput(
|
|||
for (const std::pair<int, operator_property::TensorProperty>& input :
|
||||
GetInputs(op, property)) {
|
||||
TF_LITE_ENSURE_STATUS(QuantizeOpInput(model, subgraph_idx, &op_idx,
|
||||
property, input, error_reporter));
|
||||
property, input, activations_type,
|
||||
error_reporter));
|
||||
}
|
||||
|
||||
// Quantize operator outputs.
|
||||
for (const std::pair<int, operator_property::TensorProperty>& output :
|
||||
GetOutputs(op, property)) {
|
||||
TF_LITE_ENSURE_STATUS(QuantizeOpOutput(
|
||||
model, subgraph_idx, op_idx, property, output, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
QuantizeOpOutput(model, subgraph_idx, op_idx, property, output,
|
||||
activations_type, error_reporter));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -928,6 +984,7 @@ TfLiteStatus QuantizeWeightsInputOutput(
|
|||
TfLiteStatus QuantizeBiases(ModelT* model,
|
||||
const std::unordered_set<string>& operator_names,
|
||||
const std::unordered_set<string>& real_value_op_set,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
|
@ -937,8 +994,9 @@ TfLiteStatus QuantizeBiases(ModelT* model,
|
|||
const BuiltinOperator op_code =
|
||||
model->operator_codes[op->opcode_index]->builtin_code;
|
||||
const string operator_name = subgraph->tensors[op->outputs[0]]->name;
|
||||
operator_property::OperatorProperty property = GetOperatorProperty(
|
||||
operator_names, model, subgraph_idx, op_idx, operator_name);
|
||||
operator_property::OperatorProperty property =
|
||||
GetOperatorProperty(operator_names, model, subgraph_idx, op_idx,
|
||||
operator_name, activations_type);
|
||||
if (!property.quantizable ||
|
||||
!IsRealValueOp(real_value_op_set, operator_name)) {
|
||||
continue;
|
||||
|
@ -968,10 +1026,10 @@ TfLiteStatus QuantizeBiases(ModelT* model,
|
|||
subgraph->tensors[op->inputs[property.inputs[1].first]].get();
|
||||
operator_property::TensorProperty weight_property =
|
||||
property.inputs[1].second;
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
QuantizeBias(model, input_tensor, weight_tensor, bias_tensor,
|
||||
weight_property.per_axis,
|
||||
weight_property.per_axis_index, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(QuantizeBias(
|
||||
model, input_tensor, weight_tensor, bias_tensor,
|
||||
weight_property.per_axis, weight_property.per_axis_index,
|
||||
activations_type, error_reporter));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1001,15 +1059,16 @@ std::unordered_set<string> GetAllOperatorOutputs(ModelT* model) {
|
|||
TfLiteStatus FillQuantizationParams(
|
||||
ModelT* model, const std::unordered_set<string>& operator_names,
|
||||
const std::unordered_set<string>& real_value_op_set,
|
||||
ErrorReporter* error_reporter) {
|
||||
const TensorType& activations_type, ErrorReporter* error_reporter) {
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||
for (size_t op_idx = 0; op_idx < subgraph->operators.size(); op_idx++) {
|
||||
OperatorT* op = subgraph->operators[op_idx].get();
|
||||
const string operator_name = subgraph->tensors[op->outputs[0]]->name;
|
||||
operator_property::OperatorProperty property = GetOperatorProperty(
|
||||
operator_names, model, subgraph_idx, op_idx, operator_name);
|
||||
operator_property::OperatorProperty property =
|
||||
GetOperatorProperty(operator_names, model, subgraph_idx, op_idx,
|
||||
operator_name, activations_type);
|
||||
if (!IsRealValueOp(real_value_op_set, operator_name)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -1110,15 +1169,16 @@ TfLiteStatus FillQuantizationParams(
|
|||
TfLiteStatus EnsureBiasScaleCompatibility(
|
||||
ModelT* model, const std::unordered_set<string>& operator_names,
|
||||
const std::unordered_set<string>& real_value_op_set,
|
||||
ErrorReporter* error_reporter) {
|
||||
const TensorType& activations_type, ErrorReporter* error_reporter) {
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||
for (size_t op_idx = 0; op_idx < subgraph->operators.size(); op_idx++) {
|
||||
OperatorT* op = subgraph->operators[op_idx].get();
|
||||
const string operator_name = subgraph->tensors[op->outputs[0]]->name;
|
||||
operator_property::OperatorProperty property = GetOperatorProperty(
|
||||
operator_names, model, subgraph_idx, op_idx, operator_name);
|
||||
operator_property::OperatorProperty property =
|
||||
GetOperatorProperty(operator_names, model, subgraph_idx, op_idx,
|
||||
operator_name, activations_type);
|
||||
if (!IsRealValueOp(real_value_op_set, operator_name)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -1166,11 +1226,9 @@ TfLiteStatus EnsureBiasScaleCompatibility(
|
|||
|
||||
// Get input scale for asymmetric quantization.
|
||||
QuantizationParametersT temp_quant_params = QuantizationParametersT();
|
||||
utils::GetAsymmetricQuantizationParams(
|
||||
input_tensor->quantization->min[0],
|
||||
input_tensor->quantization->max[0],
|
||||
std::numeric_limits<int8_t>::min(),
|
||||
std::numeric_limits<int8_t>::max(), &temp_quant_params);
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
utils::GetQuantizationParams(input_tensor, activations_type,
|
||||
&temp_quant_params, error_reporter));
|
||||
if (temp_quant_params.scale.size() != 1) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
"Unexpected input quantization scale size.");
|
||||
|
@ -1256,23 +1314,30 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
|||
ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type, bool allow_float,
|
||||
const std::unordered_set<string>& operator_names,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
auto real_value_op_set = PopulateRealValueOpSet(model, operator_names);
|
||||
TF_LITE_ENSURE_STATUS(FillQuantizationParams(
|
||||
model, operator_names, real_value_op_set, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(EnsureBiasScaleCompatibility(
|
||||
model, operator_names, real_value_op_set, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(QuantizeIntemediateTensors(model, error_reporter));
|
||||
auto real_value_op_set =
|
||||
PopulateRealValueOpSet(model, operator_names, activations_type);
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
FillQuantizationParams(model, operator_names, real_value_op_set,
|
||||
activations_type, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
EnsureBiasScaleCompatibility(model, operator_names, real_value_op_set,
|
||||
activations_type, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
QuantizeIntemediateTensors(model, activations_type, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(QuantizeSharedRange(model, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(QuantizeWeightsInputOutput(
|
||||
model, allow_float, operator_names, real_value_op_set, error_reporter));
|
||||
model, allow_float, operator_names, real_value_op_set, activations_type,
|
||||
error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(ApplyConstraints(model, operator_names,
|
||||
real_value_op_set, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
QuantizeBiases(model, operator_names, real_value_op_set, error_reporter));
|
||||
real_value_op_set, activations_type,
|
||||
error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(QuantizeBiases(model, operator_names, real_value_op_set,
|
||||
activations_type, error_reporter));
|
||||
utils::SetOperatorCodeVersion(model);
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
SetInputAndOutputTypes(model, input_type, output_type, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(SetInputAndOutputTypes(
|
||||
model, input_type, output_type, activations_type, error_reporter));
|
||||
|
||||
flatbuffers::Offset<Model> output_model_location =
|
||||
Model::Pack(*builder, model);
|
||||
|
@ -1281,12 +1346,25 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
|||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeModelAllOperators(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model,
|
||||
const TensorType& input_type,
|
||||
const TensorType& output_type,
|
||||
bool allow_float,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
return QuantizeModel(builder, model, input_type, output_type, allow_float,
|
||||
GetAllOperatorOutputs(model), activations_type,
|
||||
error_reporter);
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type, bool allow_float,
|
||||
ErrorReporter* error_reporter) {
|
||||
return QuantizeModel(builder, model, input_type, output_type, allow_float,
|
||||
GetAllOperatorOutputs(model), error_reporter);
|
||||
GetAllOperatorOutputs(model), TensorType_INT8,
|
||||
error_reporter);
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
|
|
|
@ -65,6 +65,28 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
|||
const std::unordered_set<string>& operator_names,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
// Same as above, but enables to provide activation type, which
|
||||
// could be TensorType_INT16 or TensorType_INT8.
|
||||
//
|
||||
// Note: This is a private API, subject to change.
|
||||
TfLiteStatus QuantizeModelAllOperators(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model,
|
||||
const TensorType& input_type,
|
||||
const TensorType& output_type,
|
||||
bool allow_float,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
// Quantizes input_model and populates the provided builder with the new model
|
||||
// with all possible input parameters.
|
||||
// All functions above call this function underneath.
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type, bool allow_float,
|
||||
const std::unordered_set<string>& operator_names,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
} // namespace optimize
|
||||
} // namespace tflite
|
||||
|
||||
|
|
|
@ -80,28 +80,36 @@ class QuantizeModelTest : public testing::Test {
|
|||
internal::FailOnErrorReporter error_reporter_;
|
||||
};
|
||||
|
||||
class QuantizeConvModelTest : public QuantizeModelTest {
|
||||
class QuantizeConvModelTest : public QuantizeModelTest,
|
||||
public testing::WithParamInterface<TensorType> {
|
||||
protected:
|
||||
QuantizeConvModelTest() {
|
||||
tensor_type_ = GetParam();
|
||||
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
|
||||
readonly_model_ = input_model_->GetModel();
|
||||
readonly_model_->UnPackTo(&model_);
|
||||
}
|
||||
TensorType tensor_type_;
|
||||
};
|
||||
|
||||
TEST_F(QuantizeConvModelTest, QuantizationSucceeds) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
INSTANTIATE_TEST_SUITE_P(QuantizeConvModelTestInst, QuantizeConvModelTest,
|
||||
testing::ValuesIn({TensorType_INT8,
|
||||
TensorType_INT16}));
|
||||
|
||||
TEST_P(QuantizeConvModelTest, QuantizationSucceeds) {
|
||||
auto status =
|
||||
QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
false, tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
const uint8_t* buffer = builder_.GetBufferPointer();
|
||||
const Model* output_model = GetModel(buffer);
|
||||
ASSERT_TRUE(output_model);
|
||||
}
|
||||
|
||||
TEST_F(QuantizeConvModelTest, SkipUnspecifiedLayer) {
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||
/*allow_float=*/true, {}, &error_reporter_);
|
||||
TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) {
|
||||
auto status = QuantizeModel(
|
||||
&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||
/*allow_float=*/true, {}, TensorType_FLOAT32, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
|
||||
// The resulting model should be the same.
|
||||
|
@ -123,9 +131,10 @@ TEST_F(QuantizeConvModelTest, SkipUnspecifiedLayer) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
TEST_P(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
|
||||
auto status =
|
||||
QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
false, tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
|
||||
|
@ -148,9 +157,10 @@ TEST_F(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
|
|||
EXPECT_EQ(model_.operator_codes[0]->version, 3);
|
||||
}
|
||||
|
||||
TEST_F(QuantizeConvModelTest, OperatorsAreUnchanged) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
TEST_P(QuantizeConvModelTest, OperatorsAreUnchanged) {
|
||||
auto status =
|
||||
QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
false, tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
ASSERT_EQ(model_.operator_codes.size(),
|
||||
readonly_model_->operator_codes()->size());
|
||||
|
@ -182,20 +192,29 @@ TEST_F(QuantizeConvModelTest, OperatorsAreUnchanged) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantizeConvModelTest, GraphIsFullyQuantized) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) {
|
||||
auto status = QuantizeModelAllOperators(
|
||||
&builder_, &model_, tensor_type_, tensor_type_,
|
||||
/*allow_float*/ false, tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
for (const auto& subgraph : model_.subgraphs) {
|
||||
for (const auto& tensor : subgraph->tensors) {
|
||||
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
|
||||
tensor->type == TensorType_INT8);
|
||||
if (tensor_type_ == TensorType_INT8) {
|
||||
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
|
||||
tensor->type == TensorType_INT8);
|
||||
} else if (tensor_type_ == TensorType_INT16) {
|
||||
EXPECT_TRUE(tensor->type == TensorType_INT64 || // bias
|
||||
tensor->type == TensorType_INT8 || // weights
|
||||
tensor->type == TensorType_INT16); // activations
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantizeConvModelTest, FloatInputAndOutput) {
|
||||
auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
|
||||
TEST_P(QuantizeConvModelTest, FloatInputAndOutput) {
|
||||
auto status = QuantizeModelAllOperators(
|
||||
&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||
/*allow_float*/ false, tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
|
||||
|
@ -234,22 +253,33 @@ TEST_F(QuantizeConvModelTest, FloatInputAndOutput) {
|
|||
EXPECT_EQ(subgraph->tensors[output_idx]->type, TensorType_FLOAT32);
|
||||
EXPECT_EQ(subgraph->tensors[output_idx]->name, "output");
|
||||
// The original input and output has been renamed.
|
||||
EXPECT_EQ(subgraph->tensors[quant_op->outputs[0]]->name, "input_int8");
|
||||
EXPECT_EQ(subgraph->tensors[dequant_op->inputs[0]]->name, "output_int8");
|
||||
std::string control_suffix =
|
||||
(tensor_type_ == TensorType_INT16) ? "int16" : "int8";
|
||||
EXPECT_EQ(subgraph->tensors[quant_op->outputs[0]]->name,
|
||||
"input_" + control_suffix);
|
||||
EXPECT_EQ(subgraph->tensors[dequant_op->inputs[0]]->name,
|
||||
"output_" + control_suffix);
|
||||
for (int tensor_idx = 0; tensor_idx < subgraph->tensors.size();
|
||||
++tensor_idx) {
|
||||
const auto& tensor = subgraph->tensors[tensor_idx];
|
||||
if (input_idx != tensor_idx && output_idx != tensor_idx) {
|
||||
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
|
||||
tensor->type == TensorType_INT8);
|
||||
if (tensor_type_ == TensorType_INT8) {
|
||||
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
|
||||
tensor->type == TensorType_INT8);
|
||||
} else if (tensor_type_ == TensorType_INT16) {
|
||||
EXPECT_TRUE(tensor->type == TensorType_INT64 || // bias
|
||||
tensor->type == TensorType_INT8 || // weights
|
||||
tensor->type == TensorType_INT16); // activations
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantizeConvModelTest, Uint8InputAndOutput) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_UINT8,
|
||||
TensorType_UINT8, &error_reporter_);
|
||||
TEST_P(QuantizeConvModelTest, Uint8InputAndOutput) {
|
||||
auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_UINT8,
|
||||
TensorType_UINT8, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
|
||||
|
@ -326,21 +356,27 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest {
|
|||
};
|
||||
|
||||
TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
const uint8_t* buffer = builder_.GetBufferPointer();
|
||||
const Model* output_model = GetModel(buffer);
|
||||
ASSERT_TRUE(output_model);
|
||||
}
|
||||
|
||||
class QuantizeConcatModelTest : public QuantizeModelTest {
|
||||
class QuantizeConcatModelTest : public QuantizeModelTest,
|
||||
public testing::WithParamInterface<TensorType> {
|
||||
protected:
|
||||
QuantizeConcatModelTest() {
|
||||
input_model_ = ReadModel(internal::kFloatConcatMax5Max10Max10);
|
||||
readonly_model_ = input_model_->GetModel();
|
||||
readonly_model_->UnPackTo(&model_);
|
||||
}
|
||||
|
||||
void SetUp() override { tensor_type_ = GetParam(); }
|
||||
|
||||
TensorType tensor_type_;
|
||||
};
|
||||
|
||||
// There are two inputs for concat, "input0" and "input1". "input0" has [0, 5]
|
||||
|
@ -352,9 +388,10 @@ class QuantizeConcatModelTest : public QuantizeModelTest {
|
|||
// input0 -> requant -> input0_requant \
|
||||
// concat - output
|
||||
// input1 /
|
||||
TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
TEST_P(QuantizeConcatModelTest, AddRequantBeforeConcat) {
|
||||
auto status =
|
||||
QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
false, tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
// There is only one subgraph.
|
||||
|
@ -373,32 +410,51 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
|
|||
EXPECT_EQ(model_.operator_codes[concat->opcode_index]->builtin_code,
|
||||
BuiltinOperator_CONCATENATION);
|
||||
|
||||
auto zero_point_control = tensor_type_ == TensorType_INT8 ? -128 : 0;
|
||||
/*
|
||||
input0_scale_control
|
||||
INT8: (5-0) / (2^8 - 1)
|
||||
INT16: (5-0) / (2^16 / 2 - 1)
|
||||
input1_scale
|
||||
INT8: (10-0) / (2^8 - 1)
|
||||
INT16: (10-0) / (2^16 / 2 - 1)
|
||||
*/
|
||||
auto input0_scale_control =
|
||||
tensor_type_ == TensorType_INT8 ? 0.019607844 : 0.00015259254;
|
||||
auto input1_scale =
|
||||
tensor_type_ == TensorType_INT8 ? 0.039215688 : 0.00030518509;
|
||||
|
||||
// There should be 4 tensors: input0, input1, input0_requantized, output.
|
||||
EXPECT_EQ(subgraph->tensors.size(), 4);
|
||||
EXPECT_EQ(subgraph->tensors[0]->type, TensorType_INT8);
|
||||
EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
|
||||
EXPECT_EQ(subgraph->tensors[0]->name, "input0");
|
||||
EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
|
||||
EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0], 0.019607844);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0], -128);
|
||||
EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT8);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0],
|
||||
input0_scale_control);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0],
|
||||
zero_point_control);
|
||||
EXPECT_EQ(subgraph->tensors[1]->type, tensor_type_);
|
||||
EXPECT_EQ(subgraph->tensors[1]->name, "input1");
|
||||
EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1);
|
||||
EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], 0.039215688);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0], -128);
|
||||
EXPECT_EQ(subgraph->tensors[2]->type, TensorType_INT8);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], input1_scale);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0],
|
||||
zero_point_control);
|
||||
EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
|
||||
EXPECT_EQ(subgraph->tensors[2]->name, "output");
|
||||
EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
|
||||
EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], 0.039215688);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0], -128);
|
||||
EXPECT_EQ(subgraph->tensors[3]->type, TensorType_INT8);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], input1_scale);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0],
|
||||
zero_point_control);
|
||||
EXPECT_EQ(subgraph->tensors[3]->type, tensor_type_);
|
||||
EXPECT_EQ(subgraph->tensors[3]->name, "input0_requantized");
|
||||
EXPECT_EQ(subgraph->tensors[3]->quantization->scale.size(), 1);
|
||||
EXPECT_EQ(subgraph->tensors[3]->quantization->zero_point.size(), 1);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->scale[0], 0.039215688);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->zero_point[0], -128);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->scale[0], input1_scale);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->zero_point[0],
|
||||
zero_point_control);
|
||||
|
||||
// The connection should be what is described in the comment.
|
||||
EXPECT_EQ(requant->inputs.size(), 1);
|
||||
|
@ -419,7 +475,9 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
|
|||
EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_QUANTIZE);
|
||||
EXPECT_EQ(model_.operator_codes[1]->version, 2);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(QuantizeConcatModelInst, QuantizeConcatModelTest,
|
||||
testing::ValuesIn({TensorType_INT8,
|
||||
TensorType_INT16}));
|
||||
class QuantizeSplitModelTest : public QuantizeModelTest {
|
||||
protected:
|
||||
QuantizeSplitModelTest() {
|
||||
|
@ -432,8 +490,9 @@ class QuantizeSplitModelTest : public QuantizeModelTest {
|
|||
// There are two outputs for split with different scales, the resulting model
|
||||
// should have the scales be hardcodes to the input scale value.
|
||||
TEST_F(QuantizeSplitModelTest, QuantizeSplit) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
// There is only one subgraph.
|
||||
|
@ -496,8 +555,9 @@ class QuantizeConvModel1Test : public QuantizeModelTest {
|
|||
};
|
||||
|
||||
TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
|
||||
|
@ -587,18 +647,26 @@ TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
|
|||
EXPECT_EQ(model_.operator_codes[0]->version, 3);
|
||||
}
|
||||
|
||||
class QuantizeConvModel2Test : public QuantizeModelTest {
|
||||
class QuantizeConvModel2Test : public QuantizeModelTest,
|
||||
public testing::WithParamInterface<TensorType> {
|
||||
protected:
|
||||
QuantizeConvModel2Test() {
|
||||
tensor_type_ = GetParam();
|
||||
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
|
||||
readonly_model_ = input_model_->GetModel();
|
||||
readonly_model_->UnPackTo(&model_);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
TensorType tensor_type_;
|
||||
};
|
||||
INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test,
|
||||
testing::ValuesIn({TensorType_INT8,
|
||||
TensorType_INT16}));
|
||||
|
||||
TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) {
|
||||
auto status =
|
||||
QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
false, tensor_type_, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
auto conv_op = subgraph->operators[0].get();
|
||||
|
@ -615,8 +683,10 @@ TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) {
|
|||
const auto output_tensor =
|
||||
subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
|
||||
|
||||
EXPECT_EQ(bias_tensor->type, TensorType_INT32);
|
||||
EXPECT_EQ(input_tensor->type, TensorType_INT8);
|
||||
EXPECT_EQ(bias_tensor->type, tensor_type_ == TensorType_INT8
|
||||
? TensorType_INT32
|
||||
: TensorType_INT64);
|
||||
EXPECT_EQ(input_tensor->type, tensor_type_);
|
||||
EXPECT_EQ(weights_tensor->type, TensorType_INT8);
|
||||
|
||||
ASSERT_TRUE(weights_tensor->quantization);
|
||||
|
@ -644,17 +714,28 @@ TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) {
|
|||
}
|
||||
|
||||
const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
|
||||
ASSERT_EQ(bias_buffer->data.size(), sizeof(int32_t) * bias_tensor->shape[0]);
|
||||
const int32_t* bias_values =
|
||||
reinterpret_cast<int32_t*>(bias_buffer->data.data());
|
||||
auto control_size = tensor_type_ == TensorType_INT8
|
||||
? sizeof(int32_t) * bias_tensor->shape[0]
|
||||
: sizeof(int64_t) * bias_tensor->shape[0];
|
||||
|
||||
ASSERT_EQ(bias_buffer->data.size(), control_size);
|
||||
const auto original_bias_buffer =
|
||||
readonly_model_->buffers()->Get(bias_tensor->buffer);
|
||||
const float* bias_float_buffer =
|
||||
reinterpret_cast<const float*>(original_bias_buffer->data()->data());
|
||||
|
||||
for (size_t i = 0; i < out_channel_size; i++) {
|
||||
auto dequantized_value = bias_values[i] * bias_scales[i];
|
||||
EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
|
||||
if (tensor_type_ == TensorType_INT8) {
|
||||
int32_t* bias_values = reinterpret_cast<int32_t*>(bias_buffer->data.data());
|
||||
for (size_t i = 0; i < out_channel_size; i++) {
|
||||
auto dequantized_value = bias_values[i] * bias_scales[i];
|
||||
EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
|
||||
}
|
||||
} else if (tensor_type_ == TensorType_INT16) {
|
||||
int64_t* bias_values = reinterpret_cast<int64_t*>(bias_buffer->data.data());
|
||||
for (size_t i = 0; i < out_channel_size; i++) {
|
||||
auto dequantized_value = bias_values[i] * bias_scales[i];
|
||||
EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
|
||||
}
|
||||
}
|
||||
|
||||
const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
|
||||
|
@ -695,8 +776,9 @@ class QuantizeSoftmaxTest : public QuantizeModelTest {
|
|||
};
|
||||
|
||||
TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
|
@ -755,8 +837,9 @@ class QuantizeAvgPoolTest : public QuantizeModelTest {
|
|||
};
|
||||
|
||||
TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
|
@ -816,8 +899,9 @@ class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest {
|
|||
};
|
||||
|
||||
TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
// Verify Reshape is quantized.
|
||||
|
@ -863,8 +947,9 @@ TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
|
|||
}
|
||||
|
||||
TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
// Verify ADD is quantized.
|
||||
|
@ -923,8 +1008,9 @@ class QuantizeConstInputTest : public QuantizeModelTest {
|
|||
};
|
||||
|
||||
TEST_F(QuantizeConstInputTest, VerifyConstOpInput) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
// Verify ConstOp is quantized.
|
||||
|
@ -965,8 +1051,9 @@ class QuantizeArgMaxTest : public QuantizeModelTest {
|
|||
};
|
||||
|
||||
TEST_F(QuantizeArgMaxTest, VerifyArgMax) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
|
@ -1008,8 +1095,9 @@ class QuantizeLSTMTest : public QuantizeModelTest {
|
|||
|
||||
TEST_F(QuantizeLSTMTest, VerifyLSTM) {
|
||||
// Quantize model.
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32,
|
||||
TensorType_FLOAT32, &error_reporter_);
|
||||
auto status = QuantizeModelAllOperators(
|
||||
&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
// Read expected model.
|
||||
|
@ -1067,8 +1155,9 @@ class QuantizeLSTM2Test : public QuantizeModelTest {
|
|||
|
||||
TEST_F(QuantizeLSTM2Test, VerifyLSTM) {
|
||||
// Quantize model.
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32,
|
||||
TensorType_FLOAT32, &error_reporter_);
|
||||
auto status = QuantizeModelAllOperators(
|
||||
&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
// Read expected model.
|
||||
|
@ -1126,8 +1215,9 @@ class QuantizeSVDFTest : public QuantizeModelTest {
|
|||
|
||||
TEST_F(QuantizeSVDFTest, VerifySVDF) {
|
||||
// Quantize model.
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
// Read expected model.
|
||||
|
@ -1184,8 +1274,9 @@ class QuantizeFCTest : public QuantizeModelTest {
|
|||
};
|
||||
|
||||
TEST_F(QuantizeFCTest, VerifyFC) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, false,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
|
@ -1224,7 +1315,9 @@ TEST_F(QuantizeFCTest, VerifyFC) {
|
|||
EXPECT_EQ(model_.operator_codes[1]->version, 1);
|
||||
}
|
||||
|
||||
class QuantizeCustomOpTest : public QuantizeModelTest {
|
||||
class QuantizeCustomOpTest
|
||||
: public QuantizeModelTest,
|
||||
public ::testing::WithParamInterface<tflite::TensorType> {
|
||||
protected:
|
||||
QuantizeCustomOpTest() {
|
||||
input_model_ = ReadModel(internal::kModelMixed);
|
||||
|
@ -1233,10 +1326,10 @@ class QuantizeCustomOpTest : public QuantizeModelTest {
|
|||
}
|
||||
};
|
||||
|
||||
TEST_F(QuantizeCustomOpTest, VerifyMixedQuantization) {
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||
/*allow_float=*/true, &error_reporter_);
|
||||
TEST_P(QuantizeCustomOpTest, VerifyMixedQuantization) {
|
||||
auto status = QuantizeModelAllOperators(
|
||||
&builder_, &model_, GetParam(), GetParam(),
|
||||
/*allow_float=*/true, GetParam(), &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
auto float_graph = readonly_model_->subgraphs()->Get(0);
|
||||
|
@ -1250,8 +1343,45 @@ TEST_F(QuantizeCustomOpTest, VerifyMixedQuantization) {
|
|||
BuiltinOperator_CUSTOM, BuiltinOperator_CUSTOM,
|
||||
BuiltinOperator_QUANTIZE, BuiltinOperator_SQUEEZE};
|
||||
const std::vector<TensorType> op_input_types = {
|
||||
TensorType_INT8, TensorType_INT8, TensorType_FLOAT32,
|
||||
TensorType_FLOAT32, TensorType_FLOAT32, TensorType_INT8};
|
||||
GetParam(), GetParam(), TensorType_FLOAT32,
|
||||
TensorType_FLOAT32, TensorType_FLOAT32, GetParam()};
|
||||
for (int i = 0; i < subgraph->operators.size(); ++i) {
|
||||
OperatorT* op = subgraph->operators[i].get();
|
||||
ASSERT_EQ(model_.operator_codes[op->opcode_index]->builtin_code,
|
||||
op_codes[i]);
|
||||
ASSERT_EQ(subgraph->tensors[op->inputs[0]]->type, op_input_types[i]);
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest,
|
||||
::testing::Values(TensorType_INT8, TensorType_INT16));
|
||||
|
||||
class QuantizeOp16x8Test : public QuantizeModelTest {
|
||||
protected:
|
||||
QuantizeOp16x8Test() {
|
||||
input_model_ = ReadModel(internal::kModelMixed16x8);
|
||||
readonly_model_ = input_model_->GetModel();
|
||||
readonly_model_->UnPackTo(&model_);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) {
|
||||
auto status = QuantizeModelAllOperators(
|
||||
&builder_, &model_, TensorType_INT16, TensorType_FLOAT32,
|
||||
/*allow_float=*/true, TensorType_INT16, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
auto float_graph = readonly_model_->subgraphs()->Get(0);
|
||||
// The original model conv_2d->log_softmax
|
||||
ASSERT_EQ(float_graph->operators()->size(), 2);
|
||||
// The resulting model should be:
|
||||
// conv_2d->dequantize->log_softmax
|
||||
ASSERT_EQ(subgraph->operators.size(), 3);
|
||||
const std::vector<BuiltinOperator> op_codes = {BuiltinOperator_CONV_2D,
|
||||
BuiltinOperator_DEQUANTIZE,
|
||||
BuiltinOperator_LOG_SOFTMAX};
|
||||
const std::vector<TensorType> op_input_types = {
|
||||
TensorType_INT16, TensorType_INT16, TensorType_FLOAT32};
|
||||
for (int i = 0; i < subgraph->operators.size(); ++i) {
|
||||
OperatorT* op = subgraph->operators[i].get();
|
||||
ASSERT_EQ(model_.operator_codes[op->opcode_index]->builtin_code,
|
||||
|
|
|
@ -48,6 +48,7 @@ const char* kModelWithArgMaxOp = "argmax.bin";
|
|||
const char* kModelWithFCOp = "fc.bin";
|
||||
|
||||
const char* kModelMixed = "mixed.bin";
|
||||
const char* kModelMixed16x8 = "mixed16x8.bin";
|
||||
|
||||
const char* kModelSplit = "split.bin";
|
||||
|
||||
|
|
|
@ -76,6 +76,11 @@ extern const char* kModelWithFCOp;
|
|||
// reshape->custom->custom->squeeze.
|
||||
extern const char* kModelMixed;
|
||||
|
||||
// Test model with mixed quantizable and
|
||||
// and un-quantizable ops for
|
||||
// activations in 16-bit.
|
||||
extern const char* kModelMixed16x8;
|
||||
|
||||
// Test model with split op.
|
||||
extern const char* kModelSplit;
|
||||
|
||||
|
|
Binary file not shown.
|
@ -1,6 +1,10 @@
|
|||
path: "tensorflow.lite.OpsSet"
|
||||
tf_class {
|
||||
is_instance: "<enum \'OpsSet\'>"
|
||||
member {
|
||||
name: "EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8"
|
||||
mtype: "<enum \'OpsSet\'>"
|
||||
}
|
||||
member {
|
||||
name: "SELECT_TF_OPS"
|
||||
mtype: "<enum \'OpsSet\'>"
|
||||
|
|
|
@ -12,6 +12,10 @@ tf_module {
|
|||
name: "GRAPHVIZ_DOT"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "INT16"
|
||||
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
|
||||
}
|
||||
member {
|
||||
name: "INT32"
|
||||
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
path: "tensorflow.lite.OpsSet"
|
||||
tf_class {
|
||||
is_instance: "<enum \'OpsSet\'>"
|
||||
member {
|
||||
name: "EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8"
|
||||
mtype: "<enum \'OpsSet\'>"
|
||||
}
|
||||
member {
|
||||
name: "SELECT_TF_OPS"
|
||||
mtype: "<enum \'OpsSet\'>"
|
||||
|
|
Loading…
Reference in New Issue