From ed6eba7389f390fddc456c6d28396878c326b369 Mon Sep 17 00:00:00 2001 From: Jaesung Chung Date: Wed, 17 Jun 2020 21:21:04 -0700 Subject: [PATCH] Make sure that keras conversions work when storing to saved model is failed PiperOrigin-RevId: 317028099 Change-Id: I544d2c5170644791fe644a595d7f4e91b9cd9d3d --- tensorflow/lite/python/lite.py | 68 +++++++++++++++++++------- tensorflow/lite/python/lite_v2_test.py | 31 ++++++++++++ 2 files changed, 80 insertions(+), 19 deletions(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 93cca1a6af5..b0bd53cb9b2 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -690,21 +690,20 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2): self._keras_model = keras_model self._trackable_obj = trackable_obj - def convert(self): - """Converts a keras model based on instance variables. + def _convert_as_saved_model(self): + """Converts a Keras model as a saved model. Returns: The converted data in serialized format. - - Raises: - ValueError: - Multiple concrete functions are specified. - Input shape is not specified. - Invalid quantization parameters. """ temp_dir = tempfile.mkdtemp() try: - self._keras_model.save(temp_dir, save_format="tf") + try: + self._keras_model.save(temp_dir, save_format="tf") + except Exception: # pylint: disable=broad-except + # When storing the given keras model to a saved model is failed, let's + # use original keras model conversion pipeline. + return None self.saved_model_dir = temp_dir self._saved_model_tags = set([_tag_constants.SERVING]) self._saved_model_exported_names = [ @@ -735,6 +734,22 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2): finally: shutil.rmtree(temp_dir, True) + def convert(self): + """Converts a keras model based on instance variables. + + Returns: + The converted data in serialized format. + + Raises: + ValueError: + Multiple concrete functions are specified. + Input shape is not specified. + Invalid quantization parameters. + """ + saved_model_convert_result = self._convert_as_saved_model() + if saved_model_convert_result: + return saved_model_convert_result + input_signature = None # If the model's call is not a `tf.function`, then we need to first get its # input signature from `model_input_signature` method. We can't directly @@ -1473,21 +1488,20 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1): self._output_tensors = output_tensors self._debug_info_func = _build_debug_info_func(sess.graph) - def convert(self): - """Converts a Keras model based on instance variables. + def _convert_as_saved_model(self): + """Converts a Keras model as a saved model. Returns: - The converted data in serialized format. Either a TFLite Flatbuffer or a - Graphviz graph depending on value in `output_format`. - - Raises: - ValueError: - Input shape is not specified. - None value for dimension in input_tensor. + The converted data in serialized format. """ temp_dir = tempfile.mkdtemp() try: - self._keras_model.save(temp_dir, save_format="tf") + try: + self._keras_model.save(temp_dir, save_format="tf") + except Exception: # pylint: disable=broad-except + # When storing the given keras model to a saved model is failed, let's + # use original keras model conversion pipeline. + return None tag_set = set([_tag_constants.SERVING]) signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY result = _freeze_saved_model(temp_dir, None, None, None, tag_set, @@ -1506,6 +1520,22 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1): finally: shutil.rmtree(temp_dir, True) + def convert(self): + """Converts a Keras model based on instance variables. + + Returns: + The converted data in serialized format. Either a TFLite Flatbuffer or a + Graphviz graph depending on value in `output_format`. + + Raises: + ValueError: + Input shape is not specified. + None value for dimension in input_tensor. + """ + saved_model_convert_result = self._convert_as_saved_model() + if saved_model_convert_result: + return saved_model_convert_result + return super(TFLiteKerasModelConverter, self).convert() diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index f56f85d0ba4..ea8db15abc2 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -769,6 +769,37 @@ class FromKerasModelTest(lite_v2_test_util.ModelTest): converter.convert() self._assertValidDebugInfo(converter._debug_info) + @test_util.run_v2_only + def testKerasFallbackPath(self): + """Test keras model which failed when exporting to the saved model.""" + input_data = tf.constant( + np.array(np.random.random_sample((20)), dtype=np.float32)) + + class Model(tf.keras.Model): + + def __init__(self): + super(Model, self).__init__() + # A None name will cause a failure in exporting to a saved model. + self.shared_weights = self.add_weight( + name=None, + shape=(20, 1), + dtype=tf.float32, + initializer=tf.random_normal_initializer( + mean=0.0, stddev=300**(-0.5))) + + def call(self, x): + return tf.add(self.shared_weights, x) + + # Building the model. + model = Model() + model.compile(optimizer='sgd', loss='mean_squared_error') + model.fit(input_data, input_data, epochs=1) + + # Convert model. + converter = lite.TFLiteConverterV2.from_keras_model(model) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + class ControlFlowTest(lite_v2_test_util.ModelTest):