Enable MLIR saved model import by default in TFLiteConverterV2's saved model API

PiperOrigin-RevId: 311792366
Change-Id: I98356499c0a1eb7c740104ca4b11af5d45c4a4a1
This commit is contained in:
Karim Nosir 2020-05-15 13:34:58 -07:00 committed by TensorFlower Gardener
parent 31583920dc
commit 27ac446be5
1 changed files with 12 additions and 18 deletions

View File

@ -386,13 +386,8 @@ class TFLiteConverterBase(object):
return True
return False
def _parse_saved_model_args(self, always_enable_saved_model_import=False):
"""Parses SavedModel arguments from the given Keras/RNN SavedModel.
Args:
always_enable_saved_model_import: Bool. When the value is true, it enables
MLIR saved model import path regardless of checking the conditions.
"""
def _parse_saved_model_args(self):
"""Parses SavedModel arguments from the given Keras/RNN SavedModel."""
if not self.experimental_new_converter:
self.saved_model_dir = None
return
@ -405,16 +400,15 @@ class TFLiteConverterBase(object):
# frozen graph def path.
self.saved_model_dir = None
return
if (not always_enable_saved_model_import and
not self._contains_function_with_implements_attr(saved_model_proto)):
if not self._contains_function_with_implements_attr(saved_model_proto):
self.saved_model_dir = None
return
else:
if not self._saved_model_exported_names:
self._saved_model_exported_names = []
self._saved_model_version = saved_model_proto.saved_model_schema_version
if self._saved_model_version not in [1, 2]:
raise ValueError("SavedModel file format({0}) is not supported".format(
raise ValueError(
"SavedModel file format({0}) is not supported".format(
self._saved_model_version))
@ -548,7 +542,7 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
self._saved_model_tags = saved_model_tags
self._saved_model_exported_names = saved_model_exported_names
self._trackable_obj = trackable_obj
self._parse_saved_model_args(always_enable_saved_model_import=True)
self._parse_saved_model_args()
def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables.