diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 0f9b1312c71..5ffa8c426b9 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -71,11 +71,12 @@ from tensorflow.python.util import deprecation as _deprecation from tensorflow.python.util.tf_export import tf_export as _tf_export -def _run_graph_optimizations(graph_def, output_arrays): +def _run_graph_optimizations(graph_def, input_arrays, output_arrays): """Apply standard TensorFlow optimizations to the graph_def. Args: graph_def: Frozen GraphDef to be optimized. + input_arrays: List of arrays that are considered inputs of the graph. output_arrays: List of arrays that are considered outputs of the graph. Returns: @@ -86,13 +87,16 @@ def _run_graph_optimizations(graph_def, output_arrays): # We need to add a collection called 'train_op' so that grappler # knows what the outputs are. fetch_collection = _meta_graph_pb2.CollectionDef() - for output in output_arrays: - fetch_collection.node_list.value.append(output) + for array in input_arrays + output_arrays: + fetch_collection.node_list.value.append(array.name) meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) config = _config_pb2.ConfigProto() rewrite_options = config.graph_options.rewrite_options rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.ON + # Avoid remapping as it creates ops like _FusedConv2D, which are not + # supported by TF Lite. + rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF return _tf_optimizer.OptimizeGraph(config, meta_graph) @@ -482,7 +486,7 @@ class TFLiteConverter(object): else: try: optimized_graph = _run_graph_optimizations( - self._graph_def, [t.name for t in self._output_tensors]) + self._graph_def, self._input_tensors, self._output_tensors) except Exception: optimized_graph = self._graph_def