Raise error when calling loaded model with layers that are not fully serialized.
PiperOrigin-RevId: 316742578 Change-Id: Iefc40d21374388ed99f7ff40bb09436830b46cbe
This commit is contained in:
parent
2a0ad47926
commit
89df3ddcd5
@ -690,18 +690,22 @@ def _finalize_saved_model_layers(layers):
|
||||
layer, _get_keras_attr(layer).call_and_return_conditional_losses,
|
||||
return_method=True)
|
||||
layer._init_call_fn_args()
|
||||
else:
|
||||
layer.call = types.MethodType(
|
||||
_unable_to_call_layer_due_to_serialization_issue, layer)
|
||||
|
||||
for layer in layers:
|
||||
# 2. Set model inputs and outputs.
|
||||
if isinstance(layer, RevivedNetwork):
|
||||
_set_network_attributes_from_metadata(layer)
|
||||
|
||||
call_fn = _get_keras_attr(layer).call_and_return_conditional_losses
|
||||
if call_fn.input_signature is None:
|
||||
inputs = infer_inputs_from_restored_call_function(call_fn)
|
||||
else:
|
||||
inputs = call_fn.input_signature[0]
|
||||
layer._set_inputs(inputs)
|
||||
if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'):
|
||||
call_fn = _get_keras_attr(layer).call_and_return_conditional_losses
|
||||
if call_fn.input_signature is None:
|
||||
inputs = infer_inputs_from_restored_call_function(call_fn)
|
||||
else:
|
||||
inputs = call_fn.input_signature[0]
|
||||
layer._set_inputs(inputs) # pylint: disable=protected-access
|
||||
|
||||
# 3. Add losses that aren't generated by the layer.call function.
|
||||
_restore_layer_unconditional_losses(layer)
|
||||
@ -713,6 +717,41 @@ def _finalize_saved_model_layers(layers):
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def _unable_to_call_layer_due_to_serialization_issue(
|
||||
layer, *unused_args, **unused_kwargs):
|
||||
"""Replaces the `layer.call` if the layer was not fully serialized.
|
||||
|
||||
Keras Model/Layer serialization is relatively relaxed because SavedModels
|
||||
are not always loaded back as keras models. Thus, when there is an issue
|
||||
tracing a non-signature function, a warning is logged instead of raising an
|
||||
error. This results in a SavedModel where the model's call function is saved,
|
||||
but the internal layer call functions are not.
|
||||
|
||||
When deserialized with `tf.keras.models.load_model`, the internal layers
|
||||
which do not have serialized call functions should raise an error when called.
|
||||
|
||||
Args:
|
||||
layer: Layer without the serialized call function.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
"""
|
||||
|
||||
raise ValueError(
|
||||
'Cannot call {} ({}), because the call function was not serialized to '
|
||||
'the SavedModel (due to lack information about the inputs). Please try '
|
||||
'one of the following methods to fix the serialization:'
|
||||
'\n\n(1) Implement `get_config` and `from_config` in the layer/model '
|
||||
'class, and pass the object to the `custom_objects` argument when '
|
||||
'loading the model. For more details, see: '
|
||||
'https://www.tensorflow.org/guide/keras/save_and_serialize'
|
||||
'\n\n(2) Ensure that the subclassed model or layer overwrites `call` '
|
||||
'and not `__call__`. The input shape and dtype will be automatically '
|
||||
'recorded when the object is called, and used when saving. To manually '
|
||||
'specify the input shape/dtype, decorate the call function with '
|
||||
'`@tf.function(input_signature=...)`.'.format(layer.name, layer))
|
||||
|
||||
|
||||
def _finalize_config_layers(layers):
|
||||
"""Runs the final steps of loading Keras Layers from config."""
|
||||
for layer in layers:
|
||||
|
@ -809,6 +809,36 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
self.assertAllClose(model.predict(f), loaded.predict(f))
|
||||
|
||||
def test_load_with_partially_failed_serialization(self):
|
||||
|
||||
class BadCustomLayer(keras.layers.Layer):
|
||||
|
||||
def __call__(self, inputs):
|
||||
return inputs
|
||||
|
||||
class Model(keras.models.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.layer = BadCustomLayer()
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec([None, 1])])
|
||||
def call(self, inputs):
|
||||
return self.layer(inputs)
|
||||
|
||||
model = Model()
|
||||
inp = constant_op.constant([[1.0]])
|
||||
model(inp)
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
self.assertAllEqual([[1.0]], self.evaluate(loaded(inp)))
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'call function was not serialized'):
|
||||
loaded.layer(inp)
|
||||
|
||||
|
||||
class TestLayerCallTracing(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user