Added an option TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 to

enable sym quantization with activations in 16-bit and weigths in 8-bit.
This commit is contained in:
Elena Zhelezina 2020-01-21 13:18:55 +00:00
parent 2e98e89091
commit a7899d7544
17 changed files with 477 additions and 202 deletions

View File

@ -93,6 +93,12 @@ class OpsSet(enum.Enum):
# quantized implementations.
TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8"
# Convert model using only TensorFlow Lite operations with quantized int8 weights
# and int16 activations.
# Specifying this will throw an error for operations that do not yet have
# quantized implementations.
TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 = "TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8"
def __str__(self):
return self.value

View File

@ -224,6 +224,10 @@ class TFLiteConverterBase(object):
self.target_spec.supported_ops) or
self._smallest_supported_type() == constants.INT8)
def _is_int16x8_target_required(self):
return (set([OpsSet.TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8]) ==
set(self.target_spec.supported_ops))
def _smallest_supported_type(self):
if self.target_spec.supported_types:
return min(self.target_spec.supported_types, key=lambda x: x.size)
@ -238,7 +242,9 @@ class TFLiteConverterBase(object):
]))
def _is_post_training_optimize(self):
return self._is_int8_target_required() or self._any_optimization_enabled()
return self._is_int8_target_required() or \
self._is_int16x8_target_required() or \
self._any_optimization_enabled()
def _is_int8_weight_only_quantize(self):
return (self._is_post_training_optimize() and
@ -255,11 +261,12 @@ class TFLiteConverterBase(object):
def _calibrate_quantize_model(self, result, inference_input_type,
inference_output_type, enable_mlir_quantizer):
allow_float = not self._is_int8_target_required()
allow_float = not self._is_int8_target_required() and not self._is_int16x8_target_required()
calibrate_quantize = _calibrator.Calibrator(result)
activations_type = constants.INT16 if self._is_int16x8_target_required() else constants.INT8
return calibrate_quantize.calibrate_and_quantize(
self.representative_dataset.input_gen, inference_input_type,
inference_output_type, allow_float, enable_mlir_quantizer)
inference_output_type, allow_float, activations_type, enable_mlir_quantizer)
def _get_base_converter_args(self):
"""Returns the base converter args.

View File

@ -30,6 +30,7 @@ INT64 = dtypes.int64
STRING = dtypes.string
QUANTIZED_UINT8 = dtypes.uint8
INT8 = dtypes.int8
INT16 = dtypes.int16
COMPLEX64 = dtypes.complex64
TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF
TFLITE = _toco_flags_pb2.TFLITE
@ -43,6 +44,7 @@ _tf_export(v1=["lite.constants.STRING"]).export_constant(__name__, "STRING")
_tf_export(v1=["lite.constants.QUANTIZED_UINT8"]).export_constant(
__name__, "QUANTIZED_UINT8")
_tf_export(v1=["lite.constants.INT8"]).export_constant(__name__, "INT8")
_tf_export(v1=["lite.constants.INT16"]).export_constant(__name__, "INT16")
_tf_export(v1=["lite.constants.TFLITE"]).export_constant(__name__, "TFLITE")
_tf_export(v1=["lite.constants.GRAPHVIZ_DOT"]).export_constant(
__name__, "GRAPHVIZ_DOT")
@ -62,6 +64,7 @@ _allowed_symbols = [
"STRING",
"QUANTIZED_UINT8",
"INT8",
"INT16",
"COMPLEX64",
"TENSORFLOW_GRAPHDEF",
"TFLITE",

View File

@ -769,9 +769,13 @@ class FromSessionTest(TestModels, parameterized.TestCase):
self.assertLess(len(quantized_tflite), len(float_tflite))
@parameterized.named_parameters(
('EnableMlirConverter', True), # enable mlir
('DisableMlirConverter', False)) # disable mlir
def testCalibrateAndQuantizeBuiltinInt8(self, enable_mlir):
# Quantize model to Int8: with enable mlir
('UseTfliteBuiltinsIntEnableMLIR', [lite.OpsSet.TFLITE_BUILTINS_INT8], True),
# Quantize model to Int8: with disable mlir
('UseTfliteBuiltinsIntDisableMLIR', [lite.OpsSet.TFLITE_BUILTINS_INT8], False),
# Quantize model to Int16: with disable mlir
('UseTfliteBuiltinsInt16DisableMLIR', [lite.OpsSet.TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8], False))
def testCalibrateAndQuantizeBuiltinInt(self, supported_ops, enable_mlir):
with ops.Graph().as_default():
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
sess = session.Session()
@ -787,9 +791,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
quantized_converter = lite.TFLiteConverter.from_session(
sess, [inp], [output])
quantized_converter.experimental_new_converter = enable_mlir
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.TFLITE_BUILTINS_INT8
]
quantized_converter.target_spec.supported_ops = supported_ops
quantized_converter.representative_dataset = calibration_gen
quantized_tflite = quantized_converter.convert()
self.assertTrue(quantized_tflite)

View File

@ -204,6 +204,7 @@ PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) {
PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
int output_py_type,
bool allow_float,
int activations_py_type,
bool enable_mlir_quantizer) {
if (NoOpModel(*model_)) {
return python_utils::ConvertToPyString(model_str_->data(),
@ -212,6 +213,9 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
TfLiteType activations_type =
python_utils::TfLiteTypeFromPyType(activations_py_type);
if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
PyErr_SetString(PyExc_ValueError,
"Input/output type cannot be kTfLiteNoType");
@ -230,7 +234,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
status = tflite::optimize::QuantizeModel(
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
TfLiteTypeToSchemaType(output_type), allow_float,
error_reporter_.get());
TfLiteTypeToSchemaType(activations_type), error_reporter_.get());
}
if (status != kTfLiteOk) {
@ -262,7 +266,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
auto status = tflite::optimize::QuantizeModel(
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
TfLiteTypeToSchemaType(output_type), allow_float, {op_name},
error_reporter_.get());
TensorType_INT8, error_reporter_.get());
if (status != kTfLiteOk) {
error_reporter_->exception();
return nullptr;

View File

@ -60,7 +60,8 @@ class CalibrationWrapper {
PyObject* FeedTensor(PyObject* input_value);
PyObject* QuantizeModel(int input_py_type, int output_py_type,
bool allow_float, bool enable_mlir_quantizer = false);
bool allow_float, int activations_py_type,
bool enable_mlir_quantizer = false);
// Allows quantizing only the operator that produces the tensor with name
// operator_output_name. (This can be used to help debug.).

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.lite.python import lite_constants
# Lazy load since some of the performance benchmark skylark rules
# break dependencies. Must use double quotes to match code internal rewrite
@ -55,7 +56,8 @@ class Calibrator(object):
raise ValueError("Failed to parse the model.")
def calibrate_and_quantize(self, dataset_gen, input_type, output_type,
allow_float, enable_mlir_quantizer=False):
allow_float, activations_type = lite_constants.INT8,
enable_mlir_quantizer=False):
"""Calibrates the model with specified generator and then quantizes it.
Returns:
@ -69,6 +71,7 @@ class Calibrator(object):
computation, useful when targeting an integer-only backend.
If False, an error will be thrown if an operation cannot be
quantized, otherwise the model will fallback to float ops.
activations_type: A tf.dtype representing the desired type for activations
enable_mlir_quantizer: A boolean. True if wants to use mlir quantizer to
quantize the calibrated model.
"""
@ -78,6 +81,7 @@ class Calibrator(object):
return self._calibrator.QuantizeModel(
np.dtype(input_type.as_numpy_dtype()).num,
np.dtype(output_type.as_numpy_dtype()).num, allow_float,
np.dtype(activations_type.as_numpy_dtype()).num,
enable_mlir_quantizer)
def calibrate_and_quantize_single(self, dataset_gen, input_type, output_type,

View File

@ -33,9 +33,13 @@ from tensorflow.python.platform import test
class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@parameterized.named_parameters(
('EnableMlirQuantizer', True), # enable mlir quantizer
('DisableMlirQuantizer', False)) # disable mlir quantizer
def test_calibration_with_quantization(self, enable_mlir):
# Activation type Int8 - enable mlir quantizer
('UseActivationTypeInt8EnabledMlir', constants.INT8, True),
# Activation type Int8 - disable mlir quantizer
('UseActivationTypeInt8DisabledMlir', constants.INT8, False),
# Activation type Int16
('UseActivationTypeInt16', constants.INT16, False))
def test_calibration_with_quantization(self, activations_type, enable_mlir):
model_path = resource_loader.get_path_to_datafile(
'test_data/mobilenet_like_model.bin')
float_model = open(model_path, 'rb').read()
@ -49,13 +53,18 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
quantized_model = quantizer.calibrate_and_quantize(input_gen,
constants.FLOAT,
constants.FLOAT, False,
activations_type,
enable_mlir)
self.assertIsNotNone(quantized_model)
@parameterized.named_parameters(
('EnableMlirQuantizer', True), # enable mlir quantizer
('DisableMlirQuantizer', False)) # disable mlir quantizer
def test_calibration_with_quantization_allow_float(self, enable_mlir):
# Activation type Int8 - enable mlir quantizer
('UseActivationTypeInt8EnabledMlir', constants.INT8, True),
# Activation type Int8 - disable mlir quantizer
('UseActivationTypeInt8DisableMlir', constants.INT8, False),
# Activation type Int16 - disable mlir quantizer
('UseActivationTypeInt16', constants.INT16, False))
def test_calibration_with_quantization_allow_float(self, activations_type, enable_mlir):
model_path = resource_loader.get_path_to_datafile(
'test_data/mobilenet_like_model.bin')
float_model = open(model_path, 'rb').read()
@ -69,6 +78,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
quantized_model = quantizer.calibrate_and_quantize(input_gen,
constants.FLOAT,
constants.FLOAT, True,
activations_type,
enable_mlir)
self.assertIsNotNone(quantized_model)
@ -88,9 +98,13 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertIsNotNone(quantized_model)
@parameterized.named_parameters(
('EnableMlirQuantizer', True), # enable mlir quantizer
('DisableMlirQuantizer', False)) # disable mlir quantizer
def test_calibration_with_quantization_multiple_inputs(self, enable_mlir):
# Activation type Int8 - enable mlir quantizer
('UseActivationTypeInt8 - EnableMlirQuantizer', constants.INT8, True),
# Activation type Int8 - disable mlir quantizer
('UseActivationTypeInt8 - DisableMlirQuantizer', constants.INT8, False),
# Activation type Int16 - disable mlir quantizer
('UseActivationTypeInt16 - DisableEnableMlirQuantizer', constants.INT16, False))
def test_calibration_with_quantization_multiple_inputs(self, activations_type, enable_mlir):
# Load multi add model from test data.
# This model has 4 inputs of size (1, 8, 8, 3).
model_path = resource_loader.get_path_to_datafile(
@ -106,6 +120,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
quantized_model = quantizer.calibrate_and_quantize(input_gen,
constants.FLOAT,
constants.FLOAT, False,
activations_type,
enable_mlir)
self.assertIsNotNone(quantized_model)
@ -148,7 +163,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
with self.assertRaisesRegex(ValueError, 'Size mismatch'):
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
constants.FLOAT, False, enable_mlir)
constants.FLOAT, False,
enable_mlir)
@parameterized.named_parameters(
('EnableMlirQuantizer', True), # enable mlir quantizer
@ -166,7 +182,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
with self.assertRaises(ValueError):
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
constants.FLOAT, False, enable_mlir)
constants.FLOAT, False,
constants.INT8, enable_mlir)
if __name__ == '__main__':

View File

@ -64,6 +64,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
property.inputs = {{0, {}}, {1, {}}};
property.outputs = {{0, {}}};
property.version = 2;
property.quantize_input_as_activations = true;
break;
case BuiltinOperator_ARG_MAX:
property.inputs = {{0, {}}};
@ -176,7 +177,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
// LogSoftmax requires output with 16/256 as scale and 127 as zero point.
TensorProperty tensor_property;
tensor_property.restriction = true;
tensor_property.restricted_value = {16.0 / 256.0, 127};
tensor_property.restricted_value_int8 = {16.0 / 256.0, 127};
property.outputs = {{0, tensor_property}};
property.version = 2;
break;
@ -186,7 +187,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
// Logistic requires output with 1/256 as scale and -128 as zero point.
TensorProperty tensor_property;
tensor_property.restriction = true;
tensor_property.restricted_value = {1 / 256.0, -128};
tensor_property.restricted_value_int8 = {1 / 256.0, -128};
tensor_property.restricted_value_int16 = {1 / 32768.0, 0};
property.outputs = {{0, tensor_property}};
property.version = 2;
break;
@ -741,7 +743,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
// L2 Norm requires output with 1/128 as scale and 0 as zero point.
TensorProperty tensor_property;
tensor_property.restriction = true;
tensor_property.restricted_value = {1 / 128.0, 0};
tensor_property.restricted_value_int8 = {1 / 128.0, 0};
property.outputs = {{0, tensor_property}};
property.version = 2;
break;
@ -756,6 +758,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
property.arbitrary_inputs = true;
property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true;
property.quantize_input_as_activations = true;
property.version = 2;
break;
case BuiltinOperator_MEAN:
@ -767,6 +770,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
property.arbitrary_inputs = true;
property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true;
property.quantize_input_as_activations = true;
property.version = 2;
break;
case BuiltinOperator_MUL:
@ -778,6 +782,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
property.arbitrary_inputs = true;
property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true;
property.restrict_same_input_output_scale = true;
property.version = 2;
break;
case BuiltinOperator_PAD:
@ -840,7 +845,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
// Softmax requires output with 1/256 as scale and -128 as zero point.
TensorProperty tensor_property;
tensor_property.restriction = true;
tensor_property.restricted_value = {1 / 256.0, -128};
tensor_property.restricted_value_int8 = {1 / 256.0, -128};
tensor_property.restricted_value_int16 = {1 / 32768.0, 0};
property.outputs = {{0, tensor_property}};
property.version = 2;
break;
@ -866,7 +872,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
// Tanh requires output with 1/128 as scale and 0 as zero point.
TensorProperty tensor_property;
tensor_property.restriction = true;
tensor_property.restricted_value = {1 / 128.0, 0};
tensor_property.restricted_value_int8 = {1 / 128.0, 0};
tensor_property.restricted_value_int16 = {1 / 32768.0, 0};
property.outputs = {{0, tensor_property}};
property.version = 2;
break;

View File

@ -43,7 +43,8 @@ struct TensorProperty {
// Constraints.
bool restriction = false;
// scale/zero_point hardcoded.
std::pair<float, int> restricted_value = {0.0, 0};
std::pair<float, int> restricted_value_int8 = {0.0, 0};
std::pair<float, int> restricted_value_int16 = {0.0, 0};
// Use derived scale.
bool use_derived_scale = false;
@ -93,6 +94,13 @@ struct OperatorProperty {
// Op version.
int version = 1;
// When we quantize activations into 16 bit and weights into 8 bit,
// we want to quantize all inputs, including constant tensors,
// for the operators like Add, Mul into 16-bit as well. The constant
// inputs are quantized as weights and this variable indicates
// that we want to do quantizations of these tensors as activations.
bool quantize_input_as_activations = false;
};
OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,

View File

@ -20,7 +20,6 @@ limitations under the License.
#include <string>
#include "absl/memory/memory.h"
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
@ -30,6 +29,7 @@ limitations under the License.
#include "tensorflow/lite/minimal_logging.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/tools/optimize/model_utils.h"
#include "third_party/eigen3/Eigen/Core"
namespace tflite {
namespace optimize {
@ -85,6 +85,46 @@ void GetAsymmetricQuantizationParams(
quantization_params->zero_point = std::vector<int64_t>(1, zero_point);
}
void GetSymmetricQuantizationParams(
float min, float max, const int half_quant_range,
QuantizationParametersT* quantization_params) {
// Adjust the boundaries to guarantee 0 is included.
min = std::min(min, 0.0f);
max = std::max(max, 0.0f);
const float scale = std::max(std::abs(max), std::abs(min)) / half_quant_range;
int64_t zero_point = 0;
quantization_params->min = std::vector<float>(1, min);
quantization_params->max = std::vector<float>(1, max);
quantization_params->scale = std::vector<float>(1, scale);
quantization_params->zero_point = std::vector<int64_t>(1, 0);
}
TfLiteStatus GetQuantizationParams(TensorT* tensor, TensorType activations_type,
QuantizationParametersT* quantization_params,
ErrorReporter* error_reporter) {
if (activations_type == TensorType_INT8) {
GetAsymmetricQuantizationParams(
tensor->quantization->min[0], tensor->quantization->max[0],
std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max(),
quantization_params);
} else if (activations_type == TensorType_INT16) {
float range = std::max(std::abs(tensor->quantization->min[0]),
std::abs(tensor->quantization->max[0]));
const float quantized_range = 32767.0;
const float scale = range / quantized_range;
quantization_params->min = std::vector<float>(1, -range);
quantization_params->max = std::vector<float>(1, range);
quantization_params->scale = std::vector<float>(1, scale);
quantization_params->zero_point = std::vector<int64_t>(1, 0);
} else {
error_reporter->Report(
"Unsupported activation type for quantize-activation: %s",
activations_type);
return kTfLiteError;
}
return kTfLiteOk;
}
// Set the max and min quantization parameter for a single tensor given its
// values.
void FillSingleMinMax(const float* const input, const uint64_t input_size,
@ -536,6 +576,7 @@ TfLiteStatus SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor,
model, tensor, error_reporter);
}
template <class BiasType>
TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
float scaling_factor,
ErrorReporter* error_reporter) {
@ -548,25 +589,38 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
uint64_t num_elements;
TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements));
std::vector<int32_t> final_buffer(num_elements);
const int32_t kScale = std::numeric_limits<int32_t>::max();
std::vector<BiasType> final_buffer(num_elements);
const BiasType kScale = std::numeric_limits<BiasType>::max();
for (size_t i = 0; i < num_elements; i++) {
const int32_t quantized_value = tflite::SafeCast<int32_t>(
const BiasType quantized_value = tflite::SafeCast<BiasType>(
TfLiteRound(float_data[i] * scaling_factor_inv));
final_buffer[i] = std::min(kScale, std::max(-kScale, quantized_value));
}
// Set the buffers and output type.
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
size_t buffer_size = num_elements * sizeof(int32_t);
size_t buffer_size = num_elements * sizeof(BiasType);
std::vector<float> scales(1, scaling_factor);
std::vector<int64_t> zero_points(1, 0);
auto output_type = std::is_same<BiasType, std::int32_t>::value
? TensorType_INT32
: TensorType_INT64;
return AddQuantizationParams(scales, zero_points, 0, uint8_buffer,
buffer_size, TensorType_INT32, model, tensor,
buffer_size, output_type, model, tensor,
error_reporter);
}
template TfLiteStatus SymmetricPerLayerBiasQuantize<std::int32_t>(
ModelT* model, TensorT* tensor, float scaling_factor,
ErrorReporter* error_reporter);
template TfLiteStatus SymmetricPerLayerBiasQuantize<std::int64_t>(
ModelT* model, TensorT* tensor, float scaling_factor,
ErrorReporter* error_reporter);
template <class BiasType>
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
float input_scale,
const float* weight_scales,
@ -583,14 +637,14 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
uint64_t num_elements;
TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements));
std::vector<int32_t> final_buffer(num_elements);
const int32_t kScale = std::numeric_limits<int32_t>::max();
std::vector<BiasType> final_buffer(num_elements);
const BiasType kScale = std::numeric_limits<BiasType>::max();
for (int32_t channel_idx = 0; channel_idx < number_of_dimension;
channel_idx++) {
float scaling_factor = scales[channel_idx];
float scaling_factor_inv = (scaling_factor == 0) ? 0 : 1.0 / scaling_factor;
const int32_t quantized_value = tflite::SafeCast<int32_t>(
const BiasType quantized_value = tflite::SafeCast<BiasType>(
TfLiteRound(float_data[channel_idx] * scaling_factor_inv));
final_buffer[channel_idx] =
std::min(kScale, std::max(-kScale, quantized_value));
@ -598,12 +652,26 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
// Set the buffers and output type.
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
size_t buffer_size = num_elements * sizeof(int32_t);
size_t buffer_size = num_elements * sizeof(BiasType);
std::vector<int64_t> zero_point(scales.size(), 0);
auto output_type = std::is_same<BiasType, std::int32_t>::value
? TensorType_INT32
: TensorType_INT64;
return AddQuantizationParams(scales, zero_point, 0, uint8_buffer, buffer_size,
TensorType_INT32, model, tensor, error_reporter);
output_type, model, tensor, error_reporter);
}
template TfLiteStatus SymmetricPerChannelBiasQuantize<std::int64_t>(
ModelT* model, TensorT* tensor, float input_scale,
const float* weight_scales, int number_of_dimension,
ErrorReporter* error_reporter);
template TfLiteStatus SymmetricPerChannelBiasQuantize<std::int32_t>(
ModelT* model, TensorT* tensor, float input_scale,
const float* weight_scales, int number_of_dimension,
ErrorReporter* error_reporter);
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
int per_axis_index, ErrorReporter* error_reporter) {
// TODO(suharshs): Currently we conflate quantizing weights and constants. Its
@ -645,12 +713,12 @@ float GetEffectiveScale(ModelT* model, SubGraphT* subgraph, int op_idx,
return scale;
}
void QuantizeActivation(TensorT* tensor) {
GetAsymmetricQuantizationParams(
tensor->quantization->min[0], tensor->quantization->max[0],
std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max(),
tensor->quantization.get());
tensor->type = TensorType_INT8;
TfLiteStatus QuantizeActivation(TensorT* tensor, TensorType activations_type,
ErrorReporter* error_reporter) {
TF_LITE_ENSURE_STATUS(GetQuantizationParams(
tensor, activations_type, tensor->quantization.get(), error_reporter));
tensor->type = activations_type;
return kTfLiteOk;
}
TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale) {

View File

@ -113,12 +113,14 @@ TfLiteStatus SymmetricQuantizeFloatsToInt16(ModelT* model, TensorT* tensor,
ErrorReporter* error_reporter);
// Symmetrically quantized the bias for per-layer ops (i.e. FullyConnected).
template <typename BiasType>
TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
float scaling_factor,
ErrorReporter* error_reporter);
// Symmetrically quantizes the bias for ops like Conv and DepthwiseConv.
// The scale of bias if weight_per_channel_scale[channel] * input_scale.
template <typename BiasType>
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
float input_scale,
const float* weight_scales,
@ -135,8 +137,14 @@ float GetEffectiveScale(ModelT* model, SubGraphT* subgraph, int op_idx,
std::vector<int> intermediate_index,
std::vector<float> factors);
// Return quantization parameters depending on activations type.
TfLiteStatus GetQuantizationParams(TensorT* tensor, TensorType activations_type,
QuantizationParametersT* quantization_params,
ErrorReporter* error_reporter);
// Quantize activation.
void QuantizeActivation(TensorT* tensor);
TfLiteStatus QuantizeActivation(TensorT* tensor, TensorType activations_type,
ErrorReporter* error_reporter);
// Quantize activation to 16bit.
TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale);

View File

@ -701,7 +701,7 @@ TEST_F(QuantizationUtilsTest, SymmetricPerLayerBiasQuantize) {
model->buffers.push_back(std::move(buffer));
// Call and verify.
EXPECT_EQ(SymmetricPerLayerBiasQuantize(
EXPECT_EQ(SymmetricPerLayerBiasQuantize<int32_t>(
model.get(), model->subgraphs[0]->tensors[0].get(),
input_scale * weight_scale, &error_reporter_),
kTfLiteOk);
@ -759,7 +759,7 @@ TEST_F(QuantizationUtilsTest, SymmetricPerChannelBiasQuantize) {
model->buffers.push_back(std::move(buffer));
// Call and verify.
EXPECT_EQ(SymmetricPerChannelBiasQuantize(
EXPECT_EQ(SymmetricPerChannelBiasQuantize<int32_t>(
model.get(), model->subgraphs[0]->tensors[0].get(), input_scale,
weight_scales.data(), 2, &error_reporter_),
kTfLiteOk);

View File

@ -42,7 +42,9 @@ bool CreateQuantizedModel(const std::string& path) {
tflite::StderrReporter error_reporter;
if (tflite::optimize::QuantizeModel(
&builder, &model, tflite::TensorType_FLOAT32,
tflite::TensorType_FLOAT32, &error_reporter) != kTfLiteOk) {
tflite::TensorType_FLOAT32,
// TODO: Pass required activation type if needed
tflite::TensorType_INT8, &error_reporter) != kTfLiteOk) {
return false;
}
return WriteFile(path, builder.GetBufferPointer(), builder.GetSize());

View File

@ -64,6 +64,7 @@ operator_property::OperatorProperty GetOperatorProperty(
TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
const TensorT* weight_tensor, TensorT* bias_tensor,
bool is_per_channel, int channel_dim_index,
const TensorType& activations_type,
ErrorReporter* error_reporter) {
if (bias_tensor->shape.size() != 1) {
error_reporter->Report("Expected bias tensor shape to be 1.");
@ -92,9 +93,15 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
weight_scales.size());
return kTfLiteError;
}
return utils::SymmetricPerChannelBiasQuantize(
model, bias_tensor, input_tensor->quantization->scale[0],
weight_scales.data(), channel_dim_size, error_reporter);
if (activations_type == tflite::TensorType_INT16) {
return utils::SymmetricPerChannelBiasQuantize<std::int64_t>(
model, bias_tensor, input_tensor->quantization->scale[0],
weight_scales.data(), channel_dim_size, error_reporter);
} else {
return utils::SymmetricPerChannelBiasQuantize<std::int32_t>(
model, bias_tensor, input_tensor->quantization->scale[0],
weight_scales.data(), channel_dim_size, error_reporter);
}
} else {
if (weight_scales.size() != 1) {
error_reporter->Report(
@ -102,40 +109,54 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
weight_scales.size());
return kTfLiteError;
}
return utils::SymmetricPerLayerBiasQuantize(
model, bias_tensor,
input_tensor->quantization->scale[0] * weight_scales[0],
error_reporter);
if (activations_type == tflite::TensorType_INT16) {
return utils::SymmetricPerLayerBiasQuantize<std::int64_t>(
model, bias_tensor,
input_tensor->quantization->scale[0] * weight_scales[0],
error_reporter);
} else {
return utils::SymmetricPerLayerBiasQuantize<std::int32_t>(
model, bias_tensor,
input_tensor->quantization->scale[0] * weight_scales[0],
error_reporter);
}
}
return kTfLiteError;
}
// True if the tensor type has to be modified.
bool TensorTypeChangeRequired(const TensorT* tensor, const TensorType& type) {
// The quantized model is type INT8, so if the user provided type is INT8, we
// do not have to do any custom logic. Additionally, if the current tensor
// isn't INT8 quantized, the custom type doesn't apply.
return (type != TensorType_INT8 && tensor->type == TensorType_INT8 &&
!tensor->quantization->scale.empty());
// The quantized model is type INT8/INT16, so if the user provided type is
// INT8/INT16, we do not have to do any custom logic. Additionally, if the
// current tensor isn't INT8/INT16 quantized, the custom type doesn't apply.
bool int8check = type != TensorType_INT8 && tensor->type == TensorType_INT8 &&
!tensor->quantization->scale.empty();
bool int16check = type != TensorType_INT16 &&
tensor->type == TensorType_INT16 &&
!tensor->quantization->scale.empty();
return (int8check || int16check);
}
// Sets the input type, adding a Leading Op node at the start of the model if
// necessary.
// Returns the new input tensor index.
int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
const int32_t tensor_idx, const TensorType& input_type) {
const int32_t tensor_idx, const TensorType& input_type,
const TensorType& activations_type) {
TensorT* tensor = subgraph->tensors[tensor_idx].get();
if (!TensorTypeChangeRequired(tensor, input_type)) {
return -1;
}
if (input_type == TensorType_FLOAT32 || input_type == TensorType_UINT8) {
std::string type_string =
activations_type == TensorType_INT16 ? "int16" : "int8";
// Create a new tensor to be the input of the leading Op.
std::unique_ptr<TensorT> leading_op_input;
if (input_type == TensorType_FLOAT32) {
// Add tensor for quantize operator. Scales and zero points are not
// needed.
const string leading_op_name = tensor->name;
const string new_name_original_input = tensor->name + "_int8";
const string new_name_original_input = tensor->name + "_" + type_string;
tensor->name = new_name_original_input;
utils::MakeTensor(leading_op_name, tensor->shape, input_type,
&leading_op_input);
@ -150,7 +171,7 @@ int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
TFLITE_DCHECK_GE(zero_point, -128);
TFLITE_DCHECK_LE(zero_point, 127);
const string leading_op_name = tensor->name;
const string new_name_original_input = tensor->name + "_int8";
const string new_name_original_input = tensor->name + "_" + type_string;
tensor->name = new_name_original_input;
utils::MakeTensorWithQuantParam(leading_op_name, tensor->shape,
input_type, scale, zero_point + 128,
@ -177,17 +198,20 @@ int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
// necessary.
// Returns the new output tensor index.
int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
const int32_t tensor_idx, const TensorType& output_type) {
const int32_t tensor_idx, const TensorType& output_type,
const TensorType& activations_type) {
TensorT* tensor = subgraph->tensors[tensor_idx].get();
if (!TensorTypeChangeRequired(tensor, output_type)) {
return -1;
}
if (output_type == TensorType_FLOAT32 || output_type == TensorType_UINT8) {
std::string type_string =
activations_type == TensorType_INT16 ? "int16" : "int8";
// Create a new tensor to be the output of the tailing op.
std::unique_ptr<TensorT> tailing_op_output;
if (output_type == TensorType_FLOAT32) {
const string tailing_op_name = tensor->name;
const string new_name_original_output = tensor->name + "_int8";
const string new_name_original_output = tensor->name + "_" + type_string;
tensor->name = new_name_original_output;
utils::MakeTensor(tailing_op_name, tensor->shape, output_type,
&tailing_op_output);
@ -202,7 +226,7 @@ int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
TFLITE_DCHECK_GE(zero_point, -128);
TFLITE_DCHECK_LE(zero_point, 127);
const string tailing_op_name = tensor->name;
const string new_name_original_output = tensor->name + "_int8";
const string new_name_original_output = tensor->name + "_" + type_string;
tensor->name = new_name_original_output;
utils::MakeTensorWithQuantParam(tailing_op_name, tensor->shape,
output_type, scale, zero_point + 128,
@ -238,6 +262,7 @@ int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
// uint8, can be thought as "requant").
TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
const TensorType& output_type,
const TensorType& activations_type,
ErrorReporter* error_reporter) {
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
subgraph_idx++) {
@ -253,8 +278,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
EnumNameTensorType(tensor->type));
return kTfLiteError;
}
const int32_t input_idx =
SetInputType(model, subgraph, subgraph->inputs[i], input_type);
const int32_t input_idx = SetInputType(
model, subgraph, subgraph->inputs[i], input_type, activations_type);
if (input_idx < 0) {
continue;
}
@ -270,8 +295,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
EnumNameTensorType(tensor->type));
return kTfLiteError;
}
const int32_t output_idx =
SetOutputType(model, subgraph, subgraph->outputs[i], output_type);
const int32_t output_idx = SetOutputType(
model, subgraph, subgraph->outputs[i], output_type, activations_type);
if (output_idx < 0) {
continue;
}
@ -287,6 +312,7 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
// The other ones with constraints are handled in QuantizeWeightsAndInput.
TfLiteStatus ApplyConstraints(ModelT* model,
const std::unordered_set<string>& operator_names,
TensorType activations_type,
ErrorReporter* error_reporter) {
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
subgraph_idx++) {
@ -332,7 +358,7 @@ TfLiteStatus ApplyConstraints(ModelT* model,
std::unique_ptr<TensorT> additional_tensor;
const string requant_tensor_name = input_tensor->name + "_requantized";
utils::MakeTensorWithQuantParam(
requant_tensor_name, input_tensor->shape, TensorType_INT8,
requant_tensor_name, input_tensor->shape, activations_type,
output_scale, output_zp, &additional_tensor);
const int32_t additional_tensor_idx = subgraph->tensors.size();
subgraph->tensors.push_back(std::move(additional_tensor));
@ -382,7 +408,8 @@ std::vector<std::pair<int, operator_property::TensorProperty>> GetOutputs(
bool ShouldRestrictSameInputOutputScale(
operator_property::OperatorProperty property) {
// Ops with multiple inputs (i.e. concat) gets restricted in ApplyConstraints.
// Ops with multiple inputs (i.e. concat, max and min) gets restricted in
// ApplyConstraints.
return (!property.arbitrary_inputs &&
property.restrict_same_input_output_scale);
}
@ -401,7 +428,7 @@ TfLiteStatus QuantizeOpInput(
ModelT* model, int32_t subgraph_idx, size_t* op_idx,
operator_property::OperatorProperty property,
const std::pair<int32_t, operator_property::TensorProperty>& input,
ErrorReporter* error_reporter) {
const TensorType& activations_type, ErrorReporter* error_reporter) {
int32_t input_idx = input.first;
operator_property::TensorProperty tensor_property = input.second;
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
@ -429,7 +456,9 @@ TfLiteStatus QuantizeOpInput(
if (utils::HasBuffer(model, subgraph, tensor_idx)) {
// TODO(suharshs): Look at consumers, throw error if one consumer is
// per-channel and one per-layer.
if (tensor_property.number_of_bits == 8) {
bool quantize_const_input = property.quantize_input_as_activations &&
activations_type == TensorType_INT16;
if (tensor_property.number_of_bits == 8 && !quantize_const_input) {
if (tensor_property.use_derived_scale) {
// Currently 8bit tensors in input do not accept derived scale.
return kTfLiteError;
@ -444,7 +473,7 @@ TfLiteStatus QuantizeOpInput(
*op_idx);
return kTfLiteError;
}
} else if (tensor_property.number_of_bits == 16) {
} else if (tensor_property.number_of_bits == 16 || quantize_const_input) {
if (tensor_property.use_derived_scale) {
// Currently 16bit tensors in input do not accept derived scale.
return kTfLiteError;
@ -476,8 +505,8 @@ TfLiteStatus QuantizeOpInput(
tensor_property.derived_scale.input_tensors,
tensor_property.derived_scale.intermediate_tensors,
tensor_property.derived_scale.factors);
return utils::SymmetricPerLayerBiasQuantize(model, tensor, scale,
error_reporter);
return utils::SymmetricPerLayerBiasQuantize<std::int32_t>(
model, tensor, scale, error_reporter);
} else if (tensor_property.number_of_bits == 10) {
// When the number of bits is 10 (instead of 16), quantize the tensor to
@ -514,7 +543,8 @@ TfLiteStatus QuantizeOpInput(
// Currently 8bit tensors in input do not accept derived scale.
return kTfLiteError;
}
utils::QuantizeActivation(tensor);
TF_LITE_ENSURE_STATUS(utils::QuantizeActivation(
tensor, activations_type, error_reporter));
} else if (tensor_property.number_of_bits == 16) {
TensorT* tensor = subgraph->tensors[tensor_idx].get();
float quantized_range = 32767.0;
@ -532,13 +562,16 @@ TfLiteStatus QuantizeOpInput(
} else {
// If the tensor is not a model input, we need to add a Quantize
// operation since the preceding op may require a float output.
std::string type_string =
activations_type == TensorType_INT16 ? "int16" : "int8";
std::unique_ptr<TensorT> op_output;
utils::MakeTensor(tensor->name + "_int8", tensor->shape,
TensorType_INT8, &op_output);
utils::MakeTensor(tensor->name + "_" + type_string, tensor->shape,
activations_type, &op_output);
op_output->quantization = absl::make_unique<QuantizationParametersT>();
op_output->quantization->min.push_back(tensor->quantization->min[0]);
op_output->quantization->max.push_back(tensor->quantization->max[0]);
utils::QuantizeActivation(op_output.get());
TF_LITE_ENSURE_STATUS(utils::QuantizeActivation(
op_output.get(), activations_type, error_reporter));
const int32_t quant_op_output_idx = subgraph->tensors.size();
subgraph->tensors.push_back(std::move(op_output));
std::unique_ptr<OperatorT> quant_op;
@ -580,7 +613,7 @@ TfLiteStatus QuantizeOpOutput(
ModelT* model, int32_t subgraph_idx, int32_t op_idx,
operator_property::OperatorProperty property,
const std::pair<int32_t, operator_property::TensorProperty>& output,
ErrorReporter* error_reporter) {
TensorType activations_type, ErrorReporter* error_reporter) {
int32_t output_idx = output.first;
operator_property::TensorProperty tensor_property = output.second;
// If the operator is not quantizable, we don't need to do anything for the
@ -644,18 +677,22 @@ TfLiteStatus QuantizeOpOutput(
const float max = input_tensor->quantization->max[0];
output_tensor->quantization->max = {max};
}
output_tensor->type = TensorType_INT8;
output_tensor->type = activations_type;
} else if (tensor_property.restriction) {
const auto scale_and_zp = tensor_property.restricted_value;
const auto scale_and_zp = activations_type == TensorType_INT16
? tensor_property.restricted_value_int16
: tensor_property.restricted_value_int8;
// Apply to output.
output_tensor->quantization = absl::make_unique<QuantizationParametersT>();
output_tensor->quantization->scale.push_back(scale_and_zp.first);
output_tensor->quantization->zero_point.push_back(scale_and_zp.second);
output_tensor->type = TensorType_INT8;
output_tensor->type = activations_type;
} else {
// Process regular output that doesn't have any restrictions.
if (utils::HasMinMax(output_tensor)) {
utils::QuantizeActivation(output_tensor);
utils::QuantizeActivation(output_tensor, activations_type,
error_reporter);
} else {
error_reporter->Report(
"Unable to find min/max value for output %d in %s in "
@ -668,6 +705,7 @@ TfLiteStatus QuantizeOpOutput(
}
TfLiteStatus QuantizeIntemediateTensors(ModelT* model,
TensorType activations_type,
ErrorReporter* error_reporter) {
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
subgraph_idx++) {
@ -691,7 +729,8 @@ TfLiteStatus QuantizeIntemediateTensors(ModelT* model,
input.second.symmetric == false) {
TensorT* tensor = subgraph->tensors[index_global].get();
if (utils::HasMinMax(tensor)) {
utils::QuantizeActivation(tensor);
utils::QuantizeActivation(tensor, activations_type,
error_reporter);
} else {
error_reporter->Report(
"Unable to find min/max value for output %d in %s in "
@ -793,7 +832,7 @@ TfLiteStatus QuantizeSharedRange(ModelT* model, ErrorReporter* error_reporter) {
TfLiteStatus QuantizeWeightsInputOutput(
ModelT* model, bool allow_float,
const std::unordered_set<string>& operator_names,
ErrorReporter* error_reporter) {
const TensorType& activations_type, ErrorReporter* error_reporter) {
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
subgraph_idx++) {
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
@ -815,14 +854,16 @@ TfLiteStatus QuantizeWeightsInputOutput(
for (const std::pair<int, operator_property::TensorProperty>& input :
GetInputs(op, property)) {
TF_LITE_ENSURE_STATUS(QuantizeOpInput(model, subgraph_idx, &op_idx,
property, input, error_reporter));
property, input, activations_type,
error_reporter));
}
// Quantize operator outputs.
for (const std::pair<int, operator_property::TensorProperty>& output :
GetOutputs(op, property)) {
TF_LITE_ENSURE_STATUS(QuantizeOpOutput(
model, subgraph_idx, op_idx, property, output, error_reporter));
TF_LITE_ENSURE_STATUS(
QuantizeOpOutput(model, subgraph_idx, op_idx, property, output,
activations_type, error_reporter));
}
}
}
@ -832,6 +873,7 @@ TfLiteStatus QuantizeWeightsInputOutput(
// Quantize bias.
TfLiteStatus QuantizeBiases(ModelT* model,
const std::unordered_set<string>& operator_names,
const TensorType& activations_type,
ErrorReporter* error_reporter) {
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
subgraph_idx++) {
@ -877,10 +919,10 @@ TfLiteStatus QuantizeBiases(ModelT* model,
subgraph->tensors[op->inputs[property.inputs[1].first]].get();
operator_property::TensorProperty weight_property =
property.inputs[1].second;
TF_LITE_ENSURE_STATUS(
QuantizeBias(model, input_tensor, weight_tensor, bias_tensor,
weight_property.per_axis,
weight_property.per_axis_index, error_reporter));
TF_LITE_ENSURE_STATUS(QuantizeBias(
model, input_tensor, weight_tensor, bias_tensor,
weight_property.per_axis, weight_property.per_axis_index,
activations_type, error_reporter));
}
}
}
@ -1000,7 +1042,7 @@ TfLiteStatus FillQuantizationParams(
// Check compatibility of activation, weight and bias scales. Adjust if needed.
TfLiteStatus EnsureBiasScaleCompatibility(
ModelT* model, const std::unordered_set<string>& operator_names,
ErrorReporter* error_reporter) {
TensorType activations_type, ErrorReporter* error_reporter) {
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
subgraph_idx++) {
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
@ -1049,11 +1091,9 @@ TfLiteStatus EnsureBiasScaleCompatibility(
// Get input scale for assymmetric quantization.
QuantizationParametersT temp_quant_params = QuantizationParametersT();
utils::GetAsymmetricQuantizationParams(
input_tensor->quantization->min[0],
input_tensor->quantization->max[0],
std::numeric_limits<int8_t>::min(),
std::numeric_limits<int8_t>::max(), &temp_quant_params);
TF_LITE_ENSURE_STATUS(
utils::GetQuantizationParams(input_tensor, activations_type,
&temp_quant_params, error_reporter));
if (temp_quant_params.scale.size() != 1) {
error_reporter->Report("Unexpected input quantization scale size.");
return kTfLiteError;
@ -1132,21 +1172,24 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* model, const TensorType& input_type,
const TensorType& output_type, bool allow_float,
const std::unordered_set<string>& operator_names,
const TensorType& activations_type,
ErrorReporter* error_reporter) {
TF_LITE_ENSURE_STATUS(
FillQuantizationParams(model, operator_names, error_reporter));
TF_LITE_ENSURE_STATUS(EnsureBiasScaleCompatibility(
model, operator_names, activations_type, error_reporter));
TF_LITE_ENSURE_STATUS(
EnsureBiasScaleCompatibility(model, operator_names, error_reporter));
TF_LITE_ENSURE_STATUS(QuantizeIntemediateTensors(model, error_reporter));
QuantizeIntemediateTensors(model, activations_type, error_reporter));
TF_LITE_ENSURE_STATUS(QuantizeSharedRange(model, error_reporter));
TF_LITE_ENSURE_STATUS(QuantizeWeightsInputOutput(
model, allow_float, operator_names, error_reporter));
model, allow_float, operator_names, activations_type, error_reporter));
TF_LITE_ENSURE_STATUS(ApplyConstraints(model, operator_names,
activations_type, error_reporter));
TF_LITE_ENSURE_STATUS(
ApplyConstraints(model, operator_names, error_reporter));
TF_LITE_ENSURE_STATUS(QuantizeBiases(model, operator_names, error_reporter));
QuantizeBiases(model, operator_names, activations_type, error_reporter));
utils::SetOperatorCodeVersion(model);
TF_LITE_ENSURE_STATUS(
SetInputAndOutputTypes(model, input_type, output_type, error_reporter));
TF_LITE_ENSURE_STATUS(SetInputAndOutputTypes(
model, input_type, output_type, activations_type, error_reporter));
flatbuffers::Offset<Model> output_model_location =
Model::Pack(*builder, model);
@ -1158,23 +1201,27 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* model, const TensorType& input_type,
const TensorType& output_type, bool allow_float,
const TensorType& activations_type,
ErrorReporter* error_reporter) {
return QuantizeModel(builder, model, input_type, output_type, allow_float,
GetAllOperatorOutputs(model), error_reporter);
GetAllOperatorOutputs(model), activations_type,
error_reporter);
}
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* model, const TensorType& input_type,
const TensorType& output_type,
const TensorType& activations_type,
ErrorReporter* error_reporter) {
return QuantizeModel(builder, model, input_type, output_type,
/*allow_float=*/false, error_reporter);
/*allow_float=*/false, activations_type, error_reporter);
}
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* model, ErrorReporter* error_reporter) {
ModelT* model, const TensorType& activations_type,
ErrorReporter* error_reporter) {
return QuantizeModel(builder, model, TensorType_FLOAT32, TensorType_FLOAT32,
/*allow_float=*/false, error_reporter);
/*allow_float=*/false, activations_type, error_reporter);
}
} // namespace optimize

View File

@ -35,7 +35,9 @@ namespace optimize {
//
// Note: This is a private API, subject to change.
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* input_model, ErrorReporter* error_reporter);
ModelT* input_model,
const TensorType& activations_type,
ErrorReporter* error_reporter);
// Same as above, but the types of quantized inputs and outputs are
// configurable.
@ -44,6 +46,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* input_model, const TensorType& input_type,
const TensorType& output_type,
const TensorType& activations_type,
ErrorReporter* error_reporter);
// Same as above, but can enable allowing float intermediate operations for ops
@ -53,6 +56,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* input_model, const TensorType& input_type,
const TensorType& output_type, bool allow_float,
const TensorType& activations_type,
ErrorReporter* error_reporter);
// Same as above, but enables only quantizing a whitelist of operations,
@ -63,6 +67,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
ModelT* input_model, const TensorType& input_type,
const TensorType& output_type, bool allow_float,
const std::unordered_set<string>& operator_names,
const TensorType& activations_type,
ErrorReporter* error_reporter);
} // namespace optimize

View File

@ -80,28 +80,35 @@ class QuantizeModelTest : public testing::Test {
internal::FailOnErrorReporter error_reporter_;
};
class QuantizeConvModelTest : public QuantizeModelTest {
class QuantizeConvModelTest : public QuantizeModelTest,
public testing::WithParamInterface<TensorType> {
protected:
QuantizeConvModelTest() {
tensor_type_ = GetParam();
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
}
TensorType tensor_type_;
};
TEST_F(QuantizeConvModelTest, QuantizationSucceeds) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
INSTANTIATE_TEST_SUITE_P(QuantizeConvModelTestInst, QuantizeConvModelTest,
testing::ValuesIn({TensorType_INT8,
TensorType_INT16}));
TEST_P(QuantizeConvModelTest, QuantizationSucceeds) {
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
const uint8_t* buffer = builder_.GetBufferPointer();
const Model* output_model = GetModel(buffer);
ASSERT_TRUE(output_model);
}
TEST_F(QuantizeConvModelTest, SkipUnspecifiedLayer) {
TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) {
auto status =
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
/*allow_float=*/true, {}, &error_reporter_);
/*allow_float=*/true, {}, tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
// The resulting model should be the same.
@ -123,9 +130,9 @@ TEST_F(QuantizeConvModelTest, SkipUnspecifiedLayer) {
}
}
TEST_F(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
TEST_P(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
@ -148,9 +155,9 @@ TEST_F(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
EXPECT_EQ(model_.operator_codes[0]->version, 3);
}
TEST_F(QuantizeConvModelTest, OperatorsAreUnchanged) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
TEST_P(QuantizeConvModelTest, OperatorsAreUnchanged) {
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
ASSERT_EQ(model_.operator_codes.size(),
readonly_model_->operator_codes()->size());
@ -182,20 +189,28 @@ TEST_F(QuantizeConvModelTest, OperatorsAreUnchanged) {
}
}
TEST_F(QuantizeConvModelTest, GraphIsFullyQuantized) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) {
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
for (const auto& subgraph : model_.subgraphs) {
for (const auto& tensor : subgraph->tensors) {
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
tensor->type == TensorType_INT8);
if (tensor_type_ == TensorType_INT8) {
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
tensor->type == TensorType_INT8);
} else if (tensor_type_ == TensorType_INT16) {
EXPECT_TRUE(tensor->type == TensorType_INT64 || // bias
tensor->type == TensorType_INT8 || // weights
tensor->type == TensorType_INT16); // activations
}
}
}
}
TEST_F(QuantizeConvModelTest, FloatInputAndOutput) {
auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
TEST_P(QuantizeConvModelTest, FloatInputAndOutput) {
auto status =
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
@ -234,22 +249,33 @@ TEST_F(QuantizeConvModelTest, FloatInputAndOutput) {
EXPECT_EQ(subgraph->tensors[output_idx]->type, TensorType_FLOAT32);
EXPECT_EQ(subgraph->tensors[output_idx]->name, "output");
// The original input and output has been renamed.
EXPECT_EQ(subgraph->tensors[quant_op->outputs[0]]->name, "input_int8");
EXPECT_EQ(subgraph->tensors[dequant_op->inputs[0]]->name, "output_int8");
std::string control_suffix =
(tensor_type_ == TensorType_INT16) ? "int16" : "int8";
EXPECT_EQ(subgraph->tensors[quant_op->outputs[0]]->name,
"input_" + control_suffix);
EXPECT_EQ(subgraph->tensors[dequant_op->inputs[0]]->name,
"output_" + control_suffix);
for (int tensor_idx = 0; tensor_idx < subgraph->tensors.size();
++tensor_idx) {
const auto& tensor = subgraph->tensors[tensor_idx];
if (input_idx != tensor_idx && output_idx != tensor_idx) {
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
tensor->type == TensorType_INT8);
if (tensor_type_ == TensorType_INT8) {
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
tensor->type == TensorType_INT8);
} else if (tensor_type_ == TensorType_INT16) {
EXPECT_TRUE(tensor->type == TensorType_INT64 || // bias
tensor->type == TensorType_INT8 || // weights
tensor->type == TensorType_INT16); // activations
}
}
}
}
}
TEST_F(QuantizeConvModelTest, Uint8InputAndOutput) {
auto status = QuantizeModel(&builder_, &model_, TensorType_UINT8,
TensorType_UINT8, &error_reporter_);
TEST_P(QuantizeConvModelTest, Uint8InputAndOutput) {
auto status =
QuantizeModel(&builder_, &model_, TensorType_UINT8, TensorType_UINT8,
TensorType_INT8, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
@ -326,21 +352,25 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest {
};
TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
const uint8_t* buffer = builder_.GetBufferPointer();
const Model* output_model = GetModel(buffer);
ASSERT_TRUE(output_model);
}
class QuantizeConcatModelTest : public QuantizeModelTest {
class QuantizeConcatModelTest : public QuantizeModelTest,
public testing::WithParamInterface<TensorType> {
protected:
QuantizeConcatModelTest() {
input_model_ = ReadModel(internal::kFloatConcatMax5Max10Max10);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
}
TensorType tensor_type_;
};
// There are two inputs for concat, "input0" and "input1". "input0" has [0, 5]
@ -352,9 +382,9 @@ class QuantizeConcatModelTest : public QuantizeModelTest {
// input0 -> requant -> input0_requant \
// concat - output
// input1 /
TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
TEST_P(QuantizeConcatModelTest, AddRequantBeforeConcat) {
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
// There is only one subgraph.
@ -373,32 +403,51 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
EXPECT_EQ(model_.operator_codes[concat->opcode_index]->builtin_code,
BuiltinOperator_CONCATENATION);
auto zero_point_control = tensor_type_ == TensorType_INT8 ? -128 : 0;
/*
input0_scale_control
INT8: (5-0) / (2^8 - 1)
INT16: (5-0) / (2^16 / 2 - 1)
input1_scale
INT8: (10-0) / (2^8 - 1)
INT16: (10-0) / (2^16 / 2 - 1)
*/
auto input0_scale_control =
tensor_type_ == TensorType_INT8 ? 0.019607844 : 0.00015259254;
auto input1_scale =
tensor_type_ == TensorType_INT8 ? 0.039215688 : 0.00030518509;
// There should be 4 tensors: input0, input1, input0_requantized, output.
EXPECT_EQ(subgraph->tensors.size(), 4);
EXPECT_EQ(subgraph->tensors[0]->type, TensorType_INT8);
EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
EXPECT_EQ(subgraph->tensors[0]->name, "input0");
EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0], 0.019607844);
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0], -128);
EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT8);
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0],
input0_scale_control);
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0],
zero_point_control);
EXPECT_EQ(subgraph->tensors[1]->type, tensor_type_);
EXPECT_EQ(subgraph->tensors[1]->name, "input1");
EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1);
EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1);
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], 0.039215688);
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0], -128);
EXPECT_EQ(subgraph->tensors[2]->type, TensorType_INT8);
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], input1_scale);
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0],
zero_point_control);
EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
EXPECT_EQ(subgraph->tensors[2]->name, "output");
EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], 0.039215688);
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0], -128);
EXPECT_EQ(subgraph->tensors[3]->type, TensorType_INT8);
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], input1_scale);
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0],
zero_point_control);
EXPECT_EQ(subgraph->tensors[3]->type, tensor_type_);
EXPECT_EQ(subgraph->tensors[3]->name, "input0_requantized");
EXPECT_EQ(subgraph->tensors[3]->quantization->scale.size(), 1);
EXPECT_EQ(subgraph->tensors[3]->quantization->zero_point.size(), 1);
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->scale[0], 0.039215688);
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->zero_point[0], -128);
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->scale[0], input1_scale);
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->zero_point[0],
zero_point_control);
// The connection should be what is described in the comment.
EXPECT_EQ(requant->inputs.size(), 1);
@ -419,7 +468,9 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_QUANTIZE);
EXPECT_EQ(model_.operator_codes[1]->version, 2);
}
INSTANTIATE_TEST_SUITE_P(QuantizeConcatModelInst, QuantizeConcatModelTest,
testing::ValuesIn({TensorType_INT8,
TensorType_INT16}));
class QuantizeSplitModelTest : public QuantizeModelTest {
protected:
QuantizeSplitModelTest() {
@ -432,8 +483,9 @@ class QuantizeSplitModelTest : public QuantizeModelTest {
// There are two outputs for split with different scales, the resulting model
// should have the scales be hardcodes to the input scale value.
TEST_F(QuantizeSplitModelTest, QuantizeSplit) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
// There is only one subgraph.
@ -496,8 +548,9 @@ class QuantizeConvModel1Test : public QuantizeModelTest {
};
TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
const auto& subgraph = model_.subgraphs[0];
@ -587,18 +640,25 @@ TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
EXPECT_EQ(model_.operator_codes[0]->version, 3);
}
class QuantizeConvModel2Test : public QuantizeModelTest {
class QuantizeConvModel2Test : public QuantizeModelTest,
public testing::WithParamInterface<TensorType> {
protected:
QuantizeConvModel2Test() {
tensor_type_ = GetParam();
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
}
};
TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
TensorType tensor_type_;
};
INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test,
testing::ValuesIn({TensorType_INT8,
TensorType_INT16}));
TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) {
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
tensor_type_, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
const auto& subgraph = model_.subgraphs[0];
auto conv_op = subgraph->operators[0].get();
@ -615,8 +675,10 @@ TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) {
const auto output_tensor =
subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
EXPECT_EQ(bias_tensor->type, TensorType_INT32);
EXPECT_EQ(input_tensor->type, TensorType_INT8);
EXPECT_EQ(bias_tensor->type, tensor_type_ == TensorType_INT8
? TensorType_INT32
: TensorType_INT64);
EXPECT_EQ(input_tensor->type, tensor_type_);
EXPECT_EQ(weights_tensor->type, TensorType_INT8);
ASSERT_TRUE(weights_tensor->quantization);
@ -644,17 +706,28 @@ TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) {
}
const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
ASSERT_EQ(bias_buffer->data.size(), sizeof(int32_t) * bias_tensor->shape[0]);
const int32_t* bias_values =
reinterpret_cast<int32_t*>(bias_buffer->data.data());
auto control_size = tensor_type_ == TensorType_INT8
? sizeof(int32_t) * bias_tensor->shape[0]
: sizeof(int64_t) * bias_tensor->shape[0];
ASSERT_EQ(bias_buffer->data.size(), control_size);
const auto original_bias_buffer =
readonly_model_->buffers()->Get(bias_tensor->buffer);
const float* bias_float_buffer =
reinterpret_cast<const float*>(original_bias_buffer->data()->data());
for (size_t i = 0; i < out_channel_size; i++) {
auto dequantized_value = bias_values[i] * bias_scales[i];
EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
if (tensor_type_ == TensorType_INT8) {
int32_t* bias_values = reinterpret_cast<int32_t*>(bias_buffer->data.data());
for (size_t i = 0; i < out_channel_size; i++) {
auto dequantized_value = bias_values[i] * bias_scales[i];
EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
}
} else if (tensor_type_ == TensorType_INT16) {
int64_t* bias_values = reinterpret_cast<int64_t*>(bias_buffer->data.data());
for (size_t i = 0; i < out_channel_size; i++) {
auto dequantized_value = bias_values[i] * bias_scales[i];
EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
}
}
const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
@ -695,8 +768,9 @@ class QuantizeSoftmaxTest : public QuantizeModelTest {
};
TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
const auto& subgraph = model_.subgraphs[0];
@ -755,8 +829,9 @@ class QuantizeAvgPoolTest : public QuantizeModelTest {
};
TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
const auto& subgraph = model_.subgraphs[0];
@ -816,8 +891,9 @@ class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest {
};
TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
// Verify Reshape is quantized.
@ -863,8 +939,9 @@ TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
}
TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
// Verify ADD is quantized.
@ -923,8 +1000,9 @@ class QuantizeConstInputTest : public QuantizeModelTest {
};
TEST_F(QuantizeConstInputTest, VerifyConstOpInput) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
// Verify ConstOp is quantized.
@ -965,8 +1043,9 @@ class QuantizeArgMaxTest : public QuantizeModelTest {
};
TEST_F(QuantizeArgMaxTest, VerifyArgMax) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
const auto& subgraph = model_.subgraphs[0];
@ -1008,8 +1087,9 @@ class QuantizeLSTMTest : public QuantizeModelTest {
TEST_F(QuantizeLSTMTest, VerifyLSTM) {
// Quantize model.
auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32,
TensorType_FLOAT32, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
// Read expected model.
@ -1067,8 +1147,9 @@ class QuantizeLSTM2Test : public QuantizeModelTest {
TEST_F(QuantizeLSTM2Test, VerifyLSTM) {
// Quantize model.
auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32,
TensorType_FLOAT32, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
// Read expected model.
@ -1126,8 +1207,9 @@ class QuantizeSVDFTest : public QuantizeModelTest {
TEST_F(QuantizeSVDFTest, VerifySVDF) {
// Quantize model.
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
// Read expected model.
@ -1184,8 +1266,9 @@ class QuantizeFCTest : public QuantizeModelTest {
};
TEST_F(QuantizeFCTest, VerifyFC) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
const auto& subgraph = model_.subgraphs[0];
@ -1236,7 +1319,7 @@ class QuantizeCustomOpTest : public QuantizeModelTest {
TEST_F(QuantizeCustomOpTest, VerifyMixedQuantization) {
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
/*allow_float=*/true, &error_reporter_);
/*allow_float=*/true, TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
const auto& subgraph = model_.subgraphs[0];
auto float_graph = readonly_model_->subgraphs()->Get(0);
@ -1270,7 +1353,8 @@ class QuantizePackTest : public QuantizeModelTest {
};
TEST_F(QuantizePackTest, VerifyPack) {
auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
@ -1334,7 +1418,8 @@ class QuantizeMinimumMaximumTest
};
TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) {
auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
const auto& subgraph = model_.subgraphs[0];
@ -1415,7 +1500,8 @@ class QuantizeUnpackTest : public QuantizeModelTest {
}
};
TEST_F(QuantizeUnpackTest, VerifyUnpack) {
auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
auto status =
QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);