Adds tests to lite_test.py to test functions in 1.X.

1. Adds support for a Variable --> Identity --> ReadVariableOp graph to convert_variables_to_constants.
2. Adds a Grappler pass before freezing the graph in lite.py in order to inline.

PiperOrigin-RevId: 234030833
This commit is contained in:
Nupur Garg 2019-02-14 14:31:43 -08:00 committed by TensorFlower Gardener
parent e7d9786e66
commit 6249068668
5 changed files with 148 additions and 56 deletions

View File

@ -38,6 +38,10 @@ from six import PY3
from google.protobuf import text_format as _text_format from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError from google.protobuf.message import DecodeError
from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.core.protobuf import config_pb2 as _config_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2 as _rewriter_config_pb2
from tensorflow.lite.python import lite_constants as constants from tensorflow.lite.python import lite_constants as constants
from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import
@ -54,15 +58,12 @@ from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=un
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
from tensorflow.lite.python.optimize import calibrator as _calibrator from tensorflow.lite.python.optimize import calibrator as _calibrator
from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2 as _rewriter_config_pb2
from tensorflow.core.protobuf import config_pb2 as _config_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
from tensorflow.python import keras as _keras from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as _tf_graph_util from tensorflow.python.framework import graph_util as _tf_graph_util
from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
from tensorflow.python.framework.errors_impl import OpError as _OpError
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
from tensorflow.python.grappler import tf_optimizer as _tf_optimizer from tensorflow.python.grappler import tf_optimizer as _tf_optimizer
from tensorflow.python.lib.io import file_io as _file_io from tensorflow.python.lib.io import file_io as _file_io
@ -73,35 +74,6 @@ from tensorflow.python.util import deprecation as _deprecation
from tensorflow.python.util.tf_export import tf_export as _tf_export from tensorflow.python.util.tf_export import tf_export as _tf_export
def _run_graph_optimizations(graph_def, input_arrays, output_arrays):
"""Apply standard TensorFlow optimizations to the graph_def.
Args:
graph_def: Frozen GraphDef to be optimized.
input_arrays: List of arrays that are considered inputs of the graph.
output_arrays: List of arrays that are considered outputs of the graph.
Returns:
A new, optimized GraphDef.
"""
meta_graph = _export_meta_graph(graph_def=graph_def)
# We need to add a collection called 'train_op' so that grappler
# knows what the outputs are.
fetch_collection = _meta_graph_pb2.CollectionDef()
for array in input_arrays + output_arrays:
fetch_collection.node_list.value.append(array.name)
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
config = _config_pb2.ConfigProto()
rewrite_options = config.graph_options.rewrite_options
rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.ON
# Avoid remapping as it creates ops like _FusedConv2D, which are not
# supported by TF Lite.
rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF
return _tf_optimizer.OptimizeGraph(config, meta_graph)
@_tf_export("lite.Optimize") @_tf_export("lite.Optimize")
class Optimize(enum.Enum): class Optimize(enum.Enum):
"""Enum defining the optimizations to apply when generating tflite graphs. """Enum defining the optimizations to apply when generating tflite graphs.
@ -311,7 +283,7 @@ class TFLiteConverter(object):
Returns: Returns:
TFLiteConverter class. TFLiteConverter class.
""" """
graph_def = _freeze_graph(sess, output_tensors) graph_def = _freeze_graph(sess, input_tensors, output_tensors)
return cls(graph_def, input_tensors, output_tensors) return cls(graph_def, input_tensors, output_tensors)
@classmethod @classmethod
@ -484,7 +456,7 @@ class TFLiteConverter(object):
output_tensors = keras_model.outputs output_tensors = keras_model.outputs
_set_tensor_shapes(input_tensors, input_shapes) _set_tensor_shapes(input_tensors, input_shapes)
graph_def = _freeze_graph(sess, output_tensors) graph_def = _freeze_graph(sess, input_tensors, output_tensors)
return cls(graph_def, input_tensors, output_tensors) return cls(graph_def, input_tensors, output_tensors)
def __setattr__(self, name, value): def __setattr__(self, name, value):
@ -586,26 +558,26 @@ class TFLiteConverter(object):
"dump_graphviz_video": self.dump_graphviz_video "dump_graphviz_video": self.dump_graphviz_video
} }
optimized_graph = None # Run a Grappler pass if it is possible.
if self.inference_type == constants.QUANTIZED_UINT8: graph_def = self._graph_def
optimized_graph = self._graph_def if self.inference_type != constants.QUANTIZED_UINT8:
else:
try: try:
optimized_graph = _run_graph_optimizations( graph_def = _run_graph_optimizations(
self._graph_def, self._input_tensors, self._output_tensors) self._graph_def, self._input_tensors, self._output_tensors)
except Exception: except (_OpError, ValueError):
optimized_graph = self._graph_def print("Warning: Grappler optimization pass failed. "
"If this behavior is unexpected, please file a bug.")
# Converts model. # Converts model.
if self._has_valid_tensors(): if self._has_valid_tensors():
result = _toco_convert_impl( result = _toco_convert_impl(
input_data=optimized_graph, input_data=graph_def,
input_tensors=self._input_tensors, input_tensors=self._input_tensors,
output_tensors=self._output_tensors, output_tensors=self._output_tensors,
**converter_kwargs) **converter_kwargs)
else: else:
result = _toco_convert_graph_def( result = _toco_convert_graph_def(
input_data=optimized_graph, input_data=graph_def,
input_arrays_with_shape=self._input_arrays_with_shape, input_arrays_with_shape=self._input_arrays_with_shape,
output_arrays=self._output_arrays, output_arrays=self._output_arrays,
**converter_kwargs) **converter_kwargs)
@ -710,6 +682,35 @@ class TocoConverter(object):
input_shapes, output_arrays) input_shapes, output_arrays)
def _run_graph_optimizations(graph_def, input_arrays, output_arrays):
"""Apply standard TensorFlow optimizations to the graph_def.
Args:
graph_def: Frozen GraphDef to be optimized.
input_arrays: List of arrays that are considered inputs of the graph.
output_arrays: List of arrays that are considered outputs of the graph.
Returns:
A new, optimized GraphDef.
"""
meta_graph = _export_meta_graph(graph_def=graph_def)
# We need to add a collection called 'train_op' so that grappler
# knows what the outputs are.
fetch_collection = _meta_graph_pb2.CollectionDef()
for array in input_arrays + output_arrays:
fetch_collection.node_list.value.append(array.name)
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
config = _config_pb2.ConfigProto()
rewrite_options = config.graph_options.rewrite_options
rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.OFF
# Avoid remapping as it creates ops like _FusedConv2D, which are not
# supported by TF Lite.
rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF
return _tf_optimizer.OptimizeGraph(config, meta_graph)
def _is_frozen_graph(sess): def _is_frozen_graph(sess):
"""Determines if the graph is frozen. """Determines if the graph is frozen.
@ -728,22 +729,28 @@ def _is_frozen_graph(sess):
return True return True
def _freeze_graph(sess, output_tensors): def _freeze_graph(sess, input_tensors, output_tensors):
"""Returns a frozen GraphDef. """Returns a frozen GraphDef.
Freezes a graph with Variables in it. Otherwise the existing GraphDef is Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
returned. existing GraphDef is returned. The Grappler pass is only run on models that
are frozen in order to inline the functions in the graph.
Args: Args:
sess: TensorFlow Session. sess: TensorFlow Session.
input_tensors: List of input tensors.
output_tensors: List of output tensors (only .name is used from this). output_tensors: List of output tensors (only .name is used from this).
Returns: Returns:
Frozen GraphDef. Frozen GraphDef.
""" """
# Runs a Grappler pass in order to inline any functions in the graph.
graph_def = _run_graph_optimizations(sess.graph_def, input_tensors,
output_tensors)
if not _is_frozen_graph(sess): if not _is_frozen_graph(sess):
output_arrays = [_tensor_name(tensor) for tensor in output_tensors] output_arrays = [_tensor_name(tensor) for tensor in output_tensors]
return _tf_graph_util.convert_variables_to_constants( return _tf_graph_util.convert_variables_to_constants(
sess, sess.graph_def, output_arrays) sess, graph_def, output_arrays)
else: else:
return sess.graph_def return sess.graph_def

View File

@ -27,13 +27,16 @@ from tensorflow.lite.python import lite_constants
from tensorflow.lite.python.interpreter import Interpreter from tensorflow.lite.python.interpreter import Interpreter
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tensorflow.python.platform import resource_loader from tensorflow.python.platform import resource_loader
@ -597,6 +600,48 @@ class FromSessionTest(test_util.TensorFlowTestCase):
interpreter = Interpreter(model_content=tflite_model) interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() interpreter.allocate_tensors()
def testFunctions(self):
"""Tests tf.function in 1.X."""
@def_function.function
def plus_placeholder(x, placeholder):
return x + placeholder
with ops.Graph().as_default():
placeholder = array_ops.placeholder(
dtype=dtypes.float32, shape=[1], name='input')
variable_node = variables.Variable(1.0, name='variable_node')
defun_node = plus_placeholder(variable_node, placeholder)
output_node = math_ops.multiply(defun_node, 2.0, name='output_node')
# Initialize variables in the model.
sess = session.Session()
sess.run(variables.variables_initializer([variable_node]))
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [placeholder],
[output_node])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual('input', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1] == input_details[0]['shape']).all())
self.assertEqual((0., 0.), input_details[0]['quantization'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output_node', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
@test_util.run_v1_only('b/120545219') @test_util.run_v1_only('b/120545219')
class FromFrozenGraphFile(test_util.TensorFlowTestCase): class FromFrozenGraphFile(test_util.TensorFlowTestCase):

View File

@ -26,8 +26,10 @@ from tensorflow.lite.python import lite
from tensorflow.lite.testing.model_coverage import model_coverage_lib as model_coverage from tensorflow.lite.testing.model_coverage import model_coverage_lib as model_coverage
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -70,6 +72,26 @@ class EvaluateFrozenGraph(test.TestCase):
model_coverage.test_frozen_graph(filename, ['inputA', 'inputB'], model_coverage.test_frozen_graph(filename, ['inputA', 'inputB'],
['add', 'Mean']) ['add', 'Mean'])
@test_util.run_v1_only('b/120545219')
def testFunctions(self):
@def_function.function
def plus_placeholder(x, placeholder):
return x + placeholder
with ops.Graph().as_default():
placeholder = array_ops.placeholder(
dtype=dtypes.float32, shape=[1], name='input')
variable_node = constant_op.constant(1.0, name='variable_node')
defun_node = plus_placeholder(variable_node, placeholder)
_ = math_ops.multiply(defun_node, 2.0, name='output_node')
# Initialize variables in the model.
sess = session.Session()
filename = self._saveFrozenGraph(sess)
model_coverage.test_frozen_graph(filename, ['input'], ['output_node'])
def _getQuantizedModel(self): def _getQuantizedModel(self):
np.random.seed(0) np.random.seed(0)
with session.Session().as_default() as sess: with session.Session().as_default() as sess:

View File

@ -248,6 +248,15 @@ def convert_variables_to_constants(sess,
found_variables = {} found_variables = {}
variable_names = [] variable_names = []
variable_dict_names = [] variable_dict_names = []
identity_ops_input_map = {}
def is_found_variable(input_tensor_name):
# Determines if the `input_tensor_name` is in `found_variables` or is an
# Identity op with an input that is in `found_variables`.
return ((input_tensor_name in found_variables) or
(input_tensor_name in identity_ops_input_map and
identity_ops_input_map[input_tensor_name] in found_variables))
for node in inference_graph.node: for node in inference_graph.node:
if node.op in ["Variable", "VariableV2", "VarHandleOp"]: if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
variable_name = node.name variable_name = node.name
@ -261,6 +270,9 @@ def convert_variables_to_constants(sess,
variable_names.append(variable_name + "/Read/ReadVariableOp:0") variable_names.append(variable_name + "/Read/ReadVariableOp:0")
else: else:
variable_names.append(variable_name + ":0") variable_names.append(variable_name + ":0")
elif node.op == "Identity":
# Creates a map of Identity node names to the input names.
identity_ops_input_map[node.name] = node.input[0].split(":")[0]
if variable_names: if variable_names:
returned_variables = sess.run(variable_names) returned_variables = sess.run(variable_names)
else: else:
@ -283,11 +295,15 @@ def convert_variables_to_constants(sess,
tensor=tensor_util.make_tensor_proto( tensor=tensor_util.make_tensor_proto(
data, dtype=dtype.type, shape=data.shape))) data, dtype=dtype.type, shape=data.shape)))
how_many_converted += 1 how_many_converted += 1
elif input_node.op == "ReadVariableOp" and ( elif (input_node.op == "ReadVariableOp" and
input_node.input[0] in found_variables): is_found_variable(input_node.input[0])):
# The preceding branch converts all VarHandleOps of ResourceVariables to # The preceding branch converts all VarHandleOps of ResourceVariables to
# constants, so we need to convert the associated ReadVariableOps to # constants, so we need to convert the associated ReadVariableOps to
# Identity ops. # Identity ops.
#
# Handles the following cases:
# Variable --> ReadVariableOp
# Variable --> Identity --> ReadVariableOp
output_node.op = "Identity" output_node.op = "Identity"
output_node.name = input_node.name output_node.name = input_node.name
output_node.input.extend([input_node.input[0]]) output_node.input.extend([input_node.input[0]])

View File

@ -217,16 +217,18 @@ class DeviceFunctionsTest(test.TestCase):
self.assertNear(4.0, output, 0.00001) self.assertNear(4.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def() variable_graph_def = sess.graph.as_graph_def()
# First get the constant_graph_def when variable_names_whitelist is set, # Get the constant_graph_def.
# note that if variable_names_whitelist is not set an error will be
# thrown because unused_variable_node is not initialized.
constant_graph_def = graph_util.convert_variables_to_constants( constant_graph_def = graph_util.convert_variables_to_constants(
sess, sess, variable_graph_def, ["output_node"])
variable_graph_def, ["output_node"],
variable_names_whitelist=set(["variable_node"]))
# Ensure the library is copied and there are no variables after
# freezing.
self.assertEqual(variable_graph_def.library, self.assertEqual(variable_graph_def.library,
constant_graph_def.library) constant_graph_def.library)
for node in constant_graph_def.node:
self.assertNotIn(
node.op,
["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
def testConvertVariablesToConsts(self): def testConvertVariablesToConsts(self):
self._test_variable_to_const_conversion(use_resource=False) self._test_variable_to_const_conversion(use_resource=False)