V1 converter requires inference_type, inference_input_type and quantized_input_stats for conversion, whereas V2 converter utilizes FQ ops inside the graph for conversion and input information. This CL does the following 1. Move the input_stats check from the convert code into the V1 converter since that is now specific to it. 2. Improve the condition checking for post training calibrate vs weight only quantize vs training time quantize 3. Actually handle training-time quantize by passing the necessary flags to TOCO. Important to note, this appraoch leaves the option for both QAT and post-training calibrate quantize to be applied together in the same conversion. PiperOrigin-RevId: 298533518 Change-Id: I48ec5b8db8f20242522ca7af70dcbe339b79aa2f
398 lines
16 KiB
Python
398 lines
16 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""TensorFlow Lite Python Interface: Sanity check."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow.lite.python import convert
|
|
from tensorflow.lite.python import lite_constants
|
|
from tensorflow.lite.python import op_hint
|
|
from tensorflow.lite.python.interpreter import Interpreter
|
|
from tensorflow.python.client import session
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
|
|
from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
|
|
from tensorflow.python.framework.graph_util_impl import _node_name
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.platform import test
|
|
|
|
|
|
class ConvertTest(test_util.TensorFlowTestCase):
|
|
|
|
def testBasic(self):
|
|
with ops.Graph().as_default():
|
|
in_tensor = array_ops.placeholder(
|
|
shape=[1, 16, 16, 3], dtype=dtypes.float32)
|
|
out_tensor = in_tensor + in_tensor
|
|
sess = session.Session()
|
|
|
|
# Try running on valid graph
|
|
tflite_model = convert.toco_convert(sess.graph_def, [in_tensor],
|
|
[out_tensor])
|
|
self.assertTrue(tflite_model)
|
|
|
|
def testQuantization(self):
|
|
with ops.Graph().as_default():
|
|
in_tensor = array_ops.placeholder(
|
|
shape=[1, 16, 16, 3], dtype=dtypes.float32)
|
|
out_tensor = array_ops.fake_quant_with_min_max_args(
|
|
in_tensor + in_tensor, min=0., max=1.)
|
|
sess = session.Session()
|
|
|
|
tflite_model = convert.toco_convert(
|
|
sess.graph_def, [in_tensor], [out_tensor],
|
|
inference_type=lite_constants.QUANTIZED_UINT8,
|
|
quantized_input_stats=[(0., 1.)])
|
|
self.assertTrue(tflite_model)
|
|
|
|
def testGraphDefBasic(self):
|
|
with ops.Graph().as_default():
|
|
in_tensor = array_ops.placeholder(
|
|
shape=[1, 16, 16, 3], dtype=dtypes.float32, name="input")
|
|
_ = in_tensor + in_tensor
|
|
sess = session.Session()
|
|
|
|
tflite_model = convert.toco_convert_graph_def(
|
|
sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
|
|
enable_mlir_converter=False,
|
|
inference_type=lite_constants.FLOAT)
|
|
self.assertTrue(tflite_model)
|
|
|
|
# Check values from converted model.
|
|
interpreter = Interpreter(model_content=tflite_model)
|
|
interpreter.allocate_tensors()
|
|
|
|
input_details = interpreter.get_input_details()
|
|
self.assertEqual(1, len(input_details))
|
|
self.assertEqual("input", input_details[0]["name"])
|
|
self.assertEqual(np.float32, input_details[0]["dtype"])
|
|
self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
|
|
self.assertEqual((0., 0.), input_details[0]["quantization"])
|
|
|
|
output_details = interpreter.get_output_details()
|
|
self.assertEqual(1, len(output_details))
|
|
self.assertEqual("add", output_details[0]["name"])
|
|
self.assertEqual(np.float32, output_details[0]["dtype"])
|
|
self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
|
|
self.assertEqual((0., 0.), output_details[0]["quantization"])
|
|
|
|
def testGraphDefQuantization(self):
|
|
with ops.Graph().as_default():
|
|
in_tensor_1 = array_ops.placeholder(
|
|
shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA")
|
|
in_tensor_2 = array_ops.placeholder(
|
|
shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB")
|
|
_ = array_ops.fake_quant_with_min_max_args(
|
|
in_tensor_1 + in_tensor_2, min=0., max=1., name="output")
|
|
sess = session.Session()
|
|
|
|
input_arrays_map = [("inputA", [1, 16, 16, 3]), ("inputB", [1, 16, 16, 3])]
|
|
output_arrays = ["output"]
|
|
tflite_model = convert.toco_convert_graph_def(
|
|
sess.graph_def,
|
|
input_arrays_map,
|
|
output_arrays,
|
|
enable_mlir_converter=False,
|
|
inference_type=lite_constants.QUANTIZED_UINT8,
|
|
quantized_input_stats=[(0., 1.), (0., 1.)])
|
|
self.assertTrue(tflite_model)
|
|
|
|
# Check values from converted model.
|
|
interpreter = Interpreter(model_content=tflite_model)
|
|
interpreter.allocate_tensors()
|
|
|
|
input_details = interpreter.get_input_details()
|
|
self.assertEqual(2, len(input_details))
|
|
self.assertEqual("inputA", input_details[0]["name"])
|
|
self.assertEqual(np.uint8, input_details[0]["dtype"])
|
|
self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
|
|
self.assertEqual((1., 0.),
|
|
input_details[0]["quantization"]) # scale, zero_point
|
|
|
|
self.assertEqual("inputB", input_details[1]["name"])
|
|
self.assertEqual(np.uint8, input_details[1]["dtype"])
|
|
self.assertTrue(([1, 16, 16, 3] == input_details[1]["shape"]).all())
|
|
self.assertEqual((1., 0.),
|
|
input_details[1]["quantization"]) # scale, zero_point
|
|
|
|
output_details = interpreter.get_output_details()
|
|
self.assertEqual(1, len(output_details))
|
|
self.assertEqual("output", output_details[0]["name"])
|
|
self.assertEqual(np.uint8, output_details[0]["dtype"])
|
|
self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
|
|
self.assertTrue(output_details[0]["quantization"][0] > 0) # scale
|
|
|
|
def testGraphDefQuantizationInvalid(self):
|
|
with ops.Graph().as_default():
|
|
in_tensor_1 = array_ops.placeholder(
|
|
shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA")
|
|
in_tensor_2 = array_ops.placeholder(
|
|
shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB")
|
|
_ = array_ops.fake_quant_with_min_max_args(
|
|
in_tensor_1 + in_tensor_2, min=0., max=1., name="output")
|
|
sess = session.Session()
|
|
|
|
input_arrays_map = [("inputA", [1, 16, 16, 3]), ("inputB", [1, 16, 16, 3])]
|
|
output_arrays = ["output"]
|
|
with self.assertRaises(ValueError) as error:
|
|
convert.toco_convert_graph_def(
|
|
sess.graph_def,
|
|
input_arrays_map,
|
|
output_arrays,
|
|
enable_mlir_converter=False,
|
|
inference_type=lite_constants.QUANTIZED_UINT8)
|
|
self.assertEqual(
|
|
"std_dev and mean must be defined when inference_type or "
|
|
"inference_input_type is QUANTIZED_UINT8 or INT8.",
|
|
str(error.exception))
|
|
|
|
|
|
class ConvertTestOpHint(test_util.TensorFlowTestCase):
|
|
"""Test the hint to stub functionality."""
|
|
|
|
def _getGraphOpTypes(self, graphdef, output_nodes):
|
|
"""Returns used op types in `graphdef` reachable from `output_nodes`.
|
|
|
|
This is used to check that after the stub transformation the expected
|
|
nodes are there.
|
|
|
|
NOTE: this is not a exact test that the graph is the correct output, but
|
|
it balances compact expressibility of test with sanity checking.
|
|
|
|
Args:
|
|
graphdef: TensorFlow proto graphdef.
|
|
output_nodes: A list of output node names that we need to reach.
|
|
|
|
Returns:
|
|
A set of node types reachable from `output_nodes`.
|
|
"""
|
|
name_to_input_name, name_to_node, _ = (
|
|
_extract_graph_summary(graphdef))
|
|
# Find all nodes that are needed by the outputs
|
|
used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name)
|
|
return set([name_to_node[node_name].op for node_name in used_node_names])
|
|
|
|
def _countIdentities(self, nodes):
|
|
"""Count the number of "Identity" op types in the list of proto nodes.
|
|
|
|
Args:
|
|
nodes: NodeDefs of the graph.
|
|
|
|
Returns:
|
|
The number of nodes with op type "Identity" found.
|
|
"""
|
|
return len([x for x in nodes if x.op == "Identity"])
|
|
|
|
def testSwishLiteHint(self):
|
|
"""Makes a custom op swish and makes sure it gets converted as a unit."""
|
|
with ops.Graph().as_default():
|
|
image = array_ops.constant([1., 2., 3., 4.])
|
|
swish_scale = array_ops.constant(1.0)
|
|
|
|
def _swish(input_tensor, scale):
|
|
custom = op_hint.OpHint("cool_activation")
|
|
input_tensor, scale = custom.add_inputs(input_tensor, scale)
|
|
output = math_ops.sigmoid(input_tensor) * input_tensor * scale
|
|
output, = custom.add_outputs(output)
|
|
return output
|
|
|
|
output = array_ops.identity(
|
|
_swish(image, swish_scale), name="ModelOutput")
|
|
|
|
with self.cached_session() as sess:
|
|
# check if identities have been put into the graph (2 input, 1 output,
|
|
# and 1 final output).
|
|
self.assertEqual(self._countIdentities(sess.graph_def.node), 4)
|
|
|
|
stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
|
|
graph_def=sess.graph_def)
|
|
|
|
self.assertEqual(
|
|
self._getGraphOpTypes(
|
|
stubbed_graphdef,
|
|
output_nodes=[op_hint._tensor_name_base(output.name)]),
|
|
set(["cool_activation", "Const", "Identity"]))
|
|
|
|
def testScaleAndBiasAndIdentity(self):
|
|
"""This tests a scaled add which has 3 inputs and 2 outputs."""
|
|
with ops.Graph().as_default():
|
|
a = array_ops.constant(1.)
|
|
x = array_ops.constant([2., 3.])
|
|
b = array_ops.constant([4., 5.])
|
|
|
|
def _scaled_and_bias_and_identity(a, x, b):
|
|
custom = op_hint.OpHint("scale_and_bias_and_identity")
|
|
a, x, b = custom.add_inputs(a, x, b)
|
|
return custom.add_outputs(a * x + b, x)
|
|
|
|
output = array_ops.identity(
|
|
_scaled_and_bias_and_identity(a, x, b), name="ModelOutput")
|
|
|
|
with self.cached_session() as sess:
|
|
# make sure one identity for each input (3) and output (2) => 3 + 2 = 5
|
|
# +1 for the final output
|
|
self.assertEqual(self._countIdentities(sess.graph_def.node), 6)
|
|
|
|
stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
|
|
graph_def=sess.graph_def)
|
|
|
|
self.assertEqual(
|
|
self._getGraphOpTypes(
|
|
stubbed_graphdef,
|
|
output_nodes=[op_hint._tensor_name_base(output.name)]),
|
|
set(["scale_and_bias_and_identity", "Const", "Identity", "Pack"]))
|
|
|
|
def testTwoFunctions(self):
|
|
"""Tests if two functions are converted correctly."""
|
|
with ops.Graph().as_default():
|
|
a = array_ops.constant([1.])
|
|
b = array_ops.constant([1.])
|
|
|
|
def _double_values(x):
|
|
custom = op_hint.OpHint("add_test")
|
|
x, = custom.add_inputs(x)
|
|
output = math_ops.multiply(x, x)
|
|
output, = custom.add_outputs(output)
|
|
return output
|
|
|
|
output = array_ops.identity(
|
|
math_ops.add(_double_values(a), _double_values(b)),
|
|
name="ModelOutput")
|
|
|
|
with self.cached_session() as sess:
|
|
# make sure one identity for each input (2) and output (2) => 2 + 2
|
|
# +1 for the final output
|
|
self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
|
|
stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
|
|
graph_def=sess.graph_def)
|
|
self.assertEqual(
|
|
self._getGraphOpTypes(
|
|
stubbed_graphdef,
|
|
output_nodes=[op_hint._tensor_name_base(output.name)]),
|
|
set(["add_test", "Const", "Identity", "Add"]))
|
|
|
|
def _get_input_index(self, x):
|
|
return x.op.node_def.attr[op_hint.OpHint.FUNCTION_INPUT_INDEX_ATTR].i
|
|
|
|
def _get_output_index(self, x):
|
|
return x.op.node_def.attr[op_hint.OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i
|
|
|
|
def _get_sort_index(self, x):
|
|
return x.op.node_def.attr[op_hint.OpHint.FUNCTION_SORT_INDEX_ATTR].i
|
|
|
|
def testTags(self):
|
|
"""Test if multiple args with the same tag are grouped."""
|
|
with ops.Graph().as_default():
|
|
a = array_ops.constant([1.])
|
|
b = array_ops.constant([2.])
|
|
c = array_ops.constant([3.])
|
|
d = array_ops.constant([4.])
|
|
custom = op_hint.OpHint("test_tag")
|
|
a = custom.add_input(
|
|
a, tag="mytag", aggregate=op_hint.OpHint.AGGREGATE_STACK)
|
|
b, = custom.add_inputs(b)
|
|
c = custom.add_input(
|
|
c, tag="mytag", aggregate=op_hint.OpHint.AGGREGATE_STACK)
|
|
d = custom.add_input(
|
|
d, tag="mytag2", aggregate=op_hint.OpHint.AGGREGATE_STACK)
|
|
res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b))
|
|
custom.add_outputs([res])
|
|
with self.cached_session():
|
|
self.assertEqual(self._get_input_index(a), 0)
|
|
self.assertEqual(self._get_sort_index(a), 0)
|
|
self.assertEqual(self._get_input_index(b), 1)
|
|
self.assertEqual(self._get_sort_index(b), 0)
|
|
self.assertEqual(self._get_input_index(c), 0)
|
|
self.assertEqual(self._get_sort_index(c), 1)
|
|
|
|
def testOverrideIndex(self):
|
|
with ops.Graph().as_default():
|
|
a = array_ops.constant([1.])
|
|
b = array_ops.constant([2.])
|
|
c = array_ops.constant([3.])
|
|
custom = op_hint.OpHint("test_override")
|
|
b = custom.add_input(b) # should auto assign 0
|
|
a = custom.add_input(a, index_override=1)
|
|
c = custom.add_input(c) # should auto assign 2
|
|
with self.cached_session():
|
|
self.assertEqual(self._get_input_index(a), 1)
|
|
self.assertEqual(self._get_input_index(b), 0)
|
|
self.assertEqual(self._get_input_index(c), 2)
|
|
|
|
def testAggregate(self):
|
|
with ops.Graph().as_default():
|
|
a = array_ops.constant([3., 4.])
|
|
b = array_ops.constant([5., 6.])
|
|
hint = op_hint.OpHint("agg")
|
|
a0, a1 = array_ops.unstack(a)
|
|
b0, b1 = array_ops.unstack(b)
|
|
|
|
a0 = hint.add_input(a0, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
|
|
b0 = hint.add_input(b0, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
|
|
a1 = hint.add_input(a1, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
|
|
b1 = hint.add_input(b1, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
|
|
|
|
c0 = math_ops.add(a0, b0, name="addleft")
|
|
c1 = math_ops.add(a1, b1, name="addright")
|
|
c0 = hint.add_output(
|
|
c0, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
|
|
c1 = hint.add_output(
|
|
c1, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
|
|
|
|
curr = array_ops.stack([c0, c1])
|
|
output = array_ops.identity(curr, name="FINAL_OUTPUT")
|
|
with self.cached_session() as sess:
|
|
stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
|
|
graph_def=sess.graph_def)
|
|
self.assertEqual(
|
|
self._getGraphOpTypes(
|
|
stubbed_graphdef,
|
|
output_nodes=[op_hint._tensor_name_base(output.name)]),
|
|
set(["agg", "Const", "Identity"]))
|
|
|
|
def testFindHintedOutputNodes(self):
|
|
"""Test if all hinted output nodes are correctly found."""
|
|
with ops.Graph().as_default():
|
|
|
|
def _build_ophinted_op(name, input1, input2):
|
|
custom_op = op_hint.OpHint(name)
|
|
input1 = custom_op.add_input(input1)
|
|
input2 = custom_op.add_input(input2)
|
|
output = math_ops.mul(input1, input2)
|
|
return custom_op.add_output(output)
|
|
|
|
output_1 = _build_ophinted_op("custom_op_1", array_ops.constant([1.]),
|
|
array_ops.constant([2.]))
|
|
output_2 = _build_ophinted_op("custom_op_2", array_ops.constant([3.]),
|
|
array_ops.constant([4.]))
|
|
with self.cached_session() as sess:
|
|
hinted_outputs_nodes = op_hint.find_all_hinted_output_nodes(sess)
|
|
expected_hinted_output_nodes = [
|
|
_node_name(output_1.name),
|
|
_node_name(output_2.name)
|
|
]
|
|
self.assertEqual(
|
|
len(hinted_outputs_nodes), len(expected_hinted_output_nodes))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|