From 34f2782a79d339d2f59114210a0250ed6cdf8b7e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Jul 2020 16:18:12 -0700 Subject: [PATCH] Fixes a misleading error message when creating a functional `tf.keras.Model()`. The Keras functional API uses the keyword arguments `inputs` and `outputs`. Currently however, if someone misspells `outputs` (e.g. as `output`), they get the following error: `TypeError: ('Keyword argument not understood:', 'inputs')`. This is confusing as it suggests that there is a problem with `inputs`, not `output`. This error might ideally be surfaced where it is detected (and then ignored in a `try` statement) in `Functional.__init__()`, however that would require a larger change. Instead this error message is fixed by telling `Model.__init__()` to validate `"inputs'` and `'outputs'` with its `kwargs`. This is less ideal because these arguments are not supposed to be passed to `Model.__init__()`, but as the user thinks that they are calling simply calling `Model.__init__()` it should not cause them any confusion. This fixes Keras [issue 13743](https://github.com/keras-team/keras/issues/13743). PiperOrigin-RevId: 321668253 Change-Id: Ideff3cd9298f573b633a2e6e821fa77b1c862570 --- tensorflow/python/keras/engine/functional.py | 4 ---- .../python/keras/engine/functional_test.py | 20 +++++++++++++++++++ tensorflow/python/keras/engine/training.py | 6 ++++-- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index fd80e7f8bb4..6c725d0d795 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -107,10 +107,6 @@ class Functional(training_lib.Model): @trackable.no_automatic_dependency_tracking def __init__(self, inputs=None, outputs=None, name=None, trainable=True): - # generic_utils.validate_kwargs( - # kwargs, {'name', 'trainable'}, - # 'Functional models may only specify `name` and `trainable` keyword ' - # 'arguments during initialization. Got an unexpected argument:') super(Functional, self).__init__(name=name, trainable=trainable) self._init_graph_network(inputs, outputs) diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py index f8a0c4103c5..b104668c9e1 100644 --- a/tensorflow/python/keras/engine/functional_test.py +++ b/tensorflow/python/keras/engine/functional_test.py @@ -2321,5 +2321,25 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): # if training is not passed at runtime self.assertAllEqual(network(x), _call(x, None)) + +class InputsOutputsErrorTest(keras_parameterized.TestCase): + + @testing_utils.enable_v2_dtype_behavior + def test_input_error(self): + inputs = input_layer_lib.Input((10,)) + outputs = layers.Dense(10)(inputs) + with self.assertRaisesRegex( + TypeError, "('Keyword argument not understood:', 'input')"): + models.Model(input=inputs, outputs=outputs) + + @testing_utils.enable_v2_dtype_behavior + def test_output_error(self): + inputs = input_layer_lib.Input((10,)) + outputs = layers.Dense(10)(inputs) + with self.assertRaisesRegex( + TypeError, "('Keyword argument not understood:', 'output')"): + models.Model(inputs=inputs, output=outputs) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 60b31e1ee21..ad72251ed9d 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -258,8 +258,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): # The following are implemented as property functions: # self.trainable_weights # self.non_trainable_weights - generic_utils.validate_kwargs(kwargs, {'trainable', 'dtype', 'dynamic', - 'name', 'autocast'}) + # `inputs` / `outputs` will only appear in kwargs if either are misspelled. + generic_utils.validate_kwargs(kwargs, { + 'trainable', 'dtype', 'dynamic', 'name', 'autocast', 'inputs', 'outputs' + }) super(Model, self).__init__(**kwargs) # By default, Model is a subclass model, which is not in graph network. self._is_graph_network = False