Fixes a misleading error message when creating a functional tf.keras.Model()
.
The Keras functional API uses the keyword arguments `inputs` and `outputs`. Currently however, if someone misspells `outputs` (e.g. as `output`), they get the following error: `TypeError: ('Keyword argument not understood:', 'inputs')`. This is confusing as it suggests that there is a problem with `inputs`, not `output`. This error might ideally be surfaced where it is detected (and then ignored in a `try` statement) in `Functional.__init__()`, however that would require a larger change. Instead this error message is fixed by telling `Model.__init__()` to validate `"inputs'` and `'outputs'` with its `kwargs`. This is less ideal because these arguments are not supposed to be passed to `Model.__init__()`, but as the user thinks that they are calling simply calling `Model.__init__()` it should not cause them any confusion. This fixes Keras [issue 13743](https://github.com/keras-team/keras/issues/13743). PiperOrigin-RevId: 321668253 Change-Id: Ideff3cd9298f573b633a2e6e821fa77b1c862570
This commit is contained in:
parent
3cc65294f8
commit
34f2782a79
@ -107,10 +107,6 @@ class Functional(training_lib.Model):
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, inputs=None, outputs=None, name=None, trainable=True):
|
||||
# generic_utils.validate_kwargs(
|
||||
# kwargs, {'name', 'trainable'},
|
||||
# 'Functional models may only specify `name` and `trainable` keyword '
|
||||
# 'arguments during initialization. Got an unexpected argument:')
|
||||
super(Functional, self).__init__(name=name, trainable=trainable)
|
||||
self._init_graph_network(inputs, outputs)
|
||||
|
||||
|
@ -2321,5 +2321,25 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
||||
# if training is not passed at runtime
|
||||
self.assertAllEqual(network(x), _call(x, None))
|
||||
|
||||
|
||||
class InputsOutputsErrorTest(keras_parameterized.TestCase):
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_input_error(self):
|
||||
inputs = input_layer_lib.Input((10,))
|
||||
outputs = layers.Dense(10)(inputs)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "('Keyword argument not understood:', 'input')"):
|
||||
models.Model(input=inputs, outputs=outputs)
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_output_error(self):
|
||||
inputs = input_layer_lib.Input((10,))
|
||||
outputs = layers.Dense(10)(inputs)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "('Keyword argument not understood:', 'output')"):
|
||||
models.Model(inputs=inputs, output=outputs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -258,8 +258,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
# The following are implemented as property functions:
|
||||
# self.trainable_weights
|
||||
# self.non_trainable_weights
|
||||
generic_utils.validate_kwargs(kwargs, {'trainable', 'dtype', 'dynamic',
|
||||
'name', 'autocast'})
|
||||
# `inputs` / `outputs` will only appear in kwargs if either are misspelled.
|
||||
generic_utils.validate_kwargs(kwargs, {
|
||||
'trainable', 'dtype', 'dynamic', 'name', 'autocast', 'inputs', 'outputs'
|
||||
})
|
||||
super(Model, self).__init__(**kwargs)
|
||||
# By default, Model is a subclass model, which is not in graph network.
|
||||
self._is_graph_network = False
|
||||
|
Loading…
Reference in New Issue
Block a user