Enable MLIR saved model import by default in TFLiteConverterV2's saved model API

PiperOrigin-RevId: 311792366
Change-Id: I98356499c0a1eb7c740104ca4b11af5d45c4a4a1
This commit is contained in:
Karim Nosir 2020-05-15 13:34:58 -07:00 committed by TensorFlower Gardener
parent 31583920dc
commit 27ac446be5
1 changed files with 12 additions and 18 deletions

View File

@ -386,13 +386,8 @@ class TFLiteConverterBase(object):
return True return True
return False return False
def _parse_saved_model_args(self, always_enable_saved_model_import=False): def _parse_saved_model_args(self):
"""Parses SavedModel arguments from the given Keras/RNN SavedModel. """Parses SavedModel arguments from the given Keras/RNN SavedModel."""
Args:
always_enable_saved_model_import: Bool. When the value is true, it enables
MLIR saved model import path regardless of checking the conditions.
"""
if not self.experimental_new_converter: if not self.experimental_new_converter:
self.saved_model_dir = None self.saved_model_dir = None
return return
@ -405,17 +400,16 @@ class TFLiteConverterBase(object):
# frozen graph def path. # frozen graph def path.
self.saved_model_dir = None self.saved_model_dir = None
return return
if (not always_enable_saved_model_import and if not self._contains_function_with_implements_attr(saved_model_proto):
not self._contains_function_with_implements_attr(saved_model_proto)):
self.saved_model_dir = None self.saved_model_dir = None
return else:
if not self._saved_model_exported_names:
if not self._saved_model_exported_names: self._saved_model_exported_names = []
self._saved_model_exported_names = [] self._saved_model_version = saved_model_proto.saved_model_schema_version
self._saved_model_version = saved_model_proto.saved_model_schema_version if self._saved_model_version not in [1, 2]:
if self._saved_model_version not in [1, 2]: raise ValueError(
raise ValueError("SavedModel file format({0}) is not supported".format( "SavedModel file format({0}) is not supported".format(
self._saved_model_version)) self._saved_model_version))
class TFLiteConverterBaseV2(TFLiteConverterBase): class TFLiteConverterBaseV2(TFLiteConverterBase):
@ -548,7 +542,7 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
self._saved_model_tags = saved_model_tags self._saved_model_tags = saved_model_tags
self._saved_model_exported_names = saved_model_exported_names self._saved_model_exported_names = saved_model_exported_names
self._trackable_obj = trackable_obj self._trackable_obj = trackable_obj
self._parse_saved_model_args(always_enable_saved_model_import=True) self._parse_saved_model_args()
def convert(self): def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables. """Converts a TensorFlow GraphDef based on instance variables.