Raise error when calling loaded model with layers that are not fully serialized.

PiperOrigin-RevId: 316742578
Change-Id: Iefc40d21374388ed99f7ff40bb09436830b46cbe
This commit is contained in:
Katherine Wu 2020-06-16 13:08:10 -07:00 committed by TensorFlower Gardener
parent 2a0ad47926
commit 89df3ddcd5
2 changed files with 75 additions and 6 deletions

View File

@ -690,18 +690,22 @@ def _finalize_saved_model_layers(layers):
layer, _get_keras_attr(layer).call_and_return_conditional_losses,
return_method=True)
layer._init_call_fn_args()
else:
layer.call = types.MethodType(
_unable_to_call_layer_due_to_serialization_issue, layer)
for layer in layers:
# 2. Set model inputs and outputs.
if isinstance(layer, RevivedNetwork):
_set_network_attributes_from_metadata(layer)
call_fn = _get_keras_attr(layer).call_and_return_conditional_losses
if call_fn.input_signature is None:
inputs = infer_inputs_from_restored_call_function(call_fn)
else:
inputs = call_fn.input_signature[0]
layer._set_inputs(inputs)
if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'):
call_fn = _get_keras_attr(layer).call_and_return_conditional_losses
if call_fn.input_signature is None:
inputs = infer_inputs_from_restored_call_function(call_fn)
else:
inputs = call_fn.input_signature[0]
layer._set_inputs(inputs) # pylint: disable=protected-access
# 3. Add losses that aren't generated by the layer.call function.
_restore_layer_unconditional_losses(layer)
@ -713,6 +717,41 @@ def _finalize_saved_model_layers(layers):
# pylint: enable=protected-access
def _unable_to_call_layer_due_to_serialization_issue(
layer, *unused_args, **unused_kwargs):
"""Replaces the `layer.call` if the layer was not fully serialized.
Keras Model/Layer serialization is relatively relaxed because SavedModels
are not always loaded back as keras models. Thus, when there is an issue
tracing a non-signature function, a warning is logged instead of raising an
error. This results in a SavedModel where the model's call function is saved,
but the internal layer call functions are not.
When deserialized with `tf.keras.models.load_model`, the internal layers
which do not have serialized call functions should raise an error when called.
Args:
layer: Layer without the serialized call function.
Raises:
ValueError
"""
raise ValueError(
'Cannot call {} ({}), because the call function was not serialized to '
'the SavedModel (due to lack information about the inputs). Please try '
'one of the following methods to fix the serialization:'
'\n\n(1) Implement `get_config` and `from_config` in the layer/model '
'class, and pass the object to the `custom_objects` argument when '
'loading the model. For more details, see: '
'https://www.tensorflow.org/guide/keras/save_and_serialize'
'\n\n(2) Ensure that the subclassed model or layer overwrites `call` '
'and not `__call__`. The input shape and dtype will be automatically '
'recorded when the object is called, and used when saving. To manually '
'specify the input shape/dtype, decorate the call function with '
'`@tf.function(input_signature=...)`.'.format(layer.name, layer))
def _finalize_config_layers(layers):
"""Runs the final steps of loading Keras Layers from config."""
for layer in layers:

View File

@ -809,6 +809,36 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
self.evaluate(variables.variables_initializer(loaded.variables))
self.assertAllClose(model.predict(f), loaded.predict(f))
def test_load_with_partially_failed_serialization(self):
class BadCustomLayer(keras.layers.Layer):
def __call__(self, inputs):
return inputs
class Model(keras.models.Model):
def __init__(self):
super(Model, self).__init__()
self.layer = BadCustomLayer()
@def_function.function(
input_signature=[tensor_spec.TensorSpec([None, 1])])
def call(self, inputs):
return self.layer(inputs)
model = Model()
inp = constant_op.constant([[1.0]])
model(inp)
saved_model_dir = self._save_model_dir()
tf_save.save(model, saved_model_dir)
loaded = keras_load.load(saved_model_dir)
self.assertAllEqual([[1.0]], self.evaluate(loaded(inp)))
with self.assertRaisesRegexp(ValueError,
'call function was not serialized'):
loaded.layer(inp)
class TestLayerCallTracing(test.TestCase, parameterized.TestCase):