Add helpful information to base_layer.get_config NotImplementedError.
This makes it possible to identify which layer the problem originates from. PiperOrigin-RevId: 295196040 Change-Id: I12ea152978a5a625b7c46f9c7713bec3206e3d87
This commit is contained in:
parent
62dfa424e4
commit
01a54f6674
@ -523,8 +523,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
# Check that either the only argument in the `__init__` is `self`,
|
||||
# or that `get_config` has been overridden:
|
||||
if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'):
|
||||
raise NotImplementedError('Layers with arguments in `__init__` must '
|
||||
'override `get_config`.')
|
||||
raise NotImplementedError('Layer %s has arguments in `__init__` and '
|
||||
'therefore must override `get_config`.' %
|
||||
self.__class__.__name__)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
|
@ -535,7 +535,7 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
# `__init__` includes kwargs but `get_config` is not overridden, so
|
||||
# an error should be thrown:
|
||||
with self.assertRaises(NotImplementedError):
|
||||
with self.assertRaisesRegexp(NotImplementedError, 'Layer MyLayer has'):
|
||||
MyLayer('custom').get_config()
|
||||
|
||||
class MyLayerNew(keras.layers.Layer):
|
||||
|
Loading…
x
Reference in New Issue
Block a user