Add python wrapper to modify model interface for full integer quantization

PiperOrigin-RevId: 308359969
Change-Id: I5929e4778e5e12651b64665bfa9f3dbc6f82e72b
This commit is contained in:
Meghna Natraj 2020-04-24 17:52:11 -07:00 committed by TensorFlower Gardener
parent 490288d631
commit a07ca66517
6 changed files with 427 additions and 0 deletions

View File

@ -0,0 +1,67 @@
load("//tensorflow:tensorflow.bzl", "pybind_extension")
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
py_binary(
name = "modify_model_interface",
srcs = ["modify_model_interface.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":modify_model_interface_constants",
":modify_model_interface_lib",
"//tensorflow/python:platform",
],
)
py_library(
name = "modify_model_interface_lib",
srcs = ["modify_model_interface_lib.py"],
srcs_version = "PY3",
deps = [
":_pywrap_modify_model_interface",
":modify_model_interface_constants",
"//tensorflow:tensorflow_py",
"//tensorflow/lite/python:schema_py",
],
)
py_test(
name = "modify_model_interface_lib_test",
srcs = ["modify_model_interface_lib_test.py"],
python_version = "PY3",
srcs_version = "PY3",
tags = [
"no_mac", # TODO(b/148247402): flatbuffers import broken on Mac OS.
],
deps = [
":modify_model_interface_lib",
"//tensorflow:tensorflow_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//third_party/py/numpy",
],
)
py_library(
name = "modify_model_interface_constants",
srcs = ["modify_model_interface_constants.py"],
srcs_version = "PY3",
deps = ["//tensorflow/lite/python:lite_constants"],
)
pybind_extension(
name = "_pywrap_modify_model_interface",
srcs = ["modify_model_interface.cc"],
module_name = "_pywrap_modify_model_interface",
deps = [
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/tools/optimize:modify_model_interface",
"@pybind11",
],
)

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Python wrapper to modify model interface.
#include "tensorflow/lite/tools/optimize/modify_model_interface.h"
#include <string>
#include "pybind11/pybind11.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace pybind11 {
PYBIND11_MODULE(_pywrap_modify_model_interface, m) {
// An anonymous function that invokes the C++ function
// after applying transformations to the python function arguments
m.def("modify_model_interface",
[](const std::string& input_file, const std::string& output_file,
const int input_type, const int output_type) -> int {
return tflite::optimize::ModifyModelInterface(
input_file, output_file,
static_cast<tflite::TensorType>(input_type),
static_cast<tflite::TensorType>(input_type));
});
}
} // namespace pybind11

View File

@ -0,0 +1,78 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Modify a quantized model's interface from float to integer.
Example usage:
python modify_model_interface_main.py \
--input_file=float_model.tflite \
--output_file=int_model.tflite \
--input_type=INT8 \
--output_type=INT8
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from tensorflow.lite.tools.optimize.python import modify_model_interface_constants as mmi_constants
from tensorflow.lite.tools.optimize.python import modify_model_interface_lib as mmi_lib
from tensorflow.python.platform import app
def main(_):
"""Application run loop."""
parser = argparse.ArgumentParser(
description="Modify a quantized model's interface from float to integer.")
parser.add_argument(
'--input_file',
type=str,
required=True,
help='Full path name to the input tflite file.')
parser.add_argument(
'--output_file',
type=str,
required=True,
help='Full path name to the output tflite file.')
parser.add_argument(
'--input_type',
type=str.upper,
choices=mmi_constants.STR_TYPES,
default=mmi_constants.DEFAULT_STR_TYPE,
help='Modified input integer interface type.')
parser.add_argument(
'--output_type',
type=str.upper,
choices=mmi_constants.STR_TYPES,
default=mmi_constants.DEFAULT_STR_TYPE,
help='Modified output integer interface type.')
args = parser.parse_args()
input_type = mmi_constants.STR_TO_TFLITE_TYPES[args.input_type]
output_type = mmi_constants.STR_TO_TFLITE_TYPES[args.output_type]
mmi_lib.modify_model_interface(args.input_file, args.output_file, input_type,
output_type)
print('Successfully modified the model input type from FLOAT to '
'{input_type} and output type from FLOAT to {output_type}.'.format(
input_type=args.input_type, output_type=args.output_type))
if __name__ == '__main__':
app.run(main=main, argv=sys.argv[:1])

View File

@ -0,0 +1,34 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Constants for modify_model_interface."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.lite.python import lite_constants
STR_TO_TFLITE_TYPES = {
'INT8': lite_constants.INT8,
'UINT8': lite_constants.QUANTIZED_UINT8
}
TFLITE_TO_STR_TYPES = {v: k for k, v in STR_TO_TFLITE_TYPES.items()}
STR_TYPES = STR_TO_TFLITE_TYPES.keys()
TFLITE_TYPES = STR_TO_TFLITE_TYPES.values()
DEFAULT_STR_TYPE = 'INT8'
DEFAULT_TFLITE_TYPE = STR_TO_TFLITE_TYPES[DEFAULT_STR_TYPE]

View File

@ -0,0 +1,79 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library to modify a quantized model's interface from float to integer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.lite.python import schema_py_generated as schema_fb
from tensorflow.lite.tools.optimize.python import _pywrap_modify_model_interface
from tensorflow.lite.tools.optimize.python import modify_model_interface_constants as mmi_constants
def _parse_type_to_int(dtype, flag):
"""Converts a tflite type to it's integer representation.
Args:
dtype: tf.DType representing the inference type.
flag: str representing the flag name.
Returns:
integer, a tflite TensorType enum value.
Raises:
ValueError: Unsupported tflite type.
"""
# Validate if dtype is supported in tflite and is a valid interface type.
if dtype not in mmi_constants.TFLITE_TYPES:
raise ValueError(
"Unsupported value '{0}' for {1}. Only {2} are supported.".format(
dtype, flag, mmi_constants.TFLITE_TYPES))
dtype_str = mmi_constants.TFLITE_TO_STR_TYPES[dtype]
dtype_int = schema_fb.TensorType.__dict__[dtype_str]
return dtype_int
def modify_model_interface(input_file, output_file, input_type, output_type):
"""Modify a quantized model's interface (input/output) from float to integer.
Args:
input_file: Full path name to the input tflite file.
output_file: Full path name to the output tflite file.
input_type: Final input interface type.
output_type: Final output interface type.
Raises:
RuntimeError: If the modification of the model interface was unsuccessful.
ValueError: If the input_type or output_type is unsupported.
"""
# Map the interface types to integer values
input_type_int = _parse_type_to_int(input_type, 'input_type')
output_type_int = _parse_type_to_int(output_type, 'output_type')
# Invoke the function to modify the model interface
status = _pywrap_modify_model_interface.modify_model_interface(
input_file, output_file, input_type_int, output_type_int)
# Throw an exception if the return status is an error.
if status != 0:
raise RuntimeError(
'Error occured when trying to modify the model input type from float '
'to {input_type} and output type from float to {output_type}.'.format(
input_type=input_type, output_type=output_type))

View File

@ -0,0 +1,129 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for modify_model_interface_lib.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from tensorflow.lite.tools.optimize.python import modify_model_interface_lib
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
def build_tflite_model_with_full_integer_quantization():
# Define TF model
input_size = 3
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(input_size,), dtype=tf.float32),
tf.keras.layers.Dense(units=5, activation=tf.nn.relu),
tf.keras.layers.Dense(units=2, activation=tf.nn.softmax)
])
# Convert TF Model to a Quantized TFLite Model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
def representative_dataset_gen():
for i in range(10):
yield [np.array([i] * input_size, dtype=np.float32)]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_model = converter.convert()
return tflite_model
class ModifyModelInterfaceTest(test_util.TensorFlowTestCase):
def testInt8Interface(self):
# 1. SETUP
# Define the temporary directory and files
temp_dir = self.get_temp_dir()
initial_file = os.path.join(temp_dir, 'initial_model.tflite')
final_file = os.path.join(temp_dir, 'final_model.tflite')
# Define initial model
initial_model = build_tflite_model_with_full_integer_quantization()
with open(initial_file, 'wb') as model_file:
model_file.write(initial_model)
# 2. INVOKE
# Invoke the modify_model_interface function
modify_model_interface_lib.modify_model_interface(initial_file, final_file,
tf.int8, tf.int8)
# 3. VALIDATE
# Load TFLite model and allocate tensors.
initial_interpreter = tf.lite.Interpreter(model_path=initial_file)
initial_interpreter.allocate_tensors()
final_interpreter = tf.lite.Interpreter(model_path=final_file)
final_interpreter.allocate_tensors()
# Get input and output types.
initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype']
initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype']
final_input_dtype = final_interpreter.get_input_details()[0]['dtype']
final_output_dtype = final_interpreter.get_output_details()[0]['dtype']
# Validate the model interfaces
self.assertEqual(initial_input_dtype, np.float32)
self.assertEqual(initial_output_dtype, np.float32)
self.assertEqual(final_input_dtype, np.int8)
self.assertEqual(final_output_dtype, np.int8)
def testUInt8Interface(self):
# 1. SETUP
# Define the temporary directory and files
temp_dir = self.get_temp_dir()
initial_file = os.path.join(temp_dir, 'initial_model.tflite')
final_file = os.path.join(temp_dir, 'final_model.tflite')
# Define initial model
initial_model = build_tflite_model_with_full_integer_quantization()
with open(initial_file, 'wb') as model_file:
model_file.write(initial_model)
# 2. INVOKE
# Invoke the modify_model_interface function
modify_model_interface_lib.modify_model_interface(initial_file, final_file,
tf.uint8, tf.uint8)
# 3. VALIDATE
# Load TFLite model and allocate tensors.
initial_interpreter = tf.lite.Interpreter(model_path=initial_file)
initial_interpreter.allocate_tensors()
final_interpreter = tf.lite.Interpreter(model_path=final_file)
final_interpreter.allocate_tensors()
# Get input and output types.
initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype']
initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype']
final_input_dtype = final_interpreter.get_input_details()[0]['dtype']
final_output_dtype = final_interpreter.get_output_details()[0]['dtype']
# Validate the model interfaces
self.assertEqual(initial_input_dtype, np.float32)
self.assertEqual(initial_output_dtype, np.float32)
self.assertEqual(final_input_dtype, np.uint8)
self.assertEqual(final_output_dtype, np.uint8)
if __name__ == '__main__':
test.main()