added INT16 support in corresponding python wrappers

This commit is contained in:
Tamas Nyiri 2020-07-14 12:44:56 +01:00
parent 2f4444c1cf
commit 1fc9c95a10
3 changed files with 43 additions and 3 deletions

View File

@ -33,7 +33,7 @@ PYBIND11_MODULE(_pywrap_modify_model_interface, m) {
return tflite::optimize::ModifyModelInterface(
input_file, output_file,
static_cast<tflite::TensorType>(input_type),
static_cast<tflite::TensorType>(input_type));
static_cast<tflite::TensorType>(output_type));
});
}

View File

@ -23,6 +23,7 @@ from tensorflow.lite.python import lite_constants
STR_TO_TFLITE_TYPES = {
'INT8': lite_constants.INT8,
'INT16': lite_constants.INT16,
'UINT8': lite_constants.QUANTIZED_UINT8
}
TFLITE_TO_STR_TYPES = {v: k for k, v in STR_TO_TFLITE_TYPES.items()}

View File

@ -28,7 +28,9 @@ from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
def build_tflite_model_with_full_integer_quantization():
def build_tflite_model_with_full_integer_quantization(supported_ops=
tf.lite.OpsSet.
TFLITE_BUILTINS_INT8):
# Define TF model
input_size = 3
model = tf.keras.Sequential([
@ -46,7 +48,7 @@ def build_tflite_model_with_full_integer_quantization():
yield [np.array([i] * input_size, dtype=np.float32)]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.target_spec.supported_ops = [supported_ops]
tflite_model = converter.convert()
return tflite_model
@ -89,6 +91,43 @@ class ModifyModelInterfaceTest(test_util.TensorFlowTestCase):
self.assertEqual(final_input_dtype, np.int8)
self.assertEqual(final_output_dtype, np.int8)
def testInt16Interface(self):
# 1. SETUP
# Define the temporary directory and files
temp_dir = self.get_temp_dir()
initial_file = os.path.join(temp_dir, 'initial_model.tflite')
final_file = os.path.join(temp_dir, 'final_model.tflite')
# Define initial model
initial_model = build_tflite_model_with_full_integer_quantization(
supported_ops=tf.lite.OpsSet.
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8)
with open(initial_file, 'wb') as model_file:
model_file.write(initial_model)
# 2. INVOKE
# Invoke the modify_model_interface function
modify_model_interface_lib.modify_model_interface(initial_file, final_file,
tf.int16, tf.int16)
# 3. VALIDATE
# Load TFLite model and allocate tensors.
initial_interpreter = tf.lite.Interpreter(model_path=initial_file)
initial_interpreter.allocate_tensors()
final_interpreter = tf.lite.Interpreter(model_path=final_file)
final_interpreter.allocate_tensors()
# Get input and output types.
initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype']
initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype']
final_input_dtype = final_interpreter.get_input_details()[0]['dtype']
final_output_dtype = final_interpreter.get_output_details()[0]['dtype']
# Validate the model interfaces
self.assertEqual(initial_input_dtype, np.float32)
self.assertEqual(initial_output_dtype, np.float32)
self.assertEqual(final_input_dtype, np.int16)
self.assertEqual(final_output_dtype, np.int16)
def testUInt8Interface(self):
# 1. SETUP
# Define the temporary directory and files