Disable grappler remapping, to avoid introducing unsupported ops. Also: avoid removing the input nodes of a graph.
PiperOrigin-RevId: 226712843
This commit is contained in:
parent
6e304fadf8
commit
9f23c56c7e
@ -71,11 +71,12 @@ from tensorflow.python.util import deprecation as _deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
||||
|
||||
|
||||
def _run_graph_optimizations(graph_def, output_arrays):
|
||||
def _run_graph_optimizations(graph_def, input_arrays, output_arrays):
|
||||
"""Apply standard TensorFlow optimizations to the graph_def.
|
||||
|
||||
Args:
|
||||
graph_def: Frozen GraphDef to be optimized.
|
||||
input_arrays: List of arrays that are considered inputs of the graph.
|
||||
output_arrays: List of arrays that are considered outputs of the graph.
|
||||
|
||||
Returns:
|
||||
@ -86,13 +87,16 @@ def _run_graph_optimizations(graph_def, output_arrays):
|
||||
# We need to add a collection called 'train_op' so that grappler
|
||||
# knows what the outputs are.
|
||||
fetch_collection = _meta_graph_pb2.CollectionDef()
|
||||
for output in output_arrays:
|
||||
fetch_collection.node_list.value.append(output)
|
||||
for array in input_arrays + output_arrays:
|
||||
fetch_collection.node_list.value.append(array.name)
|
||||
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
|
||||
|
||||
config = _config_pb2.ConfigProto()
|
||||
rewrite_options = config.graph_options.rewrite_options
|
||||
rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.ON
|
||||
# Avoid remapping as it creates ops like _FusedConv2D, which are not
|
||||
# supported by TF Lite.
|
||||
rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF
|
||||
return _tf_optimizer.OptimizeGraph(config, meta_graph)
|
||||
|
||||
|
||||
@ -482,7 +486,7 @@ class TFLiteConverter(object):
|
||||
else:
|
||||
try:
|
||||
optimized_graph = _run_graph_optimizations(
|
||||
self._graph_def, [t.name for t in self._output_tensors])
|
||||
self._graph_def, self._input_tensors, self._output_tensors)
|
||||
except Exception:
|
||||
optimized_graph = self._graph_def
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user