From f2100b9b516b1ecdbac2259c81766796f504dd9b Mon Sep 17 00:00:00 2001 From: Jaesung Chung <jaesung@google.com> Date: Thu, 9 Apr 2020 20:24:15 -0700 Subject: [PATCH] Disable grappler and freezing steps when MLIR SavedModel conversion path is on PiperOrigin-RevId: 305814559 Change-Id: I04528dcfdab7560531bc7c594ee22b0e5061bb59 --- tensorflow/lite/python/lite.py | 107 ++++++++++++++++++++------------- 1 file changed, 64 insertions(+), 43 deletions(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index fef8c9ce3cf..96f3428efe3 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -72,6 +72,7 @@ from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundEr from tensorflow.python.framework.importer import import_graph_def as _import_graph_def from tensorflow.python.keras.saving import saving_utils as _saving_utils from tensorflow.python.lib.io import file_io as _file_io +from tensorflow.python.saved_model import loader_impl as _loader_impl from tensorflow.python.saved_model import signature_constants as _signature_constants from tensorflow.python.saved_model import tag_constants as _tag_constants from tensorflow.python.saved_model.load import load as _load @@ -597,28 +598,47 @@ class TFLiteConverterV2(TFLiteConverterBase): self._parse_saved_model_args() # graph_def is used here to preserve the node bug information - frozen_func, graph_def = ( - _convert_to_constants.convert_variables_to_constants_v2_as_graph( - self._funcs[0], lower_control_flow=False)) - self._graph_def = graph_def - input_tensors = [ - tensor for tensor in frozen_func.inputs - if tensor.dtype != _dtypes.resource - ] - output_tensors = frozen_func.outputs + if self._saved_model_dir: + graph = _ops.Graph() + saved_model = _loader_impl.SavedModelLoader(self._saved_model_dir) + saved_model.load_graph(graph, tags=self._saved_model_tags) + meta_graph = saved_model.get_meta_graph_def_from_tags( + self._saved_model_tags) + signature_def = meta_graph.signature_def[ + _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + input_tensors = [ + graph.get_tensor_by_name(signature_def.inputs[key].name) + for key in signature_def.inputs + ] + output_tensors = [ + graph.get_tensor_by_name(signature_def.outputs[key].name) + for key in signature_def.outputs + ] + self._graph_def = graph_def = meta_graph.graph_def + else: + frozen_func, graph_def = ( + _convert_to_constants.convert_variables_to_constants_v2_as_graph( + self._funcs[0], lower_control_flow=False)) + self._graph_def = graph_def - # Run a Grappler pass. - grappler_config = self._grappler_config() - # Skip running grappler when there are no optimizers to run. If not, - # grappler will run with the default optimizer set and it will lead to - # causing an unexpected behavior. - if grappler_config.graph_options.rewrite_options.optimizers: - graph_def = _run_graph_optimizations( - graph_def, - input_tensors, - output_tensors, - config=grappler_config, - graph=frozen_func.graph) + input_tensors = [ + tensor for tensor in frozen_func.inputs + if tensor.dtype != _dtypes.resource + ] + output_tensors = frozen_func.outputs + + # Run a Grappler pass. + grappler_config = self._grappler_config() + # Skip running grappler when there are no optimizers to run. If not, + # grappler will run with the default optimizer set and it will lead to + # causing an unexpected behavior. + if grappler_config.graph_options.rewrite_options.optimizers: + graph_def = _run_graph_optimizations( + graph_def, + input_tensors, + output_tensors, + config=grappler_config, + graph=frozen_func.graph) quant_mode = QuantizationMode(self.optimizations, self.target_spec, self.representative_dataset, graph_def) @@ -1231,28 +1251,29 @@ class TFLiteConverter(TFLiteConverterBase): "are not enabled.") optimized_graph = self._graph_def - # if it is not uint8 or int8 with post-training quantization, it is not - # quantization aware training, then graph optimization is applied. - # Graph optimization is disabled for quantization aware training. - if (self.inference_type != constants.QUANTIZED_UINT8 or - (self.inference_type == constants.INT8 and - (post_training_optimize or weight_only_quantize))): - try: - # TODO(b/150163103): Merge `disabling lower using switch merge' calls. - # Grappler will also try to lower while loop into switch merge - # representation which is undesired for Ophints, so we simply remove - # those attributes to prevent Grappler from doing so. - graph_def = _convert_to_constants.disable_lower_using_switch_merge( - optimized_graph) - # Run function inlining optimization to ensure any models generated - # through the from_frozen_graph path have been inlined. - optimized_graph = _run_graph_optimizations( - graph_def, - self._input_tensors, - self._output_tensors, - config=self._grappler_config(["function"])) - except Exception: - optimized_graph = self._graph_def + if not self._saved_model_dir: + # if it is not uint8 or int8 with post-training quantization, it is not + # quantization aware training, then graph optimization is applied. + # Graph optimization is disabled for quantization aware training. + if (self.inference_type != constants.QUANTIZED_UINT8 or + (self.inference_type == constants.INT8 and + (post_training_optimize or weight_only_quantize))): + try: + # TODO(b/150163103): Merge `disabling lower using switch merge' calls. + # Grappler will also try to lower while loop into switch merge + # representation which is undesired for Ophints, so we simply remove + # those attributes to prevent Grappler from doing so. + graph_def = _convert_to_constants.disable_lower_using_switch_merge( + optimized_graph) + # Run function inlining optimization to ensure any models generated + # through the from_frozen_graph path have been inlined. + optimized_graph = _run_graph_optimizations( + graph_def, + self._input_tensors, + self._output_tensors, + config=self._grappler_config(["function"])) + except Exception: + optimized_graph = self._graph_def self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)