added INT16 support in corresponding python wrappers
This commit is contained in:
parent
2f4444c1cf
commit
1fc9c95a10
@ -33,7 +33,7 @@ PYBIND11_MODULE(_pywrap_modify_model_interface, m) {
|
||||
return tflite::optimize::ModifyModelInterface(
|
||||
input_file, output_file,
|
||||
static_cast<tflite::TensorType>(input_type),
|
||||
static_cast<tflite::TensorType>(input_type));
|
||||
static_cast<tflite::TensorType>(output_type));
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ from tensorflow.lite.python import lite_constants
|
||||
|
||||
STR_TO_TFLITE_TYPES = {
|
||||
'INT8': lite_constants.INT8,
|
||||
'INT16': lite_constants.INT16,
|
||||
'UINT8': lite_constants.QUANTIZED_UINT8
|
||||
}
|
||||
TFLITE_TO_STR_TYPES = {v: k for k, v in STR_TO_TFLITE_TYPES.items()}
|
||||
|
||||
@ -28,7 +28,9 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def build_tflite_model_with_full_integer_quantization():
|
||||
def build_tflite_model_with_full_integer_quantization(supported_ops=
|
||||
tf.lite.OpsSet.
|
||||
TFLITE_BUILTINS_INT8):
|
||||
# Define TF model
|
||||
input_size = 3
|
||||
model = tf.keras.Sequential([
|
||||
@ -46,7 +48,7 @@ def build_tflite_model_with_full_integer_quantization():
|
||||
yield [np.array([i] * input_size, dtype=np.float32)]
|
||||
|
||||
converter.representative_dataset = representative_dataset_gen
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||
converter.target_spec.supported_ops = [supported_ops]
|
||||
tflite_model = converter.convert()
|
||||
|
||||
return tflite_model
|
||||
@ -89,6 +91,43 @@ class ModifyModelInterfaceTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(final_input_dtype, np.int8)
|
||||
self.assertEqual(final_output_dtype, np.int8)
|
||||
|
||||
def testInt16Interface(self):
|
||||
# 1. SETUP
|
||||
# Define the temporary directory and files
|
||||
temp_dir = self.get_temp_dir()
|
||||
initial_file = os.path.join(temp_dir, 'initial_model.tflite')
|
||||
final_file = os.path.join(temp_dir, 'final_model.tflite')
|
||||
# Define initial model
|
||||
initial_model = build_tflite_model_with_full_integer_quantization(
|
||||
supported_ops=tf.lite.OpsSet.
|
||||
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8)
|
||||
with open(initial_file, 'wb') as model_file:
|
||||
model_file.write(initial_model)
|
||||
|
||||
# 2. INVOKE
|
||||
# Invoke the modify_model_interface function
|
||||
modify_model_interface_lib.modify_model_interface(initial_file, final_file,
|
||||
tf.int16, tf.int16)
|
||||
|
||||
# 3. VALIDATE
|
||||
# Load TFLite model and allocate tensors.
|
||||
initial_interpreter = tf.lite.Interpreter(model_path=initial_file)
|
||||
initial_interpreter.allocate_tensors()
|
||||
final_interpreter = tf.lite.Interpreter(model_path=final_file)
|
||||
final_interpreter.allocate_tensors()
|
||||
|
||||
# Get input and output types.
|
||||
initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype']
|
||||
initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype']
|
||||
final_input_dtype = final_interpreter.get_input_details()[0]['dtype']
|
||||
final_output_dtype = final_interpreter.get_output_details()[0]['dtype']
|
||||
|
||||
# Validate the model interfaces
|
||||
self.assertEqual(initial_input_dtype, np.float32)
|
||||
self.assertEqual(initial_output_dtype, np.float32)
|
||||
self.assertEqual(final_input_dtype, np.int16)
|
||||
self.assertEqual(final_output_dtype, np.int16)
|
||||
|
||||
def testUInt8Interface(self):
|
||||
# 1. SETUP
|
||||
# Define the temporary directory and files
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user