Disable grappler remapping, to avoid introducing unsupported ops. Also: avoid removing the input nodes of a graph.

PiperOrigin-RevId: 226712843
This commit is contained in:
A. Unique TensorFlower 2018-12-23 18:03:13 -08:00 committed by TensorFlower Gardener
parent 6e304fadf8
commit 9f23c56c7e

View File

@ -71,11 +71,12 @@ from tensorflow.python.util import deprecation as _deprecation
from tensorflow.python.util.tf_export import tf_export as _tf_export
def _run_graph_optimizations(graph_def, output_arrays):
def _run_graph_optimizations(graph_def, input_arrays, output_arrays):
"""Apply standard TensorFlow optimizations to the graph_def.
Args:
graph_def: Frozen GraphDef to be optimized.
input_arrays: List of arrays that are considered inputs of the graph.
output_arrays: List of arrays that are considered outputs of the graph.
Returns:
@ -86,13 +87,16 @@ def _run_graph_optimizations(graph_def, output_arrays):
# We need to add a collection called 'train_op' so that grappler
# knows what the outputs are.
fetch_collection = _meta_graph_pb2.CollectionDef()
for output in output_arrays:
fetch_collection.node_list.value.append(output)
for array in input_arrays + output_arrays:
fetch_collection.node_list.value.append(array.name)
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
config = _config_pb2.ConfigProto()
rewrite_options = config.graph_options.rewrite_options
rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.ON
# Avoid remapping as it creates ops like _FusedConv2D, which are not
# supported by TF Lite.
rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF
return _tf_optimizer.OptimizeGraph(config, meta_graph)
@ -482,7 +486,7 @@ class TFLiteConverter(object):
else:
try:
optimized_graph = _run_graph_optimizations(
self._graph_def, [t.name for t in self._output_tensors])
self._graph_def, self._input_tensors, self._output_tensors)
except Exception:
optimized_graph = self._graph_def