Add proper error message when user does not call "super.__init__()" in custom layer.
The current error message is quite obscure and surprising. PiperOrigin-RevId: 290557761 Change-Id: Ibf91099bceb29b562e100e399aaf4757c74060d6
This commit is contained in:
parent
6bc6e8df2c
commit
6e85ba8898
@ -650,7 +650,11 @@ class Layer(module.Module):
|
||||
|
||||
Raises:
|
||||
ValueError: if the layer's `call` method returns None (an invalid value).
|
||||
RuntimeError: if `super().__init__()` was not called in the constructor.
|
||||
"""
|
||||
if not hasattr(self, '_thread_local'):
|
||||
raise RuntimeError(
|
||||
'You must call `super().__init__()` in the layer constructor.')
|
||||
call_context = base_layer_utils.call_context()
|
||||
input_list = nest.flatten(inputs)
|
||||
|
||||
|
@ -582,6 +582,17 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
||||
model = keras.Sequential(dense)
|
||||
self.assertEqual(model.count_params(), 16 * 4 + 16)
|
||||
|
||||
def test_super_not_called(self):
|
||||
|
||||
class CustomLayerNotCallingSuper(keras.layers.Layer):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
layer = CustomLayerNotCallingSuper()
|
||||
with self.assertRaisesRegexp(RuntimeError, 'You must call `super()'):
|
||||
layer(np.random.random((10, 2)))
|
||||
|
||||
|
||||
class SymbolicSupportTest(test.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user