Make sure that keras conversions work when storing to saved model is failed

PiperOrigin-RevId: 317028099
Change-Id: I544d2c5170644791fe644a595d7f4e91b9cd9d3d
This commit is contained in:
Jaesung Chung 2020-06-17 21:21:04 -07:00 committed by TensorFlower Gardener
parent 4f341bb742
commit ed6eba7389
2 changed files with 80 additions and 19 deletions

View File

@ -690,21 +690,20 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
self._keras_model = keras_model
self._trackable_obj = trackable_obj
def convert(self):
"""Converts a keras model based on instance variables.
def _convert_as_saved_model(self):
"""Converts a Keras model as a saved model.
Returns:
The converted data in serialized format.
Raises:
ValueError:
Multiple concrete functions are specified.
Input shape is not specified.
Invalid quantization parameters.
"""
temp_dir = tempfile.mkdtemp()
try:
try:
self._keras_model.save(temp_dir, save_format="tf")
except Exception: # pylint: disable=broad-except
# When storing the given keras model to a saved model is failed, let's
# use original keras model conversion pipeline.
return None
self.saved_model_dir = temp_dir
self._saved_model_tags = set([_tag_constants.SERVING])
self._saved_model_exported_names = [
@ -735,6 +734,22 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
finally:
shutil.rmtree(temp_dir, True)
def convert(self):
"""Converts a keras model based on instance variables.
Returns:
The converted data in serialized format.
Raises:
ValueError:
Multiple concrete functions are specified.
Input shape is not specified.
Invalid quantization parameters.
"""
saved_model_convert_result = self._convert_as_saved_model()
if saved_model_convert_result:
return saved_model_convert_result
input_signature = None
# If the model's call is not a `tf.function`, then we need to first get its
# input signature from `model_input_signature` method. We can't directly
@ -1473,21 +1488,20 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
self._output_tensors = output_tensors
self._debug_info_func = _build_debug_info_func(sess.graph)
def convert(self):
"""Converts a Keras model based on instance variables.
def _convert_as_saved_model(self):
"""Converts a Keras model as a saved model.
Returns:
The converted data in serialized format. Either a TFLite Flatbuffer or a
Graphviz graph depending on value in `output_format`.
Raises:
ValueError:
Input shape is not specified.
None value for dimension in input_tensor.
The converted data in serialized format.
"""
temp_dir = tempfile.mkdtemp()
try:
try:
self._keras_model.save(temp_dir, save_format="tf")
except Exception: # pylint: disable=broad-except
# When storing the given keras model to a saved model is failed, let's
# use original keras model conversion pipeline.
return None
tag_set = set([_tag_constants.SERVING])
signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
result = _freeze_saved_model(temp_dir, None, None, None, tag_set,
@ -1506,6 +1520,22 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
finally:
shutil.rmtree(temp_dir, True)
def convert(self):
"""Converts a Keras model based on instance variables.
Returns:
The converted data in serialized format. Either a TFLite Flatbuffer or a
Graphviz graph depending on value in `output_format`.
Raises:
ValueError:
Input shape is not specified.
None value for dimension in input_tensor.
"""
saved_model_convert_result = self._convert_as_saved_model()
if saved_model_convert_result:
return saved_model_convert_result
return super(TFLiteKerasModelConverter, self).convert()

View File

@ -769,6 +769,37 @@ class FromKerasModelTest(lite_v2_test_util.ModelTest):
converter.convert()
self._assertValidDebugInfo(converter._debug_info)
@test_util.run_v2_only
def testKerasFallbackPath(self):
"""Test keras model which failed when exporting to the saved model."""
input_data = tf.constant(
np.array(np.random.random_sample((20)), dtype=np.float32))
class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
# A None name will cause a failure in exporting to a saved model.
self.shared_weights = self.add_weight(
name=None,
shape=(20, 1),
dtype=tf.float32,
initializer=tf.random_normal_initializer(
mean=0.0, stddev=300**(-0.5)))
def call(self, x):
return tf.add(self.shared_weights, x)
# Building the model.
model = Model()
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(input_data, input_data, epochs=1)
# Convert model.
converter = lite.TFLiteConverterV2.from_keras_model(model)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
class ControlFlowTest(lite_v2_test_util.ModelTest):