Support integer input and output type for Quantize-Aware Trained models

PiperOrigin-RevId: 322658564
Change-Id: I388d625fe22df0099dc2ed5a5e87db30a4a9d647
This commit is contained in:
Meghna Natraj 2020-07-22 14:40:56 -07:00 committed by TensorFlower Gardener
parent e3be70aa9d
commit a0c12335d3
6 changed files with 490 additions and 31 deletions

View File

@ -56,6 +56,8 @@
* `tf.lite`:
* Better support for ops with high-dimensional broadcasting inputs by adding
`BroadcastTo` ops when necessary.
* `TFLiteConverter`:
* Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (tf.int8, tf.uint8) instead of defaulting to float type (tf.float32).
* `tf.random`:
* <ADD RELEASE NOTES HERE>
* Math and Linear Algebra:
@ -68,7 +70,7 @@
* <ADD RELEASE NOTES HERE>
* Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see
and "denylist" where possible. Please see
https://developers.google.com/style/word-list#blacklist for more context.
* <ADD RELEASE NOTES HERE>

View File

@ -212,8 +212,11 @@ py_library(
deps = [
":lite_constants",
":op_hint",
":schema_py",
"//tensorflow/python:tf_optimizer",
"//tensorflow/python/eager:wrap_function",
"@absl_py//absl/logging",
"@flatbuffers//:runtime_py",
"@six_archive//:six",
],
)
@ -224,12 +227,24 @@ py_test(
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_mac",
"no_windows",
],
deps = [
":lite_constants",
":util",
"//tensorflow:tensorflow_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:convert_to_constants",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)

View File

@ -61,6 +61,7 @@ from tensorflow.lite.python.util import get_grappler_config as _get_grappler_con
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
from tensorflow.lite.python.util import modify_integer_quantized_model_io_type as _modify_integer_quantized_model_io_type
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
from tensorflow.python import keras as _keras
@ -314,6 +315,23 @@ class QuantizationMode(object):
else:
return False, None
def flags_modify_model_io_type(
self, input_type=constants.FLOAT, output_type=constants.FLOAT):
"""Flags for modifying the input and output type of a tflite model."""
is_post_training_quantize = self.quantizer_flags(input_type, output_type)[0]
is_training_time_only_quantize = self.training_time_int8_allow_float() and \
not is_post_training_quantize
# TODO(b/153576658): Consolidate post/during training quantization workflows
# to modify model input/output type after MLIR conversion.
if is_training_time_only_quantize:
return {
"inference_input_type": input_type,
"inference_output_type": output_type,
}
else:
return None
# Below are helpers for the above functions.
def _validate_int8_required(self):
@ -557,9 +575,8 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
def _validate_inference_input_output_types(self, quant_mode):
"""Validate inference_input_type and inference_output_type flags."""
default_types = [constants.FLOAT, None]
# We only support integer types for post training integer quantization
# as we have statistical information to quantize the input and output.
if quant_mode.is_post_training_integer_quantize():
# We support integer input/output for integer quantized models only.
if quant_mode.training_time_int8_allow_float():
all_types = default_types + [constants.INT8, constants.QUANTIZED_UINT8]
if self.inference_input_type not in all_types or \
self.inference_output_type not in all_types:
@ -643,6 +660,12 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
if calibrate_and_quantize:
result = self._calibrate_quantize_model(result, **flags)
flags_modify_model_io_type = quant_mode.flags_modify_model_io_type(
self.inference_input_type, self.inference_output_type)
if flags_modify_model_io_type:
result = _modify_integer_quantized_model_io_type(
result, **flags_modify_model_io_type)
if self._experimental_sparsify_model:
result = _mlir_sparsify(result)

View File

@ -374,8 +374,12 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
return tf.keras.Sequential(QLinear(3, input_shape=(2,)))
@parameterized.named_parameters(
('_DefaultFLOAT32InputOutput', lite.constants.FLOAT),
('_INT8InputOutput', lite.constants.INT8),
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
@test_util.run_v2_only
def testTrainingTimeQuantization(self):
def testTrainingTimeQuantization(self, inference_input_output_type):
model = self._getTrainingTimeQuantizedModel()
float_converter = lite.TFLiteConverterV2.from_keras_model(model)
@ -384,37 +388,24 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_tflite = quantized_converter.convert()
self.assertTrue(quantized_tflite)
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
interpreter = Interpreter(model_content=quantized_tflite)
self.assertEqual(np.float32, interpreter.get_input_details()[0]['dtype'])
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(inference_input_output_type.as_numpy_dtype,
input_details[0]['dtype'])
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertEqual(inference_input_output_type.as_numpy_dtype,
output_details[0]['dtype'])
@parameterized.named_parameters(
('_INT8InputOutput', lite.constants.INT8),
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
def testInvalidTrainingTimeQuantization(self, inference_input_output_type):
# We currently don't support integer inference_input_type and
# inference_output_type flags for training time quantization.
model = self._getTrainingTimeQuantizedModel()
converter = lite.TFLiteConverterV2.from_keras_model(model)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
with self.assertRaises(ValueError) as error:
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_converter.convert()
self.assertEqual(
'The inference_input_type and inference_output_type '
'must be tf.float32.', str(error.exception))
# Ensure that the quantized tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
@test_util.run_v2_only
def testNewQuantizer(self):

View File

@ -19,15 +19,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import datetime
import sys
from absl import logging
import six
from six.moves import range
from flatbuffers.python import flatbuffers
from tensorflow.core.protobuf import config_pb2 as _config_pb2
from tensorflow.core.protobuf import graph_debug_info_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
from tensorflow.lite.python import lite_constants as _lite_constants
from tensorflow.lite.python import schema_py_generated as _schema_fb
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
from tensorflow.lite.toco import types_pb2 as _types_pb2
@ -55,6 +61,25 @@ _MAP_TF_TO_TFLITE_TYPES = {
dtypes.bool: _types_pb2.BOOL,
}
_MAP_TFLITE_ENUM_TO_TF_TYPES = {
0: dtypes.float32,
1: dtypes.float16,
2: dtypes.int32,
3: dtypes.uint8,
4: dtypes.int64,
5: dtypes.string,
6: dtypes.bool,
7: dtypes.int16,
8: dtypes.complex64,
9: dtypes.int8,
10: dtypes.float64,
}
_TFLITE_FILE_IDENTIFIER = b"TFL3"
_TFLITE_MODEL_INPUT_OUTPUT_TYPES = (_lite_constants.FLOAT, _lite_constants.INT8,
_lite_constants.QUANTIZED_UINT8)
def convert_dtype_to_tflite_type(tf_dtype):
"""Converts tf.dtype to TFLite proto type.
@ -74,6 +99,31 @@ def convert_dtype_to_tflite_type(tf_dtype):
return result
def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
"""Converts tflite enum type (eg: 0) to tf type (eg: tf.float32).
Args:
tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32)
Raises:
ValueError: If an invalid tflite enum type is provided.
Returns:
tf type (eg: tf.float32)
"""
tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type)
if tf_type is None:
raise ValueError(
"Unsupported enum {}. The valid map of enum to tf.dtypes is : {}"
.format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES))
return tf_type
def _get_dtype_name(tf_type):
"""Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
return "tf." + tf_type.name
def get_tensor_name(tensor):
"""Returns name of the input tensor.
@ -514,3 +564,218 @@ extern const int {array_name}_len;
license_text=license_text)
return source_text, header_text
def _convert_model_from_bytearray_to_object(model_bytearray):
"""Converts a tflite model from a bytearray into a parsable object."""
model_object = _schema_fb.Model.GetRootAsModel(model_bytearray, 0)
model_object = _schema_fb.ModelT.InitFromObj(model_object)
model_object = copy.deepcopy(model_object)
model_object.subgraphs[0].inputs[0] = model_object.subgraphs[0].inputs[0]
return model_object
def _convert_model_from_object_to_bytearray(model_object):
"""Converts a tflite model from a parsable object into a bytearray."""
# Initial size of the buffer, which will grow automatically if needed
builder = flatbuffers.Builder(1024)
model_offset = model_object.Pack(builder)
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
return bytes(builder.Output())
def _remove_tensors_from_model(model, remove_tensors_idxs):
"""Remove tensors from model."""
if not remove_tensors_idxs:
return
if len(model.subgraphs) > 1:
raise ValueError("Model must only have one subgraph. Instead, it has "
"{} subgraphs.".format(len(model.subgraphs)))
subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
logging.debug("Removing tensors at indices : %s", remove_tensors_idxs)
# An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an
# exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]).
if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs):
logging.debug("Removing tensors only at the end of the tensor list")
del tensors[min(remove_tensors_idxs):]
else:
logging.debug("Removing tensors requires updating the model")
# Map the old tensor indices to new tensor indices
d_old_to_new_tensors = {}
left_shift_by = 0
for idx in range(len(tensors)):
if idx in remove_tensors_idxs:
left_shift_by += 1
else:
d_old_to_new_tensors[idx] = idx - left_shift_by
logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__())
# Update tensor indices referenced throughout the model
def update_tensors(tensor_idxs):
for i, ti in enumerate(tensor_idxs):
tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1)
update_tensors(subgraph.inputs)
update_tensors(subgraph.outputs)
for op in operators:
update_tensors(op.inputs)
update_tensors(op.outputs)
# Delete the tensors
for idx in sorted(remove_tensors_idxs, reverse=True):
tensors.pop(idx)
logging.debug("Removed tensors marked for deletion")
def _validate_and_find_int8_quantized_inputs_outputs(model):
"""Validate that model input is quantized and output is dequantized."""
if len(model.subgraphs) > 1:
raise ValueError("Model must only have one subgraph. Instead, it has "
"{} subgraphs.".format(len(model.subgraphs)))
subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
# Ensure model has atleast one quantize and dequantize operator
quant_opcode_idx, dequant_opcode_idx = None, None
for idx, opcode in enumerate(model.operatorCodes):
if opcode.builtinCode == _schema_fb.BuiltinOperator.QUANTIZE:
quant_opcode_idx = idx
elif opcode.builtinCode == _schema_fb.BuiltinOperator.DEQUANTIZE:
dequant_opcode_idx = idx
if quant_opcode_idx is not None and dequant_opcode_idx is not None:
break
if quant_opcode_idx is None and dequant_opcode_idx is None:
raise ValueError("Model is not integer quantized as it does not "
"contain quantize/dequantize operators.")
# Ensure model inputs and outputs are integer quantized
input_quant_ops, output_dequant_ops = [], []
for op in operators:
# Find input quantize operator
if op.opcodeIndex == quant_opcode_idx and op.inputs[0] in subgraph.inputs:
pos, float_tensor, int_tensor = \
"input", tensors[op.inputs[0]], tensors[op.outputs[0]]
input_quant_ops.append(op)
# Find output dequantize operator
elif op.opcodeIndex == dequant_opcode_idx and \
op.outputs[0] in subgraph.outputs:
pos, float_tensor, int_tensor = \
"output", tensors[op.outputs[0]], tensors[op.inputs[0]]
output_dequant_ops.append(op)
# Otherwise, ignore
else:
continue
# If found, validate the input/output tensor type
if float_tensor.type != _schema_fb.TensorType.FLOAT32:
raise ValueError(
"Model {} type must be tf.float32. Expected type for tensor with "
"name '{}' is tf.float32, instead type is tf.{}".format(
pos, float_tensor.name,
_convert_tflite_enum_type_to_tf_type(float_tensor.type).name))
if int_tensor.type != _schema_fb.TensorType.INT8:
raise ValueError(
"Model is not integer quantized. Expected type for tensor with "
"name '{}' is tf.int8, instead type is tf.{}".format(
int_tensor.name,
_convert_tflite_enum_type_to_tf_type(int_tensor.type).name))
return input_quant_ops, output_dequant_ops
def modify_integer_quantized_model_io_type(
model, inference_input_type=_lite_constants.FLOAT,
inference_output_type=_lite_constants.FLOAT):
"""Modify the float input/output type of an integer quantized model.
Args:
model: An int8 quantized tflite model with float input and output.
inference_input_type: tf.DType representing final input type.
(default tf.float32)
inference_output_type: tf.DType representing final output type.
(default tf.float32)
Returns:
An int8 quantized tflite model with modified input and/or output type.
Raises:
ValueError: If the model is not int8 quantized or the inference_input_type
and/or inference_input_type is unsupported.
RuntimeError: If the modification was unsuccessful.
"""
# Return if input and output types default to float
if inference_input_type == _lite_constants.FLOAT and \
inference_output_type == _lite_constants.FLOAT:
return model
# Validate input and output types
if inference_input_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
raise ValueError("The `inference_input_type` should be in {}".format(
tuple(_get_dtype_name(t) for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES)))
if inference_output_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
raise ValueError("The `inference_output_type` should be in {}".format(
tuple(_get_dtype_name(t) for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES)))
logging.debug(("Attempting to modify the model input from tf.float32 to %s "
"and output from tf.float32 to %s"),
_get_dtype_name(inference_input_type),
_get_dtype_name(inference_output_type))
# Convert the model to an object
model = _convert_model_from_bytearray_to_object(model)
# Validate the integer quantized model
input_quant_ops, output_dequant_ops = \
_validate_and_find_int8_quantized_inputs_outputs(model)
# Initialize references and variables
if len(model.subgraphs) > 1:
raise ValueError("Model must only have one subgraph. Instead, it has "
"{} subgraphs.".format(len(model.subgraphs)))
subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
remove_tensors_idxs = set()
# Modify model input type
if inference_input_type == _lite_constants.QUANTIZED_UINT8:
# Change quant op (float to int8) to quant op (uint8 to int8)
for op in input_quant_ops:
int8_quantization = tensors[op.outputs[0]].quantization
uint8_quantization = _schema_fb.QuantizationParametersT()
uint8_quantization.scale = [int8_quantization.scale[0]]
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
tensors[op.inputs[0]].quantization = uint8_quantization
tensors[op.inputs[0]].type = _schema_fb.TensorType.UINT8
elif inference_input_type == _lite_constants.INT8:
# Remove the inputs and the quant operator
for op in input_quant_ops:
subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0]
remove_tensors_idxs.add(op.inputs[0])
operators.remove(op)
# Modify model output type
if inference_output_type == _lite_constants.QUANTIZED_UINT8:
# Change dequant op (int8 to float) to quant op (int8 to uint8)
for op in output_dequant_ops:
op.opcodeIndex = input_quant_ops[0].opcodeIndex
int8_quantization = tensors[op.inputs[0]].quantization
uint8_quantization = _schema_fb.QuantizationParametersT()
uint8_quantization.scale = [int8_quantization.scale[0]]
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
tensors[op.outputs[0]].quantization = uint8_quantization
tensors[op.outputs[0]].type = _schema_fb.TensorType.UINT8
elif inference_output_type == _lite_constants.INT8:
# Remove the outputs and the dequant operator
for op in output_dequant_ops:
subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
remove_tensors_idxs.add(op.outputs[0])
operators.remove(op)
# Remove tensors marked for deletion.
_remove_tensors_from_model(model, remove_tensors_idxs)
# Convert the model to a bytearray
model = _convert_model_from_object_to_bytearray(model)
return model

View File

@ -19,7 +19,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from six.moves import range
import tensorflow as tf
from tensorflow.lite.python import lite_constants
from tensorflow.lite.python import util
@ -61,6 +64,31 @@ class UtilTest(test_util.TensorFlowTestCase):
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.bool), _types_pb2.BOOL)
def testConvertEnumToDtype(self):
self.assertEqual(
util._convert_tflite_enum_type_to_tf_type(0), dtypes.float32)
self.assertEqual(
util._convert_tflite_enum_type_to_tf_type(1), dtypes.float16)
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(2), dtypes.int32)
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(3), dtypes.uint8)
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(4), dtypes.int64)
self.assertEqual(
util._convert_tflite_enum_type_to_tf_type(5), dtypes.string)
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(6), dtypes.bool)
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(7), dtypes.int16)
self.assertEqual(
util._convert_tflite_enum_type_to_tf_type(8), dtypes.complex64)
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(9), dtypes.int8)
self.assertEqual(
util._convert_tflite_enum_type_to_tf_type(10), dtypes.float64)
with self.assertRaises(ValueError) as error:
util._convert_tflite_enum_type_to_tf_type(11)
self.assertEqual(
"Unsupported enum 11. The valid map of enum to tf.dtypes is : "
"{0: tf.float32, 1: tf.float16, 2: tf.int32, 3: tf.uint8, 4: tf.int64, "
"5: tf.string, 6: tf.bool, 7: tf.int16, 8: tf.complex64, 9: tf.int8, "
"10: tf.float64}", str(error.exception))
def testTensorName(self):
with ops.Graph().as_default():
in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
@ -195,5 +223,140 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
self.assertEqual([None, 3, 5], tensor.shape.as_list())
def _generate_integer_tflite_model():
"""Define an integer post-training quantized tflite model."""
# Load MNIST dataset
n = 10 # Number of samples
(train_images, train_labels), (test_images, test_labels) = \
tf.keras.datasets.mnist.load_data()
train_images, train_labels, test_images, test_labels = \
train_images[:n], train_labels[:n], test_images[:n], test_labels[:n]
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Define TF model
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation="relu"),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10)
])
# Train
model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"])
model.fit(
train_images,
train_labels,
epochs=1,
validation_split=0.1,
)
# Convert TF Model to an Integer Quantized TFLite Model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
def representative_dataset_gen():
for _ in range(2):
yield [
np.random.uniform(low=0, high=1, size=(1, 28, 28)).astype(
np.float32)
]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8}
tflite_model = converter.convert()
return tflite_model
def _test_param_modify_integer_model_io_type():
"""Function to generate parameterized inputs for testing."""
params = []
str_template = "_{}{}{}"
map_model_type = {
"PostTraining": True,
# "DuringTraining": False,
}
map_types = {
"": lite_constants.FLOAT,
"INT8": lite_constants.INT8,
"UINT8": lite_constants.QUANTIZED_UINT8
}
for k1, v1 in map_model_type.items():
for k2, v2 in map_types.items():
istr = "_Input{}".format(k2) if k2 else ""
for k3, v3 in map_types.items():
ostr = "_Output{}".format(k3) if k3 else "" if istr else "_NoUpdate"
params.append((str_template.format(k1, istr, ostr), v1, v2, v3))
return params
# TODO(b/161174063): Merge tests for integer input/output type
class UtilModifyIntegerQuantizedModelIOTypeTest(
test_util.TensorFlowTestCase, parameterized.TestCase):
@classmethod
def setUpClass(cls):
super(UtilModifyIntegerQuantizedModelIOTypeTest, cls).setUpClass()
cls.post_train_integer_model = _generate_integer_tflite_model()
@parameterized.named_parameters(_test_param_modify_integer_model_io_type())
def test(self, is_post_train, in_tftype, out_tftype):
"""Modify the float input/output type of an integer quantized model."""
def _run_tflite_inference(model, in_tftype, out_tftype):
"""Run inference on a model with a specific input/output type."""
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_content=model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]
# Validate TFLite model input and output types
self.assertEqual(input_details["dtype"], in_tftype.as_numpy_dtype)
self.assertEqual(output_details["dtype"], out_tftype.as_numpy_dtype)
# Define Input
np.random.seed(0)
input_data = np.random.uniform(low=0, high=1, size=(1, 28, 28))
input_data = input_data.astype(np.float32)
if input_details["dtype"] != np.float32:
# quantize float to int
scale, zero_point = input_details["quantization"]
input_data = input_data / scale + zero_point
input_data = input_data.astype(input_details["dtype"])
# Run Inference
interpreter.set_tensor(input_details["index"], input_data)
interpreter.invoke()
# Get output
output_data = interpreter.get_tensor(output_details["index"])[0]
if output_details["dtype"] != np.float32:
# dequantize int to float
scale, zero_point = output_details["quantization"]
output_data = output_data.astype(np.float32)
output_data = (output_data - zero_point) * scale
return output_data
model = self.__class__.post_train_integer_model if is_post_train else None
# Run model inference with float input output type
output_data = _run_tflite_inference(model, tf.float32, tf.float32)
# Run model inference with modified integer input output type
model_io = util.modify_integer_quantized_model_io_type(
model, in_tftype, out_tftype)
output_io_data = _run_tflite_inference(model_io, in_tftype, out_tftype)
# Validate that both the outputs are the same
self.assertTrue(np.allclose(output_data, output_io_data, atol=1.0))
if __name__ == "__main__":
test.main()