Only run Grappler optimizations when necessary in TFLiteConverter.
PiperOrigin-RevId: 246023134
This commit is contained in:
parent
7e025006dd
commit
1c6d02d8b8
@ -157,7 +157,12 @@ class TFLiteConverterBase(object):
|
||||
|
||||
def _grappler_config(self, target_ops):
|
||||
is_only_flex_enabled = set([OpsSet.SELECT_TF_OPS]) == target_ops
|
||||
return _get_grappler_config(enable_layout_optimizer=is_only_flex_enabled)
|
||||
if is_only_flex_enabled:
|
||||
# The layout optimizer turns NHCW to NCHW. This provides performance
|
||||
# optimizations when Flex mode is enabled. However, this is not compatible
|
||||
# with builtin ops.
|
||||
return _get_grappler_config(["layout"])
|
||||
return None
|
||||
|
||||
def _validate_representative_dataset(self):
|
||||
if self.representative_dataset:
|
||||
@ -339,12 +344,15 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
output_tensors = frozen_func.outputs
|
||||
|
||||
# Run a Grappler pass.
|
||||
graph_def = _run_graph_optimizations(
|
||||
frozen_func.graph.as_graph_def(),
|
||||
input_tensors,
|
||||
output_tensors,
|
||||
self._grappler_config(self.target_spec.supported_ops),
|
||||
graph=frozen_func.graph)
|
||||
graph_def = frozen_func.graph.as_graph_def()
|
||||
config = self._grappler_config(self.target_spec.supported_ops)
|
||||
if config:
|
||||
graph_def = _run_graph_optimizations(
|
||||
graph_def,
|
||||
input_tensors,
|
||||
output_tensors,
|
||||
config,
|
||||
graph=frozen_func.graph)
|
||||
|
||||
# Checks dimensions in input tensor.
|
||||
for tensor in input_tensors:
|
||||
@ -862,14 +870,15 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
"dump_graphviz_video": self.dump_graphviz_video
|
||||
}
|
||||
|
||||
optimized_graph = None
|
||||
if self.inference_type == constants.QUANTIZED_UINT8:
|
||||
optimized_graph = self._graph_def
|
||||
else:
|
||||
optimized_graph = self._graph_def
|
||||
if self.inference_type != constants.QUANTIZED_UINT8:
|
||||
try:
|
||||
optimized_graph = _run_graph_optimizations(
|
||||
self._graph_def, self._input_tensors, self._output_tensors,
|
||||
self._grappler_config(self.target_ops))
|
||||
config = self._grappler_config(self.target_ops)
|
||||
if config:
|
||||
optimized_graph = _run_graph_optimizations(self._graph_def,
|
||||
self._input_tensors,
|
||||
self._output_tensors,
|
||||
config)
|
||||
except Exception:
|
||||
optimized_graph = self._graph_def
|
||||
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2 as _rewriter_config_pb2
|
||||
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
|
||||
from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
|
||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||
@ -148,33 +147,19 @@ def set_tensor_shapes(tensors, shapes):
|
||||
raise ValueError(message)
|
||||
|
||||
|
||||
def get_grappler_config(enable_layout_optimizer=False, function_only=False):
|
||||
def get_grappler_config(optimizers_list):
|
||||
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
|
||||
|
||||
Args:
|
||||
enable_layout_optimizer: Bool indicating whether to run the layout
|
||||
optimizer. This turns NHCW to NCHW. This provides performance
|
||||
optimizations when Flex mode is enabled. (default False)
|
||||
function_only: Bool indiciating whether to only run the function optimizer.
|
||||
This inlines functions and is required for freezing models with functions.
|
||||
(default False)
|
||||
optimizers_list: List of strings that represents the list of optimizers.
|
||||
|
||||
Returns:
|
||||
tf.ConfigProto.
|
||||
"""
|
||||
config = _config_pb2.ConfigProto()
|
||||
rewrite_options = config.graph_options.rewrite_options
|
||||
if function_only:
|
||||
rewrite_options.optimizers.append("function")
|
||||
else:
|
||||
if enable_layout_optimizer:
|
||||
rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.ON
|
||||
else:
|
||||
rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.OFF
|
||||
|
||||
# Avoid remapping as it creates ops like _FusedConv2D, which are not
|
||||
# supported by TFLite.
|
||||
rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF
|
||||
for optimizer in optimizers_list:
|
||||
rewrite_options.optimizers.append(optimizer)
|
||||
return config
|
||||
|
||||
|
||||
@ -242,7 +227,7 @@ def freeze_graph(sess, input_tensors, output_tensors):
|
||||
return _convert_op_hints_if_present(sess, output_tensors)
|
||||
|
||||
# Runs a Grappler pass in order to inline any functions in the graph.
|
||||
config = get_grappler_config(function_only=True)
|
||||
config = get_grappler_config(["function"])
|
||||
graph_def = run_graph_optimizations(
|
||||
sess.graph_def, input_tensors, output_tensors, config, graph=sess.graph)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user