From 1fc9c95a100af0689134f8834c04603b7b482dd5 Mon Sep 17 00:00:00 2001 From: Tamas Nyiri Date: Tue, 14 Jul 2020 12:44:56 +0100 Subject: [PATCH] added INT16 support in corresponding python wrappers --- .../optimize/python/modify_model_interface.cc | 2 +- .../modify_model_interface_constants.py | 1 + .../python/modify_model_interface_lib_test.py | 43 ++++++++++++++++++- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/tools/optimize/python/modify_model_interface.cc b/tensorflow/lite/tools/optimize/python/modify_model_interface.cc index ed67b07cb0f..cd2e4a192e9 100644 --- a/tensorflow/lite/tools/optimize/python/modify_model_interface.cc +++ b/tensorflow/lite/tools/optimize/python/modify_model_interface.cc @@ -33,7 +33,7 @@ PYBIND11_MODULE(_pywrap_modify_model_interface, m) { return tflite::optimize::ModifyModelInterface( input_file, output_file, static_cast(input_type), - static_cast(input_type)); + static_cast(output_type)); }); } diff --git a/tensorflow/lite/tools/optimize/python/modify_model_interface_constants.py b/tensorflow/lite/tools/optimize/python/modify_model_interface_constants.py index 42767268e48..cbe1aa92022 100644 --- a/tensorflow/lite/tools/optimize/python/modify_model_interface_constants.py +++ b/tensorflow/lite/tools/optimize/python/modify_model_interface_constants.py @@ -23,6 +23,7 @@ from tensorflow.lite.python import lite_constants STR_TO_TFLITE_TYPES = { 'INT8': lite_constants.INT8, + 'INT16': lite_constants.INT16, 'UINT8': lite_constants.QUANTIZED_UINT8 } TFLITE_TO_STR_TYPES = {v: k for k, v in STR_TO_TFLITE_TYPES.items()} diff --git a/tensorflow/lite/tools/optimize/python/modify_model_interface_lib_test.py b/tensorflow/lite/tools/optimize/python/modify_model_interface_lib_test.py index e97f0db9bbb..70ae0ad4376 100644 --- a/tensorflow/lite/tools/optimize/python/modify_model_interface_lib_test.py +++ b/tensorflow/lite/tools/optimize/python/modify_model_interface_lib_test.py @@ -28,7 +28,9 @@ from tensorflow.python.framework import test_util from tensorflow.python.platform import test -def build_tflite_model_with_full_integer_quantization(): +def build_tflite_model_with_full_integer_quantization(supported_ops= + tf.lite.OpsSet. + TFLITE_BUILTINS_INT8): # Define TF model input_size = 3 model = tf.keras.Sequential([ @@ -46,7 +48,7 @@ def build_tflite_model_with_full_integer_quantization(): yield [np.array([i] * input_size, dtype=np.float32)] converter.representative_dataset = representative_dataset_gen - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.target_spec.supported_ops = [supported_ops] tflite_model = converter.convert() return tflite_model @@ -89,6 +91,43 @@ class ModifyModelInterfaceTest(test_util.TensorFlowTestCase): self.assertEqual(final_input_dtype, np.int8) self.assertEqual(final_output_dtype, np.int8) + def testInt16Interface(self): + # 1. SETUP + # Define the temporary directory and files + temp_dir = self.get_temp_dir() + initial_file = os.path.join(temp_dir, 'initial_model.tflite') + final_file = os.path.join(temp_dir, 'final_model.tflite') + # Define initial model + initial_model = build_tflite_model_with_full_integer_quantization( + supported_ops=tf.lite.OpsSet. + EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8) + with open(initial_file, 'wb') as model_file: + model_file.write(initial_model) + + # 2. INVOKE + # Invoke the modify_model_interface function + modify_model_interface_lib.modify_model_interface(initial_file, final_file, + tf.int16, tf.int16) + + # 3. VALIDATE + # Load TFLite model and allocate tensors. + initial_interpreter = tf.lite.Interpreter(model_path=initial_file) + initial_interpreter.allocate_tensors() + final_interpreter = tf.lite.Interpreter(model_path=final_file) + final_interpreter.allocate_tensors() + + # Get input and output types. + initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype'] + initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype'] + final_input_dtype = final_interpreter.get_input_details()[0]['dtype'] + final_output_dtype = final_interpreter.get_output_details()[0]['dtype'] + + # Validate the model interfaces + self.assertEqual(initial_input_dtype, np.float32) + self.assertEqual(initial_output_dtype, np.float32) + self.assertEqual(final_input_dtype, np.int16) + self.assertEqual(final_output_dtype, np.int16) + def testUInt8Interface(self): # 1. SETUP # Define the temporary directory and files