Print a warning when trying to serialize an invalid custom mask.
PiperOrigin-RevId: 359559873 Change-Id: Ief77cdff3c3befc880e506bc25ce3fc0e527b5ac
This commit is contained in:
parent
63e5ccf8f7
commit
076b5be77d
@ -23,6 +23,7 @@ import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
@ -1036,6 +1037,54 @@ class TestWholeModelSaving(keras_parameterized.TestCase):
|
||||
# model.
|
||||
self.assertSequenceEqual(model.metrics_names, loaded.metrics_names)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_warning_when_saving_invalid_custom_mask_layer(self):
|
||||
|
||||
class MyMasking(keras.layers.Layer):
|
||||
|
||||
def call(self, inputs):
|
||||
return inputs
|
||||
|
||||
def compute_mask(self, inputs, mask=None):
|
||||
mask = math_ops.not_equal(inputs, 0)
|
||||
return mask
|
||||
|
||||
class MyLayer(keras.layers.Layer):
|
||||
|
||||
def call(self, inputs, mask=None):
|
||||
return array_ops.identity(inputs)
|
||||
|
||||
samples = np.random.random((2, 2))
|
||||
model = keras.Sequential([MyMasking(), MyLayer()])
|
||||
model.predict(samples)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
model.save(self._save_model_dir(), testing_utils.get_save_format())
|
||||
self.assertIn(generic_utils.CustomMaskWarning,
|
||||
{warning.category for warning in w})
|
||||
|
||||
# Test that setting up a custom mask correctly does not issue a warning.
|
||||
class MyCorrectMasking(keras.layers.Layer):
|
||||
|
||||
def call(self, inputs):
|
||||
return inputs
|
||||
|
||||
def compute_mask(self, inputs, mask=None):
|
||||
mask = math_ops.not_equal(inputs, 0)
|
||||
return mask
|
||||
|
||||
# This get_config doesn't actually do anything because our mask is
|
||||
# static and doesn't need any external information to work. We do need a
|
||||
# dummy get_config method to prevent the warning from appearing, however.
|
||||
def get_config(self, *args, **kwargs):
|
||||
return {}
|
||||
|
||||
model = keras.Sequential([MyCorrectMasking(), MyLayer()])
|
||||
model.predict(samples)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
model.save(self._save_model_dir(), testing_utils.get_save_format())
|
||||
self.assertNotIn(generic_utils.CustomMaskWarning,
|
||||
{warning.category for warning in w})
|
||||
|
||||
|
||||
# Factory functions to create models that will be serialized inside a Network.
|
||||
def _make_graph_network(input_size, output_size):
|
||||
|
@ -27,6 +27,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
import types as python_types
|
||||
import warnings
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
@ -460,6 +461,12 @@ def get_registered_object(name, custom_objects=None, module_objects=None):
|
||||
return None
|
||||
|
||||
|
||||
# pylint: disable=g-bad-exception-name
|
||||
class CustomMaskWarning(Warning):
|
||||
pass
|
||||
# pylint: enable=g-bad-exception-name
|
||||
|
||||
|
||||
@keras_export('keras.utils.serialize_keras_object')
|
||||
def serialize_keras_object(instance):
|
||||
"""Serialize a Keras object into a JSON-compatible representation.
|
||||
@ -479,6 +486,20 @@ def serialize_keras_object(instance):
|
||||
if instance is None:
|
||||
return None
|
||||
|
||||
# pylint: disable=protected-access
|
||||
#
|
||||
# For v1 layers, checking supports_masking is not enough. We have to also
|
||||
# check whether compute_mask has been overridden.
|
||||
supports_masking = (getattr(instance, 'supports_masking', False)
|
||||
or (hasattr(instance, 'compute_mask')
|
||||
and not is_default(instance.compute_mask)))
|
||||
if supports_masking and is_default(instance.get_config):
|
||||
warnings.warn('Custom mask layers require a config and must override '
|
||||
'get_config. When loading, the custom mask layer must be '
|
||||
'passed to the custom_objects argument.',
|
||||
category=CustomMaskWarning)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if hasattr(instance, 'get_config'):
|
||||
name = get_registered_name(instance.__class__)
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user