diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py index d0a3daab00f..5885e240e92 100644 --- a/tensorflow/python/keras/saving/save_test.py +++ b/tensorflow/python/keras/saving/save_test.py @@ -23,6 +23,7 @@ import os import shutil import sys import tempfile +import warnings from absl.testing import parameterized import numpy as np @@ -1036,6 +1037,54 @@ class TestWholeModelSaving(keras_parameterized.TestCase): # model. self.assertSequenceEqual(model.metrics_names, loaded.metrics_names) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_warning_when_saving_invalid_custom_mask_layer(self): + + class MyMasking(keras.layers.Layer): + + def call(self, inputs): + return inputs + + def compute_mask(self, inputs, mask=None): + mask = math_ops.not_equal(inputs, 0) + return mask + + class MyLayer(keras.layers.Layer): + + def call(self, inputs, mask=None): + return array_ops.identity(inputs) + + samples = np.random.random((2, 2)) + model = keras.Sequential([MyMasking(), MyLayer()]) + model.predict(samples) + with warnings.catch_warnings(record=True) as w: + model.save(self._save_model_dir(), testing_utils.get_save_format()) + self.assertIn(generic_utils.CustomMaskWarning, + {warning.category for warning in w}) + + # Test that setting up a custom mask correctly does not issue a warning. + class MyCorrectMasking(keras.layers.Layer): + + def call(self, inputs): + return inputs + + def compute_mask(self, inputs, mask=None): + mask = math_ops.not_equal(inputs, 0) + return mask + + # This get_config doesn't actually do anything because our mask is + # static and doesn't need any external information to work. We do need a + # dummy get_config method to prevent the warning from appearing, however. + def get_config(self, *args, **kwargs): + return {} + + model = keras.Sequential([MyCorrectMasking(), MyLayer()]) + model.predict(samples) + with warnings.catch_warnings(record=True) as w: + model.save(self._save_model_dir(), testing_utils.get_save_format()) + self.assertNotIn(generic_utils.CustomMaskWarning, + {warning.category for warning in w}) + # Factory functions to create models that will be serialized inside a Network. def _make_graph_network(input_size, output_size): diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py index 89aeaf4ab28..b3371555652 100644 --- a/tensorflow/python/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -27,6 +27,7 @@ import sys import threading import time import types as python_types +import warnings import weakref import numpy as np @@ -460,6 +461,12 @@ def get_registered_object(name, custom_objects=None, module_objects=None): return None +# pylint: disable=g-bad-exception-name +class CustomMaskWarning(Warning): + pass +# pylint: enable=g-bad-exception-name + + @keras_export('keras.utils.serialize_keras_object') def serialize_keras_object(instance): """Serialize a Keras object into a JSON-compatible representation. @@ -479,6 +486,20 @@ def serialize_keras_object(instance): if instance is None: return None + # pylint: disable=protected-access + # + # For v1 layers, checking supports_masking is not enough. We have to also + # check whether compute_mask has been overridden. + supports_masking = (getattr(instance, 'supports_masking', False) + or (hasattr(instance, 'compute_mask') + and not is_default(instance.compute_mask))) + if supports_masking and is_default(instance.get_config): + warnings.warn('Custom mask layers require a config and must override ' + 'get_config. When loading, the custom mask layer must be ' + 'passed to the custom_objects argument.', + category=CustomMaskWarning) + # pylint: enable=protected-access + if hasattr(instance, 'get_config'): name = get_registered_name(instance.__class__) try: