Make sure that keras conversions work when storing to saved model is failed
PiperOrigin-RevId: 317028099 Change-Id: I544d2c5170644791fe644a595d7f4e91b9cd9d3d
This commit is contained in:
parent
4f341bb742
commit
ed6eba7389
|
@ -690,21 +690,20 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
|
||||||
self._keras_model = keras_model
|
self._keras_model = keras_model
|
||||||
self._trackable_obj = trackable_obj
|
self._trackable_obj = trackable_obj
|
||||||
|
|
||||||
def convert(self):
|
def _convert_as_saved_model(self):
|
||||||
"""Converts a keras model based on instance variables.
|
"""Converts a Keras model as a saved model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The converted data in serialized format.
|
The converted data in serialized format.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError:
|
|
||||||
Multiple concrete functions are specified.
|
|
||||||
Input shape is not specified.
|
|
||||||
Invalid quantization parameters.
|
|
||||||
"""
|
"""
|
||||||
temp_dir = tempfile.mkdtemp()
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
try:
|
||||||
try:
|
try:
|
||||||
self._keras_model.save(temp_dir, save_format="tf")
|
self._keras_model.save(temp_dir, save_format="tf")
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
# When storing the given keras model to a saved model is failed, let's
|
||||||
|
# use original keras model conversion pipeline.
|
||||||
|
return None
|
||||||
self.saved_model_dir = temp_dir
|
self.saved_model_dir = temp_dir
|
||||||
self._saved_model_tags = set([_tag_constants.SERVING])
|
self._saved_model_tags = set([_tag_constants.SERVING])
|
||||||
self._saved_model_exported_names = [
|
self._saved_model_exported_names = [
|
||||||
|
@ -735,6 +734,22 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
|
||||||
finally:
|
finally:
|
||||||
shutil.rmtree(temp_dir, True)
|
shutil.rmtree(temp_dir, True)
|
||||||
|
|
||||||
|
def convert(self):
|
||||||
|
"""Converts a keras model based on instance variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The converted data in serialized format.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError:
|
||||||
|
Multiple concrete functions are specified.
|
||||||
|
Input shape is not specified.
|
||||||
|
Invalid quantization parameters.
|
||||||
|
"""
|
||||||
|
saved_model_convert_result = self._convert_as_saved_model()
|
||||||
|
if saved_model_convert_result:
|
||||||
|
return saved_model_convert_result
|
||||||
|
|
||||||
input_signature = None
|
input_signature = None
|
||||||
# If the model's call is not a `tf.function`, then we need to first get its
|
# If the model's call is not a `tf.function`, then we need to first get its
|
||||||
# input signature from `model_input_signature` method. We can't directly
|
# input signature from `model_input_signature` method. We can't directly
|
||||||
|
@ -1473,21 +1488,20 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
|
||||||
self._output_tensors = output_tensors
|
self._output_tensors = output_tensors
|
||||||
self._debug_info_func = _build_debug_info_func(sess.graph)
|
self._debug_info_func = _build_debug_info_func(sess.graph)
|
||||||
|
|
||||||
def convert(self):
|
def _convert_as_saved_model(self):
|
||||||
"""Converts a Keras model based on instance variables.
|
"""Converts a Keras model as a saved model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The converted data in serialized format. Either a TFLite Flatbuffer or a
|
The converted data in serialized format.
|
||||||
Graphviz graph depending on value in `output_format`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError:
|
|
||||||
Input shape is not specified.
|
|
||||||
None value for dimension in input_tensor.
|
|
||||||
"""
|
"""
|
||||||
temp_dir = tempfile.mkdtemp()
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
try:
|
||||||
try:
|
try:
|
||||||
self._keras_model.save(temp_dir, save_format="tf")
|
self._keras_model.save(temp_dir, save_format="tf")
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
# When storing the given keras model to a saved model is failed, let's
|
||||||
|
# use original keras model conversion pipeline.
|
||||||
|
return None
|
||||||
tag_set = set([_tag_constants.SERVING])
|
tag_set = set([_tag_constants.SERVING])
|
||||||
signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||||
result = _freeze_saved_model(temp_dir, None, None, None, tag_set,
|
result = _freeze_saved_model(temp_dir, None, None, None, tag_set,
|
||||||
|
@ -1506,6 +1520,22 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
|
||||||
finally:
|
finally:
|
||||||
shutil.rmtree(temp_dir, True)
|
shutil.rmtree(temp_dir, True)
|
||||||
|
|
||||||
|
def convert(self):
|
||||||
|
"""Converts a Keras model based on instance variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The converted data in serialized format. Either a TFLite Flatbuffer or a
|
||||||
|
Graphviz graph depending on value in `output_format`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError:
|
||||||
|
Input shape is not specified.
|
||||||
|
None value for dimension in input_tensor.
|
||||||
|
"""
|
||||||
|
saved_model_convert_result = self._convert_as_saved_model()
|
||||||
|
if saved_model_convert_result:
|
||||||
|
return saved_model_convert_result
|
||||||
|
|
||||||
return super(TFLiteKerasModelConverter, self).convert()
|
return super(TFLiteKerasModelConverter, self).convert()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -769,6 +769,37 @@ class FromKerasModelTest(lite_v2_test_util.ModelTest):
|
||||||
converter.convert()
|
converter.convert()
|
||||||
self._assertValidDebugInfo(converter._debug_info)
|
self._assertValidDebugInfo(converter._debug_info)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testKerasFallbackPath(self):
|
||||||
|
"""Test keras model which failed when exporting to the saved model."""
|
||||||
|
input_data = tf.constant(
|
||||||
|
np.array(np.random.random_sample((20)), dtype=np.float32))
|
||||||
|
|
||||||
|
class Model(tf.keras.Model):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
# A None name will cause a failure in exporting to a saved model.
|
||||||
|
self.shared_weights = self.add_weight(
|
||||||
|
name=None,
|
||||||
|
shape=(20, 1),
|
||||||
|
dtype=tf.float32,
|
||||||
|
initializer=tf.random_normal_initializer(
|
||||||
|
mean=0.0, stddev=300**(-0.5)))
|
||||||
|
|
||||||
|
def call(self, x):
|
||||||
|
return tf.add(self.shared_weights, x)
|
||||||
|
|
||||||
|
# Building the model.
|
||||||
|
model = Model()
|
||||||
|
model.compile(optimizer='sgd', loss='mean_squared_error')
|
||||||
|
model.fit(input_data, input_data, epochs=1)
|
||||||
|
|
||||||
|
# Convert model.
|
||||||
|
converter = lite.TFLiteConverterV2.from_keras_model(model)
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
self.assertTrue(tflite_model)
|
||||||
|
|
||||||
|
|
||||||
class ControlFlowTest(lite_v2_test_util.ModelTest):
|
class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue