Added an option TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 to
enable sym quantization with activations in 16-bit and weigths in 8-bit.
This commit is contained in:
parent
2e98e89091
commit
a7899d7544
@ -93,6 +93,12 @@ class OpsSet(enum.Enum):
|
|||||||
# quantized implementations.
|
# quantized implementations.
|
||||||
TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8"
|
TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8"
|
||||||
|
|
||||||
|
# Convert model using only TensorFlow Lite operations with quantized int8 weights
|
||||||
|
# and int16 activations.
|
||||||
|
# Specifying this will throw an error for operations that do not yet have
|
||||||
|
# quantized implementations.
|
||||||
|
TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 = "TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8"
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
@ -224,6 +224,10 @@ class TFLiteConverterBase(object):
|
|||||||
self.target_spec.supported_ops) or
|
self.target_spec.supported_ops) or
|
||||||
self._smallest_supported_type() == constants.INT8)
|
self._smallest_supported_type() == constants.INT8)
|
||||||
|
|
||||||
|
def _is_int16x8_target_required(self):
|
||||||
|
return (set([OpsSet.TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8]) ==
|
||||||
|
set(self.target_spec.supported_ops))
|
||||||
|
|
||||||
def _smallest_supported_type(self):
|
def _smallest_supported_type(self):
|
||||||
if self.target_spec.supported_types:
|
if self.target_spec.supported_types:
|
||||||
return min(self.target_spec.supported_types, key=lambda x: x.size)
|
return min(self.target_spec.supported_types, key=lambda x: x.size)
|
||||||
@ -238,7 +242,9 @@ class TFLiteConverterBase(object):
|
|||||||
]))
|
]))
|
||||||
|
|
||||||
def _is_post_training_optimize(self):
|
def _is_post_training_optimize(self):
|
||||||
return self._is_int8_target_required() or self._any_optimization_enabled()
|
return self._is_int8_target_required() or \
|
||||||
|
self._is_int16x8_target_required() or \
|
||||||
|
self._any_optimization_enabled()
|
||||||
|
|
||||||
def _is_int8_weight_only_quantize(self):
|
def _is_int8_weight_only_quantize(self):
|
||||||
return (self._is_post_training_optimize() and
|
return (self._is_post_training_optimize() and
|
||||||
@ -255,11 +261,12 @@ class TFLiteConverterBase(object):
|
|||||||
|
|
||||||
def _calibrate_quantize_model(self, result, inference_input_type,
|
def _calibrate_quantize_model(self, result, inference_input_type,
|
||||||
inference_output_type, enable_mlir_quantizer):
|
inference_output_type, enable_mlir_quantizer):
|
||||||
allow_float = not self._is_int8_target_required()
|
allow_float = not self._is_int8_target_required() and not self._is_int16x8_target_required()
|
||||||
calibrate_quantize = _calibrator.Calibrator(result)
|
calibrate_quantize = _calibrator.Calibrator(result)
|
||||||
|
activations_type = constants.INT16 if self._is_int16x8_target_required() else constants.INT8
|
||||||
return calibrate_quantize.calibrate_and_quantize(
|
return calibrate_quantize.calibrate_and_quantize(
|
||||||
self.representative_dataset.input_gen, inference_input_type,
|
self.representative_dataset.input_gen, inference_input_type,
|
||||||
inference_output_type, allow_float, enable_mlir_quantizer)
|
inference_output_type, allow_float, activations_type, enable_mlir_quantizer)
|
||||||
|
|
||||||
def _get_base_converter_args(self):
|
def _get_base_converter_args(self):
|
||||||
"""Returns the base converter args.
|
"""Returns the base converter args.
|
||||||
|
@ -30,6 +30,7 @@ INT64 = dtypes.int64
|
|||||||
STRING = dtypes.string
|
STRING = dtypes.string
|
||||||
QUANTIZED_UINT8 = dtypes.uint8
|
QUANTIZED_UINT8 = dtypes.uint8
|
||||||
INT8 = dtypes.int8
|
INT8 = dtypes.int8
|
||||||
|
INT16 = dtypes.int16
|
||||||
COMPLEX64 = dtypes.complex64
|
COMPLEX64 = dtypes.complex64
|
||||||
TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF
|
TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF
|
||||||
TFLITE = _toco_flags_pb2.TFLITE
|
TFLITE = _toco_flags_pb2.TFLITE
|
||||||
@ -43,6 +44,7 @@ _tf_export(v1=["lite.constants.STRING"]).export_constant(__name__, "STRING")
|
|||||||
_tf_export(v1=["lite.constants.QUANTIZED_UINT8"]).export_constant(
|
_tf_export(v1=["lite.constants.QUANTIZED_UINT8"]).export_constant(
|
||||||
__name__, "QUANTIZED_UINT8")
|
__name__, "QUANTIZED_UINT8")
|
||||||
_tf_export(v1=["lite.constants.INT8"]).export_constant(__name__, "INT8")
|
_tf_export(v1=["lite.constants.INT8"]).export_constant(__name__, "INT8")
|
||||||
|
_tf_export(v1=["lite.constants.INT16"]).export_constant(__name__, "INT16")
|
||||||
_tf_export(v1=["lite.constants.TFLITE"]).export_constant(__name__, "TFLITE")
|
_tf_export(v1=["lite.constants.TFLITE"]).export_constant(__name__, "TFLITE")
|
||||||
_tf_export(v1=["lite.constants.GRAPHVIZ_DOT"]).export_constant(
|
_tf_export(v1=["lite.constants.GRAPHVIZ_DOT"]).export_constant(
|
||||||
__name__, "GRAPHVIZ_DOT")
|
__name__, "GRAPHVIZ_DOT")
|
||||||
@ -62,6 +64,7 @@ _allowed_symbols = [
|
|||||||
"STRING",
|
"STRING",
|
||||||
"QUANTIZED_UINT8",
|
"QUANTIZED_UINT8",
|
||||||
"INT8",
|
"INT8",
|
||||||
|
"INT16",
|
||||||
"COMPLEX64",
|
"COMPLEX64",
|
||||||
"TENSORFLOW_GRAPHDEF",
|
"TENSORFLOW_GRAPHDEF",
|
||||||
"TFLITE",
|
"TFLITE",
|
||||||
|
@ -769,9 +769,13 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('EnableMlirConverter', True), # enable mlir
|
# Quantize model to Int8: with enable mlir
|
||||||
('DisableMlirConverter', False)) # disable mlir
|
('UseTfliteBuiltinsIntEnableMLIR', [lite.OpsSet.TFLITE_BUILTINS_INT8], True),
|
||||||
def testCalibrateAndQuantizeBuiltinInt8(self, enable_mlir):
|
# Quantize model to Int8: with disable mlir
|
||||||
|
('UseTfliteBuiltinsIntDisableMLIR', [lite.OpsSet.TFLITE_BUILTINS_INT8], False),
|
||||||
|
# Quantize model to Int16: with disable mlir
|
||||||
|
('UseTfliteBuiltinsInt16DisableMLIR', [lite.OpsSet.TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8], False))
|
||||||
|
def testCalibrateAndQuantizeBuiltinInt(self, supported_ops, enable_mlir):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
|
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
|
||||||
sess = session.Session()
|
sess = session.Session()
|
||||||
@ -787,9 +791,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
quantized_converter = lite.TFLiteConverter.from_session(
|
quantized_converter = lite.TFLiteConverter.from_session(
|
||||||
sess, [inp], [output])
|
sess, [inp], [output])
|
||||||
quantized_converter.experimental_new_converter = enable_mlir
|
quantized_converter.experimental_new_converter = enable_mlir
|
||||||
quantized_converter.target_spec.supported_ops = [
|
quantized_converter.target_spec.supported_ops = supported_ops
|
||||||
lite.OpsSet.TFLITE_BUILTINS_INT8
|
|
||||||
]
|
|
||||||
quantized_converter.representative_dataset = calibration_gen
|
quantized_converter.representative_dataset = calibration_gen
|
||||||
quantized_tflite = quantized_converter.convert()
|
quantized_tflite = quantized_converter.convert()
|
||||||
self.assertTrue(quantized_tflite)
|
self.assertTrue(quantized_tflite)
|
||||||
|
@ -204,6 +204,7 @@ PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) {
|
|||||||
PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
||||||
int output_py_type,
|
int output_py_type,
|
||||||
bool allow_float,
|
bool allow_float,
|
||||||
|
int activations_py_type,
|
||||||
bool enable_mlir_quantizer) {
|
bool enable_mlir_quantizer) {
|
||||||
if (NoOpModel(*model_)) {
|
if (NoOpModel(*model_)) {
|
||||||
return python_utils::ConvertToPyString(model_str_->data(),
|
return python_utils::ConvertToPyString(model_str_->data(),
|
||||||
@ -212,6 +213,9 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
|||||||
|
|
||||||
TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
|
TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
|
||||||
TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
|
TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
|
||||||
|
TfLiteType activations_type =
|
||||||
|
python_utils::TfLiteTypeFromPyType(activations_py_type);
|
||||||
|
|
||||||
if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
|
if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
|
||||||
PyErr_SetString(PyExc_ValueError,
|
PyErr_SetString(PyExc_ValueError,
|
||||||
"Input/output type cannot be kTfLiteNoType");
|
"Input/output type cannot be kTfLiteNoType");
|
||||||
@ -230,7 +234,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
|||||||
status = tflite::optimize::QuantizeModel(
|
status = tflite::optimize::QuantizeModel(
|
||||||
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
|
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
|
||||||
TfLiteTypeToSchemaType(output_type), allow_float,
|
TfLiteTypeToSchemaType(output_type), allow_float,
|
||||||
error_reporter_.get());
|
TfLiteTypeToSchemaType(activations_type), error_reporter_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (status != kTfLiteOk) {
|
if (status != kTfLiteOk) {
|
||||||
@ -262,7 +266,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
|||||||
auto status = tflite::optimize::QuantizeModel(
|
auto status = tflite::optimize::QuantizeModel(
|
||||||
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
|
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
|
||||||
TfLiteTypeToSchemaType(output_type), allow_float, {op_name},
|
TfLiteTypeToSchemaType(output_type), allow_float, {op_name},
|
||||||
error_reporter_.get());
|
TensorType_INT8, error_reporter_.get());
|
||||||
if (status != kTfLiteOk) {
|
if (status != kTfLiteOk) {
|
||||||
error_reporter_->exception();
|
error_reporter_->exception();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -60,7 +60,8 @@ class CalibrationWrapper {
|
|||||||
PyObject* FeedTensor(PyObject* input_value);
|
PyObject* FeedTensor(PyObject* input_value);
|
||||||
|
|
||||||
PyObject* QuantizeModel(int input_py_type, int output_py_type,
|
PyObject* QuantizeModel(int input_py_type, int output_py_type,
|
||||||
bool allow_float, bool enable_mlir_quantizer = false);
|
bool allow_float, int activations_py_type,
|
||||||
|
bool enable_mlir_quantizer = false);
|
||||||
|
|
||||||
// Allows quantizing only the operator that produces the tensor with name
|
// Allows quantizing only the operator that produces the tensor with name
|
||||||
// operator_output_name. (This can be used to help debug.).
|
// operator_output_name. (This can be used to help debug.).
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||||
|
from tensorflow.lite.python import lite_constants
|
||||||
|
|
||||||
# Lazy load since some of the performance benchmark skylark rules
|
# Lazy load since some of the performance benchmark skylark rules
|
||||||
# break dependencies. Must use double quotes to match code internal rewrite
|
# break dependencies. Must use double quotes to match code internal rewrite
|
||||||
@ -55,7 +56,8 @@ class Calibrator(object):
|
|||||||
raise ValueError("Failed to parse the model.")
|
raise ValueError("Failed to parse the model.")
|
||||||
|
|
||||||
def calibrate_and_quantize(self, dataset_gen, input_type, output_type,
|
def calibrate_and_quantize(self, dataset_gen, input_type, output_type,
|
||||||
allow_float, enable_mlir_quantizer=False):
|
allow_float, activations_type = lite_constants.INT8,
|
||||||
|
enable_mlir_quantizer=False):
|
||||||
"""Calibrates the model with specified generator and then quantizes it.
|
"""Calibrates the model with specified generator and then quantizes it.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -69,6 +71,7 @@ class Calibrator(object):
|
|||||||
computation, useful when targeting an integer-only backend.
|
computation, useful when targeting an integer-only backend.
|
||||||
If False, an error will be thrown if an operation cannot be
|
If False, an error will be thrown if an operation cannot be
|
||||||
quantized, otherwise the model will fallback to float ops.
|
quantized, otherwise the model will fallback to float ops.
|
||||||
|
activations_type: A tf.dtype representing the desired type for activations
|
||||||
enable_mlir_quantizer: A boolean. True if wants to use mlir quantizer to
|
enable_mlir_quantizer: A boolean. True if wants to use mlir quantizer to
|
||||||
quantize the calibrated model.
|
quantize the calibrated model.
|
||||||
"""
|
"""
|
||||||
@ -78,6 +81,7 @@ class Calibrator(object):
|
|||||||
return self._calibrator.QuantizeModel(
|
return self._calibrator.QuantizeModel(
|
||||||
np.dtype(input_type.as_numpy_dtype()).num,
|
np.dtype(input_type.as_numpy_dtype()).num,
|
||||||
np.dtype(output_type.as_numpy_dtype()).num, allow_float,
|
np.dtype(output_type.as_numpy_dtype()).num, allow_float,
|
||||||
|
np.dtype(activations_type.as_numpy_dtype()).num,
|
||||||
enable_mlir_quantizer)
|
enable_mlir_quantizer)
|
||||||
|
|
||||||
def calibrate_and_quantize_single(self, dataset_gen, input_type, output_type,
|
def calibrate_and_quantize_single(self, dataset_gen, input_type, output_type,
|
||||||
|
@ -33,9 +33,13 @@ from tensorflow.python.platform import test
|
|||||||
class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('EnableMlirQuantizer', True), # enable mlir quantizer
|
# Activation type Int8 - enable mlir quantizer
|
||||||
('DisableMlirQuantizer', False)) # disable mlir quantizer
|
('UseActivationTypeInt8EnabledMlir', constants.INT8, True),
|
||||||
def test_calibration_with_quantization(self, enable_mlir):
|
# Activation type Int8 - disable mlir quantizer
|
||||||
|
('UseActivationTypeInt8DisabledMlir', constants.INT8, False),
|
||||||
|
# Activation type Int16
|
||||||
|
('UseActivationTypeInt16', constants.INT16, False))
|
||||||
|
def test_calibration_with_quantization(self, activations_type, enable_mlir):
|
||||||
model_path = resource_loader.get_path_to_datafile(
|
model_path = resource_loader.get_path_to_datafile(
|
||||||
'test_data/mobilenet_like_model.bin')
|
'test_data/mobilenet_like_model.bin')
|
||||||
float_model = open(model_path, 'rb').read()
|
float_model = open(model_path, 'rb').read()
|
||||||
@ -49,13 +53,18 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||||
constants.FLOAT,
|
constants.FLOAT,
|
||||||
constants.FLOAT, False,
|
constants.FLOAT, False,
|
||||||
|
activations_type,
|
||||||
enable_mlir)
|
enable_mlir)
|
||||||
self.assertIsNotNone(quantized_model)
|
self.assertIsNotNone(quantized_model)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('EnableMlirQuantizer', True), # enable mlir quantizer
|
# Activation type Int8 - enable mlir quantizer
|
||||||
('DisableMlirQuantizer', False)) # disable mlir quantizer
|
('UseActivationTypeInt8EnabledMlir', constants.INT8, True),
|
||||||
def test_calibration_with_quantization_allow_float(self, enable_mlir):
|
# Activation type Int8 - disable mlir quantizer
|
||||||
|
('UseActivationTypeInt8DisableMlir', constants.INT8, False),
|
||||||
|
# Activation type Int16 - disable mlir quantizer
|
||||||
|
('UseActivationTypeInt16', constants.INT16, False))
|
||||||
|
def test_calibration_with_quantization_allow_float(self, activations_type, enable_mlir):
|
||||||
model_path = resource_loader.get_path_to_datafile(
|
model_path = resource_loader.get_path_to_datafile(
|
||||||
'test_data/mobilenet_like_model.bin')
|
'test_data/mobilenet_like_model.bin')
|
||||||
float_model = open(model_path, 'rb').read()
|
float_model = open(model_path, 'rb').read()
|
||||||
@ -69,6 +78,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||||
constants.FLOAT,
|
constants.FLOAT,
|
||||||
constants.FLOAT, True,
|
constants.FLOAT, True,
|
||||||
|
activations_type,
|
||||||
enable_mlir)
|
enable_mlir)
|
||||||
self.assertIsNotNone(quantized_model)
|
self.assertIsNotNone(quantized_model)
|
||||||
|
|
||||||
@ -88,9 +98,13 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
self.assertIsNotNone(quantized_model)
|
self.assertIsNotNone(quantized_model)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('EnableMlirQuantizer', True), # enable mlir quantizer
|
# Activation type Int8 - enable mlir quantizer
|
||||||
('DisableMlirQuantizer', False)) # disable mlir quantizer
|
('UseActivationTypeInt8 - EnableMlirQuantizer', constants.INT8, True),
|
||||||
def test_calibration_with_quantization_multiple_inputs(self, enable_mlir):
|
# Activation type Int8 - disable mlir quantizer
|
||||||
|
('UseActivationTypeInt8 - DisableMlirQuantizer', constants.INT8, False),
|
||||||
|
# Activation type Int16 - disable mlir quantizer
|
||||||
|
('UseActivationTypeInt16 - DisableEnableMlirQuantizer', constants.INT16, False))
|
||||||
|
def test_calibration_with_quantization_multiple_inputs(self, activations_type, enable_mlir):
|
||||||
# Load multi add model from test data.
|
# Load multi add model from test data.
|
||||||
# This model has 4 inputs of size (1, 8, 8, 3).
|
# This model has 4 inputs of size (1, 8, 8, 3).
|
||||||
model_path = resource_loader.get_path_to_datafile(
|
model_path = resource_loader.get_path_to_datafile(
|
||||||
@ -106,6 +120,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||||
constants.FLOAT,
|
constants.FLOAT,
|
||||||
constants.FLOAT, False,
|
constants.FLOAT, False,
|
||||||
|
activations_type,
|
||||||
enable_mlir)
|
enable_mlir)
|
||||||
self.assertIsNotNone(quantized_model)
|
self.assertIsNotNone(quantized_model)
|
||||||
|
|
||||||
@ -148,7 +163,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, 'Size mismatch'):
|
with self.assertRaisesRegex(ValueError, 'Size mismatch'):
|
||||||
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
||||||
constants.FLOAT, False, enable_mlir)
|
constants.FLOAT, False,
|
||||||
|
enable_mlir)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('EnableMlirQuantizer', True), # enable mlir quantizer
|
('EnableMlirQuantizer', True), # enable mlir quantizer
|
||||||
@ -166,7 +182,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
||||||
constants.FLOAT, False, enable_mlir)
|
constants.FLOAT, False,
|
||||||
|
constants.INT8, enable_mlir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -64,6 +64,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||||||
property.inputs = {{0, {}}, {1, {}}};
|
property.inputs = {{0, {}}, {1, {}}};
|
||||||
property.outputs = {{0, {}}};
|
property.outputs = {{0, {}}};
|
||||||
property.version = 2;
|
property.version = 2;
|
||||||
|
property.quantize_input_as_activations = true;
|
||||||
break;
|
break;
|
||||||
case BuiltinOperator_ARG_MAX:
|
case BuiltinOperator_ARG_MAX:
|
||||||
property.inputs = {{0, {}}};
|
property.inputs = {{0, {}}};
|
||||||
@ -176,7 +177,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||||||
// LogSoftmax requires output with 16/256 as scale and 127 as zero point.
|
// LogSoftmax requires output with 16/256 as scale and 127 as zero point.
|
||||||
TensorProperty tensor_property;
|
TensorProperty tensor_property;
|
||||||
tensor_property.restriction = true;
|
tensor_property.restriction = true;
|
||||||
tensor_property.restricted_value = {16.0 / 256.0, 127};
|
tensor_property.restricted_value_int8 = {16.0 / 256.0, 127};
|
||||||
property.outputs = {{0, tensor_property}};
|
property.outputs = {{0, tensor_property}};
|
||||||
property.version = 2;
|
property.version = 2;
|
||||||
break;
|
break;
|
||||||
@ -186,7 +187,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||||||
// Logistic requires output with 1/256 as scale and -128 as zero point.
|
// Logistic requires output with 1/256 as scale and -128 as zero point.
|
||||||
TensorProperty tensor_property;
|
TensorProperty tensor_property;
|
||||||
tensor_property.restriction = true;
|
tensor_property.restriction = true;
|
||||||
tensor_property.restricted_value = {1 / 256.0, -128};
|
tensor_property.restricted_value_int8 = {1 / 256.0, -128};
|
||||||
|
tensor_property.restricted_value_int16 = {1 / 32768.0, 0};
|
||||||
property.outputs = {{0, tensor_property}};
|
property.outputs = {{0, tensor_property}};
|
||||||
property.version = 2;
|
property.version = 2;
|
||||||
break;
|
break;
|
||||||
@ -741,7 +743,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||||||
// L2 Norm requires output with 1/128 as scale and 0 as zero point.
|
// L2 Norm requires output with 1/128 as scale and 0 as zero point.
|
||||||
TensorProperty tensor_property;
|
TensorProperty tensor_property;
|
||||||
tensor_property.restriction = true;
|
tensor_property.restriction = true;
|
||||||
tensor_property.restricted_value = {1 / 128.0, 0};
|
tensor_property.restricted_value_int8 = {1 / 128.0, 0};
|
||||||
property.outputs = {{0, tensor_property}};
|
property.outputs = {{0, tensor_property}};
|
||||||
property.version = 2;
|
property.version = 2;
|
||||||
break;
|
break;
|
||||||
@ -756,6 +758,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||||||
property.arbitrary_inputs = true;
|
property.arbitrary_inputs = true;
|
||||||
property.outputs = {{0, {}}};
|
property.outputs = {{0, {}}};
|
||||||
property.restrict_same_input_output_scale = true;
|
property.restrict_same_input_output_scale = true;
|
||||||
|
property.quantize_input_as_activations = true;
|
||||||
property.version = 2;
|
property.version = 2;
|
||||||
break;
|
break;
|
||||||
case BuiltinOperator_MEAN:
|
case BuiltinOperator_MEAN:
|
||||||
@ -767,6 +770,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||||||
property.arbitrary_inputs = true;
|
property.arbitrary_inputs = true;
|
||||||
property.outputs = {{0, {}}};
|
property.outputs = {{0, {}}};
|
||||||
property.restrict_same_input_output_scale = true;
|
property.restrict_same_input_output_scale = true;
|
||||||
|
property.quantize_input_as_activations = true;
|
||||||
property.version = 2;
|
property.version = 2;
|
||||||
break;
|
break;
|
||||||
case BuiltinOperator_MUL:
|
case BuiltinOperator_MUL:
|
||||||
@ -778,6 +782,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||||||
property.arbitrary_inputs = true;
|
property.arbitrary_inputs = true;
|
||||||
property.outputs = {{0, {}}};
|
property.outputs = {{0, {}}};
|
||||||
property.restrict_same_input_output_scale = true;
|
property.restrict_same_input_output_scale = true;
|
||||||
|
property.restrict_same_input_output_scale = true;
|
||||||
property.version = 2;
|
property.version = 2;
|
||||||
break;
|
break;
|
||||||
case BuiltinOperator_PAD:
|
case BuiltinOperator_PAD:
|
||||||
@ -840,7 +845,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||||||
// Softmax requires output with 1/256 as scale and -128 as zero point.
|
// Softmax requires output with 1/256 as scale and -128 as zero point.
|
||||||
TensorProperty tensor_property;
|
TensorProperty tensor_property;
|
||||||
tensor_property.restriction = true;
|
tensor_property.restriction = true;
|
||||||
tensor_property.restricted_value = {1 / 256.0, -128};
|
tensor_property.restricted_value_int8 = {1 / 256.0, -128};
|
||||||
|
tensor_property.restricted_value_int16 = {1 / 32768.0, 0};
|
||||||
property.outputs = {{0, tensor_property}};
|
property.outputs = {{0, tensor_property}};
|
||||||
property.version = 2;
|
property.version = 2;
|
||||||
break;
|
break;
|
||||||
@ -866,7 +872,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||||||
// Tanh requires output with 1/128 as scale and 0 as zero point.
|
// Tanh requires output with 1/128 as scale and 0 as zero point.
|
||||||
TensorProperty tensor_property;
|
TensorProperty tensor_property;
|
||||||
tensor_property.restriction = true;
|
tensor_property.restriction = true;
|
||||||
tensor_property.restricted_value = {1 / 128.0, 0};
|
tensor_property.restricted_value_int8 = {1 / 128.0, 0};
|
||||||
|
tensor_property.restricted_value_int16 = {1 / 32768.0, 0};
|
||||||
property.outputs = {{0, tensor_property}};
|
property.outputs = {{0, tensor_property}};
|
||||||
property.version = 2;
|
property.version = 2;
|
||||||
break;
|
break;
|
||||||
|
@ -43,7 +43,8 @@ struct TensorProperty {
|
|||||||
// Constraints.
|
// Constraints.
|
||||||
bool restriction = false;
|
bool restriction = false;
|
||||||
// scale/zero_point hardcoded.
|
// scale/zero_point hardcoded.
|
||||||
std::pair<float, int> restricted_value = {0.0, 0};
|
std::pair<float, int> restricted_value_int8 = {0.0, 0};
|
||||||
|
std::pair<float, int> restricted_value_int16 = {0.0, 0};
|
||||||
|
|
||||||
// Use derived scale.
|
// Use derived scale.
|
||||||
bool use_derived_scale = false;
|
bool use_derived_scale = false;
|
||||||
@ -93,6 +94,13 @@ struct OperatorProperty {
|
|||||||
|
|
||||||
// Op version.
|
// Op version.
|
||||||
int version = 1;
|
int version = 1;
|
||||||
|
|
||||||
|
// When we quantize activations into 16 bit and weights into 8 bit,
|
||||||
|
// we want to quantize all inputs, including constant tensors,
|
||||||
|
// for the operators like Add, Mul into 16-bit as well. The constant
|
||||||
|
// inputs are quantized as weights and this variable indicates
|
||||||
|
// that we want to do quantizations of these tensors as activations.
|
||||||
|
bool quantize_input_as_activations = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
@ -30,6 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/minimal_logging.h"
|
#include "tensorflow/lite/minimal_logging.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
#include "tensorflow/lite/tools/optimize/model_utils.h"
|
#include "tensorflow/lite/tools/optimize/model_utils.h"
|
||||||
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace optimize {
|
namespace optimize {
|
||||||
@ -85,6 +85,46 @@ void GetAsymmetricQuantizationParams(
|
|||||||
quantization_params->zero_point = std::vector<int64_t>(1, zero_point);
|
quantization_params->zero_point = std::vector<int64_t>(1, zero_point);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GetSymmetricQuantizationParams(
|
||||||
|
float min, float max, const int half_quant_range,
|
||||||
|
QuantizationParametersT* quantization_params) {
|
||||||
|
// Adjust the boundaries to guarantee 0 is included.
|
||||||
|
min = std::min(min, 0.0f);
|
||||||
|
max = std::max(max, 0.0f);
|
||||||
|
const float scale = std::max(std::abs(max), std::abs(min)) / half_quant_range;
|
||||||
|
int64_t zero_point = 0;
|
||||||
|
quantization_params->min = std::vector<float>(1, min);
|
||||||
|
quantization_params->max = std::vector<float>(1, max);
|
||||||
|
quantization_params->scale = std::vector<float>(1, scale);
|
||||||
|
quantization_params->zero_point = std::vector<int64_t>(1, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus GetQuantizationParams(TensorT* tensor, TensorType activations_type,
|
||||||
|
QuantizationParametersT* quantization_params,
|
||||||
|
ErrorReporter* error_reporter) {
|
||||||
|
if (activations_type == TensorType_INT8) {
|
||||||
|
GetAsymmetricQuantizationParams(
|
||||||
|
tensor->quantization->min[0], tensor->quantization->max[0],
|
||||||
|
std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max(),
|
||||||
|
quantization_params);
|
||||||
|
} else if (activations_type == TensorType_INT16) {
|
||||||
|
float range = std::max(std::abs(tensor->quantization->min[0]),
|
||||||
|
std::abs(tensor->quantization->max[0]));
|
||||||
|
const float quantized_range = 32767.0;
|
||||||
|
const float scale = range / quantized_range;
|
||||||
|
quantization_params->min = std::vector<float>(1, -range);
|
||||||
|
quantization_params->max = std::vector<float>(1, range);
|
||||||
|
quantization_params->scale = std::vector<float>(1, scale);
|
||||||
|
quantization_params->zero_point = std::vector<int64_t>(1, 0);
|
||||||
|
} else {
|
||||||
|
error_reporter->Report(
|
||||||
|
"Unsupported activation type for quantize-activation: %s",
|
||||||
|
activations_type);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
// Set the max and min quantization parameter for a single tensor given its
|
// Set the max and min quantization parameter for a single tensor given its
|
||||||
// values.
|
// values.
|
||||||
void FillSingleMinMax(const float* const input, const uint64_t input_size,
|
void FillSingleMinMax(const float* const input, const uint64_t input_size,
|
||||||
@ -536,6 +576,7 @@ TfLiteStatus SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor,
|
|||||||
model, tensor, error_reporter);
|
model, tensor, error_reporter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class BiasType>
|
||||||
TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
||||||
float scaling_factor,
|
float scaling_factor,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
@ -548,25 +589,38 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
|||||||
uint64_t num_elements;
|
uint64_t num_elements;
|
||||||
TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements));
|
TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements));
|
||||||
|
|
||||||
std::vector<int32_t> final_buffer(num_elements);
|
std::vector<BiasType> final_buffer(num_elements);
|
||||||
const int32_t kScale = std::numeric_limits<int32_t>::max();
|
const BiasType kScale = std::numeric_limits<BiasType>::max();
|
||||||
|
|
||||||
for (size_t i = 0; i < num_elements; i++) {
|
for (size_t i = 0; i < num_elements; i++) {
|
||||||
const int32_t quantized_value = tflite::SafeCast<int32_t>(
|
const BiasType quantized_value = tflite::SafeCast<BiasType>(
|
||||||
TfLiteRound(float_data[i] * scaling_factor_inv));
|
TfLiteRound(float_data[i] * scaling_factor_inv));
|
||||||
final_buffer[i] = std::min(kScale, std::max(-kScale, quantized_value));
|
final_buffer[i] = std::min(kScale, std::max(-kScale, quantized_value));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the buffers and output type.
|
// Set the buffers and output type.
|
||||||
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
|
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
|
||||||
size_t buffer_size = num_elements * sizeof(int32_t);
|
size_t buffer_size = num_elements * sizeof(BiasType);
|
||||||
std::vector<float> scales(1, scaling_factor);
|
std::vector<float> scales(1, scaling_factor);
|
||||||
std::vector<int64_t> zero_points(1, 0);
|
std::vector<int64_t> zero_points(1, 0);
|
||||||
|
|
||||||
|
auto output_type = std::is_same<BiasType, std::int32_t>::value
|
||||||
|
? TensorType_INT32
|
||||||
|
: TensorType_INT64;
|
||||||
return AddQuantizationParams(scales, zero_points, 0, uint8_buffer,
|
return AddQuantizationParams(scales, zero_points, 0, uint8_buffer,
|
||||||
buffer_size, TensorType_INT32, model, tensor,
|
buffer_size, output_type, model, tensor,
|
||||||
error_reporter);
|
error_reporter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template TfLiteStatus SymmetricPerLayerBiasQuantize<std::int32_t>(
|
||||||
|
ModelT* model, TensorT* tensor, float scaling_factor,
|
||||||
|
ErrorReporter* error_reporter);
|
||||||
|
|
||||||
|
template TfLiteStatus SymmetricPerLayerBiasQuantize<std::int64_t>(
|
||||||
|
ModelT* model, TensorT* tensor, float scaling_factor,
|
||||||
|
ErrorReporter* error_reporter);
|
||||||
|
|
||||||
|
template <class BiasType>
|
||||||
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
||||||
float input_scale,
|
float input_scale,
|
||||||
const float* weight_scales,
|
const float* weight_scales,
|
||||||
@ -583,14 +637,14 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
|||||||
uint64_t num_elements;
|
uint64_t num_elements;
|
||||||
TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements));
|
TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements));
|
||||||
|
|
||||||
std::vector<int32_t> final_buffer(num_elements);
|
std::vector<BiasType> final_buffer(num_elements);
|
||||||
const int32_t kScale = std::numeric_limits<int32_t>::max();
|
const BiasType kScale = std::numeric_limits<BiasType>::max();
|
||||||
|
|
||||||
for (int32_t channel_idx = 0; channel_idx < number_of_dimension;
|
for (int32_t channel_idx = 0; channel_idx < number_of_dimension;
|
||||||
channel_idx++) {
|
channel_idx++) {
|
||||||
float scaling_factor = scales[channel_idx];
|
float scaling_factor = scales[channel_idx];
|
||||||
float scaling_factor_inv = (scaling_factor == 0) ? 0 : 1.0 / scaling_factor;
|
float scaling_factor_inv = (scaling_factor == 0) ? 0 : 1.0 / scaling_factor;
|
||||||
const int32_t quantized_value = tflite::SafeCast<int32_t>(
|
const BiasType quantized_value = tflite::SafeCast<BiasType>(
|
||||||
TfLiteRound(float_data[channel_idx] * scaling_factor_inv));
|
TfLiteRound(float_data[channel_idx] * scaling_factor_inv));
|
||||||
final_buffer[channel_idx] =
|
final_buffer[channel_idx] =
|
||||||
std::min(kScale, std::max(-kScale, quantized_value));
|
std::min(kScale, std::max(-kScale, quantized_value));
|
||||||
@ -598,12 +652,26 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
|||||||
|
|
||||||
// Set the buffers and output type.
|
// Set the buffers and output type.
|
||||||
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
|
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
|
||||||
size_t buffer_size = num_elements * sizeof(int32_t);
|
size_t buffer_size = num_elements * sizeof(BiasType);
|
||||||
std::vector<int64_t> zero_point(scales.size(), 0);
|
std::vector<int64_t> zero_point(scales.size(), 0);
|
||||||
|
|
||||||
|
auto output_type = std::is_same<BiasType, std::int32_t>::value
|
||||||
|
? TensorType_INT32
|
||||||
|
: TensorType_INT64;
|
||||||
return AddQuantizationParams(scales, zero_point, 0, uint8_buffer, buffer_size,
|
return AddQuantizationParams(scales, zero_point, 0, uint8_buffer, buffer_size,
|
||||||
TensorType_INT32, model, tensor, error_reporter);
|
output_type, model, tensor, error_reporter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template TfLiteStatus SymmetricPerChannelBiasQuantize<std::int64_t>(
|
||||||
|
ModelT* model, TensorT* tensor, float input_scale,
|
||||||
|
const float* weight_scales, int number_of_dimension,
|
||||||
|
ErrorReporter* error_reporter);
|
||||||
|
|
||||||
|
template TfLiteStatus SymmetricPerChannelBiasQuantize<std::int32_t>(
|
||||||
|
ModelT* model, TensorT* tensor, float input_scale,
|
||||||
|
const float* weight_scales, int number_of_dimension,
|
||||||
|
ErrorReporter* error_reporter);
|
||||||
|
|
||||||
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
|
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
|
||||||
int per_axis_index, ErrorReporter* error_reporter) {
|
int per_axis_index, ErrorReporter* error_reporter) {
|
||||||
// TODO(suharshs): Currently we conflate quantizing weights and constants. Its
|
// TODO(suharshs): Currently we conflate quantizing weights and constants. Its
|
||||||
@ -645,12 +713,12 @@ float GetEffectiveScale(ModelT* model, SubGraphT* subgraph, int op_idx,
|
|||||||
return scale;
|
return scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
void QuantizeActivation(TensorT* tensor) {
|
TfLiteStatus QuantizeActivation(TensorT* tensor, TensorType activations_type,
|
||||||
GetAsymmetricQuantizationParams(
|
ErrorReporter* error_reporter) {
|
||||||
tensor->quantization->min[0], tensor->quantization->max[0],
|
TF_LITE_ENSURE_STATUS(GetQuantizationParams(
|
||||||
std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max(),
|
tensor, activations_type, tensor->quantization.get(), error_reporter));
|
||||||
tensor->quantization.get());
|
tensor->type = activations_type;
|
||||||
tensor->type = TensorType_INT8;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale) {
|
TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale) {
|
||||||
|
@ -113,12 +113,14 @@ TfLiteStatus SymmetricQuantizeFloatsToInt16(ModelT* model, TensorT* tensor,
|
|||||||
ErrorReporter* error_reporter);
|
ErrorReporter* error_reporter);
|
||||||
|
|
||||||
// Symmetrically quantized the bias for per-layer ops (i.e. FullyConnected).
|
// Symmetrically quantized the bias for per-layer ops (i.e. FullyConnected).
|
||||||
|
template <typename BiasType>
|
||||||
TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
||||||
float scaling_factor,
|
float scaling_factor,
|
||||||
ErrorReporter* error_reporter);
|
ErrorReporter* error_reporter);
|
||||||
|
|
||||||
// Symmetrically quantizes the bias for ops like Conv and DepthwiseConv.
|
// Symmetrically quantizes the bias for ops like Conv and DepthwiseConv.
|
||||||
// The scale of bias if weight_per_channel_scale[channel] * input_scale.
|
// The scale of bias if weight_per_channel_scale[channel] * input_scale.
|
||||||
|
template <typename BiasType>
|
||||||
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
||||||
float input_scale,
|
float input_scale,
|
||||||
const float* weight_scales,
|
const float* weight_scales,
|
||||||
@ -135,8 +137,14 @@ float GetEffectiveScale(ModelT* model, SubGraphT* subgraph, int op_idx,
|
|||||||
std::vector<int> intermediate_index,
|
std::vector<int> intermediate_index,
|
||||||
std::vector<float> factors);
|
std::vector<float> factors);
|
||||||
|
|
||||||
|
// Return quantization parameters depending on activations type.
|
||||||
|
TfLiteStatus GetQuantizationParams(TensorT* tensor, TensorType activations_type,
|
||||||
|
QuantizationParametersT* quantization_params,
|
||||||
|
ErrorReporter* error_reporter);
|
||||||
|
|
||||||
// Quantize activation.
|
// Quantize activation.
|
||||||
void QuantizeActivation(TensorT* tensor);
|
TfLiteStatus QuantizeActivation(TensorT* tensor, TensorType activations_type,
|
||||||
|
ErrorReporter* error_reporter);
|
||||||
|
|
||||||
// Quantize activation to 16bit.
|
// Quantize activation to 16bit.
|
||||||
TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale);
|
TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale);
|
||||||
|
@ -701,7 +701,7 @@ TEST_F(QuantizationUtilsTest, SymmetricPerLayerBiasQuantize) {
|
|||||||
model->buffers.push_back(std::move(buffer));
|
model->buffers.push_back(std::move(buffer));
|
||||||
|
|
||||||
// Call and verify.
|
// Call and verify.
|
||||||
EXPECT_EQ(SymmetricPerLayerBiasQuantize(
|
EXPECT_EQ(SymmetricPerLayerBiasQuantize<int32_t>(
|
||||||
model.get(), model->subgraphs[0]->tensors[0].get(),
|
model.get(), model->subgraphs[0]->tensors[0].get(),
|
||||||
input_scale * weight_scale, &error_reporter_),
|
input_scale * weight_scale, &error_reporter_),
|
||||||
kTfLiteOk);
|
kTfLiteOk);
|
||||||
@ -759,7 +759,7 @@ TEST_F(QuantizationUtilsTest, SymmetricPerChannelBiasQuantize) {
|
|||||||
model->buffers.push_back(std::move(buffer));
|
model->buffers.push_back(std::move(buffer));
|
||||||
|
|
||||||
// Call and verify.
|
// Call and verify.
|
||||||
EXPECT_EQ(SymmetricPerChannelBiasQuantize(
|
EXPECT_EQ(SymmetricPerChannelBiasQuantize<int32_t>(
|
||||||
model.get(), model->subgraphs[0]->tensors[0].get(), input_scale,
|
model.get(), model->subgraphs[0]->tensors[0].get(), input_scale,
|
||||||
weight_scales.data(), 2, &error_reporter_),
|
weight_scales.data(), 2, &error_reporter_),
|
||||||
kTfLiteOk);
|
kTfLiteOk);
|
||||||
|
@ -42,7 +42,9 @@ bool CreateQuantizedModel(const std::string& path) {
|
|||||||
tflite::StderrReporter error_reporter;
|
tflite::StderrReporter error_reporter;
|
||||||
if (tflite::optimize::QuantizeModel(
|
if (tflite::optimize::QuantizeModel(
|
||||||
&builder, &model, tflite::TensorType_FLOAT32,
|
&builder, &model, tflite::TensorType_FLOAT32,
|
||||||
tflite::TensorType_FLOAT32, &error_reporter) != kTfLiteOk) {
|
tflite::TensorType_FLOAT32,
|
||||||
|
// TODO: Pass required activation type if needed
|
||||||
|
tflite::TensorType_INT8, &error_reporter) != kTfLiteOk) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return WriteFile(path, builder.GetBufferPointer(), builder.GetSize());
|
return WriteFile(path, builder.GetBufferPointer(), builder.GetSize());
|
||||||
|
@ -64,6 +64,7 @@ operator_property::OperatorProperty GetOperatorProperty(
|
|||||||
TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
|
TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
|
||||||
const TensorT* weight_tensor, TensorT* bias_tensor,
|
const TensorT* weight_tensor, TensorT* bias_tensor,
|
||||||
bool is_per_channel, int channel_dim_index,
|
bool is_per_channel, int channel_dim_index,
|
||||||
|
const TensorType& activations_type,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
if (bias_tensor->shape.size() != 1) {
|
if (bias_tensor->shape.size() != 1) {
|
||||||
error_reporter->Report("Expected bias tensor shape to be 1.");
|
error_reporter->Report("Expected bias tensor shape to be 1.");
|
||||||
@ -92,9 +93,15 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
|
|||||||
weight_scales.size());
|
weight_scales.size());
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
return utils::SymmetricPerChannelBiasQuantize(
|
if (activations_type == tflite::TensorType_INT16) {
|
||||||
|
return utils::SymmetricPerChannelBiasQuantize<std::int64_t>(
|
||||||
model, bias_tensor, input_tensor->quantization->scale[0],
|
model, bias_tensor, input_tensor->quantization->scale[0],
|
||||||
weight_scales.data(), channel_dim_size, error_reporter);
|
weight_scales.data(), channel_dim_size, error_reporter);
|
||||||
|
} else {
|
||||||
|
return utils::SymmetricPerChannelBiasQuantize<std::int32_t>(
|
||||||
|
model, bias_tensor, input_tensor->quantization->scale[0],
|
||||||
|
weight_scales.data(), channel_dim_size, error_reporter);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (weight_scales.size() != 1) {
|
if (weight_scales.size() != 1) {
|
||||||
error_reporter->Report(
|
error_reporter->Report(
|
||||||
@ -102,40 +109,54 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
|
|||||||
weight_scales.size());
|
weight_scales.size());
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
return utils::SymmetricPerLayerBiasQuantize(
|
if (activations_type == tflite::TensorType_INT16) {
|
||||||
|
return utils::SymmetricPerLayerBiasQuantize<std::int64_t>(
|
||||||
model, bias_tensor,
|
model, bias_tensor,
|
||||||
input_tensor->quantization->scale[0] * weight_scales[0],
|
input_tensor->quantization->scale[0] * weight_scales[0],
|
||||||
error_reporter);
|
error_reporter);
|
||||||
|
} else {
|
||||||
|
return utils::SymmetricPerLayerBiasQuantize<std::int32_t>(
|
||||||
|
model, bias_tensor,
|
||||||
|
input_tensor->quantization->scale[0] * weight_scales[0],
|
||||||
|
error_reporter);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
// True if the tensor type has to be modified.
|
// True if the tensor type has to be modified.
|
||||||
bool TensorTypeChangeRequired(const TensorT* tensor, const TensorType& type) {
|
bool TensorTypeChangeRequired(const TensorT* tensor, const TensorType& type) {
|
||||||
// The quantized model is type INT8, so if the user provided type is INT8, we
|
// The quantized model is type INT8/INT16, so if the user provided type is
|
||||||
// do not have to do any custom logic. Additionally, if the current tensor
|
// INT8/INT16, we do not have to do any custom logic. Additionally, if the
|
||||||
// isn't INT8 quantized, the custom type doesn't apply.
|
// current tensor isn't INT8/INT16 quantized, the custom type doesn't apply.
|
||||||
return (type != TensorType_INT8 && tensor->type == TensorType_INT8 &&
|
bool int8check = type != TensorType_INT8 && tensor->type == TensorType_INT8 &&
|
||||||
!tensor->quantization->scale.empty());
|
!tensor->quantization->scale.empty();
|
||||||
|
bool int16check = type != TensorType_INT16 &&
|
||||||
|
tensor->type == TensorType_INT16 &&
|
||||||
|
!tensor->quantization->scale.empty();
|
||||||
|
return (int8check || int16check);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sets the input type, adding a Leading Op node at the start of the model if
|
// Sets the input type, adding a Leading Op node at the start of the model if
|
||||||
// necessary.
|
// necessary.
|
||||||
// Returns the new input tensor index.
|
// Returns the new input tensor index.
|
||||||
int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
|
int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
|
||||||
const int32_t tensor_idx, const TensorType& input_type) {
|
const int32_t tensor_idx, const TensorType& input_type,
|
||||||
|
const TensorType& activations_type) {
|
||||||
TensorT* tensor = subgraph->tensors[tensor_idx].get();
|
TensorT* tensor = subgraph->tensors[tensor_idx].get();
|
||||||
if (!TensorTypeChangeRequired(tensor, input_type)) {
|
if (!TensorTypeChangeRequired(tensor, input_type)) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
if (input_type == TensorType_FLOAT32 || input_type == TensorType_UINT8) {
|
if (input_type == TensorType_FLOAT32 || input_type == TensorType_UINT8) {
|
||||||
|
std::string type_string =
|
||||||
|
activations_type == TensorType_INT16 ? "int16" : "int8";
|
||||||
// Create a new tensor to be the input of the leading Op.
|
// Create a new tensor to be the input of the leading Op.
|
||||||
std::unique_ptr<TensorT> leading_op_input;
|
std::unique_ptr<TensorT> leading_op_input;
|
||||||
if (input_type == TensorType_FLOAT32) {
|
if (input_type == TensorType_FLOAT32) {
|
||||||
// Add tensor for quantize operator. Scales and zero points are not
|
// Add tensor for quantize operator. Scales and zero points are not
|
||||||
// needed.
|
// needed.
|
||||||
const string leading_op_name = tensor->name;
|
const string leading_op_name = tensor->name;
|
||||||
const string new_name_original_input = tensor->name + "_int8";
|
const string new_name_original_input = tensor->name + "_" + type_string;
|
||||||
tensor->name = new_name_original_input;
|
tensor->name = new_name_original_input;
|
||||||
utils::MakeTensor(leading_op_name, tensor->shape, input_type,
|
utils::MakeTensor(leading_op_name, tensor->shape, input_type,
|
||||||
&leading_op_input);
|
&leading_op_input);
|
||||||
@ -150,7 +171,7 @@ int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
|
|||||||
TFLITE_DCHECK_GE(zero_point, -128);
|
TFLITE_DCHECK_GE(zero_point, -128);
|
||||||
TFLITE_DCHECK_LE(zero_point, 127);
|
TFLITE_DCHECK_LE(zero_point, 127);
|
||||||
const string leading_op_name = tensor->name;
|
const string leading_op_name = tensor->name;
|
||||||
const string new_name_original_input = tensor->name + "_int8";
|
const string new_name_original_input = tensor->name + "_" + type_string;
|
||||||
tensor->name = new_name_original_input;
|
tensor->name = new_name_original_input;
|
||||||
utils::MakeTensorWithQuantParam(leading_op_name, tensor->shape,
|
utils::MakeTensorWithQuantParam(leading_op_name, tensor->shape,
|
||||||
input_type, scale, zero_point + 128,
|
input_type, scale, zero_point + 128,
|
||||||
@ -177,17 +198,20 @@ int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
|
|||||||
// necessary.
|
// necessary.
|
||||||
// Returns the new output tensor index.
|
// Returns the new output tensor index.
|
||||||
int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
|
int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
|
||||||
const int32_t tensor_idx, const TensorType& output_type) {
|
const int32_t tensor_idx, const TensorType& output_type,
|
||||||
|
const TensorType& activations_type) {
|
||||||
TensorT* tensor = subgraph->tensors[tensor_idx].get();
|
TensorT* tensor = subgraph->tensors[tensor_idx].get();
|
||||||
if (!TensorTypeChangeRequired(tensor, output_type)) {
|
if (!TensorTypeChangeRequired(tensor, output_type)) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
if (output_type == TensorType_FLOAT32 || output_type == TensorType_UINT8) {
|
if (output_type == TensorType_FLOAT32 || output_type == TensorType_UINT8) {
|
||||||
|
std::string type_string =
|
||||||
|
activations_type == TensorType_INT16 ? "int16" : "int8";
|
||||||
// Create a new tensor to be the output of the tailing op.
|
// Create a new tensor to be the output of the tailing op.
|
||||||
std::unique_ptr<TensorT> tailing_op_output;
|
std::unique_ptr<TensorT> tailing_op_output;
|
||||||
if (output_type == TensorType_FLOAT32) {
|
if (output_type == TensorType_FLOAT32) {
|
||||||
const string tailing_op_name = tensor->name;
|
const string tailing_op_name = tensor->name;
|
||||||
const string new_name_original_output = tensor->name + "_int8";
|
const string new_name_original_output = tensor->name + "_" + type_string;
|
||||||
tensor->name = new_name_original_output;
|
tensor->name = new_name_original_output;
|
||||||
utils::MakeTensor(tailing_op_name, tensor->shape, output_type,
|
utils::MakeTensor(tailing_op_name, tensor->shape, output_type,
|
||||||
&tailing_op_output);
|
&tailing_op_output);
|
||||||
@ -202,7 +226,7 @@ int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
|
|||||||
TFLITE_DCHECK_GE(zero_point, -128);
|
TFLITE_DCHECK_GE(zero_point, -128);
|
||||||
TFLITE_DCHECK_LE(zero_point, 127);
|
TFLITE_DCHECK_LE(zero_point, 127);
|
||||||
const string tailing_op_name = tensor->name;
|
const string tailing_op_name = tensor->name;
|
||||||
const string new_name_original_output = tensor->name + "_int8";
|
const string new_name_original_output = tensor->name + "_" + type_string;
|
||||||
tensor->name = new_name_original_output;
|
tensor->name = new_name_original_output;
|
||||||
utils::MakeTensorWithQuantParam(tailing_op_name, tensor->shape,
|
utils::MakeTensorWithQuantParam(tailing_op_name, tensor->shape,
|
||||||
output_type, scale, zero_point + 128,
|
output_type, scale, zero_point + 128,
|
||||||
@ -238,6 +262,7 @@ int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
|
|||||||
// uint8, can be thought as "requant").
|
// uint8, can be thought as "requant").
|
||||||
TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
||||||
const TensorType& output_type,
|
const TensorType& output_type,
|
||||||
|
const TensorType& activations_type,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||||
subgraph_idx++) {
|
subgraph_idx++) {
|
||||||
@ -253,8 +278,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
|||||||
EnumNameTensorType(tensor->type));
|
EnumNameTensorType(tensor->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
const int32_t input_idx =
|
const int32_t input_idx = SetInputType(
|
||||||
SetInputType(model, subgraph, subgraph->inputs[i], input_type);
|
model, subgraph, subgraph->inputs[i], input_type, activations_type);
|
||||||
if (input_idx < 0) {
|
if (input_idx < 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -270,8 +295,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
|||||||
EnumNameTensorType(tensor->type));
|
EnumNameTensorType(tensor->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
const int32_t output_idx =
|
const int32_t output_idx = SetOutputType(
|
||||||
SetOutputType(model, subgraph, subgraph->outputs[i], output_type);
|
model, subgraph, subgraph->outputs[i], output_type, activations_type);
|
||||||
if (output_idx < 0) {
|
if (output_idx < 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -287,6 +312,7 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
|||||||
// The other ones with constraints are handled in QuantizeWeightsAndInput.
|
// The other ones with constraints are handled in QuantizeWeightsAndInput.
|
||||||
TfLiteStatus ApplyConstraints(ModelT* model,
|
TfLiteStatus ApplyConstraints(ModelT* model,
|
||||||
const std::unordered_set<string>& operator_names,
|
const std::unordered_set<string>& operator_names,
|
||||||
|
TensorType activations_type,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||||
subgraph_idx++) {
|
subgraph_idx++) {
|
||||||
@ -332,7 +358,7 @@ TfLiteStatus ApplyConstraints(ModelT* model,
|
|||||||
std::unique_ptr<TensorT> additional_tensor;
|
std::unique_ptr<TensorT> additional_tensor;
|
||||||
const string requant_tensor_name = input_tensor->name + "_requantized";
|
const string requant_tensor_name = input_tensor->name + "_requantized";
|
||||||
utils::MakeTensorWithQuantParam(
|
utils::MakeTensorWithQuantParam(
|
||||||
requant_tensor_name, input_tensor->shape, TensorType_INT8,
|
requant_tensor_name, input_tensor->shape, activations_type,
|
||||||
output_scale, output_zp, &additional_tensor);
|
output_scale, output_zp, &additional_tensor);
|
||||||
const int32_t additional_tensor_idx = subgraph->tensors.size();
|
const int32_t additional_tensor_idx = subgraph->tensors.size();
|
||||||
subgraph->tensors.push_back(std::move(additional_tensor));
|
subgraph->tensors.push_back(std::move(additional_tensor));
|
||||||
@ -382,7 +408,8 @@ std::vector<std::pair<int, operator_property::TensorProperty>> GetOutputs(
|
|||||||
|
|
||||||
bool ShouldRestrictSameInputOutputScale(
|
bool ShouldRestrictSameInputOutputScale(
|
||||||
operator_property::OperatorProperty property) {
|
operator_property::OperatorProperty property) {
|
||||||
// Ops with multiple inputs (i.e. concat) gets restricted in ApplyConstraints.
|
// Ops with multiple inputs (i.e. concat, max and min) gets restricted in
|
||||||
|
// ApplyConstraints.
|
||||||
return (!property.arbitrary_inputs &&
|
return (!property.arbitrary_inputs &&
|
||||||
property.restrict_same_input_output_scale);
|
property.restrict_same_input_output_scale);
|
||||||
}
|
}
|
||||||
@ -401,7 +428,7 @@ TfLiteStatus QuantizeOpInput(
|
|||||||
ModelT* model, int32_t subgraph_idx, size_t* op_idx,
|
ModelT* model, int32_t subgraph_idx, size_t* op_idx,
|
||||||
operator_property::OperatorProperty property,
|
operator_property::OperatorProperty property,
|
||||||
const std::pair<int32_t, operator_property::TensorProperty>& input,
|
const std::pair<int32_t, operator_property::TensorProperty>& input,
|
||||||
ErrorReporter* error_reporter) {
|
const TensorType& activations_type, ErrorReporter* error_reporter) {
|
||||||
int32_t input_idx = input.first;
|
int32_t input_idx = input.first;
|
||||||
operator_property::TensorProperty tensor_property = input.second;
|
operator_property::TensorProperty tensor_property = input.second;
|
||||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||||
@ -429,7 +456,9 @@ TfLiteStatus QuantizeOpInput(
|
|||||||
if (utils::HasBuffer(model, subgraph, tensor_idx)) {
|
if (utils::HasBuffer(model, subgraph, tensor_idx)) {
|
||||||
// TODO(suharshs): Look at consumers, throw error if one consumer is
|
// TODO(suharshs): Look at consumers, throw error if one consumer is
|
||||||
// per-channel and one per-layer.
|
// per-channel and one per-layer.
|
||||||
if (tensor_property.number_of_bits == 8) {
|
bool quantize_const_input = property.quantize_input_as_activations &&
|
||||||
|
activations_type == TensorType_INT16;
|
||||||
|
if (tensor_property.number_of_bits == 8 && !quantize_const_input) {
|
||||||
if (tensor_property.use_derived_scale) {
|
if (tensor_property.use_derived_scale) {
|
||||||
// Currently 8bit tensors in input do not accept derived scale.
|
// Currently 8bit tensors in input do not accept derived scale.
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
@ -444,7 +473,7 @@ TfLiteStatus QuantizeOpInput(
|
|||||||
*op_idx);
|
*op_idx);
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
} else if (tensor_property.number_of_bits == 16) {
|
} else if (tensor_property.number_of_bits == 16 || quantize_const_input) {
|
||||||
if (tensor_property.use_derived_scale) {
|
if (tensor_property.use_derived_scale) {
|
||||||
// Currently 16bit tensors in input do not accept derived scale.
|
// Currently 16bit tensors in input do not accept derived scale.
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
@ -476,8 +505,8 @@ TfLiteStatus QuantizeOpInput(
|
|||||||
tensor_property.derived_scale.input_tensors,
|
tensor_property.derived_scale.input_tensors,
|
||||||
tensor_property.derived_scale.intermediate_tensors,
|
tensor_property.derived_scale.intermediate_tensors,
|
||||||
tensor_property.derived_scale.factors);
|
tensor_property.derived_scale.factors);
|
||||||
return utils::SymmetricPerLayerBiasQuantize(model, tensor, scale,
|
return utils::SymmetricPerLayerBiasQuantize<std::int32_t>(
|
||||||
error_reporter);
|
model, tensor, scale, error_reporter);
|
||||||
|
|
||||||
} else if (tensor_property.number_of_bits == 10) {
|
} else if (tensor_property.number_of_bits == 10) {
|
||||||
// When the number of bits is 10 (instead of 16), quantize the tensor to
|
// When the number of bits is 10 (instead of 16), quantize the tensor to
|
||||||
@ -514,7 +543,8 @@ TfLiteStatus QuantizeOpInput(
|
|||||||
// Currently 8bit tensors in input do not accept derived scale.
|
// Currently 8bit tensors in input do not accept derived scale.
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
utils::QuantizeActivation(tensor);
|
TF_LITE_ENSURE_STATUS(utils::QuantizeActivation(
|
||||||
|
tensor, activations_type, error_reporter));
|
||||||
} else if (tensor_property.number_of_bits == 16) {
|
} else if (tensor_property.number_of_bits == 16) {
|
||||||
TensorT* tensor = subgraph->tensors[tensor_idx].get();
|
TensorT* tensor = subgraph->tensors[tensor_idx].get();
|
||||||
float quantized_range = 32767.0;
|
float quantized_range = 32767.0;
|
||||||
@ -532,13 +562,16 @@ TfLiteStatus QuantizeOpInput(
|
|||||||
} else {
|
} else {
|
||||||
// If the tensor is not a model input, we need to add a Quantize
|
// If the tensor is not a model input, we need to add a Quantize
|
||||||
// operation since the preceding op may require a float output.
|
// operation since the preceding op may require a float output.
|
||||||
|
std::string type_string =
|
||||||
|
activations_type == TensorType_INT16 ? "int16" : "int8";
|
||||||
std::unique_ptr<TensorT> op_output;
|
std::unique_ptr<TensorT> op_output;
|
||||||
utils::MakeTensor(tensor->name + "_int8", tensor->shape,
|
utils::MakeTensor(tensor->name + "_" + type_string, tensor->shape,
|
||||||
TensorType_INT8, &op_output);
|
activations_type, &op_output);
|
||||||
op_output->quantization = absl::make_unique<QuantizationParametersT>();
|
op_output->quantization = absl::make_unique<QuantizationParametersT>();
|
||||||
op_output->quantization->min.push_back(tensor->quantization->min[0]);
|
op_output->quantization->min.push_back(tensor->quantization->min[0]);
|
||||||
op_output->quantization->max.push_back(tensor->quantization->max[0]);
|
op_output->quantization->max.push_back(tensor->quantization->max[0]);
|
||||||
utils::QuantizeActivation(op_output.get());
|
TF_LITE_ENSURE_STATUS(utils::QuantizeActivation(
|
||||||
|
op_output.get(), activations_type, error_reporter));
|
||||||
const int32_t quant_op_output_idx = subgraph->tensors.size();
|
const int32_t quant_op_output_idx = subgraph->tensors.size();
|
||||||
subgraph->tensors.push_back(std::move(op_output));
|
subgraph->tensors.push_back(std::move(op_output));
|
||||||
std::unique_ptr<OperatorT> quant_op;
|
std::unique_ptr<OperatorT> quant_op;
|
||||||
@ -580,7 +613,7 @@ TfLiteStatus QuantizeOpOutput(
|
|||||||
ModelT* model, int32_t subgraph_idx, int32_t op_idx,
|
ModelT* model, int32_t subgraph_idx, int32_t op_idx,
|
||||||
operator_property::OperatorProperty property,
|
operator_property::OperatorProperty property,
|
||||||
const std::pair<int32_t, operator_property::TensorProperty>& output,
|
const std::pair<int32_t, operator_property::TensorProperty>& output,
|
||||||
ErrorReporter* error_reporter) {
|
TensorType activations_type, ErrorReporter* error_reporter) {
|
||||||
int32_t output_idx = output.first;
|
int32_t output_idx = output.first;
|
||||||
operator_property::TensorProperty tensor_property = output.second;
|
operator_property::TensorProperty tensor_property = output.second;
|
||||||
// If the operator is not quantizable, we don't need to do anything for the
|
// If the operator is not quantizable, we don't need to do anything for the
|
||||||
@ -644,18 +677,22 @@ TfLiteStatus QuantizeOpOutput(
|
|||||||
const float max = input_tensor->quantization->max[0];
|
const float max = input_tensor->quantization->max[0];
|
||||||
output_tensor->quantization->max = {max};
|
output_tensor->quantization->max = {max};
|
||||||
}
|
}
|
||||||
output_tensor->type = TensorType_INT8;
|
output_tensor->type = activations_type;
|
||||||
} else if (tensor_property.restriction) {
|
} else if (tensor_property.restriction) {
|
||||||
const auto scale_and_zp = tensor_property.restricted_value;
|
const auto scale_and_zp = activations_type == TensorType_INT16
|
||||||
|
? tensor_property.restricted_value_int16
|
||||||
|
: tensor_property.restricted_value_int8;
|
||||||
|
|
||||||
// Apply to output.
|
// Apply to output.
|
||||||
output_tensor->quantization = absl::make_unique<QuantizationParametersT>();
|
output_tensor->quantization = absl::make_unique<QuantizationParametersT>();
|
||||||
output_tensor->quantization->scale.push_back(scale_and_zp.first);
|
output_tensor->quantization->scale.push_back(scale_and_zp.first);
|
||||||
output_tensor->quantization->zero_point.push_back(scale_and_zp.second);
|
output_tensor->quantization->zero_point.push_back(scale_and_zp.second);
|
||||||
output_tensor->type = TensorType_INT8;
|
output_tensor->type = activations_type;
|
||||||
} else {
|
} else {
|
||||||
// Process regular output that doesn't have any restrictions.
|
// Process regular output that doesn't have any restrictions.
|
||||||
if (utils::HasMinMax(output_tensor)) {
|
if (utils::HasMinMax(output_tensor)) {
|
||||||
utils::QuantizeActivation(output_tensor);
|
utils::QuantizeActivation(output_tensor, activations_type,
|
||||||
|
error_reporter);
|
||||||
} else {
|
} else {
|
||||||
error_reporter->Report(
|
error_reporter->Report(
|
||||||
"Unable to find min/max value for output %d in %s in "
|
"Unable to find min/max value for output %d in %s in "
|
||||||
@ -668,6 +705,7 @@ TfLiteStatus QuantizeOpOutput(
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus QuantizeIntemediateTensors(ModelT* model,
|
TfLiteStatus QuantizeIntemediateTensors(ModelT* model,
|
||||||
|
TensorType activations_type,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||||
subgraph_idx++) {
|
subgraph_idx++) {
|
||||||
@ -691,7 +729,8 @@ TfLiteStatus QuantizeIntemediateTensors(ModelT* model,
|
|||||||
input.second.symmetric == false) {
|
input.second.symmetric == false) {
|
||||||
TensorT* tensor = subgraph->tensors[index_global].get();
|
TensorT* tensor = subgraph->tensors[index_global].get();
|
||||||
if (utils::HasMinMax(tensor)) {
|
if (utils::HasMinMax(tensor)) {
|
||||||
utils::QuantizeActivation(tensor);
|
utils::QuantizeActivation(tensor, activations_type,
|
||||||
|
error_reporter);
|
||||||
} else {
|
} else {
|
||||||
error_reporter->Report(
|
error_reporter->Report(
|
||||||
"Unable to find min/max value for output %d in %s in "
|
"Unable to find min/max value for output %d in %s in "
|
||||||
@ -793,7 +832,7 @@ TfLiteStatus QuantizeSharedRange(ModelT* model, ErrorReporter* error_reporter) {
|
|||||||
TfLiteStatus QuantizeWeightsInputOutput(
|
TfLiteStatus QuantizeWeightsInputOutput(
|
||||||
ModelT* model, bool allow_float,
|
ModelT* model, bool allow_float,
|
||||||
const std::unordered_set<string>& operator_names,
|
const std::unordered_set<string>& operator_names,
|
||||||
ErrorReporter* error_reporter) {
|
const TensorType& activations_type, ErrorReporter* error_reporter) {
|
||||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||||
subgraph_idx++) {
|
subgraph_idx++) {
|
||||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||||
@ -815,14 +854,16 @@ TfLiteStatus QuantizeWeightsInputOutput(
|
|||||||
for (const std::pair<int, operator_property::TensorProperty>& input :
|
for (const std::pair<int, operator_property::TensorProperty>& input :
|
||||||
GetInputs(op, property)) {
|
GetInputs(op, property)) {
|
||||||
TF_LITE_ENSURE_STATUS(QuantizeOpInput(model, subgraph_idx, &op_idx,
|
TF_LITE_ENSURE_STATUS(QuantizeOpInput(model, subgraph_idx, &op_idx,
|
||||||
property, input, error_reporter));
|
property, input, activations_type,
|
||||||
|
error_reporter));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quantize operator outputs.
|
// Quantize operator outputs.
|
||||||
for (const std::pair<int, operator_property::TensorProperty>& output :
|
for (const std::pair<int, operator_property::TensorProperty>& output :
|
||||||
GetOutputs(op, property)) {
|
GetOutputs(op, property)) {
|
||||||
TF_LITE_ENSURE_STATUS(QuantizeOpOutput(
|
TF_LITE_ENSURE_STATUS(
|
||||||
model, subgraph_idx, op_idx, property, output, error_reporter));
|
QuantizeOpOutput(model, subgraph_idx, op_idx, property, output,
|
||||||
|
activations_type, error_reporter));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -832,6 +873,7 @@ TfLiteStatus QuantizeWeightsInputOutput(
|
|||||||
// Quantize bias.
|
// Quantize bias.
|
||||||
TfLiteStatus QuantizeBiases(ModelT* model,
|
TfLiteStatus QuantizeBiases(ModelT* model,
|
||||||
const std::unordered_set<string>& operator_names,
|
const std::unordered_set<string>& operator_names,
|
||||||
|
const TensorType& activations_type,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||||
subgraph_idx++) {
|
subgraph_idx++) {
|
||||||
@ -877,10 +919,10 @@ TfLiteStatus QuantizeBiases(ModelT* model,
|
|||||||
subgraph->tensors[op->inputs[property.inputs[1].first]].get();
|
subgraph->tensors[op->inputs[property.inputs[1].first]].get();
|
||||||
operator_property::TensorProperty weight_property =
|
operator_property::TensorProperty weight_property =
|
||||||
property.inputs[1].second;
|
property.inputs[1].second;
|
||||||
TF_LITE_ENSURE_STATUS(
|
TF_LITE_ENSURE_STATUS(QuantizeBias(
|
||||||
QuantizeBias(model, input_tensor, weight_tensor, bias_tensor,
|
model, input_tensor, weight_tensor, bias_tensor,
|
||||||
weight_property.per_axis,
|
weight_property.per_axis, weight_property.per_axis_index,
|
||||||
weight_property.per_axis_index, error_reporter));
|
activations_type, error_reporter));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1000,7 +1042,7 @@ TfLiteStatus FillQuantizationParams(
|
|||||||
// Check compatibility of activation, weight and bias scales. Adjust if needed.
|
// Check compatibility of activation, weight and bias scales. Adjust if needed.
|
||||||
TfLiteStatus EnsureBiasScaleCompatibility(
|
TfLiteStatus EnsureBiasScaleCompatibility(
|
||||||
ModelT* model, const std::unordered_set<string>& operator_names,
|
ModelT* model, const std::unordered_set<string>& operator_names,
|
||||||
ErrorReporter* error_reporter) {
|
TensorType activations_type, ErrorReporter* error_reporter) {
|
||||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||||
subgraph_idx++) {
|
subgraph_idx++) {
|
||||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||||
@ -1049,11 +1091,9 @@ TfLiteStatus EnsureBiasScaleCompatibility(
|
|||||||
|
|
||||||
// Get input scale for assymmetric quantization.
|
// Get input scale for assymmetric quantization.
|
||||||
QuantizationParametersT temp_quant_params = QuantizationParametersT();
|
QuantizationParametersT temp_quant_params = QuantizationParametersT();
|
||||||
utils::GetAsymmetricQuantizationParams(
|
TF_LITE_ENSURE_STATUS(
|
||||||
input_tensor->quantization->min[0],
|
utils::GetQuantizationParams(input_tensor, activations_type,
|
||||||
input_tensor->quantization->max[0],
|
&temp_quant_params, error_reporter));
|
||||||
std::numeric_limits<int8_t>::min(),
|
|
||||||
std::numeric_limits<int8_t>::max(), &temp_quant_params);
|
|
||||||
if (temp_quant_params.scale.size() != 1) {
|
if (temp_quant_params.scale.size() != 1) {
|
||||||
error_reporter->Report("Unexpected input quantization scale size.");
|
error_reporter->Report("Unexpected input quantization scale size.");
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
@ -1132,21 +1172,24 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
|||||||
ModelT* model, const TensorType& input_type,
|
ModelT* model, const TensorType& input_type,
|
||||||
const TensorType& output_type, bool allow_float,
|
const TensorType& output_type, bool allow_float,
|
||||||
const std::unordered_set<string>& operator_names,
|
const std::unordered_set<string>& operator_names,
|
||||||
|
const TensorType& activations_type,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
TF_LITE_ENSURE_STATUS(
|
TF_LITE_ENSURE_STATUS(
|
||||||
FillQuantizationParams(model, operator_names, error_reporter));
|
FillQuantizationParams(model, operator_names, error_reporter));
|
||||||
|
TF_LITE_ENSURE_STATUS(EnsureBiasScaleCompatibility(
|
||||||
|
model, operator_names, activations_type, error_reporter));
|
||||||
TF_LITE_ENSURE_STATUS(
|
TF_LITE_ENSURE_STATUS(
|
||||||
EnsureBiasScaleCompatibility(model, operator_names, error_reporter));
|
QuantizeIntemediateTensors(model, activations_type, error_reporter));
|
||||||
TF_LITE_ENSURE_STATUS(QuantizeIntemediateTensors(model, error_reporter));
|
|
||||||
TF_LITE_ENSURE_STATUS(QuantizeSharedRange(model, error_reporter));
|
TF_LITE_ENSURE_STATUS(QuantizeSharedRange(model, error_reporter));
|
||||||
TF_LITE_ENSURE_STATUS(QuantizeWeightsInputOutput(
|
TF_LITE_ENSURE_STATUS(QuantizeWeightsInputOutput(
|
||||||
model, allow_float, operator_names, error_reporter));
|
model, allow_float, operator_names, activations_type, error_reporter));
|
||||||
|
TF_LITE_ENSURE_STATUS(ApplyConstraints(model, operator_names,
|
||||||
|
activations_type, error_reporter));
|
||||||
TF_LITE_ENSURE_STATUS(
|
TF_LITE_ENSURE_STATUS(
|
||||||
ApplyConstraints(model, operator_names, error_reporter));
|
QuantizeBiases(model, operator_names, activations_type, error_reporter));
|
||||||
TF_LITE_ENSURE_STATUS(QuantizeBiases(model, operator_names, error_reporter));
|
|
||||||
utils::SetOperatorCodeVersion(model);
|
utils::SetOperatorCodeVersion(model);
|
||||||
TF_LITE_ENSURE_STATUS(
|
TF_LITE_ENSURE_STATUS(SetInputAndOutputTypes(
|
||||||
SetInputAndOutputTypes(model, input_type, output_type, error_reporter));
|
model, input_type, output_type, activations_type, error_reporter));
|
||||||
|
|
||||||
flatbuffers::Offset<Model> output_model_location =
|
flatbuffers::Offset<Model> output_model_location =
|
||||||
Model::Pack(*builder, model);
|
Model::Pack(*builder, model);
|
||||||
@ -1158,23 +1201,27 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
|||||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||||
ModelT* model, const TensorType& input_type,
|
ModelT* model, const TensorType& input_type,
|
||||||
const TensorType& output_type, bool allow_float,
|
const TensorType& output_type, bool allow_float,
|
||||||
|
const TensorType& activations_type,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
return QuantizeModel(builder, model, input_type, output_type, allow_float,
|
return QuantizeModel(builder, model, input_type, output_type, allow_float,
|
||||||
GetAllOperatorOutputs(model), error_reporter);
|
GetAllOperatorOutputs(model), activations_type,
|
||||||
|
error_reporter);
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||||
ModelT* model, const TensorType& input_type,
|
ModelT* model, const TensorType& input_type,
|
||||||
const TensorType& output_type,
|
const TensorType& output_type,
|
||||||
|
const TensorType& activations_type,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
return QuantizeModel(builder, model, input_type, output_type,
|
return QuantizeModel(builder, model, input_type, output_type,
|
||||||
/*allow_float=*/false, error_reporter);
|
/*allow_float=*/false, activations_type, error_reporter);
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||||
ModelT* model, ErrorReporter* error_reporter) {
|
ModelT* model, const TensorType& activations_type,
|
||||||
|
ErrorReporter* error_reporter) {
|
||||||
return QuantizeModel(builder, model, TensorType_FLOAT32, TensorType_FLOAT32,
|
return QuantizeModel(builder, model, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||||
/*allow_float=*/false, error_reporter);
|
/*allow_float=*/false, activations_type, error_reporter);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace optimize
|
} // namespace optimize
|
||||||
|
@ -35,7 +35,9 @@ namespace optimize {
|
|||||||
//
|
//
|
||||||
// Note: This is a private API, subject to change.
|
// Note: This is a private API, subject to change.
|
||||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||||
ModelT* input_model, ErrorReporter* error_reporter);
|
ModelT* input_model,
|
||||||
|
const TensorType& activations_type,
|
||||||
|
ErrorReporter* error_reporter);
|
||||||
|
|
||||||
// Same as above, but the types of quantized inputs and outputs are
|
// Same as above, but the types of quantized inputs and outputs are
|
||||||
// configurable.
|
// configurable.
|
||||||
@ -44,6 +46,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
|||||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||||
ModelT* input_model, const TensorType& input_type,
|
ModelT* input_model, const TensorType& input_type,
|
||||||
const TensorType& output_type,
|
const TensorType& output_type,
|
||||||
|
const TensorType& activations_type,
|
||||||
ErrorReporter* error_reporter);
|
ErrorReporter* error_reporter);
|
||||||
|
|
||||||
// Same as above, but can enable allowing float intermediate operations for ops
|
// Same as above, but can enable allowing float intermediate operations for ops
|
||||||
@ -53,6 +56,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
|||||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||||
ModelT* input_model, const TensorType& input_type,
|
ModelT* input_model, const TensorType& input_type,
|
||||||
const TensorType& output_type, bool allow_float,
|
const TensorType& output_type, bool allow_float,
|
||||||
|
const TensorType& activations_type,
|
||||||
ErrorReporter* error_reporter);
|
ErrorReporter* error_reporter);
|
||||||
|
|
||||||
// Same as above, but enables only quantizing a whitelist of operations,
|
// Same as above, but enables only quantizing a whitelist of operations,
|
||||||
@ -63,6 +67,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
|||||||
ModelT* input_model, const TensorType& input_type,
|
ModelT* input_model, const TensorType& input_type,
|
||||||
const TensorType& output_type, bool allow_float,
|
const TensorType& output_type, bool allow_float,
|
||||||
const std::unordered_set<string>& operator_names,
|
const std::unordered_set<string>& operator_names,
|
||||||
|
const TensorType& activations_type,
|
||||||
ErrorReporter* error_reporter);
|
ErrorReporter* error_reporter);
|
||||||
|
|
||||||
} // namespace optimize
|
} // namespace optimize
|
||||||
|
@ -80,28 +80,35 @@ class QuantizeModelTest : public testing::Test {
|
|||||||
internal::FailOnErrorReporter error_reporter_;
|
internal::FailOnErrorReporter error_reporter_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class QuantizeConvModelTest : public QuantizeModelTest {
|
class QuantizeConvModelTest : public QuantizeModelTest,
|
||||||
|
public testing::WithParamInterface<TensorType> {
|
||||||
protected:
|
protected:
|
||||||
QuantizeConvModelTest() {
|
QuantizeConvModelTest() {
|
||||||
|
tensor_type_ = GetParam();
|
||||||
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
|
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
|
||||||
readonly_model_ = input_model_->GetModel();
|
readonly_model_ = input_model_->GetModel();
|
||||||
readonly_model_->UnPackTo(&model_);
|
readonly_model_->UnPackTo(&model_);
|
||||||
}
|
}
|
||||||
|
TensorType tensor_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(QuantizeConvModelTest, QuantizationSucceeds) {
|
INSTANTIATE_TEST_SUITE_P(QuantizeConvModelTestInst, QuantizeConvModelTest,
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
testing::ValuesIn({TensorType_INT8,
|
||||||
TensorType_INT8, &error_reporter_);
|
TensorType_INT16}));
|
||||||
|
|
||||||
|
TEST_P(QuantizeConvModelTest, QuantizationSucceeds) {
|
||||||
|
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||||
|
tensor_type_, &error_reporter_);
|
||||||
EXPECT_EQ(status, kTfLiteOk);
|
EXPECT_EQ(status, kTfLiteOk);
|
||||||
const uint8_t* buffer = builder_.GetBufferPointer();
|
const uint8_t* buffer = builder_.GetBufferPointer();
|
||||||
const Model* output_model = GetModel(buffer);
|
const Model* output_model = GetModel(buffer);
|
||||||
ASSERT_TRUE(output_model);
|
ASSERT_TRUE(output_model);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(QuantizeConvModelTest, SkipUnspecifiedLayer) {
|
TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) {
|
||||||
auto status =
|
auto status =
|
||||||
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||||
/*allow_float=*/true, {}, &error_reporter_);
|
/*allow_float=*/true, {}, tensor_type_, &error_reporter_);
|
||||||
EXPECT_EQ(status, kTfLiteOk);
|
EXPECT_EQ(status, kTfLiteOk);
|
||||||
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
|
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
|
||||||
// The resulting model should be the same.
|
// The resulting model should be the same.
|
||||||
@ -123,9 +130,9 @@ TEST_F(QuantizeConvModelTest, SkipUnspecifiedLayer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
|
TEST_P(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||||
TensorType_INT8, &error_reporter_);
|
tensor_type_, &error_reporter_);
|
||||||
EXPECT_EQ(status, kTfLiteOk);
|
EXPECT_EQ(status, kTfLiteOk);
|
||||||
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
|
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
|
||||||
for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
|
for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
|
||||||
@ -148,9 +155,9 @@ TEST_F(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
|
|||||||
EXPECT_EQ(model_.operator_codes[0]->version, 3);
|
EXPECT_EQ(model_.operator_codes[0]->version, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(QuantizeConvModelTest, OperatorsAreUnchanged) {
|
TEST_P(QuantizeConvModelTest, OperatorsAreUnchanged) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||||
TensorType_INT8, &error_reporter_);
|
tensor_type_, &error_reporter_);
|
||||||
EXPECT_EQ(status, kTfLiteOk);
|
EXPECT_EQ(status, kTfLiteOk);
|
||||||
ASSERT_EQ(model_.operator_codes.size(),
|
ASSERT_EQ(model_.operator_codes.size(),
|
||||||
readonly_model_->operator_codes()->size());
|
readonly_model_->operator_codes()->size());
|
||||||
@ -182,20 +189,28 @@ TEST_F(QuantizeConvModelTest, OperatorsAreUnchanged) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(QuantizeConvModelTest, GraphIsFullyQuantized) {
|
TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||||
TensorType_INT8, &error_reporter_);
|
tensor_type_, &error_reporter_);
|
||||||
EXPECT_EQ(status, kTfLiteOk);
|
EXPECT_EQ(status, kTfLiteOk);
|
||||||
for (const auto& subgraph : model_.subgraphs) {
|
for (const auto& subgraph : model_.subgraphs) {
|
||||||
for (const auto& tensor : subgraph->tensors) {
|
for (const auto& tensor : subgraph->tensors) {
|
||||||
|
if (tensor_type_ == TensorType_INT8) {
|
||||||
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
|
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
|
||||||
tensor->type == TensorType_INT8);
|
tensor->type == TensorType_INT8);
|
||||||
|
} else if (tensor_type_ == TensorType_INT16) {
|
||||||
|
EXPECT_TRUE(tensor->type == TensorType_INT64 || // bias
|
||||||
|
tensor->type == TensorType_INT8 || // weights
|
||||||
|
tensor->type == TensorType_INT16); // activations
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(QuantizeConvModelTest, FloatInputAndOutput) {
|
TEST_P(QuantizeConvModelTest, FloatInputAndOutput) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||||
|
tensor_type_, &error_reporter_);
|
||||||
EXPECT_EQ(status, kTfLiteOk);
|
EXPECT_EQ(status, kTfLiteOk);
|
||||||
|
|
||||||
for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
|
for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
|
||||||
@ -234,22 +249,33 @@ TEST_F(QuantizeConvModelTest, FloatInputAndOutput) {
|
|||||||
EXPECT_EQ(subgraph->tensors[output_idx]->type, TensorType_FLOAT32);
|
EXPECT_EQ(subgraph->tensors[output_idx]->type, TensorType_FLOAT32);
|
||||||
EXPECT_EQ(subgraph->tensors[output_idx]->name, "output");
|
EXPECT_EQ(subgraph->tensors[output_idx]->name, "output");
|
||||||
// The original input and output has been renamed.
|
// The original input and output has been renamed.
|
||||||
EXPECT_EQ(subgraph->tensors[quant_op->outputs[0]]->name, "input_int8");
|
std::string control_suffix =
|
||||||
EXPECT_EQ(subgraph->tensors[dequant_op->inputs[0]]->name, "output_int8");
|
(tensor_type_ == TensorType_INT16) ? "int16" : "int8";
|
||||||
|
EXPECT_EQ(subgraph->tensors[quant_op->outputs[0]]->name,
|
||||||
|
"input_" + control_suffix);
|
||||||
|
EXPECT_EQ(subgraph->tensors[dequant_op->inputs[0]]->name,
|
||||||
|
"output_" + control_suffix);
|
||||||
for (int tensor_idx = 0; tensor_idx < subgraph->tensors.size();
|
for (int tensor_idx = 0; tensor_idx < subgraph->tensors.size();
|
||||||
++tensor_idx) {
|
++tensor_idx) {
|
||||||
const auto& tensor = subgraph->tensors[tensor_idx];
|
const auto& tensor = subgraph->tensors[tensor_idx];
|
||||||
if (input_idx != tensor_idx && output_idx != tensor_idx) {
|
if (input_idx != tensor_idx && output_idx != tensor_idx) {
|
||||||
|
if (tensor_type_ == TensorType_INT8) {
|
||||||
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
|
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
|
||||||
tensor->type == TensorType_INT8);
|
tensor->type == TensorType_INT8);
|
||||||
|
} else if (tensor_type_ == TensorType_INT16) {
|
||||||
|
EXPECT_TRUE(tensor->type == TensorType_INT64 || // bias
|
||||||
|
tensor->type == TensorType_INT8 || // weights
|
||||||
|
tensor->type == TensorType_INT16); // activations
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(QuantizeConvModelTest, Uint8InputAndOutput) {
|
TEST_P(QuantizeConvModelTest, Uint8InputAndOutput) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_UINT8,
|
auto status =
|
||||||
TensorType_UINT8, &error_reporter_);
|
QuantizeModel(&builder_, &model_, TensorType_UINT8, TensorType_UINT8,
|
||||||
|
TensorType_INT8, &error_reporter_);
|
||||||
EXPECT_EQ(status, kTfLiteOk);
|
EXPECT_EQ(status, kTfLiteOk);
|
||||||
|
|
||||||
for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
|
for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
|
||||||
@ -326,7 +352,8 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) {
|
TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||||
TensorType_INT8, &error_reporter_);
|
TensorType_INT8, &error_reporter_);
|
||||||
EXPECT_EQ(status, kTfLiteOk);
|
EXPECT_EQ(status, kTfLiteOk);
|
||||||
const uint8_t* buffer = builder_.GetBufferPointer();
|
const uint8_t* buffer = builder_.GetBufferPointer();
|
||||||
@ -334,13 +361,16 @@ TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) {
|
|||||||
ASSERT_TRUE(output_model);
|
ASSERT_TRUE(output_model);
|
||||||
}
|
}
|
||||||
|
|
||||||
class QuantizeConcatModelTest : public QuantizeModelTest {
|
class QuantizeConcatModelTest : public QuantizeModelTest,
|
||||||
|
public testing::WithParamInterface<TensorType> {
|
||||||
protected:
|
protected:
|
||||||
QuantizeConcatModelTest() {
|
QuantizeConcatModelTest() {
|
||||||
input_model_ = ReadModel(internal::kFloatConcatMax5Max10Max10);
|
input_model_ = ReadModel(internal::kFloatConcatMax5Max10Max10);
|
||||||
readonly_model_ = input_model_->GetModel();
|
readonly_model_ = input_model_->GetModel();
|
||||||
readonly_model_->UnPackTo(&model_);
|
readonly_model_->UnPackTo(&model_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TensorType tensor_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// There are two inputs for concat, "input0" and "input1". "input0" has [0, 5]
|
// There are two inputs for concat, "input0" and "input1". "input0" has [0, 5]
|
||||||
@ -352,9 +382,9 @@ class QuantizeConcatModelTest : public QuantizeModelTest {
|
|||||||
// input0 -> requant -> input0_requant \
|
// input0 -> requant -> input0_requant \
|
||||||
// concat - output
|
// concat - output
|
||||||
// input1 /
|
// input1 /
|
||||||
TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
|
TEST_P(QuantizeConcatModelTest, AddRequantBeforeConcat) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||||
TensorType_INT8, &error_reporter_);
|
tensor_type_, &error_reporter_);
|
||||||
EXPECT_EQ(status, kTfLiteOk);
|
EXPECT_EQ(status, kTfLiteOk);
|
||||||
|
|
||||||
// There is only one subgraph.
|
// There is only one subgraph.
|
||||||
@ -373,32 +403,51 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
|
|||||||
EXPECT_EQ(model_.operator_codes[concat->opcode_index]->builtin_code,
|
EXPECT_EQ(model_.operator_codes[concat->opcode_index]->builtin_code,
|
||||||
BuiltinOperator_CONCATENATION);
|
BuiltinOperator_CONCATENATION);
|
||||||
|
|
||||||
|
auto zero_point_control = tensor_type_ == TensorType_INT8 ? -128 : 0;
|
||||||
|
/*
|
||||||
|
input0_scale_control
|
||||||
|
INT8: (5-0) / (2^8 - 1)
|
||||||
|
INT16: (5-0) / (2^16 / 2 - 1)
|
||||||
|
input1_scale
|
||||||
|
INT8: (10-0) / (2^8 - 1)
|
||||||
|
INT16: (10-0) / (2^16 / 2 - 1)
|
||||||
|
*/
|
||||||
|
auto input0_scale_control =
|
||||||
|
tensor_type_ == TensorType_INT8 ? 0.019607844 : 0.00015259254;
|
||||||
|
auto input1_scale =
|
||||||
|
tensor_type_ == TensorType_INT8 ? 0.039215688 : 0.00030518509;
|
||||||
|
|
||||||
// There should be 4 tensors: input0, input1, input0_requantized, output.
|
// There should be 4 tensors: input0, input1, input0_requantized, output.
|
||||||
EXPECT_EQ(subgraph->tensors.size(), 4);
|
EXPECT_EQ(subgraph->tensors.size(), 4);
|
||||||
EXPECT_EQ(subgraph->tensors[0]->type, TensorType_INT8);
|
EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
|
||||||
EXPECT_EQ(subgraph->tensors[0]->name, "input0");
|
EXPECT_EQ(subgraph->tensors[0]->name, "input0");
|
||||||
EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
|
EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
|
||||||
EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
|
EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
|
||||||
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0], 0.019607844);
|
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0],
|
||||||
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0], -128);
|
input0_scale_control);
|
||||||
EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT8);
|
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0],
|
||||||
|
zero_point_control);
|
||||||
|
EXPECT_EQ(subgraph->tensors[1]->type, tensor_type_);
|
||||||
EXPECT_EQ(subgraph->tensors[1]->name, "input1");
|
EXPECT_EQ(subgraph->tensors[1]->name, "input1");
|
||||||
EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1);
|
EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1);
|
||||||
EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1);
|
EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1);
|
||||||
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], 0.039215688);
|
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], input1_scale);
|
||||||
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0], -128);
|
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0],
|
||||||
EXPECT_EQ(subgraph->tensors[2]->type, TensorType_INT8);
|
zero_point_control);
|
||||||
|
EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
|
||||||
EXPECT_EQ(subgraph->tensors[2]->name, "output");
|
EXPECT_EQ(subgraph->tensors[2]->name, "output");
|
||||||
EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
|
EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
|
||||||
EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
|
EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
|
||||||
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], 0.039215688);
|
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], input1_scale);
|
||||||
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0], -128);
|
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0],
|
||||||
EXPECT_EQ(subgraph->tensors[3]->type, TensorType_INT8);
|
zero_point_control);
|
||||||
|
EXPECT_EQ(subgraph->tensors[3]->type, tensor_type_);
|
||||||
EXPECT_EQ(subgraph->tensors[3]->name, "input0_requantized");
|
EXPECT_EQ(subgraph->tensors[3]->name, "input0_requantized");
|
||||||
EXPECT_EQ(subgraph->tensors[3]->quantization->scale.size(), 1);
|
EXPECT_EQ(subgraph->tensors[3]->quantization->scale.size(), 1);
|
||||||
EXPECT_EQ(subgraph->tensors[3]->quantization->zero_point.size(), 1);
|
EXPECT_EQ(subgraph->tensors[3]->quantization->zero_point.size(), 1);
|
||||||
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->scale[0], 0.039215688);
|
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->scale[0], input1_scale);
|
||||||
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->zero_point[0], -128);
|
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->zero_point[0],
|
||||||
|
zero_point_control);
|
||||||
|
|
||||||
// The connection should be what is described in the comment.
|
// The connection should be what is described in the comment.
|
||||||
EXPECT_EQ(requant->inputs.size(), 1);
|
EXPECT_EQ(requant->inputs.size(), 1);
|
||||||
@ -419,7 +468,9 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
|
|||||||
EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_QUANTIZE);
|
EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_QUANTIZE);
|
||||||
EXPECT_EQ(model_.operator_codes[1]->version, 2);
|
EXPECT_EQ(model_.operator_codes[1]->version, 2);
|
||||||
}
|
}
|
||||||
|
INSTANTIATE_TEST_SUITE_P(QuantizeConcatModelInst, QuantizeConcatModelTest,
|
||||||
|
testing::ValuesIn({TensorType_INT8,
|
||||||
|
TensorType_INT16}));
|
||||||
class QuantizeSplitModelTest : public QuantizeModelTest {
|
class QuantizeSplitModelTest : public QuantizeModelTest {
|
||||||
protected:
|
protected:
|
||||||
QuantizeSplitModelTest() {
|
QuantizeSplitModelTest() {
|
||||||
@ -432,7 +483,8 @@ class QuantizeSplitModelTest : public QuantizeModelTest {
|
|||||||
// There are two outputs for split with different scales, the resulting model
|
// There are two outputs for split with different scales, the resulting model
|
||||||
// should have the scales be hardcodes to the input scale value.
|
// should have the scales be hardcodes to the input scale value.
|
||||||
TEST_F(QuantizeSplitModelTest, QuantizeSplit) {
|
TEST_F(QuantizeSplitModelTest, QuantizeSplit) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||||
TensorType_INT8, &error_reporter_);
|
TensorType_INT8, &error_reporter_);
|
||||||
EXPECT_EQ(status, kTfLiteOk);
|
EXPECT_EQ(status, kTfLiteOk);
|
||||||
|
|
||||||
@ -496,7 +548,8 @@ class QuantizeConvModel1Test : public QuantizeModelTest {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
|
TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||||
TensorType_INT8, &error_reporter_);
|
TensorType_INT8, &error_reporter_);
|
||||||
EXPECT_EQ(status, kTfLiteOk);
|
EXPECT_EQ(status, kTfLiteOk);
|
||||||
const auto& subgraph = model_.subgraphs[0];
|
const auto& subgraph = model_.subgraphs[0];
|
||||||
@ -587,18 +640,25 @@ TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
|
|||||||
EXPECT_EQ(model_.operator_codes[0]->version, 3);
|
EXPECT_EQ(model_.operator_codes[0]->version, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
class QuantizeConvModel2Test : public QuantizeModelTest {
|
class QuantizeConvModel2Test : public QuantizeModelTest,
|
||||||
|
public testing::WithParamInterface<TensorType> {
|
||||||
protected:
|
protected:
|
||||||
QuantizeConvModel2Test() {
|
QuantizeConvModel2Test() {
|
||||||
|
tensor_type_ = GetParam();
|
||||||
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
|
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
|
||||||
readonly_model_ = input_model_->GetModel();
|
readonly_model_ = input_model_->GetModel();
|
||||||
readonly_model_->UnPackTo(&model_);
|
readonly_model_->UnPackTo(&model_);
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) {
|
TensorType tensor_type_;
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
};
|
||||||
TensorType_INT8, &error_reporter_);
|
INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test,
|
||||||
|
testing::ValuesIn({TensorType_INT8,
|
||||||
|
TensorType_INT16}));
|
||||||
|
|
||||||
|
TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) {
|
||||||
|
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||||
|
tensor_type_, &error_reporter_);
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
const auto& subgraph = model_.subgraphs[0];
|
const auto& subgraph = model_.subgraphs[0];
|
||||||
auto conv_op = subgraph->operators[0].get();
|
auto conv_op = subgraph->operators[0].get();
|
||||||
@ -615,8 +675,10 @@ TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) {
|
|||||||
const auto output_tensor =
|
const auto output_tensor =
|
||||||
subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
|
subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
|
||||||
|
|
||||||
EXPECT_EQ(bias_tensor->type, TensorType_INT32);
|
EXPECT_EQ(bias_tensor->type, tensor_type_ == TensorType_INT8
|
||||||
EXPECT_EQ(input_tensor->type, TensorType_INT8);
|
? TensorType_INT32
|
||||||
|
: TensorType_INT64);
|
||||||
|
EXPECT_EQ(input_tensor->type, tensor_type_);
|
||||||
EXPECT_EQ(weights_tensor->type, TensorType_INT8);
|
EXPECT_EQ(weights_tensor->type, TensorType_INT8);
|
||||||
|
|
||||||
ASSERT_TRUE(weights_tensor->quantization);
|
ASSERT_TRUE(weights_tensor->quantization);
|
||||||
@ -644,18 +706,29 @@ TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
|
const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
|
||||||
ASSERT_EQ(bias_buffer->data.size(), sizeof(int32_t) * bias_tensor->shape[0]);
|
auto control_size = tensor_type_ == TensorType_INT8
|
||||||
const int32_t* bias_values =
|
? sizeof(int32_t) * bias_tensor->shape[0]
|
||||||
reinterpret_cast<int32_t*>(bias_buffer->data.data());
|
: sizeof(int64_t) * bias_tensor->shape[0];
|
||||||
|
|
||||||
|
ASSERT_EQ(bias_buffer->data.size(), control_size);
|
||||||
const auto original_bias_buffer =
|
const auto original_bias_buffer =
|
||||||
readonly_model_->buffers()->Get(bias_tensor->buffer);
|
readonly_model_->buffers()->Get(bias_tensor->buffer);
|
||||||
const float* bias_float_buffer =
|
const float* bias_float_buffer =
|
||||||
reinterpret_cast<const float*>(original_bias_buffer->data()->data());
|
reinterpret_cast<const float*>(original_bias_buffer->data()->data());
|
||||||
|
|
||||||
|
if (tensor_type_ == TensorType_INT8) {
|
||||||
|
int32_t* bias_values = reinterpret_cast<int32_t*>(bias_buffer->data.data());
|
||||||
for (size_t i = 0; i < out_channel_size; i++) {
|
for (size_t i = 0; i < out_channel_size; i++) {
|
||||||
auto dequantized_value = bias_values[i] * bias_scales[i];
|
auto dequantized_value = bias_values[i] * bias_scales[i];
|
||||||
EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
|
EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
|
||||||
}
|
}
|
||||||
|
} else if (tensor_type_ == TensorType_INT16) {
|
||||||
|
int64_t* bias_values = reinterpret_cast<int64_t*>(bias_buffer->data.data());
|
||||||
|
for (size_t i = 0; i < out_channel_size; i++) {
|
||||||
|
auto dequantized_value = bias_values[i] * bias_scales[i];
|
||||||
|
EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
|
const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
|
||||||
const auto original_weights_buffer =
|
const auto original_weights_buffer =
|
||||||
@ -695,7 +768,8 @@ class QuantizeSoftmaxTest : public QuantizeModelTest {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) {
|
TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||||
TensorType_INT8, &error_reporter_);
|
TensorType_INT8, &error_reporter_);
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
|
|
||||||
@ -755,7 +829,8 @@ class QuantizeAvgPoolTest : public QuantizeModelTest {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) {
|
TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||||
TensorType_INT8, &error_reporter_);
|
TensorType_INT8, &error_reporter_);
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
|
|
||||||
@ -816,7 +891,8 @@ class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
|
TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||||
TensorType_INT8, &error_reporter_);
|
TensorType_INT8, &error_reporter_);
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
|
|
||||||
@ -863,7 +939,8 @@ TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) {
|
TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||||
TensorType_INT8, &error_reporter_);
|
TensorType_INT8, &error_reporter_);
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
|
|
||||||
@ -923,7 +1000,8 @@ class QuantizeConstInputTest : public QuantizeModelTest {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(QuantizeConstInputTest, VerifyConstOpInput) {
|
TEST_F(QuantizeConstInputTest, VerifyConstOpInput) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||||
TensorType_INT8, &error_reporter_);
|
TensorType_INT8, &error_reporter_);
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
|
|
||||||
@ -965,7 +1043,8 @@ class QuantizeArgMaxTest : public QuantizeModelTest {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(QuantizeArgMaxTest, VerifyArgMax) {
|
TEST_F(QuantizeArgMaxTest, VerifyArgMax) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||||
TensorType_INT8, &error_reporter_);
|
TensorType_INT8, &error_reporter_);
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
|
|
||||||
@ -1008,8 +1087,9 @@ class QuantizeLSTMTest : public QuantizeModelTest {
|
|||||||
|
|
||||||
TEST_F(QuantizeLSTMTest, VerifyLSTM) {
|
TEST_F(QuantizeLSTMTest, VerifyLSTM) {
|
||||||
// Quantize model.
|
// Quantize model.
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32,
|
auto status =
|
||||||
TensorType_FLOAT32, &error_reporter_);
|
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||||
|
TensorType_INT8, &error_reporter_);
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
|
|
||||||
// Read expected model.
|
// Read expected model.
|
||||||
@ -1067,8 +1147,9 @@ class QuantizeLSTM2Test : public QuantizeModelTest {
|
|||||||
|
|
||||||
TEST_F(QuantizeLSTM2Test, VerifyLSTM) {
|
TEST_F(QuantizeLSTM2Test, VerifyLSTM) {
|
||||||
// Quantize model.
|
// Quantize model.
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32,
|
auto status =
|
||||||
TensorType_FLOAT32, &error_reporter_);
|
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||||
|
TensorType_INT8, &error_reporter_);
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
|
|
||||||
// Read expected model.
|
// Read expected model.
|
||||||
@ -1126,7 +1207,8 @@ class QuantizeSVDFTest : public QuantizeModelTest {
|
|||||||
|
|
||||||
TEST_F(QuantizeSVDFTest, VerifySVDF) {
|
TEST_F(QuantizeSVDFTest, VerifySVDF) {
|
||||||
// Quantize model.
|
// Quantize model.
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||||
TensorType_INT8, &error_reporter_);
|
TensorType_INT8, &error_reporter_);
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
|
|
||||||
@ -1184,7 +1266,8 @@ class QuantizeFCTest : public QuantizeModelTest {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(QuantizeFCTest, VerifyFC) {
|
TEST_F(QuantizeFCTest, VerifyFC) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||||
TensorType_INT8, &error_reporter_);
|
TensorType_INT8, &error_reporter_);
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
|
|
||||||
@ -1236,7 +1319,7 @@ class QuantizeCustomOpTest : public QuantizeModelTest {
|
|||||||
TEST_F(QuantizeCustomOpTest, VerifyMixedQuantization) {
|
TEST_F(QuantizeCustomOpTest, VerifyMixedQuantization) {
|
||||||
auto status =
|
auto status =
|
||||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||||
/*allow_float=*/true, &error_reporter_);
|
/*allow_float=*/true, TensorType_INT8, &error_reporter_);
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
const auto& subgraph = model_.subgraphs[0];
|
const auto& subgraph = model_.subgraphs[0];
|
||||||
auto float_graph = readonly_model_->subgraphs()->Get(0);
|
auto float_graph = readonly_model_->subgraphs()->Get(0);
|
||||||
@ -1270,7 +1353,8 @@ class QuantizePackTest : public QuantizeModelTest {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(QuantizePackTest, VerifyPack) {
|
TEST_F(QuantizePackTest, VerifyPack) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_);
|
||||||
|
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
|
|
||||||
@ -1334,7 +1418,8 @@ class QuantizeMinimumMaximumTest
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) {
|
TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_);
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
const auto& subgraph = model_.subgraphs[0];
|
const auto& subgraph = model_.subgraphs[0];
|
||||||
|
|
||||||
@ -1415,7 +1500,8 @@ class QuantizeUnpackTest : public QuantizeModelTest {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
TEST_F(QuantizeUnpackTest, VerifyUnpack) {
|
TEST_F(QuantizeUnpackTest, VerifyUnpack) {
|
||||||
auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
|
auto status =
|
||||||
|
QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_);
|
||||||
|
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
ASSERT_EQ(kTfLiteOk, status);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user