Support integer input and output type for Quantize-Aware Trained models

PiperOrigin-RevId: 322658564
Change-Id: I388d625fe22df0099dc2ed5a5e87db30a4a9d647
This commit is contained in:
Meghna Natraj 2020-07-22 14:40:56 -07:00 committed by TensorFlower Gardener
parent e3be70aa9d
commit a0c12335d3
6 changed files with 490 additions and 31 deletions

View File

@ -56,6 +56,8 @@
* `tf.lite`: * `tf.lite`:
* Better support for ops with high-dimensional broadcasting inputs by adding * Better support for ops with high-dimensional broadcasting inputs by adding
`BroadcastTo` ops when necessary. `BroadcastTo` ops when necessary.
* `TFLiteConverter`:
* Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (tf.int8, tf.uint8) instead of defaulting to float type (tf.float32).
* `tf.random`: * `tf.random`:
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
* Math and Linear Algebra: * Math and Linear Algebra:
@ -68,7 +70,7 @@
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
* Other: * Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist" * We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see and "denylist" where possible. Please see
https://developers.google.com/style/word-list#blacklist for more context. https://developers.google.com/style/word-list#blacklist for more context.
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>

View File

@ -212,8 +212,11 @@ py_library(
deps = [ deps = [
":lite_constants", ":lite_constants",
":op_hint", ":op_hint",
":schema_py",
"//tensorflow/python:tf_optimizer", "//tensorflow/python:tf_optimizer",
"//tensorflow/python/eager:wrap_function", "//tensorflow/python/eager:wrap_function",
"@absl_py//absl/logging",
"@flatbuffers//:runtime_py",
"@six_archive//:six", "@six_archive//:six",
], ],
) )
@ -224,12 +227,24 @@ py_test(
python_version = "PY3", python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [
"no_mac",
"no_windows", "no_windows",
], ],
deps = [ deps = [
":lite_constants",
":util", ":util",
"//tensorflow:tensorflow_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:convert_to_constants",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
"@six_archive//:six", "@six_archive//:six",
], ],
) )

View File

@ -61,6 +61,7 @@ from tensorflow.lite.python.util import get_grappler_config as _get_grappler_con
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
from tensorflow.lite.python.util import modify_integer_quantized_model_io_type as _modify_integer_quantized_model_io_type
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
from tensorflow.python import keras as _keras from tensorflow.python import keras as _keras
@ -314,6 +315,23 @@ class QuantizationMode(object):
else: else:
return False, None return False, None
def flags_modify_model_io_type(
self, input_type=constants.FLOAT, output_type=constants.FLOAT):
"""Flags for modifying the input and output type of a tflite model."""
is_post_training_quantize = self.quantizer_flags(input_type, output_type)[0]
is_training_time_only_quantize = self.training_time_int8_allow_float() and \
not is_post_training_quantize
# TODO(b/153576658): Consolidate post/during training quantization workflows
# to modify model input/output type after MLIR conversion.
if is_training_time_only_quantize:
return {
"inference_input_type": input_type,
"inference_output_type": output_type,
}
else:
return None
# Below are helpers for the above functions. # Below are helpers for the above functions.
def _validate_int8_required(self): def _validate_int8_required(self):
@ -557,9 +575,8 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
def _validate_inference_input_output_types(self, quant_mode): def _validate_inference_input_output_types(self, quant_mode):
"""Validate inference_input_type and inference_output_type flags.""" """Validate inference_input_type and inference_output_type flags."""
default_types = [constants.FLOAT, None] default_types = [constants.FLOAT, None]
# We only support integer types for post training integer quantization # We support integer input/output for integer quantized models only.
# as we have statistical information to quantize the input and output. if quant_mode.training_time_int8_allow_float():
if quant_mode.is_post_training_integer_quantize():
all_types = default_types + [constants.INT8, constants.QUANTIZED_UINT8] all_types = default_types + [constants.INT8, constants.QUANTIZED_UINT8]
if self.inference_input_type not in all_types or \ if self.inference_input_type not in all_types or \
self.inference_output_type not in all_types: self.inference_output_type not in all_types:
@ -643,6 +660,12 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
if calibrate_and_quantize: if calibrate_and_quantize:
result = self._calibrate_quantize_model(result, **flags) result = self._calibrate_quantize_model(result, **flags)
flags_modify_model_io_type = quant_mode.flags_modify_model_io_type(
self.inference_input_type, self.inference_output_type)
if flags_modify_model_io_type:
result = _modify_integer_quantized_model_io_type(
result, **flags_modify_model_io_type)
if self._experimental_sparsify_model: if self._experimental_sparsify_model:
result = _mlir_sparsify(result) result = _mlir_sparsify(result)

View File

@ -374,8 +374,12 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
return tf.keras.Sequential(QLinear(3, input_shape=(2,))) return tf.keras.Sequential(QLinear(3, input_shape=(2,)))
@parameterized.named_parameters(
('_DefaultFLOAT32InputOutput', lite.constants.FLOAT),
('_INT8InputOutput', lite.constants.INT8),
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
@test_util.run_v2_only @test_util.run_v2_only
def testTrainingTimeQuantization(self): def testTrainingTimeQuantization(self, inference_input_output_type):
model = self._getTrainingTimeQuantizedModel() model = self._getTrainingTimeQuantizedModel()
float_converter = lite.TFLiteConverterV2.from_keras_model(model) float_converter = lite.TFLiteConverterV2.from_keras_model(model)
@ -384,37 +388,24 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
quantized_converter = lite.TFLiteConverterV2.from_keras_model(model) quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
quantized_converter.optimizations = [lite.Optimize.DEFAULT] quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_tflite = quantized_converter.convert() quantized_tflite = quantized_converter.convert()
self.assertTrue(quantized_tflite) self.assertTrue(quantized_tflite)
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
interpreter = Interpreter(model_content=quantized_tflite) interpreter = Interpreter(model_content=quantized_tflite)
self.assertEqual(np.float32, interpreter.get_input_details()[0]['dtype']) interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(inference_input_output_type.as_numpy_dtype,
input_details[0]['dtype'])
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertEqual(inference_input_output_type.as_numpy_dtype,
output_details[0]['dtype'])
@parameterized.named_parameters( # Ensure that the quantized tflite model is smaller.
('_INT8InputOutput', lite.constants.INT8), self.assertLess(len(quantized_tflite), len(float_tflite))
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
def testInvalidTrainingTimeQuantization(self, inference_input_output_type):
# We currently don't support integer inference_input_type and
# inference_output_type flags for training time quantization.
model = self._getTrainingTimeQuantizedModel()
converter = lite.TFLiteConverterV2.from_keras_model(model)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
with self.assertRaises(ValueError) as error:
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_converter.convert()
self.assertEqual(
'The inference_input_type and inference_output_type '
'must be tf.float32.', str(error.exception))
@test_util.run_v2_only @test_util.run_v2_only
def testNewQuantizer(self): def testNewQuantizer(self):

View File

@ -19,15 +19,21 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy
import datetime import datetime
import sys import sys
from absl import logging
import six import six
from six.moves import range from six.moves import range
from flatbuffers.python import flatbuffers
from tensorflow.core.protobuf import config_pb2 as _config_pb2 from tensorflow.core.protobuf import config_pb2 as _config_pb2
from tensorflow.core.protobuf import graph_debug_info_pb2 from tensorflow.core.protobuf import graph_debug_info_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
from tensorflow.lite.python import lite_constants as _lite_constants
from tensorflow.lite.python import schema_py_generated as _schema_fb
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
from tensorflow.lite.toco import types_pb2 as _types_pb2 from tensorflow.lite.toco import types_pb2 as _types_pb2
@ -55,6 +61,25 @@ _MAP_TF_TO_TFLITE_TYPES = {
dtypes.bool: _types_pb2.BOOL, dtypes.bool: _types_pb2.BOOL,
} }
_MAP_TFLITE_ENUM_TO_TF_TYPES = {
0: dtypes.float32,
1: dtypes.float16,
2: dtypes.int32,
3: dtypes.uint8,
4: dtypes.int64,
5: dtypes.string,
6: dtypes.bool,
7: dtypes.int16,
8: dtypes.complex64,
9: dtypes.int8,
10: dtypes.float64,
}
_TFLITE_FILE_IDENTIFIER = b"TFL3"
_TFLITE_MODEL_INPUT_OUTPUT_TYPES = (_lite_constants.FLOAT, _lite_constants.INT8,
_lite_constants.QUANTIZED_UINT8)
def convert_dtype_to_tflite_type(tf_dtype): def convert_dtype_to_tflite_type(tf_dtype):
"""Converts tf.dtype to TFLite proto type. """Converts tf.dtype to TFLite proto type.
@ -74,6 +99,31 @@ def convert_dtype_to_tflite_type(tf_dtype):
return result return result
def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
"""Converts tflite enum type (eg: 0) to tf type (eg: tf.float32).
Args:
tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32)
Raises:
ValueError: If an invalid tflite enum type is provided.
Returns:
tf type (eg: tf.float32)
"""
tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type)
if tf_type is None:
raise ValueError(
"Unsupported enum {}. The valid map of enum to tf.dtypes is : {}"
.format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES))
return tf_type
def _get_dtype_name(tf_type):
"""Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
return "tf." + tf_type.name
def get_tensor_name(tensor): def get_tensor_name(tensor):
"""Returns name of the input tensor. """Returns name of the input tensor.
@ -514,3 +564,218 @@ extern const int {array_name}_len;
license_text=license_text) license_text=license_text)
return source_text, header_text return source_text, header_text
def _convert_model_from_bytearray_to_object(model_bytearray):
"""Converts a tflite model from a bytearray into a parsable object."""
model_object = _schema_fb.Model.GetRootAsModel(model_bytearray, 0)
model_object = _schema_fb.ModelT.InitFromObj(model_object)
model_object = copy.deepcopy(model_object)
model_object.subgraphs[0].inputs[0] = model_object.subgraphs[0].inputs[0]
return model_object
def _convert_model_from_object_to_bytearray(model_object):
"""Converts a tflite model from a parsable object into a bytearray."""
# Initial size of the buffer, which will grow automatically if needed
builder = flatbuffers.Builder(1024)
model_offset = model_object.Pack(builder)
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
return bytes(builder.Output())
def _remove_tensors_from_model(model, remove_tensors_idxs):
"""Remove tensors from model."""
if not remove_tensors_idxs:
return
if len(model.subgraphs) > 1:
raise ValueError("Model must only have one subgraph. Instead, it has "
"{} subgraphs.".format(len(model.subgraphs)))
subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
logging.debug("Removing tensors at indices : %s", remove_tensors_idxs)
# An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an
# exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]).
if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs):
logging.debug("Removing tensors only at the end of the tensor list")
del tensors[min(remove_tensors_idxs):]
else:
logging.debug("Removing tensors requires updating the model")
# Map the old tensor indices to new tensor indices
d_old_to_new_tensors = {}
left_shift_by = 0
for idx in range(len(tensors)):
if idx in remove_tensors_idxs:
left_shift_by += 1
else:
d_old_to_new_tensors[idx] = idx - left_shift_by
logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__())
# Update tensor indices referenced throughout the model
def update_tensors(tensor_idxs):
for i, ti in enumerate(tensor_idxs):
tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1)
update_tensors(subgraph.inputs)
update_tensors(subgraph.outputs)
for op in operators:
update_tensors(op.inputs)
update_tensors(op.outputs)
# Delete the tensors
for idx in sorted(remove_tensors_idxs, reverse=True):
tensors.pop(idx)
logging.debug("Removed tensors marked for deletion")
def _validate_and_find_int8_quantized_inputs_outputs(model):
"""Validate that model input is quantized and output is dequantized."""
if len(model.subgraphs) > 1:
raise ValueError("Model must only have one subgraph. Instead, it has "
"{} subgraphs.".format(len(model.subgraphs)))
subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
# Ensure model has atleast one quantize and dequantize operator
quant_opcode_idx, dequant_opcode_idx = None, None
for idx, opcode in enumerate(model.operatorCodes):
if opcode.builtinCode == _schema_fb.BuiltinOperator.QUANTIZE:
quant_opcode_idx = idx
elif opcode.builtinCode == _schema_fb.BuiltinOperator.DEQUANTIZE:
dequant_opcode_idx = idx
if quant_opcode_idx is not None and dequant_opcode_idx is not None:
break
if quant_opcode_idx is None and dequant_opcode_idx is None:
raise ValueError("Model is not integer quantized as it does not "
"contain quantize/dequantize operators.")
# Ensure model inputs and outputs are integer quantized
input_quant_ops, output_dequant_ops = [], []
for op in operators:
# Find input quantize operator
if op.opcodeIndex == quant_opcode_idx and op.inputs[0] in subgraph.inputs:
pos, float_tensor, int_tensor = \
"input", tensors[op.inputs[0]], tensors[op.outputs[0]]
input_quant_ops.append(op)
# Find output dequantize operator
elif op.opcodeIndex == dequant_opcode_idx and \
op.outputs[0] in subgraph.outputs:
pos, float_tensor, int_tensor = \
"output", tensors[op.outputs[0]], tensors[op.inputs[0]]
output_dequant_ops.append(op)
# Otherwise, ignore
else:
continue
# If found, validate the input/output tensor type
if float_tensor.type != _schema_fb.TensorType.FLOAT32:
raise ValueError(
"Model {} type must be tf.float32. Expected type for tensor with "
"name '{}' is tf.float32, instead type is tf.{}".format(
pos, float_tensor.name,
_convert_tflite_enum_type_to_tf_type(float_tensor.type).name))
if int_tensor.type != _schema_fb.TensorType.INT8:
raise ValueError(
"Model is not integer quantized. Expected type for tensor with "
"name '{}' is tf.int8, instead type is tf.{}".format(
int_tensor.name,
_convert_tflite_enum_type_to_tf_type(int_tensor.type).name))
return input_quant_ops, output_dequant_ops
def modify_integer_quantized_model_io_type(
model, inference_input_type=_lite_constants.FLOAT,
inference_output_type=_lite_constants.FLOAT):
"""Modify the float input/output type of an integer quantized model.
Args:
model: An int8 quantized tflite model with float input and output.
inference_input_type: tf.DType representing final input type.
(default tf.float32)
inference_output_type: tf.DType representing final output type.
(default tf.float32)
Returns:
An int8 quantized tflite model with modified input and/or output type.
Raises:
ValueError: If the model is not int8 quantized or the inference_input_type
and/or inference_input_type is unsupported.
RuntimeError: If the modification was unsuccessful.
"""
# Return if input and output types default to float
if inference_input_type == _lite_constants.FLOAT and \
inference_output_type == _lite_constants.FLOAT:
return model
# Validate input and output types
if inference_input_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
raise ValueError("The `inference_input_type` should be in {}".format(
tuple(_get_dtype_name(t) for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES)))
if inference_output_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
raise ValueError("The `inference_output_type` should be in {}".format(
tuple(_get_dtype_name(t) for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES)))
logging.debug(("Attempting to modify the model input from tf.float32 to %s "
"and output from tf.float32 to %s"),
_get_dtype_name(inference_input_type),
_get_dtype_name(inference_output_type))
# Convert the model to an object
model = _convert_model_from_bytearray_to_object(model)
# Validate the integer quantized model
input_quant_ops, output_dequant_ops = \
_validate_and_find_int8_quantized_inputs_outputs(model)
# Initialize references and variables
if len(model.subgraphs) > 1:
raise ValueError("Model must only have one subgraph. Instead, it has "
"{} subgraphs.".format(len(model.subgraphs)))
subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
remove_tensors_idxs = set()
# Modify model input type
if inference_input_type == _lite_constants.QUANTIZED_UINT8:
# Change quant op (float to int8) to quant op (uint8 to int8)
for op in input_quant_ops:
int8_quantization = tensors[op.outputs[0]].quantization
uint8_quantization = _schema_fb.QuantizationParametersT()
uint8_quantization.scale = [int8_quantization.scale[0]]
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
tensors[op.inputs[0]].quantization = uint8_quantization
tensors[op.inputs[0]].type = _schema_fb.TensorType.UINT8
elif inference_input_type == _lite_constants.INT8:
# Remove the inputs and the quant operator
for op in input_quant_ops:
subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0]
remove_tensors_idxs.add(op.inputs[0])
operators.remove(op)
# Modify model output type
if inference_output_type == _lite_constants.QUANTIZED_UINT8:
# Change dequant op (int8 to float) to quant op (int8 to uint8)
for op in output_dequant_ops:
op.opcodeIndex = input_quant_ops[0].opcodeIndex
int8_quantization = tensors[op.inputs[0]].quantization
uint8_quantization = _schema_fb.QuantizationParametersT()
uint8_quantization.scale = [int8_quantization.scale[0]]
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
tensors[op.outputs[0]].quantization = uint8_quantization
tensors[op.outputs[0]].type = _schema_fb.TensorType.UINT8
elif inference_output_type == _lite_constants.INT8:
# Remove the outputs and the dequant operator
for op in output_dequant_ops:
subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
remove_tensors_idxs.add(op.outputs[0])
operators.remove(op)
# Remove tensors marked for deletion.
_remove_tensors_from_model(model, remove_tensors_idxs)
# Convert the model to a bytearray
model = _convert_model_from_object_to_bytearray(model)
return model

View File

@ -19,7 +19,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from six.moves import range from six.moves import range
import tensorflow as tf
from tensorflow.lite.python import lite_constants from tensorflow.lite.python import lite_constants
from tensorflow.lite.python import util from tensorflow.lite.python import util
@ -61,6 +64,31 @@ class UtilTest(test_util.TensorFlowTestCase):
self.assertEqual( self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.bool), _types_pb2.BOOL) util.convert_dtype_to_tflite_type(dtypes.bool), _types_pb2.BOOL)
def testConvertEnumToDtype(self):
self.assertEqual(
util._convert_tflite_enum_type_to_tf_type(0), dtypes.float32)
self.assertEqual(
util._convert_tflite_enum_type_to_tf_type(1), dtypes.float16)
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(2), dtypes.int32)
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(3), dtypes.uint8)
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(4), dtypes.int64)
self.assertEqual(
util._convert_tflite_enum_type_to_tf_type(5), dtypes.string)
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(6), dtypes.bool)
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(7), dtypes.int16)
self.assertEqual(
util._convert_tflite_enum_type_to_tf_type(8), dtypes.complex64)
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(9), dtypes.int8)
self.assertEqual(
util._convert_tflite_enum_type_to_tf_type(10), dtypes.float64)
with self.assertRaises(ValueError) as error:
util._convert_tflite_enum_type_to_tf_type(11)
self.assertEqual(
"Unsupported enum 11. The valid map of enum to tf.dtypes is : "
"{0: tf.float32, 1: tf.float16, 2: tf.int32, 3: tf.uint8, 4: tf.int64, "
"5: tf.string, 6: tf.bool, 7: tf.int16, 8: tf.complex64, 9: tf.int8, "
"10: tf.float64}", str(error.exception))
def testTensorName(self): def testTensorName(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32) in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
@ -195,5 +223,140 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
self.assertEqual([None, 3, 5], tensor.shape.as_list()) self.assertEqual([None, 3, 5], tensor.shape.as_list())
def _generate_integer_tflite_model():
"""Define an integer post-training quantized tflite model."""
# Load MNIST dataset
n = 10 # Number of samples
(train_images, train_labels), (test_images, test_labels) = \
tf.keras.datasets.mnist.load_data()
train_images, train_labels, test_images, test_labels = \
train_images[:n], train_labels[:n], test_images[:n], test_labels[:n]
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Define TF model
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation="relu"),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10)
])
# Train
model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"])
model.fit(
train_images,
train_labels,
epochs=1,
validation_split=0.1,
)
# Convert TF Model to an Integer Quantized TFLite Model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
def representative_dataset_gen():
for _ in range(2):
yield [
np.random.uniform(low=0, high=1, size=(1, 28, 28)).astype(
np.float32)
]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8}
tflite_model = converter.convert()
return tflite_model
def _test_param_modify_integer_model_io_type():
"""Function to generate parameterized inputs for testing."""
params = []
str_template = "_{}{}{}"
map_model_type = {
"PostTraining": True,
# "DuringTraining": False,
}
map_types = {
"": lite_constants.FLOAT,
"INT8": lite_constants.INT8,
"UINT8": lite_constants.QUANTIZED_UINT8
}
for k1, v1 in map_model_type.items():
for k2, v2 in map_types.items():
istr = "_Input{}".format(k2) if k2 else ""
for k3, v3 in map_types.items():
ostr = "_Output{}".format(k3) if k3 else "" if istr else "_NoUpdate"
params.append((str_template.format(k1, istr, ostr), v1, v2, v3))
return params
# TODO(b/161174063): Merge tests for integer input/output type
class UtilModifyIntegerQuantizedModelIOTypeTest(
test_util.TensorFlowTestCase, parameterized.TestCase):
@classmethod
def setUpClass(cls):
super(UtilModifyIntegerQuantizedModelIOTypeTest, cls).setUpClass()
cls.post_train_integer_model = _generate_integer_tflite_model()
@parameterized.named_parameters(_test_param_modify_integer_model_io_type())
def test(self, is_post_train, in_tftype, out_tftype):
"""Modify the float input/output type of an integer quantized model."""
def _run_tflite_inference(model, in_tftype, out_tftype):
"""Run inference on a model with a specific input/output type."""
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_content=model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]
# Validate TFLite model input and output types
self.assertEqual(input_details["dtype"], in_tftype.as_numpy_dtype)
self.assertEqual(output_details["dtype"], out_tftype.as_numpy_dtype)
# Define Input
np.random.seed(0)
input_data = np.random.uniform(low=0, high=1, size=(1, 28, 28))
input_data = input_data.astype(np.float32)
if input_details["dtype"] != np.float32:
# quantize float to int
scale, zero_point = input_details["quantization"]
input_data = input_data / scale + zero_point
input_data = input_data.astype(input_details["dtype"])
# Run Inference
interpreter.set_tensor(input_details["index"], input_data)
interpreter.invoke()
# Get output
output_data = interpreter.get_tensor(output_details["index"])[0]
if output_details["dtype"] != np.float32:
# dequantize int to float
scale, zero_point = output_details["quantization"]
output_data = output_data.astype(np.float32)
output_data = (output_data - zero_point) * scale
return output_data
model = self.__class__.post_train_integer_model if is_post_train else None
# Run model inference with float input output type
output_data = _run_tflite_inference(model, tf.float32, tf.float32)
# Run model inference with modified integer input output type
model_io = util.modify_integer_quantized_model_io_type(
model, in_tftype, out_tftype)
output_io_data = _run_tflite_inference(model_io, in_tftype, out_tftype)
# Validate that both the outputs are the same
self.assertTrue(np.allclose(output_data, output_io_data, atol=1.0))
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()