Add proper error message when user does not call "super.__init__()" in custom layer.

The current error message is quite obscure and surprising.

PiperOrigin-RevId: 290557761
Change-Id: Ibf91099bceb29b562e100e399aaf4757c74060d6
This commit is contained in:
Francois Chollet 2020-01-19 22:04:38 -08:00 committed by TensorFlower Gardener
parent 6bc6e8df2c
commit 6e85ba8898
2 changed files with 15 additions and 0 deletions

View File

@ -650,7 +650,11 @@ class Layer(module.Module):
Raises:
ValueError: if the layer's `call` method returns None (an invalid value).
RuntimeError: if `super().__init__()` was not called in the constructor.
"""
if not hasattr(self, '_thread_local'):
raise RuntimeError(
'You must call `super().__init__()` in the layer constructor.')
call_context = base_layer_utils.call_context()
input_list = nest.flatten(inputs)

View File

@ -582,6 +582,17 @@ class BaseLayerTest(keras_parameterized.TestCase):
model = keras.Sequential(dense)
self.assertEqual(model.count_params(), 16 * 4 + 16)
def test_super_not_called(self):
class CustomLayerNotCallingSuper(keras.layers.Layer):
def __init__(self):
pass
layer = CustomLayerNotCallingSuper()
with self.assertRaisesRegexp(RuntimeError, 'You must call `super()'):
layer(np.random.random((10, 2)))
class SymbolicSupportTest(test.TestCase):