Extend Keras Lambda layers to work with functions of any signature rather than only functions that take one argument.
Any *args and **kwargs passed when calling the lambda layer will be forwarded directly to the underlying lambda. PiperOrigin-RevId: 311789009 Change-Id: Ic072d2252038330cc944d7f565f14806753d7436
This commit is contained in:
parent
2bbf57217f
commit
b3bf8bd856
@ -53,7 +53,7 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@ -738,8 +738,7 @@ class Lambda(Layer):
|
||||
models. `Lambda` layers are best suited for simple operations or
|
||||
quick experimentation. For more advanced use cases, follow
|
||||
[this guide](https://www.tensorflow.org/guide/keras/custom_layers_and_models)
|
||||
for subclassing `tf.keras.layers.Layer`. (Do not subclass
|
||||
`tf.keras.layers.Lamba`.)
|
||||
for subclassing `tf.keras.layers.Layer`.
|
||||
|
||||
The main reason to subclass `tf.keras.layers.Layer` instead of using a
|
||||
`Lambda` layer is saving and inspecting a Model. `Lambda` layers
|
||||
@ -799,7 +798,8 @@ class Lambda(Layer):
|
||||
computation, but anything more complex should use a subclass Layer instead.
|
||||
|
||||
Arguments:
|
||||
function: The function to evaluate when the layer is called.
|
||||
function: The function to be evaluated. Takes input tensor as first
|
||||
argument.
|
||||
output_shape: Expected output shape from function. This argument can be
|
||||
inferred if not explicitly provided. Can be a tuple or function. If a
|
||||
tuple, it only specifies the first dimension onward;
|
||||
@ -812,8 +812,8 @@ class Lambda(Layer):
|
||||
mask: Either None (indicating no masking) or a callable with the same
|
||||
signature as the `compute_mask` layer method, or a tensor that will be
|
||||
returned as output mask regardless of what the input is.
|
||||
arguments: Optional dictionary of keyword arguments to pass by default to
|
||||
the function when those arguments are not passed to the layer call.
|
||||
arguments: Optional dictionary of keyword arguments to be passed to the
|
||||
function.
|
||||
Input shape: Arbitrary. Use the keyword argument input_shape (tuple of
|
||||
integers, does not include the samples axis) when using this layer as the
|
||||
first layer in a model.
|
||||
@ -823,16 +823,11 @@ class Lambda(Layer):
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, function, output_shape=None, mask=None, arguments=None,
|
||||
**kwargs):
|
||||
super(Lambda, self).__init__(**kwargs)
|
||||
|
||||
self.arguments = arguments or {}
|
||||
self.function = function
|
||||
|
||||
# Decorate the function to produce this layer's call method
|
||||
def _call_wrapper(*args, **kwargs):
|
||||
return self._call_wrapper(*args, **kwargs)
|
||||
self.call = tf_decorator.make_decorator(function, _call_wrapper)
|
||||
|
||||
super(Lambda, self).__init__(**kwargs)
|
||||
|
||||
if mask is not None:
|
||||
self.supports_masking = True
|
||||
self.mask = mask
|
||||
@ -841,8 +836,9 @@ class Lambda(Layer):
|
||||
# Warning on every invocation will be quite irksome in Eager mode.
|
||||
self._already_warned = False
|
||||
|
||||
self._expects_training_arg = 'training' in self._call_fn_args
|
||||
self._expects_mask_arg = 'mask' in self._call_fn_args
|
||||
function_args = tf_inspect.getfullargspec(function).args
|
||||
self._fn_expects_training_arg = 'training' in function_args
|
||||
self._fn_expects_mask_arg = 'mask' in function_args
|
||||
|
||||
@tf_utils.shape_type_conversion
|
||||
def compute_output_shape(self, input_shape):
|
||||
@ -873,22 +869,23 @@ class Lambda(Layer):
|
||||
output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False)
|
||||
return nest.map_structure(_add_batch, output_shapes)
|
||||
|
||||
def _call_wrapper(self, *args, **kwargs):
|
||||
def call(self, inputs, mask=None, training=None):
|
||||
# We must copy for thread safety, but it only needs to be a shallow copy.
|
||||
call_kwargs = {k: v for k, v in self.arguments.items()}
|
||||
|
||||
# override default kwargs with the args passed to the layer call
|
||||
call_kwargs.update(kwargs)
|
||||
kwargs = {k: v for k, v in self.arguments.items()}
|
||||
if self._fn_expects_mask_arg:
|
||||
kwargs['mask'] = mask
|
||||
if self._fn_expects_training_arg:
|
||||
kwargs['training'] = training
|
||||
|
||||
created_variables = []
|
||||
def _variable_creator(next_creator, **creator_kwargs):
|
||||
var = next_creator(**creator_kwargs)
|
||||
def _variable_creator(next_creator, **kwargs):
|
||||
var = next_creator(**kwargs)
|
||||
created_variables.append(var)
|
||||
return var
|
||||
|
||||
with backprop.GradientTape(watch_accessed_variables=True) as tape,\
|
||||
variable_scope.variable_creator_scope(_variable_creator):
|
||||
result = self.function(*args, **call_kwargs)
|
||||
result = self.function(inputs, **kwargs)
|
||||
self._check_variables(created_variables, tape.watched_variables())
|
||||
return result
|
||||
|
||||
|
@ -139,26 +139,6 @@ class LambdaLayerTest(keras_parameterized.TestCase):
|
||||
out = ld([x1, x2])
|
||||
self.assertAllEqual(out.shape, [3, 2])
|
||||
|
||||
def test_lambda_multiple_args(self):
|
||||
ld = keras.layers.Lambda(lambda x, y: x[0] + y)
|
||||
x1 = np.ones([3, 2], np.float32)
|
||||
x2 = np.ones([3, 5], np.float32)
|
||||
|
||||
expected_result = x1 * 2
|
||||
self.assertAllEqual(ld([x1, x2], x1), expected_result)
|
||||
self.assertAllEqual(ld([x1, x2], y=x1), expected_result)
|
||||
self.assertAllEqual(ld(x=[x1, x2], y=x1), expected_result)
|
||||
|
||||
def test_lambda_constructor_args_and_multiple_args(self):
|
||||
x1 = np.ones([3, 2], np.float32)
|
||||
x2 = np.ones([3, 5], np.float32)
|
||||
ld = keras.layers.Lambda(lambda x, y: x[0] + y, arguments={'y': x1*2})
|
||||
|
||||
self.assertAllEqual(ld([x1, x2]), x1 * 3)
|
||||
self.assertAllEqual(ld([x1, x2], y=x1), x1 * 2)
|
||||
self.assertAllEqual(ld(x=[x1, x2]), x1 * 3)
|
||||
self.assertAllEqual(ld(x=[x1, x2], y=x1), x1 * 2)
|
||||
|
||||
def test_lambda_output_shape(self):
|
||||
l = keras.layers.Lambda(lambda x: x + 1, output_shape=(1, 1))
|
||||
l(keras.backend.variable(np.ones((1, 1))))
|
||||
|
@ -145,7 +145,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
@ -145,7 +145,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
Loading…
Reference in New Issue
Block a user