Added an option TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 to

enable sym quantization with activations in 16-bit and weigths in 8-bit.
This commit is contained in:
Elena Zhelezina 2020-01-21 13:18:55 +00:00
parent 2e98e89091
commit a7899d7544
17 changed files with 477 additions and 202 deletions

View File

@ -93,6 +93,12 @@ class OpsSet(enum.Enum):
# quantized implementations. # quantized implementations.
TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8" TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8"
# Convert model using only TensorFlow Lite operations with quantized int8 weights
# and int16 activations.
# Specifying this will throw an error for operations that do not yet have
# quantized implementations.
TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 = "TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8"
def __str__(self): def __str__(self):
return self.value return self.value

View File

@ -224,6 +224,10 @@ class TFLiteConverterBase(object):
self.target_spec.supported_ops) or self.target_spec.supported_ops) or
self._smallest_supported_type() == constants.INT8) self._smallest_supported_type() == constants.INT8)
def _is_int16x8_target_required(self):
return (set([OpsSet.TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8]) ==
set(self.target_spec.supported_ops))
def _smallest_supported_type(self): def _smallest_supported_type(self):
if self.target_spec.supported_types: if self.target_spec.supported_types:
return min(self.target_spec.supported_types, key=lambda x: x.size) return min(self.target_spec.supported_types, key=lambda x: x.size)
@ -238,7 +242,9 @@ class TFLiteConverterBase(object):
])) ]))
def _is_post_training_optimize(self): def _is_post_training_optimize(self):
return self._is_int8_target_required() or self._any_optimization_enabled() return self._is_int8_target_required() or \
self._is_int16x8_target_required() or \
self._any_optimization_enabled()
def _is_int8_weight_only_quantize(self): def _is_int8_weight_only_quantize(self):
return (self._is_post_training_optimize() and return (self._is_post_training_optimize() and
@ -255,11 +261,12 @@ class TFLiteConverterBase(object):
def _calibrate_quantize_model(self, result, inference_input_type, def _calibrate_quantize_model(self, result, inference_input_type,
inference_output_type, enable_mlir_quantizer): inference_output_type, enable_mlir_quantizer):
allow_float = not self._is_int8_target_required() allow_float = not self._is_int8_target_required() and not self._is_int16x8_target_required()
calibrate_quantize = _calibrator.Calibrator(result) calibrate_quantize = _calibrator.Calibrator(result)
activations_type = constants.INT16 if self._is_int16x8_target_required() else constants.INT8
return calibrate_quantize.calibrate_and_quantize( return calibrate_quantize.calibrate_and_quantize(
self.representative_dataset.input_gen, inference_input_type, self.representative_dataset.input_gen, inference_input_type,
inference_output_type, allow_float, enable_mlir_quantizer) inference_output_type, allow_float, activations_type, enable_mlir_quantizer)
def _get_base_converter_args(self): def _get_base_converter_args(self):
"""Returns the base converter args. """Returns the base converter args.

View File

@ -30,6 +30,7 @@ INT64 = dtypes.int64
STRING = dtypes.string STRING = dtypes.string
QUANTIZED_UINT8 = dtypes.uint8 QUANTIZED_UINT8 = dtypes.uint8
INT8 = dtypes.int8 INT8 = dtypes.int8
INT16 = dtypes.int16
COMPLEX64 = dtypes.complex64 COMPLEX64 = dtypes.complex64
TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF
TFLITE = _toco_flags_pb2.TFLITE TFLITE = _toco_flags_pb2.TFLITE
@ -43,6 +44,7 @@ _tf_export(v1=["lite.constants.STRING"]).export_constant(__name__, "STRING")
_tf_export(v1=["lite.constants.QUANTIZED_UINT8"]).export_constant( _tf_export(v1=["lite.constants.QUANTIZED_UINT8"]).export_constant(
__name__, "QUANTIZED_UINT8") __name__, "QUANTIZED_UINT8")
_tf_export(v1=["lite.constants.INT8"]).export_constant(__name__, "INT8") _tf_export(v1=["lite.constants.INT8"]).export_constant(__name__, "INT8")
_tf_export(v1=["lite.constants.INT16"]).export_constant(__name__, "INT16")
_tf_export(v1=["lite.constants.TFLITE"]).export_constant(__name__, "TFLITE") _tf_export(v1=["lite.constants.TFLITE"]).export_constant(__name__, "TFLITE")
_tf_export(v1=["lite.constants.GRAPHVIZ_DOT"]).export_constant( _tf_export(v1=["lite.constants.GRAPHVIZ_DOT"]).export_constant(
__name__, "GRAPHVIZ_DOT") __name__, "GRAPHVIZ_DOT")
@ -62,6 +64,7 @@ _allowed_symbols = [
"STRING", "STRING",
"QUANTIZED_UINT8", "QUANTIZED_UINT8",
"INT8", "INT8",
"INT16",
"COMPLEX64", "COMPLEX64",
"TENSORFLOW_GRAPHDEF", "TENSORFLOW_GRAPHDEF",
"TFLITE", "TFLITE",

View File

@ -769,9 +769,13 @@ class FromSessionTest(TestModels, parameterized.TestCase):
self.assertLess(len(quantized_tflite), len(float_tflite)) self.assertLess(len(quantized_tflite), len(float_tflite))
@parameterized.named_parameters( @parameterized.named_parameters(
('EnableMlirConverter', True), # enable mlir # Quantize model to Int8: with enable mlir
('DisableMlirConverter', False)) # disable mlir ('UseTfliteBuiltinsIntEnableMLIR', [lite.OpsSet.TFLITE_BUILTINS_INT8], True),
def testCalibrateAndQuantizeBuiltinInt8(self, enable_mlir): # Quantize model to Int8: with disable mlir
('UseTfliteBuiltinsIntDisableMLIR', [lite.OpsSet.TFLITE_BUILTINS_INT8], False),
# Quantize model to Int16: with disable mlir
('UseTfliteBuiltinsInt16DisableMLIR', [lite.OpsSet.TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8], False))
def testCalibrateAndQuantizeBuiltinInt(self, supported_ops, enable_mlir):
with ops.Graph().as_default(): with ops.Graph().as_default():
inp, output, calibration_gen = self._getCalibrationQuantizeModel() inp, output, calibration_gen = self._getCalibrationQuantizeModel()
sess = session.Session() sess = session.Session()
@ -787,9 +791,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
quantized_converter = lite.TFLiteConverter.from_session( quantized_converter = lite.TFLiteConverter.from_session(
sess, [inp], [output]) sess, [inp], [output])
quantized_converter.experimental_new_converter = enable_mlir quantized_converter.experimental_new_converter = enable_mlir
quantized_converter.target_spec.supported_ops = [ quantized_converter.target_spec.supported_ops = supported_ops
lite.OpsSet.TFLITE_BUILTINS_INT8
]
quantized_converter.representative_dataset = calibration_gen quantized_converter.representative_dataset = calibration_gen
quantized_tflite = quantized_converter.convert() quantized_tflite = quantized_converter.convert()
self.assertTrue(quantized_tflite) self.assertTrue(quantized_tflite)

View File

@ -204,6 +204,7 @@ PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) {
PyObject* CalibrationWrapper::QuantizeModel(int input_py_type, PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
int output_py_type, int output_py_type,
bool allow_float, bool allow_float,
int activations_py_type,
bool enable_mlir_quantizer) { bool enable_mlir_quantizer) {
if (NoOpModel(*model_)) { if (NoOpModel(*model_)) {
return python_utils::ConvertToPyString(model_str_->data(), return python_utils::ConvertToPyString(model_str_->data(),
@ -212,6 +213,9 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type); TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type); TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
TfLiteType activations_type =
python_utils::TfLiteTypeFromPyType(activations_py_type);
if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) { if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"Input/output type cannot be kTfLiteNoType"); "Input/output type cannot be kTfLiteNoType");
@ -230,7 +234,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
status = tflite::optimize::QuantizeModel( status = tflite::optimize::QuantizeModel(
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type), &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
TfLiteTypeToSchemaType(output_type), allow_float, TfLiteTypeToSchemaType(output_type), allow_float,
error_reporter_.get()); TfLiteTypeToSchemaType(activations_type), error_reporter_.get());
} }
if (status != kTfLiteOk) { if (status != kTfLiteOk) {
@ -262,7 +266,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
auto status = tflite::optimize::QuantizeModel( auto status = tflite::optimize::QuantizeModel(
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type), &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
TfLiteTypeToSchemaType(output_type), allow_float, {op_name}, TfLiteTypeToSchemaType(output_type), allow_float, {op_name},
error_reporter_.get()); TensorType_INT8, error_reporter_.get());
if (status != kTfLiteOk) { if (status != kTfLiteOk) {
error_reporter_->exception(); error_reporter_->exception();
return nullptr; return nullptr;

View File

@ -60,7 +60,8 @@ class CalibrationWrapper {
PyObject* FeedTensor(PyObject* input_value); PyObject* FeedTensor(PyObject* input_value);
PyObject* QuantizeModel(int input_py_type, int output_py_type, PyObject* QuantizeModel(int input_py_type, int output_py_type,
bool allow_float, bool enable_mlir_quantizer = false); bool allow_float, int activations_py_type,
bool enable_mlir_quantizer = false);
// Allows quantizing only the operator that produces the tensor with name // Allows quantizing only the operator that produces the tensor with name
// operator_output_name. (This can be used to help debug.). // operator_output_name. (This can be used to help debug.).

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.lite.python import lite_constants
# Lazy load since some of the performance benchmark skylark rules # Lazy load since some of the performance benchmark skylark rules
# break dependencies. Must use double quotes to match code internal rewrite # break dependencies. Must use double quotes to match code internal rewrite
@ -55,7 +56,8 @@ class Calibrator(object):
raise ValueError("Failed to parse the model.") raise ValueError("Failed to parse the model.")
def calibrate_and_quantize(self, dataset_gen, input_type, output_type, def calibrate_and_quantize(self, dataset_gen, input_type, output_type,
allow_float, enable_mlir_quantizer=False): allow_float, activations_type = lite_constants.INT8,
enable_mlir_quantizer=False):
"""Calibrates the model with specified generator and then quantizes it. """Calibrates the model with specified generator and then quantizes it.
Returns: Returns:
@ -69,6 +71,7 @@ class Calibrator(object):
computation, useful when targeting an integer-only backend. computation, useful when targeting an integer-only backend.
If False, an error will be thrown if an operation cannot be If False, an error will be thrown if an operation cannot be
quantized, otherwise the model will fallback to float ops. quantized, otherwise the model will fallback to float ops.
activations_type: A tf.dtype representing the desired type for activations
enable_mlir_quantizer: A boolean. True if wants to use mlir quantizer to enable_mlir_quantizer: A boolean. True if wants to use mlir quantizer to
quantize the calibrated model. quantize the calibrated model.
""" """
@ -78,6 +81,7 @@ class Calibrator(object):
return self._calibrator.QuantizeModel( return self._calibrator.QuantizeModel(
np.dtype(input_type.as_numpy_dtype()).num, np.dtype(input_type.as_numpy_dtype()).num,
np.dtype(output_type.as_numpy_dtype()).num, allow_float, np.dtype(output_type.as_numpy_dtype()).num, allow_float,
np.dtype(activations_type.as_numpy_dtype()).num,
enable_mlir_quantizer) enable_mlir_quantizer)
def calibrate_and_quantize_single(self, dataset_gen, input_type, output_type, def calibrate_and_quantize_single(self, dataset_gen, input_type, output_type,

View File

@ -33,9 +33,13 @@ from tensorflow.python.platform import test
class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@parameterized.named_parameters( @parameterized.named_parameters(
('EnableMlirQuantizer', True), # enable mlir quantizer # Activation type Int8 - enable mlir quantizer
('DisableMlirQuantizer', False)) # disable mlir quantizer ('UseActivationTypeInt8EnabledMlir', constants.INT8, True),
def test_calibration_with_quantization(self, enable_mlir): # Activation type Int8 - disable mlir quantizer
('UseActivationTypeInt8DisabledMlir', constants.INT8, False),
# Activation type Int16
('UseActivationTypeInt16', constants.INT16, False))
def test_calibration_with_quantization(self, activations_type, enable_mlir):
model_path = resource_loader.get_path_to_datafile( model_path = resource_loader.get_path_to_datafile(
'test_data/mobilenet_like_model.bin') 'test_data/mobilenet_like_model.bin')
float_model = open(model_path, 'rb').read() float_model = open(model_path, 'rb').read()
@ -49,13 +53,18 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
quantized_model = quantizer.calibrate_and_quantize(input_gen, quantized_model = quantizer.calibrate_and_quantize(input_gen,
constants.FLOAT, constants.FLOAT,
constants.FLOAT, False, constants.FLOAT, False,
activations_type,
enable_mlir) enable_mlir)
self.assertIsNotNone(quantized_model) self.assertIsNotNone(quantized_model)
@parameterized.named_parameters( @parameterized.named_parameters(
('EnableMlirQuantizer', True), # enable mlir quantizer # Activation type Int8 - enable mlir quantizer
('DisableMlirQuantizer', False)) # disable mlir quantizer ('UseActivationTypeInt8EnabledMlir', constants.INT8, True),
def test_calibration_with_quantization_allow_float(self, enable_mlir): # Activation type Int8 - disable mlir quantizer
('UseActivationTypeInt8DisableMlir', constants.INT8, False),
# Activation type Int16 - disable mlir quantizer
('UseActivationTypeInt16', constants.INT16, False))
def test_calibration_with_quantization_allow_float(self, activations_type, enable_mlir):
model_path = resource_loader.get_path_to_datafile( model_path = resource_loader.get_path_to_datafile(
'test_data/mobilenet_like_model.bin') 'test_data/mobilenet_like_model.bin')
float_model = open(model_path, 'rb').read() float_model = open(model_path, 'rb').read()
@ -69,6 +78,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
quantized_model = quantizer.calibrate_and_quantize(input_gen, quantized_model = quantizer.calibrate_and_quantize(input_gen,
constants.FLOAT, constants.FLOAT,
constants.FLOAT, True, constants.FLOAT, True,
activations_type,
enable_mlir) enable_mlir)
self.assertIsNotNone(quantized_model) self.assertIsNotNone(quantized_model)
@ -88,9 +98,13 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertIsNotNone(quantized_model) self.assertIsNotNone(quantized_model)
@parameterized.named_parameters( @parameterized.named_parameters(
('EnableMlirQuantizer', True), # enable mlir quantizer # Activation type Int8 - enable mlir quantizer
('DisableMlirQuantizer', False)) # disable mlir quantizer ('UseActivationTypeInt8 - EnableMlirQuantizer', constants.INT8, True),
def test_calibration_with_quantization_multiple_inputs(self, enable_mlir): # Activation type Int8 - disable mlir quantizer
('UseActivationTypeInt8 - DisableMlirQuantizer', constants.INT8, False),
# Activation type Int16 - disable mlir quantizer
('UseActivationTypeInt16 - DisableEnableMlirQuantizer', constants.INT16, False))
def test_calibration_with_quantization_multiple_inputs(self, activations_type, enable_mlir):
# Load multi add model from test data. # Load multi add model from test data.
# This model has 4 inputs of size (1, 8, 8, 3). # This model has 4 inputs of size (1, 8, 8, 3).
model_path = resource_loader.get_path_to_datafile( model_path = resource_loader.get_path_to_datafile(
@ -106,6 +120,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
quantized_model = quantizer.calibrate_and_quantize(input_gen, quantized_model = quantizer.calibrate_and_quantize(input_gen,
constants.FLOAT, constants.FLOAT,
constants.FLOAT, False, constants.FLOAT, False,
activations_type,
enable_mlir) enable_mlir)
self.assertIsNotNone(quantized_model) self.assertIsNotNone(quantized_model)
@ -148,7 +163,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
with self.assertRaisesRegex(ValueError, 'Size mismatch'): with self.assertRaisesRegex(ValueError, 'Size mismatch'):
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
constants.FLOAT, False, enable_mlir) constants.FLOAT, False,
enable_mlir)
@parameterized.named_parameters( @parameterized.named_parameters(
('EnableMlirQuantizer', True), # enable mlir quantizer ('EnableMlirQuantizer', True), # enable mlir quantizer
@ -166,7 +182,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
constants.FLOAT, False, enable_mlir) constants.FLOAT, False,
constants.INT8, enable_mlir)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -64,6 +64,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
property.inputs = {{0, {}}, {1, {}}}; property.inputs = {{0, {}}, {1, {}}};
property.outputs = {{0, {}}}; property.outputs = {{0, {}}};
property.version = 2; property.version = 2;
property.quantize_input_as_activations = true;
break; break;
case BuiltinOperator_ARG_MAX: case BuiltinOperator_ARG_MAX:
property.inputs = {{0, {}}}; property.inputs = {{0, {}}};
@ -176,7 +177,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
// LogSoftmax requires output with 16/256 as scale and 127 as zero point. // LogSoftmax requires output with 16/256 as scale and 127 as zero point.
TensorProperty tensor_property; TensorProperty tensor_property;
tensor_property.restriction = true; tensor_property.restriction = true;
tensor_property.restricted_value = {16.0 / 256.0, 127}; tensor_property.restricted_value_int8 = {16.0 / 256.0, 127};
property.outputs = {{0, tensor_property}}; property.outputs = {{0, tensor_property}};
property.version = 2; property.version = 2;
break; break;
@ -186,7 +187,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
// Logistic requires output with 1/256 as scale and -128 as zero point. // Logistic requires output with 1/256 as scale and -128 as zero point.
TensorProperty tensor_property; TensorProperty tensor_property;
tensor_property.restriction = true; tensor_property.restriction = true;
tensor_property.restricted_value = {1 / 256.0, -128}; tensor_property.restricted_value_int8 = {1 / 256.0, -128};
tensor_property.restricted_value_int16 = {1 / 32768.0, 0};
property.outputs = {{0, tensor_property}}; property.outputs = {{0, tensor_property}};
property.version = 2; property.version = 2;
break; break;
@ -741,7 +743,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
// L2 Norm requires output with 1/128 as scale and 0 as zero point. // L2 Norm requires output with 1/128 as scale and 0 as zero point.
TensorProperty tensor_property; TensorProperty tensor_property;
tensor_property.restriction = true; tensor_property.restriction = true;
tensor_property.restricted_value = {1 / 128.0, 0}; tensor_property.restricted_value_int8 = {1 / 128.0, 0};
property.outputs = {{0, tensor_property}}; property.outputs = {{0, tensor_property}};
property.version = 2; property.version = 2;
break; break;
@ -756,6 +758,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
property.arbitrary_inputs = true; property.arbitrary_inputs = true;
property.outputs = {{0, {}}}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.quantize_input_as_activations = true;
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_MEAN: case BuiltinOperator_MEAN:
@ -767,6 +770,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
property.arbitrary_inputs = true; property.arbitrary_inputs = true;
property.outputs = {{0, {}}}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.quantize_input_as_activations = true;
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_MUL: case BuiltinOperator_MUL:
@ -778,6 +782,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
property.arbitrary_inputs = true; property.arbitrary_inputs = true;
property.outputs = {{0, {}}}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.restrict_same_input_output_scale = true;
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_PAD: case BuiltinOperator_PAD:
@ -840,7 +845,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
// Softmax requires output with 1/256 as scale and -128 as zero point. // Softmax requires output with 1/256 as scale and -128 as zero point.
TensorProperty tensor_property; TensorProperty tensor_property;
tensor_property.restriction = true; tensor_property.restriction = true;
tensor_property.restricted_value = {1 / 256.0, -128}; tensor_property.restricted_value_int8 = {1 / 256.0, -128};
tensor_property.restricted_value_int16 = {1 / 32768.0, 0};
property.outputs = {{0, tensor_property}}; property.outputs = {{0, tensor_property}};
property.version = 2; property.version = 2;
break; break;
@ -866,7 +872,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
// Tanh requires output with 1/128 as scale and 0 as zero point. // Tanh requires output with 1/128 as scale and 0 as zero point.
TensorProperty tensor_property; TensorProperty tensor_property;
tensor_property.restriction = true; tensor_property.restriction = true;
tensor_property.restricted_value = {1 / 128.0, 0}; tensor_property.restricted_value_int8 = {1 / 128.0, 0};
tensor_property.restricted_value_int16 = {1 / 32768.0, 0};
property.outputs = {{0, tensor_property}}; property.outputs = {{0, tensor_property}};
property.version = 2; property.version = 2;
break; break;

View File

@ -43,7 +43,8 @@ struct TensorProperty {
// Constraints. // Constraints.
bool restriction = false; bool restriction = false;
// scale/zero_point hardcoded. // scale/zero_point hardcoded.
std::pair<float, int> restricted_value = {0.0, 0}; std::pair<float, int> restricted_value_int8 = {0.0, 0};
std::pair<float, int> restricted_value_int16 = {0.0, 0};
// Use derived scale. // Use derived scale.
bool use_derived_scale = false; bool use_derived_scale = false;
@ -93,6 +94,13 @@ struct OperatorProperty {
// Op version. // Op version.
int version = 1; int version = 1;
// When we quantize activations into 16 bit and weights into 8 bit,
// we want to quantize all inputs, including constant tensors,
// for the operators like Add, Mul into 16-bit as well. The constant
// inputs are quantized as weights and this variable indicates
// that we want to do quantizations of these tensors as activations.
bool quantize_input_as_activations = false;
}; };
OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,

View File

@ -20,7 +20,6 @@ limitations under the License.
#include <string> #include <string>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/quantization_util.h"
@ -30,6 +29,7 @@ limitations under the License.
#include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/minimal_logging.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/tools/optimize/model_utils.h" #include "tensorflow/lite/tools/optimize/model_utils.h"
#include "third_party/eigen3/Eigen/Core"
namespace tflite { namespace tflite {
namespace optimize { namespace optimize {
@ -85,6 +85,46 @@ void GetAsymmetricQuantizationParams(
quantization_params->zero_point = std::vector<int64_t>(1, zero_point); quantization_params->zero_point = std::vector<int64_t>(1, zero_point);
} }
void GetSymmetricQuantizationParams(
float min, float max, const int half_quant_range,
QuantizationParametersT* quantization_params) {
// Adjust the boundaries to guarantee 0 is included.
min = std::min(min, 0.0f);
max = std::max(max, 0.0f);
const float scale = std::max(std::abs(max), std::abs(min)) / half_quant_range;
int64_t zero_point = 0;
quantization_params->min = std::vector<float>(1, min);
quantization_params->max = std::vector<float>(1, max);
quantization_params->scale = std::vector<float>(1, scale);
quantization_params->zero_point = std::vector<int64_t>(1, 0);
}
TfLiteStatus GetQuantizationParams(TensorT* tensor, TensorType activations_type,
QuantizationParametersT* quantization_params,
ErrorReporter* error_reporter) {
if (activations_type == TensorType_INT8) {
GetAsymmetricQuantizationParams(
tensor->quantization->min[0], tensor->quantization->max[0],
std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max(),
quantization_params);
} else if (activations_type == TensorType_INT16) {
float range = std::max(std::abs(tensor->quantization->min[0]),
std::abs(tensor->quantization->max[0]));
const float quantized_range = 32767.0;
const float scale = range / quantized_range;
quantization_params->min = std::vector<float>(1, -range);
quantization_params->max = std::vector<float>(1, range);
quantization_params->scale = std::vector<float>(1, scale);
quantization_params->zero_point = std::vector<int64_t>(1, 0);
} else {
error_reporter->Report(
"Unsupported activation type for quantize-activation: %s",
activations_type);
return kTfLiteError;
}
return kTfLiteOk;
}
// Set the max and min quantization parameter for a single tensor given its // Set the max and min quantization parameter for a single tensor given its
// values. // values.
void FillSingleMinMax(const float* const input, const uint64_t input_size, void FillSingleMinMax(const float* const input, const uint64_t input_size,
@ -536,6 +576,7 @@ TfLiteStatus SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor,
model, tensor, error_reporter); model, tensor, error_reporter);
} }
template <class BiasType>
TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor, TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
float scaling_factor, float scaling_factor,
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
@ -548,25 +589,38 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
uint64_t num_elements; uint64_t num_elements;
TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements)); TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements));
std::vector<int32_t> final_buffer(num_elements); std::vector<BiasType> final_buffer(num_elements);
const int32_t kScale = std::numeric_limits<int32_t>::max(); const BiasType kScale = std::numeric_limits<BiasType>::max();
for (size_t i = 0; i < num_elements; i++) { for (size_t i = 0; i < num_elements; i++) {
const int32_t quantized_value = tflite::SafeCast<int32_t>( const BiasType quantized_value = tflite::SafeCast<BiasType>(
TfLiteRound(float_data[i] * scaling_factor_inv)); TfLiteRound(float_data[i] * scaling_factor_inv));
final_buffer[i] = std::min(kScale, std::max(-kScale, quantized_value)); final_buffer[i] = std::min(kScale, std::max(-kScale, quantized_value));
} }
// Set the buffers and output type. // Set the buffers and output type.
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data()); uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
size_t buffer_size = num_elements * sizeof(int32_t); size_t buffer_size = num_elements * sizeof(BiasType);
std::vector<float> scales(1, scaling_factor); std::vector<float> scales(1, scaling_factor);
std::vector<int64_t> zero_points(1, 0); std::vector<int64_t> zero_points(1, 0);
auto output_type = std::is_same<BiasType, std::int32_t>::value
? TensorType_INT32
: TensorType_INT64;
return AddQuantizationParams(scales, zero_points, 0, uint8_buffer, return AddQuantizationParams(scales, zero_points, 0, uint8_buffer,
buffer_size, TensorType_INT32, model, tensor, buffer_size, output_type, model, tensor,
error_reporter); error_reporter);
} }
template TfLiteStatus SymmetricPerLayerBiasQuantize<std::int32_t>(
ModelT* model, TensorT* tensor, float scaling_factor,
ErrorReporter* error_reporter);
template TfLiteStatus SymmetricPerLayerBiasQuantize<std::int64_t>(
ModelT* model, TensorT* tensor, float scaling_factor,
ErrorReporter* error_reporter);
template <class BiasType>
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor, TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
float input_scale, float input_scale,
const float* weight_scales, const float* weight_scales,
@ -583,14 +637,14 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
uint64_t num_elements; uint64_t num_elements;
TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements)); TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements));
std::vector<int32_t> final_buffer(num_elements); std::vector<BiasType> final_buffer(num_elements);
const int32_t kScale = std::numeric_limits<int32_t>::max(); const BiasType kScale = std::numeric_limits<BiasType>::max();
for (int32_t channel_idx = 0; channel_idx < number_of_dimension; for (int32_t channel_idx = 0; channel_idx < number_of_dimension;
channel_idx++) { channel_idx++) {
float scaling_factor = scales[channel_idx]; float scaling_factor = scales[channel_idx];
float scaling_factor_inv = (scaling_factor == 0) ? 0 : 1.0 / scaling_factor; float scaling_factor_inv = (scaling_factor == 0) ? 0 : 1.0 / scaling_factor;
const int32_t quantized_value = tflite::SafeCast<int32_t>( const BiasType quantized_value = tflite::SafeCast<BiasType>(
TfLiteRound(float_data[channel_idx] * scaling_factor_inv)); TfLiteRound(float_data[channel_idx] * scaling_factor_inv));
final_buffer[channel_idx] = final_buffer[channel_idx] =
std::min(kScale, std::max(-kScale, quantized_value)); std::min(kScale, std::max(-kScale, quantized_value));
@ -598,12 +652,26 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
// Set the buffers and output type. // Set the buffers and output type.
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data()); uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
size_t buffer_size = num_elements * sizeof(int32_t); size_t buffer_size = num_elements * sizeof(BiasType);
std::vector<int64_t> zero_point(scales.size(), 0); std::vector<int64_t> zero_point(scales.size(), 0);
auto output_type = std::is_same<BiasType, std::int32_t>::value
? TensorType_INT32
: TensorType_INT64;
return AddQuantizationParams(scales, zero_point, 0, uint8_buffer, buffer_size, return AddQuantizationParams(scales, zero_point, 0, uint8_buffer, buffer_size,
TensorType_INT32, model, tensor, error_reporter); output_type, model, tensor, error_reporter);
} }
template TfLiteStatus SymmetricPerChannelBiasQuantize<std::int64_t>(
ModelT* model, TensorT* tensor, float input_scale,
const float* weight_scales, int number_of_dimension,
ErrorReporter* error_reporter);
template TfLiteStatus SymmetricPerChannelBiasQuantize<std::int32_t>(
ModelT* model, TensorT* tensor, float input_scale,
const float* weight_scales, int number_of_dimension,
ErrorReporter* error_reporter);
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel, TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
int per_axis_index, ErrorReporter* error_reporter) { int per_axis_index, ErrorReporter* error_reporter) {
// TODO(suharshs): Currently we conflate quantizing weights and constants. Its // TODO(suharshs): Currently we conflate quantizing weights and constants. Its
@ -645,12 +713,12 @@ float GetEffectiveScale(ModelT* model, SubGraphT* subgraph, int op_idx,
return scale; return scale;
} }
void QuantizeActivation(TensorT* tensor) { TfLiteStatus QuantizeActivation(TensorT* tensor, TensorType activations_type,
GetAsymmetricQuantizationParams( ErrorReporter* error_reporter) {
tensor->quantization->min[0], tensor->quantization->max[0], TF_LITE_ENSURE_STATUS(GetQuantizationParams(
std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max(), tensor, activations_type, tensor->quantization.get(), error_reporter));
tensor->quantization.get()); tensor->type = activations_type;
tensor->type = TensorType_INT8; return kTfLiteOk;
} }
TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale) { TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale) {

View File

@ -113,12 +113,14 @@ TfLiteStatus SymmetricQuantizeFloatsToInt16(ModelT* model, TensorT* tensor,
ErrorReporter* error_reporter); ErrorReporter* error_reporter);
// Symmetrically quantized the bias for per-layer ops (i.e. FullyConnected). // Symmetrically quantized the bias for per-layer ops (i.e. FullyConnected).
template <typename BiasType>
TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor, TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
float scaling_factor, float scaling_factor,
ErrorReporter* error_reporter); ErrorReporter* error_reporter);
// Symmetrically quantizes the bias for ops like Conv and DepthwiseConv. // Symmetrically quantizes the bias for ops like Conv and DepthwiseConv.
// The scale of bias if weight_per_channel_scale[channel] * input_scale. // The scale of bias if weight_per_channel_scale[channel] * input_scale.
template <typename BiasType>
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor, TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
float input_scale, float input_scale,
const float* weight_scales, const float* weight_scales,
@ -135,8 +137,14 @@ float GetEffectiveScale(ModelT* model, SubGraphT* subgraph, int op_idx,
std::vector<int> intermediate_index, std::vector<int> intermediate_index,
std::vector<float> factors); std::vector<float> factors);
// Return quantization parameters depending on activations type.
TfLiteStatus GetQuantizationParams(TensorT* tensor, TensorType activations_type,
QuantizationParametersT* quantization_params,
ErrorReporter* error_reporter);
// Quantize activation. // Quantize activation.
void QuantizeActivation(TensorT* tensor); TfLiteStatus QuantizeActivation(TensorT* tensor, TensorType activations_type,
ErrorReporter* error_reporter);
// Quantize activation to 16bit. // Quantize activation to 16bit.
TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale); TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale);

View File

@ -701,7 +701,7 @@ TEST_F(QuantizationUtilsTest, SymmetricPerLayerBiasQuantize) {
model->buffers.push_back(std::move(buffer)); model->buffers.push_back(std::move(buffer));
// Call and verify. // Call and verify.
EXPECT_EQ(SymmetricPerLayerBiasQuantize( EXPECT_EQ(SymmetricPerLayerBiasQuantize<int32_t>(
model.get(), model->subgraphs[0]->tensors[0].get(), model.get(), model->subgraphs[0]->tensors[0].get(),
input_scale * weight_scale, &error_reporter_), input_scale * weight_scale, &error_reporter_),
kTfLiteOk); kTfLiteOk);
@ -759,7 +759,7 @@ TEST_F(QuantizationUtilsTest, SymmetricPerChannelBiasQuantize) {
model->buffers.push_back(std::move(buffer)); model->buffers.push_back(std::move(buffer));
// Call and verify. // Call and verify.
EXPECT_EQ(SymmetricPerChannelBiasQuantize( EXPECT_EQ(SymmetricPerChannelBiasQuantize<int32_t>(
model.get(), model->subgraphs[0]->tensors[0].get(), input_scale, model.get(), model->subgraphs[0]->tensors[0].get(), input_scale,
weight_scales.data(), 2, &error_reporter_), weight_scales.data(), 2, &error_reporter_),
kTfLiteOk); kTfLiteOk);

View File

@ -42,7 +42,9 @@ bool CreateQuantizedModel(const std::string& path) {
tflite::StderrReporter error_reporter; tflite::StderrReporter error_reporter;
if (tflite::optimize::QuantizeModel( if (tflite::optimize::QuantizeModel(
&builder, &model, tflite::TensorType_FLOAT32, &builder, &model, tflite::TensorType_FLOAT32,
tflite::TensorType_FLOAT32, &error_reporter) != kTfLiteOk) { tflite::TensorType_FLOAT32,
// TODO: Pass required activation type if needed
tflite::TensorType_INT8, &error_reporter) != kTfLiteOk) {
return false; return false;
} }
return WriteFile(path, builder.GetBufferPointer(), builder.GetSize()); return WriteFile(path, builder.GetBufferPointer(), builder.GetSize());

View File

@ -64,6 +64,7 @@ operator_property::OperatorProperty GetOperatorProperty(
TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor, TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
const TensorT* weight_tensor, TensorT* bias_tensor, const TensorT* weight_tensor, TensorT* bias_tensor,
bool is_per_channel, int channel_dim_index, bool is_per_channel, int channel_dim_index,
const TensorType& activations_type,
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
if (bias_tensor->shape.size() != 1) { if (bias_tensor->shape.size() != 1) {
error_reporter->Report("Expected bias tensor shape to be 1."); error_reporter->Report("Expected bias tensor shape to be 1.");
@ -92,9 +93,15 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
weight_scales.size()); weight_scales.size());
return kTfLiteError; return kTfLiteError;
} }
return utils::SymmetricPerChannelBiasQuantize( if (activations_type == tflite::TensorType_INT16) {
return utils::SymmetricPerChannelBiasQuantize<std::int64_t>(
model, bias_tensor, input_tensor->quantization->scale[0], model, bias_tensor, input_tensor->quantization->scale[0],
weight_scales.data(), channel_dim_size, error_reporter); weight_scales.data(), channel_dim_size, error_reporter);
} else {
return utils::SymmetricPerChannelBiasQuantize<std::int32_t>(
model, bias_tensor, input_tensor->quantization->scale[0],
weight_scales.data(), channel_dim_size, error_reporter);
}
} else { } else {
if (weight_scales.size() != 1) { if (weight_scales.size() != 1) {
error_reporter->Report( error_reporter->Report(
@ -102,40 +109,54 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
weight_scales.size()); weight_scales.size());
return kTfLiteError; return kTfLiteError;
} }
return utils::SymmetricPerLayerBiasQuantize( if (activations_type == tflite::TensorType_INT16) {
return utils::SymmetricPerLayerBiasQuantize<std::int64_t>(
model, bias_tensor, model, bias_tensor,
input_tensor->quantization->scale[0] * weight_scales[0], input_tensor->quantization->scale[0] * weight_scales[0],
error_reporter); error_reporter);
} else {
return utils::SymmetricPerLayerBiasQuantize<std::int32_t>(
model, bias_tensor,
input_tensor->quantization->scale[0] * weight_scales[0],
error_reporter);
}
} }
return kTfLiteError; return kTfLiteError;
} }
// True if the tensor type has to be modified. // True if the tensor type has to be modified.
bool TensorTypeChangeRequired(const TensorT* tensor, const TensorType& type) { bool TensorTypeChangeRequired(const TensorT* tensor, const TensorType& type) {
// The quantized model is type INT8, so if the user provided type is INT8, we // The quantized model is type INT8/INT16, so if the user provided type is
// do not have to do any custom logic. Additionally, if the current tensor // INT8/INT16, we do not have to do any custom logic. Additionally, if the
// isn't INT8 quantized, the custom type doesn't apply. // current tensor isn't INT8/INT16 quantized, the custom type doesn't apply.
return (type != TensorType_INT8 && tensor->type == TensorType_INT8 && bool int8check = type != TensorType_INT8 && tensor->type == TensorType_INT8 &&
!tensor->quantization->scale.empty()); !tensor->quantization->scale.empty();
bool int16check = type != TensorType_INT16 &&
tensor->type == TensorType_INT16 &&
!tensor->quantization->scale.empty();
return (int8check || int16check);
} }
// Sets the input type, adding a Leading Op node at the start of the model if // Sets the input type, adding a Leading Op node at the start of the model if
// necessary. // necessary.
// Returns the new input tensor index. // Returns the new input tensor index.
int32_t SetInputType(ModelT* model, SubGraphT* subgraph, int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
const int32_t tensor_idx, const TensorType& input_type) { const int32_t tensor_idx, const TensorType& input_type,
const TensorType& activations_type) {
TensorT* tensor = subgraph->tensors[tensor_idx].get(); TensorT* tensor = subgraph->tensors[tensor_idx].get();
if (!TensorTypeChangeRequired(tensor, input_type)) { if (!TensorTypeChangeRequired(tensor, input_type)) {
return -1; return -1;
} }
if (input_type == TensorType_FLOAT32 || input_type == TensorType_UINT8) { if (input_type == TensorType_FLOAT32 || input_type == TensorType_UINT8) {
std::string type_string =
activations_type == TensorType_INT16 ? "int16" : "int8";
// Create a new tensor to be the input of the leading Op. // Create a new tensor to be the input of the leading Op.
std::unique_ptr<TensorT> leading_op_input; std::unique_ptr<TensorT> leading_op_input;
if (input_type == TensorType_FLOAT32) { if (input_type == TensorType_FLOAT32) {
// Add tensor for quantize operator. Scales and zero points are not // Add tensor for quantize operator. Scales and zero points are not
// needed. // needed.
const string leading_op_name = tensor->name; const string leading_op_name = tensor->name;
const string new_name_original_input = tensor->name + "_int8"; const string new_name_original_input = tensor->name + "_" + type_string;
tensor->name = new_name_original_input; tensor->name = new_name_original_input;
utils::MakeTensor(leading_op_name, tensor->shape, input_type, utils::MakeTensor(leading_op_name, tensor->shape, input_type,
&leading_op_input); &leading_op_input);
@ -150,7 +171,7 @@ int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
TFLITE_DCHECK_GE(zero_point, -128); TFLITE_DCHECK_GE(zero_point, -128);
TFLITE_DCHECK_LE(zero_point, 127); TFLITE_DCHECK_LE(zero_point, 127);
const string leading_op_name = tensor->name; const string leading_op_name = tensor->name;
const string new_name_original_input = tensor->name + "_int8"; const string new_name_original_input = tensor->name + "_" + type_string;
tensor->name = new_name_original_input; tensor->name = new_name_original_input;
utils::MakeTensorWithQuantParam(leading_op_name, tensor->shape, utils::MakeTensorWithQuantParam(leading_op_name, tensor->shape,
input_type, scale, zero_point + 128, input_type, scale, zero_point + 128,
@ -177,17 +198,20 @@ int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
// necessary. // necessary.
// Returns the new output tensor index. // Returns the new output tensor index.
int32_t SetOutputType(ModelT* model, SubGraphT* subgraph, int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
const int32_t tensor_idx, const TensorType& output_type) { const int32_t tensor_idx, const TensorType& output_type,
const TensorType& activations_type) {
TensorT* tensor = subgraph->tensors[tensor_idx].get(); TensorT* tensor = subgraph->tensors[tensor_idx].get();
if (!TensorTypeChangeRequired(tensor, output_type)) { if (!TensorTypeChangeRequired(tensor, output_type)) {
return -1; return -1;
} }
if (output_type == TensorType_FLOAT32 || output_type == TensorType_UINT8) { if (output_type == TensorType_FLOAT32 || output_type == TensorType_UINT8) {
std::string type_string =
activations_type == TensorType_INT16 ? "int16" : "int8";
// Create a new tensor to be the output of the tailing op. // Create a new tensor to be the output of the tailing op.
std::unique_ptr<TensorT> tailing_op_output; std::unique_ptr<TensorT> tailing_op_output;
if (output_type == TensorType_FLOAT32) { if (output_type == TensorType_FLOAT32) {
const string tailing_op_name = tensor->name; const string tailing_op_name = tensor->name;
const string new_name_original_output = tensor->name + "_int8"; const string new_name_original_output = tensor->name + "_" + type_string;
tensor->name = new_name_original_output; tensor->name = new_name_original_output;
utils::MakeTensor(tailing_op_name, tensor->shape, output_type, utils::MakeTensor(tailing_op_name, tensor->shape, output_type,
&tailing_op_output); &tailing_op_output);
@ -202,7 +226,7 @@ int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
TFLITE_DCHECK_GE(zero_point, -128); TFLITE_DCHECK_GE(zero_point, -128);
TFLITE_DCHECK_LE(zero_point, 127); TFLITE_DCHECK_LE(zero_point, 127);
const string tailing_op_name = tensor->name; const string tailing_op_name = tensor->name;
const string new_name_original_output = tensor->name + "_int8"; const string new_name_original_output = tensor->name + "_" + type_string;
tensor->name = new_name_original_output; tensor->name = new_name_original_output;
utils::MakeTensorWithQuantParam(tailing_op_name, tensor->shape, utils::MakeTensorWithQuantParam(tailing_op_name, tensor->shape,
output_type, scale, zero_point + 128, output_type, scale, zero_point + 128,
@ -238,6 +262,7 @@ int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
// uint8, can be thought as "requant"). // uint8, can be thought as "requant").
TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type, TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
const TensorType& output_type, const TensorType& output_type,
const TensorType& activations_type,
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
subgraph_idx++) { subgraph_idx++) {
@ -253,8 +278,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
EnumNameTensorType(tensor->type)); EnumNameTensorType(tensor->type));
return kTfLiteError; return kTfLiteError;
} }
const int32_t input_idx = const int32_t input_idx = SetInputType(
SetInputType(model, subgraph, subgraph->inputs[i], input_type); model, subgraph, subgraph->inputs[i], input_type, activations_type);
if (input_idx < 0) { if (input_idx < 0) {
continue; continue;
} }
@ -270,8 +295,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
EnumNameTensorType(tensor->type)); EnumNameTensorType(tensor->type));
return kTfLiteError; return kTfLiteError;
} }
const int32_t output_idx = const int32_t output_idx = SetOutputType(
SetOutputType(model, subgraph, subgraph->outputs[i], output_type); model, subgraph, subgraph->outputs[i], output_type, activations_type);
if (output_idx < 0) { if (output_idx < 0) {
continue; continue;
} }
@ -287,6 +312,7 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
// The other ones with constraints are handled in QuantizeWeightsAndInput. // The other ones with constraints are handled in QuantizeWeightsAndInput.
TfLiteStatus ApplyConstraints(ModelT* model, TfLiteStatus ApplyConstraints(ModelT* model,
const std::unordered_set<string>& operator_names, const std::unordered_set<string>& operator_names,
TensorType activations_type,
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
subgraph_idx++) { subgraph_idx++) {
@ -332,7 +358,7 @@ TfLiteStatus ApplyConstraints(ModelT* model,
std::unique_ptr<TensorT> additional_tensor; std::unique_ptr<TensorT> additional_tensor;
const string requant_tensor_name = input_tensor->name + "_requantized"; const string requant_tensor_name = input_tensor->name + "_requantized";
utils::MakeTensorWithQuantParam( utils::MakeTensorWithQuantParam(
requant_tensor_name, input_tensor->shape, TensorType_INT8, requant_tensor_name, input_tensor->shape, activations_type,
output_scale, output_zp, &additional_tensor); output_scale, output_zp, &additional_tensor);
const int32_t additional_tensor_idx = subgraph->tensors.size(); const int32_t additional_tensor_idx = subgraph->tensors.size();
subgraph->tensors.push_back(std::move(additional_tensor)); subgraph->tensors.push_back(std::move(additional_tensor));
@ -382,7 +408,8 @@ std::vector<std::pair<int, operator_property::TensorProperty>> GetOutputs(
bool ShouldRestrictSameInputOutputScale( bool ShouldRestrictSameInputOutputScale(
operator_property::OperatorProperty property) { operator_property::OperatorProperty property) {
// Ops with multiple inputs (i.e. concat) gets restricted in ApplyConstraints. // Ops with multiple inputs (i.e. concat, max and min) gets restricted in
// ApplyConstraints.
return (!property.arbitrary_inputs && return (!property.arbitrary_inputs &&
property.restrict_same_input_output_scale); property.restrict_same_input_output_scale);
} }
@ -401,7 +428,7 @@ TfLiteStatus QuantizeOpInput(
ModelT* model, int32_t subgraph_idx, size_t* op_idx, ModelT* model, int32_t subgraph_idx, size_t* op_idx,
operator_property::OperatorProperty property, operator_property::OperatorProperty property,
const std::pair<int32_t, operator_property::TensorProperty>& input, const std::pair<int32_t, operator_property::TensorProperty>& input,
ErrorReporter* error_reporter) { const TensorType& activations_type, ErrorReporter* error_reporter) {
int32_t input_idx = input.first; int32_t input_idx = input.first;
operator_property::TensorProperty tensor_property = input.second; operator_property::TensorProperty tensor_property = input.second;
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
@ -429,7 +456,9 @@ TfLiteStatus QuantizeOpInput(
if (utils::HasBuffer(model, subgraph, tensor_idx)) { if (utils::HasBuffer(model, subgraph, tensor_idx)) {
// TODO(suharshs): Look at consumers, throw error if one consumer is // TODO(suharshs): Look at consumers, throw error if one consumer is
// per-channel and one per-layer. // per-channel and one per-layer.
if (tensor_property.number_of_bits == 8) { bool quantize_const_input = property.quantize_input_as_activations &&
activations_type == TensorType_INT16;
if (tensor_property.number_of_bits == 8 && !quantize_const_input) {
if (tensor_property.use_derived_scale) { if (tensor_property.use_derived_scale) {
// Currently 8bit tensors in input do not accept derived scale. // Currently 8bit tensors in input do not accept derived scale.
return kTfLiteError; return kTfLiteError;
@ -444,7 +473,7 @@ TfLiteStatus QuantizeOpInput(
*op_idx); *op_idx);
return kTfLiteError; return kTfLiteError;
} }
} else if (tensor_property.number_of_bits == 16) { } else if (tensor_property.number_of_bits == 16 || quantize_const_input) {
if (tensor_property.use_derived_scale) { if (tensor_property.use_derived_scale) {
// Currently 16bit tensors in input do not accept derived scale. // Currently 16bit tensors in input do not accept derived scale.
return kTfLiteError; return kTfLiteError;
@ -476,8 +505,8 @@ TfLiteStatus QuantizeOpInput(
tensor_property.derived_scale.input_tensors, tensor_property.derived_scale.input_tensors,
tensor_property.derived_scale.intermediate_tensors, tensor_property.derived_scale.intermediate_tensors,
tensor_property.derived_scale.factors); tensor_property.derived_scale.factors);
return utils::SymmetricPerLayerBiasQuantize(model, tensor, scale, return utils::SymmetricPerLayerBiasQuantize<std::int32_t>(
error_reporter); model, tensor, scale, error_reporter);
} else if (tensor_property.number_of_bits == 10) { } else if (tensor_property.number_of_bits == 10) {
// When the number of bits is 10 (instead of 16), quantize the tensor to // When the number of bits is 10 (instead of 16), quantize the tensor to
@ -514,7 +543,8 @@ TfLiteStatus QuantizeOpInput(
// Currently 8bit tensors in input do not accept derived scale. // Currently 8bit tensors in input do not accept derived scale.
return kTfLiteError; return kTfLiteError;
} }
utils::QuantizeActivation(tensor); TF_LITE_ENSURE_STATUS(utils::QuantizeActivation(
tensor, activations_type, error_reporter));
} else if (tensor_property.number_of_bits == 16) { } else if (tensor_property.number_of_bits == 16) {
TensorT* tensor = subgraph->tensors[tensor_idx].get(); TensorT* tensor = subgraph->tensors[tensor_idx].get();
float quantized_range = 32767.0; float quantized_range = 32767.0;
@ -532,13 +562,16 @@ TfLiteStatus QuantizeOpInput(
} else { } else {
// If the tensor is not a model input, we need to add a Quantize // If the tensor is not a model input, we need to add a Quantize
// operation since the preceding op may require a float output. // operation since the preceding op may require a float output.
std::string type_string =
activations_type == TensorType_INT16 ? "int16" : "int8";
std::unique_ptr<TensorT> op_output; std::unique_ptr<TensorT> op_output;
utils::MakeTensor(tensor->name + "_int8", tensor->shape, utils::MakeTensor(tensor->name + "_" + type_string, tensor->shape,
TensorType_INT8, &op_output); activations_type, &op_output);
op_output->quantization = absl::make_unique<QuantizationParametersT>(); op_output->quantization = absl::make_unique<QuantizationParametersT>();
op_output->quantization->min.push_back(tensor->quantization->min[0]); op_output->quantization->min.push_back(tensor->quantization->min[0]);
op_output->quantization->max.push_back(tensor->quantization->max[0]); op_output->quantization->max.push_back(tensor->quantization->max[0]);
utils::QuantizeActivation(op_output.get()); TF_LITE_ENSURE_STATUS(utils::QuantizeActivation(
op_output.get(), activations_type, error_reporter));
const int32_t quant_op_output_idx = subgraph->tensors.size(); const int32_t quant_op_output_idx = subgraph->tensors.size();
subgraph->tensors.push_back(std::move(op_output)); subgraph->tensors.push_back(std::move(op_output));
std::unique_ptr<OperatorT> quant_op; std::unique_ptr<OperatorT> quant_op;
@ -580,7 +613,7 @@ TfLiteStatus QuantizeOpOutput(
ModelT* model, int32_t subgraph_idx, int32_t op_idx, ModelT* model, int32_t subgraph_idx, int32_t op_idx,
operator_property::OperatorProperty property, operator_property::OperatorProperty property,
const std::pair<int32_t, operator_property::TensorProperty>& output, const std::pair<int32_t, operator_property::TensorProperty>& output,
ErrorReporter* error_reporter) { TensorType activations_type, ErrorReporter* error_reporter) {
int32_t output_idx = output.first; int32_t output_idx = output.first;
operator_property::TensorProperty tensor_property = output.second; operator_property::TensorProperty tensor_property = output.second;
// If the operator is not quantizable, we don't need to do anything for the // If the operator is not quantizable, we don't need to do anything for the
@ -644,18 +677,22 @@ TfLiteStatus QuantizeOpOutput(
const float max = input_tensor->quantization->max[0]; const float max = input_tensor->quantization->max[0];
output_tensor->quantization->max = {max}; output_tensor->quantization->max = {max};
} }
output_tensor->type = TensorType_INT8; output_tensor->type = activations_type;
} else if (tensor_property.restriction) { } else if (tensor_property.restriction) {
const auto scale_and_zp = tensor_property.restricted_value; const auto scale_and_zp = activations_type == TensorType_INT16
? tensor_property.restricted_value_int16
: tensor_property.restricted_value_int8;
// Apply to output. // Apply to output.
output_tensor->quantization = absl::make_unique<QuantizationParametersT>(); output_tensor->quantization = absl::make_unique<QuantizationParametersT>();
output_tensor->quantization->scale.push_back(scale_and_zp.first); output_tensor->quantization->scale.push_back(scale_and_zp.first);
output_tensor->quantization->zero_point.push_back(scale_and_zp.second); output_tensor->quantization->zero_point.push_back(scale_and_zp.second);
output_tensor->type = TensorType_INT8; output_tensor->type = activations_type;
} else { } else {
// Process regular output that doesn't have any restrictions. // Process regular output that doesn't have any restrictions.
if (utils::HasMinMax(output_tensor)) { if (utils::HasMinMax(output_tensor)) {
utils::QuantizeActivation(output_tensor); utils::QuantizeActivation(output_tensor, activations_type,
error_reporter);
} else { } else {
error_reporter->Report( error_reporter->Report(
"Unable to find min/max value for output %d in %s in " "Unable to find min/max value for output %d in %s in "
@ -668,6 +705,7 @@ TfLiteStatus QuantizeOpOutput(
} }
TfLiteStatus QuantizeIntemediateTensors(ModelT* model, TfLiteStatus QuantizeIntemediateTensors(ModelT* model,
TensorType activations_type,
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
subgraph_idx++) { subgraph_idx++) {
@ -691,7 +729,8 @@ TfLiteStatus QuantizeIntemediateTensors(ModelT* model,
input.second.symmetric == false) { input.second.symmetric == false) {
TensorT* tensor = subgraph->tensors[index_global].get(); TensorT* tensor = subgraph->tensors[index_global].get();
if (utils::HasMinMax(tensor)) { if (utils::HasMinMax(tensor)) {
utils::QuantizeActivation(tensor); utils::QuantizeActivation(tensor, activations_type,
error_reporter);
} else { } else {
error_reporter->Report( error_reporter->Report(
"Unable to find min/max value for output %d in %s in " "Unable to find min/max value for output %d in %s in "
@ -793,7 +832,7 @@ TfLiteStatus QuantizeSharedRange(ModelT* model, ErrorReporter* error_reporter) {
TfLiteStatus QuantizeWeightsInputOutput( TfLiteStatus QuantizeWeightsInputOutput(
ModelT* model, bool allow_float, ModelT* model, bool allow_float,
const std::unordered_set<string>& operator_names, const std::unordered_set<string>& operator_names,
ErrorReporter* error_reporter) { const TensorType& activations_type, ErrorReporter* error_reporter) {
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
subgraph_idx++) { subgraph_idx++) {
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
@ -815,14 +854,16 @@ TfLiteStatus QuantizeWeightsInputOutput(
for (const std::pair<int, operator_property::TensorProperty>& input : for (const std::pair<int, operator_property::TensorProperty>& input :
GetInputs(op, property)) { GetInputs(op, property)) {
TF_LITE_ENSURE_STATUS(QuantizeOpInput(model, subgraph_idx, &op_idx, TF_LITE_ENSURE_STATUS(QuantizeOpInput(model, subgraph_idx, &op_idx,
property, input, error_reporter)); property, input, activations_type,
error_reporter));
} }
// Quantize operator outputs. // Quantize operator outputs.
for (const std::pair<int, operator_property::TensorProperty>& output : for (const std::pair<int, operator_property::TensorProperty>& output :
GetOutputs(op, property)) { GetOutputs(op, property)) {
TF_LITE_ENSURE_STATUS(QuantizeOpOutput( TF_LITE_ENSURE_STATUS(
model, subgraph_idx, op_idx, property, output, error_reporter)); QuantizeOpOutput(model, subgraph_idx, op_idx, property, output,
activations_type, error_reporter));
} }
} }
} }
@ -832,6 +873,7 @@ TfLiteStatus QuantizeWeightsInputOutput(
// Quantize bias. // Quantize bias.
TfLiteStatus QuantizeBiases(ModelT* model, TfLiteStatus QuantizeBiases(ModelT* model,
const std::unordered_set<string>& operator_names, const std::unordered_set<string>& operator_names,
const TensorType& activations_type,
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
subgraph_idx++) { subgraph_idx++) {
@ -877,10 +919,10 @@ TfLiteStatus QuantizeBiases(ModelT* model,
subgraph->tensors[op->inputs[property.inputs[1].first]].get(); subgraph->tensors[op->inputs[property.inputs[1].first]].get();
operator_property::TensorProperty weight_property = operator_property::TensorProperty weight_property =
property.inputs[1].second; property.inputs[1].second;
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(QuantizeBias(
QuantizeBias(model, input_tensor, weight_tensor, bias_tensor, model, input_tensor, weight_tensor, bias_tensor,
weight_property.per_axis, weight_property.per_axis, weight_property.per_axis_index,
weight_property.per_axis_index, error_reporter)); activations_type, error_reporter));
} }
} }
} }
@ -1000,7 +1042,7 @@ TfLiteStatus FillQuantizationParams(
// Check compatibility of activation, weight and bias scales. Adjust if needed. // Check compatibility of activation, weight and bias scales. Adjust if needed.
TfLiteStatus EnsureBiasScaleCompatibility( TfLiteStatus EnsureBiasScaleCompatibility(
ModelT* model, const std::unordered_set<string>& operator_names, ModelT* model, const std::unordered_set<string>& operator_names,
ErrorReporter* error_reporter) { TensorType activations_type, ErrorReporter* error_reporter) {
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
subgraph_idx++) { subgraph_idx++) {
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
@ -1049,11 +1091,9 @@ TfLiteStatus EnsureBiasScaleCompatibility(
// Get input scale for assymmetric quantization. // Get input scale for assymmetric quantization.
QuantizationParametersT temp_quant_params = QuantizationParametersT(); QuantizationParametersT temp_quant_params = QuantizationParametersT();
utils::GetAsymmetricQuantizationParams( TF_LITE_ENSURE_STATUS(
input_tensor->quantization->min[0], utils::GetQuantizationParams(input_tensor, activations_type,
input_tensor->quantization->max[0], &temp_quant_params, error_reporter));
std::numeric_limits<int8_t>::min(),
std::numeric_limits<int8_t>::max(), &temp_quant_params);
if (temp_quant_params.scale.size() != 1) { if (temp_quant_params.scale.size() != 1) {
error_reporter->Report("Unexpected input quantization scale size."); error_reporter->Report("Unexpected input quantization scale size.");
return kTfLiteError; return kTfLiteError;
@ -1132,21 +1172,24 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* model, const TensorType& input_type, ModelT* model, const TensorType& input_type,
const TensorType& output_type, bool allow_float, const TensorType& output_type, bool allow_float,
const std::unordered_set<string>& operator_names, const std::unordered_set<string>& operator_names,
const TensorType& activations_type,
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(
FillQuantizationParams(model, operator_names, error_reporter)); FillQuantizationParams(model, operator_names, error_reporter));
TF_LITE_ENSURE_STATUS(EnsureBiasScaleCompatibility(
model, operator_names, activations_type, error_reporter));
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(
EnsureBiasScaleCompatibility(model, operator_names, error_reporter)); QuantizeIntemediateTensors(model, activations_type, error_reporter));
TF_LITE_ENSURE_STATUS(QuantizeIntemediateTensors(model, error_reporter));
TF_LITE_ENSURE_STATUS(QuantizeSharedRange(model, error_reporter)); TF_LITE_ENSURE_STATUS(QuantizeSharedRange(model, error_reporter));
TF_LITE_ENSURE_STATUS(QuantizeWeightsInputOutput( TF_LITE_ENSURE_STATUS(QuantizeWeightsInputOutput(
model, allow_float, operator_names, error_reporter)); model, allow_float, operator_names, activations_type, error_reporter));
TF_LITE_ENSURE_STATUS(ApplyConstraints(model, operator_names,
activations_type, error_reporter));
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(
ApplyConstraints(model, operator_names, error_reporter)); QuantizeBiases(model, operator_names, activations_type, error_reporter));
TF_LITE_ENSURE_STATUS(QuantizeBiases(model, operator_names, error_reporter));
utils::SetOperatorCodeVersion(model); utils::SetOperatorCodeVersion(model);
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(SetInputAndOutputTypes(
SetInputAndOutputTypes(model, input_type, output_type, error_reporter)); model, input_type, output_type, activations_type, error_reporter));
flatbuffers::Offset<Model> output_model_location = flatbuffers::Offset<Model> output_model_location =
Model::Pack(*builder, model); Model::Pack(*builder, model);
@ -1158,23 +1201,27 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* model, const TensorType& input_type, ModelT* model, const TensorType& input_type,
const TensorType& output_type, bool allow_float, const TensorType& output_type, bool allow_float,
const TensorType& activations_type,
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
return QuantizeModel(builder, model, input_type, output_type, allow_float, return QuantizeModel(builder, model, input_type, output_type, allow_float,
GetAllOperatorOutputs(model), error_reporter); GetAllOperatorOutputs(model), activations_type,
error_reporter);
} }
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* model, const TensorType& input_type, ModelT* model, const TensorType& input_type,
const TensorType& output_type, const TensorType& output_type,
const TensorType& activations_type,
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
return QuantizeModel(builder, model, input_type, output_type, return QuantizeModel(builder, model, input_type, output_type,
/*allow_float=*/false, error_reporter); /*allow_float=*/false, activations_type, error_reporter);
} }
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* model, ErrorReporter* error_reporter) { ModelT* model, const TensorType& activations_type,
ErrorReporter* error_reporter) {
return QuantizeModel(builder, model, TensorType_FLOAT32, TensorType_FLOAT32, return QuantizeModel(builder, model, TensorType_FLOAT32, TensorType_FLOAT32,
/*allow_float=*/false, error_reporter); /*allow_float=*/false, activations_type, error_reporter);
} }
} // namespace optimize } // namespace optimize

View File

@ -35,7 +35,9 @@ namespace optimize {
// //
// Note: This is a private API, subject to change. // Note: This is a private API, subject to change.
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* input_model, ErrorReporter* error_reporter); ModelT* input_model,
const TensorType& activations_type,
ErrorReporter* error_reporter);
// Same as above, but the types of quantized inputs and outputs are // Same as above, but the types of quantized inputs and outputs are
// configurable. // configurable.
@ -44,6 +46,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* input_model, const TensorType& input_type, ModelT* input_model, const TensorType& input_type,
const TensorType& output_type, const TensorType& output_type,
const TensorType& activations_type,
ErrorReporter* error_reporter); ErrorReporter* error_reporter);
// Same as above, but can enable allowing float intermediate operations for ops // Same as above, but can enable allowing float intermediate operations for ops
@ -53,6 +56,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* input_model, const TensorType& input_type, ModelT* input_model, const TensorType& input_type,
const TensorType& output_type, bool allow_float, const TensorType& output_type, bool allow_float,
const TensorType& activations_type,
ErrorReporter* error_reporter); ErrorReporter* error_reporter);
// Same as above, but enables only quantizing a whitelist of operations, // Same as above, but enables only quantizing a whitelist of operations,
@ -63,6 +67,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* input_model, const TensorType& input_type, ModelT* input_model, const TensorType& input_type,
const TensorType& output_type, bool allow_float, const TensorType& output_type, bool allow_float,
const std::unordered_set<string>& operator_names, const std::unordered_set<string>& operator_names,
const TensorType& activations_type,
ErrorReporter* error_reporter); ErrorReporter* error_reporter);
} // namespace optimize } // namespace optimize

View File

@ -80,28 +80,35 @@ class QuantizeModelTest : public testing::Test {
internal::FailOnErrorReporter error_reporter_; internal::FailOnErrorReporter error_reporter_;
}; };
class QuantizeConvModelTest : public QuantizeModelTest { class QuantizeConvModelTest : public QuantizeModelTest,
public testing::WithParamInterface<TensorType> {
protected: protected:
QuantizeConvModelTest() { QuantizeConvModelTest() {
tensor_type_ = GetParam();
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
readonly_model_ = input_model_->GetModel(); readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_); readonly_model_->UnPackTo(&model_);
} }
TensorType tensor_type_;
}; };
TEST_F(QuantizeConvModelTest, QuantizationSucceeds) { INSTANTIATE_TEST_SUITE_P(QuantizeConvModelTestInst, QuantizeConvModelTest,
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, testing::ValuesIn({TensorType_INT8,
TensorType_INT8, &error_reporter_); TensorType_INT16}));
TEST_P(QuantizeConvModelTest, QuantizationSucceeds) {
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk); EXPECT_EQ(status, kTfLiteOk);
const uint8_t* buffer = builder_.GetBufferPointer(); const uint8_t* buffer = builder_.GetBufferPointer();
const Model* output_model = GetModel(buffer); const Model* output_model = GetModel(buffer);
ASSERT_TRUE(output_model); ASSERT_TRUE(output_model);
} }
TEST_F(QuantizeConvModelTest, SkipUnspecifiedLayer) { TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) {
auto status = auto status =
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
/*allow_float=*/true, {}, &error_reporter_); /*allow_float=*/true, {}, tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk); EXPECT_EQ(status, kTfLiteOk);
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size()); ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
// The resulting model should be the same. // The resulting model should be the same.
@ -123,9 +130,9 @@ TEST_F(QuantizeConvModelTest, SkipUnspecifiedLayer) {
} }
} }
TEST_F(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) { TEST_P(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
TensorType_INT8, &error_reporter_); tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk); EXPECT_EQ(status, kTfLiteOk);
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size()); ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size(); for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
@ -148,9 +155,9 @@ TEST_F(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
EXPECT_EQ(model_.operator_codes[0]->version, 3); EXPECT_EQ(model_.operator_codes[0]->version, 3);
} }
TEST_F(QuantizeConvModelTest, OperatorsAreUnchanged) { TEST_P(QuantizeConvModelTest, OperatorsAreUnchanged) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
TensorType_INT8, &error_reporter_); tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk); EXPECT_EQ(status, kTfLiteOk);
ASSERT_EQ(model_.operator_codes.size(), ASSERT_EQ(model_.operator_codes.size(),
readonly_model_->operator_codes()->size()); readonly_model_->operator_codes()->size());
@ -182,20 +189,28 @@ TEST_F(QuantizeConvModelTest, OperatorsAreUnchanged) {
} }
} }
TEST_F(QuantizeConvModelTest, GraphIsFullyQuantized) { TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
TensorType_INT8, &error_reporter_); tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk); EXPECT_EQ(status, kTfLiteOk);
for (const auto& subgraph : model_.subgraphs) { for (const auto& subgraph : model_.subgraphs) {
for (const auto& tensor : subgraph->tensors) { for (const auto& tensor : subgraph->tensors) {
if (tensor_type_ == TensorType_INT8) {
EXPECT_TRUE(tensor->type == TensorType_INT32 || EXPECT_TRUE(tensor->type == TensorType_INT32 ||
tensor->type == TensorType_INT8); tensor->type == TensorType_INT8);
} else if (tensor_type_ == TensorType_INT16) {
EXPECT_TRUE(tensor->type == TensorType_INT64 || // bias
tensor->type == TensorType_INT8 || // weights
tensor->type == TensorType_INT16); // activations
}
} }
} }
} }
TEST_F(QuantizeConvModelTest, FloatInputAndOutput) { TEST_P(QuantizeConvModelTest, FloatInputAndOutput) {
auto status = QuantizeModel(&builder_, &model_, &error_reporter_); auto status =
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk); EXPECT_EQ(status, kTfLiteOk);
for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size(); for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
@ -234,22 +249,33 @@ TEST_F(QuantizeConvModelTest, FloatInputAndOutput) {
EXPECT_EQ(subgraph->tensors[output_idx]->type, TensorType_FLOAT32); EXPECT_EQ(subgraph->tensors[output_idx]->type, TensorType_FLOAT32);
EXPECT_EQ(subgraph->tensors[output_idx]->name, "output"); EXPECT_EQ(subgraph->tensors[output_idx]->name, "output");
// The original input and output has been renamed. // The original input and output has been renamed.
EXPECT_EQ(subgraph->tensors[quant_op->outputs[0]]->name, "input_int8"); std::string control_suffix =
EXPECT_EQ(subgraph->tensors[dequant_op->inputs[0]]->name, "output_int8"); (tensor_type_ == TensorType_INT16) ? "int16" : "int8";
EXPECT_EQ(subgraph->tensors[quant_op->outputs[0]]->name,
"input_" + control_suffix);
EXPECT_EQ(subgraph->tensors[dequant_op->inputs[0]]->name,
"output_" + control_suffix);
for (int tensor_idx = 0; tensor_idx < subgraph->tensors.size(); for (int tensor_idx = 0; tensor_idx < subgraph->tensors.size();
++tensor_idx) { ++tensor_idx) {
const auto& tensor = subgraph->tensors[tensor_idx]; const auto& tensor = subgraph->tensors[tensor_idx];
if (input_idx != tensor_idx && output_idx != tensor_idx) { if (input_idx != tensor_idx && output_idx != tensor_idx) {
if (tensor_type_ == TensorType_INT8) {
EXPECT_TRUE(tensor->type == TensorType_INT32 || EXPECT_TRUE(tensor->type == TensorType_INT32 ||
tensor->type == TensorType_INT8); tensor->type == TensorType_INT8);
} else if (tensor_type_ == TensorType_INT16) {
EXPECT_TRUE(tensor->type == TensorType_INT64 || // bias
tensor->type == TensorType_INT8 || // weights
tensor->type == TensorType_INT16); // activations
}
} }
} }
} }
} }
TEST_F(QuantizeConvModelTest, Uint8InputAndOutput) { TEST_P(QuantizeConvModelTest, Uint8InputAndOutput) {
auto status = QuantizeModel(&builder_, &model_, TensorType_UINT8, auto status =
TensorType_UINT8, &error_reporter_); QuantizeModel(&builder_, &model_, TensorType_UINT8, TensorType_UINT8,
TensorType_INT8, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk); EXPECT_EQ(status, kTfLiteOk);
for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size(); for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
@ -326,7 +352,8 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest {
}; };
TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) { TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_); TensorType_INT8, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk); EXPECT_EQ(status, kTfLiteOk);
const uint8_t* buffer = builder_.GetBufferPointer(); const uint8_t* buffer = builder_.GetBufferPointer();
@ -334,13 +361,16 @@ TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) {
ASSERT_TRUE(output_model); ASSERT_TRUE(output_model);
} }
class QuantizeConcatModelTest : public QuantizeModelTest { class QuantizeConcatModelTest : public QuantizeModelTest,
public testing::WithParamInterface<TensorType> {
protected: protected:
QuantizeConcatModelTest() { QuantizeConcatModelTest() {
input_model_ = ReadModel(internal::kFloatConcatMax5Max10Max10); input_model_ = ReadModel(internal::kFloatConcatMax5Max10Max10);
readonly_model_ = input_model_->GetModel(); readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_); readonly_model_->UnPackTo(&model_);
} }
TensorType tensor_type_;
}; };
// There are two inputs for concat, "input0" and "input1". "input0" has [0, 5] // There are two inputs for concat, "input0" and "input1". "input0" has [0, 5]
@ -352,9 +382,9 @@ class QuantizeConcatModelTest : public QuantizeModelTest {
// input0 -> requant -> input0_requant \ // input0 -> requant -> input0_requant \
// concat - output // concat - output
// input1 / // input1 /
TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) { TEST_P(QuantizeConcatModelTest, AddRequantBeforeConcat) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
TensorType_INT8, &error_reporter_); tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk); EXPECT_EQ(status, kTfLiteOk);
// There is only one subgraph. // There is only one subgraph.
@ -373,32 +403,51 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
EXPECT_EQ(model_.operator_codes[concat->opcode_index]->builtin_code, EXPECT_EQ(model_.operator_codes[concat->opcode_index]->builtin_code,
BuiltinOperator_CONCATENATION); BuiltinOperator_CONCATENATION);
auto zero_point_control = tensor_type_ == TensorType_INT8 ? -128 : 0;
/*
input0_scale_control
INT8: (5-0) / (2^8 - 1)
INT16: (5-0) / (2^16 / 2 - 1)
input1_scale
INT8: (10-0) / (2^8 - 1)
INT16: (10-0) / (2^16 / 2 - 1)
*/
auto input0_scale_control =
tensor_type_ == TensorType_INT8 ? 0.019607844 : 0.00015259254;
auto input1_scale =
tensor_type_ == TensorType_INT8 ? 0.039215688 : 0.00030518509;
// There should be 4 tensors: input0, input1, input0_requantized, output. // There should be 4 tensors: input0, input1, input0_requantized, output.
EXPECT_EQ(subgraph->tensors.size(), 4); EXPECT_EQ(subgraph->tensors.size(), 4);
EXPECT_EQ(subgraph->tensors[0]->type, TensorType_INT8); EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
EXPECT_EQ(subgraph->tensors[0]->name, "input0"); EXPECT_EQ(subgraph->tensors[0]->name, "input0");
EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1); EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1); EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0], 0.019607844); EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0],
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0], -128); input0_scale_control);
EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT8); EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0],
zero_point_control);
EXPECT_EQ(subgraph->tensors[1]->type, tensor_type_);
EXPECT_EQ(subgraph->tensors[1]->name, "input1"); EXPECT_EQ(subgraph->tensors[1]->name, "input1");
EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1); EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1);
EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1); EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1);
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], 0.039215688); EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], input1_scale);
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0], -128); EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0],
EXPECT_EQ(subgraph->tensors[2]->type, TensorType_INT8); zero_point_control);
EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
EXPECT_EQ(subgraph->tensors[2]->name, "output"); EXPECT_EQ(subgraph->tensors[2]->name, "output");
EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1); EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1); EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], 0.039215688); EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], input1_scale);
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0], -128); EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0],
EXPECT_EQ(subgraph->tensors[3]->type, TensorType_INT8); zero_point_control);
EXPECT_EQ(subgraph->tensors[3]->type, tensor_type_);
EXPECT_EQ(subgraph->tensors[3]->name, "input0_requantized"); EXPECT_EQ(subgraph->tensors[3]->name, "input0_requantized");
EXPECT_EQ(subgraph->tensors[3]->quantization->scale.size(), 1); EXPECT_EQ(subgraph->tensors[3]->quantization->scale.size(), 1);
EXPECT_EQ(subgraph->tensors[3]->quantization->zero_point.size(), 1); EXPECT_EQ(subgraph->tensors[3]->quantization->zero_point.size(), 1);
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->scale[0], 0.039215688); EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->scale[0], input1_scale);
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->zero_point[0], -128); EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->zero_point[0],
zero_point_control);
// The connection should be what is described in the comment. // The connection should be what is described in the comment.
EXPECT_EQ(requant->inputs.size(), 1); EXPECT_EQ(requant->inputs.size(), 1);
@ -419,7 +468,9 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_QUANTIZE); EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_QUANTIZE);
EXPECT_EQ(model_.operator_codes[1]->version, 2); EXPECT_EQ(model_.operator_codes[1]->version, 2);
} }
INSTANTIATE_TEST_SUITE_P(QuantizeConcatModelInst, QuantizeConcatModelTest,
testing::ValuesIn({TensorType_INT8,
TensorType_INT16}));
class QuantizeSplitModelTest : public QuantizeModelTest { class QuantizeSplitModelTest : public QuantizeModelTest {
protected: protected:
QuantizeSplitModelTest() { QuantizeSplitModelTest() {
@ -432,7 +483,8 @@ class QuantizeSplitModelTest : public QuantizeModelTest {
// There are two outputs for split with different scales, the resulting model // There are two outputs for split with different scales, the resulting model
// should have the scales be hardcodes to the input scale value. // should have the scales be hardcodes to the input scale value.
TEST_F(QuantizeSplitModelTest, QuantizeSplit) { TEST_F(QuantizeSplitModelTest, QuantizeSplit) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_); TensorType_INT8, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk); EXPECT_EQ(status, kTfLiteOk);
@ -496,7 +548,8 @@ class QuantizeConvModel1Test : public QuantizeModelTest {
}; };
TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) { TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_); TensorType_INT8, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk); EXPECT_EQ(status, kTfLiteOk);
const auto& subgraph = model_.subgraphs[0]; const auto& subgraph = model_.subgraphs[0];
@ -587,18 +640,25 @@ TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
EXPECT_EQ(model_.operator_codes[0]->version, 3); EXPECT_EQ(model_.operator_codes[0]->version, 3);
} }
class QuantizeConvModel2Test : public QuantizeModelTest { class QuantizeConvModel2Test : public QuantizeModelTest,
public testing::WithParamInterface<TensorType> {
protected: protected:
QuantizeConvModel2Test() { QuantizeConvModel2Test() {
tensor_type_ = GetParam();
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
readonly_model_ = input_model_->GetModel(); readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_); readonly_model_->UnPackTo(&model_);
} }
};
TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) { TensorType tensor_type_;
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, };
TensorType_INT8, &error_reporter_); INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test,
testing::ValuesIn({TensorType_INT8,
TensorType_INT16}));
TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) {
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
tensor_type_, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);
const auto& subgraph = model_.subgraphs[0]; const auto& subgraph = model_.subgraphs[0];
auto conv_op = subgraph->operators[0].get(); auto conv_op = subgraph->operators[0].get();
@ -615,8 +675,10 @@ TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) {
const auto output_tensor = const auto output_tensor =
subgraph->tensors[conv_op->outputs[output_tensor_idx]].get(); subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
EXPECT_EQ(bias_tensor->type, TensorType_INT32); EXPECT_EQ(bias_tensor->type, tensor_type_ == TensorType_INT8
EXPECT_EQ(input_tensor->type, TensorType_INT8); ? TensorType_INT32
: TensorType_INT64);
EXPECT_EQ(input_tensor->type, tensor_type_);
EXPECT_EQ(weights_tensor->type, TensorType_INT8); EXPECT_EQ(weights_tensor->type, TensorType_INT8);
ASSERT_TRUE(weights_tensor->quantization); ASSERT_TRUE(weights_tensor->quantization);
@ -644,18 +706,29 @@ TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) {
} }
const auto bias_buffer = model_.buffers[bias_tensor->buffer].get(); const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
ASSERT_EQ(bias_buffer->data.size(), sizeof(int32_t) * bias_tensor->shape[0]); auto control_size = tensor_type_ == TensorType_INT8
const int32_t* bias_values = ? sizeof(int32_t) * bias_tensor->shape[0]
reinterpret_cast<int32_t*>(bias_buffer->data.data()); : sizeof(int64_t) * bias_tensor->shape[0];
ASSERT_EQ(bias_buffer->data.size(), control_size);
const auto original_bias_buffer = const auto original_bias_buffer =
readonly_model_->buffers()->Get(bias_tensor->buffer); readonly_model_->buffers()->Get(bias_tensor->buffer);
const float* bias_float_buffer = const float* bias_float_buffer =
reinterpret_cast<const float*>(original_bias_buffer->data()->data()); reinterpret_cast<const float*>(original_bias_buffer->data()->data());
if (tensor_type_ == TensorType_INT8) {
int32_t* bias_values = reinterpret_cast<int32_t*>(bias_buffer->data.data());
for (size_t i = 0; i < out_channel_size; i++) { for (size_t i = 0; i < out_channel_size; i++) {
auto dequantized_value = bias_values[i] * bias_scales[i]; auto dequantized_value = bias_values[i] * bias_scales[i];
EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2); EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
} }
} else if (tensor_type_ == TensorType_INT16) {
int64_t* bias_values = reinterpret_cast<int64_t*>(bias_buffer->data.data());
for (size_t i = 0; i < out_channel_size; i++) {
auto dequantized_value = bias_values[i] * bias_scales[i];
EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
}
}
const auto weights_buffer = model_.buffers[weights_tensor->buffer].get(); const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
const auto original_weights_buffer = const auto original_weights_buffer =
@ -695,7 +768,8 @@ class QuantizeSoftmaxTest : public QuantizeModelTest {
}; };
TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_); TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);
@ -755,7 +829,8 @@ class QuantizeAvgPoolTest : public QuantizeModelTest {
}; };
TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_); TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);
@ -816,7 +891,8 @@ class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest {
}; };
TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_); TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);
@ -863,7 +939,8 @@ TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
} }
TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) { TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_); TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);
@ -923,7 +1000,8 @@ class QuantizeConstInputTest : public QuantizeModelTest {
}; };
TEST_F(QuantizeConstInputTest, VerifyConstOpInput) { TEST_F(QuantizeConstInputTest, VerifyConstOpInput) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_); TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);
@ -965,7 +1043,8 @@ class QuantizeArgMaxTest : public QuantizeModelTest {
}; };
TEST_F(QuantizeArgMaxTest, VerifyArgMax) { TEST_F(QuantizeArgMaxTest, VerifyArgMax) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_); TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);
@ -1008,8 +1087,9 @@ class QuantizeLSTMTest : public QuantizeModelTest {
TEST_F(QuantizeLSTMTest, VerifyLSTM) { TEST_F(QuantizeLSTMTest, VerifyLSTM) {
// Quantize model. // Quantize model.
auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32, auto status =
TensorType_FLOAT32, &error_reporter_); QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);
// Read expected model. // Read expected model.
@ -1067,8 +1147,9 @@ class QuantizeLSTM2Test : public QuantizeModelTest {
TEST_F(QuantizeLSTM2Test, VerifyLSTM) { TEST_F(QuantizeLSTM2Test, VerifyLSTM) {
// Quantize model. // Quantize model.
auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32, auto status =
TensorType_FLOAT32, &error_reporter_); QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);
// Read expected model. // Read expected model.
@ -1126,7 +1207,8 @@ class QuantizeSVDFTest : public QuantizeModelTest {
TEST_F(QuantizeSVDFTest, VerifySVDF) { TEST_F(QuantizeSVDFTest, VerifySVDF) {
// Quantize model. // Quantize model.
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_); TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);
@ -1184,7 +1266,8 @@ class QuantizeFCTest : public QuantizeModelTest {
}; };
TEST_F(QuantizeFCTest, VerifyFC) { TEST_F(QuantizeFCTest, VerifyFC) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_); TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);
@ -1236,7 +1319,7 @@ class QuantizeCustomOpTest : public QuantizeModelTest {
TEST_F(QuantizeCustomOpTest, VerifyMixedQuantization) { TEST_F(QuantizeCustomOpTest, VerifyMixedQuantization) {
auto status = auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
/*allow_float=*/true, &error_reporter_); /*allow_float=*/true, TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);
const auto& subgraph = model_.subgraphs[0]; const auto& subgraph = model_.subgraphs[0];
auto float_graph = readonly_model_->subgraphs()->Get(0); auto float_graph = readonly_model_->subgraphs()->Get(0);
@ -1270,7 +1353,8 @@ class QuantizePackTest : public QuantizeModelTest {
}; };
TEST_F(QuantizePackTest, VerifyPack) { TEST_F(QuantizePackTest, VerifyPack) {
auto status = QuantizeModel(&builder_, &model_, &error_reporter_); auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);
@ -1334,7 +1418,8 @@ class QuantizeMinimumMaximumTest
}; };
TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) {
auto status = QuantizeModel(&builder_, &model_, &error_reporter_); auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);
const auto& subgraph = model_.subgraphs[0]; const auto& subgraph = model_.subgraphs[0];
@ -1415,7 +1500,8 @@ class QuantizeUnpackTest : public QuantizeModelTest {
} }
}; };
TEST_F(QuantizeUnpackTest, VerifyUnpack) { TEST_F(QuantizeUnpackTest, VerifyUnpack) {
auto status = QuantizeModel(&builder_, &model_, &error_reporter_); auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status); ASSERT_EQ(kTfLiteOk, status);