182 lines
6.7 KiB
Python
182 lines
6.7 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Creates TOCO options to process a model."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import tempfile
|
|
import traceback
|
|
|
|
import numpy as np
|
|
import tensorflow.compat.v1 as tf
|
|
|
|
from tensorflow.lite.testing import zip_test_utils
|
|
|
|
|
|
def toco_options(data_types,
|
|
input_arrays,
|
|
output_arrays,
|
|
shapes,
|
|
extra_toco_options=None):
|
|
"""Create TOCO options to process a model.
|
|
|
|
Args:
|
|
data_types: input and inference types used by TOCO.
|
|
input_arrays: names of the input tensors
|
|
output_arrays: name of the output tensors
|
|
shapes: shapes of the input tensors
|
|
extra_toco_options: additional toco options
|
|
|
|
Returns:
|
|
the options in a string.
|
|
"""
|
|
if extra_toco_options is None:
|
|
extra_toco_options = zip_test_utils.ExtraTocoOptions()
|
|
|
|
shape_str = ":".join([",".join(str(y) for y in x) for x in shapes if x])
|
|
inference_type = "FLOAT"
|
|
# TODO(ahentz): if we get multi-input quantization to work we need this
|
|
# to change
|
|
if data_types[0] == "QUANTIZED_UINT8":
|
|
inference_type = "QUANTIZED_UINT8"
|
|
s = (" --input_data_types=%s" % ",".join(data_types) +
|
|
" --inference_type=%s" % inference_type +
|
|
" --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" +
|
|
" --input_arrays=%s" % ",".join(input_arrays) +
|
|
" --output_arrays=%s" % ",".join(output_arrays))
|
|
if shape_str:
|
|
s += (" --input_shapes=%s" % shape_str)
|
|
if extra_toco_options.drop_control_dependency:
|
|
s += " --drop_control_dependency"
|
|
if extra_toco_options.allow_custom_ops:
|
|
s += " --allow_custom_ops"
|
|
if extra_toco_options.rnn_states:
|
|
s += (" --rnn_states='" + extra_toco_options.rnn_states + "'")
|
|
if extra_toco_options.split_tflite_lstm_inputs is not None:
|
|
if extra_toco_options.split_tflite_lstm_inputs:
|
|
s += " --split_tflite_lstm_inputs=true"
|
|
else:
|
|
s += " --split_tflite_lstm_inputs=false"
|
|
return s
|
|
|
|
|
|
def toco_convert(options, graph_def, input_tensors, output_tensors, **kwargs):
|
|
"""Convert a model's graph def into a tflite model.
|
|
|
|
NOTE: this currently shells out to the toco binary, but we would like
|
|
convert to Python API tooling in the future.
|
|
|
|
Args:
|
|
options: An Options instance.
|
|
graph_def: A GraphDef object.
|
|
input_tensors: List of input tensor tuples `(name, shape, type)`.
|
|
output_tensors: List of output tensors (names).
|
|
**kwargs: Extra options to be passed.
|
|
|
|
Returns:
|
|
output tflite model, log_txt from conversion
|
|
or None, log_txt if it did not convert properly.
|
|
"""
|
|
# Convert ophint ops if presented.
|
|
graph_def = tf.compat.v1.lite.experimental.convert_op_hints_to_stubs(
|
|
graph_def=graph_def)
|
|
graph_def_str = graph_def.SerializeToString()
|
|
|
|
extra_toco_options = kwargs.get("extra_toco_options",
|
|
zip_test_utils.ExtraTocoOptions())
|
|
test_params = kwargs.get("test_params", {})
|
|
input_arrays = [x[0] for x in input_tensors]
|
|
data_types = [zip_test_utils.TF_TYPE_INFO[x[2]][1] for x in input_tensors]
|
|
|
|
fully_quantize = test_params.get("fully_quantize", False)
|
|
dynamic_range_quantize = test_params.get("dynamic_range_quantize", False)
|
|
if dynamic_range_quantize or fully_quantize:
|
|
with tempfile.NamedTemporaryFile() as graphdef_file:
|
|
graphdef_file.write(graph_def_str)
|
|
graphdef_file.flush()
|
|
|
|
input_shapes = zip_test_utils.get_input_shapes_map(input_tensors)
|
|
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
|
|
graphdef_file.name, input_arrays, output_tensors, input_shapes)
|
|
|
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
|
|
if fully_quantize:
|
|
# Read the input range for the representative dataset from parameters.
|
|
min_value, max_value = test_params.get("input_range", (-1, 1))
|
|
|
|
def representative_dataset(input_tensors):
|
|
calibration_inputs = []
|
|
for _, shape, _ in input_tensors:
|
|
if shape:
|
|
dims = [dim.value for dim in shape.dims]
|
|
calibration_inputs.append(
|
|
np.random.uniform(min_value, max_value,
|
|
tuple(dims)).astype(np.float32))
|
|
return calibration_inputs
|
|
|
|
def representative_dataset_gen():
|
|
for _ in range(100):
|
|
yield representative_dataset(input_tensors)
|
|
|
|
converter.target_spec.supported_ops = [
|
|
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
|
|
]
|
|
converter.representative_dataset = representative_dataset_gen
|
|
if extra_toco_options.inference_input_type:
|
|
converter.inference_input_type = (
|
|
extra_toco_options.inference_input_type)
|
|
if extra_toco_options.inference_output_type:
|
|
converter.inference_output_type = (
|
|
extra_toco_options.inference_output_type)
|
|
else:
|
|
converter.inference_output_type = tf.int8
|
|
|
|
try:
|
|
tflite_model = converter.convert()
|
|
return tflite_model, ""
|
|
except Exception as e:
|
|
log = "{0}\n{1}".format(str(e), traceback.format_exc())
|
|
return None, log
|
|
|
|
else:
|
|
opts = toco_options(
|
|
data_types=data_types,
|
|
input_arrays=input_arrays,
|
|
shapes=[x[1] for x in input_tensors],
|
|
output_arrays=output_tensors,
|
|
extra_toco_options=extra_toco_options)
|
|
|
|
with tempfile.NamedTemporaryFile() as graphdef_file, \
|
|
tempfile.NamedTemporaryFile() as output_file, \
|
|
tempfile.NamedTemporaryFile("w+") as stdout_file:
|
|
graphdef_file.write(graph_def_str)
|
|
graphdef_file.flush()
|
|
|
|
# TODO(aselle): Switch this to subprocess at some point.
|
|
if options.run_with_flex:
|
|
opts += " --enable_select_tf_ops --force_select_tf_ops"
|
|
cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" %
|
|
(options.toco, graphdef_file.name, output_file.name, opts,
|
|
stdout_file.name))
|
|
exit_code = os.system(cmd)
|
|
log = (
|
|
cmd + "exited with code %d" % exit_code + "\n------------------\n" +
|
|
stdout_file.read())
|
|
return (None if exit_code != 0 else output_file.read()), log
|