Add helpful information to base_layer.get_config NotImplementedError.

This makes it possible to identify which layer the problem originates from.

PiperOrigin-RevId: 295196040
Change-Id: I12ea152978a5a625b7c46f9c7713bec3206e3d87
This commit is contained in:
A. Unique TensorFlower 2020-02-14 12:03:12 -08:00 committed by TensorFlower Gardener
parent 62dfa424e4
commit 01a54f6674
2 changed files with 4 additions and 3 deletions

View File

@ -523,8 +523,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# Check that either the only argument in the `__init__` is `self`,
# or that `get_config` has been overridden:
if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'):
raise NotImplementedError('Layers with arguments in `__init__` must '
'override `get_config`.')
raise NotImplementedError('Layer %s has arguments in `__init__` and '
'therefore must override `get_config`.' %
self.__class__.__name__)
return config
@classmethod

View File

@ -535,7 +535,7 @@ class BaseLayerTest(keras_parameterized.TestCase):
# `__init__` includes kwargs but `get_config` is not overridden, so
# an error should be thrown:
with self.assertRaises(NotImplementedError):
with self.assertRaisesRegexp(NotImplementedError, 'Layer MyLayer has'):
MyLayer('custom').get_config()
class MyLayerNew(keras.layers.Layer):