From 27ac446be5b10ee68900696a2c5184fce727e86d Mon Sep 17 00:00:00 2001 From: Karim Nosir Date: Fri, 15 May 2020 13:34:58 -0700 Subject: [PATCH] Enable MLIR saved model import by default in TFLiteConverterV2's saved model API PiperOrigin-RevId: 311792366 Change-Id: I98356499c0a1eb7c740104ca4b11af5d45c4a4a1 --- tensorflow/lite/python/lite.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 99be58f4376..ce59c56a1d0 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -386,13 +386,8 @@ class TFLiteConverterBase(object): return True return False - def _parse_saved_model_args(self, always_enable_saved_model_import=False): - """Parses SavedModel arguments from the given Keras/RNN SavedModel. - - Args: - always_enable_saved_model_import: Bool. When the value is true, it enables - MLIR saved model import path regardless of checking the conditions. - """ + def _parse_saved_model_args(self): + """Parses SavedModel arguments from the given Keras/RNN SavedModel.""" if not self.experimental_new_converter: self.saved_model_dir = None return @@ -405,17 +400,16 @@ class TFLiteConverterBase(object): # frozen graph def path. self.saved_model_dir = None return - if (not always_enable_saved_model_import and - not self._contains_function_with_implements_attr(saved_model_proto)): + if not self._contains_function_with_implements_attr(saved_model_proto): self.saved_model_dir = None - return - - if not self._saved_model_exported_names: - self._saved_model_exported_names = [] - self._saved_model_version = saved_model_proto.saved_model_schema_version - if self._saved_model_version not in [1, 2]: - raise ValueError("SavedModel file format({0}) is not supported".format( - self._saved_model_version)) + else: + if not self._saved_model_exported_names: + self._saved_model_exported_names = [] + self._saved_model_version = saved_model_proto.saved_model_schema_version + if self._saved_model_version not in [1, 2]: + raise ValueError( + "SavedModel file format({0}) is not supported".format( + self._saved_model_version)) class TFLiteConverterBaseV2(TFLiteConverterBase): @@ -548,7 +542,7 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2): self._saved_model_tags = saved_model_tags self._saved_model_exported_names = saved_model_exported_names self._trackable_obj = trackable_obj - self._parse_saved_model_args(always_enable_saved_model_import=True) + self._parse_saved_model_args() def convert(self): """Converts a TensorFlow GraphDef based on instance variables.