Improve error message when Keras is unable to find a custom object.
PiperOrigin-RevId: 342918274 Change-Id: I882e96c55d75fa1edf9688215b6738624ad93c63
This commit is contained in:
parent
f096affea0
commit
5186e8295c
@ -293,7 +293,12 @@ def class_and_config_for_serialized_keras_object(
|
||||
class_name = config['class_name']
|
||||
cls = get_registered_object(class_name, custom_objects, module_objects)
|
||||
if cls is None:
|
||||
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
|
||||
raise ValueError(
|
||||
'Unknown {}: {}. Please ensure this object is '
|
||||
'passed to the `custom_objects` argument. See '
|
||||
'https://www.tensorflow.org/guide/keras/save_and_serialize'
|
||||
'#registering_the_custom_object for details.'
|
||||
.format(printable_module_name, class_name))
|
||||
|
||||
cls_config = config['config']
|
||||
# Check if `cls_config` is a list. If it is a list, return the class and the
|
||||
@ -375,7 +380,12 @@ def deserialize_keras_object(identifier,
|
||||
obj = module_objects.get(object_name)
|
||||
if obj is None:
|
||||
raise ValueError(
|
||||
'Unknown ' + printable_module_name + ': ' + object_name)
|
||||
'Unknown {}: {}. Please ensure this object is '
|
||||
'passed to the `custom_objects` argument. See '
|
||||
'https://www.tensorflow.org/guide/keras/save_and_serialize'
|
||||
'#registering_the_custom_object for details.'
|
||||
.format(printable_module_name, object_name))
|
||||
|
||||
# Classes passed by name are instantiated with no args, functions are
|
||||
# returned as-is.
|
||||
if tf_inspect.isclass(obj):
|
||||
|
||||
@ -354,6 +354,20 @@ class SerializeKerasObjectTest(test.TestCase):
|
||||
expected_output = new_model.predict(input_data)
|
||||
self.assertAllEqual(output, expected_output)
|
||||
|
||||
def test_deserialize_unknown_object(self):
|
||||
|
||||
class CustomLayer(keras.layers.Layer):
|
||||
pass
|
||||
|
||||
layer = CustomLayer()
|
||||
config = keras.utils.generic_utils.serialize_keras_object(layer)
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'passed to the `custom_objects` arg'):
|
||||
keras.utils.generic_utils.deserialize_keras_object(config)
|
||||
restored = keras.utils.generic_utils.deserialize_keras_object(
|
||||
config, custom_objects={'CustomLayer': CustomLayer})
|
||||
self.assertIsInstance(restored, CustomLayer)
|
||||
|
||||
|
||||
class SliceArraysTest(test.TestCase):
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user