Fixes a misleading error message when creating a functional tf.keras.Model().

The Keras functional API uses the keyword arguments `inputs` and `outputs`. Currently however, if someone misspells `outputs` (e.g. as `output`), they get the following error: `TypeError: ('Keyword argument not understood:', 'inputs')`. This is confusing as it suggests that there is a problem with `inputs`, not `output`.

This error might ideally be surfaced where it is detected (and then ignored in a `try` statement) in `Functional.__init__()`, however that would require a larger change.

Instead this error message is fixed by telling `Model.__init__()` to validate `"inputs'` and `'outputs'` with its `kwargs`. This is less ideal because these arguments are not supposed to be passed to `Model.__init__()`, but as the user thinks that they are calling simply calling `Model.__init__()` it should not cause them any confusion.

This fixes Keras [issue 13743](https://github.com/keras-team/keras/issues/13743).

PiperOrigin-RevId: 321668253
Change-Id: Ideff3cd9298f573b633a2e6e821fa77b1c862570
This commit is contained in:
A. Unique TensorFlower 2020-07-16 16:18:12 -07:00 committed by TensorFlower Gardener
parent 3cc65294f8
commit 34f2782a79
3 changed files with 24 additions and 6 deletions

View File

@ -107,10 +107,6 @@ class Functional(training_lib.Model):
@trackable.no_automatic_dependency_tracking
def __init__(self, inputs=None, outputs=None, name=None, trainable=True):
# generic_utils.validate_kwargs(
# kwargs, {'name', 'trainable'},
# 'Functional models may only specify `name` and `trainable` keyword '
# 'arguments during initialization. Got an unexpected argument:')
super(Functional, self).__init__(name=name, trainable=trainable)
self._init_graph_network(inputs, outputs)

View File

@ -2321,5 +2321,25 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
# if training is not passed at runtime
self.assertAllEqual(network(x), _call(x, None))
class InputsOutputsErrorTest(keras_parameterized.TestCase):
@testing_utils.enable_v2_dtype_behavior
def test_input_error(self):
inputs = input_layer_lib.Input((10,))
outputs = layers.Dense(10)(inputs)
with self.assertRaisesRegex(
TypeError, "('Keyword argument not understood:', 'input')"):
models.Model(input=inputs, outputs=outputs)
@testing_utils.enable_v2_dtype_behavior
def test_output_error(self):
inputs = input_layer_lib.Input((10,))
outputs = layers.Dense(10)(inputs)
with self.assertRaisesRegex(
TypeError, "('Keyword argument not understood:', 'output')"):
models.Model(inputs=inputs, output=outputs)
if __name__ == '__main__':
test.main()

View File

@ -258,8 +258,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
# The following are implemented as property functions:
# self.trainable_weights
# self.non_trainable_weights
generic_utils.validate_kwargs(kwargs, {'trainable', 'dtype', 'dynamic',
'name', 'autocast'})
# `inputs` / `outputs` will only appear in kwargs if either are misspelled.
generic_utils.validate_kwargs(kwargs, {
'trainable', 'dtype', 'dynamic', 'name', 'autocast', 'inputs', 'outputs'
})
super(Model, self).__init__(**kwargs)
# By default, Model is a subclass model, which is not in graph network.
self._is_graph_network = False