Fix issue when a Layer's first argument isn't called "inputs".
PiperOrigin-RevId: 290563724 Change-Id: I55a5da8a4624dfc330c89e9ce5302501137b82cb
This commit is contained in:
parent
6e85ba8898
commit
dd116af48d
@ -626,13 +626,12 @@ class Layer(module.Module):
|
||||
# carry over the input mask
|
||||
return mask
|
||||
|
||||
def __call__(self, inputs, *args, **kwargs):
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Wraps `call`, applying pre- and post-processing steps.
|
||||
|
||||
Arguments:
|
||||
inputs: input tensor(s).
|
||||
*args: additional positional arguments to be passed to `self.call`.
|
||||
**kwargs: additional keyword arguments to be passed to `self.call`.
|
||||
*args: Positional arguments to be passed to `self.call`.
|
||||
**kwargs: Keyword arguments to be passed to `self.call`.
|
||||
|
||||
Returns:
|
||||
Output tensor(s).
|
||||
@ -655,6 +654,17 @@ class Layer(module.Module):
|
||||
if not hasattr(self, '_thread_local'):
|
||||
raise RuntimeError(
|
||||
'You must call `super().__init__()` in the layer constructor.')
|
||||
|
||||
# Grab the first positional or keyword argument.
|
||||
if args:
|
||||
inputs = args[0]
|
||||
args = args[1:]
|
||||
elif self._call_fn_args[0] in kwargs:
|
||||
inputs = kwargs.pop(self._call_fn_args[0])
|
||||
else:
|
||||
raise ValueError(
|
||||
'The first argument to `Layer.call` must always be passed.')
|
||||
|
||||
call_context = base_layer_utils.call_context()
|
||||
input_list = nest.flatten(inputs)
|
||||
|
||||
|
@ -593,6 +593,43 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
||||
with self.assertRaisesRegexp(RuntimeError, 'You must call `super()'):
|
||||
layer(np.random.random((10, 2)))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_first_arg_not_called_inputs(self):
|
||||
x, y = array_ops.ones((10, 1)), array_ops.ones((10, 1))
|
||||
|
||||
class ArgLayer(keras.layers.Layer):
|
||||
|
||||
def call(self, x, y):
|
||||
return x + y
|
||||
|
||||
layer = ArgLayer()
|
||||
out = self.evaluate(layer(x=x, y=y))
|
||||
self.assertAllClose(out, 2 * np.ones((10, 1)))
|
||||
|
||||
class KwargLayer(keras.layers.Layer):
|
||||
|
||||
def call(self, x=None, y=None):
|
||||
return x + y
|
||||
|
||||
layer = KwargLayer()
|
||||
out = self.evaluate(layer(x=x, y=y))
|
||||
self.assertAllClose(out, 2 * np.ones((10, 1)))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must always be passed'):
|
||||
layer(y=y)
|
||||
|
||||
class TFFunctionLayer(keras.layers.Layer):
|
||||
|
||||
@def_function.function
|
||||
def call(self, x, y=None):
|
||||
if y is None:
|
||||
return x
|
||||
return x + y
|
||||
|
||||
layer = TFFunctionLayer()
|
||||
out = self.evaluate(layer(x=x, y=y))
|
||||
self.assertAllClose(out, 2 * np.ones((10, 1)))
|
||||
|
||||
|
||||
class SymbolicSupportTest(test.TestCase):
|
||||
|
||||
|
@ -147,7 +147,7 @@ def trace_model_call(model, input_signature=None):
|
||||
|
||||
with base_layer_utils.call_context().enter(
|
||||
model, inputs=inputs, build_graph=False, training=False, saving=True):
|
||||
outputs_list = nest.flatten(model(inputs=inputs, training=False))
|
||||
outputs_list = nest.flatten(model(inputs, training=False))
|
||||
|
||||
try:
|
||||
output_names = model.output_names
|
||||
@ -211,8 +211,8 @@ def model_metadata(model, include_optimizer=True, require_config=True):
|
||||
metadata['training_config']['optimizer_config'] = optimizer_config
|
||||
except AttributeError:
|
||||
pass # If the model has an optimizer, but not all of the attributes
|
||||
# loss, _compile_metrics, etc., then it was not compiled using
|
||||
# model.compile. In this case, do not save the training config.
|
||||
# loss, _compile_metrics, etc., then it was not compiled using
|
||||
# model.compile. In this case, do not save the training config.
|
||||
return metadata
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user