Only run Grappler optimizations when necessary in TFLiteConverter.

PiperOrigin-RevId: 246023134
This commit is contained in:
Nupur Garg 2019-04-30 14:35:21 -07:00 committed by TensorFlower Gardener
parent 7e025006dd
commit 1c6d02d8b8
2 changed files with 28 additions and 34 deletions

View File

@ -157,7 +157,12 @@ class TFLiteConverterBase(object):
def _grappler_config(self, target_ops):
is_only_flex_enabled = set([OpsSet.SELECT_TF_OPS]) == target_ops
return _get_grappler_config(enable_layout_optimizer=is_only_flex_enabled)
if is_only_flex_enabled:
# The layout optimizer turns NHCW to NCHW. This provides performance
# optimizations when Flex mode is enabled. However, this is not compatible
# with builtin ops.
return _get_grappler_config(["layout"])
return None
def _validate_representative_dataset(self):
if self.representative_dataset:
@ -339,12 +344,15 @@ class TFLiteConverterV2(TFLiteConverterBase):
output_tensors = frozen_func.outputs
# Run a Grappler pass.
graph_def = _run_graph_optimizations(
frozen_func.graph.as_graph_def(),
input_tensors,
output_tensors,
self._grappler_config(self.target_spec.supported_ops),
graph=frozen_func.graph)
graph_def = frozen_func.graph.as_graph_def()
config = self._grappler_config(self.target_spec.supported_ops)
if config:
graph_def = _run_graph_optimizations(
graph_def,
input_tensors,
output_tensors,
config,
graph=frozen_func.graph)
# Checks dimensions in input tensor.
for tensor in input_tensors:
@ -862,14 +870,15 @@ class TFLiteConverter(TFLiteConverterBase):
"dump_graphviz_video": self.dump_graphviz_video
}
optimized_graph = None
if self.inference_type == constants.QUANTIZED_UINT8:
optimized_graph = self._graph_def
else:
optimized_graph = self._graph_def
if self.inference_type != constants.QUANTIZED_UINT8:
try:
optimized_graph = _run_graph_optimizations(
self._graph_def, self._input_tensors, self._output_tensors,
self._grappler_config(self.target_ops))
config = self._grappler_config(self.target_ops)
if config:
optimized_graph = _run_graph_optimizations(self._graph_def,
self._input_tensors,
self._output_tensors,
config)
except Exception:
optimized_graph = self._graph_def

View File

@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.core.protobuf import config_pb2 as _config_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2 as _rewriter_config_pb2
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
from tensorflow.lite.toco import types_pb2 as _types_pb2
@ -148,33 +147,19 @@ def set_tensor_shapes(tensors, shapes):
raise ValueError(message)
def get_grappler_config(enable_layout_optimizer=False, function_only=False):
def get_grappler_config(optimizers_list):
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
Args:
enable_layout_optimizer: Bool indicating whether to run the layout
optimizer. This turns NHCW to NCHW. This provides performance
optimizations when Flex mode is enabled. (default False)
function_only: Bool indiciating whether to only run the function optimizer.
This inlines functions and is required for freezing models with functions.
(default False)
optimizers_list: List of strings that represents the list of optimizers.
Returns:
tf.ConfigProto.
"""
config = _config_pb2.ConfigProto()
rewrite_options = config.graph_options.rewrite_options
if function_only:
rewrite_options.optimizers.append("function")
else:
if enable_layout_optimizer:
rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.ON
else:
rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.OFF
# Avoid remapping as it creates ops like _FusedConv2D, which are not
# supported by TFLite.
rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF
for optimizer in optimizers_list:
rewrite_options.optimizers.append(optimizer)
return config
@ -242,7 +227,7 @@ def freeze_graph(sess, input_tensors, output_tensors):
return _convert_op_hints_if_present(sess, output_tensors)
# Runs a Grappler pass in order to inline any functions in the graph.
config = get_grappler_config(function_only=True)
config = get_grappler_config(["function"])
graph_def = run_graph_optimizations(
sess.graph_def, input_tensors, output_tensors, config, graph=sess.graph)