Adds tests to lite_test.py to test functions in 1.X.
1. Adds support for a Variable --> Identity --> ReadVariableOp graph to convert_variables_to_constants. 2. Adds a Grappler pass before freezing the graph in lite.py in order to inline. PiperOrigin-RevId: 234030833
This commit is contained in:
parent
e7d9786e66
commit
6249068668
@ -38,6 +38,10 @@ from six import PY3
|
||||
|
||||
from google.protobuf import text_format as _text_format
|
||||
from google.protobuf.message import DecodeError
|
||||
from tensorflow.core.framework import graph_pb2 as _graph_pb2
|
||||
from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2 as _rewriter_config_pb2
|
||||
from tensorflow.lite.python import lite_constants as constants
|
||||
from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import
|
||||
@ -54,15 +58,12 @@ from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=un
|
||||
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.optimize import calibrator as _calibrator
|
||||
from tensorflow.core.framework import graph_pb2 as _graph_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2 as _rewriter_config_pb2
|
||||
from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
||||
from tensorflow.python import keras as _keras
|
||||
from tensorflow.python.client import session as _session
|
||||
from tensorflow.python.framework import graph_util as _tf_graph_util
|
||||
from tensorflow.python.framework import ops as _ops
|
||||
from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
|
||||
from tensorflow.python.framework.errors_impl import OpError as _OpError
|
||||
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
|
||||
from tensorflow.python.grappler import tf_optimizer as _tf_optimizer
|
||||
from tensorflow.python.lib.io import file_io as _file_io
|
||||
@ -73,35 +74,6 @@ from tensorflow.python.util import deprecation as _deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
||||
|
||||
|
||||
def _run_graph_optimizations(graph_def, input_arrays, output_arrays):
|
||||
"""Apply standard TensorFlow optimizations to the graph_def.
|
||||
|
||||
Args:
|
||||
graph_def: Frozen GraphDef to be optimized.
|
||||
input_arrays: List of arrays that are considered inputs of the graph.
|
||||
output_arrays: List of arrays that are considered outputs of the graph.
|
||||
|
||||
Returns:
|
||||
A new, optimized GraphDef.
|
||||
"""
|
||||
meta_graph = _export_meta_graph(graph_def=graph_def)
|
||||
|
||||
# We need to add a collection called 'train_op' so that grappler
|
||||
# knows what the outputs are.
|
||||
fetch_collection = _meta_graph_pb2.CollectionDef()
|
||||
for array in input_arrays + output_arrays:
|
||||
fetch_collection.node_list.value.append(array.name)
|
||||
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
|
||||
|
||||
config = _config_pb2.ConfigProto()
|
||||
rewrite_options = config.graph_options.rewrite_options
|
||||
rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.ON
|
||||
# Avoid remapping as it creates ops like _FusedConv2D, which are not
|
||||
# supported by TF Lite.
|
||||
rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF
|
||||
return _tf_optimizer.OptimizeGraph(config, meta_graph)
|
||||
|
||||
|
||||
@_tf_export("lite.Optimize")
|
||||
class Optimize(enum.Enum):
|
||||
"""Enum defining the optimizations to apply when generating tflite graphs.
|
||||
@ -311,7 +283,7 @@ class TFLiteConverter(object):
|
||||
Returns:
|
||||
TFLiteConverter class.
|
||||
"""
|
||||
graph_def = _freeze_graph(sess, output_tensors)
|
||||
graph_def = _freeze_graph(sess, input_tensors, output_tensors)
|
||||
return cls(graph_def, input_tensors, output_tensors)
|
||||
|
||||
@classmethod
|
||||
@ -484,7 +456,7 @@ class TFLiteConverter(object):
|
||||
output_tensors = keras_model.outputs
|
||||
_set_tensor_shapes(input_tensors, input_shapes)
|
||||
|
||||
graph_def = _freeze_graph(sess, output_tensors)
|
||||
graph_def = _freeze_graph(sess, input_tensors, output_tensors)
|
||||
return cls(graph_def, input_tensors, output_tensors)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
@ -586,26 +558,26 @@ class TFLiteConverter(object):
|
||||
"dump_graphviz_video": self.dump_graphviz_video
|
||||
}
|
||||
|
||||
optimized_graph = None
|
||||
if self.inference_type == constants.QUANTIZED_UINT8:
|
||||
optimized_graph = self._graph_def
|
||||
else:
|
||||
# Run a Grappler pass if it is possible.
|
||||
graph_def = self._graph_def
|
||||
if self.inference_type != constants.QUANTIZED_UINT8:
|
||||
try:
|
||||
optimized_graph = _run_graph_optimizations(
|
||||
graph_def = _run_graph_optimizations(
|
||||
self._graph_def, self._input_tensors, self._output_tensors)
|
||||
except Exception:
|
||||
optimized_graph = self._graph_def
|
||||
except (_OpError, ValueError):
|
||||
print("Warning: Grappler optimization pass failed. "
|
||||
"If this behavior is unexpected, please file a bug.")
|
||||
|
||||
# Converts model.
|
||||
if self._has_valid_tensors():
|
||||
result = _toco_convert_impl(
|
||||
input_data=optimized_graph,
|
||||
input_data=graph_def,
|
||||
input_tensors=self._input_tensors,
|
||||
output_tensors=self._output_tensors,
|
||||
**converter_kwargs)
|
||||
else:
|
||||
result = _toco_convert_graph_def(
|
||||
input_data=optimized_graph,
|
||||
input_data=graph_def,
|
||||
input_arrays_with_shape=self._input_arrays_with_shape,
|
||||
output_arrays=self._output_arrays,
|
||||
**converter_kwargs)
|
||||
@ -710,6 +682,35 @@ class TocoConverter(object):
|
||||
input_shapes, output_arrays)
|
||||
|
||||
|
||||
def _run_graph_optimizations(graph_def, input_arrays, output_arrays):
|
||||
"""Apply standard TensorFlow optimizations to the graph_def.
|
||||
|
||||
Args:
|
||||
graph_def: Frozen GraphDef to be optimized.
|
||||
input_arrays: List of arrays that are considered inputs of the graph.
|
||||
output_arrays: List of arrays that are considered outputs of the graph.
|
||||
|
||||
Returns:
|
||||
A new, optimized GraphDef.
|
||||
"""
|
||||
meta_graph = _export_meta_graph(graph_def=graph_def)
|
||||
|
||||
# We need to add a collection called 'train_op' so that grappler
|
||||
# knows what the outputs are.
|
||||
fetch_collection = _meta_graph_pb2.CollectionDef()
|
||||
for array in input_arrays + output_arrays:
|
||||
fetch_collection.node_list.value.append(array.name)
|
||||
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
|
||||
|
||||
config = _config_pb2.ConfigProto()
|
||||
rewrite_options = config.graph_options.rewrite_options
|
||||
rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.OFF
|
||||
# Avoid remapping as it creates ops like _FusedConv2D, which are not
|
||||
# supported by TF Lite.
|
||||
rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF
|
||||
return _tf_optimizer.OptimizeGraph(config, meta_graph)
|
||||
|
||||
|
||||
def _is_frozen_graph(sess):
|
||||
"""Determines if the graph is frozen.
|
||||
|
||||
@ -728,22 +729,28 @@ def _is_frozen_graph(sess):
|
||||
return True
|
||||
|
||||
|
||||
def _freeze_graph(sess, output_tensors):
|
||||
def _freeze_graph(sess, input_tensors, output_tensors):
|
||||
"""Returns a frozen GraphDef.
|
||||
|
||||
Freezes a graph with Variables in it. Otherwise the existing GraphDef is
|
||||
returned.
|
||||
Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
|
||||
existing GraphDef is returned. The Grappler pass is only run on models that
|
||||
are frozen in order to inline the functions in the graph.
|
||||
|
||||
Args:
|
||||
sess: TensorFlow Session.
|
||||
input_tensors: List of input tensors.
|
||||
output_tensors: List of output tensors (only .name is used from this).
|
||||
|
||||
Returns:
|
||||
Frozen GraphDef.
|
||||
"""
|
||||
# Runs a Grappler pass in order to inline any functions in the graph.
|
||||
graph_def = _run_graph_optimizations(sess.graph_def, input_tensors,
|
||||
output_tensors)
|
||||
|
||||
if not _is_frozen_graph(sess):
|
||||
output_arrays = [_tensor_name(tensor) for tensor in output_tensors]
|
||||
return _tf_graph_util.convert_variables_to_constants(
|
||||
sess, sess.graph_def, output_arrays)
|
||||
sess, graph_def, output_arrays)
|
||||
else:
|
||||
return sess.graph_def
|
||||
|
@ -27,13 +27,16 @@ from tensorflow.lite.python import lite_constants
|
||||
from tensorflow.lite.python.interpreter import Interpreter
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import resource_loader
|
||||
@ -597,6 +600,48 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
def testFunctions(self):
|
||||
"""Tests tf.function in 1.X."""
|
||||
|
||||
@def_function.function
|
||||
def plus_placeholder(x, placeholder):
|
||||
return x + placeholder
|
||||
|
||||
with ops.Graph().as_default():
|
||||
placeholder = array_ops.placeholder(
|
||||
dtype=dtypes.float32, shape=[1], name='input')
|
||||
variable_node = variables.Variable(1.0, name='variable_node')
|
||||
defun_node = plus_placeholder(variable_node, placeholder)
|
||||
output_node = math_ops.multiply(defun_node, 2.0, name='output_node')
|
||||
|
||||
# Initialize variables in the model.
|
||||
sess = session.Session()
|
||||
sess.run(variables.variables_initializer([variable_node]))
|
||||
|
||||
# Convert model and ensure model is not None.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [placeholder],
|
||||
[output_node])
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertEqual(1, len(input_details))
|
||||
self.assertEqual('input', input_details[0]['name'])
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue(([1] == input_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(1, len(output_details))
|
||||
self.assertEqual('output_node', output_details[0]['name'])
|
||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||
self.assertTrue(([1] == output_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
||||
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
|
||||
|
@ -26,8 +26,10 @@ from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.testing.model_coverage import model_coverage_lib as model_coverage
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -70,6 +72,26 @@ class EvaluateFrozenGraph(test.TestCase):
|
||||
model_coverage.test_frozen_graph(filename, ['inputA', 'inputB'],
|
||||
['add', 'Mean'])
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
def testFunctions(self):
|
||||
|
||||
@def_function.function
|
||||
def plus_placeholder(x, placeholder):
|
||||
return x + placeholder
|
||||
|
||||
with ops.Graph().as_default():
|
||||
placeholder = array_ops.placeholder(
|
||||
dtype=dtypes.float32, shape=[1], name='input')
|
||||
variable_node = constant_op.constant(1.0, name='variable_node')
|
||||
defun_node = plus_placeholder(variable_node, placeholder)
|
||||
_ = math_ops.multiply(defun_node, 2.0, name='output_node')
|
||||
|
||||
# Initialize variables in the model.
|
||||
sess = session.Session()
|
||||
|
||||
filename = self._saveFrozenGraph(sess)
|
||||
model_coverage.test_frozen_graph(filename, ['input'], ['output_node'])
|
||||
|
||||
def _getQuantizedModel(self):
|
||||
np.random.seed(0)
|
||||
with session.Session().as_default() as sess:
|
||||
|
@ -248,6 +248,15 @@ def convert_variables_to_constants(sess,
|
||||
found_variables = {}
|
||||
variable_names = []
|
||||
variable_dict_names = []
|
||||
identity_ops_input_map = {}
|
||||
|
||||
def is_found_variable(input_tensor_name):
|
||||
# Determines if the `input_tensor_name` is in `found_variables` or is an
|
||||
# Identity op with an input that is in `found_variables`.
|
||||
return ((input_tensor_name in found_variables) or
|
||||
(input_tensor_name in identity_ops_input_map and
|
||||
identity_ops_input_map[input_tensor_name] in found_variables))
|
||||
|
||||
for node in inference_graph.node:
|
||||
if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
|
||||
variable_name = node.name
|
||||
@ -261,6 +270,9 @@ def convert_variables_to_constants(sess,
|
||||
variable_names.append(variable_name + "/Read/ReadVariableOp:0")
|
||||
else:
|
||||
variable_names.append(variable_name + ":0")
|
||||
elif node.op == "Identity":
|
||||
# Creates a map of Identity node names to the input names.
|
||||
identity_ops_input_map[node.name] = node.input[0].split(":")[0]
|
||||
if variable_names:
|
||||
returned_variables = sess.run(variable_names)
|
||||
else:
|
||||
@ -283,11 +295,15 @@ def convert_variables_to_constants(sess,
|
||||
tensor=tensor_util.make_tensor_proto(
|
||||
data, dtype=dtype.type, shape=data.shape)))
|
||||
how_many_converted += 1
|
||||
elif input_node.op == "ReadVariableOp" and (
|
||||
input_node.input[0] in found_variables):
|
||||
elif (input_node.op == "ReadVariableOp" and
|
||||
is_found_variable(input_node.input[0])):
|
||||
# The preceding branch converts all VarHandleOps of ResourceVariables to
|
||||
# constants, so we need to convert the associated ReadVariableOps to
|
||||
# Identity ops.
|
||||
#
|
||||
# Handles the following cases:
|
||||
# Variable --> ReadVariableOp
|
||||
# Variable --> Identity --> ReadVariableOp
|
||||
output_node.op = "Identity"
|
||||
output_node.name = input_node.name
|
||||
output_node.input.extend([input_node.input[0]])
|
||||
|
@ -217,16 +217,18 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
self.assertNear(4.0, output, 0.00001)
|
||||
variable_graph_def = sess.graph.as_graph_def()
|
||||
|
||||
# First get the constant_graph_def when variable_names_whitelist is set,
|
||||
# note that if variable_names_whitelist is not set an error will be
|
||||
# thrown because unused_variable_node is not initialized.
|
||||
# Get the constant_graph_def.
|
||||
constant_graph_def = graph_util.convert_variables_to_constants(
|
||||
sess,
|
||||
variable_graph_def, ["output_node"],
|
||||
variable_names_whitelist=set(["variable_node"]))
|
||||
sess, variable_graph_def, ["output_node"])
|
||||
|
||||
# Ensure the library is copied and there are no variables after
|
||||
# freezing.
|
||||
self.assertEqual(variable_graph_def.library,
|
||||
constant_graph_def.library)
|
||||
for node in constant_graph_def.node:
|
||||
self.assertNotIn(
|
||||
node.op,
|
||||
["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
||||
|
||||
def testConvertVariablesToConsts(self):
|
||||
self._test_variable_to_const_conversion(use_resource=False)
|
||||
|
Loading…
Reference in New Issue
Block a user