From f2100b9b516b1ecdbac2259c81766796f504dd9b Mon Sep 17 00:00:00 2001
From: Jaesung Chung <jaesung@google.com>
Date: Thu, 9 Apr 2020 20:24:15 -0700
Subject: [PATCH] Disable grappler and freezing steps when MLIR SavedModel
 conversion path is on

PiperOrigin-RevId: 305814559
Change-Id: I04528dcfdab7560531bc7c594ee22b0e5061bb59
---
 tensorflow/lite/python/lite.py | 107 ++++++++++++++++++++-------------
 1 file changed, 64 insertions(+), 43 deletions(-)

diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py
index fef8c9ce3cf..96f3428efe3 100644
--- a/tensorflow/lite/python/lite.py
+++ b/tensorflow/lite/python/lite.py
@@ -72,6 +72,7 @@ from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundEr
 from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
 from tensorflow.python.keras.saving import saving_utils as _saving_utils
 from tensorflow.python.lib.io import file_io as _file_io
+from tensorflow.python.saved_model import loader_impl as _loader_impl
 from tensorflow.python.saved_model import signature_constants as _signature_constants
 from tensorflow.python.saved_model import tag_constants as _tag_constants
 from tensorflow.python.saved_model.load import load as _load
@@ -597,28 +598,47 @@ class TFLiteConverterV2(TFLiteConverterBase):
     self._parse_saved_model_args()
 
     # graph_def is used here to preserve the node bug information
-    frozen_func, graph_def = (
-        _convert_to_constants.convert_variables_to_constants_v2_as_graph(
-            self._funcs[0], lower_control_flow=False))
-    self._graph_def = graph_def
-    input_tensors = [
-        tensor for tensor in frozen_func.inputs
-        if tensor.dtype != _dtypes.resource
-    ]
-    output_tensors = frozen_func.outputs
+    if self._saved_model_dir:
+      graph = _ops.Graph()
+      saved_model = _loader_impl.SavedModelLoader(self._saved_model_dir)
+      saved_model.load_graph(graph, tags=self._saved_model_tags)
+      meta_graph = saved_model.get_meta_graph_def_from_tags(
+          self._saved_model_tags)
+      signature_def = meta_graph.signature_def[
+          _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
+      input_tensors = [
+          graph.get_tensor_by_name(signature_def.inputs[key].name)
+          for key in signature_def.inputs
+      ]
+      output_tensors = [
+          graph.get_tensor_by_name(signature_def.outputs[key].name)
+          for key in signature_def.outputs
+      ]
+      self._graph_def = graph_def = meta_graph.graph_def
+    else:
+      frozen_func, graph_def = (
+          _convert_to_constants.convert_variables_to_constants_v2_as_graph(
+              self._funcs[0], lower_control_flow=False))
+      self._graph_def = graph_def
 
-    # Run a Grappler pass.
-    grappler_config = self._grappler_config()
-    # Skip running grappler when there are no optimizers to run. If not,
-    # grappler will run with the default optimizer set and it will lead to
-    # causing an unexpected behavior.
-    if grappler_config.graph_options.rewrite_options.optimizers:
-      graph_def = _run_graph_optimizations(
-          graph_def,
-          input_tensors,
-          output_tensors,
-          config=grappler_config,
-          graph=frozen_func.graph)
+      input_tensors = [
+          tensor for tensor in frozen_func.inputs
+          if tensor.dtype != _dtypes.resource
+      ]
+      output_tensors = frozen_func.outputs
+
+      # Run a Grappler pass.
+      grappler_config = self._grappler_config()
+      # Skip running grappler when there are no optimizers to run. If not,
+      # grappler will run with the default optimizer set and it will lead to
+      # causing an unexpected behavior.
+      if grappler_config.graph_options.rewrite_options.optimizers:
+        graph_def = _run_graph_optimizations(
+            graph_def,
+            input_tensors,
+            output_tensors,
+            config=grappler_config,
+            graph=frozen_func.graph)
 
     quant_mode = QuantizationMode(self.optimizations, self.target_spec,
                                   self.representative_dataset, graph_def)
@@ -1231,28 +1251,29 @@ class TFLiteConverter(TFLiteConverterBase):
           "are not enabled.")
 
     optimized_graph = self._graph_def
-    # if it is not uint8 or int8 with post-training quantization, it is not
-    # quantization aware training, then graph optimization is applied.
-    # Graph optimization is disabled for quantization aware training.
-    if (self.inference_type != constants.QUANTIZED_UINT8 or
-        (self.inference_type == constants.INT8 and
-         (post_training_optimize or weight_only_quantize))):
-      try:
-        # TODO(b/150163103): Merge `disabling lower using switch merge' calls.
-        # Grappler will also try to lower while loop into switch merge
-        # representation which is undesired for Ophints, so we simply remove
-        # those attributes to prevent Grappler from doing so.
-        graph_def = _convert_to_constants.disable_lower_using_switch_merge(
-            optimized_graph)
-        # Run function inlining optimization to ensure any models generated
-        # through the from_frozen_graph path have been inlined.
-        optimized_graph = _run_graph_optimizations(
-            graph_def,
-            self._input_tensors,
-            self._output_tensors,
-            config=self._grappler_config(["function"]))
-      except Exception:
-        optimized_graph = self._graph_def
+    if not self._saved_model_dir:
+      # if it is not uint8 or int8 with post-training quantization, it is not
+      # quantization aware training, then graph optimization is applied.
+      # Graph optimization is disabled for quantization aware training.
+      if (self.inference_type != constants.QUANTIZED_UINT8 or
+          (self.inference_type == constants.INT8 and
+           (post_training_optimize or weight_only_quantize))):
+        try:
+          # TODO(b/150163103): Merge `disabling lower using switch merge' calls.
+          # Grappler will also try to lower while loop into switch merge
+          # representation which is undesired for Ophints, so we simply remove
+          # those attributes to prevent Grappler from doing so.
+          graph_def = _convert_to_constants.disable_lower_using_switch_merge(
+              optimized_graph)
+          # Run function inlining optimization to ensure any models generated
+          # through the from_frozen_graph path have been inlined.
+          optimized_graph = _run_graph_optimizations(
+              graph_def,
+              self._input_tensors,
+              self._output_tensors,
+              config=self._grappler_config(["function"]))
+        except Exception:
+          optimized_graph = self._graph_def
 
     self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)