Enable MLIR saved model import by default in TFLiteConverterV2's saved model API
PiperOrigin-RevId: 311792366 Change-Id: I98356499c0a1eb7c740104ca4b11af5d45c4a4a1
This commit is contained in:
parent
31583920dc
commit
27ac446be5
|
@ -386,13 +386,8 @@ class TFLiteConverterBase(object):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _parse_saved_model_args(self, always_enable_saved_model_import=False):
|
def _parse_saved_model_args(self):
|
||||||
"""Parses SavedModel arguments from the given Keras/RNN SavedModel.
|
"""Parses SavedModel arguments from the given Keras/RNN SavedModel."""
|
||||||
|
|
||||||
Args:
|
|
||||||
always_enable_saved_model_import: Bool. When the value is true, it enables
|
|
||||||
MLIR saved model import path regardless of checking the conditions.
|
|
||||||
"""
|
|
||||||
if not self.experimental_new_converter:
|
if not self.experimental_new_converter:
|
||||||
self.saved_model_dir = None
|
self.saved_model_dir = None
|
||||||
return
|
return
|
||||||
|
@ -405,17 +400,16 @@ class TFLiteConverterBase(object):
|
||||||
# frozen graph def path.
|
# frozen graph def path.
|
||||||
self.saved_model_dir = None
|
self.saved_model_dir = None
|
||||||
return
|
return
|
||||||
if (not always_enable_saved_model_import and
|
if not self._contains_function_with_implements_attr(saved_model_proto):
|
||||||
not self._contains_function_with_implements_attr(saved_model_proto)):
|
|
||||||
self.saved_model_dir = None
|
self.saved_model_dir = None
|
||||||
return
|
else:
|
||||||
|
if not self._saved_model_exported_names:
|
||||||
if not self._saved_model_exported_names:
|
self._saved_model_exported_names = []
|
||||||
self._saved_model_exported_names = []
|
self._saved_model_version = saved_model_proto.saved_model_schema_version
|
||||||
self._saved_model_version = saved_model_proto.saved_model_schema_version
|
if self._saved_model_version not in [1, 2]:
|
||||||
if self._saved_model_version not in [1, 2]:
|
raise ValueError(
|
||||||
raise ValueError("SavedModel file format({0}) is not supported".format(
|
"SavedModel file format({0}) is not supported".format(
|
||||||
self._saved_model_version))
|
self._saved_model_version))
|
||||||
|
|
||||||
|
|
||||||
class TFLiteConverterBaseV2(TFLiteConverterBase):
|
class TFLiteConverterBaseV2(TFLiteConverterBase):
|
||||||
|
@ -548,7 +542,7 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
|
||||||
self._saved_model_tags = saved_model_tags
|
self._saved_model_tags = saved_model_tags
|
||||||
self._saved_model_exported_names = saved_model_exported_names
|
self._saved_model_exported_names = saved_model_exported_names
|
||||||
self._trackable_obj = trackable_obj
|
self._trackable_obj = trackable_obj
|
||||||
self._parse_saved_model_args(always_enable_saved_model_import=True)
|
self._parse_saved_model_args()
|
||||||
|
|
||||||
def convert(self):
|
def convert(self):
|
||||||
"""Converts a TensorFlow GraphDef based on instance variables.
|
"""Converts a TensorFlow GraphDef based on instance variables.
|
||||||
|
|
Loading…
Reference in New Issue