Added an option TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 to
enable sym quantization with activations in 16-bit and weigths in 8-bit.
This commit is contained in:
parent
2e98e89091
commit
a7899d7544
@ -93,6 +93,12 @@ class OpsSet(enum.Enum):
|
||||
# quantized implementations.
|
||||
TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8"
|
||||
|
||||
# Convert model using only TensorFlow Lite operations with quantized int8 weights
|
||||
# and int16 activations.
|
||||
# Specifying this will throw an error for operations that do not yet have
|
||||
# quantized implementations.
|
||||
TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 = "TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
@ -224,6 +224,10 @@ class TFLiteConverterBase(object):
|
||||
self.target_spec.supported_ops) or
|
||||
self._smallest_supported_type() == constants.INT8)
|
||||
|
||||
def _is_int16x8_target_required(self):
|
||||
return (set([OpsSet.TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8]) ==
|
||||
set(self.target_spec.supported_ops))
|
||||
|
||||
def _smallest_supported_type(self):
|
||||
if self.target_spec.supported_types:
|
||||
return min(self.target_spec.supported_types, key=lambda x: x.size)
|
||||
@ -238,7 +242,9 @@ class TFLiteConverterBase(object):
|
||||
]))
|
||||
|
||||
def _is_post_training_optimize(self):
|
||||
return self._is_int8_target_required() or self._any_optimization_enabled()
|
||||
return self._is_int8_target_required() or \
|
||||
self._is_int16x8_target_required() or \
|
||||
self._any_optimization_enabled()
|
||||
|
||||
def _is_int8_weight_only_quantize(self):
|
||||
return (self._is_post_training_optimize() and
|
||||
@ -255,11 +261,12 @@ class TFLiteConverterBase(object):
|
||||
|
||||
def _calibrate_quantize_model(self, result, inference_input_type,
|
||||
inference_output_type, enable_mlir_quantizer):
|
||||
allow_float = not self._is_int8_target_required()
|
||||
allow_float = not self._is_int8_target_required() and not self._is_int16x8_target_required()
|
||||
calibrate_quantize = _calibrator.Calibrator(result)
|
||||
activations_type = constants.INT16 if self._is_int16x8_target_required() else constants.INT8
|
||||
return calibrate_quantize.calibrate_and_quantize(
|
||||
self.representative_dataset.input_gen, inference_input_type,
|
||||
inference_output_type, allow_float, enable_mlir_quantizer)
|
||||
inference_output_type, allow_float, activations_type, enable_mlir_quantizer)
|
||||
|
||||
def _get_base_converter_args(self):
|
||||
"""Returns the base converter args.
|
||||
|
@ -30,6 +30,7 @@ INT64 = dtypes.int64
|
||||
STRING = dtypes.string
|
||||
QUANTIZED_UINT8 = dtypes.uint8
|
||||
INT8 = dtypes.int8
|
||||
INT16 = dtypes.int16
|
||||
COMPLEX64 = dtypes.complex64
|
||||
TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF
|
||||
TFLITE = _toco_flags_pb2.TFLITE
|
||||
@ -43,6 +44,7 @@ _tf_export(v1=["lite.constants.STRING"]).export_constant(__name__, "STRING")
|
||||
_tf_export(v1=["lite.constants.QUANTIZED_UINT8"]).export_constant(
|
||||
__name__, "QUANTIZED_UINT8")
|
||||
_tf_export(v1=["lite.constants.INT8"]).export_constant(__name__, "INT8")
|
||||
_tf_export(v1=["lite.constants.INT16"]).export_constant(__name__, "INT16")
|
||||
_tf_export(v1=["lite.constants.TFLITE"]).export_constant(__name__, "TFLITE")
|
||||
_tf_export(v1=["lite.constants.GRAPHVIZ_DOT"]).export_constant(
|
||||
__name__, "GRAPHVIZ_DOT")
|
||||
@ -62,6 +64,7 @@ _allowed_symbols = [
|
||||
"STRING",
|
||||
"QUANTIZED_UINT8",
|
||||
"INT8",
|
||||
"INT16",
|
||||
"COMPLEX64",
|
||||
"TENSORFLOW_GRAPHDEF",
|
||||
"TFLITE",
|
||||
|
@ -769,9 +769,13 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('EnableMlirConverter', True), # enable mlir
|
||||
('DisableMlirConverter', False)) # disable mlir
|
||||
def testCalibrateAndQuantizeBuiltinInt8(self, enable_mlir):
|
||||
# Quantize model to Int8: with enable mlir
|
||||
('UseTfliteBuiltinsIntEnableMLIR', [lite.OpsSet.TFLITE_BUILTINS_INT8], True),
|
||||
# Quantize model to Int8: with disable mlir
|
||||
('UseTfliteBuiltinsIntDisableMLIR', [lite.OpsSet.TFLITE_BUILTINS_INT8], False),
|
||||
# Quantize model to Int16: with disable mlir
|
||||
('UseTfliteBuiltinsInt16DisableMLIR', [lite.OpsSet.TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8], False))
|
||||
def testCalibrateAndQuantizeBuiltinInt(self, supported_ops, enable_mlir):
|
||||
with ops.Graph().as_default():
|
||||
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
|
||||
sess = session.Session()
|
||||
@ -787,9 +791,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
quantized_converter = lite.TFLiteConverter.from_session(
|
||||
sess, [inp], [output])
|
||||
quantized_converter.experimental_new_converter = enable_mlir
|
||||
quantized_converter.target_spec.supported_ops = [
|
||||
lite.OpsSet.TFLITE_BUILTINS_INT8
|
||||
]
|
||||
quantized_converter.target_spec.supported_ops = supported_ops
|
||||
quantized_converter.representative_dataset = calibration_gen
|
||||
quantized_tflite = quantized_converter.convert()
|
||||
self.assertTrue(quantized_tflite)
|
||||
|
@ -204,6 +204,7 @@ PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) {
|
||||
PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
||||
int output_py_type,
|
||||
bool allow_float,
|
||||
int activations_py_type,
|
||||
bool enable_mlir_quantizer) {
|
||||
if (NoOpModel(*model_)) {
|
||||
return python_utils::ConvertToPyString(model_str_->data(),
|
||||
@ -212,6 +213,9 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
||||
|
||||
TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
|
||||
TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
|
||||
TfLiteType activations_type =
|
||||
python_utils::TfLiteTypeFromPyType(activations_py_type);
|
||||
|
||||
if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Input/output type cannot be kTfLiteNoType");
|
||||
@ -230,7 +234,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
||||
status = tflite::optimize::QuantizeModel(
|
||||
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
|
||||
TfLiteTypeToSchemaType(output_type), allow_float,
|
||||
error_reporter_.get());
|
||||
TfLiteTypeToSchemaType(activations_type), error_reporter_.get());
|
||||
}
|
||||
|
||||
if (status != kTfLiteOk) {
|
||||
@ -262,7 +266,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
||||
auto status = tflite::optimize::QuantizeModel(
|
||||
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
|
||||
TfLiteTypeToSchemaType(output_type), allow_float, {op_name},
|
||||
error_reporter_.get());
|
||||
TensorType_INT8, error_reporter_.get());
|
||||
if (status != kTfLiteOk) {
|
||||
error_reporter_->exception();
|
||||
return nullptr;
|
||||
|
@ -60,7 +60,8 @@ class CalibrationWrapper {
|
||||
PyObject* FeedTensor(PyObject* input_value);
|
||||
|
||||
PyObject* QuantizeModel(int input_py_type, int output_py_type,
|
||||
bool allow_float, bool enable_mlir_quantizer = false);
|
||||
bool allow_float, int activations_py_type,
|
||||
bool enable_mlir_quantizer = false);
|
||||
|
||||
// Allows quantizing only the operator that produces the tensor with name
|
||||
// operator_output_name. (This can be used to help debug.).
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
from tensorflow.lite.python import lite_constants
|
||||
|
||||
# Lazy load since some of the performance benchmark skylark rules
|
||||
# break dependencies. Must use double quotes to match code internal rewrite
|
||||
@ -55,7 +56,8 @@ class Calibrator(object):
|
||||
raise ValueError("Failed to parse the model.")
|
||||
|
||||
def calibrate_and_quantize(self, dataset_gen, input_type, output_type,
|
||||
allow_float, enable_mlir_quantizer=False):
|
||||
allow_float, activations_type = lite_constants.INT8,
|
||||
enable_mlir_quantizer=False):
|
||||
"""Calibrates the model with specified generator and then quantizes it.
|
||||
|
||||
Returns:
|
||||
@ -69,6 +71,7 @@ class Calibrator(object):
|
||||
computation, useful when targeting an integer-only backend.
|
||||
If False, an error will be thrown if an operation cannot be
|
||||
quantized, otherwise the model will fallback to float ops.
|
||||
activations_type: A tf.dtype representing the desired type for activations
|
||||
enable_mlir_quantizer: A boolean. True if wants to use mlir quantizer to
|
||||
quantize the calibrated model.
|
||||
"""
|
||||
@ -78,6 +81,7 @@ class Calibrator(object):
|
||||
return self._calibrator.QuantizeModel(
|
||||
np.dtype(input_type.as_numpy_dtype()).num,
|
||||
np.dtype(output_type.as_numpy_dtype()).num, allow_float,
|
||||
np.dtype(activations_type.as_numpy_dtype()).num,
|
||||
enable_mlir_quantizer)
|
||||
|
||||
def calibrate_and_quantize_single(self, dataset_gen, input_type, output_type,
|
||||
|
@ -33,9 +33,13 @@ from tensorflow.python.platform import test
|
||||
class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('EnableMlirQuantizer', True), # enable mlir quantizer
|
||||
('DisableMlirQuantizer', False)) # disable mlir quantizer
|
||||
def test_calibration_with_quantization(self, enable_mlir):
|
||||
# Activation type Int8 - enable mlir quantizer
|
||||
('UseActivationTypeInt8EnabledMlir', constants.INT8, True),
|
||||
# Activation type Int8 - disable mlir quantizer
|
||||
('UseActivationTypeInt8DisabledMlir', constants.INT8, False),
|
||||
# Activation type Int16
|
||||
('UseActivationTypeInt16', constants.INT16, False))
|
||||
def test_calibration_with_quantization(self, activations_type, enable_mlir):
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
'test_data/mobilenet_like_model.bin')
|
||||
float_model = open(model_path, 'rb').read()
|
||||
@ -49,13 +53,18 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||
constants.FLOAT,
|
||||
constants.FLOAT, False,
|
||||
activations_type,
|
||||
enable_mlir)
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('EnableMlirQuantizer', True), # enable mlir quantizer
|
||||
('DisableMlirQuantizer', False)) # disable mlir quantizer
|
||||
def test_calibration_with_quantization_allow_float(self, enable_mlir):
|
||||
# Activation type Int8 - enable mlir quantizer
|
||||
('UseActivationTypeInt8EnabledMlir', constants.INT8, True),
|
||||
# Activation type Int8 - disable mlir quantizer
|
||||
('UseActivationTypeInt8DisableMlir', constants.INT8, False),
|
||||
# Activation type Int16 - disable mlir quantizer
|
||||
('UseActivationTypeInt16', constants.INT16, False))
|
||||
def test_calibration_with_quantization_allow_float(self, activations_type, enable_mlir):
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
'test_data/mobilenet_like_model.bin')
|
||||
float_model = open(model_path, 'rb').read()
|
||||
@ -69,6 +78,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||
constants.FLOAT,
|
||||
constants.FLOAT, True,
|
||||
activations_type,
|
||||
enable_mlir)
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
@ -88,9 +98,13 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('EnableMlirQuantizer', True), # enable mlir quantizer
|
||||
('DisableMlirQuantizer', False)) # disable mlir quantizer
|
||||
def test_calibration_with_quantization_multiple_inputs(self, enable_mlir):
|
||||
# Activation type Int8 - enable mlir quantizer
|
||||
('UseActivationTypeInt8 - EnableMlirQuantizer', constants.INT8, True),
|
||||
# Activation type Int8 - disable mlir quantizer
|
||||
('UseActivationTypeInt8 - DisableMlirQuantizer', constants.INT8, False),
|
||||
# Activation type Int16 - disable mlir quantizer
|
||||
('UseActivationTypeInt16 - DisableEnableMlirQuantizer', constants.INT16, False))
|
||||
def test_calibration_with_quantization_multiple_inputs(self, activations_type, enable_mlir):
|
||||
# Load multi add model from test data.
|
||||
# This model has 4 inputs of size (1, 8, 8, 3).
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
@ -106,6 +120,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||
constants.FLOAT,
|
||||
constants.FLOAT, False,
|
||||
activations_type,
|
||||
enable_mlir)
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
@ -148,7 +163,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'Size mismatch'):
|
||||
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
||||
constants.FLOAT, False, enable_mlir)
|
||||
constants.FLOAT, False,
|
||||
enable_mlir)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('EnableMlirQuantizer', True), # enable mlir quantizer
|
||||
@ -166,7 +182,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
||||
constants.FLOAT, False, enable_mlir)
|
||||
constants.FLOAT, False,
|
||||
constants.INT8, enable_mlir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -64,6 +64,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||
property.inputs = {{0, {}}, {1, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.version = 2;
|
||||
property.quantize_input_as_activations = true;
|
||||
break;
|
||||
case BuiltinOperator_ARG_MAX:
|
||||
property.inputs = {{0, {}}};
|
||||
@ -176,7 +177,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||
// LogSoftmax requires output with 16/256 as scale and 127 as zero point.
|
||||
TensorProperty tensor_property;
|
||||
tensor_property.restriction = true;
|
||||
tensor_property.restricted_value = {16.0 / 256.0, 127};
|
||||
tensor_property.restricted_value_int8 = {16.0 / 256.0, 127};
|
||||
property.outputs = {{0, tensor_property}};
|
||||
property.version = 2;
|
||||
break;
|
||||
@ -186,7 +187,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||
// Logistic requires output with 1/256 as scale and -128 as zero point.
|
||||
TensorProperty tensor_property;
|
||||
tensor_property.restriction = true;
|
||||
tensor_property.restricted_value = {1 / 256.0, -128};
|
||||
tensor_property.restricted_value_int8 = {1 / 256.0, -128};
|
||||
tensor_property.restricted_value_int16 = {1 / 32768.0, 0};
|
||||
property.outputs = {{0, tensor_property}};
|
||||
property.version = 2;
|
||||
break;
|
||||
@ -741,7 +743,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||
// L2 Norm requires output with 1/128 as scale and 0 as zero point.
|
||||
TensorProperty tensor_property;
|
||||
tensor_property.restriction = true;
|
||||
tensor_property.restricted_value = {1 / 128.0, 0};
|
||||
tensor_property.restricted_value_int8 = {1 / 128.0, 0};
|
||||
property.outputs = {{0, tensor_property}};
|
||||
property.version = 2;
|
||||
break;
|
||||
@ -756,6 +758,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||
property.arbitrary_inputs = true;
|
||||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.quantize_input_as_activations = true;
|
||||
property.version = 2;
|
||||
break;
|
||||
case BuiltinOperator_MEAN:
|
||||
@ -767,6 +770,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||
property.arbitrary_inputs = true;
|
||||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.quantize_input_as_activations = true;
|
||||
property.version = 2;
|
||||
break;
|
||||
case BuiltinOperator_MUL:
|
||||
@ -778,6 +782,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||
property.arbitrary_inputs = true;
|
||||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.version = 2;
|
||||
break;
|
||||
case BuiltinOperator_PAD:
|
||||
@ -840,7 +845,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||
// Softmax requires output with 1/256 as scale and -128 as zero point.
|
||||
TensorProperty tensor_property;
|
||||
tensor_property.restriction = true;
|
||||
tensor_property.restricted_value = {1 / 256.0, -128};
|
||||
tensor_property.restricted_value_int8 = {1 / 256.0, -128};
|
||||
tensor_property.restricted_value_int16 = {1 / 32768.0, 0};
|
||||
property.outputs = {{0, tensor_property}};
|
||||
property.version = 2;
|
||||
break;
|
||||
@ -866,7 +872,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||
// Tanh requires output with 1/128 as scale and 0 as zero point.
|
||||
TensorProperty tensor_property;
|
||||
tensor_property.restriction = true;
|
||||
tensor_property.restricted_value = {1 / 128.0, 0};
|
||||
tensor_property.restricted_value_int8 = {1 / 128.0, 0};
|
||||
tensor_property.restricted_value_int16 = {1 / 32768.0, 0};
|
||||
property.outputs = {{0, tensor_property}};
|
||||
property.version = 2;
|
||||
break;
|
||||
|
@ -43,7 +43,8 @@ struct TensorProperty {
|
||||
// Constraints.
|
||||
bool restriction = false;
|
||||
// scale/zero_point hardcoded.
|
||||
std::pair<float, int> restricted_value = {0.0, 0};
|
||||
std::pair<float, int> restricted_value_int8 = {0.0, 0};
|
||||
std::pair<float, int> restricted_value_int16 = {0.0, 0};
|
||||
|
||||
// Use derived scale.
|
||||
bool use_derived_scale = false;
|
||||
@ -93,6 +94,13 @@ struct OperatorProperty {
|
||||
|
||||
// Op version.
|
||||
int version = 1;
|
||||
|
||||
// When we quantize activations into 16 bit and weights into 8 bit,
|
||||
// we want to quantize all inputs, including constant tensors,
|
||||
// for the operators like Add, Mul into 16-bit as well. The constant
|
||||
// inputs are quantized as weights and this variable indicates
|
||||
// that we want to do quantizations of these tensors as activations.
|
||||
bool quantize_input_as_activations = false;
|
||||
};
|
||||
|
||||
OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
@ -30,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/tools/optimize/model_utils.h"
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
|
||||
namespace tflite {
|
||||
namespace optimize {
|
||||
@ -85,6 +85,46 @@ void GetAsymmetricQuantizationParams(
|
||||
quantization_params->zero_point = std::vector<int64_t>(1, zero_point);
|
||||
}
|
||||
|
||||
void GetSymmetricQuantizationParams(
|
||||
float min, float max, const int half_quant_range,
|
||||
QuantizationParametersT* quantization_params) {
|
||||
// Adjust the boundaries to guarantee 0 is included.
|
||||
min = std::min(min, 0.0f);
|
||||
max = std::max(max, 0.0f);
|
||||
const float scale = std::max(std::abs(max), std::abs(min)) / half_quant_range;
|
||||
int64_t zero_point = 0;
|
||||
quantization_params->min = std::vector<float>(1, min);
|
||||
quantization_params->max = std::vector<float>(1, max);
|
||||
quantization_params->scale = std::vector<float>(1, scale);
|
||||
quantization_params->zero_point = std::vector<int64_t>(1, 0);
|
||||
}
|
||||
|
||||
TfLiteStatus GetQuantizationParams(TensorT* tensor, TensorType activations_type,
|
||||
QuantizationParametersT* quantization_params,
|
||||
ErrorReporter* error_reporter) {
|
||||
if (activations_type == TensorType_INT8) {
|
||||
GetAsymmetricQuantizationParams(
|
||||
tensor->quantization->min[0], tensor->quantization->max[0],
|
||||
std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max(),
|
||||
quantization_params);
|
||||
} else if (activations_type == TensorType_INT16) {
|
||||
float range = std::max(std::abs(tensor->quantization->min[0]),
|
||||
std::abs(tensor->quantization->max[0]));
|
||||
const float quantized_range = 32767.0;
|
||||
const float scale = range / quantized_range;
|
||||
quantization_params->min = std::vector<float>(1, -range);
|
||||
quantization_params->max = std::vector<float>(1, range);
|
||||
quantization_params->scale = std::vector<float>(1, scale);
|
||||
quantization_params->zero_point = std::vector<int64_t>(1, 0);
|
||||
} else {
|
||||
error_reporter->Report(
|
||||
"Unsupported activation type for quantize-activation: %s",
|
||||
activations_type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Set the max and min quantization parameter for a single tensor given its
|
||||
// values.
|
||||
void FillSingleMinMax(const float* const input, const uint64_t input_size,
|
||||
@ -536,6 +576,7 @@ TfLiteStatus SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor,
|
||||
model, tensor, error_reporter);
|
||||
}
|
||||
|
||||
template <class BiasType>
|
||||
TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
float scaling_factor,
|
||||
ErrorReporter* error_reporter) {
|
||||
@ -548,25 +589,38 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
uint64_t num_elements;
|
||||
TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements));
|
||||
|
||||
std::vector<int32_t> final_buffer(num_elements);
|
||||
const int32_t kScale = std::numeric_limits<int32_t>::max();
|
||||
std::vector<BiasType> final_buffer(num_elements);
|
||||
const BiasType kScale = std::numeric_limits<BiasType>::max();
|
||||
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
const int32_t quantized_value = tflite::SafeCast<int32_t>(
|
||||
const BiasType quantized_value = tflite::SafeCast<BiasType>(
|
||||
TfLiteRound(float_data[i] * scaling_factor_inv));
|
||||
final_buffer[i] = std::min(kScale, std::max(-kScale, quantized_value));
|
||||
}
|
||||
|
||||
// Set the buffers and output type.
|
||||
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
|
||||
size_t buffer_size = num_elements * sizeof(int32_t);
|
||||
size_t buffer_size = num_elements * sizeof(BiasType);
|
||||
std::vector<float> scales(1, scaling_factor);
|
||||
std::vector<int64_t> zero_points(1, 0);
|
||||
|
||||
auto output_type = std::is_same<BiasType, std::int32_t>::value
|
||||
? TensorType_INT32
|
||||
: TensorType_INT64;
|
||||
return AddQuantizationParams(scales, zero_points, 0, uint8_buffer,
|
||||
buffer_size, TensorType_INT32, model, tensor,
|
||||
buffer_size, output_type, model, tensor,
|
||||
error_reporter);
|
||||
}
|
||||
|
||||
template TfLiteStatus SymmetricPerLayerBiasQuantize<std::int32_t>(
|
||||
ModelT* model, TensorT* tensor, float scaling_factor,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
template TfLiteStatus SymmetricPerLayerBiasQuantize<std::int64_t>(
|
||||
ModelT* model, TensorT* tensor, float scaling_factor,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
template <class BiasType>
|
||||
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
float input_scale,
|
||||
const float* weight_scales,
|
||||
@ -583,14 +637,14 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
uint64_t num_elements;
|
||||
TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements));
|
||||
|
||||
std::vector<int32_t> final_buffer(num_elements);
|
||||
const int32_t kScale = std::numeric_limits<int32_t>::max();
|
||||
std::vector<BiasType> final_buffer(num_elements);
|
||||
const BiasType kScale = std::numeric_limits<BiasType>::max();
|
||||
|
||||
for (int32_t channel_idx = 0; channel_idx < number_of_dimension;
|
||||
channel_idx++) {
|
||||
float scaling_factor = scales[channel_idx];
|
||||
float scaling_factor_inv = (scaling_factor == 0) ? 0 : 1.0 / scaling_factor;
|
||||
const int32_t quantized_value = tflite::SafeCast<int32_t>(
|
||||
const BiasType quantized_value = tflite::SafeCast<BiasType>(
|
||||
TfLiteRound(float_data[channel_idx] * scaling_factor_inv));
|
||||
final_buffer[channel_idx] =
|
||||
std::min(kScale, std::max(-kScale, quantized_value));
|
||||
@ -598,12 +652,26 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
|
||||
// Set the buffers and output type.
|
||||
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
|
||||
size_t buffer_size = num_elements * sizeof(int32_t);
|
||||
size_t buffer_size = num_elements * sizeof(BiasType);
|
||||
std::vector<int64_t> zero_point(scales.size(), 0);
|
||||
|
||||
auto output_type = std::is_same<BiasType, std::int32_t>::value
|
||||
? TensorType_INT32
|
||||
: TensorType_INT64;
|
||||
return AddQuantizationParams(scales, zero_point, 0, uint8_buffer, buffer_size,
|
||||
TensorType_INT32, model, tensor, error_reporter);
|
||||
output_type, model, tensor, error_reporter);
|
||||
}
|
||||
|
||||
template TfLiteStatus SymmetricPerChannelBiasQuantize<std::int64_t>(
|
||||
ModelT* model, TensorT* tensor, float input_scale,
|
||||
const float* weight_scales, int number_of_dimension,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
template TfLiteStatus SymmetricPerChannelBiasQuantize<std::int32_t>(
|
||||
ModelT* model, TensorT* tensor, float input_scale,
|
||||
const float* weight_scales, int number_of_dimension,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
|
||||
int per_axis_index, ErrorReporter* error_reporter) {
|
||||
// TODO(suharshs): Currently we conflate quantizing weights and constants. Its
|
||||
@ -645,12 +713,12 @@ float GetEffectiveScale(ModelT* model, SubGraphT* subgraph, int op_idx,
|
||||
return scale;
|
||||
}
|
||||
|
||||
void QuantizeActivation(TensorT* tensor) {
|
||||
GetAsymmetricQuantizationParams(
|
||||
tensor->quantization->min[0], tensor->quantization->max[0],
|
||||
std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max(),
|
||||
tensor->quantization.get());
|
||||
tensor->type = TensorType_INT8;
|
||||
TfLiteStatus QuantizeActivation(TensorT* tensor, TensorType activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
TF_LITE_ENSURE_STATUS(GetQuantizationParams(
|
||||
tensor, activations_type, tensor->quantization.get(), error_reporter));
|
||||
tensor->type = activations_type;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale) {
|
||||
|
@ -113,12 +113,14 @@ TfLiteStatus SymmetricQuantizeFloatsToInt16(ModelT* model, TensorT* tensor,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
// Symmetrically quantized the bias for per-layer ops (i.e. FullyConnected).
|
||||
template <typename BiasType>
|
||||
TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
float scaling_factor,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
// Symmetrically quantizes the bias for ops like Conv and DepthwiseConv.
|
||||
// The scale of bias if weight_per_channel_scale[channel] * input_scale.
|
||||
template <typename BiasType>
|
||||
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
float input_scale,
|
||||
const float* weight_scales,
|
||||
@ -135,8 +137,14 @@ float GetEffectiveScale(ModelT* model, SubGraphT* subgraph, int op_idx,
|
||||
std::vector<int> intermediate_index,
|
||||
std::vector<float> factors);
|
||||
|
||||
// Return quantization parameters depending on activations type.
|
||||
TfLiteStatus GetQuantizationParams(TensorT* tensor, TensorType activations_type,
|
||||
QuantizationParametersT* quantization_params,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
// Quantize activation.
|
||||
void QuantizeActivation(TensorT* tensor);
|
||||
TfLiteStatus QuantizeActivation(TensorT* tensor, TensorType activations_type,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
// Quantize activation to 16bit.
|
||||
TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale);
|
||||
|
@ -701,7 +701,7 @@ TEST_F(QuantizationUtilsTest, SymmetricPerLayerBiasQuantize) {
|
||||
model->buffers.push_back(std::move(buffer));
|
||||
|
||||
// Call and verify.
|
||||
EXPECT_EQ(SymmetricPerLayerBiasQuantize(
|
||||
EXPECT_EQ(SymmetricPerLayerBiasQuantize<int32_t>(
|
||||
model.get(), model->subgraphs[0]->tensors[0].get(),
|
||||
input_scale * weight_scale, &error_reporter_),
|
||||
kTfLiteOk);
|
||||
@ -759,7 +759,7 @@ TEST_F(QuantizationUtilsTest, SymmetricPerChannelBiasQuantize) {
|
||||
model->buffers.push_back(std::move(buffer));
|
||||
|
||||
// Call and verify.
|
||||
EXPECT_EQ(SymmetricPerChannelBiasQuantize(
|
||||
EXPECT_EQ(SymmetricPerChannelBiasQuantize<int32_t>(
|
||||
model.get(), model->subgraphs[0]->tensors[0].get(), input_scale,
|
||||
weight_scales.data(), 2, &error_reporter_),
|
||||
kTfLiteOk);
|
||||
|
@ -42,7 +42,9 @@ bool CreateQuantizedModel(const std::string& path) {
|
||||
tflite::StderrReporter error_reporter;
|
||||
if (tflite::optimize::QuantizeModel(
|
||||
&builder, &model, tflite::TensorType_FLOAT32,
|
||||
tflite::TensorType_FLOAT32, &error_reporter) != kTfLiteOk) {
|
||||
tflite::TensorType_FLOAT32,
|
||||
// TODO: Pass required activation type if needed
|
||||
tflite::TensorType_INT8, &error_reporter) != kTfLiteOk) {
|
||||
return false;
|
||||
}
|
||||
return WriteFile(path, builder.GetBufferPointer(), builder.GetSize());
|
||||
|
@ -64,6 +64,7 @@ operator_property::OperatorProperty GetOperatorProperty(
|
||||
TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
|
||||
const TensorT* weight_tensor, TensorT* bias_tensor,
|
||||
bool is_per_channel, int channel_dim_index,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
if (bias_tensor->shape.size() != 1) {
|
||||
error_reporter->Report("Expected bias tensor shape to be 1.");
|
||||
@ -92,9 +93,15 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
|
||||
weight_scales.size());
|
||||
return kTfLiteError;
|
||||
}
|
||||
return utils::SymmetricPerChannelBiasQuantize(
|
||||
model, bias_tensor, input_tensor->quantization->scale[0],
|
||||
weight_scales.data(), channel_dim_size, error_reporter);
|
||||
if (activations_type == tflite::TensorType_INT16) {
|
||||
return utils::SymmetricPerChannelBiasQuantize<std::int64_t>(
|
||||
model, bias_tensor, input_tensor->quantization->scale[0],
|
||||
weight_scales.data(), channel_dim_size, error_reporter);
|
||||
} else {
|
||||
return utils::SymmetricPerChannelBiasQuantize<std::int32_t>(
|
||||
model, bias_tensor, input_tensor->quantization->scale[0],
|
||||
weight_scales.data(), channel_dim_size, error_reporter);
|
||||
}
|
||||
} else {
|
||||
if (weight_scales.size() != 1) {
|
||||
error_reporter->Report(
|
||||
@ -102,40 +109,54 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
|
||||
weight_scales.size());
|
||||
return kTfLiteError;
|
||||
}
|
||||
return utils::SymmetricPerLayerBiasQuantize(
|
||||
model, bias_tensor,
|
||||
input_tensor->quantization->scale[0] * weight_scales[0],
|
||||
error_reporter);
|
||||
if (activations_type == tflite::TensorType_INT16) {
|
||||
return utils::SymmetricPerLayerBiasQuantize<std::int64_t>(
|
||||
model, bias_tensor,
|
||||
input_tensor->quantization->scale[0] * weight_scales[0],
|
||||
error_reporter);
|
||||
} else {
|
||||
return utils::SymmetricPerLayerBiasQuantize<std::int32_t>(
|
||||
model, bias_tensor,
|
||||
input_tensor->quantization->scale[0] * weight_scales[0],
|
||||
error_reporter);
|
||||
}
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// True if the tensor type has to be modified.
|
||||
bool TensorTypeChangeRequired(const TensorT* tensor, const TensorType& type) {
|
||||
// The quantized model is type INT8, so if the user provided type is INT8, we
|
||||
// do not have to do any custom logic. Additionally, if the current tensor
|
||||
// isn't INT8 quantized, the custom type doesn't apply.
|
||||
return (type != TensorType_INT8 && tensor->type == TensorType_INT8 &&
|
||||
!tensor->quantization->scale.empty());
|
||||
// The quantized model is type INT8/INT16, so if the user provided type is
|
||||
// INT8/INT16, we do not have to do any custom logic. Additionally, if the
|
||||
// current tensor isn't INT8/INT16 quantized, the custom type doesn't apply.
|
||||
bool int8check = type != TensorType_INT8 && tensor->type == TensorType_INT8 &&
|
||||
!tensor->quantization->scale.empty();
|
||||
bool int16check = type != TensorType_INT16 &&
|
||||
tensor->type == TensorType_INT16 &&
|
||||
!tensor->quantization->scale.empty();
|
||||
return (int8check || int16check);
|
||||
}
|
||||
|
||||
// Sets the input type, adding a Leading Op node at the start of the model if
|
||||
// necessary.
|
||||
// Returns the new input tensor index.
|
||||
int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
|
||||
const int32_t tensor_idx, const TensorType& input_type) {
|
||||
const int32_t tensor_idx, const TensorType& input_type,
|
||||
const TensorType& activations_type) {
|
||||
TensorT* tensor = subgraph->tensors[tensor_idx].get();
|
||||
if (!TensorTypeChangeRequired(tensor, input_type)) {
|
||||
return -1;
|
||||
}
|
||||
if (input_type == TensorType_FLOAT32 || input_type == TensorType_UINT8) {
|
||||
std::string type_string =
|
||||
activations_type == TensorType_INT16 ? "int16" : "int8";
|
||||
// Create a new tensor to be the input of the leading Op.
|
||||
std::unique_ptr<TensorT> leading_op_input;
|
||||
if (input_type == TensorType_FLOAT32) {
|
||||
// Add tensor for quantize operator. Scales and zero points are not
|
||||
// needed.
|
||||
const string leading_op_name = tensor->name;
|
||||
const string new_name_original_input = tensor->name + "_int8";
|
||||
const string new_name_original_input = tensor->name + "_" + type_string;
|
||||
tensor->name = new_name_original_input;
|
||||
utils::MakeTensor(leading_op_name, tensor->shape, input_type,
|
||||
&leading_op_input);
|
||||
@ -150,7 +171,7 @@ int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
|
||||
TFLITE_DCHECK_GE(zero_point, -128);
|
||||
TFLITE_DCHECK_LE(zero_point, 127);
|
||||
const string leading_op_name = tensor->name;
|
||||
const string new_name_original_input = tensor->name + "_int8";
|
||||
const string new_name_original_input = tensor->name + "_" + type_string;
|
||||
tensor->name = new_name_original_input;
|
||||
utils::MakeTensorWithQuantParam(leading_op_name, tensor->shape,
|
||||
input_type, scale, zero_point + 128,
|
||||
@ -177,17 +198,20 @@ int32_t SetInputType(ModelT* model, SubGraphT* subgraph,
|
||||
// necessary.
|
||||
// Returns the new output tensor index.
|
||||
int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
|
||||
const int32_t tensor_idx, const TensorType& output_type) {
|
||||
const int32_t tensor_idx, const TensorType& output_type,
|
||||
const TensorType& activations_type) {
|
||||
TensorT* tensor = subgraph->tensors[tensor_idx].get();
|
||||
if (!TensorTypeChangeRequired(tensor, output_type)) {
|
||||
return -1;
|
||||
}
|
||||
if (output_type == TensorType_FLOAT32 || output_type == TensorType_UINT8) {
|
||||
std::string type_string =
|
||||
activations_type == TensorType_INT16 ? "int16" : "int8";
|
||||
// Create a new tensor to be the output of the tailing op.
|
||||
std::unique_ptr<TensorT> tailing_op_output;
|
||||
if (output_type == TensorType_FLOAT32) {
|
||||
const string tailing_op_name = tensor->name;
|
||||
const string new_name_original_output = tensor->name + "_int8";
|
||||
const string new_name_original_output = tensor->name + "_" + type_string;
|
||||
tensor->name = new_name_original_output;
|
||||
utils::MakeTensor(tailing_op_name, tensor->shape, output_type,
|
||||
&tailing_op_output);
|
||||
@ -202,7 +226,7 @@ int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
|
||||
TFLITE_DCHECK_GE(zero_point, -128);
|
||||
TFLITE_DCHECK_LE(zero_point, 127);
|
||||
const string tailing_op_name = tensor->name;
|
||||
const string new_name_original_output = tensor->name + "_int8";
|
||||
const string new_name_original_output = tensor->name + "_" + type_string;
|
||||
tensor->name = new_name_original_output;
|
||||
utils::MakeTensorWithQuantParam(tailing_op_name, tensor->shape,
|
||||
output_type, scale, zero_point + 128,
|
||||
@ -238,6 +262,7 @@ int32_t SetOutputType(ModelT* model, SubGraphT* subgraph,
|
||||
// uint8, can be thought as "requant").
|
||||
TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
@ -253,8 +278,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
||||
EnumNameTensorType(tensor->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
const int32_t input_idx =
|
||||
SetInputType(model, subgraph, subgraph->inputs[i], input_type);
|
||||
const int32_t input_idx = SetInputType(
|
||||
model, subgraph, subgraph->inputs[i], input_type, activations_type);
|
||||
if (input_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
@ -270,8 +295,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
||||
EnumNameTensorType(tensor->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
const int32_t output_idx =
|
||||
SetOutputType(model, subgraph, subgraph->outputs[i], output_type);
|
||||
const int32_t output_idx = SetOutputType(
|
||||
model, subgraph, subgraph->outputs[i], output_type, activations_type);
|
||||
if (output_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
@ -287,6 +312,7 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
||||
// The other ones with constraints are handled in QuantizeWeightsAndInput.
|
||||
TfLiteStatus ApplyConstraints(ModelT* model,
|
||||
const std::unordered_set<string>& operator_names,
|
||||
TensorType activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
@ -332,7 +358,7 @@ TfLiteStatus ApplyConstraints(ModelT* model,
|
||||
std::unique_ptr<TensorT> additional_tensor;
|
||||
const string requant_tensor_name = input_tensor->name + "_requantized";
|
||||
utils::MakeTensorWithQuantParam(
|
||||
requant_tensor_name, input_tensor->shape, TensorType_INT8,
|
||||
requant_tensor_name, input_tensor->shape, activations_type,
|
||||
output_scale, output_zp, &additional_tensor);
|
||||
const int32_t additional_tensor_idx = subgraph->tensors.size();
|
||||
subgraph->tensors.push_back(std::move(additional_tensor));
|
||||
@ -382,7 +408,8 @@ std::vector<std::pair<int, operator_property::TensorProperty>> GetOutputs(
|
||||
|
||||
bool ShouldRestrictSameInputOutputScale(
|
||||
operator_property::OperatorProperty property) {
|
||||
// Ops with multiple inputs (i.e. concat) gets restricted in ApplyConstraints.
|
||||
// Ops with multiple inputs (i.e. concat, max and min) gets restricted in
|
||||
// ApplyConstraints.
|
||||
return (!property.arbitrary_inputs &&
|
||||
property.restrict_same_input_output_scale);
|
||||
}
|
||||
@ -401,7 +428,7 @@ TfLiteStatus QuantizeOpInput(
|
||||
ModelT* model, int32_t subgraph_idx, size_t* op_idx,
|
||||
operator_property::OperatorProperty property,
|
||||
const std::pair<int32_t, operator_property::TensorProperty>& input,
|
||||
ErrorReporter* error_reporter) {
|
||||
const TensorType& activations_type, ErrorReporter* error_reporter) {
|
||||
int32_t input_idx = input.first;
|
||||
operator_property::TensorProperty tensor_property = input.second;
|
||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||
@ -429,7 +456,9 @@ TfLiteStatus QuantizeOpInput(
|
||||
if (utils::HasBuffer(model, subgraph, tensor_idx)) {
|
||||
// TODO(suharshs): Look at consumers, throw error if one consumer is
|
||||
// per-channel and one per-layer.
|
||||
if (tensor_property.number_of_bits == 8) {
|
||||
bool quantize_const_input = property.quantize_input_as_activations &&
|
||||
activations_type == TensorType_INT16;
|
||||
if (tensor_property.number_of_bits == 8 && !quantize_const_input) {
|
||||
if (tensor_property.use_derived_scale) {
|
||||
// Currently 8bit tensors in input do not accept derived scale.
|
||||
return kTfLiteError;
|
||||
@ -444,7 +473,7 @@ TfLiteStatus QuantizeOpInput(
|
||||
*op_idx);
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else if (tensor_property.number_of_bits == 16) {
|
||||
} else if (tensor_property.number_of_bits == 16 || quantize_const_input) {
|
||||
if (tensor_property.use_derived_scale) {
|
||||
// Currently 16bit tensors in input do not accept derived scale.
|
||||
return kTfLiteError;
|
||||
@ -476,8 +505,8 @@ TfLiteStatus QuantizeOpInput(
|
||||
tensor_property.derived_scale.input_tensors,
|
||||
tensor_property.derived_scale.intermediate_tensors,
|
||||
tensor_property.derived_scale.factors);
|
||||
return utils::SymmetricPerLayerBiasQuantize(model, tensor, scale,
|
||||
error_reporter);
|
||||
return utils::SymmetricPerLayerBiasQuantize<std::int32_t>(
|
||||
model, tensor, scale, error_reporter);
|
||||
|
||||
} else if (tensor_property.number_of_bits == 10) {
|
||||
// When the number of bits is 10 (instead of 16), quantize the tensor to
|
||||
@ -514,7 +543,8 @@ TfLiteStatus QuantizeOpInput(
|
||||
// Currently 8bit tensors in input do not accept derived scale.
|
||||
return kTfLiteError;
|
||||
}
|
||||
utils::QuantizeActivation(tensor);
|
||||
TF_LITE_ENSURE_STATUS(utils::QuantizeActivation(
|
||||
tensor, activations_type, error_reporter));
|
||||
} else if (tensor_property.number_of_bits == 16) {
|
||||
TensorT* tensor = subgraph->tensors[tensor_idx].get();
|
||||
float quantized_range = 32767.0;
|
||||
@ -532,13 +562,16 @@ TfLiteStatus QuantizeOpInput(
|
||||
} else {
|
||||
// If the tensor is not a model input, we need to add a Quantize
|
||||
// operation since the preceding op may require a float output.
|
||||
std::string type_string =
|
||||
activations_type == TensorType_INT16 ? "int16" : "int8";
|
||||
std::unique_ptr<TensorT> op_output;
|
||||
utils::MakeTensor(tensor->name + "_int8", tensor->shape,
|
||||
TensorType_INT8, &op_output);
|
||||
utils::MakeTensor(tensor->name + "_" + type_string, tensor->shape,
|
||||
activations_type, &op_output);
|
||||
op_output->quantization = absl::make_unique<QuantizationParametersT>();
|
||||
op_output->quantization->min.push_back(tensor->quantization->min[0]);
|
||||
op_output->quantization->max.push_back(tensor->quantization->max[0]);
|
||||
utils::QuantizeActivation(op_output.get());
|
||||
TF_LITE_ENSURE_STATUS(utils::QuantizeActivation(
|
||||
op_output.get(), activations_type, error_reporter));
|
||||
const int32_t quant_op_output_idx = subgraph->tensors.size();
|
||||
subgraph->tensors.push_back(std::move(op_output));
|
||||
std::unique_ptr<OperatorT> quant_op;
|
||||
@ -580,7 +613,7 @@ TfLiteStatus QuantizeOpOutput(
|
||||
ModelT* model, int32_t subgraph_idx, int32_t op_idx,
|
||||
operator_property::OperatorProperty property,
|
||||
const std::pair<int32_t, operator_property::TensorProperty>& output,
|
||||
ErrorReporter* error_reporter) {
|
||||
TensorType activations_type, ErrorReporter* error_reporter) {
|
||||
int32_t output_idx = output.first;
|
||||
operator_property::TensorProperty tensor_property = output.second;
|
||||
// If the operator is not quantizable, we don't need to do anything for the
|
||||
@ -644,18 +677,22 @@ TfLiteStatus QuantizeOpOutput(
|
||||
const float max = input_tensor->quantization->max[0];
|
||||
output_tensor->quantization->max = {max};
|
||||
}
|
||||
output_tensor->type = TensorType_INT8;
|
||||
output_tensor->type = activations_type;
|
||||
} else if (tensor_property.restriction) {
|
||||
const auto scale_and_zp = tensor_property.restricted_value;
|
||||
const auto scale_and_zp = activations_type == TensorType_INT16
|
||||
? tensor_property.restricted_value_int16
|
||||
: tensor_property.restricted_value_int8;
|
||||
|
||||
// Apply to output.
|
||||
output_tensor->quantization = absl::make_unique<QuantizationParametersT>();
|
||||
output_tensor->quantization->scale.push_back(scale_and_zp.first);
|
||||
output_tensor->quantization->zero_point.push_back(scale_and_zp.second);
|
||||
output_tensor->type = TensorType_INT8;
|
||||
output_tensor->type = activations_type;
|
||||
} else {
|
||||
// Process regular output that doesn't have any restrictions.
|
||||
if (utils::HasMinMax(output_tensor)) {
|
||||
utils::QuantizeActivation(output_tensor);
|
||||
utils::QuantizeActivation(output_tensor, activations_type,
|
||||
error_reporter);
|
||||
} else {
|
||||
error_reporter->Report(
|
||||
"Unable to find min/max value for output %d in %s in "
|
||||
@ -668,6 +705,7 @@ TfLiteStatus QuantizeOpOutput(
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeIntemediateTensors(ModelT* model,
|
||||
TensorType activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
@ -691,7 +729,8 @@ TfLiteStatus QuantizeIntemediateTensors(ModelT* model,
|
||||
input.second.symmetric == false) {
|
||||
TensorT* tensor = subgraph->tensors[index_global].get();
|
||||
if (utils::HasMinMax(tensor)) {
|
||||
utils::QuantizeActivation(tensor);
|
||||
utils::QuantizeActivation(tensor, activations_type,
|
||||
error_reporter);
|
||||
} else {
|
||||
error_reporter->Report(
|
||||
"Unable to find min/max value for output %d in %s in "
|
||||
@ -793,7 +832,7 @@ TfLiteStatus QuantizeSharedRange(ModelT* model, ErrorReporter* error_reporter) {
|
||||
TfLiteStatus QuantizeWeightsInputOutput(
|
||||
ModelT* model, bool allow_float,
|
||||
const std::unordered_set<string>& operator_names,
|
||||
ErrorReporter* error_reporter) {
|
||||
const TensorType& activations_type, ErrorReporter* error_reporter) {
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||
@ -815,14 +854,16 @@ TfLiteStatus QuantizeWeightsInputOutput(
|
||||
for (const std::pair<int, operator_property::TensorProperty>& input :
|
||||
GetInputs(op, property)) {
|
||||
TF_LITE_ENSURE_STATUS(QuantizeOpInput(model, subgraph_idx, &op_idx,
|
||||
property, input, error_reporter));
|
||||
property, input, activations_type,
|
||||
error_reporter));
|
||||
}
|
||||
|
||||
// Quantize operator outputs.
|
||||
for (const std::pair<int, operator_property::TensorProperty>& output :
|
||||
GetOutputs(op, property)) {
|
||||
TF_LITE_ENSURE_STATUS(QuantizeOpOutput(
|
||||
model, subgraph_idx, op_idx, property, output, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
QuantizeOpOutput(model, subgraph_idx, op_idx, property, output,
|
||||
activations_type, error_reporter));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -832,6 +873,7 @@ TfLiteStatus QuantizeWeightsInputOutput(
|
||||
// Quantize bias.
|
||||
TfLiteStatus QuantizeBiases(ModelT* model,
|
||||
const std::unordered_set<string>& operator_names,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
@ -877,10 +919,10 @@ TfLiteStatus QuantizeBiases(ModelT* model,
|
||||
subgraph->tensors[op->inputs[property.inputs[1].first]].get();
|
||||
operator_property::TensorProperty weight_property =
|
||||
property.inputs[1].second;
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
QuantizeBias(model, input_tensor, weight_tensor, bias_tensor,
|
||||
weight_property.per_axis,
|
||||
weight_property.per_axis_index, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(QuantizeBias(
|
||||
model, input_tensor, weight_tensor, bias_tensor,
|
||||
weight_property.per_axis, weight_property.per_axis_index,
|
||||
activations_type, error_reporter));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1000,7 +1042,7 @@ TfLiteStatus FillQuantizationParams(
|
||||
// Check compatibility of activation, weight and bias scales. Adjust if needed.
|
||||
TfLiteStatus EnsureBiasScaleCompatibility(
|
||||
ModelT* model, const std::unordered_set<string>& operator_names,
|
||||
ErrorReporter* error_reporter) {
|
||||
TensorType activations_type, ErrorReporter* error_reporter) {
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||
@ -1049,11 +1091,9 @@ TfLiteStatus EnsureBiasScaleCompatibility(
|
||||
|
||||
// Get input scale for assymmetric quantization.
|
||||
QuantizationParametersT temp_quant_params = QuantizationParametersT();
|
||||
utils::GetAsymmetricQuantizationParams(
|
||||
input_tensor->quantization->min[0],
|
||||
input_tensor->quantization->max[0],
|
||||
std::numeric_limits<int8_t>::min(),
|
||||
std::numeric_limits<int8_t>::max(), &temp_quant_params);
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
utils::GetQuantizationParams(input_tensor, activations_type,
|
||||
&temp_quant_params, error_reporter));
|
||||
if (temp_quant_params.scale.size() != 1) {
|
||||
error_reporter->Report("Unexpected input quantization scale size.");
|
||||
return kTfLiteError;
|
||||
@ -1132,21 +1172,24 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type, bool allow_float,
|
||||
const std::unordered_set<string>& operator_names,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
FillQuantizationParams(model, operator_names, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(EnsureBiasScaleCompatibility(
|
||||
model, operator_names, activations_type, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
EnsureBiasScaleCompatibility(model, operator_names, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(QuantizeIntemediateTensors(model, error_reporter));
|
||||
QuantizeIntemediateTensors(model, activations_type, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(QuantizeSharedRange(model, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(QuantizeWeightsInputOutput(
|
||||
model, allow_float, operator_names, error_reporter));
|
||||
model, allow_float, operator_names, activations_type, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(ApplyConstraints(model, operator_names,
|
||||
activations_type, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
ApplyConstraints(model, operator_names, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(QuantizeBiases(model, operator_names, error_reporter));
|
||||
QuantizeBiases(model, operator_names, activations_type, error_reporter));
|
||||
utils::SetOperatorCodeVersion(model);
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
SetInputAndOutputTypes(model, input_type, output_type, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(SetInputAndOutputTypes(
|
||||
model, input_type, output_type, activations_type, error_reporter));
|
||||
|
||||
flatbuffers::Offset<Model> output_model_location =
|
||||
Model::Pack(*builder, model);
|
||||
@ -1158,23 +1201,27 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type, bool allow_float,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
return QuantizeModel(builder, model, input_type, output_type, allow_float,
|
||||
GetAllOperatorOutputs(model), error_reporter);
|
||||
GetAllOperatorOutputs(model), activations_type,
|
||||
error_reporter);
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
return QuantizeModel(builder, model, input_type, output_type,
|
||||
/*allow_float=*/false, error_reporter);
|
||||
/*allow_float=*/false, activations_type, error_reporter);
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model, ErrorReporter* error_reporter) {
|
||||
ModelT* model, const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter) {
|
||||
return QuantizeModel(builder, model, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||
/*allow_float=*/false, error_reporter);
|
||||
/*allow_float=*/false, activations_type, error_reporter);
|
||||
}
|
||||
|
||||
} // namespace optimize
|
||||
|
@ -35,7 +35,9 @@ namespace optimize {
|
||||
//
|
||||
// Note: This is a private API, subject to change.
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* input_model, ErrorReporter* error_reporter);
|
||||
ModelT* input_model,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
// Same as above, but the types of quantized inputs and outputs are
|
||||
// configurable.
|
||||
@ -44,6 +46,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* input_model, const TensorType& input_type,
|
||||
const TensorType& output_type,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
// Same as above, but can enable allowing float intermediate operations for ops
|
||||
@ -53,6 +56,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* input_model, const TensorType& input_type,
|
||||
const TensorType& output_type, bool allow_float,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
// Same as above, but enables only quantizing a whitelist of operations,
|
||||
@ -63,6 +67,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* input_model, const TensorType& input_type,
|
||||
const TensorType& output_type, bool allow_float,
|
||||
const std::unordered_set<string>& operator_names,
|
||||
const TensorType& activations_type,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
} // namespace optimize
|
||||
|
@ -80,28 +80,35 @@ class QuantizeModelTest : public testing::Test {
|
||||
internal::FailOnErrorReporter error_reporter_;
|
||||
};
|
||||
|
||||
class QuantizeConvModelTest : public QuantizeModelTest {
|
||||
class QuantizeConvModelTest : public QuantizeModelTest,
|
||||
public testing::WithParamInterface<TensorType> {
|
||||
protected:
|
||||
QuantizeConvModelTest() {
|
||||
tensor_type_ = GetParam();
|
||||
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
|
||||
readonly_model_ = input_model_->GetModel();
|
||||
readonly_model_->UnPackTo(&model_);
|
||||
}
|
||||
TensorType tensor_type_;
|
||||
};
|
||||
|
||||
TEST_F(QuantizeConvModelTest, QuantizationSucceeds) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
INSTANTIATE_TEST_SUITE_P(QuantizeConvModelTestInst, QuantizeConvModelTest,
|
||||
testing::ValuesIn({TensorType_INT8,
|
||||
TensorType_INT16}));
|
||||
|
||||
TEST_P(QuantizeConvModelTest, QuantizationSucceeds) {
|
||||
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
const uint8_t* buffer = builder_.GetBufferPointer();
|
||||
const Model* output_model = GetModel(buffer);
|
||||
ASSERT_TRUE(output_model);
|
||||
}
|
||||
|
||||
TEST_F(QuantizeConvModelTest, SkipUnspecifiedLayer) {
|
||||
TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) {
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||
/*allow_float=*/true, {}, &error_reporter_);
|
||||
/*allow_float=*/true, {}, tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
|
||||
// The resulting model should be the same.
|
||||
@ -123,9 +130,9 @@ TEST_F(QuantizeConvModelTest, SkipUnspecifiedLayer) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
TEST_P(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
|
||||
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
|
||||
@ -148,9 +155,9 @@ TEST_F(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
|
||||
EXPECT_EQ(model_.operator_codes[0]->version, 3);
|
||||
}
|
||||
|
||||
TEST_F(QuantizeConvModelTest, OperatorsAreUnchanged) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
TEST_P(QuantizeConvModelTest, OperatorsAreUnchanged) {
|
||||
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
ASSERT_EQ(model_.operator_codes.size(),
|
||||
readonly_model_->operator_codes()->size());
|
||||
@ -182,20 +189,28 @@ TEST_F(QuantizeConvModelTest, OperatorsAreUnchanged) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantizeConvModelTest, GraphIsFullyQuantized) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) {
|
||||
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
for (const auto& subgraph : model_.subgraphs) {
|
||||
for (const auto& tensor : subgraph->tensors) {
|
||||
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
|
||||
tensor->type == TensorType_INT8);
|
||||
if (tensor_type_ == TensorType_INT8) {
|
||||
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
|
||||
tensor->type == TensorType_INT8);
|
||||
} else if (tensor_type_ == TensorType_INT16) {
|
||||
EXPECT_TRUE(tensor->type == TensorType_INT64 || // bias
|
||||
tensor->type == TensorType_INT8 || // weights
|
||||
tensor->type == TensorType_INT16); // activations
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantizeConvModelTest, FloatInputAndOutput) {
|
||||
auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
|
||||
TEST_P(QuantizeConvModelTest, FloatInputAndOutput) {
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||
tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
|
||||
@ -234,22 +249,33 @@ TEST_F(QuantizeConvModelTest, FloatInputAndOutput) {
|
||||
EXPECT_EQ(subgraph->tensors[output_idx]->type, TensorType_FLOAT32);
|
||||
EXPECT_EQ(subgraph->tensors[output_idx]->name, "output");
|
||||
// The original input and output has been renamed.
|
||||
EXPECT_EQ(subgraph->tensors[quant_op->outputs[0]]->name, "input_int8");
|
||||
EXPECT_EQ(subgraph->tensors[dequant_op->inputs[0]]->name, "output_int8");
|
||||
std::string control_suffix =
|
||||
(tensor_type_ == TensorType_INT16) ? "int16" : "int8";
|
||||
EXPECT_EQ(subgraph->tensors[quant_op->outputs[0]]->name,
|
||||
"input_" + control_suffix);
|
||||
EXPECT_EQ(subgraph->tensors[dequant_op->inputs[0]]->name,
|
||||
"output_" + control_suffix);
|
||||
for (int tensor_idx = 0; tensor_idx < subgraph->tensors.size();
|
||||
++tensor_idx) {
|
||||
const auto& tensor = subgraph->tensors[tensor_idx];
|
||||
if (input_idx != tensor_idx && output_idx != tensor_idx) {
|
||||
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
|
||||
tensor->type == TensorType_INT8);
|
||||
if (tensor_type_ == TensorType_INT8) {
|
||||
EXPECT_TRUE(tensor->type == TensorType_INT32 ||
|
||||
tensor->type == TensorType_INT8);
|
||||
} else if (tensor_type_ == TensorType_INT16) {
|
||||
EXPECT_TRUE(tensor->type == TensorType_INT64 || // bias
|
||||
tensor->type == TensorType_INT8 || // weights
|
||||
tensor->type == TensorType_INT16); // activations
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantizeConvModelTest, Uint8InputAndOutput) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_UINT8,
|
||||
TensorType_UINT8, &error_reporter_);
|
||||
TEST_P(QuantizeConvModelTest, Uint8InputAndOutput) {
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_UINT8, TensorType_UINT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
|
||||
@ -326,21 +352,25 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest {
|
||||
};
|
||||
|
||||
TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
const uint8_t* buffer = builder_.GetBufferPointer();
|
||||
const Model* output_model = GetModel(buffer);
|
||||
ASSERT_TRUE(output_model);
|
||||
}
|
||||
|
||||
class QuantizeConcatModelTest : public QuantizeModelTest {
|
||||
class QuantizeConcatModelTest : public QuantizeModelTest,
|
||||
public testing::WithParamInterface<TensorType> {
|
||||
protected:
|
||||
QuantizeConcatModelTest() {
|
||||
input_model_ = ReadModel(internal::kFloatConcatMax5Max10Max10);
|
||||
readonly_model_ = input_model_->GetModel();
|
||||
readonly_model_->UnPackTo(&model_);
|
||||
}
|
||||
|
||||
TensorType tensor_type_;
|
||||
};
|
||||
|
||||
// There are two inputs for concat, "input0" and "input1". "input0" has [0, 5]
|
||||
@ -352,9 +382,9 @@ class QuantizeConcatModelTest : public QuantizeModelTest {
|
||||
// input0 -> requant -> input0_requant \
|
||||
// concat - output
|
||||
// input1 /
|
||||
TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
TEST_P(QuantizeConcatModelTest, AddRequantBeforeConcat) {
|
||||
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
// There is only one subgraph.
|
||||
@ -373,32 +403,51 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
|
||||
EXPECT_EQ(model_.operator_codes[concat->opcode_index]->builtin_code,
|
||||
BuiltinOperator_CONCATENATION);
|
||||
|
||||
auto zero_point_control = tensor_type_ == TensorType_INT8 ? -128 : 0;
|
||||
/*
|
||||
input0_scale_control
|
||||
INT8: (5-0) / (2^8 - 1)
|
||||
INT16: (5-0) / (2^16 / 2 - 1)
|
||||
input1_scale
|
||||
INT8: (10-0) / (2^8 - 1)
|
||||
INT16: (10-0) / (2^16 / 2 - 1)
|
||||
*/
|
||||
auto input0_scale_control =
|
||||
tensor_type_ == TensorType_INT8 ? 0.019607844 : 0.00015259254;
|
||||
auto input1_scale =
|
||||
tensor_type_ == TensorType_INT8 ? 0.039215688 : 0.00030518509;
|
||||
|
||||
// There should be 4 tensors: input0, input1, input0_requantized, output.
|
||||
EXPECT_EQ(subgraph->tensors.size(), 4);
|
||||
EXPECT_EQ(subgraph->tensors[0]->type, TensorType_INT8);
|
||||
EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
|
||||
EXPECT_EQ(subgraph->tensors[0]->name, "input0");
|
||||
EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
|
||||
EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0], 0.019607844);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0], -128);
|
||||
EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT8);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0],
|
||||
input0_scale_control);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0],
|
||||
zero_point_control);
|
||||
EXPECT_EQ(subgraph->tensors[1]->type, tensor_type_);
|
||||
EXPECT_EQ(subgraph->tensors[1]->name, "input1");
|
||||
EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1);
|
||||
EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], 0.039215688);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0], -128);
|
||||
EXPECT_EQ(subgraph->tensors[2]->type, TensorType_INT8);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], input1_scale);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0],
|
||||
zero_point_control);
|
||||
EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
|
||||
EXPECT_EQ(subgraph->tensors[2]->name, "output");
|
||||
EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
|
||||
EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], 0.039215688);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0], -128);
|
||||
EXPECT_EQ(subgraph->tensors[3]->type, TensorType_INT8);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], input1_scale);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0],
|
||||
zero_point_control);
|
||||
EXPECT_EQ(subgraph->tensors[3]->type, tensor_type_);
|
||||
EXPECT_EQ(subgraph->tensors[3]->name, "input0_requantized");
|
||||
EXPECT_EQ(subgraph->tensors[3]->quantization->scale.size(), 1);
|
||||
EXPECT_EQ(subgraph->tensors[3]->quantization->zero_point.size(), 1);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->scale[0], 0.039215688);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->zero_point[0], -128);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->scale[0], input1_scale);
|
||||
EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->zero_point[0],
|
||||
zero_point_control);
|
||||
|
||||
// The connection should be what is described in the comment.
|
||||
EXPECT_EQ(requant->inputs.size(), 1);
|
||||
@ -419,7 +468,9 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
|
||||
EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_QUANTIZE);
|
||||
EXPECT_EQ(model_.operator_codes[1]->version, 2);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(QuantizeConcatModelInst, QuantizeConcatModelTest,
|
||||
testing::ValuesIn({TensorType_INT8,
|
||||
TensorType_INT16}));
|
||||
class QuantizeSplitModelTest : public QuantizeModelTest {
|
||||
protected:
|
||||
QuantizeSplitModelTest() {
|
||||
@ -432,8 +483,9 @@ class QuantizeSplitModelTest : public QuantizeModelTest {
|
||||
// There are two outputs for split with different scales, the resulting model
|
||||
// should have the scales be hardcodes to the input scale value.
|
||||
TEST_F(QuantizeSplitModelTest, QuantizeSplit) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
// There is only one subgraph.
|
||||
@ -496,8 +548,9 @@ class QuantizeConvModel1Test : public QuantizeModelTest {
|
||||
};
|
||||
|
||||
TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
|
||||
@ -587,18 +640,25 @@ TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
|
||||
EXPECT_EQ(model_.operator_codes[0]->version, 3);
|
||||
}
|
||||
|
||||
class QuantizeConvModel2Test : public QuantizeModelTest {
|
||||
class QuantizeConvModel2Test : public QuantizeModelTest,
|
||||
public testing::WithParamInterface<TensorType> {
|
||||
protected:
|
||||
QuantizeConvModel2Test() {
|
||||
tensor_type_ = GetParam();
|
||||
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
|
||||
readonly_model_ = input_model_->GetModel();
|
||||
readonly_model_->UnPackTo(&model_);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
TensorType tensor_type_;
|
||||
};
|
||||
INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test,
|
||||
testing::ValuesIn({TensorType_INT8,
|
||||
TensorType_INT16}));
|
||||
|
||||
TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) {
|
||||
auto status = QuantizeModel(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
tensor_type_, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
auto conv_op = subgraph->operators[0].get();
|
||||
@ -615,8 +675,10 @@ TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) {
|
||||
const auto output_tensor =
|
||||
subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
|
||||
|
||||
EXPECT_EQ(bias_tensor->type, TensorType_INT32);
|
||||
EXPECT_EQ(input_tensor->type, TensorType_INT8);
|
||||
EXPECT_EQ(bias_tensor->type, tensor_type_ == TensorType_INT8
|
||||
? TensorType_INT32
|
||||
: TensorType_INT64);
|
||||
EXPECT_EQ(input_tensor->type, tensor_type_);
|
||||
EXPECT_EQ(weights_tensor->type, TensorType_INT8);
|
||||
|
||||
ASSERT_TRUE(weights_tensor->quantization);
|
||||
@ -644,17 +706,28 @@ TEST_F(QuantizeConvModel2Test, VerifyConvQuantization) {
|
||||
}
|
||||
|
||||
const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
|
||||
ASSERT_EQ(bias_buffer->data.size(), sizeof(int32_t) * bias_tensor->shape[0]);
|
||||
const int32_t* bias_values =
|
||||
reinterpret_cast<int32_t*>(bias_buffer->data.data());
|
||||
auto control_size = tensor_type_ == TensorType_INT8
|
||||
? sizeof(int32_t) * bias_tensor->shape[0]
|
||||
: sizeof(int64_t) * bias_tensor->shape[0];
|
||||
|
||||
ASSERT_EQ(bias_buffer->data.size(), control_size);
|
||||
const auto original_bias_buffer =
|
||||
readonly_model_->buffers()->Get(bias_tensor->buffer);
|
||||
const float* bias_float_buffer =
|
||||
reinterpret_cast<const float*>(original_bias_buffer->data()->data());
|
||||
|
||||
for (size_t i = 0; i < out_channel_size; i++) {
|
||||
auto dequantized_value = bias_values[i] * bias_scales[i];
|
||||
EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
|
||||
if (tensor_type_ == TensorType_INT8) {
|
||||
int32_t* bias_values = reinterpret_cast<int32_t*>(bias_buffer->data.data());
|
||||
for (size_t i = 0; i < out_channel_size; i++) {
|
||||
auto dequantized_value = bias_values[i] * bias_scales[i];
|
||||
EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
|
||||
}
|
||||
} else if (tensor_type_ == TensorType_INT16) {
|
||||
int64_t* bias_values = reinterpret_cast<int64_t*>(bias_buffer->data.data());
|
||||
for (size_t i = 0; i < out_channel_size; i++) {
|
||||
auto dequantized_value = bias_values[i] * bias_scales[i];
|
||||
EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
|
||||
}
|
||||
}
|
||||
|
||||
const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
|
||||
@ -695,8 +768,9 @@ class QuantizeSoftmaxTest : public QuantizeModelTest {
|
||||
};
|
||||
|
||||
TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
@ -755,8 +829,9 @@ class QuantizeAvgPoolTest : public QuantizeModelTest {
|
||||
};
|
||||
|
||||
TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
@ -816,8 +891,9 @@ class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest {
|
||||
};
|
||||
|
||||
TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
// Verify Reshape is quantized.
|
||||
@ -863,8 +939,9 @@ TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
|
||||
}
|
||||
|
||||
TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
// Verify ADD is quantized.
|
||||
@ -923,8 +1000,9 @@ class QuantizeConstInputTest : public QuantizeModelTest {
|
||||
};
|
||||
|
||||
TEST_F(QuantizeConstInputTest, VerifyConstOpInput) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
// Verify ConstOp is quantized.
|
||||
@ -965,8 +1043,9 @@ class QuantizeArgMaxTest : public QuantizeModelTest {
|
||||
};
|
||||
|
||||
TEST_F(QuantizeArgMaxTest, VerifyArgMax) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
@ -1008,8 +1087,9 @@ class QuantizeLSTMTest : public QuantizeModelTest {
|
||||
|
||||
TEST_F(QuantizeLSTMTest, VerifyLSTM) {
|
||||
// Quantize model.
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32,
|
||||
TensorType_FLOAT32, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
// Read expected model.
|
||||
@ -1067,8 +1147,9 @@ class QuantizeLSTM2Test : public QuantizeModelTest {
|
||||
|
||||
TEST_F(QuantizeLSTM2Test, VerifyLSTM) {
|
||||
// Quantize model.
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32,
|
||||
TensorType_FLOAT32, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
// Read expected model.
|
||||
@ -1126,8 +1207,9 @@ class QuantizeSVDFTest : public QuantizeModelTest {
|
||||
|
||||
TEST_F(QuantizeSVDFTest, VerifySVDF) {
|
||||
// Quantize model.
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
// Read expected model.
|
||||
@ -1184,8 +1266,9 @@ class QuantizeFCTest : public QuantizeModelTest {
|
||||
};
|
||||
|
||||
TEST_F(QuantizeFCTest, VerifyFC) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
@ -1236,7 +1319,7 @@ class QuantizeCustomOpTest : public QuantizeModelTest {
|
||||
TEST_F(QuantizeCustomOpTest, VerifyMixedQuantization) {
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, TensorType_INT8,
|
||||
/*allow_float=*/true, &error_reporter_);
|
||||
/*allow_float=*/true, TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
auto float_graph = readonly_model_->subgraphs()->Get(0);
|
||||
@ -1270,7 +1353,8 @@ class QuantizePackTest : public QuantizeModelTest {
|
||||
};
|
||||
|
||||
TEST_F(QuantizePackTest, VerifyPack) {
|
||||
auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_);
|
||||
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
@ -1334,7 +1418,8 @@ class QuantizeMinimumMaximumTest
|
||||
};
|
||||
|
||||
TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) {
|
||||
auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
|
||||
@ -1415,7 +1500,8 @@ class QuantizeUnpackTest : public QuantizeModelTest {
|
||||
}
|
||||
};
|
||||
TEST_F(QuantizeUnpackTest, VerifyUnpack) {
|
||||
auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_INT8, &error_reporter_);
|
||||
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user