Make sure that keras conversions work when storing to saved model is failed
PiperOrigin-RevId: 317028099 Change-Id: I544d2c5170644791fe644a595d7f4e91b9cd9d3d
This commit is contained in:
parent
4f341bb742
commit
ed6eba7389
|
@ -690,21 +690,20 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
|
|||
self._keras_model = keras_model
|
||||
self._trackable_obj = trackable_obj
|
||||
|
||||
def convert(self):
|
||||
"""Converts a keras model based on instance variables.
|
||||
def _convert_as_saved_model(self):
|
||||
"""Converts a Keras model as a saved model.
|
||||
|
||||
Returns:
|
||||
The converted data in serialized format.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
Multiple concrete functions are specified.
|
||||
Input shape is not specified.
|
||||
Invalid quantization parameters.
|
||||
"""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
try:
|
||||
self._keras_model.save(temp_dir, save_format="tf")
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# When storing the given keras model to a saved model is failed, let's
|
||||
# use original keras model conversion pipeline.
|
||||
return None
|
||||
self.saved_model_dir = temp_dir
|
||||
self._saved_model_tags = set([_tag_constants.SERVING])
|
||||
self._saved_model_exported_names = [
|
||||
|
@ -735,6 +734,22 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
|
|||
finally:
|
||||
shutil.rmtree(temp_dir, True)
|
||||
|
||||
def convert(self):
|
||||
"""Converts a keras model based on instance variables.
|
||||
|
||||
Returns:
|
||||
The converted data in serialized format.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
Multiple concrete functions are specified.
|
||||
Input shape is not specified.
|
||||
Invalid quantization parameters.
|
||||
"""
|
||||
saved_model_convert_result = self._convert_as_saved_model()
|
||||
if saved_model_convert_result:
|
||||
return saved_model_convert_result
|
||||
|
||||
input_signature = None
|
||||
# If the model's call is not a `tf.function`, then we need to first get its
|
||||
# input signature from `model_input_signature` method. We can't directly
|
||||
|
@ -1473,21 +1488,20 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
|
|||
self._output_tensors = output_tensors
|
||||
self._debug_info_func = _build_debug_info_func(sess.graph)
|
||||
|
||||
def convert(self):
|
||||
"""Converts a Keras model based on instance variables.
|
||||
def _convert_as_saved_model(self):
|
||||
"""Converts a Keras model as a saved model.
|
||||
|
||||
Returns:
|
||||
The converted data in serialized format. Either a TFLite Flatbuffer or a
|
||||
Graphviz graph depending on value in `output_format`.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
Input shape is not specified.
|
||||
None value for dimension in input_tensor.
|
||||
The converted data in serialized format.
|
||||
"""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
try:
|
||||
self._keras_model.save(temp_dir, save_format="tf")
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# When storing the given keras model to a saved model is failed, let's
|
||||
# use original keras model conversion pipeline.
|
||||
return None
|
||||
tag_set = set([_tag_constants.SERVING])
|
||||
signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||
result = _freeze_saved_model(temp_dir, None, None, None, tag_set,
|
||||
|
@ -1506,6 +1520,22 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
|
|||
finally:
|
||||
shutil.rmtree(temp_dir, True)
|
||||
|
||||
def convert(self):
|
||||
"""Converts a Keras model based on instance variables.
|
||||
|
||||
Returns:
|
||||
The converted data in serialized format. Either a TFLite Flatbuffer or a
|
||||
Graphviz graph depending on value in `output_format`.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
Input shape is not specified.
|
||||
None value for dimension in input_tensor.
|
||||
"""
|
||||
saved_model_convert_result = self._convert_as_saved_model()
|
||||
if saved_model_convert_result:
|
||||
return saved_model_convert_result
|
||||
|
||||
return super(TFLiteKerasModelConverter, self).convert()
|
||||
|
||||
|
||||
|
|
|
@ -769,6 +769,37 @@ class FromKerasModelTest(lite_v2_test_util.ModelTest):
|
|||
converter.convert()
|
||||
self._assertValidDebugInfo(converter._debug_info)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testKerasFallbackPath(self):
|
||||
"""Test keras model which failed when exporting to the saved model."""
|
||||
input_data = tf.constant(
|
||||
np.array(np.random.random_sample((20)), dtype=np.float32))
|
||||
|
||||
class Model(tf.keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
# A None name will cause a failure in exporting to a saved model.
|
||||
self.shared_weights = self.add_weight(
|
||||
name=None,
|
||||
shape=(20, 1),
|
||||
dtype=tf.float32,
|
||||
initializer=tf.random_normal_initializer(
|
||||
mean=0.0, stddev=300**(-0.5)))
|
||||
|
||||
def call(self, x):
|
||||
return tf.add(self.shared_weights, x)
|
||||
|
||||
# Building the model.
|
||||
model = Model()
|
||||
model.compile(optimizer='sgd', loss='mean_squared_error')
|
||||
model.fit(input_data, input_data, epochs=1)
|
||||
|
||||
# Convert model.
|
||||
converter = lite.TFLiteConverterV2.from_keras_model(model)
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
|
||||
class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
|
|
Loading…
Reference in New Issue