diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 99be58f4376..ce59c56a1d0 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -386,13 +386,8 @@ class TFLiteConverterBase(object): return True return False - def _parse_saved_model_args(self, always_enable_saved_model_import=False): - """Parses SavedModel arguments from the given Keras/RNN SavedModel. - - Args: - always_enable_saved_model_import: Bool. When the value is true, it enables - MLIR saved model import path regardless of checking the conditions. - """ + def _parse_saved_model_args(self): + """Parses SavedModel arguments from the given Keras/RNN SavedModel.""" if not self.experimental_new_converter: self.saved_model_dir = None return @@ -405,17 +400,16 @@ class TFLiteConverterBase(object): # frozen graph def path. self.saved_model_dir = None return - if (not always_enable_saved_model_import and - not self._contains_function_with_implements_attr(saved_model_proto)): + if not self._contains_function_with_implements_attr(saved_model_proto): self.saved_model_dir = None - return - - if not self._saved_model_exported_names: - self._saved_model_exported_names = [] - self._saved_model_version = saved_model_proto.saved_model_schema_version - if self._saved_model_version not in [1, 2]: - raise ValueError("SavedModel file format({0}) is not supported".format( - self._saved_model_version)) + else: + if not self._saved_model_exported_names: + self._saved_model_exported_names = [] + self._saved_model_version = saved_model_proto.saved_model_schema_version + if self._saved_model_version not in [1, 2]: + raise ValueError( + "SavedModel file format({0}) is not supported".format( + self._saved_model_version)) class TFLiteConverterBaseV2(TFLiteConverterBase): @@ -548,7 +542,7 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2): self._saved_model_tags = saved_model_tags self._saved_model_exported_names = saved_model_exported_names self._trackable_obj = trackable_obj - self._parse_saved_model_args(always_enable_saved_model_import=True) + self._parse_saved_model_args() def convert(self): """Converts a TensorFlow GraphDef based on instance variables.