From 09deaeb03ca4ceb40cf600a337083e2054e65390 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 3 Apr 2019 15:47:14 -0700 Subject: [PATCH] Accept "custom_objects" as arguments to `TFLiteConverter.from_keras_model` This would be needed to, for example, load a keras model containing a `tensorflow_hub.KerasLayer` PiperOrigin-RevId: 241821677 --- tensorflow/lite/python/lite.py | 7 ++- tensorflow/lite/python/lite_test.py | 54 ++++++++++++++++++- .../tensorflow.lite.-t-f-lite-converter.pbtxt | 2 +- 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 751f07bb209..b2f77ea9a4e 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -579,7 +579,8 @@ class TFLiteConverter(object): model_file, input_arrays=None, input_shapes=None, - output_arrays=None): + output_arrays=None, + custom_objects=None): """Creates a TFLiteConverter class from a tf.keras model file. Args: @@ -592,13 +593,15 @@ class TFLiteConverter(object): None}). (default None) output_arrays: List of output tensors to freeze graph with. Uses output arrays from SignatureDef when none are provided. (default None) + custom_objects: Dict mapping names (strings) to custom classes or + functions to be considered during model deserialization. (default None) Returns: TFLiteConverter class. """ _keras.backend.clear_session() _keras.backend.set_learning_phase(False) - keras_model = _keras.models.load_model(model_file) + keras_model = _keras.models.load_model(model_file, custom_objects) sess = _keras.backend.get_session() # Get input and output tensors. diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index 0ff6f5865ee..34a628e733d 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -1066,16 +1066,33 @@ class FromSavedModelTest(test_util.TensorFlowTestCase): interpreter.allocate_tensors() +class MyAddLayer(keras.layers.Layer): + + def __init__(self, increment, **kwargs): + super(MyAddLayer, self).__init__(**kwargs) + self._increment = increment + + def call(self, inputs): + return inputs + self._increment + + def get_config(self): + config = super(MyAddLayer, self).get_config() + config['increment'] = self._increment + return config + + @test_util.run_v1_only('b/120545219') class FromKerasFile(test_util.TensorFlowTestCase): def setUp(self): keras.backend.clear_session() - def _getSequentialModel(self): + def _getSequentialModel(self, include_custom_layer=False): with session.Session().as_default(): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,))) + if include_custom_layer: + model.add(MyAddLayer(1.0)) model.add(keras.layers.RepeatVector(3)) model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) model.compile( @@ -1093,6 +1110,10 @@ class FromKerasFile(test_util.TensorFlowTestCase): keras.models.save_model(model, keras_file) finally: os.close(fd) + + if include_custom_layer: + custom_objects = {'MyAddLayer': MyAddLayer} + return keras_file, custom_objects return keras_file def testSequentialModel(self): @@ -1133,6 +1154,37 @@ class FromKerasFile(test_util.TensorFlowTestCase): np.testing.assert_almost_equal(tflite_result, keras_result, 5) os.remove(keras_file) + def testCustomLayer(self): + """Test a Sequential tf.keras model with default inputs.""" + keras_file, custom_objects = self._getSequentialModel( + include_custom_layer=True) + + converter = lite.TFLiteConverter.from_keras_model_file( + keras_file, custom_objects=custom_objects) + + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check tensor details of converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # Check inference of converted model. + input_data = np.array([[1, 2, 3]], dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], input_data) + interpreter.invoke() + tflite_result = interpreter.get_tensor(output_details[0]['index']) + + keras_model = keras.models.load_model( + keras_file, custom_objects=custom_objects) + keras_result = keras_model.predict(input_data) + + np.testing.assert_almost_equal(tflite_result, keras_result, 5) + os.remove(keras_file) + def testSequentialModelInputArray(self): """Test a Sequential tf.keras model testing input arrays argument.""" keras_file = self._getSequentialModel() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt index c955b1a04a4..791031c1611 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt @@ -16,7 +16,7 @@ tf_class { } member_method { name: "from_keras_model_file" - argspec: "args=[\'cls\', \'model_file\', \'input_arrays\', \'input_shapes\', \'output_arrays\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'cls\', \'model_file\', \'input_arrays\', \'input_shapes\', \'output_arrays\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_saved_model"