Fix issue when a Layer's first argument isn't called "inputs".
PiperOrigin-RevId: 290563724 Change-Id: I55a5da8a4624dfc330c89e9ce5302501137b82cb
This commit is contained in:
parent
6e85ba8898
commit
dd116af48d
@ -626,13 +626,12 @@ class Layer(module.Module):
|
|||||||
# carry over the input mask
|
# carry over the input mask
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
def __call__(self, inputs, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
"""Wraps `call`, applying pre- and post-processing steps.
|
"""Wraps `call`, applying pre- and post-processing steps.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
inputs: input tensor(s).
|
*args: Positional arguments to be passed to `self.call`.
|
||||||
*args: additional positional arguments to be passed to `self.call`.
|
**kwargs: Keyword arguments to be passed to `self.call`.
|
||||||
**kwargs: additional keyword arguments to be passed to `self.call`.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output tensor(s).
|
Output tensor(s).
|
||||||
@ -655,6 +654,17 @@ class Layer(module.Module):
|
|||||||
if not hasattr(self, '_thread_local'):
|
if not hasattr(self, '_thread_local'):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'You must call `super().__init__()` in the layer constructor.')
|
'You must call `super().__init__()` in the layer constructor.')
|
||||||
|
|
||||||
|
# Grab the first positional or keyword argument.
|
||||||
|
if args:
|
||||||
|
inputs = args[0]
|
||||||
|
args = args[1:]
|
||||||
|
elif self._call_fn_args[0] in kwargs:
|
||||||
|
inputs = kwargs.pop(self._call_fn_args[0])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'The first argument to `Layer.call` must always be passed.')
|
||||||
|
|
||||||
call_context = base_layer_utils.call_context()
|
call_context = base_layer_utils.call_context()
|
||||||
input_list = nest.flatten(inputs)
|
input_list = nest.flatten(inputs)
|
||||||
|
|
||||||
|
@ -593,6 +593,43 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||||||
with self.assertRaisesRegexp(RuntimeError, 'You must call `super()'):
|
with self.assertRaisesRegexp(RuntimeError, 'You must call `super()'):
|
||||||
layer(np.random.random((10, 2)))
|
layer(np.random.random((10, 2)))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def test_first_arg_not_called_inputs(self):
|
||||||
|
x, y = array_ops.ones((10, 1)), array_ops.ones((10, 1))
|
||||||
|
|
||||||
|
class ArgLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def call(self, x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
layer = ArgLayer()
|
||||||
|
out = self.evaluate(layer(x=x, y=y))
|
||||||
|
self.assertAllClose(out, 2 * np.ones((10, 1)))
|
||||||
|
|
||||||
|
class KwargLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def call(self, x=None, y=None):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
layer = KwargLayer()
|
||||||
|
out = self.evaluate(layer(x=x, y=y))
|
||||||
|
self.assertAllClose(out, 2 * np.ones((10, 1)))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'must always be passed'):
|
||||||
|
layer(y=y)
|
||||||
|
|
||||||
|
class TFFunctionLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def call(self, x, y=None):
|
||||||
|
if y is None:
|
||||||
|
return x
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
layer = TFFunctionLayer()
|
||||||
|
out = self.evaluate(layer(x=x, y=y))
|
||||||
|
self.assertAllClose(out, 2 * np.ones((10, 1)))
|
||||||
|
|
||||||
|
|
||||||
class SymbolicSupportTest(test.TestCase):
|
class SymbolicSupportTest(test.TestCase):
|
||||||
|
|
||||||
|
@ -147,7 +147,7 @@ def trace_model_call(model, input_signature=None):
|
|||||||
|
|
||||||
with base_layer_utils.call_context().enter(
|
with base_layer_utils.call_context().enter(
|
||||||
model, inputs=inputs, build_graph=False, training=False, saving=True):
|
model, inputs=inputs, build_graph=False, training=False, saving=True):
|
||||||
outputs_list = nest.flatten(model(inputs=inputs, training=False))
|
outputs_list = nest.flatten(model(inputs, training=False))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
output_names = model.output_names
|
output_names = model.output_names
|
||||||
@ -211,8 +211,8 @@ def model_metadata(model, include_optimizer=True, require_config=True):
|
|||||||
metadata['training_config']['optimizer_config'] = optimizer_config
|
metadata['training_config']['optimizer_config'] = optimizer_config
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass # If the model has an optimizer, but not all of the attributes
|
pass # If the model has an optimizer, but not all of the attributes
|
||||||
# loss, _compile_metrics, etc., then it was not compiled using
|
# loss, _compile_metrics, etc., then it was not compiled using
|
||||||
# model.compile. In this case, do not save the training config.
|
# model.compile. In this case, do not save the training config.
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user