From a7899d7544230fce8dae4895733d82623af2b934 Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Tue, 21 Jan 2020 13:18:55 +0000 Subject: [PATCH] Added an option TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 to enable sym quantization with activations in 16-bit and weigths in 8-bit. --- tensorflow/lite/python/convert.py | 6 + tensorflow/lite/python/lite.py | 13 +- tensorflow/lite/python/lite_constants.py | 3 + tensorflow/lite/python/lite_test.py | 14 +- .../python/optimize/calibration_wrapper.cc | 8 +- .../python/optimize/calibration_wrapper.h | 3 +- tensorflow/lite/python/optimize/calibrator.py | 6 +- .../lite/python/optimize/calibrator_test.py | 39 ++- .../lite/tools/optimize/operator_property.cc | 17 +- .../lite/tools/optimize/operator_property.h | 10 +- .../lite/tools/optimize/quantization_utils.cc | 102 +++++-- .../lite/tools/optimize/quantization_utils.h | 10 +- .../tools/optimize/quantization_utils_test.cc | 4 +- .../tools/optimize/quantization_wrapper.cc | 4 +- .../lite/tools/optimize/quantize_model.cc | 175 +++++++----- .../lite/tools/optimize/quantize_model.h | 7 +- .../tools/optimize/quantize_model_test.cc | 258 ++++++++++++------ 17 files changed, 477 insertions(+), 202 deletions(-) diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 2fe4d172487..494f32a515c 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -93,6 +93,12 @@ class OpsSet(enum.Enum): # quantized implementations. TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8" + # Convert model using only TensorFlow Lite operations with quantized int8 weights + # and int16 activations. + # Specifying this will throw an error for operations that do not yet have + # quantized implementations. + TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 = "TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8" + def __str__(self): return self.value diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 657cfea1bb8..fc9c064faf0 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -224,6 +224,10 @@ class TFLiteConverterBase(object): self.target_spec.supported_ops) or self._smallest_supported_type() == constants.INT8) + def _is_int16x8_target_required(self): + return (set([OpsSet.TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8]) == + set(self.target_spec.supported_ops)) + def _smallest_supported_type(self): if self.target_spec.supported_types: return min(self.target_spec.supported_types, key=lambda x: x.size) @@ -238,7 +242,9 @@ class TFLiteConverterBase(object): ])) def _is_post_training_optimize(self): - return self._is_int8_target_required() or self._any_optimization_enabled() + return self._is_int8_target_required() or \ + self._is_int16x8_target_required() or \ + self._any_optimization_enabled() def _is_int8_weight_only_quantize(self): return (self._is_post_training_optimize() and @@ -255,11 +261,12 @@ class TFLiteConverterBase(object): def _calibrate_quantize_model(self, result, inference_input_type, inference_output_type, enable_mlir_quantizer): - allow_float = not self._is_int8_target_required() + allow_float = not self._is_int8_target_required() and not self._is_int16x8_target_required() calibrate_quantize = _calibrator.Calibrator(result) + activations_type = constants.INT16 if self._is_int16x8_target_required() else constants.INT8 return calibrate_quantize.calibrate_and_quantize( self.representative_dataset.input_gen, inference_input_type, - inference_output_type, allow_float, enable_mlir_quantizer) + inference_output_type, allow_float, activations_type, enable_mlir_quantizer) def _get_base_converter_args(self): """Returns the base converter args. diff --git a/tensorflow/lite/python/lite_constants.py b/tensorflow/lite/python/lite_constants.py index d43452c775b..4902f23795e 100644 --- a/tensorflow/lite/python/lite_constants.py +++ b/tensorflow/lite/python/lite_constants.py @@ -30,6 +30,7 @@ INT64 = dtypes.int64 STRING = dtypes.string QUANTIZED_UINT8 = dtypes.uint8 INT8 = dtypes.int8 +INT16 = dtypes.int16 COMPLEX64 = dtypes.complex64 TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF TFLITE = _toco_flags_pb2.TFLITE @@ -43,6 +44,7 @@ _tf_export(v1=["lite.constants.STRING"]).export_constant(__name__, "STRING") _tf_export(v1=["lite.constants.QUANTIZED_UINT8"]).export_constant( __name__, "QUANTIZED_UINT8") _tf_export(v1=["lite.constants.INT8"]).export_constant(__name__, "INT8") +_tf_export(v1=["lite.constants.INT16"]).export_constant(__name__, "INT16") _tf_export(v1=["lite.constants.TFLITE"]).export_constant(__name__, "TFLITE") _tf_export(v1=["lite.constants.GRAPHVIZ_DOT"]).export_constant( __name__, "GRAPHVIZ_DOT") @@ -62,6 +64,7 @@ _allowed_symbols = [ "STRING", "QUANTIZED_UINT8", "INT8", + "INT16", "COMPLEX64", "TENSORFLOW_GRAPHDEF", "TFLITE", diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index 16959c84146..ef5e5d1cdf4 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -769,9 +769,13 @@ class FromSessionTest(TestModels, parameterized.TestCase): self.assertLess(len(quantized_tflite), len(float_tflite)) @parameterized.named_parameters( - ('EnableMlirConverter', True), # enable mlir - ('DisableMlirConverter', False)) # disable mlir - def testCalibrateAndQuantizeBuiltinInt8(self, enable_mlir): + # Quantize model to Int8: with enable mlir + ('UseTfliteBuiltinsIntEnableMLIR', [lite.OpsSet.TFLITE_BUILTINS_INT8], True), + # Quantize model to Int8: with disable mlir + ('UseTfliteBuiltinsIntDisableMLIR', [lite.OpsSet.TFLITE_BUILTINS_INT8], False), + # Quantize model to Int16: with disable mlir + ('UseTfliteBuiltinsInt16DisableMLIR', [lite.OpsSet.TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8], False)) + def testCalibrateAndQuantizeBuiltinInt(self, supported_ops, enable_mlir): with ops.Graph().as_default(): inp, output, calibration_gen = self._getCalibrationQuantizeModel() sess = session.Session() @@ -787,9 +791,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): quantized_converter = lite.TFLiteConverter.from_session( sess, [inp], [output]) quantized_converter.experimental_new_converter = enable_mlir - quantized_converter.target_spec.supported_ops = [ - lite.OpsSet.TFLITE_BUILTINS_INT8 - ] + quantized_converter.target_spec.supported_ops = supported_ops quantized_converter.representative_dataset = calibration_gen quantized_tflite = quantized_converter.convert() self.assertTrue(quantized_tflite) diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc index 89ffb3430ea..88995136726 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -204,6 +204,7 @@ PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) { PyObject* CalibrationWrapper::QuantizeModel(int input_py_type, int output_py_type, bool allow_float, + int activations_py_type, bool enable_mlir_quantizer) { if (NoOpModel(*model_)) { return python_utils::ConvertToPyString(model_str_->data(), @@ -212,6 +213,9 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type, TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type); TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type); + TfLiteType activations_type = + python_utils::TfLiteTypeFromPyType(activations_py_type); + if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) { PyErr_SetString(PyExc_ValueError, "Input/output type cannot be kTfLiteNoType"); @@ -230,7 +234,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type, status = tflite::optimize::QuantizeModel( &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type), TfLiteTypeToSchemaType(output_type), allow_float, - error_reporter_.get()); + TfLiteTypeToSchemaType(activations_type), error_reporter_.get()); } if (status != kTfLiteOk) { @@ -262,7 +266,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type, auto status = tflite::optimize::QuantizeModel( &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type), TfLiteTypeToSchemaType(output_type), allow_float, {op_name}, - error_reporter_.get()); + TensorType_INT8, error_reporter_.get()); if (status != kTfLiteOk) { error_reporter_->exception(); return nullptr; diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.h b/tensorflow/lite/python/optimize/calibration_wrapper.h index 0fefc29dd81..e72fe15e958 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.h +++ b/tensorflow/lite/python/optimize/calibration_wrapper.h @@ -60,7 +60,8 @@ class CalibrationWrapper { PyObject* FeedTensor(PyObject* input_value); PyObject* QuantizeModel(int input_py_type, int output_py_type, - bool allow_float, bool enable_mlir_quantizer = false); + bool allow_float, int activations_py_type, + bool enable_mlir_quantizer = false); // Allows quantizing only the operator that produces the tensor with name // operator_output_name. (This can be used to help debug.). diff --git a/tensorflow/lite/python/optimize/calibrator.py b/tensorflow/lite/python/optimize/calibrator.py index 6d9a29236f0..1f962917551 100644 --- a/tensorflow/lite/python/optimize/calibrator.py +++ b/tensorflow/lite/python/optimize/calibrator.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.util.lazy_loader import LazyLoader +from tensorflow.lite.python import lite_constants # Lazy load since some of the performance benchmark skylark rules # break dependencies. Must use double quotes to match code internal rewrite @@ -55,7 +56,8 @@ class Calibrator(object): raise ValueError("Failed to parse the model.") def calibrate_and_quantize(self, dataset_gen, input_type, output_type, - allow_float, enable_mlir_quantizer=False): + allow_float, activations_type = lite_constants.INT8, + enable_mlir_quantizer=False): """Calibrates the model with specified generator and then quantizes it. Returns: @@ -69,6 +71,7 @@ class Calibrator(object): computation, useful when targeting an integer-only backend. If False, an error will be thrown if an operation cannot be quantized, otherwise the model will fallback to float ops. + activations_type: A tf.dtype representing the desired type for activations enable_mlir_quantizer: A boolean. True if wants to use mlir quantizer to quantize the calibrated model. """ @@ -78,6 +81,7 @@ class Calibrator(object): return self._calibrator.QuantizeModel( np.dtype(input_type.as_numpy_dtype()).num, np.dtype(output_type.as_numpy_dtype()).num, allow_float, + np.dtype(activations_type.as_numpy_dtype()).num, enable_mlir_quantizer) def calibrate_and_quantize_single(self, dataset_gen, input_type, output_type, diff --git a/tensorflow/lite/python/optimize/calibrator_test.py b/tensorflow/lite/python/optimize/calibrator_test.py index 28e8723f23d..7ec5f8f526c 100644 --- a/tensorflow/lite/python/optimize/calibrator_test.py +++ b/tensorflow/lite/python/optimize/calibrator_test.py @@ -33,9 +33,13 @@ from tensorflow.python.platform import test class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): @parameterized.named_parameters( - ('EnableMlirQuantizer', True), # enable mlir quantizer - ('DisableMlirQuantizer', False)) # disable mlir quantizer - def test_calibration_with_quantization(self, enable_mlir): + # Activation type Int8 - enable mlir quantizer + ('UseActivationTypeInt8EnabledMlir', constants.INT8, True), + # Activation type Int8 - disable mlir quantizer + ('UseActivationTypeInt8DisabledMlir', constants.INT8, False), + # Activation type Int16 + ('UseActivationTypeInt16', constants.INT16, False)) + def test_calibration_with_quantization(self, activations_type, enable_mlir): model_path = resource_loader.get_path_to_datafile( 'test_data/mobilenet_like_model.bin') float_model = open(model_path, 'rb').read() @@ -49,13 +53,18 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): quantized_model = quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, constants.FLOAT, False, + activations_type, enable_mlir) self.assertIsNotNone(quantized_model) @parameterized.named_parameters( - ('EnableMlirQuantizer', True), # enable mlir quantizer - ('DisableMlirQuantizer', False)) # disable mlir quantizer - def test_calibration_with_quantization_allow_float(self, enable_mlir): + # Activation type Int8 - enable mlir quantizer + ('UseActivationTypeInt8EnabledMlir', constants.INT8, True), + # Activation type Int8 - disable mlir quantizer + ('UseActivationTypeInt8DisableMlir', constants.INT8, False), + # Activation type Int16 - disable mlir quantizer + ('UseActivationTypeInt16', constants.INT16, False)) + def test_calibration_with_quantization_allow_float(self, activations_type, enable_mlir): model_path = resource_loader.get_path_to_datafile( 'test_data/mobilenet_like_model.bin') float_model = open(model_path, 'rb').read() @@ -69,6 +78,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): quantized_model = quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, constants.FLOAT, True, + activations_type, enable_mlir) self.assertIsNotNone(quantized_model) @@ -88,9 +98,13 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertIsNotNone(quantized_model) @parameterized.named_parameters( - ('EnableMlirQuantizer', True), # enable mlir quantizer - ('DisableMlirQuantizer', False)) # disable mlir quantizer - def test_calibration_with_quantization_multiple_inputs(self, enable_mlir): + # Activation type Int8 - enable mlir quantizer + ('UseActivationTypeInt8 - EnableMlirQuantizer', constants.INT8, True), + # Activation type Int8 - disable mlir quantizer + ('UseActivationTypeInt8 - DisableMlirQuantizer', constants.INT8, False), + # Activation type Int16 - disable mlir quantizer + ('UseActivationTypeInt16 - DisableEnableMlirQuantizer', constants.INT16, False)) + def test_calibration_with_quantization_multiple_inputs(self, activations_type, enable_mlir): # Load multi add model from test data. # This model has 4 inputs of size (1, 8, 8, 3). model_path = resource_loader.get_path_to_datafile( @@ -106,6 +120,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): quantized_model = quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, constants.FLOAT, False, + activations_type, enable_mlir) self.assertIsNotNone(quantized_model) @@ -148,7 +163,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): with self.assertRaisesRegex(ValueError, 'Size mismatch'): quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, - constants.FLOAT, False, enable_mlir) + constants.FLOAT, False, + enable_mlir) @parameterized.named_parameters( ('EnableMlirQuantizer', True), # enable mlir quantizer @@ -166,7 +182,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): with self.assertRaises(ValueError): quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, - constants.FLOAT, False, enable_mlir) + constants.FLOAT, False, + constants.INT8, enable_mlir) if __name__ == '__main__': diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index 13f63092761..1f2d8bb4a4d 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -64,6 +64,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.inputs = {{0, {}}, {1, {}}}; property.outputs = {{0, {}}}; property.version = 2; + property.quantize_input_as_activations = true; break; case BuiltinOperator_ARG_MAX: property.inputs = {{0, {}}}; @@ -176,7 +177,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, // LogSoftmax requires output with 16/256 as scale and 127 as zero point. TensorProperty tensor_property; tensor_property.restriction = true; - tensor_property.restricted_value = {16.0 / 256.0, 127}; + tensor_property.restricted_value_int8 = {16.0 / 256.0, 127}; property.outputs = {{0, tensor_property}}; property.version = 2; break; @@ -186,7 +187,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, // Logistic requires output with 1/256 as scale and -128 as zero point. TensorProperty tensor_property; tensor_property.restriction = true; - tensor_property.restricted_value = {1 / 256.0, -128}; + tensor_property.restricted_value_int8 = {1 / 256.0, -128}; + tensor_property.restricted_value_int16 = {1 / 32768.0, 0}; property.outputs = {{0, tensor_property}}; property.version = 2; break; @@ -741,7 +743,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, // L2 Norm requires output with 1/128 as scale and 0 as zero point. TensorProperty tensor_property; tensor_property.restriction = true; - tensor_property.restricted_value = {1 / 128.0, 0}; + tensor_property.restricted_value_int8 = {1 / 128.0, 0}; property.outputs = {{0, tensor_property}}; property.version = 2; break; @@ -756,6 +758,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.arbitrary_inputs = true; property.outputs = {{0, {}}}; property.restrict_same_input_output_scale = true; + property.quantize_input_as_activations = true; property.version = 2; break; case BuiltinOperator_MEAN: @@ -767,6 +770,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.arbitrary_inputs = true; property.outputs = {{0, {}}}; property.restrict_same_input_output_scale = true; + property.quantize_input_as_activations = true; property.version = 2; break; case BuiltinOperator_MUL: @@ -778,6 +782,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.arbitrary_inputs = true; property.outputs = {{0, {}}}; property.restrict_same_input_output_scale = true; + property.restrict_same_input_output_scale = true; property.version = 2; break; case BuiltinOperator_PAD: @@ -840,7 +845,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, // Softmax requires output with 1/256 as scale and -128 as zero point. TensorProperty tensor_property; tensor_property.restriction = true; - tensor_property.restricted_value = {1 / 256.0, -128}; + tensor_property.restricted_value_int8 = {1 / 256.0, -128}; + tensor_property.restricted_value_int16 = {1 / 32768.0, 0}; property.outputs = {{0, tensor_property}}; property.version = 2; break; @@ -866,7 +872,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, // Tanh requires output with 1/128 as scale and 0 as zero point. TensorProperty tensor_property; tensor_property.restriction = true; - tensor_property.restricted_value = {1 / 128.0, 0}; + tensor_property.restricted_value_int8 = {1 / 128.0, 0}; + tensor_property.restricted_value_int16 = {1 / 32768.0, 0}; property.outputs = {{0, tensor_property}}; property.version = 2; break; diff --git a/tensorflow/lite/tools/optimize/operator_property.h b/tensorflow/lite/tools/optimize/operator_property.h index 5d37aa304e5..23052308568 100644 --- a/tensorflow/lite/tools/optimize/operator_property.h +++ b/tensorflow/lite/tools/optimize/operator_property.h @@ -43,7 +43,8 @@ struct TensorProperty { // Constraints. bool restriction = false; // scale/zero_point hardcoded. - std::pair restricted_value = {0.0, 0}; + std::pair restricted_value_int8 = {0.0, 0}; + std::pair restricted_value_int16 = {0.0, 0}; // Use derived scale. bool use_derived_scale = false; @@ -93,6 +94,13 @@ struct OperatorProperty { // Op version. int version = 1; + + // When we quantize activations into 16 bit and weights into 8 bit, + // we want to quantize all inputs, including constant tensors, + // for the operators like Add, Mul into 16-bit as well. The constant + // inputs are quantized as weights and this variable indicates + // that we want to do quantizations of these tensors as activations. + bool quantize_input_as_activations = false; }; OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, diff --git a/tensorflow/lite/tools/optimize/quantization_utils.cc b/tensorflow/lite/tools/optimize/quantization_utils.cc index 10680758d72..4bc9686ec2c 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "third_party/eigen3/Eigen/Core" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" @@ -30,6 +29,7 @@ limitations under the License. #include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/tools/optimize/model_utils.h" +#include "third_party/eigen3/Eigen/Core" namespace tflite { namespace optimize { @@ -85,6 +85,46 @@ void GetAsymmetricQuantizationParams( quantization_params->zero_point = std::vector(1, zero_point); } +void GetSymmetricQuantizationParams( + float min, float max, const int half_quant_range, + QuantizationParametersT* quantization_params) { + // Adjust the boundaries to guarantee 0 is included. + min = std::min(min, 0.0f); + max = std::max(max, 0.0f); + const float scale = std::max(std::abs(max), std::abs(min)) / half_quant_range; + int64_t zero_point = 0; + quantization_params->min = std::vector(1, min); + quantization_params->max = std::vector(1, max); + quantization_params->scale = std::vector(1, scale); + quantization_params->zero_point = std::vector(1, 0); +} + +TfLiteStatus GetQuantizationParams(TensorT* tensor, TensorType activations_type, + QuantizationParametersT* quantization_params, + ErrorReporter* error_reporter) { + if (activations_type == TensorType_INT8) { + GetAsymmetricQuantizationParams( + tensor->quantization->min[0], tensor->quantization->max[0], + std::numeric_limits::min(), std::numeric_limits::max(), + quantization_params); + } else if (activations_type == TensorType_INT16) { + float range = std::max(std::abs(tensor->quantization->min[0]), + std::abs(tensor->quantization->max[0])); + const float quantized_range = 32767.0; + const float scale = range / quantized_range; + quantization_params->min = std::vector(1, -range); + quantization_params->max = std::vector(1, range); + quantization_params->scale = std::vector(1, scale); + quantization_params->zero_point = std::vector(1, 0); + } else { + error_reporter->Report( + "Unsupported activation type for quantize-activation: %s", + activations_type); + return kTfLiteError; + } + return kTfLiteOk; +} + // Set the max and min quantization parameter for a single tensor given its // values. void FillSingleMinMax(const float* const input, const uint64_t input_size, @@ -536,6 +576,7 @@ TfLiteStatus SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor, model, tensor, error_reporter); } +template TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor, float scaling_factor, ErrorReporter* error_reporter) { @@ -548,25 +589,38 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor, uint64_t num_elements; TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements)); - std::vector final_buffer(num_elements); - const int32_t kScale = std::numeric_limits::max(); + std::vector final_buffer(num_elements); + const BiasType kScale = std::numeric_limits::max(); for (size_t i = 0; i < num_elements; i++) { - const int32_t quantized_value = tflite::SafeCast( + const BiasType quantized_value = tflite::SafeCast( TfLiteRound(float_data[i] * scaling_factor_inv)); final_buffer[i] = std::min(kScale, std::max(-kScale, quantized_value)); } // Set the buffers and output type. uint8_t* uint8_buffer = reinterpret_cast(final_buffer.data()); - size_t buffer_size = num_elements * sizeof(int32_t); + size_t buffer_size = num_elements * sizeof(BiasType); std::vector scales(1, scaling_factor); std::vector zero_points(1, 0); + + auto output_type = std::is_same::value + ? TensorType_INT32 + : TensorType_INT64; return AddQuantizationParams(scales, zero_points, 0, uint8_buffer, - buffer_size, TensorType_INT32, model, tensor, + buffer_size, output_type, model, tensor, error_reporter); } +template TfLiteStatus SymmetricPerLayerBiasQuantize( + ModelT* model, TensorT* tensor, float scaling_factor, + ErrorReporter* error_reporter); + +template TfLiteStatus SymmetricPerLayerBiasQuantize( + ModelT* model, TensorT* tensor, float scaling_factor, + ErrorReporter* error_reporter); + +template TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor, float input_scale, const float* weight_scales, @@ -583,14 +637,14 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor, uint64_t num_elements; TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements)); - std::vector final_buffer(num_elements); - const int32_t kScale = std::numeric_limits::max(); + std::vector final_buffer(num_elements); + const BiasType kScale = std::numeric_limits::max(); for (int32_t channel_idx = 0; channel_idx < number_of_dimension; channel_idx++) { float scaling_factor = scales[channel_idx]; float scaling_factor_inv = (scaling_factor == 0) ? 0 : 1.0 / scaling_factor; - const int32_t quantized_value = tflite::SafeCast( + const BiasType quantized_value = tflite::SafeCast( TfLiteRound(float_data[channel_idx] * scaling_factor_inv)); final_buffer[channel_idx] = std::min(kScale, std::max(-kScale, quantized_value)); @@ -598,12 +652,26 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor, // Set the buffers and output type. uint8_t* uint8_buffer = reinterpret_cast(final_buffer.data()); - size_t buffer_size = num_elements * sizeof(int32_t); + size_t buffer_size = num_elements * sizeof(BiasType); std::vector zero_point(scales.size(), 0); + + auto output_type = std::is_same::value + ? TensorType_INT32 + : TensorType_INT64; return AddQuantizationParams(scales, zero_point, 0, uint8_buffer, buffer_size, - TensorType_INT32, model, tensor, error_reporter); + output_type, model, tensor, error_reporter); } +template TfLiteStatus SymmetricPerChannelBiasQuantize( + ModelT* model, TensorT* tensor, float input_scale, + const float* weight_scales, int number_of_dimension, + ErrorReporter* error_reporter); + +template TfLiteStatus SymmetricPerChannelBiasQuantize( + ModelT* model, TensorT* tensor, float input_scale, + const float* weight_scales, int number_of_dimension, + ErrorReporter* error_reporter); + TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel, int per_axis_index, ErrorReporter* error_reporter) { // TODO(suharshs): Currently we conflate quantizing weights and constants. Its @@ -645,12 +713,12 @@ float GetEffectiveScale(ModelT* model, SubGraphT* subgraph, int op_idx, return scale; } -void QuantizeActivation(TensorT* tensor) { - GetAsymmetricQuantizationParams( - tensor->quantization->min[0], tensor->quantization->max[0], - std::numeric_limits::min(), std::numeric_limits::max(), - tensor->quantization.get()); - tensor->type = TensorType_INT8; +TfLiteStatus QuantizeActivation(TensorT* tensor, TensorType activations_type, + ErrorReporter* error_reporter) { + TF_LITE_ENSURE_STATUS(GetQuantizationParams( + tensor, activations_type, tensor->quantization.get(), error_reporter)); + tensor->type = activations_type; + return kTfLiteOk; } TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale) { diff --git a/tensorflow/lite/tools/optimize/quantization_utils.h b/tensorflow/lite/tools/optimize/quantization_utils.h index 18ed707e175..752b4253250 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils.h +++ b/tensorflow/lite/tools/optimize/quantization_utils.h @@ -113,12 +113,14 @@ TfLiteStatus SymmetricQuantizeFloatsToInt16(ModelT* model, TensorT* tensor, ErrorReporter* error_reporter); // Symmetrically quantized the bias for per-layer ops (i.e. FullyConnected). +template TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor, float scaling_factor, ErrorReporter* error_reporter); // Symmetrically quantizes the bias for ops like Conv and DepthwiseConv. // The scale of bias if weight_per_channel_scale[channel] * input_scale. +template TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor, float input_scale, const float* weight_scales, @@ -135,8 +137,14 @@ float GetEffectiveScale(ModelT* model, SubGraphT* subgraph, int op_idx, std::vector intermediate_index, std::vector factors); +// Return quantization parameters depending on activations type. +TfLiteStatus GetQuantizationParams(TensorT* tensor, TensorType activations_type, + QuantizationParametersT* quantization_params, + ErrorReporter* error_reporter); + // Quantize activation. -void QuantizeActivation(TensorT* tensor); +TfLiteStatus QuantizeActivation(TensorT* tensor, TensorType activations_type, + ErrorReporter* error_reporter); // Quantize activation to 16bit. TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale); diff --git a/tensorflow/lite/tools/optimize/quantization_utils_test.cc b/tensorflow/lite/tools/optimize/quantization_utils_test.cc index ece0123d166..49009e49600 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils_test.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils_test.cc @@ -701,7 +701,7 @@ TEST_F(QuantizationUtilsTest, SymmetricPerLayerBiasQuantize) { model->buffers.push_back(std::move(buffer)); // Call and verify. - EXPECT_EQ(SymmetricPerLayerBiasQuantize( + EXPECT_EQ(SymmetricPerLayerBiasQuantize( model.get(), model->subgraphs[0]->tensors[0].get(), input_scale * weight_scale, &error_reporter_), kTfLiteOk); @@ -759,7 +759,7 @@ TEST_F(QuantizationUtilsTest, SymmetricPerChannelBiasQuantize) { model->buffers.push_back(std::move(buffer)); // Call and verify. - EXPECT_EQ(SymmetricPerChannelBiasQuantize( + EXPECT_EQ(SymmetricPerChannelBiasQuantize( model.get(), model->subgraphs[0]->tensors[0].get(), input_scale, weight_scales.data(), 2, &error_reporter_), kTfLiteOk); diff --git a/tensorflow/lite/tools/optimize/quantization_wrapper.cc b/tensorflow/lite/tools/optimize/quantization_wrapper.cc index bd3331da6bf..5002c382bc7 100644 --- a/tensorflow/lite/tools/optimize/quantization_wrapper.cc +++ b/tensorflow/lite/tools/optimize/quantization_wrapper.cc @@ -42,7 +42,9 @@ bool CreateQuantizedModel(const std::string& path) { tflite::StderrReporter error_reporter; if (tflite::optimize::QuantizeModel( &builder, &model, tflite::TensorType_FLOAT32, - tflite::TensorType_FLOAT32, &error_reporter) != kTfLiteOk) { + tflite::TensorType_FLOAT32, + // TODO: Pass required activation type if needed + tflite::TensorType_INT8, &error_reporter) != kTfLiteOk) { return false; } return WriteFile(path, builder.GetBufferPointer(), builder.GetSize()); diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc index 6fc19ff2a56..ee562fe9c4c 100644 --- a/tensorflow/lite/tools/optimize/quantize_model.cc +++ b/tensorflow/lite/tools/optimize/quantize_model.cc @@ -64,6 +64,7 @@ operator_property::OperatorProperty GetOperatorProperty( TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor, const TensorT* weight_tensor, TensorT* bias_tensor, bool is_per_channel, int channel_dim_index, + const TensorType& activations_type, ErrorReporter* error_reporter) { if (bias_tensor->shape.size() != 1) { error_reporter->Report("Expected bias tensor shape to be 1."); @@ -92,9 +93,15 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor, weight_scales.size()); return kTfLiteError; } - return utils::SymmetricPerChannelBiasQuantize( - model, bias_tensor, input_tensor->quantization->scale[0], - weight_scales.data(), channel_dim_size, error_reporter); + if (activations_type == tflite::TensorType_INT16) { + return utils::SymmetricPerChannelBiasQuantize( + model, bias_tensor, input_tensor->quantization->scale[0], + weight_scales.data(), channel_dim_size, error_reporter); + } else { + return utils::SymmetricPerChannelBiasQuantize( + model, bias_tensor, input_tensor->quantization->scale[0], + weight_scales.data(), channel_dim_size, error_reporter); + } } else { if (weight_scales.size() != 1) { error_reporter->Report( @@ -102,40 +109,54 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor, weight_scales.size()); return kTfLiteError; } - return utils::SymmetricPerLayerBiasQuantize( - model, bias_tensor, - input_tensor->quantization->scale[0] * weight_scales[0], - error_reporter); + if (activations_type == tflite::TensorType_INT16) { + return utils::SymmetricPerLayerBiasQuantize( + model, bias_tensor, + input_tensor->quantization->scale[0] * weight_scales[0], + error_reporter); + } else { + return utils::SymmetricPerLayerBiasQuantize( + model, bias_tensor, + input_tensor->quantization->scale[0] * weight_scales[0], + error_reporter); + } } return kTfLiteError; } // True if the tensor type has to be modified. bool TensorTypeChangeRequired(const TensorT* tensor, const TensorType& type) { - // The quantized model is type INT8, so if the user provided type is INT8, we - // do not have to do any custom logic. Additionally, if the current tensor - // isn't INT8 quantized, the custom type doesn't apply. - return (type != TensorType_INT8 && tensor->type == TensorType_INT8 && - !tensor->quantization->scale.empty()); + // The quantized model is type INT8/INT16, so if the user provided type is + // INT8/INT16, we do not have to do any custom logic. Additionally, if the + // current tensor isn't INT8/INT16 quantized, the custom type doesn't apply. + bool int8check = type != TensorType_INT8 && tensor->type == TensorType_INT8 && + !tensor->quantization->scale.empty(); + bool int16check = type != TensorType_INT16 && + tensor->type == TensorType_INT16 && + !tensor->quantization->scale.empty(); + return (int8check || int16check); } // Sets the input type, adding a Leading Op node at the start of the model if // necessary. // Returns the new input tensor index. int32_t SetInputType(ModelT* model, SubGraphT* subgraph, - const int32_t tensor_idx, const TensorType& input_type) { + const int32_t tensor_idx, const TensorType& input_type, + const TensorType& activations_type) { TensorT* tensor = subgraph->tensors[tensor_idx].get(); if (!TensorTypeChangeRequired(tensor, input_type)) { return -1; } if (input_type == TensorType_FLOAT32 || input_type == TensorType_UINT8) { + std::string type_string = + activations_type == TensorType_INT16 ? "int16" : "int8"; // Create a new tensor to be the input of the leading Op. std::unique_ptr leading_op_input; if (input_type == TensorType_FLOAT32) { // Add tensor for quantize operator. Scales and zero points are not // needed. const string leading_op_name = tensor->name; - const string new_name_original_input = tensor->name + "_int8"; + const string new_name_original_input = tensor->name + "_" + type_string; tensor->name = new_name_original_input; utils::MakeTensor(leading_op_name, tensor->shape, input_type, &leading_op_input); @@ -150,7 +171,7 @@ int32_t SetInputType(ModelT* model, SubGraphT* subgraph, TFLITE_DCHECK_GE(zero_point, -128); TFLITE_DCHECK_LE(zero_point, 127); const string leading_op_name = tensor->name; - const string new_name_original_input = tensor->name + "_int8"; + const string new_name_original_input = tensor->name + "_" + type_string; tensor->name = new_name_original_input; utils::MakeTensorWithQuantParam(leading_op_name, tensor->shape, input_type, scale, zero_point + 128, @@ -177,17 +198,20 @@ int32_t SetInputType(ModelT* model, SubGraphT* subgraph, // necessary. // Returns the new output tensor index. int32_t SetOutputType(ModelT* model, SubGraphT* subgraph, - const int32_t tensor_idx, const TensorType& output_type) { + const int32_t tensor_idx, const TensorType& output_type, + const TensorType& activations_type) { TensorT* tensor = subgraph->tensors[tensor_idx].get(); if (!TensorTypeChangeRequired(tensor, output_type)) { return -1; } if (output_type == TensorType_FLOAT32 || output_type == TensorType_UINT8) { + std::string type_string = + activations_type == TensorType_INT16 ? "int16" : "int8"; // Create a new tensor to be the output of the tailing op. std::unique_ptr tailing_op_output; if (output_type == TensorType_FLOAT32) { const string tailing_op_name = tensor->name; - const string new_name_original_output = tensor->name + "_int8"; + const string new_name_original_output = tensor->name + "_" + type_string; tensor->name = new_name_original_output; utils::MakeTensor(tailing_op_name, tensor->shape, output_type, &tailing_op_output); @@ -202,7 +226,7 @@ int32_t SetOutputType(ModelT* model, SubGraphT* subgraph, TFLITE_DCHECK_GE(zero_point, -128); TFLITE_DCHECK_LE(zero_point, 127); const string tailing_op_name = tensor->name; - const string new_name_original_output = tensor->name + "_int8"; + const string new_name_original_output = tensor->name + "_" + type_string; tensor->name = new_name_original_output; utils::MakeTensorWithQuantParam(tailing_op_name, tensor->shape, output_type, scale, zero_point + 128, @@ -238,6 +262,7 @@ int32_t SetOutputType(ModelT* model, SubGraphT* subgraph, // uint8, can be thought as "requant"). TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type, const TensorType& output_type, + const TensorType& activations_type, ErrorReporter* error_reporter) { for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); subgraph_idx++) { @@ -253,8 +278,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type, EnumNameTensorType(tensor->type)); return kTfLiteError; } - const int32_t input_idx = - SetInputType(model, subgraph, subgraph->inputs[i], input_type); + const int32_t input_idx = SetInputType( + model, subgraph, subgraph->inputs[i], input_type, activations_type); if (input_idx < 0) { continue; } @@ -270,8 +295,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type, EnumNameTensorType(tensor->type)); return kTfLiteError; } - const int32_t output_idx = - SetOutputType(model, subgraph, subgraph->outputs[i], output_type); + const int32_t output_idx = SetOutputType( + model, subgraph, subgraph->outputs[i], output_type, activations_type); if (output_idx < 0) { continue; } @@ -287,6 +312,7 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type, // The other ones with constraints are handled in QuantizeWeightsAndInput. TfLiteStatus ApplyConstraints(ModelT* model, const std::unordered_set& operator_names, + TensorType activations_type, ErrorReporter* error_reporter) { for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); subgraph_idx++) { @@ -332,7 +358,7 @@ TfLiteStatus ApplyConstraints(ModelT* model, std::unique_ptr additional_tensor; const string requant_tensor_name = input_tensor->name + "_requantized"; utils::MakeTensorWithQuantParam( - requant_tensor_name, input_tensor->shape, TensorType_INT8, + requant_tensor_name, input_tensor->shape, activations_type, output_scale, output_zp, &additional_tensor); const int32_t additional_tensor_idx = subgraph->tensors.size(); subgraph->tensors.push_back(std::move(additional_tensor)); @@ -382,7 +408,8 @@ std::vector> GetOutputs( bool ShouldRestrictSameInputOutputScale( operator_property::OperatorProperty property) { - // Ops with multiple inputs (i.e. concat) gets restricted in ApplyConstraints. + // Ops with multiple inputs (i.e. concat, max and min) gets restricted in + // ApplyConstraints. return (!property.arbitrary_inputs && property.restrict_same_input_output_scale); } @@ -401,7 +428,7 @@ TfLiteStatus QuantizeOpInput( ModelT* model, int32_t subgraph_idx, size_t* op_idx, operator_property::OperatorProperty property, const std::pair& input, - ErrorReporter* error_reporter) { + const TensorType& activations_type, ErrorReporter* error_reporter) { int32_t input_idx = input.first; operator_property::TensorProperty tensor_property = input.second; SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); @@ -429,7 +456,9 @@ TfLiteStatus QuantizeOpInput( if (utils::HasBuffer(model, subgraph, tensor_idx)) { // TODO(suharshs): Look at consumers, throw error if one consumer is // per-channel and one per-layer. - if (tensor_property.number_of_bits == 8) { + bool quantize_const_input = property.quantize_input_as_activations && + activations_type == TensorType_INT16; + if (tensor_property.number_of_bits == 8 && !quantize_const_input) { if (tensor_property.use_derived_scale) { // Currently 8bit tensors in input do not accept derived scale. return kTfLiteError; @@ -444,7 +473,7 @@ TfLiteStatus QuantizeOpInput( *op_idx); return kTfLiteError; } - } else if (tensor_property.number_of_bits == 16) { + } else if (tensor_property.number_of_bits == 16 || quantize_const_input) { if (tensor_property.use_derived_scale) { // Currently 16bit tensors in input do not accept derived scale. return kTfLiteError; @@ -476,8 +505,8 @@ TfLiteStatus QuantizeOpInput( tensor_property.derived_scale.input_tensors, tensor_property.derived_scale.intermediate_tensors, tensor_property.derived_scale.factors); - return utils::SymmetricPerLayerBiasQuantize(model, tensor, scale, - error_reporter); + return utils::SymmetricPerLayerBiasQuantize( + model, tensor, scale, error_reporter); } else if (tensor_property.number_of_bits == 10) { // When the number of bits is 10 (instead of 16), quantize the tensor to @@ -514,7 +543,8 @@ TfLiteStatus QuantizeOpInput( // Currently 8bit tensors in input do not accept derived scale. return kTfLiteError; } - utils::QuantizeActivation(tensor); + TF_LITE_ENSURE_STATUS(utils::QuantizeActivation( + tensor, activations_type, error_reporter)); } else if (tensor_property.number_of_bits == 16) { TensorT* tensor = subgraph->tensors[tensor_idx].get(); float quantized_range = 32767.0; @@ -532,13 +562,16 @@ TfLiteStatus QuantizeOpInput( } else { // If the tensor is not a model input, we need to add a Quantize // operation since the preceding op may require a float output. + std::string type_string = + activations_type == TensorType_INT16 ? "int16" : "int8"; std::unique_ptr op_output; - utils::MakeTensor(tensor->name + "_int8", tensor->shape, - TensorType_INT8, &op_output); + utils::MakeTensor(tensor->name + "_" + type_string, tensor->shape, + activations_type, &op_output); op_output->quantization = absl::make_unique(); op_output->quantization->min.push_back(tensor->quantization->min[0]); op_output->quantization->max.push_back(tensor->quantization->max[0]); - utils::QuantizeActivation(op_output.get()); + TF_LITE_ENSURE_STATUS(utils::QuantizeActivation( + op_output.get(), activations_type, error_reporter)); const int32_t quant_op_output_idx = subgraph->tensors.size(); subgraph->tensors.push_back(std::move(op_output)); std::unique_ptr quant_op; @@ -580,7 +613,7 @@ TfLiteStatus QuantizeOpOutput( ModelT* model, int32_t subgraph_idx, int32_t op_idx, operator_property::OperatorProperty property, const std::pair& output, - ErrorReporter* error_reporter) { + TensorType activations_type, ErrorReporter* error_reporter) { int32_t output_idx = output.first; operator_property::TensorProperty tensor_property = output.second; // If the operator is not quantizable, we don't need to do anything for the @@ -644,18 +677,22 @@ TfLiteStatus QuantizeOpOutput( const float max = input_tensor->quantization->max[0]; output_tensor->quantization->max = {max}; } - output_tensor->type = TensorType_INT8; + output_tensor->type = activations_type; } else if (tensor_property.restriction) { - const auto scale_and_zp = tensor_property.restricted_value; + const auto scale_and_zp = activations_type == TensorType_INT16 + ? tensor_property.restricted_value_int16 + : tensor_property.restricted_value_int8; + // Apply to output. output_tensor->quantization = absl::make_unique(); output_tensor->quantization->scale.push_back(scale_and_zp.first); output_tensor->quantization->zero_point.push_back(scale_and_zp.second); - output_tensor->type = TensorType_INT8; + output_tensor->type = activations_type; } else { // Process regular output that doesn't have any restrictions. if (utils::HasMinMax(output_tensor)) { - utils::QuantizeActivation(output_tensor); + utils::QuantizeActivation(output_tensor, activations_type, + error_reporter); } else { error_reporter->Report( "Unable to find min/max value for output %d in %s in " @@ -668,6 +705,7 @@ TfLiteStatus QuantizeOpOutput( } TfLiteStatus QuantizeIntemediateTensors(ModelT* model, + TensorType activations_type, ErrorReporter* error_reporter) { for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); subgraph_idx++) { @@ -691,7 +729,8 @@ TfLiteStatus QuantizeIntemediateTensors(ModelT* model, input.second.symmetric == false) { TensorT* tensor = subgraph->tensors[index_global].get(); if (utils::HasMinMax(tensor)) { - utils::QuantizeActivation(tensor); + utils::QuantizeActivation(tensor, activations_type, + error_reporter); } else { error_reporter->Report( "Unable to find min/max value for output %d in %s in " @@ -793,7 +832,7 @@ TfLiteStatus QuantizeSharedRange(ModelT* model, ErrorReporter* error_reporter) { TfLiteStatus QuantizeWeightsInputOutput( ModelT* model, bool allow_float, const std::unordered_set& operator_names, - ErrorReporter* error_reporter) { + const TensorType& activations_type, ErrorReporter* error_reporter) { for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); subgraph_idx++) { SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); @@ -815,14 +854,16 @@ TfLiteStatus QuantizeWeightsInputOutput( for (const std::pair& input : GetInputs(op, property)) { TF_LITE_ENSURE_STATUS(QuantizeOpInput(model, subgraph_idx, &op_idx, - property, input, error_reporter)); + property, input, activations_type, + error_reporter)); } // Quantize operator outputs. for (const std::pair& output : GetOutputs(op, property)) { - TF_LITE_ENSURE_STATUS(QuantizeOpOutput( - model, subgraph_idx, op_idx, property, output, error_reporter)); + TF_LITE_ENSURE_STATUS( + QuantizeOpOutput(model, subgraph_idx, op_idx, property, output, + activations_type, error_reporter)); } } } @@ -832,6 +873,7 @@ TfLiteStatus QuantizeWeightsInputOutput( // Quantize bias. TfLiteStatus QuantizeBiases(ModelT* model, const std::unordered_set& operator_names, + const TensorType& activations_type, ErrorReporter* error_reporter) { for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); subgraph_idx++) { @@ -877,10 +919,10 @@ TfLiteStatus QuantizeBiases(ModelT* model, subgraph->tensors[op->inputs[property.inputs[1].first]].get(); operator_property::TensorProperty weight_property = property.inputs[1].second; - TF_LITE_ENSURE_STATUS( - QuantizeBias(model, input_tensor, weight_tensor, bias_tensor, - weight_property.per_axis, - weight_property.per_axis_index, error_reporter)); + TF_LITE_ENSURE_STATUS(QuantizeBias( + model, input_tensor, weight_tensor, bias_tensor, + weight_property.per_axis, weight_property.per_axis_index, + activations_type, error_reporter)); } } } @@ -1000,7 +1042,7 @@ TfLiteStatus FillQuantizationParams( // Check compatibility of activation, weight and bias scales. Adjust if needed. TfLiteStatus EnsureBiasScaleCompatibility( ModelT* model, const std::unordered_set& operator_names, - ErrorReporter* error_reporter) { + TensorType activations_type, ErrorReporter* error_reporter) { for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); subgraph_idx++) { SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); @@ -1049,11 +1091,9 @@ TfLiteStatus EnsureBiasScaleCompatibility( // Get input scale for assymmetric quantization. QuantizationParametersT temp_quant_params = QuantizationParametersT(); - utils::GetAsymmetricQuantizationParams( - input_tensor->quantization->min[0], - input_tensor->quantization->max[0], - std::numeric_limits::min(), - std::numeric_limits::max(), &temp_quant_params); + TF_LITE_ENSURE_STATUS( + utils::GetQuantizationParams(input_tensor, activations_type, + &temp_quant_params, error_reporter)); if (temp_quant_params.scale.size() != 1) { error_reporter->Report("Unexpected input quantization scale size."); return kTfLiteError; @@ -1132,21 +1172,24 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, ModelT* model, const TensorType& input_type, const TensorType& output_type, bool allow_float, const std::unordered_set& operator_names, + const TensorType& activations_type, ErrorReporter* error_reporter) { TF_LITE_ENSURE_STATUS( FillQuantizationParams(model, operator_names, error_reporter)); + TF_LITE_ENSURE_STATUS(EnsureBiasScaleCompatibility( + model, operator_names, activations_type, error_reporter)); TF_LITE_ENSURE_STATUS( - EnsureBiasScaleCompatibility(model, operator_names, error_reporter)); - TF_LITE_ENSURE_STATUS(QuantizeIntemediateTensors(model, error_reporter)); + QuantizeIntemediateTensors(model, activations_type, error_reporter)); TF_LITE_ENSURE_STATUS(QuantizeSharedRange(model, error_reporter)); TF_LITE_ENSURE_STATUS(QuantizeWeightsInputOutput( - model, allow_float, operator_names, error_reporter)); + model, allow_float, operator_names, activations_type, error_reporter)); + TF_LITE_ENSURE_STATUS(ApplyConstraints(model, operator_names, + activations_type, error_reporter)); TF_LITE_ENSURE_STATUS( - ApplyConstraints(model, operator_names, error_reporter)); - TF_LITE_ENSURE_STATUS(QuantizeBiases(model, operator_names, error_reporter)); + QuantizeBiases(model, operator_names, activations_type, error_reporter)); utils::SetOperatorCodeVersion(model); - TF_LITE_ENSURE_STATUS( - SetInputAndOutputTypes(model, input_type, output_type, error_reporter)); + TF_LITE_ENSURE_STATUS(SetInputAndOutputTypes( + model, input_type, output_type, activations_type, error_reporter)); flatbuffers::Offset output_model_location = Model::Pack(*builder, model); @@ -1158,23 +1201,27 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, ModelT* model, const TensorType& input_type, const TensorType& output_type, bool allow_float, + const TensorType& activations_type, ErrorReporter* error_reporter) { return QuantizeModel(builder, model, input_type, output_type, allow_float, - GetAllOperatorOutputs(model), error_reporter); + GetAllOperatorOutputs(model), activations_type, + error_reporter); } TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, ModelT* model, const TensorType& input_type, const TensorType& output_type, + const TensorType& activations_type, ErrorReporter* error_reporter) { return QuantizeModel(builder, model, input_type, output_type, - /*allow_float=*/false, error_reporter); + /*allow_float=*/false, activations_type, error_reporter); } TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, - ModelT* model, ErrorReporter* error_reporter) { + ModelT* model, const TensorType& activations_type, + ErrorReporter* error_reporter) { return QuantizeModel(builder, model, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/false, error_reporter); + /*allow_float=*/false, activations_type, error_reporter); } } // namespace optimize diff --git a/tensorflow/lite/tools/optimize/quantize_model.h b/tensorflow/lite/tools/optimize/quantize_model.h index 9b0353f6b6b..cc801ec9870 100644 --- a/tensorflow/lite/tools/optimize/quantize_model.h +++ b/tensorflow/lite/tools/optimize/quantize_model.h @@ -35,7 +35,9 @@ namespace optimize { // // Note: This is a private API, subject to change. TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, - ModelT* input_model, ErrorReporter* error_reporter); + ModelT* input_model, + const TensorType& activations_type, + ErrorReporter* error_reporter); // Same as above, but the types of quantized inputs and outputs are // configurable. @@ -44,6 +46,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, ModelT* input_model, const TensorType& input_type, const TensorType& output_type, + const TensorType& activations_type, ErrorReporter* error_reporter); // Same as above, but can enable allowing float intermediate operations for ops @@ -53,6 +56,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, ModelT* input_model, const TensorType& input_type, const TensorType& output_type, bool allow_float, + const TensorType& activations_type, ErrorReporter* error_reporter); // Same as above, but enables only quantizing a whitelist of operations, @@ -63,6 +67,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, ModelT* input_model, const TensorType& input_type, const TensorType& output_type, bool allow_float, const std::unordered_set& operator_names, + const TensorType& activations_type, ErrorReporter* error_reporter); } // namespace optimize diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index da1b293c84b..166d60ecc66 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -80,28 +80,35 @@ class QuantizeModelTest : public testing::Test { internal::FailOnErrorReporter error_reporter_; }; -class QuantizeConvModelTest : public QuantizeModelTest { +class QuantizeConvModelTest : public QuantizeModelTest, + public testing::WithParamInterface { protected: QuantizeConvModelTest() { + tensor_type_ = GetParam(); input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); readonly_model_->UnPackTo(&model_); } + TensorType tensor_type_; }; -TEST_F(QuantizeConvModelTest, QuantizationSucceeds) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); +INSTANTIATE_TEST_SUITE_P(QuantizeConvModelTestInst, QuantizeConvModelTest, + testing::ValuesIn({TensorType_INT8, + TensorType_INT16})); + +TEST_P(QuantizeConvModelTest, QuantizationSucceeds) { + auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_, + tensor_type_, &error_reporter_); EXPECT_EQ(status, kTfLiteOk); const uint8_t* buffer = builder_.GetBufferPointer(); const Model* output_model = GetModel(buffer); ASSERT_TRUE(output_model); } -TEST_F(QuantizeConvModelTest, SkipUnspecifiedLayer) { +TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) { auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/true, {}, &error_reporter_); + /*allow_float=*/true, {}, tensor_type_, &error_reporter_); EXPECT_EQ(status, kTfLiteOk); ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size()); // The resulting model should be the same. @@ -123,9 +130,9 @@ TEST_F(QuantizeConvModelTest, SkipUnspecifiedLayer) { } } -TEST_F(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); +TEST_P(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) { + auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_, + tensor_type_, &error_reporter_); EXPECT_EQ(status, kTfLiteOk); ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size()); for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size(); @@ -148,9 +155,9 @@ TEST_F(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) { EXPECT_EQ(model_.operator_codes[0]->version, 3); } -TEST_F(QuantizeConvModelTest, OperatorsAreUnchanged) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); +TEST_P(QuantizeConvModelTest, OperatorsAreUnchanged) { + auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_, + tensor_type_, &error_reporter_); EXPECT_EQ(status, kTfLiteOk); ASSERT_EQ(model_.operator_codes.size(), readonly_model_->operator_codes()->size()); @@ -182,20 +189,28 @@ TEST_F(QuantizeConvModelTest, OperatorsAreUnchanged) { } } -TEST_F(QuantizeConvModelTest, GraphIsFullyQuantized) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); +TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) { + auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_, + tensor_type_, &error_reporter_); EXPECT_EQ(status, kTfLiteOk); for (const auto& subgraph : model_.subgraphs) { for (const auto& tensor : subgraph->tensors) { - EXPECT_TRUE(tensor->type == TensorType_INT32 || - tensor->type == TensorType_INT8); + if (tensor_type_ == TensorType_INT8) { + EXPECT_TRUE(tensor->type == TensorType_INT32 || + tensor->type == TensorType_INT8); + } else if (tensor_type_ == TensorType_INT16) { + EXPECT_TRUE(tensor->type == TensorType_INT64 || // bias + tensor->type == TensorType_INT8 || // weights + tensor->type == TensorType_INT16); // activations + } } } } -TEST_F(QuantizeConvModelTest, FloatInputAndOutput) { - auto status = QuantizeModel(&builder_, &model_, &error_reporter_); +TEST_P(QuantizeConvModelTest, FloatInputAndOutput) { + auto status = + QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, + tensor_type_, &error_reporter_); EXPECT_EQ(status, kTfLiteOk); for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size(); @@ -234,22 +249,33 @@ TEST_F(QuantizeConvModelTest, FloatInputAndOutput) { EXPECT_EQ(subgraph->tensors[output_idx]->type, TensorType_FLOAT32); EXPECT_EQ(subgraph->tensors[output_idx]->name, "output"); // The original input and output has been renamed. - EXPECT_EQ(subgraph->tensors[quant_op->outputs[0]]->name, "input_int8"); - EXPECT_EQ(subgraph->tensors[dequant_op->inputs[0]]->name, "output_int8"); + std::string control_suffix = + (tensor_type_ == TensorType_INT16) ? "int16" : "int8"; + EXPECT_EQ(subgraph->tensors[quant_op->outputs[0]]->name, + "input_" + control_suffix); + EXPECT_EQ(subgraph->tensors[dequant_op->inputs[0]]->name, + "output_" + control_suffix); for (int tensor_idx = 0; tensor_idx < subgraph->tensors.size(); ++tensor_idx) { const auto& tensor = subgraph->tensors[tensor_idx]; if (input_idx != tensor_idx && output_idx != tensor_idx) { - EXPECT_TRUE(tensor->type == TensorType_INT32 || - tensor->type == TensorType_INT8); + if (tensor_type_ == TensorType_INT8) { + EXPECT_TRUE(tensor->type == TensorType_INT32 || + tensor->type == TensorType_INT8); + } else if (tensor_type_ == TensorType_INT16) { + EXPECT_TRUE(tensor->type == TensorType_INT64 || // bias + tensor->type == TensorType_INT8 || // weights + tensor->type == TensorType_INT16); // activations + } } } } } -TEST_F(QuantizeConvModelTest, Uint8InputAndOutput) { - auto status = QuantizeModel(&builder_, &model_, TensorType_UINT8, - TensorType_UINT8, &error_reporter_); +TEST_P(QuantizeConvModelTest, Uint8InputAndOutput) { + auto status = + QuantizeModel(&builder_, &model_, TensorType_UINT8, TensorType_UINT8, + TensorType_INT8, &error_reporter_); EXPECT_EQ(status, kTfLiteOk); for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size(); @@ -326,21 +352,25 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest { }; TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, + TensorType_INT8, &error_reporter_); EXPECT_EQ(status, kTfLiteOk); const uint8_t* buffer = builder_.GetBufferPointer(); const Model* output_model = GetModel(buffer); ASSERT_TRUE(output_model); } -class QuantizeConcatModelTest : public QuantizeModelTest { +class QuantizeConcatModelTest : public QuantizeModelTest, + public testing::WithParamInterface { protected: QuantizeConcatModelTest() { input_model_ = ReadModel(internal::kFloatConcatMax5Max10Max10); readonly_model_ = input_model_->GetModel(); readonly_model_->UnPackTo(&model_); } + + TensorType tensor_type_; }; // There are two inputs for concat, "input0" and "input1". "input0" has [0, 5] @@ -352,9 +382,9 @@ class QuantizeConcatModelTest : public QuantizeModelTest { // input0 -> requant -> input0_requant \ // concat - output // input1 / -TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); +TEST_P(QuantizeConcatModelTest, AddRequantBeforeConcat) { + auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_, + tensor_type_, &error_reporter_); EXPECT_EQ(status, kTfLiteOk); // There is only one subgraph. @@ -373,32 +403,51 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) { EXPECT_EQ(model_.operator_codes[concat->opcode_index]->builtin_code, BuiltinOperator_CONCATENATION); + auto zero_point_control = tensor_type_ == TensorType_INT8 ? -128 : 0; + /* + input0_scale_control + INT8: (5-0) / (2^8 - 1) + INT16: (5-0) / (2^16 / 2 - 1) + input1_scale + INT8: (10-0) / (2^8 - 1) + INT16: (10-0) / (2^16 / 2 - 1) + */ + auto input0_scale_control = + tensor_type_ == TensorType_INT8 ? 0.019607844 : 0.00015259254; + auto input1_scale = + tensor_type_ == TensorType_INT8 ? 0.039215688 : 0.00030518509; + // There should be 4 tensors: input0, input1, input0_requantized, output. EXPECT_EQ(subgraph->tensors.size(), 4); - EXPECT_EQ(subgraph->tensors[0]->type, TensorType_INT8); + EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_); EXPECT_EQ(subgraph->tensors[0]->name, "input0"); EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1); EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1); - EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0], 0.019607844); - EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0], -128); - EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT8); + EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0], + input0_scale_control); + EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0], + zero_point_control); + EXPECT_EQ(subgraph->tensors[1]->type, tensor_type_); EXPECT_EQ(subgraph->tensors[1]->name, "input1"); EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1); EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1); - EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], 0.039215688); - EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0], -128); - EXPECT_EQ(subgraph->tensors[2]->type, TensorType_INT8); + EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], input1_scale); + EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0], + zero_point_control); + EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_); EXPECT_EQ(subgraph->tensors[2]->name, "output"); EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1); EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1); - EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], 0.039215688); - EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0], -128); - EXPECT_EQ(subgraph->tensors[3]->type, TensorType_INT8); + EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], input1_scale); + EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0], + zero_point_control); + EXPECT_EQ(subgraph->tensors[3]->type, tensor_type_); EXPECT_EQ(subgraph->tensors[3]->name, "input0_requantized"); EXPECT_EQ(subgraph->tensors[3]->quantization->scale.size(), 1); EXPECT_EQ(subgraph->tensors[3]->quantization->zero_point.size(), 1); - EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->scale[0], 0.039215688); - EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->zero_point[0], -128); + EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->scale[0], input1_scale); + EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->zero_point[0], + zero_point_control); // The connection should be what is described in the comment. EXPECT_EQ(requant->inputs.size(), 1); @@ -419,7 +468,9 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) { EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_QUANTIZE); EXPECT_EQ(model_.operator_codes[1]->version, 2); } - +INSTANTIATE_TEST_SUITE_P(QuantizeConcatModelInst, QuantizeConcatModelTest, + testing::ValuesIn({TensorType_INT8, + TensorType_INT16})); class QuantizeSplitModelTest : public QuantizeModelTest { protected: QuantizeSplitModelTest() { @@ -432,8 +483,9 @@ class QuantizeSplitModelTest : public QuantizeModelTest { // There are two outputs for split with different scales, the resulting model // should have the scales be hardcodes to the input scale value. TEST_F(QuantizeSplitModelTest, QuantizeSplit) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, + TensorType_INT8, &error_reporter_); EXPECT_EQ(status, kTfLiteOk); // There is only one subgraph. @@ -496,8 +548,9 @@ class QuantizeConvModel1Test : public QuantizeModelTest { }; TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, + TensorType_INT8, &error_reporter_); EXPECT_EQ(status, kTfLiteOk); const auto& subgraph = model_.subgraphs[0]; @@ -587,18 +640,25 @@ TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) { EXPECT_EQ(model_.operator_codes[0]->version, 3); } -class QuantizeConvModel2Test : public QuantizeModelTest { +class QuantizeConvModel2Test : public QuantizeModelTest, + public testing::WithParamInterface { protected: QuantizeConvModel2Test() { + tensor_type_ = GetParam(); input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); readonly_model_->UnPackTo(&model_); } -}; -TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); + TensorType tensor_type_; +}; +INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test, + testing::ValuesIn({TensorType_INT8, + TensorType_INT16})); + +TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) { + auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_, + tensor_type_, &error_reporter_); ASSERT_EQ(kTfLiteOk, status); const auto& subgraph = model_.subgraphs[0]; auto conv_op = subgraph->operators[0].get(); @@ -615,8 +675,10 @@ TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) { const auto output_tensor = subgraph->tensors[conv_op->outputs[output_tensor_idx]].get(); - EXPECT_EQ(bias_tensor->type, TensorType_INT32); - EXPECT_EQ(input_tensor->type, TensorType_INT8); + EXPECT_EQ(bias_tensor->type, tensor_type_ == TensorType_INT8 + ? TensorType_INT32 + : TensorType_INT64); + EXPECT_EQ(input_tensor->type, tensor_type_); EXPECT_EQ(weights_tensor->type, TensorType_INT8); ASSERT_TRUE(weights_tensor->quantization); @@ -644,17 +706,28 @@ TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) { } const auto bias_buffer = model_.buffers[bias_tensor->buffer].get(); - ASSERT_EQ(bias_buffer->data.size(), sizeof(int32_t) * bias_tensor->shape[0]); - const int32_t* bias_values = - reinterpret_cast(bias_buffer->data.data()); + auto control_size = tensor_type_ == TensorType_INT8 + ? sizeof(int32_t) * bias_tensor->shape[0] + : sizeof(int64_t) * bias_tensor->shape[0]; + + ASSERT_EQ(bias_buffer->data.size(), control_size); const auto original_bias_buffer = readonly_model_->buffers()->Get(bias_tensor->buffer); const float* bias_float_buffer = reinterpret_cast(original_bias_buffer->data()->data()); - for (size_t i = 0; i < out_channel_size; i++) { - auto dequantized_value = bias_values[i] * bias_scales[i]; - EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2); + if (tensor_type_ == TensorType_INT8) { + int32_t* bias_values = reinterpret_cast(bias_buffer->data.data()); + for (size_t i = 0; i < out_channel_size; i++) { + auto dequantized_value = bias_values[i] * bias_scales[i]; + EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2); + } + } else if (tensor_type_ == TensorType_INT16) { + int64_t* bias_values = reinterpret_cast(bias_buffer->data.data()); + for (size_t i = 0; i < out_channel_size; i++) { + auto dequantized_value = bias_values[i] * bias_scales[i]; + EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2); + } } const auto weights_buffer = model_.buffers[weights_tensor->buffer].get(); @@ -695,8 +768,9 @@ class QuantizeSoftmaxTest : public QuantizeModelTest { }; TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, + TensorType_INT8, &error_reporter_); ASSERT_EQ(kTfLiteOk, status); const auto& subgraph = model_.subgraphs[0]; @@ -755,8 +829,9 @@ class QuantizeAvgPoolTest : public QuantizeModelTest { }; TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, + TensorType_INT8, &error_reporter_); ASSERT_EQ(kTfLiteOk, status); const auto& subgraph = model_.subgraphs[0]; @@ -816,8 +891,9 @@ class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { }; TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, + TensorType_INT8, &error_reporter_); ASSERT_EQ(kTfLiteOk, status); // Verify Reshape is quantized. @@ -863,8 +939,9 @@ TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { } TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, + TensorType_INT8, &error_reporter_); ASSERT_EQ(kTfLiteOk, status); // Verify ADD is quantized. @@ -923,8 +1000,9 @@ class QuantizeConstInputTest : public QuantizeModelTest { }; TEST_F(QuantizeConstInputTest, VerifyConstOpInput) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, + TensorType_INT8, &error_reporter_); ASSERT_EQ(kTfLiteOk, status); // Verify ConstOp is quantized. @@ -965,8 +1043,9 @@ class QuantizeArgMaxTest : public QuantizeModelTest { }; TEST_F(QuantizeArgMaxTest, VerifyArgMax) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, + TensorType_INT8, &error_reporter_); ASSERT_EQ(kTfLiteOk, status); const auto& subgraph = model_.subgraphs[0]; @@ -1008,8 +1087,9 @@ class QuantizeLSTMTest : public QuantizeModelTest { TEST_F(QuantizeLSTMTest, VerifyLSTM) { // Quantize model. - auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32, - TensorType_FLOAT32, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, + TensorType_INT8, &error_reporter_); ASSERT_EQ(kTfLiteOk, status); // Read expected model. @@ -1067,8 +1147,9 @@ class QuantizeLSTM2Test : public QuantizeModelTest { TEST_F(QuantizeLSTM2Test, VerifyLSTM) { // Quantize model. - auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32, - TensorType_FLOAT32, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, + TensorType_INT8, &error_reporter_); ASSERT_EQ(kTfLiteOk, status); // Read expected model. @@ -1126,8 +1207,9 @@ class QuantizeSVDFTest : public QuantizeModelTest { TEST_F(QuantizeSVDFTest, VerifySVDF) { // Quantize model. - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, + TensorType_INT8, &error_reporter_); ASSERT_EQ(kTfLiteOk, status); // Read expected model. @@ -1184,8 +1266,9 @@ class QuantizeFCTest : public QuantizeModelTest { }; TEST_F(QuantizeFCTest, VerifyFC) { - auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, - TensorType_INT8, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, + TensorType_INT8, &error_reporter_); ASSERT_EQ(kTfLiteOk, status); const auto& subgraph = model_.subgraphs[0]; @@ -1236,7 +1319,7 @@ class QuantizeCustomOpTest : public QuantizeModelTest { TEST_F(QuantizeCustomOpTest, VerifyMixedQuantization) { auto status = QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/true, &error_reporter_); + /*allow_float=*/true, TensorType_INT8, &error_reporter_); ASSERT_EQ(kTfLiteOk, status); const auto& subgraph = model_.subgraphs[0]; auto float_graph = readonly_model_->subgraphs()->Get(0); @@ -1270,7 +1353,8 @@ class QuantizePackTest : public QuantizeModelTest { }; TEST_F(QuantizePackTest, VerifyPack) { - auto status = QuantizeModel(&builder_, &model_, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_); ASSERT_EQ(kTfLiteOk, status); @@ -1334,7 +1418,8 @@ class QuantizeMinimumMaximumTest }; TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { - auto status = QuantizeModel(&builder_, &model_, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_); ASSERT_EQ(kTfLiteOk, status); const auto& subgraph = model_.subgraphs[0]; @@ -1415,7 +1500,8 @@ class QuantizeUnpackTest : public QuantizeModelTest { } }; TEST_F(QuantizeUnpackTest, VerifyUnpack) { - auto status = QuantizeModel(&builder_, &model_, &error_reporter_); + auto status = + QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_); ASSERT_EQ(kTfLiteOk, status);