Add python wrapper to modify model interface for full integer quantization
PiperOrigin-RevId: 308359969 Change-Id: I5929e4778e5e12651b64665bfa9f3dbc6f82e72b
This commit is contained in:
parent
490288d631
commit
a07ca66517
67
tensorflow/lite/tools/optimize/python/BUILD
Normal file
67
tensorflow/lite/tools/optimize/python/BUILD
Normal file
@ -0,0 +1,67 @@
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "modify_model_interface",
|
||||
srcs = ["modify_model_interface.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":modify_model_interface_constants",
|
||||
":modify_model_interface_lib",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "modify_model_interface_lib",
|
||||
srcs = ["modify_model_interface_lib.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":_pywrap_modify_model_interface",
|
||||
":modify_model_interface_constants",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/lite/python:schema_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "modify_model_interface_lib_test",
|
||||
srcs = ["modify_model_interface_lib_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
tags = [
|
||||
"no_mac", # TODO(b/148247402): flatbuffers import broken on Mac OS.
|
||||
],
|
||||
deps = [
|
||||
":modify_model_interface_lib",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "modify_model_interface_constants",
|
||||
srcs = ["modify_model_interface_constants.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = ["//tensorflow/lite/python:lite_constants"],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_pywrap_modify_model_interface",
|
||||
srcs = ["modify_model_interface.cc"],
|
||||
module_name = "_pywrap_modify_model_interface",
|
||||
deps = [
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/tools/optimize:modify_model_interface",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Python wrapper to modify model interface.
|
||||
|
||||
#include "tensorflow/lite/tools/optimize/modify_model_interface.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace pybind11 {
|
||||
|
||||
PYBIND11_MODULE(_pywrap_modify_model_interface, m) {
|
||||
// An anonymous function that invokes the C++ function
|
||||
// after applying transformations to the python function arguments
|
||||
m.def("modify_model_interface",
|
||||
[](const std::string& input_file, const std::string& output_file,
|
||||
const int input_type, const int output_type) -> int {
|
||||
return tflite::optimize::ModifyModelInterface(
|
||||
input_file, output_file,
|
||||
static_cast<tflite::TensorType>(input_type),
|
||||
static_cast<tflite::TensorType>(input_type));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace pybind11
|
@ -0,0 +1,78 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Modify a quantized model's interface from float to integer.
|
||||
|
||||
Example usage:
|
||||
python modify_model_interface_main.py \
|
||||
--input_file=float_model.tflite \
|
||||
--output_file=int_model.tflite \
|
||||
--input_type=INT8 \
|
||||
--output_type=INT8
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from tensorflow.lite.tools.optimize.python import modify_model_interface_constants as mmi_constants
|
||||
from tensorflow.lite.tools.optimize.python import modify_model_interface_lib as mmi_lib
|
||||
from tensorflow.python.platform import app
|
||||
|
||||
|
||||
def main(_):
|
||||
"""Application run loop."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Modify a quantized model's interface from float to integer.")
|
||||
parser.add_argument(
|
||||
'--input_file',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Full path name to the input tflite file.')
|
||||
parser.add_argument(
|
||||
'--output_file',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Full path name to the output tflite file.')
|
||||
parser.add_argument(
|
||||
'--input_type',
|
||||
type=str.upper,
|
||||
choices=mmi_constants.STR_TYPES,
|
||||
default=mmi_constants.DEFAULT_STR_TYPE,
|
||||
help='Modified input integer interface type.')
|
||||
parser.add_argument(
|
||||
'--output_type',
|
||||
type=str.upper,
|
||||
choices=mmi_constants.STR_TYPES,
|
||||
default=mmi_constants.DEFAULT_STR_TYPE,
|
||||
help='Modified output integer interface type.')
|
||||
args = parser.parse_args()
|
||||
|
||||
input_type = mmi_constants.STR_TO_TFLITE_TYPES[args.input_type]
|
||||
output_type = mmi_constants.STR_TO_TFLITE_TYPES[args.output_type]
|
||||
|
||||
mmi_lib.modify_model_interface(args.input_file, args.output_file, input_type,
|
||||
output_type)
|
||||
|
||||
print('Successfully modified the model input type from FLOAT to '
|
||||
'{input_type} and output type from FLOAT to {output_type}.'.format(
|
||||
input_type=args.input_type, output_type=args.output_type))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main=main, argv=sys.argv[:1])
|
@ -0,0 +1,34 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Constants for modify_model_interface."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.lite.python import lite_constants
|
||||
|
||||
STR_TO_TFLITE_TYPES = {
|
||||
'INT8': lite_constants.INT8,
|
||||
'UINT8': lite_constants.QUANTIZED_UINT8
|
||||
}
|
||||
TFLITE_TO_STR_TYPES = {v: k for k, v in STR_TO_TFLITE_TYPES.items()}
|
||||
|
||||
STR_TYPES = STR_TO_TFLITE_TYPES.keys()
|
||||
TFLITE_TYPES = STR_TO_TFLITE_TYPES.values()
|
||||
|
||||
DEFAULT_STR_TYPE = 'INT8'
|
||||
DEFAULT_TFLITE_TYPE = STR_TO_TFLITE_TYPES[DEFAULT_STR_TYPE]
|
@ -0,0 +1,79 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Library to modify a quantized model's interface from float to integer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||
from tensorflow.lite.tools.optimize.python import _pywrap_modify_model_interface
|
||||
from tensorflow.lite.tools.optimize.python import modify_model_interface_constants as mmi_constants
|
||||
|
||||
|
||||
def _parse_type_to_int(dtype, flag):
|
||||
"""Converts a tflite type to it's integer representation.
|
||||
|
||||
Args:
|
||||
dtype: tf.DType representing the inference type.
|
||||
flag: str representing the flag name.
|
||||
|
||||
Returns:
|
||||
integer, a tflite TensorType enum value.
|
||||
|
||||
Raises:
|
||||
ValueError: Unsupported tflite type.
|
||||
"""
|
||||
# Validate if dtype is supported in tflite and is a valid interface type.
|
||||
if dtype not in mmi_constants.TFLITE_TYPES:
|
||||
raise ValueError(
|
||||
"Unsupported value '{0}' for {1}. Only {2} are supported.".format(
|
||||
dtype, flag, mmi_constants.TFLITE_TYPES))
|
||||
|
||||
dtype_str = mmi_constants.TFLITE_TO_STR_TYPES[dtype]
|
||||
dtype_int = schema_fb.TensorType.__dict__[dtype_str]
|
||||
|
||||
return dtype_int
|
||||
|
||||
|
||||
def modify_model_interface(input_file, output_file, input_type, output_type):
|
||||
"""Modify a quantized model's interface (input/output) from float to integer.
|
||||
|
||||
Args:
|
||||
input_file: Full path name to the input tflite file.
|
||||
output_file: Full path name to the output tflite file.
|
||||
input_type: Final input interface type.
|
||||
output_type: Final output interface type.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the modification of the model interface was unsuccessful.
|
||||
ValueError: If the input_type or output_type is unsupported.
|
||||
|
||||
"""
|
||||
# Map the interface types to integer values
|
||||
input_type_int = _parse_type_to_int(input_type, 'input_type')
|
||||
output_type_int = _parse_type_to_int(output_type, 'output_type')
|
||||
|
||||
# Invoke the function to modify the model interface
|
||||
status = _pywrap_modify_model_interface.modify_model_interface(
|
||||
input_file, output_file, input_type_int, output_type_int)
|
||||
|
||||
# Throw an exception if the return status is an error.
|
||||
if status != 0:
|
||||
raise RuntimeError(
|
||||
'Error occured when trying to modify the model input type from float '
|
||||
'to {input_type} and output type from float to {output_type}.'.format(
|
||||
input_type=input_type, output_type=output_type))
|
@ -0,0 +1,129 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for modify_model_interface_lib.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.lite.tools.optimize.python import modify_model_interface_lib
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def build_tflite_model_with_full_integer_quantization():
|
||||
# Define TF model
|
||||
input_size = 3
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.InputLayer(input_shape=(input_size,), dtype=tf.float32),
|
||||
tf.keras.layers.Dense(units=5, activation=tf.nn.relu),
|
||||
tf.keras.layers.Dense(units=2, activation=tf.nn.softmax)
|
||||
])
|
||||
|
||||
# Convert TF Model to a Quantized TFLite Model
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
|
||||
def representative_dataset_gen():
|
||||
for i in range(10):
|
||||
yield [np.array([i] * input_size, dtype=np.float32)]
|
||||
|
||||
converter.representative_dataset = representative_dataset_gen
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||
tflite_model = converter.convert()
|
||||
|
||||
return tflite_model
|
||||
|
||||
|
||||
class ModifyModelInterfaceTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testInt8Interface(self):
|
||||
# 1. SETUP
|
||||
# Define the temporary directory and files
|
||||
temp_dir = self.get_temp_dir()
|
||||
initial_file = os.path.join(temp_dir, 'initial_model.tflite')
|
||||
final_file = os.path.join(temp_dir, 'final_model.tflite')
|
||||
# Define initial model
|
||||
initial_model = build_tflite_model_with_full_integer_quantization()
|
||||
with open(initial_file, 'wb') as model_file:
|
||||
model_file.write(initial_model)
|
||||
|
||||
# 2. INVOKE
|
||||
# Invoke the modify_model_interface function
|
||||
modify_model_interface_lib.modify_model_interface(initial_file, final_file,
|
||||
tf.int8, tf.int8)
|
||||
|
||||
# 3. VALIDATE
|
||||
# Load TFLite model and allocate tensors.
|
||||
initial_interpreter = tf.lite.Interpreter(model_path=initial_file)
|
||||
initial_interpreter.allocate_tensors()
|
||||
final_interpreter = tf.lite.Interpreter(model_path=final_file)
|
||||
final_interpreter.allocate_tensors()
|
||||
|
||||
# Get input and output types.
|
||||
initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype']
|
||||
initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype']
|
||||
final_input_dtype = final_interpreter.get_input_details()[0]['dtype']
|
||||
final_output_dtype = final_interpreter.get_output_details()[0]['dtype']
|
||||
|
||||
# Validate the model interfaces
|
||||
self.assertEqual(initial_input_dtype, np.float32)
|
||||
self.assertEqual(initial_output_dtype, np.float32)
|
||||
self.assertEqual(final_input_dtype, np.int8)
|
||||
self.assertEqual(final_output_dtype, np.int8)
|
||||
|
||||
def testUInt8Interface(self):
|
||||
# 1. SETUP
|
||||
# Define the temporary directory and files
|
||||
temp_dir = self.get_temp_dir()
|
||||
initial_file = os.path.join(temp_dir, 'initial_model.tflite')
|
||||
final_file = os.path.join(temp_dir, 'final_model.tflite')
|
||||
# Define initial model
|
||||
initial_model = build_tflite_model_with_full_integer_quantization()
|
||||
with open(initial_file, 'wb') as model_file:
|
||||
model_file.write(initial_model)
|
||||
|
||||
# 2. INVOKE
|
||||
# Invoke the modify_model_interface function
|
||||
modify_model_interface_lib.modify_model_interface(initial_file, final_file,
|
||||
tf.uint8, tf.uint8)
|
||||
|
||||
# 3. VALIDATE
|
||||
# Load TFLite model and allocate tensors.
|
||||
initial_interpreter = tf.lite.Interpreter(model_path=initial_file)
|
||||
initial_interpreter.allocate_tensors()
|
||||
final_interpreter = tf.lite.Interpreter(model_path=final_file)
|
||||
final_interpreter.allocate_tensors()
|
||||
|
||||
# Get input and output types.
|
||||
initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype']
|
||||
initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype']
|
||||
final_input_dtype = final_interpreter.get_input_details()[0]['dtype']
|
||||
final_output_dtype = final_interpreter.get_output_details()[0]['dtype']
|
||||
|
||||
# Validate the model interfaces
|
||||
self.assertEqual(initial_input_dtype, np.float32)
|
||||
self.assertEqual(initial_output_dtype, np.float32)
|
||||
self.assertEqual(final_input_dtype, np.uint8)
|
||||
self.assertEqual(final_output_dtype, np.uint8)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user