Accept "custom_objects" as arguments to TFLiteConverter.from_keras_model

This would be needed to, for example, load a keras model containing a `tensorflow_hub.KerasLayer`

PiperOrigin-RevId: 241821677
This commit is contained in:
Mark Daoust 2019-04-03 15:47:14 -07:00 committed by TensorFlower Gardener
parent ea5004e52b
commit 09deaeb03c
3 changed files with 59 additions and 4 deletions

View File

@ -579,7 +579,8 @@ class TFLiteConverter(object):
model_file,
input_arrays=None,
input_shapes=None,
output_arrays=None):
output_arrays=None,
custom_objects=None):
"""Creates a TFLiteConverter class from a tf.keras model file.
Args:
@ -592,13 +593,15 @@ class TFLiteConverter(object):
None}). (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
custom_objects: Dict mapping names (strings) to custom classes or
functions to be considered during model deserialization. (default None)
Returns:
TFLiteConverter class.
"""
_keras.backend.clear_session()
_keras.backend.set_learning_phase(False)
keras_model = _keras.models.load_model(model_file)
keras_model = _keras.models.load_model(model_file, custom_objects)
sess = _keras.backend.get_session()
# Get input and output tensors.

View File

@ -1066,16 +1066,33 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
interpreter.allocate_tensors()
class MyAddLayer(keras.layers.Layer):
def __init__(self, increment, **kwargs):
super(MyAddLayer, self).__init__(**kwargs)
self._increment = increment
def call(self, inputs):
return inputs + self._increment
def get_config(self):
config = super(MyAddLayer, self).get_config()
config['increment'] = self._increment
return config
@test_util.run_v1_only('b/120545219')
class FromKerasFile(test_util.TensorFlowTestCase):
def setUp(self):
keras.backend.clear_session()
def _getSequentialModel(self):
def _getSequentialModel(self, include_custom_layer=False):
with session.Session().as_default():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
if include_custom_layer:
model.add(MyAddLayer(1.0))
model.add(keras.layers.RepeatVector(3))
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
model.compile(
@ -1093,6 +1110,10 @@ class FromKerasFile(test_util.TensorFlowTestCase):
keras.models.save_model(model, keras_file)
finally:
os.close(fd)
if include_custom_layer:
custom_objects = {'MyAddLayer': MyAddLayer}
return keras_file, custom_objects
return keras_file
def testSequentialModel(self):
@ -1133,6 +1154,37 @@ class FromKerasFile(test_util.TensorFlowTestCase):
np.testing.assert_almost_equal(tflite_result, keras_result, 5)
os.remove(keras_file)
def testCustomLayer(self):
"""Test a Sequential tf.keras model with default inputs."""
keras_file, custom_objects = self._getSequentialModel(
include_custom_layer=True)
converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, custom_objects=custom_objects)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Check tensor details of converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Check inference of converted model.
input_data = np.array([[1, 2, 3]], dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
tflite_result = interpreter.get_tensor(output_details[0]['index'])
keras_model = keras.models.load_model(
keras_file, custom_objects=custom_objects)
keras_result = keras_model.predict(input_data)
np.testing.assert_almost_equal(tflite_result, keras_result, 5)
os.remove(keras_file)
def testSequentialModelInputArray(self):
"""Test a Sequential tf.keras model testing input arrays argument."""
keras_file = self._getSequentialModel()

View File

@ -16,7 +16,7 @@ tf_class {
}
member_method {
name: "from_keras_model_file"
argspec: "args=[\'cls\', \'model_file\', \'input_arrays\', \'input_shapes\', \'output_arrays\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'cls\', \'model_file\', \'input_arrays\', \'input_shapes\', \'output_arrays\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_saved_model"