Add more model support to TocoConverter.
PiperOrigin-RevId: 210643904
This commit is contained in:
parent
c4099e6ee8
commit
2e7352e57c
@ -70,7 +70,7 @@ py_library(
|
||||
py_test(
|
||||
name = "lite_test",
|
||||
srcs = ["lite_test.py"],
|
||||
data = [":interpreter_test_data"],
|
||||
data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pbtxt"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
@ -130,6 +130,7 @@ py_test(
|
||||
],
|
||||
deps = [
|
||||
":convert",
|
||||
":interpreter",
|
||||
":op_hint",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -226,6 +226,54 @@ def build_toco_convert_protos(input_tensors,
|
||||
return model, toco
|
||||
|
||||
|
||||
def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
||||
*args, **kwargs):
|
||||
""""Convert a model using TOCO.
|
||||
|
||||
This function is used to convert GraphDefs that cannot be loaded into
|
||||
TensorFlow to TFLite. Conversion can be customized by providing arguments
|
||||
that are forwarded to `build_toco_convert_protos` (see documentation for
|
||||
details).
|
||||
|
||||
Args:
|
||||
input_data: Input data (i.e. often `sess.graph_def`),
|
||||
input_arrays_with_shape: Tuple of strings representing input tensor names
|
||||
and list of integers representing input shapes
|
||||
(e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
|
||||
into TensorFlow and when `input_tensors` is None. (default None)
|
||||
output_arrays: List of output tensors to freeze graph with. Use only when
|
||||
graph cannot be loaded into TensorFlow and when `output_tensors` is None.
|
||||
(default None)
|
||||
*args: See `build_toco_convert_protos`,
|
||||
**kwargs: See `build_toco_convert_protos`.
|
||||
|
||||
Returns:
|
||||
The converted data. For example if TFLite was the destination, then
|
||||
this will be a tflite flatbuffer in a bytes array.
|
||||
|
||||
Raises:
|
||||
Defined in `build_toco_convert_protos`.
|
||||
"""
|
||||
model_flags, toco_flags = build_toco_convert_protos(
|
||||
input_tensors=[], output_tensors=[], *args, **kwargs)
|
||||
|
||||
for idx, (name, shape) in enumerate(input_arrays_with_shape):
|
||||
input_array = model_flags.input_arrays.add()
|
||||
if kwargs["inference_type"] == lite_constants.QUANTIZED_UINT8:
|
||||
input_array.mean_value, input_array.std_value = kwargs[
|
||||
"quantized_input_stats"][idx]
|
||||
input_array.name = name
|
||||
input_array.shape.dims.extend(map(int, shape))
|
||||
|
||||
for name in output_arrays:
|
||||
model_flags.output_arrays.append(name)
|
||||
|
||||
data = toco_convert_protos(model_flags.SerializeToString(),
|
||||
toco_flags.SerializeToString(),
|
||||
input_data.SerializeToString())
|
||||
return data
|
||||
|
||||
|
||||
def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
|
||||
**kwargs):
|
||||
""""Convert a model using TOCO.
|
||||
|
@ -17,9 +17,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.lite.python import convert
|
||||
from tensorflow.contrib.lite.python import lite_constants
|
||||
from tensorflow.contrib.lite.python import op_hint
|
||||
from tensorflow.contrib.lite.python.interpreter import Interpreter
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -37,9 +40,12 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
dtype=dtypes.float32)
|
||||
out_tensor = in_tensor + in_tensor
|
||||
sess = session.Session()
|
||||
|
||||
# Try running on valid graph
|
||||
result = convert.toco_convert(sess.graph_def, [in_tensor], [out_tensor])
|
||||
self.assertTrue(result)
|
||||
tflite_model = convert.toco_convert(sess.graph_def, [in_tensor],
|
||||
[out_tensor])
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
# TODO(aselle): remove tests that fail (we must get TOCO to not fatal
|
||||
# all the time).
|
||||
# Try running on identity graph (known fail)
|
||||
@ -52,11 +58,85 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor,
|
||||
min=0., max=1.)
|
||||
sess = session.Session()
|
||||
result = convert.toco_convert(
|
||||
|
||||
tflite_model = convert.toco_convert(
|
||||
sess.graph_def, [in_tensor], [out_tensor],
|
||||
inference_type=lite_constants.QUANTIZED_UINT8,
|
||||
quantized_input_stats=[(0., 1.)])
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
def testGraphDefBasic(self):
|
||||
in_tensor = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32, name="input")
|
||||
_ = in_tensor + in_tensor
|
||||
sess = session.Session()
|
||||
|
||||
tflite_model = convert.toco_convert_graph_def(
|
||||
sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
|
||||
inference_type=lite_constants.FLOAT)
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertEqual(1, len(input_details))
|
||||
self.assertEqual("input", input_details[0]["name"])
|
||||
self.assertEqual(np.float32, input_details[0]["dtype"])
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
|
||||
self.assertEqual((0., 0.), input_details[0]["quantization"])
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(1, len(output_details))
|
||||
self.assertEqual("add", output_details[0]["name"])
|
||||
self.assertEqual(np.float32, output_details[0]["dtype"])
|
||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
|
||||
self.assertEqual((0., 0.), output_details[0]["quantization"])
|
||||
|
||||
def testGraphDefQuantization(self):
|
||||
in_tensor_1 = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA")
|
||||
in_tensor_2 = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB")
|
||||
_ = array_ops.fake_quant_with_min_max_args(
|
||||
in_tensor_1 + in_tensor_2, min=0., max=1., name="output")
|
||||
sess = session.Session()
|
||||
|
||||
input_arrays_map = [("inputA", [1, 16, 16, 3]), ("inputB", [1, 16, 16, 3])]
|
||||
output_arrays = ["output"]
|
||||
tflite_model = convert.toco_convert_graph_def(
|
||||
sess.graph_def,
|
||||
input_arrays_map,
|
||||
output_arrays,
|
||||
inference_type=lite_constants.QUANTIZED_UINT8,
|
||||
quantized_input_stats=[(0., 1.), (0., 1.)])
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertEqual(2, len(input_details))
|
||||
self.assertEqual("inputA", input_details[0]["name"])
|
||||
self.assertEqual(np.uint8, input_details[0]["dtype"])
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
|
||||
self.assertEqual((1., 0.),
|
||||
input_details[0]["quantization"]) # scale, zero_point
|
||||
|
||||
self.assertEqual("inputB", input_details[1]["name"])
|
||||
self.assertEqual(np.uint8, input_details[1]["dtype"])
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[1]["shape"]).all())
|
||||
self.assertEqual((1., 0.),
|
||||
input_details[1]["quantization"]) # scale, zero_point
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(1, len(output_details))
|
||||
self.assertEqual("output", output_details[0]["name"])
|
||||
self.assertEqual(np.uint8, output_details[0]["dtype"])
|
||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
|
||||
self.assertTrue(output_details[0]["quantization"][0] > 0) # scale
|
||||
|
||||
|
||||
class ConvertTestOpHint(test_util.TensorFlowTestCase):
|
||||
@ -243,7 +323,6 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
|
||||
with self.test_session() as sess:
|
||||
stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
|
||||
graph_def=sess.graph_def)
|
||||
print(stubbed_graphdef)
|
||||
self.assertCountEqual(
|
||||
self._getGraphOpTypes(
|
||||
stubbed_graphdef,
|
||||
|
@ -42,6 +42,7 @@ from tensorflow.contrib.lite.python import lite_constants as constants
|
||||
from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
|
||||
from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name
|
||||
from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import
|
||||
from tensorflow.contrib.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
|
||||
from tensorflow.contrib.lite.python.convert import toco_convert_impl as _toco_convert_impl
|
||||
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
|
||||
from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
|
||||
@ -55,6 +56,7 @@ from tensorflow.python import keras as _keras
|
||||
from tensorflow.python.client import session as _session
|
||||
from tensorflow.python.framework import graph_util as _tf_graph_util
|
||||
from tensorflow.python.framework import ops as _ops
|
||||
from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
|
||||
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
|
||||
from tensorflow.python.saved_model import signature_constants as _signature_constants
|
||||
from tensorflow.python.saved_model import tag_constants as _tag_constants
|
||||
@ -133,7 +135,12 @@ class TocoConverter(object):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, graph_def, input_tensors, output_tensors):
|
||||
def __init__(self,
|
||||
graph_def,
|
||||
input_tensors,
|
||||
output_tensors,
|
||||
input_arrays_with_shape=None,
|
||||
output_arrays=None):
|
||||
"""Constructor for TocoConverter.
|
||||
|
||||
Args:
|
||||
@ -142,6 +149,17 @@ class TocoConverter(object):
|
||||
input_tensors: List of input tensors. Type and shape are computed using
|
||||
`foo.get_shape()` and `foo.dtype`.
|
||||
output_tensors: List of output tensors (only .name is used from this).
|
||||
input_arrays_with_shape: Tuple of strings representing input tensor names
|
||||
and list of integers representing input shapes
|
||||
(e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
|
||||
into TensorFlow and when `input_tensors` and `output_tensors` are None.
|
||||
(default None)
|
||||
output_arrays: List of output tensors to freeze graph with. Use only when
|
||||
graph cannot be loaded into TensorFlow and when `input_tensors` and
|
||||
`output_tensors` are None. (default None)
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid arguments.
|
||||
"""
|
||||
self._graph_def = graph_def
|
||||
self._input_tensors = input_tensors
|
||||
@ -159,6 +177,15 @@ class TocoConverter(object):
|
||||
self.dump_graphviz_dir = None
|
||||
self.dump_graphviz_video = False
|
||||
|
||||
# Attributes are used by models that cannot be loaded into TensorFlow.
|
||||
if not self._has_valid_tensors():
|
||||
if not input_arrays_with_shape or not output_arrays:
|
||||
raise ValueError(
|
||||
"If input_tensors and output_tensors are None, both "
|
||||
"input_arrays_with_shape and output_arrays must be defined.")
|
||||
self._input_arrays_with_shape = input_arrays_with_shape
|
||||
self._output_arrays = output_arrays
|
||||
|
||||
@classmethod
|
||||
def from_session(cls, sess, input_tensors, output_tensors):
|
||||
"""Creates a TocoConverter class from a TensorFlow Session.
|
||||
@ -200,6 +227,7 @@ class TocoConverter(object):
|
||||
Unable to parse input file.
|
||||
The graph is not frozen.
|
||||
input_arrays or output_arrays contains an invalid tensor name.
|
||||
input_shapes is not correctly defined when required
|
||||
"""
|
||||
with _ops.Graph().as_default():
|
||||
with _session.Session() as sess:
|
||||
@ -222,20 +250,44 @@ class TocoConverter(object):
|
||||
except (_text_format.ParseError, DecodeError):
|
||||
raise ValueError(
|
||||
"Unable to parse input file '{}'.".format(graph_def_file))
|
||||
_import_graph_def(graph_def, name="")
|
||||
|
||||
# Get input and output tensors.
|
||||
input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
|
||||
output_tensors = _get_tensors_from_tensor_names(sess.graph,
|
||||
output_arrays)
|
||||
_set_tensor_shapes(input_tensors, input_shapes)
|
||||
# Handles models with custom TFLite ops that cannot be resolved in
|
||||
# TensorFlow.
|
||||
load_model_in_session = True
|
||||
try:
|
||||
_import_graph_def(graph_def, name="")
|
||||
except _NotFoundError:
|
||||
load_model_in_session = False
|
||||
|
||||
# Check if graph is frozen.
|
||||
if not _is_frozen_graph(sess):
|
||||
raise ValueError("Please freeze the graph using freeze_graph.py.")
|
||||
if load_model_in_session:
|
||||
# Check if graph is frozen.
|
||||
if not _is_frozen_graph(sess):
|
||||
raise ValueError("Please freeze the graph using freeze_graph.py.")
|
||||
|
||||
# Create TocoConverter class.
|
||||
return cls(sess.graph_def, input_tensors, output_tensors)
|
||||
# Get input and output tensors.
|
||||
input_tensors = _get_tensors_from_tensor_names(
|
||||
sess.graph, input_arrays)
|
||||
output_tensors = _get_tensors_from_tensor_names(
|
||||
sess.graph, output_arrays)
|
||||
_set_tensor_shapes(input_tensors, input_shapes)
|
||||
|
||||
return cls(sess.graph_def, input_tensors, output_tensors)
|
||||
else:
|
||||
if not input_shapes:
|
||||
raise ValueError("input_shapes must be defined for this model.")
|
||||
if set(input_arrays) != set(input_shapes.keys()):
|
||||
raise ValueError("input_shapes must contain a value for each item "
|
||||
"in input_array.")
|
||||
|
||||
input_arrays_with_shape = [
|
||||
(name, input_shapes[name]) for name in input_arrays
|
||||
]
|
||||
return cls(
|
||||
graph_def,
|
||||
input_tensors=None,
|
||||
output_tensors=None,
|
||||
input_arrays_with_shape=input_arrays_with_shape,
|
||||
output_arrays=output_arrays)
|
||||
|
||||
@classmethod
|
||||
def from_saved_model(cls,
|
||||
@ -330,25 +382,25 @@ class TocoConverter(object):
|
||||
None value for dimension in input_tensor.
|
||||
"""
|
||||
# Checks dimensions in input tensor.
|
||||
for tensor in self._input_tensors:
|
||||
if not tensor.get_shape():
|
||||
raise ValueError("Provide an input shape for input array '{0}'.".format(
|
||||
_tensor_name(tensor)))
|
||||
shape = tensor.get_shape().as_list()
|
||||
if None in shape[1:]:
|
||||
raise ValueError(
|
||||
"None is only supported in the 1st dimension. Tensor '{0}' has "
|
||||
"invalid shape '{1}'.".format(_tensor_name(tensor), shape))
|
||||
elif shape[0] is None:
|
||||
self._set_batch_size(batch_size=1)
|
||||
if self._has_valid_tensors():
|
||||
for tensor in self._input_tensors:
|
||||
if not tensor.get_shape():
|
||||
raise ValueError("Provide an input shape for input array "
|
||||
"'{0}'.".format(_tensor_name(tensor)))
|
||||
shape = tensor.get_shape().as_list()
|
||||
if None in shape[1:]:
|
||||
raise ValueError(
|
||||
"None is only supported in the 1st dimension. Tensor '{0}' has "
|
||||
"invalid shape '{1}'.".format(_tensor_name(tensor), shape))
|
||||
elif shape[0] is None:
|
||||
self._set_batch_size(batch_size=1)
|
||||
|
||||
# Get quantization stats. Ensures there is one stat per name if the stats
|
||||
# are specified.
|
||||
if self.quantized_input_stats:
|
||||
quantized_stats = []
|
||||
invalid_stats = []
|
||||
for tensor in self._input_tensors:
|
||||
name = _tensor_name(tensor)
|
||||
for name in self.get_input_arrays():
|
||||
if name in self.quantized_input_stats:
|
||||
quantized_stats.append(self.quantized_input_stats[name])
|
||||
else:
|
||||
@ -360,24 +412,35 @@ class TocoConverter(object):
|
||||
else:
|
||||
quantized_stats = None
|
||||
|
||||
converter_kwargs = {
|
||||
"inference_type": self.inference_type,
|
||||
"inference_input_type": self.inference_input_type,
|
||||
"input_format": constants.TENSORFLOW_GRAPHDEF,
|
||||
"output_format": self.output_format,
|
||||
"quantized_input_stats": quantized_stats,
|
||||
"default_ranges_stats": self.default_ranges_stats,
|
||||
"drop_control_dependency": self.drop_control_dependency,
|
||||
"reorder_across_fake_quant": self.reorder_across_fake_quant,
|
||||
"change_concat_input_ranges": self.change_concat_input_ranges,
|
||||
"allow_custom_ops": self.allow_custom_ops,
|
||||
"quantize_weights": self.quantize_weights,
|
||||
"dump_graphviz_dir": self.dump_graphviz_dir,
|
||||
"dump_graphviz_video": self.dump_graphviz_video
|
||||
}
|
||||
|
||||
# Converts model.
|
||||
result = _toco_convert_impl(
|
||||
input_data=self._graph_def,
|
||||
input_tensors=self._input_tensors,
|
||||
output_tensors=self._output_tensors,
|
||||
inference_type=self.inference_type,
|
||||
inference_input_type=self.inference_input_type,
|
||||
input_format=constants.TENSORFLOW_GRAPHDEF,
|
||||
output_format=self.output_format,
|
||||
quantized_input_stats=quantized_stats,
|
||||
default_ranges_stats=self.default_ranges_stats,
|
||||
drop_control_dependency=self.drop_control_dependency,
|
||||
reorder_across_fake_quant=self.reorder_across_fake_quant,
|
||||
change_concat_input_ranges=self.change_concat_input_ranges,
|
||||
allow_custom_ops=self.allow_custom_ops,
|
||||
quantize_weights=self.quantize_weights,
|
||||
dump_graphviz_dir=self.dump_graphviz_dir,
|
||||
dump_graphviz_video=self.dump_graphviz_video)
|
||||
if self._has_valid_tensors():
|
||||
result = _toco_convert_impl(
|
||||
input_data=self._graph_def,
|
||||
input_tensors=self._input_tensors,
|
||||
output_tensors=self._output_tensors,
|
||||
**converter_kwargs)
|
||||
else:
|
||||
result = _toco_convert_graph_def(
|
||||
input_data=self._graph_def,
|
||||
input_arrays_with_shape=self._input_arrays_with_shape,
|
||||
output_arrays=self._output_arrays,
|
||||
**converter_kwargs)
|
||||
return result
|
||||
|
||||
def get_input_arrays(self):
|
||||
@ -386,7 +449,18 @@ class TocoConverter(object):
|
||||
Returns:
|
||||
List of strings.
|
||||
"""
|
||||
return [_tensor_name(tensor) for tensor in self._input_tensors]
|
||||
if self._has_valid_tensors():
|
||||
return [_tensor_name(tensor) for tensor in self._input_tensors]
|
||||
else:
|
||||
return [name for name, _ in self._input_arrays_with_shape]
|
||||
|
||||
def _has_valid_tensors(self):
|
||||
"""Checks if the input and output tensors have been initialized.
|
||||
|
||||
Returns:
|
||||
Bool.
|
||||
"""
|
||||
return self._input_tensors and self._output_tensors
|
||||
|
||||
def _set_batch_size(self, batch_size):
|
||||
"""Sets the first dimension of the input tensor to `batch_size`.
|
||||
@ -394,7 +468,14 @@ class TocoConverter(object):
|
||||
Args:
|
||||
batch_size: Batch size for the model. Replaces the first dimension of an
|
||||
input size array if undefined. (default 1)
|
||||
|
||||
Raises:
|
||||
ValueError: input_tensor is not defined.
|
||||
"""
|
||||
if not self._has_valid_tensors():
|
||||
raise ValueError("The batch size cannot be set for this model. Please "
|
||||
"use input_shapes parameter.")
|
||||
|
||||
for tensor in self._input_tensors:
|
||||
shape = tensor.get_shape().as_list()
|
||||
shape[0] = batch_size
|
||||
|
@ -35,11 +35,51 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import saved_model
|
||||
from tensorflow.python.training.training_util import write_graph
|
||||
|
||||
|
||||
class FromConstructor(test_util.TensorFlowTestCase):
|
||||
|
||||
# Tests invalid constructors using a dummy value for the GraphDef.
|
||||
def testInvalidConstructor(self):
|
||||
message = ('If input_tensors and output_tensors are None, both '
|
||||
'input_arrays_with_shape and output_arrays must be defined.')
|
||||
|
||||
# `output_arrays` is not defined.
|
||||
with self.assertRaises(ValueError) as error:
|
||||
lite.TocoConverter(
|
||||
None, None, [], input_arrays_with_shape=[('input', [3, 9])])
|
||||
self.assertEqual(message, str(error.exception))
|
||||
|
||||
# `input_arrays_with_shape` is not defined.
|
||||
with self.assertRaises(ValueError) as error:
|
||||
lite.TocoConverter(None, [], None, output_arrays=['output'])
|
||||
self.assertEqual(message, str(error.exception))
|
||||
|
||||
# Tests valid constructors using a dummy value for the GraphDef.
|
||||
def testValidConstructor(self):
|
||||
converter = lite.TocoConverter(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
input_arrays_with_shape=[('input', [3, 9])],
|
||||
output_arrays=['output'])
|
||||
self.assertFalse(converter._has_valid_tensors())
|
||||
self.assertEqual(converter.get_input_arrays(), ['input'])
|
||||
|
||||
with self.assertRaises(ValueError) as error:
|
||||
converter._set_batch_size(1)
|
||||
self.assertEqual(
|
||||
'The batch size cannot be set for this model. Please use '
|
||||
'input_shapes parameter.', str(error.exception))
|
||||
|
||||
converter = lite.TocoConverter(None, ['input_tensor'], ['output_tensor'])
|
||||
self.assertTrue(converter._has_valid_tensors())
|
||||
|
||||
|
||||
class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testFloat(self):
|
||||
@ -490,6 +530,79 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
|
||||
'Unable to parse input file \'{}\'.'.format(graph_def_file),
|
||||
str(error.exception))
|
||||
|
||||
# TODO(nupurgarg): Test model loading in open source.
|
||||
def _initObjectDetectionArgs(self):
|
||||
# Initializes the arguments required for the object detection model.
|
||||
self._graph_def_file = resource_loader.get_path_to_datafile(
|
||||
'testdata/tflite_graph.pbtxt')
|
||||
self._input_arrays = ['normalized_input_image_tensor']
|
||||
self._output_arrays = [
|
||||
'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
|
||||
'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3'
|
||||
]
|
||||
self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]}
|
||||
|
||||
def testTFLiteGraphDef(self):
|
||||
# Tests the object detection model that cannot be loaded in TensorFlow.
|
||||
self._initObjectDetectionArgs()
|
||||
|
||||
converter = lite.TocoConverter.from_frozen_graph(
|
||||
self._graph_def_file, self._input_arrays, self._output_arrays,
|
||||
self._input_shapes)
|
||||
converter.allow_custom_ops = True
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertEqual(1, len(input_details))
|
||||
self.assertEqual('normalized_input_image_tensor', input_details[0]['name'])
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue(([1, 300, 300, 3] == input_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(4, len(output_details))
|
||||
self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name'])
|
||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||
self.assertTrue(([1, 10, 4] == output_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
||||
|
||||
self.assertEqual('TFLite_Detection_PostProcess:1',
|
||||
output_details[1]['name'])
|
||||
self.assertTrue(([1, 10] == output_details[1]['shape']).all())
|
||||
self.assertEqual('TFLite_Detection_PostProcess:2',
|
||||
output_details[2]['name'])
|
||||
self.assertTrue(([1, 10] == output_details[2]['shape']).all())
|
||||
self.assertEqual('TFLite_Detection_PostProcess:3',
|
||||
output_details[3]['name'])
|
||||
self.assertTrue(([1] == output_details[3]['shape']).all())
|
||||
|
||||
def testTFLiteGraphDefInvalid(self):
|
||||
# Tests invalid cases for the model that cannot be loaded in TensorFlow.
|
||||
self._initObjectDetectionArgs()
|
||||
|
||||
# Missing `input_shapes`.
|
||||
with self.assertRaises(ValueError) as error:
|
||||
lite.TocoConverter.from_frozen_graph(
|
||||
self._graph_def_file, self._input_arrays, self._output_arrays)
|
||||
self.assertEqual('input_shapes must be defined for this model.',
|
||||
str(error.exception))
|
||||
|
||||
# `input_shapes` does not contain the names in `input_arrays`.
|
||||
with self.assertRaises(ValueError) as error:
|
||||
lite.TocoConverter.from_frozen_graph(
|
||||
self._graph_def_file,
|
||||
self._input_arrays,
|
||||
self._output_arrays,
|
||||
input_shapes={'invalid-value': [1, 19]})
|
||||
self.assertEqual(
|
||||
'input_shapes must contain a value for each item in input_array.',
|
||||
str(error.exception))
|
||||
|
||||
|
||||
class FromSavedModelTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -132,7 +132,8 @@ def _convert_model(flags):
|
||||
if flags.reorder_across_fake_quant:
|
||||
converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
|
||||
if flags.change_concat_input_ranges:
|
||||
converter.change_concat_input_ranges = flags.change_concat_input_ranges
|
||||
converter.change_concat_input_ranges = (
|
||||
flags.change_concat_input_ranges == "TRUE")
|
||||
if flags.allow_custom_ops:
|
||||
converter.allow_custom_ops = flags.allow_custom_ops
|
||||
if flags.quantize_weights:
|
||||
@ -333,9 +334,14 @@ def run_main(_):
|
||||
"the graph. Results in a graph that differs from the quantized "
|
||||
"training graph, potentially causing differing arithmetic "
|
||||
"behavior. (default False)"))
|
||||
# Usage for this flag is --change_concat_input_ranges=true or
|
||||
# --change_concat_input_ranges=false in order to make it clear what the flag
|
||||
# is set to. This keeps the usage consistent with other usages of the flag
|
||||
# where the default is different. The default value here is False.
|
||||
parser.add_argument(
|
||||
"--change_concat_input_ranges",
|
||||
action="store_true",
|
||||
type=str.upper,
|
||||
choices=["TRUE", "FALSE"],
|
||||
help=("Boolean to change behavior of min/max ranges for inputs and "
|
||||
"outputs of the concat operator for quantized models. Changes the "
|
||||
"ranges of concat operator overlap when true. (default False)"))
|
||||
|
@ -767,6 +767,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
||||
],
|
||||
build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
|
||||
)
|
||||
|
||||
tf_http_archive(
|
||||
name = "tflite_mobilenet_ssd_quant",
|
||||
sha256 = "a809cd290b4d6a2e8a9d5dad076e0bd695b8091974e0eed1052b480b2f21b6dc",
|
||||
@ -777,6 +778,17 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
||||
build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
|
||||
)
|
||||
|
||||
tf_http_archive(
|
||||
name = "tflite_mobilenet_ssd_quant_protobuf",
|
||||
sha256 = "09280972c5777f1aa775ef67cb4ac5d5ed21970acd8535aeca62450ef14f0d79",
|
||||
urls = [
|
||||
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
|
||||
"http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
|
||||
],
|
||||
strip_prefix = "ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18",
|
||||
build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
|
||||
)
|
||||
|
||||
tf_http_archive(
|
||||
name = "tflite_conv_actions_frozen",
|
||||
sha256 = "d947b38cba389b5e2d0bfc3ea6cc49c784e187b41a071387b3742d1acac7691e",
|
||||
|
Loading…
x
Reference in New Issue
Block a user