From 6249068668847933066ee2119cc6c88e4c57009f Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Thu, 14 Feb 2019 14:31:43 -0800 Subject: [PATCH] Adds tests to lite_test.py to test functions in 1.X. 1. Adds support for a Variable --> Identity --> ReadVariableOp graph to convert_variables_to_constants. 2. Adds a Grappler pass before freezing the graph in lite.py in order to inline. PiperOrigin-RevId: 234030833 --- tensorflow/lite/python/lite.py | 103 ++++++++++-------- tensorflow/lite/python/lite_test.py | 45 ++++++++ .../model_coverage/model_coverage_lib_test.py | 22 ++++ .../python/framework/graph_util_impl.py | 20 +++- .../python/framework/graph_util_test.py | 14 ++- 5 files changed, 148 insertions(+), 56 deletions(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 40efbe53925..a05dc28f799 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -38,6 +38,10 @@ from six import PY3 from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError +from tensorflow.core.framework import graph_pb2 as _graph_pb2 +from tensorflow.core.protobuf import config_pb2 as _config_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 as _rewriter_config_pb2 from tensorflow.lite.python import lite_constants as constants from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import @@ -54,15 +58,12 @@ from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=un from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import from tensorflow.lite.python.optimize import calibrator as _calibrator -from tensorflow.core.framework import graph_pb2 as _graph_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 as _rewriter_config_pb2 -from tensorflow.core.protobuf import config_pb2 as _config_pb2 -from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session from tensorflow.python.framework import graph_util as _tf_graph_util from tensorflow.python.framework import ops as _ops from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError +from tensorflow.python.framework.errors_impl import OpError as _OpError from tensorflow.python.framework.importer import import_graph_def as _import_graph_def from tensorflow.python.grappler import tf_optimizer as _tf_optimizer from tensorflow.python.lib.io import file_io as _file_io @@ -73,35 +74,6 @@ from tensorflow.python.util import deprecation as _deprecation from tensorflow.python.util.tf_export import tf_export as _tf_export -def _run_graph_optimizations(graph_def, input_arrays, output_arrays): - """Apply standard TensorFlow optimizations to the graph_def. - - Args: - graph_def: Frozen GraphDef to be optimized. - input_arrays: List of arrays that are considered inputs of the graph. - output_arrays: List of arrays that are considered outputs of the graph. - - Returns: - A new, optimized GraphDef. - """ - meta_graph = _export_meta_graph(graph_def=graph_def) - - # We need to add a collection called 'train_op' so that grappler - # knows what the outputs are. - fetch_collection = _meta_graph_pb2.CollectionDef() - for array in input_arrays + output_arrays: - fetch_collection.node_list.value.append(array.name) - meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) - - config = _config_pb2.ConfigProto() - rewrite_options = config.graph_options.rewrite_options - rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.ON - # Avoid remapping as it creates ops like _FusedConv2D, which are not - # supported by TF Lite. - rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF - return _tf_optimizer.OptimizeGraph(config, meta_graph) - - @_tf_export("lite.Optimize") class Optimize(enum.Enum): """Enum defining the optimizations to apply when generating tflite graphs. @@ -311,7 +283,7 @@ class TFLiteConverter(object): Returns: TFLiteConverter class. """ - graph_def = _freeze_graph(sess, output_tensors) + graph_def = _freeze_graph(sess, input_tensors, output_tensors) return cls(graph_def, input_tensors, output_tensors) @classmethod @@ -484,7 +456,7 @@ class TFLiteConverter(object): output_tensors = keras_model.outputs _set_tensor_shapes(input_tensors, input_shapes) - graph_def = _freeze_graph(sess, output_tensors) + graph_def = _freeze_graph(sess, input_tensors, output_tensors) return cls(graph_def, input_tensors, output_tensors) def __setattr__(self, name, value): @@ -586,26 +558,26 @@ class TFLiteConverter(object): "dump_graphviz_video": self.dump_graphviz_video } - optimized_graph = None - if self.inference_type == constants.QUANTIZED_UINT8: - optimized_graph = self._graph_def - else: + # Run a Grappler pass if it is possible. + graph_def = self._graph_def + if self.inference_type != constants.QUANTIZED_UINT8: try: - optimized_graph = _run_graph_optimizations( + graph_def = _run_graph_optimizations( self._graph_def, self._input_tensors, self._output_tensors) - except Exception: - optimized_graph = self._graph_def + except (_OpError, ValueError): + print("Warning: Grappler optimization pass failed. " + "If this behavior is unexpected, please file a bug.") # Converts model. if self._has_valid_tensors(): result = _toco_convert_impl( - input_data=optimized_graph, + input_data=graph_def, input_tensors=self._input_tensors, output_tensors=self._output_tensors, **converter_kwargs) else: result = _toco_convert_graph_def( - input_data=optimized_graph, + input_data=graph_def, input_arrays_with_shape=self._input_arrays_with_shape, output_arrays=self._output_arrays, **converter_kwargs) @@ -710,6 +682,35 @@ class TocoConverter(object): input_shapes, output_arrays) +def _run_graph_optimizations(graph_def, input_arrays, output_arrays): + """Apply standard TensorFlow optimizations to the graph_def. + + Args: + graph_def: Frozen GraphDef to be optimized. + input_arrays: List of arrays that are considered inputs of the graph. + output_arrays: List of arrays that are considered outputs of the graph. + + Returns: + A new, optimized GraphDef. + """ + meta_graph = _export_meta_graph(graph_def=graph_def) + + # We need to add a collection called 'train_op' so that grappler + # knows what the outputs are. + fetch_collection = _meta_graph_pb2.CollectionDef() + for array in input_arrays + output_arrays: + fetch_collection.node_list.value.append(array.name) + meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) + + config = _config_pb2.ConfigProto() + rewrite_options = config.graph_options.rewrite_options + rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.OFF + # Avoid remapping as it creates ops like _FusedConv2D, which are not + # supported by TF Lite. + rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF + return _tf_optimizer.OptimizeGraph(config, meta_graph) + + def _is_frozen_graph(sess): """Determines if the graph is frozen. @@ -728,22 +729,28 @@ def _is_frozen_graph(sess): return True -def _freeze_graph(sess, output_tensors): +def _freeze_graph(sess, input_tensors, output_tensors): """Returns a frozen GraphDef. - Freezes a graph with Variables in it. Otherwise the existing GraphDef is - returned. + Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the + existing GraphDef is returned. The Grappler pass is only run on models that + are frozen in order to inline the functions in the graph. Args: sess: TensorFlow Session. + input_tensors: List of input tensors. output_tensors: List of output tensors (only .name is used from this). Returns: Frozen GraphDef. """ + # Runs a Grappler pass in order to inline any functions in the graph. + graph_def = _run_graph_optimizations(sess.graph_def, input_tensors, + output_tensors) + if not _is_frozen_graph(sess): output_arrays = [_tensor_name(tensor) for tensor in output_tensors] return _tf_graph_util.convert_variables_to_constants( - sess, sess.graph_def, output_arrays) + sess, graph_def, output_arrays) else: return sess.graph_def diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index d41b7a75fd1..810de4be542 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -27,13 +27,16 @@ from tensorflow.lite.python import lite_constants from tensorflow.lite.python.interpreter import Interpreter from tensorflow.python import keras from tensorflow.python.client import session +from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer from tensorflow.python.platform import gfile from tensorflow.python.platform import resource_loader @@ -597,6 +600,48 @@ class FromSessionTest(test_util.TensorFlowTestCase): interpreter = Interpreter(model_content=tflite_model) interpreter.allocate_tensors() + def testFunctions(self): + """Tests tf.function in 1.X.""" + + @def_function.function + def plus_placeholder(x, placeholder): + return x + placeholder + + with ops.Graph().as_default(): + placeholder = array_ops.placeholder( + dtype=dtypes.float32, shape=[1], name='input') + variable_node = variables.Variable(1.0, name='variable_node') + defun_node = plus_placeholder(variable_node, placeholder) + output_node = math_ops.multiply(defun_node, 2.0, name='output_node') + + # Initialize variables in the model. + sess = session.Session() + sess.run(variables.variables_initializer([variable_node])) + + # Convert model and ensure model is not None. + converter = lite.TFLiteConverter.from_session(sess, [placeholder], + [output_node]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('output_node', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + @test_util.run_v1_only('b/120545219') class FromFrozenGraphFile(test_util.TensorFlowTestCase): diff --git a/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py index 4e329ac97d7..d7dd4c43a34 100644 --- a/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py +++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py @@ -26,8 +26,10 @@ from tensorflow.lite.python import lite from tensorflow.lite.testing.model_coverage import model_coverage_lib as model_coverage from tensorflow.python import keras from tensorflow.python.client import session +from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -70,6 +72,26 @@ class EvaluateFrozenGraph(test.TestCase): model_coverage.test_frozen_graph(filename, ['inputA', 'inputB'], ['add', 'Mean']) + @test_util.run_v1_only('b/120545219') + def testFunctions(self): + + @def_function.function + def plus_placeholder(x, placeholder): + return x + placeholder + + with ops.Graph().as_default(): + placeholder = array_ops.placeholder( + dtype=dtypes.float32, shape=[1], name='input') + variable_node = constant_op.constant(1.0, name='variable_node') + defun_node = plus_placeholder(variable_node, placeholder) + _ = math_ops.multiply(defun_node, 2.0, name='output_node') + + # Initialize variables in the model. + sess = session.Session() + + filename = self._saveFrozenGraph(sess) + model_coverage.test_frozen_graph(filename, ['input'], ['output_node']) + def _getQuantizedModel(self): np.random.seed(0) with session.Session().as_default() as sess: diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py index a46fccc513c..50cdb7a15d1 100644 --- a/tensorflow/python/framework/graph_util_impl.py +++ b/tensorflow/python/framework/graph_util_impl.py @@ -248,6 +248,15 @@ def convert_variables_to_constants(sess, found_variables = {} variable_names = [] variable_dict_names = [] + identity_ops_input_map = {} + + def is_found_variable(input_tensor_name): + # Determines if the `input_tensor_name` is in `found_variables` or is an + # Identity op with an input that is in `found_variables`. + return ((input_tensor_name in found_variables) or + (input_tensor_name in identity_ops_input_map and + identity_ops_input_map[input_tensor_name] in found_variables)) + for node in inference_graph.node: if node.op in ["Variable", "VariableV2", "VarHandleOp"]: variable_name = node.name @@ -261,6 +270,9 @@ def convert_variables_to_constants(sess, variable_names.append(variable_name + "/Read/ReadVariableOp:0") else: variable_names.append(variable_name + ":0") + elif node.op == "Identity": + # Creates a map of Identity node names to the input names. + identity_ops_input_map[node.name] = node.input[0].split(":")[0] if variable_names: returned_variables = sess.run(variable_names) else: @@ -283,11 +295,15 @@ def convert_variables_to_constants(sess, tensor=tensor_util.make_tensor_proto( data, dtype=dtype.type, shape=data.shape))) how_many_converted += 1 - elif input_node.op == "ReadVariableOp" and ( - input_node.input[0] in found_variables): + elif (input_node.op == "ReadVariableOp" and + is_found_variable(input_node.input[0])): # The preceding branch converts all VarHandleOps of ResourceVariables to # constants, so we need to convert the associated ReadVariableOps to # Identity ops. + # + # Handles the following cases: + # Variable --> ReadVariableOp + # Variable --> Identity --> ReadVariableOp output_node.op = "Identity" output_node.name = input_node.name output_node.input.extend([input_node.input[0]]) diff --git a/tensorflow/python/framework/graph_util_test.py b/tensorflow/python/framework/graph_util_test.py index dd26b8a78e9..6802586ef65 100644 --- a/tensorflow/python/framework/graph_util_test.py +++ b/tensorflow/python/framework/graph_util_test.py @@ -217,16 +217,18 @@ class DeviceFunctionsTest(test.TestCase): self.assertNear(4.0, output, 0.00001) variable_graph_def = sess.graph.as_graph_def() - # First get the constant_graph_def when variable_names_whitelist is set, - # note that if variable_names_whitelist is not set an error will be - # thrown because unused_variable_node is not initialized. + # Get the constant_graph_def. constant_graph_def = graph_util.convert_variables_to_constants( - sess, - variable_graph_def, ["output_node"], - variable_names_whitelist=set(["variable_node"])) + sess, variable_graph_def, ["output_node"]) + # Ensure the library is copied and there are no variables after + # freezing. self.assertEqual(variable_graph_def.library, constant_graph_def.library) + for node in constant_graph_def.node: + self.assertNotIn( + node.op, + ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"]) def testConvertVariablesToConsts(self): self._test_variable_to_const_conversion(use_resource=False)