diff --git a/RELEASE.md b/RELEASE.md index 12b5168954b..7895a0ba113 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -56,6 +56,8 @@ * `tf.lite`: * Better support for ops with high-dimensional broadcasting inputs by adding `BroadcastTo` ops when necessary. + * `TFLiteConverter`: + * Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (tf.int8, tf.uint8) instead of defaulting to float type (tf.float32). * `tf.random`: * <ADD RELEASE NOTES HERE> * Math and Linear Algebra: @@ -68,7 +70,7 @@ * <ADD RELEASE NOTES HERE> * Other: * We have replaced uses of "whitelist" and "blacklist" with "allowlist" - and "denylist" where possible. Please see + and "denylist" where possible. Please see https://developers.google.com/style/word-list#blacklist for more context. * <ADD RELEASE NOTES HERE> diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index e26000c810a..55a2a69675d 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -212,8 +212,11 @@ py_library( deps = [ ":lite_constants", ":op_hint", + ":schema_py", "//tensorflow/python:tf_optimizer", "//tensorflow/python/eager:wrap_function", + "@absl_py//absl/logging", + "@flatbuffers//:runtime_py", "@six_archive//:six", ], ) @@ -224,12 +227,24 @@ py_test( python_version = "PY3", srcs_version = "PY2AND3", tags = [ + "no_mac", "no_windows", ], deps = [ + ":lite_constants", ":util", + "//tensorflow:tensorflow_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:convert_to_constants", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:session", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", "@six_archive//:six", ], ) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index e919aa4b00f..a08b40bbed6 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -61,6 +61,7 @@ from tensorflow.lite.python.util import get_grappler_config as _get_grappler_con from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph +from tensorflow.lite.python.util import modify_integer_quantized_model_io_type as _modify_integer_quantized_model_io_type from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes from tensorflow.python import keras as _keras @@ -314,6 +315,23 @@ class QuantizationMode(object): else: return False, None + def flags_modify_model_io_type( + self, input_type=constants.FLOAT, output_type=constants.FLOAT): + """Flags for modifying the input and output type of a tflite model.""" + is_post_training_quantize = self.quantizer_flags(input_type, output_type)[0] + is_training_time_only_quantize = self.training_time_int8_allow_float() and \ + not is_post_training_quantize + + # TODO(b/153576658): Consolidate post/during training quantization workflows + # to modify model input/output type after MLIR conversion. + if is_training_time_only_quantize: + return { + "inference_input_type": input_type, + "inference_output_type": output_type, + } + else: + return None + # Below are helpers for the above functions. def _validate_int8_required(self): @@ -557,9 +575,8 @@ class TFLiteConverterBaseV2(TFLiteConverterBase): def _validate_inference_input_output_types(self, quant_mode): """Validate inference_input_type and inference_output_type flags.""" default_types = [constants.FLOAT, None] - # We only support integer types for post training integer quantization - # as we have statistical information to quantize the input and output. - if quant_mode.is_post_training_integer_quantize(): + # We support integer input/output for integer quantized models only. + if quant_mode.training_time_int8_allow_float(): all_types = default_types + [constants.INT8, constants.QUANTIZED_UINT8] if self.inference_input_type not in all_types or \ self.inference_output_type not in all_types: @@ -643,6 +660,12 @@ class TFLiteConverterBaseV2(TFLiteConverterBase): if calibrate_and_quantize: result = self._calibrate_quantize_model(result, **flags) + flags_modify_model_io_type = quant_mode.flags_modify_model_io_type( + self.inference_input_type, self.inference_output_type) + if flags_modify_model_io_type: + result = _modify_integer_quantized_model_io_type( + result, **flags_modify_model_io_type) + if self._experimental_sparsify_model: result = _mlir_sparsify(result) diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 6fab4fd6086..4093a9d5bb4 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -374,8 +374,12 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): return tf.keras.Sequential(QLinear(3, input_shape=(2,))) + @parameterized.named_parameters( + ('_DefaultFLOAT32InputOutput', lite.constants.FLOAT), + ('_INT8InputOutput', lite.constants.INT8), + ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8)) @test_util.run_v2_only - def testTrainingTimeQuantization(self): + def testTrainingTimeQuantization(self, inference_input_output_type): model = self._getTrainingTimeQuantizedModel() float_converter = lite.TFLiteConverterV2.from_keras_model(model) @@ -384,37 +388,24 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): quantized_converter = lite.TFLiteConverterV2.from_keras_model(model) quantized_converter.optimizations = [lite.Optimize.DEFAULT] + quantized_converter.inference_input_type = inference_input_output_type + quantized_converter.inference_output_type = inference_input_output_type quantized_tflite = quantized_converter.convert() self.assertTrue(quantized_tflite) - # Ensure that the quantized weights tflite model is smaller. - self.assertLess(len(quantized_tflite), len(float_tflite)) - interpreter = Interpreter(model_content=quantized_tflite) - self.assertEqual(np.float32, interpreter.get_input_details()[0]['dtype']) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + self.assertLen(input_details, 1) + self.assertEqual(inference_input_output_type.as_numpy_dtype, + input_details[0]['dtype']) + output_details = interpreter.get_output_details() + self.assertLen(output_details, 1) + self.assertEqual(inference_input_output_type.as_numpy_dtype, + output_details[0]['dtype']) - @parameterized.named_parameters( - ('_INT8InputOutput', lite.constants.INT8), - ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8)) - def testInvalidTrainingTimeQuantization(self, inference_input_output_type): - # We currently don't support integer inference_input_type and - # inference_output_type flags for training time quantization. - - model = self._getTrainingTimeQuantizedModel() - - converter = lite.TFLiteConverterV2.from_keras_model(model) - tflite_model = converter.convert() - self.assertTrue(tflite_model) - - quantized_converter = lite.TFLiteConverterV2.from_keras_model(model) - quantized_converter.optimizations = [lite.Optimize.DEFAULT] - with self.assertRaises(ValueError) as error: - quantized_converter.inference_input_type = inference_input_output_type - quantized_converter.inference_output_type = inference_input_output_type - quantized_converter.convert() - self.assertEqual( - 'The inference_input_type and inference_output_type ' - 'must be tf.float32.', str(error.exception)) + # Ensure that the quantized tflite model is smaller. + self.assertLess(len(quantized_tflite), len(float_tflite)) @test_util.run_v2_only def testNewQuantizer(self): diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index ff7caad0f88..9f84681c12b 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -19,15 +19,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy import datetime import sys +from absl import logging + import six from six.moves import range +from flatbuffers.python import flatbuffers from tensorflow.core.protobuf import config_pb2 as _config_pb2 from tensorflow.core.protobuf import graph_debug_info_pb2 from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 +from tensorflow.lite.python import lite_constants as _lite_constants +from tensorflow.lite.python import schema_py_generated as _schema_fb from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes from tensorflow.lite.toco import types_pb2 as _types_pb2 @@ -55,6 +61,25 @@ _MAP_TF_TO_TFLITE_TYPES = { dtypes.bool: _types_pb2.BOOL, } +_MAP_TFLITE_ENUM_TO_TF_TYPES = { + 0: dtypes.float32, + 1: dtypes.float16, + 2: dtypes.int32, + 3: dtypes.uint8, + 4: dtypes.int64, + 5: dtypes.string, + 6: dtypes.bool, + 7: dtypes.int16, + 8: dtypes.complex64, + 9: dtypes.int8, + 10: dtypes.float64, +} + +_TFLITE_FILE_IDENTIFIER = b"TFL3" + +_TFLITE_MODEL_INPUT_OUTPUT_TYPES = (_lite_constants.FLOAT, _lite_constants.INT8, + _lite_constants.QUANTIZED_UINT8) + def convert_dtype_to_tflite_type(tf_dtype): """Converts tf.dtype to TFLite proto type. @@ -74,6 +99,31 @@ def convert_dtype_to_tflite_type(tf_dtype): return result +def _convert_tflite_enum_type_to_tf_type(tflite_enum_type): + """Converts tflite enum type (eg: 0) to tf type (eg: tf.float32). + + Args: + tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32) + + Raises: + ValueError: If an invalid tflite enum type is provided. + + Returns: + tf type (eg: tf.float32) + """ + tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type) + if tf_type is None: + raise ValueError( + "Unsupported enum {}. The valid map of enum to tf.dtypes is : {}" + .format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES)) + return tf_type + + +def _get_dtype_name(tf_type): + """Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32").""" + return "tf." + tf_type.name + + def get_tensor_name(tensor): """Returns name of the input tensor. @@ -514,3 +564,218 @@ extern const int {array_name}_len; license_text=license_text) return source_text, header_text + + +def _convert_model_from_bytearray_to_object(model_bytearray): + """Converts a tflite model from a bytearray into a parsable object.""" + model_object = _schema_fb.Model.GetRootAsModel(model_bytearray, 0) + model_object = _schema_fb.ModelT.InitFromObj(model_object) + model_object = copy.deepcopy(model_object) + model_object.subgraphs[0].inputs[0] = model_object.subgraphs[0].inputs[0] + return model_object + + +def _convert_model_from_object_to_bytearray(model_object): + """Converts a tflite model from a parsable object into a bytearray.""" + # Initial size of the buffer, which will grow automatically if needed + builder = flatbuffers.Builder(1024) + model_offset = model_object.Pack(builder) + builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER) + return bytes(builder.Output()) + + +def _remove_tensors_from_model(model, remove_tensors_idxs): + """Remove tensors from model.""" + if not remove_tensors_idxs: + return + if len(model.subgraphs) > 1: + raise ValueError("Model must only have one subgraph. Instead, it has " + "{} subgraphs.".format(len(model.subgraphs))) + subgraph = model.subgraphs[0] + tensors = subgraph.tensors + operators = subgraph.operators + + logging.debug("Removing tensors at indices : %s", remove_tensors_idxs) + # An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an + # exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]). + if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs): + logging.debug("Removing tensors only at the end of the tensor list") + del tensors[min(remove_tensors_idxs):] + else: + logging.debug("Removing tensors requires updating the model") + # Map the old tensor indices to new tensor indices + d_old_to_new_tensors = {} + left_shift_by = 0 + for idx in range(len(tensors)): + if idx in remove_tensors_idxs: + left_shift_by += 1 + else: + d_old_to_new_tensors[idx] = idx - left_shift_by + logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__()) + # Update tensor indices referenced throughout the model + def update_tensors(tensor_idxs): + for i, ti in enumerate(tensor_idxs): + tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1) + update_tensors(subgraph.inputs) + update_tensors(subgraph.outputs) + for op in operators: + update_tensors(op.inputs) + update_tensors(op.outputs) + # Delete the tensors + for idx in sorted(remove_tensors_idxs, reverse=True): + tensors.pop(idx) + logging.debug("Removed tensors marked for deletion") + + +def _validate_and_find_int8_quantized_inputs_outputs(model): + """Validate that model input is quantized and output is dequantized.""" + if len(model.subgraphs) > 1: + raise ValueError("Model must only have one subgraph. Instead, it has " + "{} subgraphs.".format(len(model.subgraphs))) + subgraph = model.subgraphs[0] + tensors = subgraph.tensors + operators = subgraph.operators + + # Ensure model has atleast one quantize and dequantize operator + quant_opcode_idx, dequant_opcode_idx = None, None + for idx, opcode in enumerate(model.operatorCodes): + if opcode.builtinCode == _schema_fb.BuiltinOperator.QUANTIZE: + quant_opcode_idx = idx + elif opcode.builtinCode == _schema_fb.BuiltinOperator.DEQUANTIZE: + dequant_opcode_idx = idx + if quant_opcode_idx is not None and dequant_opcode_idx is not None: + break + if quant_opcode_idx is None and dequant_opcode_idx is None: + raise ValueError("Model is not integer quantized as it does not " + "contain quantize/dequantize operators.") + + # Ensure model inputs and outputs are integer quantized + input_quant_ops, output_dequant_ops = [], [] + for op in operators: + # Find input quantize operator + if op.opcodeIndex == quant_opcode_idx and op.inputs[0] in subgraph.inputs: + pos, float_tensor, int_tensor = \ + "input", tensors[op.inputs[0]], tensors[op.outputs[0]] + input_quant_ops.append(op) + # Find output dequantize operator + elif op.opcodeIndex == dequant_opcode_idx and \ + op.outputs[0] in subgraph.outputs: + pos, float_tensor, int_tensor = \ + "output", tensors[op.outputs[0]], tensors[op.inputs[0]] + output_dequant_ops.append(op) + # Otherwise, ignore + else: + continue + # If found, validate the input/output tensor type + if float_tensor.type != _schema_fb.TensorType.FLOAT32: + raise ValueError( + "Model {} type must be tf.float32. Expected type for tensor with " + "name '{}' is tf.float32, instead type is tf.{}".format( + pos, float_tensor.name, + _convert_tflite_enum_type_to_tf_type(float_tensor.type).name)) + if int_tensor.type != _schema_fb.TensorType.INT8: + raise ValueError( + "Model is not integer quantized. Expected type for tensor with " + "name '{}' is tf.int8, instead type is tf.{}".format( + int_tensor.name, + _convert_tflite_enum_type_to_tf_type(int_tensor.type).name)) + + return input_quant_ops, output_dequant_ops + + +def modify_integer_quantized_model_io_type( + model, inference_input_type=_lite_constants.FLOAT, + inference_output_type=_lite_constants.FLOAT): + """Modify the float input/output type of an integer quantized model. + + Args: + model: An int8 quantized tflite model with float input and output. + inference_input_type: tf.DType representing final input type. + (default tf.float32) + inference_output_type: tf.DType representing final output type. + (default tf.float32) + + Returns: + An int8 quantized tflite model with modified input and/or output type. + + Raises: + ValueError: If the model is not int8 quantized or the inference_input_type + and/or inference_input_type is unsupported. + RuntimeError: If the modification was unsuccessful. + + """ + # Return if input and output types default to float + if inference_input_type == _lite_constants.FLOAT and \ + inference_output_type == _lite_constants.FLOAT: + return model + + # Validate input and output types + if inference_input_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES: + raise ValueError("The `inference_input_type` should be in {}".format( + tuple(_get_dtype_name(t) for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES))) + if inference_output_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES: + raise ValueError("The `inference_output_type` should be in {}".format( + tuple(_get_dtype_name(t) for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES))) + + logging.debug(("Attempting to modify the model input from tf.float32 to %s " + "and output from tf.float32 to %s"), + _get_dtype_name(inference_input_type), + _get_dtype_name(inference_output_type)) + # Convert the model to an object + model = _convert_model_from_bytearray_to_object(model) + + # Validate the integer quantized model + input_quant_ops, output_dequant_ops = \ + _validate_and_find_int8_quantized_inputs_outputs(model) + + # Initialize references and variables + if len(model.subgraphs) > 1: + raise ValueError("Model must only have one subgraph. Instead, it has " + "{} subgraphs.".format(len(model.subgraphs))) + subgraph = model.subgraphs[0] + tensors = subgraph.tensors + operators = subgraph.operators + remove_tensors_idxs = set() + + # Modify model input type + if inference_input_type == _lite_constants.QUANTIZED_UINT8: + # Change quant op (float to int8) to quant op (uint8 to int8) + for op in input_quant_ops: + int8_quantization = tensors[op.outputs[0]].quantization + uint8_quantization = _schema_fb.QuantizationParametersT() + uint8_quantization.scale = [int8_quantization.scale[0]] + uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] + tensors[op.inputs[0]].quantization = uint8_quantization + tensors[op.inputs[0]].type = _schema_fb.TensorType.UINT8 + elif inference_input_type == _lite_constants.INT8: + # Remove the inputs and the quant operator + for op in input_quant_ops: + subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0] + remove_tensors_idxs.add(op.inputs[0]) + operators.remove(op) + + # Modify model output type + if inference_output_type == _lite_constants.QUANTIZED_UINT8: + # Change dequant op (int8 to float) to quant op (int8 to uint8) + for op in output_dequant_ops: + op.opcodeIndex = input_quant_ops[0].opcodeIndex + int8_quantization = tensors[op.inputs[0]].quantization + uint8_quantization = _schema_fb.QuantizationParametersT() + uint8_quantization.scale = [int8_quantization.scale[0]] + uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] + tensors[op.outputs[0]].quantization = uint8_quantization + tensors[op.outputs[0]].type = _schema_fb.TensorType.UINT8 + elif inference_output_type == _lite_constants.INT8: + # Remove the outputs and the dequant operator + for op in output_dequant_ops: + subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0] + remove_tensors_idxs.add(op.outputs[0]) + operators.remove(op) + + # Remove tensors marked for deletion. + _remove_tensors_from_model(model, remove_tensors_idxs) + + # Convert the model to a bytearray + model = _convert_model_from_object_to_bytearray(model) + + return model diff --git a/tensorflow/lite/python/util_test.py b/tensorflow/lite/python/util_test.py index f3c287dd7fc..0e9cbc1e58a 100644 --- a/tensorflow/lite/python/util_test.py +++ b/tensorflow/lite/python/util_test.py @@ -19,7 +19,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized +import numpy as np from six.moves import range +import tensorflow as tf from tensorflow.lite.python import lite_constants from tensorflow.lite.python import util @@ -61,6 +64,31 @@ class UtilTest(test_util.TensorFlowTestCase): self.assertEqual( util.convert_dtype_to_tflite_type(dtypes.bool), _types_pb2.BOOL) + def testConvertEnumToDtype(self): + self.assertEqual( + util._convert_tflite_enum_type_to_tf_type(0), dtypes.float32) + self.assertEqual( + util._convert_tflite_enum_type_to_tf_type(1), dtypes.float16) + self.assertEqual(util._convert_tflite_enum_type_to_tf_type(2), dtypes.int32) + self.assertEqual(util._convert_tflite_enum_type_to_tf_type(3), dtypes.uint8) + self.assertEqual(util._convert_tflite_enum_type_to_tf_type(4), dtypes.int64) + self.assertEqual( + util._convert_tflite_enum_type_to_tf_type(5), dtypes.string) + self.assertEqual(util._convert_tflite_enum_type_to_tf_type(6), dtypes.bool) + self.assertEqual(util._convert_tflite_enum_type_to_tf_type(7), dtypes.int16) + self.assertEqual( + util._convert_tflite_enum_type_to_tf_type(8), dtypes.complex64) + self.assertEqual(util._convert_tflite_enum_type_to_tf_type(9), dtypes.int8) + self.assertEqual( + util._convert_tflite_enum_type_to_tf_type(10), dtypes.float64) + with self.assertRaises(ValueError) as error: + util._convert_tflite_enum_type_to_tf_type(11) + self.assertEqual( + "Unsupported enum 11. The valid map of enum to tf.dtypes is : " + "{0: tf.float32, 1: tf.float16, 2: tf.int32, 3: tf.uint8, 4: tf.int64, " + "5: tf.string, 6: tf.bool, 7: tf.int16, 8: tf.complex64, 9: tf.int8, " + "10: tf.float64}", str(error.exception)) + def testTensorName(self): with ops.Graph().as_default(): in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32) @@ -195,5 +223,140 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase): self.assertEqual([None, 3, 5], tensor.shape.as_list()) +def _generate_integer_tflite_model(): + """Define an integer post-training quantized tflite model.""" + # Load MNIST dataset + n = 10 # Number of samples + (train_images, train_labels), (test_images, test_labels) = \ + tf.keras.datasets.mnist.load_data() + train_images, train_labels, test_images, test_labels = \ + train_images[:n], train_labels[:n], test_images[:n], test_labels[:n] + + # Normalize the input image so that each pixel value is between 0 to 1. + train_images = train_images / 255.0 + test_images = test_images / 255.0 + + # Define TF model + model = tf.keras.Sequential([ + tf.keras.layers.InputLayer(input_shape=(28, 28)), + tf.keras.layers.Reshape(target_shape=(28, 28, 1)), + tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation="relu"), + tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(10) + ]) + + # Train + model.compile( + optimizer="adam", + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=["accuracy"]) + + model.fit( + train_images, + train_labels, + epochs=1, + validation_split=0.1, + ) + + # Convert TF Model to an Integer Quantized TFLite Model + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = {tf.lite.Optimize.DEFAULT} + def representative_dataset_gen(): + for _ in range(2): + yield [ + np.random.uniform(low=0, high=1, size=(1, 28, 28)).astype( + np.float32) + ] + converter.representative_dataset = representative_dataset_gen + converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8} + tflite_model = converter.convert() + + return tflite_model + + +def _test_param_modify_integer_model_io_type(): + """Function to generate parameterized inputs for testing.""" + params = [] + str_template = "_{}{}{}" + map_model_type = { + "PostTraining": True, + # "DuringTraining": False, + } + map_types = { + "": lite_constants.FLOAT, + "INT8": lite_constants.INT8, + "UINT8": lite_constants.QUANTIZED_UINT8 + } + for k1, v1 in map_model_type.items(): + for k2, v2 in map_types.items(): + istr = "_Input{}".format(k2) if k2 else "" + for k3, v3 in map_types.items(): + ostr = "_Output{}".format(k3) if k3 else "" if istr else "_NoUpdate" + params.append((str_template.format(k1, istr, ostr), v1, v2, v3)) + return params + + +# TODO(b/161174063): Merge tests for integer input/output type +class UtilModifyIntegerQuantizedModelIOTypeTest( + test_util.TensorFlowTestCase, parameterized.TestCase): + + @classmethod + def setUpClass(cls): + super(UtilModifyIntegerQuantizedModelIOTypeTest, cls).setUpClass() + cls.post_train_integer_model = _generate_integer_tflite_model() + + @parameterized.named_parameters(_test_param_modify_integer_model_io_type()) + def test(self, is_post_train, in_tftype, out_tftype): + """Modify the float input/output type of an integer quantized model.""" + + def _run_tflite_inference(model, in_tftype, out_tftype): + """Run inference on a model with a specific input/output type.""" + # Load TFLite model and allocate tensors. + interpreter = tf.lite.Interpreter(model_content=model) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details()[0] + output_details = interpreter.get_output_details()[0] + + # Validate TFLite model input and output types + self.assertEqual(input_details["dtype"], in_tftype.as_numpy_dtype) + self.assertEqual(output_details["dtype"], out_tftype.as_numpy_dtype) + + # Define Input + np.random.seed(0) + input_data = np.random.uniform(low=0, high=1, size=(1, 28, 28)) + input_data = input_data.astype(np.float32) + if input_details["dtype"] != np.float32: + # quantize float to int + scale, zero_point = input_details["quantization"] + input_data = input_data / scale + zero_point + input_data = input_data.astype(input_details["dtype"]) + + # Run Inference + interpreter.set_tensor(input_details["index"], input_data) + interpreter.invoke() + + # Get output + output_data = interpreter.get_tensor(output_details["index"])[0] + if output_details["dtype"] != np.float32: + # dequantize int to float + scale, zero_point = output_details["quantization"] + output_data = output_data.astype(np.float32) + output_data = (output_data - zero_point) * scale + + return output_data + + model = self.__class__.post_train_integer_model if is_post_train else None + # Run model inference with float input output type + output_data = _run_tflite_inference(model, tf.float32, tf.float32) + # Run model inference with modified integer input output type + model_io = util.modify_integer_quantized_model_io_type( + model, in_tftype, out_tftype) + output_io_data = _run_tflite_inference(model_io, in_tftype, out_tftype) + + # Validate that both the outputs are the same + self.assertTrue(np.allclose(output_data, output_io_data, atol=1.0)) + + if __name__ == "__main__": test.main()