Fix issue when a Layer's first argument isn't called "inputs".

PiperOrigin-RevId: 290563724
Change-Id: I55a5da8a4624dfc330c89e9ce5302501137b82cb
This commit is contained in:
Thomas O'Malley 2020-01-19 23:33:52 -08:00 committed by TensorFlower Gardener
parent 6e85ba8898
commit dd116af48d
3 changed files with 54 additions and 7 deletions

View File

@ -626,13 +626,12 @@ class Layer(module.Module):
# carry over the input mask
return mask
def __call__(self, inputs, *args, **kwargs):
def __call__(self, *args, **kwargs):
"""Wraps `call`, applying pre- and post-processing steps.
Arguments:
inputs: input tensor(s).
*args: additional positional arguments to be passed to `self.call`.
**kwargs: additional keyword arguments to be passed to `self.call`.
*args: Positional arguments to be passed to `self.call`.
**kwargs: Keyword arguments to be passed to `self.call`.
Returns:
Output tensor(s).
@ -655,6 +654,17 @@ class Layer(module.Module):
if not hasattr(self, '_thread_local'):
raise RuntimeError(
'You must call `super().__init__()` in the layer constructor.')
# Grab the first positional or keyword argument.
if args:
inputs = args[0]
args = args[1:]
elif self._call_fn_args[0] in kwargs:
inputs = kwargs.pop(self._call_fn_args[0])
else:
raise ValueError(
'The first argument to `Layer.call` must always be passed.')
call_context = base_layer_utils.call_context()
input_list = nest.flatten(inputs)

View File

@ -593,6 +593,43 @@ class BaseLayerTest(keras_parameterized.TestCase):
with self.assertRaisesRegexp(RuntimeError, 'You must call `super()'):
layer(np.random.random((10, 2)))
@test_util.run_in_graph_and_eager_modes
def test_first_arg_not_called_inputs(self):
x, y = array_ops.ones((10, 1)), array_ops.ones((10, 1))
class ArgLayer(keras.layers.Layer):
def call(self, x, y):
return x + y
layer = ArgLayer()
out = self.evaluate(layer(x=x, y=y))
self.assertAllClose(out, 2 * np.ones((10, 1)))
class KwargLayer(keras.layers.Layer):
def call(self, x=None, y=None):
return x + y
layer = KwargLayer()
out = self.evaluate(layer(x=x, y=y))
self.assertAllClose(out, 2 * np.ones((10, 1)))
with self.assertRaisesRegexp(ValueError, 'must always be passed'):
layer(y=y)
class TFFunctionLayer(keras.layers.Layer):
@def_function.function
def call(self, x, y=None):
if y is None:
return x
return x + y
layer = TFFunctionLayer()
out = self.evaluate(layer(x=x, y=y))
self.assertAllClose(out, 2 * np.ones((10, 1)))
class SymbolicSupportTest(test.TestCase):

View File

@ -147,7 +147,7 @@ def trace_model_call(model, input_signature=None):
with base_layer_utils.call_context().enter(
model, inputs=inputs, build_graph=False, training=False, saving=True):
outputs_list = nest.flatten(model(inputs=inputs, training=False))
outputs_list = nest.flatten(model(inputs, training=False))
try:
output_names = model.output_names
@ -211,8 +211,8 @@ def model_metadata(model, include_optimizer=True, require_config=True):
metadata['training_config']['optimizer_config'] = optimizer_config
except AttributeError:
pass # If the model has an optimizer, but not all of the attributes
# loss, _compile_metrics, etc., then it was not compiled using
# model.compile. In this case, do not save the training config.
# loss, _compile_metrics, etc., then it was not compiled using
# model.compile. In this case, do not save the training config.
return metadata