Enable MLIR saved model import by default in TFLiteConverterV2's saved model API
PiperOrigin-RevId: 311792366 Change-Id: I98356499c0a1eb7c740104ca4b11af5d45c4a4a1
This commit is contained in:
parent
31583920dc
commit
27ac446be5
|
@ -386,13 +386,8 @@ class TFLiteConverterBase(object):
|
|||
return True
|
||||
return False
|
||||
|
||||
def _parse_saved_model_args(self, always_enable_saved_model_import=False):
|
||||
"""Parses SavedModel arguments from the given Keras/RNN SavedModel.
|
||||
|
||||
Args:
|
||||
always_enable_saved_model_import: Bool. When the value is true, it enables
|
||||
MLIR saved model import path regardless of checking the conditions.
|
||||
"""
|
||||
def _parse_saved_model_args(self):
|
||||
"""Parses SavedModel arguments from the given Keras/RNN SavedModel."""
|
||||
if not self.experimental_new_converter:
|
||||
self.saved_model_dir = None
|
||||
return
|
||||
|
@ -405,17 +400,16 @@ class TFLiteConverterBase(object):
|
|||
# frozen graph def path.
|
||||
self.saved_model_dir = None
|
||||
return
|
||||
if (not always_enable_saved_model_import and
|
||||
not self._contains_function_with_implements_attr(saved_model_proto)):
|
||||
if not self._contains_function_with_implements_attr(saved_model_proto):
|
||||
self.saved_model_dir = None
|
||||
return
|
||||
|
||||
if not self._saved_model_exported_names:
|
||||
self._saved_model_exported_names = []
|
||||
self._saved_model_version = saved_model_proto.saved_model_schema_version
|
||||
if self._saved_model_version not in [1, 2]:
|
||||
raise ValueError("SavedModel file format({0}) is not supported".format(
|
||||
self._saved_model_version))
|
||||
else:
|
||||
if not self._saved_model_exported_names:
|
||||
self._saved_model_exported_names = []
|
||||
self._saved_model_version = saved_model_proto.saved_model_schema_version
|
||||
if self._saved_model_version not in [1, 2]:
|
||||
raise ValueError(
|
||||
"SavedModel file format({0}) is not supported".format(
|
||||
self._saved_model_version))
|
||||
|
||||
|
||||
class TFLiteConverterBaseV2(TFLiteConverterBase):
|
||||
|
@ -548,7 +542,7 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
|
|||
self._saved_model_tags = saved_model_tags
|
||||
self._saved_model_exported_names = saved_model_exported_names
|
||||
self._trackable_obj = trackable_obj
|
||||
self._parse_saved_model_args(always_enable_saved_model_import=True)
|
||||
self._parse_saved_model_args()
|
||||
|
||||
def convert(self):
|
||||
"""Converts a TensorFlow GraphDef based on instance variables.
|
||||
|
|
Loading…
Reference in New Issue