From dd116af48d19d8b79ddf0d72e44a98709bfe98a1 Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Sun, 19 Jan 2020 23:33:52 -0800 Subject: [PATCH] Fix issue when a Layer's first argument isn't called "inputs". PiperOrigin-RevId: 290563724 Change-Id: I55a5da8a4624dfc330c89e9ce5302501137b82cb --- tensorflow/python/keras/engine/base_layer.py | 18 +++++++-- .../python/keras/engine/base_layer_test.py | 37 +++++++++++++++++++ .../python/keras/saving/saving_utils.py | 6 +-- 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 71d3084556a..cab0b04b44f 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -626,13 +626,12 @@ class Layer(module.Module): # carry over the input mask return mask - def __call__(self, inputs, *args, **kwargs): + def __call__(self, *args, **kwargs): """Wraps `call`, applying pre- and post-processing steps. Arguments: - inputs: input tensor(s). - *args: additional positional arguments to be passed to `self.call`. - **kwargs: additional keyword arguments to be passed to `self.call`. + *args: Positional arguments to be passed to `self.call`. + **kwargs: Keyword arguments to be passed to `self.call`. Returns: Output tensor(s). @@ -655,6 +654,17 @@ class Layer(module.Module): if not hasattr(self, '_thread_local'): raise RuntimeError( 'You must call `super().__init__()` in the layer constructor.') + + # Grab the first positional or keyword argument. + if args: + inputs = args[0] + args = args[1:] + elif self._call_fn_args[0] in kwargs: + inputs = kwargs.pop(self._call_fn_args[0]) + else: + raise ValueError( + 'The first argument to `Layer.call` must always be passed.') + call_context = base_layer_utils.call_context() input_list = nest.flatten(inputs) diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index 98ec9d86184..31193fe8b57 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -593,6 +593,43 @@ class BaseLayerTest(keras_parameterized.TestCase): with self.assertRaisesRegexp(RuntimeError, 'You must call `super()'): layer(np.random.random((10, 2))) + @test_util.run_in_graph_and_eager_modes + def test_first_arg_not_called_inputs(self): + x, y = array_ops.ones((10, 1)), array_ops.ones((10, 1)) + + class ArgLayer(keras.layers.Layer): + + def call(self, x, y): + return x + y + + layer = ArgLayer() + out = self.evaluate(layer(x=x, y=y)) + self.assertAllClose(out, 2 * np.ones((10, 1))) + + class KwargLayer(keras.layers.Layer): + + def call(self, x=None, y=None): + return x + y + + layer = KwargLayer() + out = self.evaluate(layer(x=x, y=y)) + self.assertAllClose(out, 2 * np.ones((10, 1))) + + with self.assertRaisesRegexp(ValueError, 'must always be passed'): + layer(y=y) + + class TFFunctionLayer(keras.layers.Layer): + + @def_function.function + def call(self, x, y=None): + if y is None: + return x + return x + y + + layer = TFFunctionLayer() + out = self.evaluate(layer(x=x, y=y)) + self.assertAllClose(out, 2 * np.ones((10, 1))) + class SymbolicSupportTest(test.TestCase): diff --git a/tensorflow/python/keras/saving/saving_utils.py b/tensorflow/python/keras/saving/saving_utils.py index 0949aa10a2b..fe8d26485b9 100644 --- a/tensorflow/python/keras/saving/saving_utils.py +++ b/tensorflow/python/keras/saving/saving_utils.py @@ -147,7 +147,7 @@ def trace_model_call(model, input_signature=None): with base_layer_utils.call_context().enter( model, inputs=inputs, build_graph=False, training=False, saving=True): - outputs_list = nest.flatten(model(inputs=inputs, training=False)) + outputs_list = nest.flatten(model(inputs, training=False)) try: output_names = model.output_names @@ -211,8 +211,8 @@ def model_metadata(model, include_optimizer=True, require_config=True): metadata['training_config']['optimizer_config'] = optimizer_config except AttributeError: pass # If the model has an optimizer, but not all of the attributes - # loss, _compile_metrics, etc., then it was not compiled using - # model.compile. In this case, do not save the training config. + # loss, _compile_metrics, etc., then it was not compiled using + # model.compile. In this case, do not save the training config. return metadata