Add python wrapper to modify model interface for full integer quantization
PiperOrigin-RevId: 308359969 Change-Id: I5929e4778e5e12651b64665bfa9f3dbc6f82e72b
This commit is contained in:
parent
490288d631
commit
a07ca66517
67
tensorflow/lite/tools/optimize/python/BUILD
Normal file
67
tensorflow/lite/tools/optimize/python/BUILD
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
"//visibility:public",
|
||||||
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
py_binary(
|
||||||
|
name = "modify_model_interface",
|
||||||
|
srcs = ["modify_model_interface.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":modify_model_interface_constants",
|
||||||
|
":modify_model_interface_lib",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "modify_model_interface_lib",
|
||||||
|
srcs = ["modify_model_interface_lib.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":_pywrap_modify_model_interface",
|
||||||
|
":modify_model_interface_constants",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/lite/python:schema_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "modify_model_interface_lib_test",
|
||||||
|
srcs = ["modify_model_interface_lib_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
tags = [
|
||||||
|
"no_mac", # TODO(b/148247402): flatbuffers import broken on Mac OS.
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":modify_model_interface_lib",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "modify_model_interface_constants",
|
||||||
|
srcs = ["modify_model_interface_constants.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = ["//tensorflow/lite/python:lite_constants"],
|
||||||
|
)
|
||||||
|
|
||||||
|
pybind_extension(
|
||||||
|
name = "_pywrap_modify_model_interface",
|
||||||
|
srcs = ["modify_model_interface.cc"],
|
||||||
|
module_name = "_pywrap_modify_model_interface",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
"//tensorflow/lite/tools/optimize:modify_model_interface",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
@ -0,0 +1,40 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// Python wrapper to modify model interface.
|
||||||
|
|
||||||
|
#include "tensorflow/lite/tools/optimize/modify_model_interface.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
|
namespace pybind11 {
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_modify_model_interface, m) {
|
||||||
|
// An anonymous function that invokes the C++ function
|
||||||
|
// after applying transformations to the python function arguments
|
||||||
|
m.def("modify_model_interface",
|
||||||
|
[](const std::string& input_file, const std::string& output_file,
|
||||||
|
const int input_type, const int output_type) -> int {
|
||||||
|
return tflite::optimize::ModifyModelInterface(
|
||||||
|
input_file, output_file,
|
||||||
|
static_cast<tflite::TensorType>(input_type),
|
||||||
|
static_cast<tflite::TensorType>(input_type));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace pybind11
|
@ -0,0 +1,78 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
r"""Modify a quantized model's interface from float to integer.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
python modify_model_interface_main.py \
|
||||||
|
--input_file=float_model.tflite \
|
||||||
|
--output_file=int_model.tflite \
|
||||||
|
--input_type=INT8 \
|
||||||
|
--output_type=INT8
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tensorflow.lite.tools.optimize.python import modify_model_interface_constants as mmi_constants
|
||||||
|
from tensorflow.lite.tools.optimize.python import modify_model_interface_lib as mmi_lib
|
||||||
|
from tensorflow.python.platform import app
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
"""Application run loop."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Modify a quantized model's interface from float to integer.")
|
||||||
|
parser.add_argument(
|
||||||
|
'--input_file',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help='Full path name to the input tflite file.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--output_file',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help='Full path name to the output tflite file.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--input_type',
|
||||||
|
type=str.upper,
|
||||||
|
choices=mmi_constants.STR_TYPES,
|
||||||
|
default=mmi_constants.DEFAULT_STR_TYPE,
|
||||||
|
help='Modified input integer interface type.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--output_type',
|
||||||
|
type=str.upper,
|
||||||
|
choices=mmi_constants.STR_TYPES,
|
||||||
|
default=mmi_constants.DEFAULT_STR_TYPE,
|
||||||
|
help='Modified output integer interface type.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
input_type = mmi_constants.STR_TO_TFLITE_TYPES[args.input_type]
|
||||||
|
output_type = mmi_constants.STR_TO_TFLITE_TYPES[args.output_type]
|
||||||
|
|
||||||
|
mmi_lib.modify_model_interface(args.input_file, args.output_file, input_type,
|
||||||
|
output_type)
|
||||||
|
|
||||||
|
print('Successfully modified the model input type from FLOAT to '
|
||||||
|
'{input_type} and output type from FLOAT to {output_type}.'.format(
|
||||||
|
input_type=args.input_type, output_type=args.output_type))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(main=main, argv=sys.argv[:1])
|
@ -0,0 +1,34 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Constants for modify_model_interface."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.lite.python import lite_constants
|
||||||
|
|
||||||
|
STR_TO_TFLITE_TYPES = {
|
||||||
|
'INT8': lite_constants.INT8,
|
||||||
|
'UINT8': lite_constants.QUANTIZED_UINT8
|
||||||
|
}
|
||||||
|
TFLITE_TO_STR_TYPES = {v: k for k, v in STR_TO_TFLITE_TYPES.items()}
|
||||||
|
|
||||||
|
STR_TYPES = STR_TO_TFLITE_TYPES.keys()
|
||||||
|
TFLITE_TYPES = STR_TO_TFLITE_TYPES.values()
|
||||||
|
|
||||||
|
DEFAULT_STR_TYPE = 'INT8'
|
||||||
|
DEFAULT_TFLITE_TYPE = STR_TO_TFLITE_TYPES[DEFAULT_STR_TYPE]
|
@ -0,0 +1,79 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Library to modify a quantized model's interface from float to integer."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||||
|
from tensorflow.lite.tools.optimize.python import _pywrap_modify_model_interface
|
||||||
|
from tensorflow.lite.tools.optimize.python import modify_model_interface_constants as mmi_constants
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_type_to_int(dtype, flag):
|
||||||
|
"""Converts a tflite type to it's integer representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype: tf.DType representing the inference type.
|
||||||
|
flag: str representing the flag name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
integer, a tflite TensorType enum value.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Unsupported tflite type.
|
||||||
|
"""
|
||||||
|
# Validate if dtype is supported in tflite and is a valid interface type.
|
||||||
|
if dtype not in mmi_constants.TFLITE_TYPES:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported value '{0}' for {1}. Only {2} are supported.".format(
|
||||||
|
dtype, flag, mmi_constants.TFLITE_TYPES))
|
||||||
|
|
||||||
|
dtype_str = mmi_constants.TFLITE_TO_STR_TYPES[dtype]
|
||||||
|
dtype_int = schema_fb.TensorType.__dict__[dtype_str]
|
||||||
|
|
||||||
|
return dtype_int
|
||||||
|
|
||||||
|
|
||||||
|
def modify_model_interface(input_file, output_file, input_type, output_type):
|
||||||
|
"""Modify a quantized model's interface (input/output) from float to integer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_file: Full path name to the input tflite file.
|
||||||
|
output_file: Full path name to the output tflite file.
|
||||||
|
input_type: Final input interface type.
|
||||||
|
output_type: Final output interface type.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the modification of the model interface was unsuccessful.
|
||||||
|
ValueError: If the input_type or output_type is unsupported.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Map the interface types to integer values
|
||||||
|
input_type_int = _parse_type_to_int(input_type, 'input_type')
|
||||||
|
output_type_int = _parse_type_to_int(output_type, 'output_type')
|
||||||
|
|
||||||
|
# Invoke the function to modify the model interface
|
||||||
|
status = _pywrap_modify_model_interface.modify_model_interface(
|
||||||
|
input_file, output_file, input_type_int, output_type_int)
|
||||||
|
|
||||||
|
# Throw an exception if the return status is an error.
|
||||||
|
if status != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Error occured when trying to modify the model input type from float '
|
||||||
|
'to {input_type} and output type from float to {output_type}.'.format(
|
||||||
|
input_type=input_type, output_type=output_type))
|
@ -0,0 +1,129 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for modify_model_interface_lib.py."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.lite.tools.optimize.python import modify_model_interface_lib
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
def build_tflite_model_with_full_integer_quantization():
|
||||||
|
# Define TF model
|
||||||
|
input_size = 3
|
||||||
|
model = tf.keras.Sequential([
|
||||||
|
tf.keras.layers.InputLayer(input_shape=(input_size,), dtype=tf.float32),
|
||||||
|
tf.keras.layers.Dense(units=5, activation=tf.nn.relu),
|
||||||
|
tf.keras.layers.Dense(units=2, activation=tf.nn.softmax)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Convert TF Model to a Quantized TFLite Model
|
||||||
|
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||||
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||||
|
|
||||||
|
def representative_dataset_gen():
|
||||||
|
for i in range(10):
|
||||||
|
yield [np.array([i] * input_size, dtype=np.float32)]
|
||||||
|
|
||||||
|
converter.representative_dataset = representative_dataset_gen
|
||||||
|
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
|
return tflite_model
|
||||||
|
|
||||||
|
|
||||||
|
class ModifyModelInterfaceTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def testInt8Interface(self):
|
||||||
|
# 1. SETUP
|
||||||
|
# Define the temporary directory and files
|
||||||
|
temp_dir = self.get_temp_dir()
|
||||||
|
initial_file = os.path.join(temp_dir, 'initial_model.tflite')
|
||||||
|
final_file = os.path.join(temp_dir, 'final_model.tflite')
|
||||||
|
# Define initial model
|
||||||
|
initial_model = build_tflite_model_with_full_integer_quantization()
|
||||||
|
with open(initial_file, 'wb') as model_file:
|
||||||
|
model_file.write(initial_model)
|
||||||
|
|
||||||
|
# 2. INVOKE
|
||||||
|
# Invoke the modify_model_interface function
|
||||||
|
modify_model_interface_lib.modify_model_interface(initial_file, final_file,
|
||||||
|
tf.int8, tf.int8)
|
||||||
|
|
||||||
|
# 3. VALIDATE
|
||||||
|
# Load TFLite model and allocate tensors.
|
||||||
|
initial_interpreter = tf.lite.Interpreter(model_path=initial_file)
|
||||||
|
initial_interpreter.allocate_tensors()
|
||||||
|
final_interpreter = tf.lite.Interpreter(model_path=final_file)
|
||||||
|
final_interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
# Get input and output types.
|
||||||
|
initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype']
|
||||||
|
initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype']
|
||||||
|
final_input_dtype = final_interpreter.get_input_details()[0]['dtype']
|
||||||
|
final_output_dtype = final_interpreter.get_output_details()[0]['dtype']
|
||||||
|
|
||||||
|
# Validate the model interfaces
|
||||||
|
self.assertEqual(initial_input_dtype, np.float32)
|
||||||
|
self.assertEqual(initial_output_dtype, np.float32)
|
||||||
|
self.assertEqual(final_input_dtype, np.int8)
|
||||||
|
self.assertEqual(final_output_dtype, np.int8)
|
||||||
|
|
||||||
|
def testUInt8Interface(self):
|
||||||
|
# 1. SETUP
|
||||||
|
# Define the temporary directory and files
|
||||||
|
temp_dir = self.get_temp_dir()
|
||||||
|
initial_file = os.path.join(temp_dir, 'initial_model.tflite')
|
||||||
|
final_file = os.path.join(temp_dir, 'final_model.tflite')
|
||||||
|
# Define initial model
|
||||||
|
initial_model = build_tflite_model_with_full_integer_quantization()
|
||||||
|
with open(initial_file, 'wb') as model_file:
|
||||||
|
model_file.write(initial_model)
|
||||||
|
|
||||||
|
# 2. INVOKE
|
||||||
|
# Invoke the modify_model_interface function
|
||||||
|
modify_model_interface_lib.modify_model_interface(initial_file, final_file,
|
||||||
|
tf.uint8, tf.uint8)
|
||||||
|
|
||||||
|
# 3. VALIDATE
|
||||||
|
# Load TFLite model and allocate tensors.
|
||||||
|
initial_interpreter = tf.lite.Interpreter(model_path=initial_file)
|
||||||
|
initial_interpreter.allocate_tensors()
|
||||||
|
final_interpreter = tf.lite.Interpreter(model_path=final_file)
|
||||||
|
final_interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
# Get input and output types.
|
||||||
|
initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype']
|
||||||
|
initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype']
|
||||||
|
final_input_dtype = final_interpreter.get_input_details()[0]['dtype']
|
||||||
|
final_output_dtype = final_interpreter.get_output_details()[0]['dtype']
|
||||||
|
|
||||||
|
# Validate the model interfaces
|
||||||
|
self.assertEqual(initial_input_dtype, np.float32)
|
||||||
|
self.assertEqual(initial_output_dtype, np.float32)
|
||||||
|
self.assertEqual(final_input_dtype, np.uint8)
|
||||||
|
self.assertEqual(final_output_dtype, np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
Loading…
Reference in New Issue
Block a user