Accept "custom_objects" as arguments to TFLiteConverter.from_keras_model
This would be needed to, for example, load a keras model containing a `tensorflow_hub.KerasLayer` PiperOrigin-RevId: 241821677
This commit is contained in:
parent
ea5004e52b
commit
09deaeb03c
@ -579,7 +579,8 @@ class TFLiteConverter(object):
|
||||
model_file,
|
||||
input_arrays=None,
|
||||
input_shapes=None,
|
||||
output_arrays=None):
|
||||
output_arrays=None,
|
||||
custom_objects=None):
|
||||
"""Creates a TFLiteConverter class from a tf.keras model file.
|
||||
|
||||
Args:
|
||||
@ -592,13 +593,15 @@ class TFLiteConverter(object):
|
||||
None}). (default None)
|
||||
output_arrays: List of output tensors to freeze graph with. Uses output
|
||||
arrays from SignatureDef when none are provided. (default None)
|
||||
custom_objects: Dict mapping names (strings) to custom classes or
|
||||
functions to be considered during model deserialization. (default None)
|
||||
|
||||
Returns:
|
||||
TFLiteConverter class.
|
||||
"""
|
||||
_keras.backend.clear_session()
|
||||
_keras.backend.set_learning_phase(False)
|
||||
keras_model = _keras.models.load_model(model_file)
|
||||
keras_model = _keras.models.load_model(model_file, custom_objects)
|
||||
sess = _keras.backend.get_session()
|
||||
|
||||
# Get input and output tensors.
|
||||
|
@ -1066,16 +1066,33 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
|
||||
class MyAddLayer(keras.layers.Layer):
|
||||
|
||||
def __init__(self, increment, **kwargs):
|
||||
super(MyAddLayer, self).__init__(**kwargs)
|
||||
self._increment = increment
|
||||
|
||||
def call(self, inputs):
|
||||
return inputs + self._increment
|
||||
|
||||
def get_config(self):
|
||||
config = super(MyAddLayer, self).get_config()
|
||||
config['increment'] = self._increment
|
||||
return config
|
||||
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
class FromKerasFile(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
keras.backend.clear_session()
|
||||
|
||||
def _getSequentialModel(self):
|
||||
def _getSequentialModel(self, include_custom_layer=False):
|
||||
with session.Session().as_default():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(2, input_shape=(3,)))
|
||||
if include_custom_layer:
|
||||
model.add(MyAddLayer(1.0))
|
||||
model.add(keras.layers.RepeatVector(3))
|
||||
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
|
||||
model.compile(
|
||||
@ -1093,6 +1110,10 @@ class FromKerasFile(test_util.TensorFlowTestCase):
|
||||
keras.models.save_model(model, keras_file)
|
||||
finally:
|
||||
os.close(fd)
|
||||
|
||||
if include_custom_layer:
|
||||
custom_objects = {'MyAddLayer': MyAddLayer}
|
||||
return keras_file, custom_objects
|
||||
return keras_file
|
||||
|
||||
def testSequentialModel(self):
|
||||
@ -1133,6 +1154,37 @@ class FromKerasFile(test_util.TensorFlowTestCase):
|
||||
np.testing.assert_almost_equal(tflite_result, keras_result, 5)
|
||||
os.remove(keras_file)
|
||||
|
||||
def testCustomLayer(self):
|
||||
"""Test a Sequential tf.keras model with default inputs."""
|
||||
keras_file, custom_objects = self._getSequentialModel(
|
||||
include_custom_layer=True)
|
||||
|
||||
converter = lite.TFLiteConverter.from_keras_model_file(
|
||||
keras_file, custom_objects=custom_objects)
|
||||
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
# Check tensor details of converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
output_details = interpreter.get_output_details()
|
||||
|
||||
# Check inference of converted model.
|
||||
input_data = np.array([[1, 2, 3]], dtype=np.float32)
|
||||
interpreter.set_tensor(input_details[0]['index'], input_data)
|
||||
interpreter.invoke()
|
||||
tflite_result = interpreter.get_tensor(output_details[0]['index'])
|
||||
|
||||
keras_model = keras.models.load_model(
|
||||
keras_file, custom_objects=custom_objects)
|
||||
keras_result = keras_model.predict(input_data)
|
||||
|
||||
np.testing.assert_almost_equal(tflite_result, keras_result, 5)
|
||||
os.remove(keras_file)
|
||||
|
||||
def testSequentialModelInputArray(self):
|
||||
"""Test a Sequential tf.keras model testing input arrays argument."""
|
||||
keras_file = self._getSequentialModel()
|
||||
|
@ -16,7 +16,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_keras_model_file"
|
||||
argspec: "args=[\'cls\', \'model_file\', \'input_arrays\', \'input_shapes\', \'output_arrays\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'cls\', \'model_file\', \'input_arrays\', \'input_shapes\', \'output_arrays\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_saved_model"
|
||||
|
Loading…
Reference in New Issue
Block a user