Accept "custom_objects" as arguments to TFLiteConverter.from_keras_model
This would be needed to, for example, load a keras model containing a `tensorflow_hub.KerasLayer` PiperOrigin-RevId: 241821677
This commit is contained in:
parent
ea5004e52b
commit
09deaeb03c
@ -579,7 +579,8 @@ class TFLiteConverter(object):
|
|||||||
model_file,
|
model_file,
|
||||||
input_arrays=None,
|
input_arrays=None,
|
||||||
input_shapes=None,
|
input_shapes=None,
|
||||||
output_arrays=None):
|
output_arrays=None,
|
||||||
|
custom_objects=None):
|
||||||
"""Creates a TFLiteConverter class from a tf.keras model file.
|
"""Creates a TFLiteConverter class from a tf.keras model file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -592,13 +593,15 @@ class TFLiteConverter(object):
|
|||||||
None}). (default None)
|
None}). (default None)
|
||||||
output_arrays: List of output tensors to freeze graph with. Uses output
|
output_arrays: List of output tensors to freeze graph with. Uses output
|
||||||
arrays from SignatureDef when none are provided. (default None)
|
arrays from SignatureDef when none are provided. (default None)
|
||||||
|
custom_objects: Dict mapping names (strings) to custom classes or
|
||||||
|
functions to be considered during model deserialization. (default None)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TFLiteConverter class.
|
TFLiteConverter class.
|
||||||
"""
|
"""
|
||||||
_keras.backend.clear_session()
|
_keras.backend.clear_session()
|
||||||
_keras.backend.set_learning_phase(False)
|
_keras.backend.set_learning_phase(False)
|
||||||
keras_model = _keras.models.load_model(model_file)
|
keras_model = _keras.models.load_model(model_file, custom_objects)
|
||||||
sess = _keras.backend.get_session()
|
sess = _keras.backend.get_session()
|
||||||
|
|
||||||
# Get input and output tensors.
|
# Get input and output tensors.
|
||||||
|
@ -1066,16 +1066,33 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
|
|||||||
interpreter.allocate_tensors()
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
|
||||||
|
class MyAddLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self, increment, **kwargs):
|
||||||
|
super(MyAddLayer, self).__init__(**kwargs)
|
||||||
|
self._increment = increment
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
return inputs + self._increment
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = super(MyAddLayer, self).get_config()
|
||||||
|
config['increment'] = self._increment
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
@test_util.run_v1_only('b/120545219')
|
||||||
class FromKerasFile(test_util.TensorFlowTestCase):
|
class FromKerasFile(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
keras.backend.clear_session()
|
keras.backend.clear_session()
|
||||||
|
|
||||||
def _getSequentialModel(self):
|
def _getSequentialModel(self, include_custom_layer=False):
|
||||||
with session.Session().as_default():
|
with session.Session().as_default():
|
||||||
model = keras.models.Sequential()
|
model = keras.models.Sequential()
|
||||||
model.add(keras.layers.Dense(2, input_shape=(3,)))
|
model.add(keras.layers.Dense(2, input_shape=(3,)))
|
||||||
|
if include_custom_layer:
|
||||||
|
model.add(MyAddLayer(1.0))
|
||||||
model.add(keras.layers.RepeatVector(3))
|
model.add(keras.layers.RepeatVector(3))
|
||||||
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
|
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
|
||||||
model.compile(
|
model.compile(
|
||||||
@ -1093,6 +1110,10 @@ class FromKerasFile(test_util.TensorFlowTestCase):
|
|||||||
keras.models.save_model(model, keras_file)
|
keras.models.save_model(model, keras_file)
|
||||||
finally:
|
finally:
|
||||||
os.close(fd)
|
os.close(fd)
|
||||||
|
|
||||||
|
if include_custom_layer:
|
||||||
|
custom_objects = {'MyAddLayer': MyAddLayer}
|
||||||
|
return keras_file, custom_objects
|
||||||
return keras_file
|
return keras_file
|
||||||
|
|
||||||
def testSequentialModel(self):
|
def testSequentialModel(self):
|
||||||
@ -1133,6 +1154,37 @@ class FromKerasFile(test_util.TensorFlowTestCase):
|
|||||||
np.testing.assert_almost_equal(tflite_result, keras_result, 5)
|
np.testing.assert_almost_equal(tflite_result, keras_result, 5)
|
||||||
os.remove(keras_file)
|
os.remove(keras_file)
|
||||||
|
|
||||||
|
def testCustomLayer(self):
|
||||||
|
"""Test a Sequential tf.keras model with default inputs."""
|
||||||
|
keras_file, custom_objects = self._getSequentialModel(
|
||||||
|
include_custom_layer=True)
|
||||||
|
|
||||||
|
converter = lite.TFLiteConverter.from_keras_model_file(
|
||||||
|
keras_file, custom_objects=custom_objects)
|
||||||
|
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
self.assertTrue(tflite_model)
|
||||||
|
|
||||||
|
# Check tensor details of converted model.
|
||||||
|
interpreter = Interpreter(model_content=tflite_model)
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
input_details = interpreter.get_input_details()
|
||||||
|
output_details = interpreter.get_output_details()
|
||||||
|
|
||||||
|
# Check inference of converted model.
|
||||||
|
input_data = np.array([[1, 2, 3]], dtype=np.float32)
|
||||||
|
interpreter.set_tensor(input_details[0]['index'], input_data)
|
||||||
|
interpreter.invoke()
|
||||||
|
tflite_result = interpreter.get_tensor(output_details[0]['index'])
|
||||||
|
|
||||||
|
keras_model = keras.models.load_model(
|
||||||
|
keras_file, custom_objects=custom_objects)
|
||||||
|
keras_result = keras_model.predict(input_data)
|
||||||
|
|
||||||
|
np.testing.assert_almost_equal(tflite_result, keras_result, 5)
|
||||||
|
os.remove(keras_file)
|
||||||
|
|
||||||
def testSequentialModelInputArray(self):
|
def testSequentialModelInputArray(self):
|
||||||
"""Test a Sequential tf.keras model testing input arrays argument."""
|
"""Test a Sequential tf.keras model testing input arrays argument."""
|
||||||
keras_file = self._getSequentialModel()
|
keras_file = self._getSequentialModel()
|
||||||
|
@ -16,7 +16,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "from_keras_model_file"
|
name: "from_keras_model_file"
|
||||||
argspec: "args=[\'cls\', \'model_file\', \'input_arrays\', \'input_shapes\', \'output_arrays\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
argspec: "args=[\'cls\', \'model_file\', \'input_arrays\', \'input_shapes\', \'output_arrays\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "from_saved_model"
|
name: "from_saved_model"
|
||||||
|
Loading…
Reference in New Issue
Block a user